fix nnet thread crash && rescore cost time

pull/2839/head
YangZhou 3 years ago
parent f9c7b1bcc2
commit e7d8ecf30c

@ -63,8 +63,9 @@ void CTCPrefixBeamSearch::Reset() {
times_.emplace_back(empty); times_.emplace_back(empty);
} }
void CTCPrefixBeamSearch::InitDecoder() { Reset(); } void CTCPrefixBeamSearch::InitDecoder() {
Reset();
}
void CTCPrefixBeamSearch::AdvanceDecode( void CTCPrefixBeamSearch::AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable) { const std::shared_ptr<kaldi::DecodableInterface>& decodable) {
@ -77,7 +78,7 @@ void CTCPrefixBeamSearch::AdvanceDecode(
bool flag = decodable->FrameLikelihood(num_frame_decoded_, &frame_prob); bool flag = decodable->FrameLikelihood(num_frame_decoded_, &frame_prob);
feat_nnet_cost += timer.Elapsed(); feat_nnet_cost += timer.Elapsed();
if (flag == false) { if (flag == false) {
VLOG(3) << "decoder advance decode exit." << frame_prob.size(); VLOG(2) << "decoder advance decode exit." << frame_prob.size();
break; break;
} }
@ -87,7 +88,7 @@ void CTCPrefixBeamSearch::AdvanceDecode(
AdvanceDecoding(likelihood); AdvanceDecoding(likelihood);
search_cost += timer.Elapsed(); search_cost += timer.Elapsed();
VLOG(2) << "num_frame_decoded_: " << num_frame_decoded_; VLOG(1) << "num_frame_decoded_: " << num_frame_decoded_;
} }
VLOG(1) << "AdvanceDecode feat + forward cost: " << feat_nnet_cost VLOG(1) << "AdvanceDecode feat + forward cost: " << feat_nnet_cost
<< " sec."; << " sec.";

@ -33,6 +33,14 @@ void NnetProducer::Accept(const std::vector<kaldi::BaseFloat>& inputs) {
condition_variable_.notify_one(); condition_variable_.notify_one();
} }
void NnetProducer::UnLock() {
std::unique_lock<std::mutex> lock(read_mutex_);
while (frontend_->IsFinished() == false && cache_.empty()) {
condition_read_ready_.wait(lock);
}
return;
}
void NnetProducer::RunNnetEvaluation(NnetProducer *me) { void NnetProducer::RunNnetEvaluation(NnetProducer *me) {
me->RunNnetEvaluationInteral(); me->RunNnetEvaluationInteral();
} }
@ -47,7 +55,7 @@ void NnetProducer::RunNnetEvaluationInteral() {
result = Compute(); result = Compute();
} while (result); } while (result);
if (frontend_->IsFinished() == true) { if (frontend_->IsFinished() == true) {
Compute(); //Compute();
if (cache_.empty()) finished_ = true; if (cache_.empty()) finished_ = true;
} }
} }
@ -61,8 +69,8 @@ void NnetProducer::Acceptlikelihood(
for (size_t idx = 0; idx < likelihood.NumRows(); ++idx) { for (size_t idx = 0; idx < likelihood.NumRows(); ++idx) {
for (size_t col = 0; col < likelihood.NumCols(); ++col) { for (size_t col = 0; col < likelihood.NumCols(); ++col) {
prob[col] = likelihood(idx, col); prob[col] = likelihood(idx, col);
cache_.push_back(prob);
} }
cache_.push_back(prob);
} }
} }
@ -100,6 +108,7 @@ bool NnetProducer::Compute() {
out.logprobs.data() + idx * vocab_dim, out.logprobs.data() + idx * vocab_dim,
out.logprobs.data() + (idx + 1) * vocab_dim); out.logprobs.data() + (idx + 1) * vocab_dim);
cache_.push_back(logprob); cache_.push_back(logprob);
condition_read_ready_.notify_one();
} }
return true; return true;
} }

