[engine] add recognizer_controller && fix build bugs (#3086)

* fix asr compile
pull/3112/head
YangZhou 2 years ago committed by GitHub
parent 2be7e5725f
commit 767f6dd4e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -14,11 +14,6 @@ set(PPS_VERSION_MINOR 0)
set(PPS_VERSION_PATCH 0)
set(PPS_VERSION "${PPS_VERSION_MAJOR}.${PPS_VERSION_MINOR}.${PPS_VERSION_PATCH}")
# fc_patch dir
set(FETCHCONTENT_QUIET off)
get_filename_component(fc_patch "fc_patch" REALPATH BASE_DIR "${CMAKE_SOURCE_DIR}")
set(FETCHCONTENT_BASE_DIR ${fc_patch})
# compiler option
# Keep the same with openfst, -fPIC or -fpic
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} --std=c++14 -pthread -fPIC -O0 -Wall -g -ldl")
@ -50,11 +45,18 @@ set(CMAKE_INSTALL_PREFIX ${CMAKE_CURRENT_BINARY_DIR}/install)
include(FetchContent)
include(ExternalProject)
# fc_patch dir
set(FETCHCONTENT_QUIET off)
get_filename_component(fc_patch "fc_patch" REALPATH BASE_DIR "${CMAKE_SOURCE_DIR}")
set(FETCHCONTENT_BASE_DIR ${fc_patch})
###############################################################################
# Option Configurations
###############################################################################
# https://github.com/google/brotli/pull/655
option(BUILD_SHARED_LIBS "Build shared libraries" ON)
option(NDEBUG "debug option" OFF)
option(WITH_ASR "build asr" ON)
option(WITH_CLS "build cls" ON)

@ -9,5 +9,6 @@ FetchContent_MakeAvailable(gflags)
# openfst need
include_directories(${gflags_BINARY_DIR}/include)
link_directories(${gflags_BINARY_DIR})
install(FILES ${gflags_BINARY_DIR}/libgflags_nothreads.a DESTINATION lib)
#install(FILES ${gflags_BINARY_DIR}/libgflags_nothreads.a DESTINATION lib)

@ -30,7 +30,7 @@ ExternalProject_Add(openfst
CONFIGURE_COMMAND ${openfst_SOURCE_DIR}/configure --prefix=${openfst_PREFIX_DIR}
"CPPFLAGS=-I${gflags_BINARY_DIR}/include -I${glog_SOURCE_DIR}/src -I${glog_BINARY_DIR}"
"LDFLAGS=-L${gflags_BINARY_DIR} -L${glog_BINARY_DIR}"
"LIBS=-lgflags_nothreads -lglog -lpthread"
"LIBS=-lgflags_nothreads -lglog -lpthread -fPIC"
COMMAND ${CMAKE_COMMAND} -E copy_directory ${PROJECT_SOURCE_DIR}/patch/openfst ${openfst_SOURCE_DIR}
BUILD_COMMAND make -j 4
)

@ -21,4 +21,4 @@ if(WITH_VAD)
add_subdirectory(vad)
endif()
add_subdirectory(codelab)
add_subdirectory(codelab)

@ -16,9 +16,9 @@ set(TEST_BINS
foreach(bin_name IN LISTS TEST_BINS)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util)
target_link_libraries(${bin_name} nnet decoder fst utils libgflags_nothreads.so glog kaldi-base kaldi-matrix kaldi-util)
target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS})
target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS})
target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS} -ldl)
endforeach()

@ -17,6 +17,7 @@
// https://github.com/wenet-e2e/wenet/blob/main/runtime/core/decoder/asr_model.cc
#include "nnet/u2_nnet.h"
#include <type_traits>
#ifdef WITH_PROFILING
#include "paddle/fluid/platform/profiler.h"
@ -214,7 +215,7 @@ void U2Nnet::ForwardEncoderChunkImpl(
// not cache feature in nnet
CHECK_EQ(cached_feats_.size(), 0);
// CHECK_EQ(std::is_same<float, kaldi::BaseFloat>::value, true);
CHECK_EQ((std::is_same<float, kaldi::BaseFloat>::value), true);
std::memcpy(feats_ptr,
chunk_feats.data(),
chunk_feats.size() * sizeof(kaldi::BaseFloat));

@ -2,6 +2,7 @@ set(srcs)
list(APPEND srcs
u2_recognizer.cc
recognizer_controller.cc
)
add_library(recognizer STATIC ${srcs})
@ -11,13 +12,14 @@ set(TEST_BINS
u2_recognizer_main
u2_recognizer_thread_main
u2_recognizer_batch_main
recognizer_batch_main
)
foreach(bin_name IN LISTS TEST_BINS)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} recognizer nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util)
target_link_libraries(${bin_name} recognizer nnet decoder fst utils libgflags_nothreads.so glog kaldi-base kaldi-matrix kaldi-util)
target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS})
target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS})
target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS} -ldl)
endforeach()

