diff --git a/speechx/speechx/asr/recognizer/u2_recognizer.cc b/speechx/speechx/asr/recognizer/u2_recognizer.cc index fcda097a3..30595d79f 100644 --- a/speechx/speechx/asr/recognizer/u2_recognizer.cc +++ b/speechx/speechx/asr/recognizer/u2_recognizer.cc @@ -43,16 +43,15 @@ U2Recognizer::U2Recognizer(const U2RecognizerResource& resource) input_finished_ = false; num_frames_ = 0; result_.clear(); - } -U2Recognizer::U2Recognizer(const U2RecognizerResource& resource, - std::shared_ptrnnet) +U2Recognizer::U2Recognizer(const U2RecognizerResource& resource, + std::shared_ptr nnet) : opts_(resource) { BaseFloat am_scale = resource.acoustic_scale; const FeaturePipelineOptions& feature_opts = resource.feature_pipeline_opts; - std::shared_ptr feature_pipeline( - new FeaturePipeline(feature_opts)); + std::shared_ptr feature_pipeline = + std::make_shared(feature_opts); nnet_producer_.reset(new NnetProducer(nnet, feature_pipeline)); decodable_.reset(new Decodable(nnet_producer_, am_scale)); @@ -70,8 +69,8 @@ U2Recognizer::U2Recognizer(const U2RecognizerResource& resource, } U2Recognizer::~U2Recognizer() { - SetInputFinished(); - WaitDecodeFinished(); + SetInputFinished(); + WaitDecodeFinished(); } void U2Recognizer::WaitDecodeFinished() { @@ -120,8 +119,8 @@ void U2Recognizer::RunDecoderSearchInternal() { void U2Recognizer::Accept(const vector& waves) { kaldi::Timer timer; nnet_producer_->Accept(waves); - VLOG(1) << "feed waves cost: " << timer.Elapsed() << " sec. " << waves.size() - << " samples."; + VLOG(1) << "feed waves cost: " << timer.Elapsed() << " sec. " + << waves.size() << " samples."; } void U2Recognizer::Decode() { diff --git a/speechx/speechx/asr/recognizer/u2_recognizer.h b/speechx/speechx/asr/recognizer/u2_recognizer.h index cff9a253f..5cd407000 100644 --- a/speechx/speechx/asr/recognizer/u2_recognizer.h +++ b/speechx/speechx/asr/recognizer/u2_recognizer.h @@ -111,9 +111,9 @@ struct U2RecognizerResource { class U2Recognizer { public: - U2Recognizer(const U2RecognizerResource& resouce); - U2Recognizer(const U2RecognizerResource& resource, - std::shared_ptrnnet); + explict U2Recognizer(const U2RecognizerResource& resouce); + explict U2Recognizer(const U2RecognizerResource& resource, + std::shared_ptr nnet); ~U2Recognizer(); void InitDecoder(); void ResetContinuousDecoding(); @@ -145,7 +145,7 @@ class U2Recognizer { void AttentionRescoring(); private: - static void RunDecoderSearch(U2Recognizer *me); + static void RunDecoderSearch(U2Recognizer* me); void RunDecoderSearchInternal(); void UpdateResult(bool finish = false); diff --git a/speechx/speechx/asr/recognizer/u2_recognizer_batch_main.cc b/speechx/speechx/asr/recognizer/u2_recognizer_batch_main.cc index 19887cdb0..709e5aa62 100644 --- a/speechx/speechx/asr/recognizer/u2_recognizer_batch_main.cc +++ b/speechx/speechx/asr/recognizer/u2_recognizer_batch_main.cc @@ -12,14 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "recognizer/u2_recognizer.h" #include "common/base/thread_pool.h" -#include "common/utils/strings.h" #include "common/utils/file_utils.h" +#include "common/utils/strings.h" #include "decoder/param.h" #include "frontend/wave-reader.h" -#include "nnet/u2_nnet.h" #include "kaldi/util/table-types.h" -#include "recognizer/u2_recognizer.h" +#include "nnet/u2_nnet.h" DEFINE_string(wav_rspecifier, "", "test feature rspecifier"); DEFINE_string(result_wspecifier, "", "test result wspecifier"); @@ -31,10 +31,10 @@ using std::string; using std::vector; void SplitUtt(string wavlist_file, - vector>* uttlists, - vector>* wavlists, - int njob) { - vector wavlist; + vector>* uttlists, + vector>* wavlists, + int njob) { + vector wavlist; wavlists->resize(njob); uttlists->resize(njob); ppspeech::ReadFileToVector(wavlist_file, &wavlist); @@ -43,13 +43,13 @@ void SplitUtt(string wavlist_file, vector utt_wav = ppspeech::StrSplit(utt_str, " \t"); LOG(INFO) << utt_wav[0]; CHECK_EQ(utt_wav.size(), size_t(2)); - uttlists->at(idx % njob).push_back(utt_wav[0]); - wavlists->at(idx % njob).push_back(utt_wav[1]); + uttlists->at(idx % njob).push_back(utt_wav[0]); + wavlists->at(idx % njob).push_back(utt_wav[1]); } } void recognizer_func(const ppspeech::U2RecognizerResource& resource, - std::shared_ptr nnet, + std::shared_ptr nnet, std::vector wavlist, std::vector uttlist, std::vector* results) { @@ -60,8 +60,8 @@ void recognizer_func(const ppspeech::U2RecognizerResource& resource, int chunk_sample_size = FLAGS_streaming_chunk * FLAGS_sample_rate; if (wavlist.empty()) return; - std::shared_ptr recognizer_ptr( - new ppspeech::U2Recognizer(resource, nnet)); + std::shared_ptr recognizer_ptr = + std::make_shared(resource, nnet); results->reserve(wavlist.size()); for (size_t idx = 0; idx < wavlist.size(); ++idx) { @@ -118,22 +118,22 @@ void recognizer_func(const ppspeech::U2RecognizerResource& resource, result = " "; } - tot_decode_time += local_timer.Elapsed(); + tot_decode_time += local_timer.Elapsed(); LOG(INFO) << utt << " " << result; LOG(INFO) << " RTF: " << local_timer.Elapsed() / dur << " dur: " << dur - << " cost: " << local_timer.Elapsed(); + << " cost: " << local_timer.Elapsed(); results->push_back(result); ++num_done; - } - recognizer_ptr->WaitFinished(); - LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done); - LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec"; - LOG(INFO) << "total decode cost:" << tot_decode_time << " sec"; - LOG(INFO) << "total rescore cost:" << tot_attention_rescore_time << " sec"; - LOG(INFO) << "RTF is: " << tot_decode_time / tot_wav_duration; + } + recognizer_ptr->WaitFinished(); + LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done); + LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec"; + LOG(INFO) << "total decode cost:" << tot_decode_time << " sec"; + LOG(INFO) << "total rescore cost:" << tot_attention_rescore_time << " sec"; + LOG(INFO) << "RTF is: " << tot_decode_time / tot_wav_duration; } - + int main(int argc, char* argv[]) { gflags::SetUsageMessage("Usage:"); gflags::ParseCommandLineFlags(&argc, &argv, false); @@ -157,14 +157,19 @@ int main(int argc, char* argv[]) { vector> uttlist; vector> resultlist(njob); vector> futurelist; - std::thread threads[njob]; - std::shared_ptr nnet(new ppspeech::U2Nnet(resource.model_opts)); - SplitUtt(FLAGS_wav_rspecifier, &uttlist, &wavlist, njob); + std::shared_ptr nnet( + new ppspeech::U2Nnet(resource.model_opts)); + SplitUtt(FLAGS_wav_rspecifier, &uttlist, &wavlist, njob); for (size_t i = 0; i < njob; ++i) { - std::future f = threadpool.enqueue(recognizer_func, resource, nnet->Clone(), wavlist[i], uttlist[i], &resultlist[i]); + std::future f = threadpool.enqueue(recognizer_func, + resource, + nnet->Clone(), + wavlist[i], + uttlist[i], + &resultlist[i]); futurelist.push_back(std::move(f)); } - + for (size_t i = 0; i < njob; ++i) { futurelist[i].get(); }