@ -36,6 +36,7 @@ class NnetProducer {
bool ReadandCompute(std::vector<kaldi::BaseFloat>* nnet_prob); bool ReadandCompute(std::vector<kaldi::BaseFloat>* nnet_prob);
static void RunNnetEvaluation(NnetProducer *me); static void RunNnetEvaluation(NnetProducer *me);
void RunNnetEvaluationInteral(); void RunNnetEvaluationInteral();
void UnLock();
void Wait() { void Wait() {
abort_ = true; abort_ = true;
@ -49,7 +50,6 @@ class NnetProducer {
LOG(INFO) << "set finished"; LOG(INFO) << "set finished";
frontend_->SetFinished(); frontend_->SetFinished();
condition_variable_.notify_one(); condition_variable_.notify_one();
LOG(INFO) << "compute last feats done.";
} }
// the compute thread exit // the compute thread exit
@ -60,13 +60,11 @@ class NnetProducer {
} }
void Reset() { void Reset() {
//if (thread_.joinable()) thread_.join();
frontend_->Reset(); frontend_->Reset();
nnet_->Reset(); nnet_->Reset();
VLOG(3) << "feature cache reset: cache size: " << cache_.size(); VLOG(3) << "feature cache reset: cache size: " << cache_.size();
cache_.clear(); cache_.clear();
finished_ = false; finished_ = false;
//thread_ = std::thread(RunNnetEvaluation, this);
} }
void AttentionRescoring(const std::vector<std::vector<int>>& hyps, void AttentionRescoring(const std::vector<std::vector<int>>& hyps,
@ -80,7 +78,9 @@ class NnetProducer {
std::shared_ptr<NnetBase> nnet_; std::shared_ptr<NnetBase> nnet_;
SafeQueue<std::vector<kaldi::BaseFloat>> cache_; SafeQueue<std::vector<kaldi::BaseFloat>> cache_;
std::mutex mutex_; std::mutex mutex_;
std::mutex read_mutex_;
std::condition_variable condition_variable_; std::condition_variable condition_variable_;
std::condition_variable condition_read_ready_;
std::thread thread_; std::thread thread_;
bool finished_; bool finished_;
bool abort_; bool abort_;

@ -85,11 +85,12 @@ void U2Recognizer::RunDecoderSearch(U2Recognizer* me) {
} }
void U2Recognizer::RunDecoderSearchInternal() { void U2Recognizer::RunDecoderSearchInternal() {
LOG(INFO) << "DecoderSearchInteral begin";
while (!nnet_producer_->IsFinished()) { while (!nnet_producer_->IsFinished()) {
Decode(); nnet_producer_->UnLock();
decoder_->AdvanceDecode(decodable_);
} }
Decode(); Decode();
Rescoring();
LOG(INFO) << "DecoderSearchInteral exit"; LOG(INFO) << "DecoderSearchInteral exit";
} }

@ -140,18 +140,16 @@ class U2Recognizer {
} }
const std::vector<DecodeResult>& Result() const { return result_; } const std::vector<DecodeResult>& Result() const { return result_; }
void AttentionRescoring();
private: private:
static void RunDecoderSearch(U2Recognizer *me); static void RunDecoderSearch(U2Recognizer *me);
void RunDecoderSearchInternal(); void RunDecoderSearchInternal();
void AttentionRescoring();
void UpdateResult(bool finish = false); void UpdateResult(bool finish = false);
private: private:
U2RecognizerResource opts_; U2RecognizerResource opts_;
// std::shared_ptr<U2RecognizerResource> resource_;
// U2RecognizerResource resource_;
std::shared_ptr<NnetProducer> nnet_producer_; std::shared_ptr<NnetProducer> nnet_producer_;
std::shared_ptr<Decodable> decodable_; std::shared_ptr<Decodable> decodable_;
std::unique_ptr<CTCPrefixBeamSearch> decoder_; std::unique_ptr<CTCPrefixBeamSearch> decoder_;

@ -22,15 +22,6 @@ DEFINE_string(result_wspecifier, "", "test result wspecifier");
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size"); DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
DEFINE_int32(sample_rate, 16000, "sample rate"); DEFINE_int32(sample_rate, 16000, "sample rate");
/*void decode_func(std::shared_ptr<ppspeech::U2Recognizer> recognizer) {
while (!recognizer->IsFinished()) {
recognizer->Decode();
}
recognizer->Decode();
recognizer->Rescoring();
LOG(INFO) << "decode thread exit!!!";
}*/
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:"); gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false); gflags::ParseCommandLineFlags(&argc, &argv, false);
@ -40,6 +31,7 @@ int main(int argc, char* argv[]) {
int32 num_done = 0, num_err = 0; int32 num_done = 0, num_err = 0;
double tot_wav_duration = 0.0; double tot_wav_duration = 0.0;
double tot_attention_rescore_time = 0.0;
double tot_decode_time = 0.0; double tot_decode_time = 0.0;
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader( kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
@ -74,7 +66,6 @@ int main(int argc, char* argv[]) {
LOG(INFO) << "wav len (sample): " << tot_samples; LOG(INFO) << "wav len (sample): " << tot_samples;
int sample_offset = 0; int sample_offset = 0;
kaldi::Timer timer;
kaldi::Timer local_timer; kaldi::Timer local_timer;
while (sample_offset < tot_samples) { while (sample_offset < tot_samples) {
@ -85,7 +76,6 @@ int main(int argc, char* argv[]) {
for (int i = 0; i < cur_chunk_size; ++i) { for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk[i] = waveform(sample_offset + i); wav_chunk[i] = waveform(sample_offset + i);
} }
// wav_chunk = waveform.Range(sample_offset + i, cur_chunk_size);
recognizer_ptr->Accept(wav_chunk); recognizer_ptr->Accept(wav_chunk);
if (cur_chunk_size < chunk_sample_size) { if (cur_chunk_size < chunk_sample_size) {
@ -97,6 +87,11 @@ int main(int argc, char* argv[]) {
} }
CHECK(sample_offset == tot_samples); CHECK(sample_offset == tot_samples);
recognizer_ptr->WaitDecodeFinished(); recognizer_ptr->WaitDecodeFinished();
kaldi::Timer timer;
recognizer_ptr->AttentionRescoring();
tot_attention_rescore_time += timer.Elapsed();
std::string result = recognizer_ptr->GetFinalResult(); std::string result = recognizer_ptr->GetFinalResult();
if (result.empty()) { if (result.empty()) {
// the TokenWriter can not write empty string. // the TokenWriter can not write empty string.
@ -119,5 +114,6 @@ int main(int argc, char* argv[]) {
LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done); LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done);
LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec"; LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec";
LOG(INFO) << "total decode cost:" << tot_decode_time << " sec"; LOG(INFO) << "total decode cost:" << tot_decode_time << " sec";
LOG(INFO) << "total rescore cost:" << tot_attention_rescore_time << " sec";
LOG(INFO) << "RTF is: " << tot_decode_time / tot_wav_duration; LOG(INFO) << "RTF is: " << tot_decode_time / tot_wav_duration;
} }

@ -73,8 +73,7 @@ int main(int argc, char* argv[]) {
new ppspeech::CMVN(FLAGS_cmvn_file, std::move(fbank))); new ppspeech::CMVN(FLAGS_cmvn_file, std::move(fbank)));
// the feature cache output feature chunk by chunk. // the feature cache output feature chunk by chunk.
ppspeech::FeatureCacheOptions feat_cache_opts; ppspeech::FeatureCache feature_cache(kint16max, std::move(cmvn));
ppspeech::FeatureCache feature_cache(feat_cache_opts, std::move(cmvn));
LOG(INFO) << "fbank: " << true; LOG(INFO) << "fbank: " << true;
LOG(INFO) << "feat dim: " << feature_cache.Dim(); LOG(INFO) << "feat dim: " << feature_cache.Dim();

@ -20,10 +20,9 @@ using kaldi::BaseFloat;
using std::unique_ptr; using std::unique_ptr;
using std::vector; using std::vector;
FeatureCache::FeatureCache(FeatureCacheOptions opts, FeatureCache::FeatureCache(size_t max_size,
unique_ptr<FrontendInterface> base_extractor) { unique_ptr<FrontendInterface> base_extractor) {
max_size_ = opts.max_size; max_size_ = max_size;
timeout_ = opts.timeout; // ms
base_extractor_ = std::move(base_extractor); base_extractor_ = std::move(base_extractor);
dim_ = base_extractor_->Dim(); dim_ = base_extractor_->Dim();
} }
@ -44,23 +43,13 @@ bool FeatureCache::Read(std::vector<kaldi::BaseFloat>* feats) {
result = Compute(); result = Compute();
} while (result); } while (result);
} }
//while (cache_.empty() && base_extractor_->IsFinished() == false) {
//// todo refactor: wait
//// ready_read_condition_.wait(lock);
//int32 elapsed = static_cast<int32>(timer.Elapsed() * 1000); // ms
//if (elapsed > timeout_) {
//return false;
//}
//usleep(100); // sleep 0.1 ms
//}
if (cache_.empty()) return false; if (cache_.empty()) return false;
// read from cache // read from cache
*feats = cache_.front(); *feats = cache_.front();
cache_.pop(); cache_.pop();
//ready_feed_condition_.notify_one(); //ready_feed_condition_.notify_one();
VLOG(1) << "FeatureCache::Read cost: " << timer.Elapsed() << " sec."; VLOG(2) << "FeatureCache::Read cost: " << timer.Elapsed() << " sec.";
return true; return true;
} }
@ -80,17 +69,9 @@ bool FeatureCache::Compute() {
int32 start = chunk_idx * dim_; int32 start = chunk_idx * dim_;
vector<BaseFloat> feature_chunk(feature.data() + start, vector<BaseFloat> feature_chunk(feature.data() + start,
feature.data() + start + dim_); feature.data() + start + dim_);
// std::unique_lock<std::mutex> lock(mutex_);
//while (cache_.size() >= max_size_) {
// cache full, wait
// ready_feed_condition_.wait(lock);
//}
// feed cache // feed cache
cache_.push(feature_chunk); cache_.push(feature_chunk);
++nframe_; ++nframe_;
//ready_read_condition_.notify_one();
} }
VLOG(1) << "FeatureCache::Compute cost: " << timer.Elapsed() << " sec. " VLOG(1) << "FeatureCache::Compute cost: " << timer.Elapsed() << " sec. "