@ -12,48 +12,66 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "common/base/thread_pool.h"
#include "common/utils/file_utils.h"
#include "common/utils/strings.h"
#include "decoder/param.h"
#include "frontend/wave-reader.h"
#include "kaldi/util/table-types.h"
#include "nnet/u2_nnet.h"
#include "recognizer/u2_recognizer.h"
#include "recognizer/recognizer_controller.h"
DEFINE_string(wav_rspecifier, "", "test feature rspecifier");
DEFINE_string(result_wspecifier, "", "test result wspecifier");
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
DEFINE_int32(sample_rate, 16000, "sample rate");
DEFINE_int32(njob, 3, "njob");
using std::string;
using std::vector;
void SplitUtt(string wavlist_file,
vector<vector<string>>* uttlists,
vector<vector<string>>* wavlists,
int njob) {
vector<string> wavlist;
wavlists->resize(njob);
uttlists->resize(njob);
ppspeech::ReadFileToVector(wavlist_file, &wavlist);
for (size_t idx = 0; idx < wavlist.size(); ++idx) {
string utt_str = wavlist[idx];
vector<string> utt_wav = ppspeech::StrSplit(utt_str, " \t");
LOG(INFO) << utt_wav[0];
CHECK_EQ(utt_wav.size(), size_t(2));
uttlists->at(idx % njob).push_back(utt_wav[0]);
wavlists->at(idx % njob).push_back(utt_wav[1]);
}
}
int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
void recognizer_func(ppspeech::RecognizerController* recognizer_controller,
std::vector<string> wavlist,
std::vector<string> uttlist,
std::vector<string>* results) {
int32 num_done = 0, num_err = 0;
double tot_wav_duration = 0.0;
double tot_attention_rescore_time = 0.0;
double tot_decode_time = 0.0;
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
FLAGS_wav_rspecifier);
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
int sample_rate = FLAGS_sample_rate;
float streaming_chunk = FLAGS_streaming_chunk;
int chunk_sample_size = streaming_chunk * sample_rate;
LOG(INFO) << "sr: " << sample_rate;
LOG(INFO) << "chunk size (s): " << streaming_chunk;
LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
ppspeech::U2RecognizerResource resource =
ppspeech::U2RecognizerResource::InitFromFlags();
std::shared_ptr<ppspeech::U2Recognizer> recognizer_ptr(
new ppspeech::U2Recognizer(resource));
for (; !wav_reader.Done(); wav_reader.Next()) {
recognizer_ptr->InitDecoder();
std::string utt = wav_reader.Key();
const kaldi::WaveData& wave_data = wav_reader.Value();
int chunk_sample_size = FLAGS_streaming_chunk * FLAGS_sample_rate;
if (wavlist.empty()) return;
results->reserve(wavlist.size());
for (size_t idx = 0; idx < wavlist.size(); ++idx) {
std::string utt = uttlist[idx];
std::string wav_file = wavlist[idx];
std::ifstream infile;
infile.open(wav_file, std::ifstream::in);
kaldi::WaveData wave_data;
wave_data.Read(infile);
int32 recog_id = -1;
while (recog_id != -1) {
recog_id = recognizer_controller->GetRecognizerInstanceId();
}
LOG(INFO) << "utt: " << utt;
LOG(INFO) << "wav dur: " << wave_data.Duration() << " sec.";
double dur = wave_data.Duration();
@ -77,27 +95,21 @@ int main(int argc, char* argv[]) {
wav_chunk[i] = waveform(sample_offset + i);
}
recognizer_ptr->Accept(wav_chunk);
recognizer_controller->Accept(wav_chunk, recog_id);
if (cur_chunk_size < chunk_sample_size) {
recognizer_ptr->SetInputFinished();
recognizer_controller->SetInputFinished(recog_id);
}
// no overlap
sample_offset += cur_chunk_size;
}
CHECK(sample_offset == tot_samples);
recognizer_ptr->WaitDecodeFinished();
kaldi::Timer timer;
recognizer_ptr->AttentionRescoring();
tot_attention_rescore_time += timer.Elapsed();
std::string result = recognizer_ptr->GetFinalResult();
std::string result = recognizer_controller->GetFinalResult(recog_id);
if (result.empty()) {
// the TokenWriter can not write empty string.
++num_err;
LOG(INFO) << " the result of " << utt << " is empty";
continue;
result = " ";
}
tot_decode_time += local_timer.Elapsed();
@ -105,15 +117,59 @@ int main(int argc, char* argv[]) {
LOG(INFO) << " RTF: " << local_timer.Elapsed() / dur << " dur: " << dur
<< " cost: " << local_timer.Elapsed();
result_writer.Write(utt, result);
results->push_back(result);
++num_done;
}
recognizer_ptr->WaitFinished();
LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done);
LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec";
LOG(INFO) << "total decode cost:" << tot_decode_time << " sec";
LOG(INFO) << "total rescore cost:" << tot_attention_rescore_time << " sec";
LOG(INFO) << "RTF is: " << tot_decode_time / tot_wav_duration;
}
int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
int sample_rate = FLAGS_sample_rate;
float streaming_chunk = FLAGS_streaming_chunk;
int chunk_sample_size = streaming_chunk * sample_rate;
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
int njob = FLAGS_njob;
LOG(INFO) << "sr: " << sample_rate;
LOG(INFO) << "chunk size (s): " << streaming_chunk;
LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
ppspeech::U2RecognizerResource resource =
ppspeech::U2RecognizerResource::InitFromFlags();
ppspeech::RecognizerController recognizer_controller(njob, resource);
ThreadPool threadpool(njob);
vector<vector<string>> wavlist;
vector<vector<string>> uttlist;
vector<vector<string>> resultlist(njob);
vector<std::future<void>> futurelist;
SplitUtt(FLAGS_wav_rspecifier, &uttlist, &wavlist, njob);
for (size_t i = 0; i < njob; ++i) {
std::future<void> f = threadpool.enqueue(recognizer_func,
&recognizer_controller,
wavlist[i],
uttlist[i],
&resultlist[i]);
futurelist.push_back(std::move(f));
}
for (size_t i = 0; i < njob; ++i) {
futurelist[i].get();
}
for (size_t idx = 0; idx < njob; ++idx) {
for (size_t utt_idx = 0; utt_idx < uttlist[idx].size(); ++utt_idx) {
string utt = uttlist[idx][utt_idx];
string result = resultlist[idx][utt_idx];
result_writer.Write(utt, result);
}
}
return 0;
}

