pull/2866/head
YangZhou 3 years ago
parent cf43a69966
commit 2f29a0a461

@ -43,16 +43,15 @@ U2Recognizer::U2Recognizer(const U2RecognizerResource& resource)
input_finished_ = false; input_finished_ = false;
num_frames_ = 0; num_frames_ = 0;
result_.clear(); result_.clear();
} }
U2Recognizer::U2Recognizer(const U2RecognizerResource& resource, U2Recognizer::U2Recognizer(const U2RecognizerResource& resource,
std::shared_ptr<NnetBase>nnet) std::shared_ptr<NnetBase> nnet)
: opts_(resource) { : opts_(resource) {
BaseFloat am_scale = resource.acoustic_scale; BaseFloat am_scale = resource.acoustic_scale;
const FeaturePipelineOptions& feature_opts = resource.feature_pipeline_opts; const FeaturePipelineOptions& feature_opts = resource.feature_pipeline_opts;
std::shared_ptr<FeaturePipeline> feature_pipeline( std::shared_ptr<FeaturePipeline> feature_pipeline =
new FeaturePipeline(feature_opts)); std::make_shared<FeaturePipeline>(feature_opts);
nnet_producer_.reset(new NnetProducer(nnet, feature_pipeline)); nnet_producer_.reset(new NnetProducer(nnet, feature_pipeline));
decodable_.reset(new Decodable(nnet_producer_, am_scale)); decodable_.reset(new Decodable(nnet_producer_, am_scale));
@ -70,8 +69,8 @@ U2Recognizer::U2Recognizer(const U2RecognizerResource& resource,
} }
U2Recognizer::~U2Recognizer() { U2Recognizer::~U2Recognizer() {
SetInputFinished(); SetInputFinished();
WaitDecodeFinished(); WaitDecodeFinished();
} }
void U2Recognizer::WaitDecodeFinished() { void U2Recognizer::WaitDecodeFinished() {
@ -120,8 +119,8 @@ void U2Recognizer::RunDecoderSearchInternal() {
void U2Recognizer::Accept(const vector<BaseFloat>& waves) { void U2Recognizer::Accept(const vector<BaseFloat>& waves) {
kaldi::Timer timer; kaldi::Timer timer;
nnet_producer_->Accept(waves); nnet_producer_->Accept(waves);
VLOG(1) << "feed waves cost: " << timer.Elapsed() << " sec. " << waves.size() VLOG(1) << "feed waves cost: " << timer.Elapsed() << " sec. "
<< " samples."; << waves.size() << " samples.";
} }
void U2Recognizer::Decode() { void U2Recognizer::Decode() {

@ -111,9 +111,9 @@ struct U2RecognizerResource {
class U2Recognizer { class U2Recognizer {
public: public:
U2Recognizer(const U2RecognizerResource& resouce); explict U2Recognizer(const U2RecognizerResource& resouce);
U2Recognizer(const U2RecognizerResource& resource, explict U2Recognizer(const U2RecognizerResource& resource,
std::shared_ptr<NnetBase>nnet); std::shared_ptr<NnetBase> nnet);
~U2Recognizer(); ~U2Recognizer();
void InitDecoder(); void InitDecoder();
void ResetContinuousDecoding(); void ResetContinuousDecoding();
@ -145,7 +145,7 @@ class U2Recognizer {
void AttentionRescoring(); void AttentionRescoring();
private: private:
static void RunDecoderSearch(U2Recognizer *me); static void RunDecoderSearch(U2Recognizer* me);
void RunDecoderSearchInternal(); void RunDecoderSearchInternal();
void UpdateResult(bool finish = false); void UpdateResult(bool finish = false);

@ -12,14 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "recognizer/u2_recognizer.h"
#include "common/base/thread_pool.h" #include "common/base/thread_pool.h"
#include "common/utils/strings.h"
#include "common/utils/file_utils.h" #include "common/utils/file_utils.h"
#include "common/utils/strings.h"
#include "decoder/param.h" #include "decoder/param.h"
#include "frontend/wave-reader.h" #include "frontend/wave-reader.h"
#include "nnet/u2_nnet.h"
#include "kaldi/util/table-types.h" #include "kaldi/util/table-types.h"
#include "recognizer/u2_recognizer.h" #include "nnet/u2_nnet.h"
DEFINE_string(wav_rspecifier, "", "test feature rspecifier"); DEFINE_string(wav_rspecifier, "", "test feature rspecifier");
DEFINE_string(result_wspecifier, "", "test result wspecifier"); DEFINE_string(result_wspecifier, "", "test result wspecifier");
@ -31,10 +31,10 @@ using std::string;
using std::vector; using std::vector;
void SplitUtt(string wavlist_file, void SplitUtt(string wavlist_file,
vector<vector<string>>* uttlists, vector<vector<string>>* uttlists,
vector<vector<string>>* wavlists, vector<vector<string>>* wavlists,
int njob) { int njob) {
vector<string> wavlist; vector<string> wavlist;
wavlists->resize(njob); wavlists->resize(njob);
uttlists->resize(njob); uttlists->resize(njob);
ppspeech::ReadFileToVector(wavlist_file, &wavlist); ppspeech::ReadFileToVector(wavlist_file, &wavlist);
@ -43,13 +43,13 @@ void SplitUtt(string wavlist_file,
vector<string> utt_wav = ppspeech::StrSplit(utt_str, " \t"); vector<string> utt_wav = ppspeech::StrSplit(utt_str, " \t");
LOG(INFO) << utt_wav[0]; LOG(INFO) << utt_wav[0];
CHECK_EQ(utt_wav.size(), size_t(2)); CHECK_EQ(utt_wav.size(), size_t(2));
uttlists->at(idx % njob).push_back(utt_wav[0]); uttlists->at(idx % njob).push_back(utt_wav[0]);
wavlists->at(idx % njob).push_back(utt_wav[1]); wavlists->at(idx % njob).push_back(utt_wav[1]);
} }
} }
void recognizer_func(const ppspeech::U2RecognizerResource& resource, void recognizer_func(const ppspeech::U2RecognizerResource& resource,
std::shared_ptr<ppspeech::NnetBase> nnet, std::shared_ptr<ppspeech::NnetBase> nnet,
std::vector<string> wavlist, std::vector<string> wavlist,
std::vector<string> uttlist, std::vector<string> uttlist,
std::vector<string>* results) { std::vector<string>* results) {
@ -60,8 +60,8 @@ void recognizer_func(const ppspeech::U2RecognizerResource& resource,
int chunk_sample_size = FLAGS_streaming_chunk * FLAGS_sample_rate; int chunk_sample_size = FLAGS_streaming_chunk * FLAGS_sample_rate;
if (wavlist.empty()) return; if (wavlist.empty()) return;
std::shared_ptr<ppspeech::U2Recognizer> recognizer_ptr( std::shared_ptr<ppspeech::U2Recognizer> recognizer_ptr =
new ppspeech::U2Recognizer(resource, nnet)); std::make_shared<ppspeech::U2Recognizer>(resource, nnet);
results->reserve(wavlist.size()); results->reserve(wavlist.size());
for (size_t idx = 0; idx < wavlist.size(); ++idx) { for (size_t idx = 0; idx < wavlist.size(); ++idx) {
@ -118,22 +118,22 @@ void recognizer_func(const ppspeech::U2RecognizerResource& resource,
result = " "; result = " ";
} }
tot_decode_time += local_timer.Elapsed(); tot_decode_time += local_timer.Elapsed();
LOG(INFO) << utt << " " << result; LOG(INFO) << utt << " " << result;
LOG(INFO) << " RTF: " << local_timer.Elapsed() / dur << " dur: " << dur LOG(INFO) << " RTF: " << local_timer.Elapsed() / dur << " dur: " << dur
<< " cost: " << local_timer.Elapsed(); << " cost: " << local_timer.Elapsed();
results->push_back(result); results->push_back(result);
++num_done; ++num_done;
} }
recognizer_ptr->WaitFinished(); recognizer_ptr->WaitFinished();
LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done); LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done);
LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec"; LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec";
LOG(INFO) << "total decode cost:" << tot_decode_time << " sec"; LOG(INFO) << "total decode cost:" << tot_decode_time << " sec";
LOG(INFO) << "total rescore cost:" << tot_attention_rescore_time << " sec"; LOG(INFO) << "total rescore cost:" << tot_attention_rescore_time << " sec";
LOG(INFO) << "RTF is: " << tot_decode_time / tot_wav_duration; LOG(INFO) << "RTF is: " << tot_decode_time / tot_wav_duration;
} }
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:"); gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false); gflags::ParseCommandLineFlags(&argc, &argv, false);
@ -157,14 +157,19 @@ int main(int argc, char* argv[]) {
vector<vector<string>> uttlist; vector<vector<string>> uttlist;
vector<vector<string>> resultlist(njob); vector<vector<string>> resultlist(njob);
vector<std::future<void>> futurelist; vector<std::future<void>> futurelist;
std::thread threads[njob]; std::shared_ptr<ppspeech::U2Nnet> nnet(
std::shared_ptr<ppspeech::U2Nnet> nnet(new ppspeech::U2Nnet(resource.model_opts)); new ppspeech::U2Nnet(resource.model_opts));
SplitUtt(FLAGS_wav_rspecifier, &uttlist, &wavlist, njob); SplitUtt(FLAGS_wav_rspecifier, &uttlist, &wavlist, njob);
for (size_t i = 0; i < njob; ++i) { for (size_t i = 0; i < njob; ++i) {
std::future<void> f = threadpool.enqueue(recognizer_func, resource, nnet->Clone(), wavlist[i], uttlist[i], &resultlist[i]); std::future<void> f = threadpool.enqueue(recognizer_func,
resource,
nnet->Clone(),
wavlist[i],
uttlist[i],
&resultlist[i]);
futurelist.push_back(std::move(f)); futurelist.push_back(std::move(f));
} }
for (size_t i = 0; i < njob; ++i) { for (size_t i = 0; i < njob; ++i) {
futurelist[i].get(); futurelist[i].get();
} }

Loading…
Cancel
Save