|
|
|
@ -33,11 +33,15 @@ U2Recognizer::U2Recognizer(const U2RecognizerResource& resource)
|
|
|
|
|
decodable_.reset(new Decodable(nnet_producer_, am_scale));
|
|
|
|
|
|
|
|
|
|
CHECK_NE(resource.vocab_path, "");
|
|
|
|
|
decoder_.reset(new CTCPrefixBeamSearch(
|
|
|
|
|
resource.vocab_path, resource.decoder_opts.ctc_prefix_search_opts));
|
|
|
|
|
if (resource.decoder_opts.tlg_decoder_opts.fst_path == "") {
|
|
|
|
|
LOG(INFO) << resource.decoder_opts.tlg_decoder_opts.fst_path;
|
|
|
|
|
decoder_.reset(new CTCPrefixBeamSearch(
|
|
|
|
|
resource.vocab_path, resource.decoder_opts.ctc_prefix_search_opts));
|
|
|
|
|
} else {
|
|
|
|
|
decoder_.reset(new TLGDecoder(resource.decoder_opts.tlg_decoder_opts));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
unit_table_ = decoder_->VocabTable();
|
|
|
|
|
symbol_table_ = unit_table_;
|
|
|
|
|
symbol_table_ = decoder_->WordSymbolTable();
|
|
|
|
|
|
|
|
|
|
global_frame_offset_ = 0;
|
|
|
|
|
input_finished_ = false;
|
|
|
|
@ -56,11 +60,14 @@ U2Recognizer::U2Recognizer(const U2RecognizerResource& resource,
|
|
|
|
|
decodable_.reset(new Decodable(nnet_producer_, am_scale));
|
|
|
|
|
|
|
|
|
|
CHECK_NE(resource.vocab_path, "");
|
|
|
|
|
decoder_.reset(new CTCPrefixBeamSearch(
|
|
|
|
|
resource.vocab_path, resource.decoder_opts.ctc_prefix_search_opts));
|
|
|
|
|
if (resource.decoder_opts.tlg_decoder_opts.fst_path == "") {
|
|
|
|
|
decoder_.reset(new CTCPrefixBeamSearch(
|
|
|
|
|
resource.vocab_path, resource.decoder_opts.ctc_prefix_search_opts));
|
|
|
|
|
} else {
|
|
|
|
|
decoder_.reset(new TLGDecoder(resource.decoder_opts.tlg_decoder_opts));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
unit_table_ = decoder_->VocabTable();
|
|
|
|
|
symbol_table_ = unit_table_;
|
|
|
|
|
symbol_table_ = decoder_->WordSymbolTable();
|
|
|
|
|
|
|
|
|
|
global_frame_offset_ = 0;
|
|
|
|
|
input_finished_ = false;
|
|
|
|
@ -109,10 +116,11 @@ void U2Recognizer::RunDecoderSearch(U2Recognizer* me) {
|
|
|
|
|
void U2Recognizer::RunDecoderSearchInternal() {
|
|
|
|
|
LOG(INFO) << "DecoderSearchInteral begin";
|
|
|
|
|
while (!nnet_producer_->IsFinished()) {
|
|
|
|
|
nnet_producer_->UnLock();
|
|
|
|
|
nnet_producer_->WaitProduce();
|
|
|
|
|
decoder_->AdvanceDecode(decodable_);
|
|
|
|
|
}
|
|
|
|
|
Decode();
|
|
|
|
|
decoder_->AdvanceDecode(decodable_);
|
|
|
|
|
UpdateResult(false);
|
|
|
|
|
LOG(INFO) << "DecoderSearchInteral exit";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -140,7 +148,7 @@ void U2Recognizer::UpdateResult(bool finish) {
|
|
|
|
|
const auto& times = decoder_->Times();
|
|
|
|
|
result_.clear();
|
|
|
|
|
|
|
|
|
|
CHECK_EQ(hypotheses.size(), likelihood.size());
|
|
|
|
|
CHECK_EQ(inputs.size(), likelihood.size());
|
|
|
|
|
for (size_t i = 0; i < hypotheses.size(); i++) {
|
|
|
|
|
const std::vector<int>& hypothesis = hypotheses[i];
|
|
|
|
|
|
|
|
|
@ -148,13 +156,9 @@ void U2Recognizer::UpdateResult(bool finish) {
|
|
|
|
|
path.score = likelihood[i];
|
|
|
|
|
for (size_t j = 0; j < hypothesis.size(); j++) {
|
|
|
|
|
std::string word = symbol_table_->Find(hypothesis[j]);
|
|
|
|
|
// A detailed explanation of this if-else branch can be found in
|
|
|
|
|
// https://github.com/wenet-e2e/wenet/issues/583#issuecomment-907994058
|
|
|
|
|
if (decoder_->Type() == kWfstBeamSearch) {
|
|
|
|
|
path.sentence += (" " + word);
|
|
|
|
|
} else {
|
|
|
|
|
path.sentence += (word);
|
|
|
|
|
}
|
|
|
|
|
// path.sentence += (" " + word); // todo SmileGoat: add blank
|
|
|
|
|
// processor
|
|
|
|
|
path.sentence += word; // todo SmileGoat: add blank processor
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TimeStamp is only supported in final result
|
|
|
|
@ -162,7 +166,7 @@ void U2Recognizer::UpdateResult(bool finish) {
|
|
|
|
|
// various FST operations when building the decoding graph. So here we
|
|
|
|
|
// use time stamp of the input(e2e model unit), which is more accurate,
|
|
|
|
|
// and it requires the symbol table of the e2e model used in training.
|
|
|
|
|
if (unit_table_ != nullptr && finish) {
|
|
|
|
|
if (symbol_table_ != nullptr && finish) {
|
|
|
|
|
int offset = global_frame_offset_ * FrameShiftInMs();
|
|
|
|
|
|
|
|
|
|
const std::vector<int>& input = inputs[i];
|
|
|
|
@ -170,7 +174,7 @@ void U2Recognizer::UpdateResult(bool finish) {
|
|
|
|
|
CHECK_EQ(input.size(), time_stamp.size());
|
|
|
|
|
|
|
|
|
|
for (size_t j = 0; j < input.size(); j++) {
|
|
|
|
|
std::string word = unit_table_->Find(input[j]);
|
|
|
|
|
std::string word = symbol_table_->Find(input[j]);
|
|
|
|
|
|
|
|
|
|
int start =
|
|
|
|
|
time_stamp[j] * FrameShiftInMs() - time_stamp_gap_ > 0
|
|
|
|
@ -214,7 +218,7 @@ void U2Recognizer::UpdateResult(bool finish) {
|
|
|
|
|
|
|
|
|
|
void U2Recognizer::AttentionRescoring() {
|
|
|
|
|
decoder_->FinalizeSearch();
|
|
|
|
|
UpdateResult(true);
|
|
|
|
|
UpdateResult(false);
|
|
|
|
|
|
|
|
|
|
// No need to do rescoring
|
|
|
|
|
if (0.0 == opts_.decoder_opts.rescoring_weight) {
|
|
|
|
|