@ -0,0 +1,71 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "recognizer/recognizer_controller.h"
#include "recognizer/u2_recognizer.h"
#include "nnet/u2_nnet.h"
namespace ppspeech {
RecognizerController::RecognizerController(int num_worker, U2RecognizerResource resource) {
nnet_ = std::make_shared<ppspeech::U2Nnet>(resource.model_opts);
recognizer_workers.resize(num_worker);
for (size_t i = 0; i < num_worker; ++i) {
recognizer_workers[i].reset(new ppspeech::U2Recognizer(resource, nnet_->Clone()));
recognizer_workers[i]->InitDecoder();
waiting_workers.push(i);
}
}
int RecognizerController::GetRecognizerInstanceId() {
if (waiting_workers.empty()) {
return -1;
}
int idx = -1;
{
std::unique_lock<std::mutex> lock(mutex_);
idx = waiting_workers.front();
waiting_workers.pop();
}
return idx;
}
RecognizerController::~RecognizerController() {
for (size_t i = 0; i < recognizer_workers.size(); ++i) {
recognizer_workers[i]->SetInputFinished();
recognizer_workers[i]->WaitDecodeFinished();
}
}
std::string RecognizerController::GetFinalResult(int idx) {
recognizer_workers[idx]->WaitDecodeFinished();
recognizer_workers[idx]->AttentionRescoring();
std::string result = recognizer_workers[idx]->GetFinalResult();
recognizer_workers[idx]->InitDecoder();
{
std::unique_lock<std::mutex> lock(mutex_);
waiting_workers.push(idx);
}
return result;
}
void RecognizerController::Accept(std::vector<float> data, int idx) {
recognizer_workers[idx]->Accept(data);
}
void RecognizerController::SetInputFinished(int idx) {
recognizer_workers[idx]->SetInputFinished();
}
}

@ -0,0 +1,39 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <queue>
#include <memory>
#include "recognizer/u2_recognizer.h"
#include "nnet/u2_nnet.h"
namespace ppspeech {
class RecognizerController {
public:
explicit RecognizerController(int num_worker, U2RecognizerResource resource);
~RecognizerController();
int GetRecognizerInstanceId();
void Accept(std::vector<float> data, int idx);
void SetInputFinished(int idx);
std::string GetFinalResult(int idx);
private:
std::queue<int> waiting_workers;
std::shared_ptr<ppspeech::U2Nnet> nnet_;
std::mutex mutex_;
std::vector<std::unique_ptr<ppspeech::U2Recognizer>> recognizer_workers;
};
}

