diff --git a/speechx/speechx/asr/nnet/u2_nnet.cc b/speechx/speechx/asr/nnet/u2_nnet.cc index e3277a38..0795c836 100644 --- a/speechx/speechx/asr/nnet/u2_nnet.cc +++ b/speechx/speechx/asr/nnet/u2_nnet.cc @@ -118,27 +118,38 @@ U2Nnet::U2Nnet(const ModelOptions& opts) : opts_(opts) { // shallow copy U2Nnet::U2Nnet(const U2Nnet& other) { // copy meta - right_context_ = other.right_context_; - subsampling_rate_ = other.subsampling_rate_; - sos_ = other.sos_; - eos_ = other.eos_; - is_bidecoder_ = other.is_bidecoder_; chunk_size_ = other.chunk_size_; num_left_chunks_ = other.num_left_chunks_; - - forward_encoder_chunk_ = other.forward_encoder_chunk_; - forward_attention_decoder_ = other.forward_attention_decoder_; - ctc_activation_ = other.ctc_activation_; - offset_ = other.offset_; // copy model ptr - model_ = other.model_; + model_ = other.model_->Clone(); + ctc_activation_ = model_->Function("ctc_activation"); + subsampling_rate_ = model_->Attribute("subsampling_rate"); + right_context_ = model_->Attribute("right_context"); + sos_ = model_->Attribute("sos_symbol"); + eos_ = model_->Attribute("eos_symbol"); + is_bidecoder_ = model_->Attribute("is_bidirectional_decoder"); + + forward_encoder_chunk_ = model_->Function("forward_encoder_chunk"); + forward_attention_decoder_ = model_->Function("forward_attention_decoder"); + ctc_activation_ = model_->Function("ctc_activation"); + CHECK(forward_encoder_chunk_.IsValid()); + CHECK(forward_attention_decoder_.IsValid()); + CHECK(ctc_activation_.IsValid()); + + LOG(INFO) << "Paddle Model Info: "; + LOG(INFO) << "\tsubsampling_rate " << subsampling_rate_; + LOG(INFO) << "\tright context " << right_context_; + LOG(INFO) << "\tsos " << sos_; + LOG(INFO) << "\teos " << eos_; + LOG(INFO) << "\tis bidecoder " << is_bidecoder_ << std::endl; + // ignore inner states } -std::shared_ptr U2Nnet::Copy() const { +std::shared_ptr U2Nnet::Clone() const { auto asr_model = std::make_shared(*this); // reset inner state for new decoding asr_model->Reset(); diff --git a/speechx/speechx/asr/nnet/u2_nnet.h b/speechx/speechx/asr/nnet/u2_nnet.h index 127d84db..35a15707 100644 --- a/speechx/speechx/asr/nnet/u2_nnet.h +++ b/speechx/speechx/asr/nnet/u2_nnet.h @@ -42,7 +42,7 @@ class U2NnetBase : public NnetBase { num_left_chunks_ = num_left_chunks; } - virtual std::shared_ptr Copy() const = 0; + virtual std::shared_ptr Clone() const = 0; protected: virtual void ForwardEncoderChunkImpl( @@ -91,7 +91,7 @@ class U2Nnet : public U2NnetBase { std::shared_ptr model() const { return model_; } - std::shared_ptr Copy() const override; + std::shared_ptr Clone() const override; void ForwardEncoderChunkImpl( const std::vector& chunk_feats, diff --git a/speechx/speechx/asr/recognizer/CMakeLists.txt b/speechx/speechx/asr/recognizer/CMakeLists.txt index 8f9117e4..f28c5fea 100644 --- a/speechx/speechx/asr/recognizer/CMakeLists.txt +++ b/speechx/speechx/asr/recognizer/CMakeLists.txt @@ -10,6 +10,7 @@ target_link_libraries(recognizer PUBLIC decoder) set(TEST_BINS u2_recognizer_main u2_recognizer_thread_main + u2_recognizer_batch_main ) foreach(bin_name IN LISTS TEST_BINS) diff --git a/speechx/speechx/asr/recognizer/u2_recognizer.cc b/speechx/speechx/asr/recognizer/u2_recognizer.cc index 0c5a8941..30595d79 100644 --- a/speechx/speechx/asr/recognizer/u2_recognizer.cc +++ b/speechx/speechx/asr/recognizer/u2_recognizer.cc @@ -43,12 +43,34 @@ U2Recognizer::U2Recognizer(const U2RecognizerResource& resource) input_finished_ = false; num_frames_ = 0; result_.clear(); +} + +U2Recognizer::U2Recognizer(const U2RecognizerResource& resource, + std::shared_ptr nnet) + : opts_(resource) { + BaseFloat am_scale = resource.acoustic_scale; + const FeaturePipelineOptions& feature_opts = resource.feature_pipeline_opts; + std::shared_ptr feature_pipeline = + std::make_shared(feature_opts); + nnet_producer_.reset(new NnetProducer(nnet, feature_pipeline)); + decodable_.reset(new Decodable(nnet_producer_, am_scale)); + + CHECK_NE(resource.vocab_path, ""); + decoder_.reset(new CTCPrefixBeamSearch( + resource.vocab_path, resource.decoder_opts.ctc_prefix_search_opts)); + + unit_table_ = decoder_->VocabTable(); + symbol_table_ = unit_table_; + global_frame_offset_ = 0; + input_finished_ = false; + num_frames_ = 0; + result_.clear(); } U2Recognizer::~U2Recognizer() { - SetInputFinished(); - WaitDecodeFinished(); + SetInputFinished(); + WaitDecodeFinished(); } void U2Recognizer::WaitDecodeFinished() { @@ -97,8 +119,8 @@ void U2Recognizer::RunDecoderSearchInternal() { void U2Recognizer::Accept(const vector& waves) { kaldi::Timer timer; nnet_producer_->Accept(waves); - VLOG(1) << "feed waves cost: " << timer.Elapsed() << " sec. " << waves.size() - << " samples."; + VLOG(1) << "feed waves cost: " << timer.Elapsed() << " sec. " + << waves.size() << " samples."; } void U2Recognizer::Decode() { diff --git a/speechx/speechx/asr/recognizer/u2_recognizer.h b/speechx/speechx/asr/recognizer/u2_recognizer.h index 57f2c9c5..5d628e3a 100644 --- a/speechx/speechx/asr/recognizer/u2_recognizer.h +++ b/speechx/speechx/asr/recognizer/u2_recognizer.h @@ -112,6 +112,8 @@ struct U2RecognizerResource { class U2Recognizer { public: explicit U2Recognizer(const U2RecognizerResource& resouce); + explicit U2Recognizer(const U2RecognizerResource& resource, + std::shared_ptr nnet); ~U2Recognizer(); void InitDecoder(); void ResetContinuousDecoding(); @@ -143,7 +145,7 @@ class U2Recognizer { void AttentionRescoring(); private: - static void RunDecoderSearch(U2Recognizer *me); + static void RunDecoderSearch(U2Recognizer* me); void RunDecoderSearchInternal(); void UpdateResult(bool finish = false); diff --git a/speechx/speechx/asr/recognizer/u2_recognizer_batch_main.cc b/speechx/speechx/asr/recognizer/u2_recognizer_batch_main.cc new file mode 100644 index 00000000..709e5aa6 --- /dev/null +++ b/speechx/speechx/asr/recognizer/u2_recognizer_batch_main.cc @@ -0,0 +1,185 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "recognizer/u2_recognizer.h" +#include "common/base/thread_pool.h" +#include "common/utils/file_utils.h" +#include "common/utils/strings.h" +#include "decoder/param.h" +#include "frontend/wave-reader.h" +#include "kaldi/util/table-types.h" +#include "nnet/u2_nnet.h" + +DEFINE_string(wav_rspecifier, "", "test feature rspecifier"); +DEFINE_string(result_wspecifier, "", "test result wspecifier"); +DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size"); +DEFINE_int32(sample_rate, 16000, "sample rate"); +DEFINE_int32(njob, 3, "njob"); + +using std::string; +using std::vector; + +void SplitUtt(string wavlist_file, + vector>* uttlists, + vector>* wavlists, + int njob) { + vector wavlist; + wavlists->resize(njob); + uttlists->resize(njob); + ppspeech::ReadFileToVector(wavlist_file, &wavlist); + for (size_t idx = 0; idx < wavlist.size(); ++idx) { + string utt_str = wavlist[idx]; + vector utt_wav = ppspeech::StrSplit(utt_str, " \t"); + LOG(INFO) << utt_wav[0]; + CHECK_EQ(utt_wav.size(), size_t(2)); + uttlists->at(idx % njob).push_back(utt_wav[0]); + wavlists->at(idx % njob).push_back(utt_wav[1]); + } +} + +void recognizer_func(const ppspeech::U2RecognizerResource& resource, + std::shared_ptr nnet, + std::vector wavlist, + std::vector uttlist, + std::vector* results) { + int32 num_done = 0, num_err = 0; + double tot_wav_duration = 0.0; + double tot_attention_rescore_time = 0.0; + double tot_decode_time = 0.0; + int chunk_sample_size = FLAGS_streaming_chunk * FLAGS_sample_rate; + if (wavlist.empty()) return; + + std::shared_ptr recognizer_ptr = + std::make_shared(resource, nnet); + + results->reserve(wavlist.size()); + for (size_t idx = 0; idx < wavlist.size(); ++idx) { + std::string utt = uttlist[idx]; + std::string wav_file = wavlist[idx]; + std::ifstream infile; + infile.open(wav_file, std::ifstream::in); + kaldi::WaveData wave_data; + wave_data.Read(infile); + recognizer_ptr->InitDecoder(); + LOG(INFO) << "utt: " << utt; + LOG(INFO) << "wav dur: " << wave_data.Duration() << " sec."; + double dur = wave_data.Duration(); + tot_wav_duration += dur; + + int32 this_channel = 0; + kaldi::SubVector waveform(wave_data.Data(), + this_channel); + int tot_samples = waveform.Dim(); + LOG(INFO) << "wav len (sample): " << tot_samples; + + int sample_offset = 0; + kaldi::Timer local_timer; + + while (sample_offset < tot_samples) { + int cur_chunk_size = + std::min(chunk_sample_size, tot_samples - sample_offset); + + std::vector wav_chunk(cur_chunk_size); + for (int i = 0; i < cur_chunk_size; ++i) { + wav_chunk[i] = waveform(sample_offset + i); + } + + recognizer_ptr->Accept(wav_chunk); + if (cur_chunk_size < chunk_sample_size) { + recognizer_ptr->SetInputFinished(); + } + + // no overlap + sample_offset += cur_chunk_size; + } + CHECK(sample_offset == tot_samples); + recognizer_ptr->WaitDecodeFinished(); + + kaldi::Timer timer; + recognizer_ptr->AttentionRescoring(); + tot_attention_rescore_time += timer.Elapsed(); + + std::string result = recognizer_ptr->GetFinalResult(); + if (result.empty()) { + // the TokenWriter can not write empty string. + ++num_err; + LOG(INFO) << " the result of " << utt << " is empty"; + result = " "; + } + + tot_decode_time += local_timer.Elapsed(); + LOG(INFO) << utt << " " << result; + LOG(INFO) << " RTF: " << local_timer.Elapsed() / dur << " dur: " << dur + << " cost: " << local_timer.Elapsed(); + + results->push_back(result); + ++num_done; + } + recognizer_ptr->WaitFinished(); + LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done); + LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec"; + LOG(INFO) << "total decode cost:" << tot_decode_time << " sec"; + LOG(INFO) << "total rescore cost:" << tot_attention_rescore_time << " sec"; + LOG(INFO) << "RTF is: " << tot_decode_time / tot_wav_duration; +} + +int main(int argc, char* argv[]) { + gflags::SetUsageMessage("Usage:"); + gflags::ParseCommandLineFlags(&argc, &argv, false); + google::InitGoogleLogging(argv[0]); + google::InstallFailureSignalHandler(); + FLAGS_logtostderr = 1; + + int sample_rate = FLAGS_sample_rate; + float streaming_chunk = FLAGS_streaming_chunk; + int chunk_sample_size = streaming_chunk * sample_rate; + kaldi::TokenWriter result_writer(FLAGS_result_wspecifier); + int njob = FLAGS_njob; + LOG(INFO) << "sr: " << sample_rate; + LOG(INFO) << "chunk size (s): " << streaming_chunk; + LOG(INFO) << "chunk size (sample): " << chunk_sample_size; + + ppspeech::U2RecognizerResource resource = + ppspeech::U2RecognizerResource::InitFromFlags(); + ThreadPool threadpool(njob); + vector> wavlist; + vector> uttlist; + vector> resultlist(njob); + vector> futurelist; + std::shared_ptr nnet( + new ppspeech::U2Nnet(resource.model_opts)); + SplitUtt(FLAGS_wav_rspecifier, &uttlist, &wavlist, njob); + for (size_t i = 0; i < njob; ++i) { + std::future f = threadpool.enqueue(recognizer_func, + resource, + nnet->Clone(), + wavlist[i], + uttlist[i], + &resultlist[i]); + futurelist.push_back(std::move(f)); + } + + for (size_t i = 0; i < njob; ++i) { + futurelist[i].get(); + } + + for (size_t idx = 0; idx < njob; ++idx) { + for (size_t utt_idx = 0; utt_idx < uttlist[idx].size(); ++utt_idx) { + string utt = uttlist[idx][utt_idx]; + string result = resultlist[idx][utt_idx]; + result_writer.Write(utt, result); + } + } + return 0; +} diff --git a/speechx/speechx/common/base/common.h b/speechx/speechx/common/base/common.h index 2a066ee6..06fcd9fd 100644 --- a/speechx/speechx/common/base/common.h +++ b/speechx/speechx/common/base/common.h @@ -42,6 +42,8 @@ #include #include #include +#include +#include #include "base/basic_types.h" #include "base/flags.h"