From fa352a551ec228e92bd5af1c99cdfe1665299264 Mon Sep 17 00:00:00 2001 From: YangZhou Date: Fri, 24 Mar 2023 22:44:05 +0800 Subject: [PATCH] fix --- runtime/engine/asr/decoder/CMakeLists.txt | 2 +- runtime/engine/asr/nnet/u2_nnet.cc | 16 +-- .../asr/recognizer/u2_recognizer_main.cc | 36 +++--- .../recognizer/u2_recognizer_thread_main.cc | 119 ------------------ 4 files changed, 24 insertions(+), 149 deletions(-) delete mode 100644 runtime/engine/asr/recognizer/u2_recognizer_thread_main.cc diff --git a/runtime/engine/asr/decoder/CMakeLists.txt b/runtime/engine/asr/decoder/CMakeLists.txt index 0cd9fc48a..086dc1a4b 100644 --- a/runtime/engine/asr/decoder/CMakeLists.txt +++ b/runtime/engine/asr/decoder/CMakeLists.txt @@ -16,7 +16,7 @@ set(TEST_BINS foreach(bin_name IN LISTS TEST_BINS) add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) - target_link_libraries(${bin_name} nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util) + target_link_libraries(${bin_name} nnet decoder fst utils libgflags_nothreads.so glog kaldi-base kaldi-matrix kaldi-util) target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS}) target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS} -ldl) diff --git a/runtime/engine/asr/nnet/u2_nnet.cc b/runtime/engine/asr/nnet/u2_nnet.cc index d6b394729..ae088413c 100644 --- a/runtime/engine/asr/nnet/u2_nnet.cc +++ b/runtime/engine/asr/nnet/u2_nnet.cc @@ -214,7 +214,7 @@ void U2Nnet::ForwardEncoderChunkImpl( // not cache feature in nnet CHECK_EQ(cached_feats_.size(), 0); - // CHECK_EQ(std::is_same::value, true); + CHECK_EQ(std::is_same::value, true); std::memcpy(feats_ptr, chunk_feats.data(), chunk_feats.size() * sizeof(kaldi::BaseFloat)); @@ -594,7 +594,7 @@ void U2Nnet::AttentionRescoring(const std::vector>& hyps, } else { // dump r_probs CHECK_EQ(r_probs_shape.size(), 1); - //CHECK_EQ(r_probs_shape[0], 1) << r_probs_shape[0]; + CHECK_EQ(r_probs_shape[0], 1) << r_probs_shape[0]; } // compute rescoring score @@ -604,15 +604,15 @@ void U2Nnet::AttentionRescoring(const std::vector>& hyps, VLOG(2) << "split prob: " << probs_v.size() << " " << probs_v[0].shape().size() << " 0: " << probs_v[0].shape()[0] << ", " << probs_v[0].shape()[1] << ", " << probs_v[0].shape()[2]; - //CHECK(static_cast(probs_v.size()) == num_hyps) - // << ": is " << probs_v.size() << " expect: " << num_hyps; + CHECK(static_cast(probs_v.size()) == num_hyps) + << ": is " << probs_v.size() << " expect: " << num_hyps; std::vector r_probs_v; if (is_bidecoder_ && reverse_weight > 0) { r_probs_v = paddle::experimental::split_with_num(r_probs, num_hyps, 0); - //CHECK(static_cast(r_probs_v.size()) == num_hyps) - // << "r_probs_v size: is " << r_probs_v.size() - // << " expect: " << num_hyps; + CHECK(static_cast(r_probs_v.size()) == num_hyps) + << "r_probs_v size: is " << r_probs_v.size() + << " expect: " << num_hyps; } for (int i = 0; i < num_hyps; ++i) { @@ -654,7 +654,7 @@ void U2Nnet::EncoderOuts( const int& B = shape[0]; const int& T = shape[1]; const int& D = shape[2]; - //CHECK(B == 1) << "Only support batch one."; + CHECK(B == 1) << "Only support batch one."; VLOG(3) << "encoder out " << i << " shape: (" << B << "," << T << "," << D << ")"; diff --git a/runtime/engine/asr/recognizer/u2_recognizer_main.cc b/runtime/engine/asr/recognizer/u2_recognizer_main.cc index 178c91db1..e0d4b3883 100644 --- a/runtime/engine/asr/recognizer/u2_recognizer_main.cc +++ b/runtime/engine/asr/recognizer/u2_recognizer_main.cc @@ -31,6 +31,7 @@ int main(int argc, char* argv[]) { int32 num_done = 0, num_err = 0; double tot_wav_duration = 0.0; + double tot_attention_rescore_time = 0.0; double tot_decode_time = 0.0; kaldi::SequentialTableReader wav_reader( @@ -46,10 +47,11 @@ int main(int argc, char* argv[]) { ppspeech::U2RecognizerResource resource = ppspeech::U2RecognizerResource::InitFromFlags(); - ppspeech::U2Recognizer recognizer(resource); + std::shared_ptr recognizer_ptr( + new ppspeech::U2Recognizer(resource)); for (; !wav_reader.Done(); wav_reader.Next()) { - recognizer.InitDecoder(); + recognizer_ptr->InitDecoder(); std::string utt = wav_reader.Key(); const kaldi::WaveData& wave_data = wav_reader.Value(); LOG(INFO) << "utt: " << utt; @@ -64,8 +66,6 @@ int main(int argc, char* argv[]) { LOG(INFO) << "wav len (sample): " << tot_samples; int sample_offset = 0; - int cnt = 0; - kaldi::Timer timer; kaldi::Timer local_timer; while (sample_offset < tot_samples) { @@ -76,32 +76,23 @@ int main(int argc, char* argv[]) { for (int i = 0; i < cur_chunk_size; ++i) { wav_chunk[i] = waveform(sample_offset + i); } - // wav_chunk = waveform.Range(sample_offset + i, cur_chunk_size); - recognizer.Accept(wav_chunk); - if (cur_chunk_size < chunk_sample_size) { - recognizer.SetInputFinished(); - } - recognizer.Decode(); - if (recognizer.DecodedSomething()) { - LOG(INFO) << "Pratial result: " << cnt << " " - << recognizer.GetPartialResult(); + recognizer_ptr->Accept(wav_chunk); + if (cur_chunk_size == (tot_samples - sample_offset)) { + recognizer_ptr->SetInputFinished(); } // no overlap sample_offset += cur_chunk_size; - cnt++; } CHECK(sample_offset == tot_samples); + recognizer_ptr->WaitDecodeFinished(); - // second pass decoding - recognizer.Rescoring(); - - tot_decode_time += timer.Elapsed(); - - std::string result = recognizer.GetFinalResult(); - + kaldi::Timer timer; + recognizer_ptr->AttentionRescoring(); + tot_attention_rescore_time += timer.Elapsed(); + std::string result = recognizer_ptr->GetFinalResult(); if (result.empty()) { // the TokenWriter can not write empty string. ++num_err; @@ -109,6 +100,7 @@ int main(int argc, char* argv[]) { continue; } + tot_decode_time += local_timer.Elapsed(); LOG(INFO) << utt << " " << result; LOG(INFO) << " RTF: " << local_timer.Elapsed() / dur << " dur: " << dur << " cost: " << local_timer.Elapsed(); @@ -117,9 +109,11 @@ int main(int argc, char* argv[]) { ++num_done; } + recognizer_ptr->WaitFinished(); LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done); LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec"; LOG(INFO) << "total decode cost:" << tot_decode_time << " sec"; + LOG(INFO) << "total rescore cost:" << tot_attention_rescore_time << " sec"; LOG(INFO) << "RTF is: " << tot_decode_time / tot_wav_duration; } diff --git a/runtime/engine/asr/recognizer/u2_recognizer_thread_main.cc b/runtime/engine/asr/recognizer/u2_recognizer_thread_main.cc deleted file mode 100644 index 272defc60..000000000 --- a/runtime/engine/asr/recognizer/u2_recognizer_thread_main.cc +++ /dev/null @@ -1,119 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "decoder/param.h" -#include "frontend/wave-reader.h" -#include "kaldi/util/table-types.h" -#include "recognizer/u2_recognizer.h" - -DEFINE_string(wav_rspecifier, "", "test feature rspecifier"); -DEFINE_string(result_wspecifier, "", "test result wspecifier"); -DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size"); -DEFINE_int32(sample_rate, 16000, "sample rate"); - -int main(int argc, char* argv[]) { - gflags::SetUsageMessage("Usage:"); - gflags::ParseCommandLineFlags(&argc, &argv, false); - google::InitGoogleLogging(argv[0]); - google::InstallFailureSignalHandler(); - FLAGS_logtostderr = 1; - - int32 num_done = 0, num_err = 0; - double tot_wav_duration = 0.0; - double tot_attention_rescore_time = 0.0; - double tot_decode_time = 0.0; - - kaldi::SequentialTableReader wav_reader( - FLAGS_wav_rspecifier); - kaldi::TokenWriter result_writer(FLAGS_result_wspecifier); - - int sample_rate = FLAGS_sample_rate; - float streaming_chunk = FLAGS_streaming_chunk; - int chunk_sample_size = streaming_chunk * sample_rate; - LOG(INFO) << "sr: " << sample_rate; - LOG(INFO) << "chunk size (s): " << streaming_chunk; - LOG(INFO) << "chunk size (sample): " << chunk_sample_size; - - ppspeech::U2RecognizerResource resource = - ppspeech::U2RecognizerResource::InitFromFlags(); - std::shared_ptr recognizer_ptr( - new ppspeech::U2Recognizer(resource)); - - for (; !wav_reader.Done(); wav_reader.Next()) { - recognizer_ptr->InitDecoder(); - std::string utt = wav_reader.Key(); - const kaldi::WaveData& wave_data = wav_reader.Value(); - LOG(INFO) << "utt: " << utt; - LOG(INFO) << "wav dur: " << wave_data.Duration() << " sec."; - double dur = wave_data.Duration(); - tot_wav_duration += dur; - - int32 this_channel = 0; - kaldi::SubVector waveform(wave_data.Data(), - this_channel); - int tot_samples = waveform.Dim(); - LOG(INFO) << "wav len (sample): " << tot_samples; - - int sample_offset = 0; - kaldi::Timer local_timer; - - while (sample_offset < tot_samples) { - int cur_chunk_size = - std::min(chunk_sample_size, tot_samples - sample_offset); - - std::vector wav_chunk(cur_chunk_size); - for (int i = 0; i < cur_chunk_size; ++i) { - wav_chunk[i] = waveform(sample_offset + i); - } - - recognizer_ptr->Accept(wav_chunk); - if (cur_chunk_size < chunk_sample_size) { - recognizer_ptr->SetInputFinished(); - } - - // no overlap - sample_offset += cur_chunk_size; - } - CHECK(sample_offset == tot_samples); - recognizer_ptr->WaitDecodeFinished(); - - kaldi::Timer timer; - recognizer_ptr->AttentionRescoring(); - tot_attention_rescore_time += timer.Elapsed(); - - std::string result = recognizer_ptr->GetFinalResult(); - if (result.empty()) { - // the TokenWriter can not write empty string. - ++num_err; - LOG(INFO) << " the result of " << utt << " is empty"; - continue; - } - - tot_decode_time += local_timer.Elapsed(); - LOG(INFO) << utt << " " << result; - LOG(INFO) << " RTF: " << local_timer.Elapsed() / dur << " dur: " << dur - << " cost: " << local_timer.Elapsed(); - - result_writer.Write(utt, result); - - ++num_done; - } - recognizer_ptr->WaitFinished(); - - LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done); - LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec"; - LOG(INFO) << "total decode cost:" << tot_decode_time << " sec"; - LOG(INFO) << "total rescore cost:" << tot_attention_rescore_time << " sec"; - LOG(INFO) << "RTF is: " << tot_decode_time / tot_wav_duration; -}