@ -31,6 +31,7 @@ int main(int argc, char* argv[]) {
int32 num_done = 0, num_err = 0;
double tot_wav_duration = 0.0;
double tot_attention_rescore_time = 0.0;
double tot_decode_time = 0.0;
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
@ -46,10 +47,11 @@ int main(int argc, char* argv[]) {
ppspeech::U2RecognizerResource resource =
ppspeech::U2RecognizerResource::InitFromFlags();
ppspeech::U2Recognizer recognizer(resource);
std::shared_ptr<ppspeech::U2Recognizer> recognizer_ptr(
new ppspeech::U2Recognizer(resource));
for (; !wav_reader.Done(); wav_reader.Next()) {
recognizer.InitDecoder();
recognizer_ptr->InitDecoder();
std::string utt = wav_reader.Key();
const kaldi::WaveData& wave_data = wav_reader.Value();
LOG(INFO) << "utt: " << utt;
@ -64,8 +66,6 @@ int main(int argc, char* argv[]) {
LOG(INFO) << "wav len (sample): " << tot_samples;
int sample_offset = 0;
int cnt = 0;
kaldi::Timer timer;
kaldi::Timer local_timer;
while (sample_offset < tot_samples) {
@ -76,32 +76,23 @@ int main(int argc, char* argv[]) {
for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk[i] = waveform(sample_offset + i);
}
// wav_chunk = waveform.Range(sample_offset + i, cur_chunk_size);
recognizer.Accept(wav_chunk);
if (cur_chunk_size < chunk_sample_size) {
recognizer.SetInputFinished();
}
recognizer.Decode();
if (recognizer.DecodedSomething()) {
LOG(INFO) << "Pratial result: " << cnt << " "
<< recognizer.GetPartialResult();
recognizer_ptr->Accept(wav_chunk);
if (cur_chunk_size == (tot_samples - sample_offset)) {
recognizer_ptr->SetInputFinished();
}
// no overlap
sample_offset += cur_chunk_size;
cnt++;
}
CHECK(sample_offset == tot_samples);
recognizer_ptr->WaitDecodeFinished();
// second pass decoding
recognizer.Rescoring();
tot_decode_time += timer.Elapsed();
std::string result = recognizer.GetFinalResult();
kaldi::Timer timer;
recognizer_ptr->AttentionRescoring();
tot_attention_rescore_time += timer.Elapsed();
std::string result = recognizer_ptr->GetFinalResult();
if (result.empty()) {
// the TokenWriter can not write empty string.
++num_err;
@ -109,6 +100,7 @@ int main(int argc, char* argv[]) {
continue;
}
tot_decode_time += local_timer.Elapsed();
LOG(INFO) << utt << " " << result;
LOG(INFO) << " RTF: " << local_timer.Elapsed() / dur << " dur: " << dur
<< " cost: " << local_timer.Elapsed();
@ -117,9 +109,11 @@ int main(int argc, char* argv[]) {
++num_done;
}
recognizer_ptr->WaitFinished();
LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done);
LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec";
LOG(INFO) << "total decode cost:" << tot_decode_time << " sec";
LOG(INFO) << "total rescore cost:" << tot_attention_rescore_time << " sec";
LOG(INFO) << "RTF is: " << tot_decode_time / tot_wav_duration;
}

@ -26,5 +26,5 @@ foreach(bin_name IN LISTS BINS)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
# https://github.com/Kitware/CMake/blob/v3.1.0/Modules/FindThreads.cmake#L207
target_link_libraries(${bin_name} PUBLIC frontend base utils kaldi-util gflags Threads::Threads extern_glog)
endforeach()
target_link_libraries(${bin_name} PUBLIC frontend base utils kaldi-util libgflags_nothreads.so Threads::Threads extern_glog)
endforeach()

@ -11,5 +11,5 @@ fsttablecompose
foreach(binary IN LISTS BINS)
add_executable(${binary} ${CMAKE_CURRENT_SOURCE_DIR}/${binary}.cc)
target_include_directories(${binary} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${binary} PUBLIC kaldi-fstext glog gflags fst dl)
target_link_libraries(${binary} PUBLIC kaldi-fstext glog libgflags_nothreads.so fst dl)
endforeach()

Loading…
Cancel
Save