@ -19,16 +19,10 @@
namespace ppspeech { namespace ppspeech {
struct FeatureCacheOptions {
int32 max_size;
int32 timeout; // ms
FeatureCacheOptions() : max_size(kint16max), timeout(1) {}
};
class FeatureCache : public FrontendInterface { class FeatureCache : public FrontendInterface {
public: public:
explicit FeatureCache( explicit FeatureCache(
FeatureCacheOptions opts, size_t max_size = kint16max,
std::unique_ptr<FrontendInterface> base_extractor = NULL); std::unique_ptr<FrontendInterface> base_extractor = NULL);
// Feed feats or waves // Feed feats or waves
@ -64,11 +58,8 @@ class FeatureCache : public FrontendInterface {
int32 dim_; int32 dim_;
size_t max_size_; // cache capacity size_t max_size_; // cache capacity
int32 frame_chunk_size_; // window
int32 frame_chunk_stride_; // stride
std::unique_ptr<FrontendInterface> base_extractor_; std::unique_ptr<FrontendInterface> base_extractor_;
kaldi::int32 timeout_; // ms
std::queue<std::vector<BaseFloat>> cache_; // feature cache std::queue<std::vector<BaseFloat>> cache_; // feature cache
std::mutex mutex_; std::mutex mutex_;

@ -33,7 +33,7 @@ FeaturePipeline::FeaturePipeline(const FeaturePipelineOptions& opts)
new ppspeech::CMVN(opts.cmvn_file, std::move(base_feature))); new ppspeech::CMVN(opts.cmvn_file, std::move(base_feature)));
unique_ptr<FrontendInterface> cache( unique_ptr<FrontendInterface> cache(
new ppspeech::FeatureCache(opts.feature_cache_opts, std::move(cmvn))); new ppspeech::FeatureCache(kint16max, std::move(cmvn)));
base_extractor_.reset( base_extractor_.reset(
new ppspeech::Assembler(opts.assembler_opts, std::move(cache))); new ppspeech::Assembler(opts.assembler_opts, std::move(cache)));

@ -39,7 +39,6 @@ namespace ppspeech {
struct FeaturePipelineOptions { struct FeaturePipelineOptions {
std::string cmvn_file{}; std::string cmvn_file{};
knf::FbankOptions fbank_opts{}; knf::FbankOptions fbank_opts{};
FeatureCacheOptions feature_cache_opts{};
AssemblerOptions assembler_opts{}; AssemblerOptions assembler_opts{};
static FeaturePipelineOptions InitFromFlags() { static FeaturePipelineOptions InitFromFlags() {

Loading…
Cancel
Save