parent
7dc9cba3be
commit
86eb718908
@ -1,41 +0,0 @@
|
||||
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "base/common.h"
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
struct WordPiece {
|
||||
std::string word;
|
||||
int start = -1;
|
||||
int end = -1;
|
||||
|
||||
WordPiece(std::string word, int start, int end)
|
||||
: word(std::move(word)), start(start), end(end) {}
|
||||
};
|
||||
|
||||
struct DecodeResult {
|
||||
float score = -kBaseFloatMax;
|
||||
std::string sentence;
|
||||
std::vector<WordPiece> word_pieces;
|
||||
|
||||
static bool CompareFunc(const DecodeResult& a, const DecodeResult& b) {
|
||||
return a.score > b.score;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ppspeech
|
@ -0,0 +1,209 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "decoder/u2_recognizer.h"
|
||||
#include "nnet/u2_nnet.h"
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
using kaldi::Vector;
|
||||
using kaldi::VectorBase;
|
||||
using kaldi::BaseFloat;
|
||||
using std::vector;
|
||||
using kaldi::SubVector;
|
||||
using std::unique_ptr;
|
||||
|
||||
U2Recognizer::U2Recognizer(const U2RecognizerResource& resource): opts_(resource) {
|
||||
const FeaturePipelineOptions& feature_opts = resource.feature_pipeline_opts;
|
||||
feature_pipeline_.reset(new FeaturePipeline(feature_opts));
|
||||
|
||||
std::shared_ptr<NnetInterface> nnet(new U2Nnet(resource.model_opts));
|
||||
|
||||
BaseFloat am_scale = resource.acoustic_scale;
|
||||
decodable_.reset(new Decodable(nnet, feature_pipeline_, am_scale));
|
||||
|
||||
decoder_.reset(new CTCPrefixBeamSearch(resource.vocab_path, resource.decoder_opts.ctc_prefix_search_opts));
|
||||
|
||||
unit_table_ = decoder_->VocabTable();
|
||||
symbol_table_ = unit_table_;
|
||||
|
||||
input_finished_ = false;
|
||||
}
|
||||
|
||||
void U2Recognizer::Reset() {
|
||||
global_frame_offset_ = 0;
|
||||
num_frames_ = 0;
|
||||
result_.clear();
|
||||
|
||||
feature_pipeline_->Reset();
|
||||
decodable_->Reset();
|
||||
decoder_->Reset();
|
||||
}
|
||||
|
||||
void U2Recognizer::ResetContinuousDecoding() {
|
||||
global_frame_offset_ = num_frames_;
|
||||
num_frames_ = 0;
|
||||
result_.clear();
|
||||
|
||||
feature_pipeline_->Reset();
|
||||
decodable_->Reset();
|
||||
decoder_->Reset();
|
||||
}
|
||||
|
||||
|
||||
void U2Recognizer::Accept(const VectorBase<BaseFloat>& waves) {
|
||||
feature_pipeline_->Accept(waves);
|
||||
}
|
||||
|
||||
|
||||
void U2Recognizer::Decode() {
|
||||
decoder_->AdvanceDecode(decodable_);
|
||||
}
|
||||
|
||||
void U2Recognizer::Rescoring() {
|
||||
// Do attention Rescoring
|
||||
kaldi::Timer timer;
|
||||
AttentionRescoring();
|
||||
VLOG(1) << "Rescoring cost latency: " << timer.Elapsed() << " sec.";
|
||||
}
|
||||
|
||||
void U2Recognizer::UpdateResult(bool finish) {
|
||||
const auto& hypotheses = decoder_->Outputs();
|
||||
const auto& inputs = decoder_->Inputs();
|
||||
const auto& likelihood = decoder_->Likelihood();
|
||||
const auto& times = decoder_->Times();
|
||||
result_.clear();
|
||||
|
||||
CHECK_EQ(hypotheses.size(), likelihood.size());
|
||||
for (size_t i = 0; i < hypotheses.size(); i++) {
|
||||
const std::vector<int>& hypothesis = hypotheses[i];
|
||||
|
||||
DecodeResult path;
|
||||
path.score = likelihood[i];
|
||||
for (size_t j = 0; j < hypothesis.size(); j++) {
|
||||
std::string word = symbol_table_->Find(hypothesis[j]);
|
||||
// A detailed explanation of this if-else branch can be found in
|
||||
// https://github.com/wenet-e2e/wenet/issues/583#issuecomment-907994058
|
||||
if (decoder_->Type() == kWfstBeamSearch) {
|
||||
path.sentence += (" " + word);
|
||||
} else {
|
||||
path.sentence += (word);
|
||||
}
|
||||
}
|
||||
|
||||
// TimeStamp is only supported in final result
|
||||
// TimeStamp of the output of CtcWfstBeamSearch may be inaccurate due to
|
||||
// various FST operations when building the decoding graph. So here we use
|
||||
// time stamp of the input(e2e model unit), which is more accurate, and it
|
||||
// requires the symbol table of the e2e model used in training.
|
||||
if (unit_table_ != nullptr && finish) {
|
||||
int offset = global_frame_offset_ * FrameShiftInMs();
|
||||
|
||||
const std::vector<int>& input = inputs[i];
|
||||
const std::vector<int> time_stamp = times[i];
|
||||
CHECK_EQ(input.size(), time_stamp.size());
|
||||
|
||||
for (size_t j = 0; j < input.size(); j++) {
|
||||
std::string word = unit_table_->Find(input[j]);
|
||||
|
||||
int start = time_stamp[j] * FrameShiftInMs() - time_stamp_gap_ > 0
|
||||
? time_stamp[j] * FrameShiftInMs() - time_stamp_gap_
|
||||
: 0;
|
||||
if (j > 0) {
|
||||
start = (time_stamp[j] - time_stamp[j - 1]) * FrameShiftInMs() <
|
||||
time_stamp_gap_
|
||||
? (time_stamp[j - 1] + time_stamp[j]) / 2 *
|
||||
FrameShiftInMs()
|
||||
: start;
|
||||
}
|
||||
|
||||
int end = time_stamp[j] * FrameShiftInMs();
|
||||
if (j < input.size() - 1) {
|
||||
end = (time_stamp[j + 1] - time_stamp[j]) * FrameShiftInMs() <
|
||||
time_stamp_gap_
|
||||
? (time_stamp[j + 1] + time_stamp[j]) / 2 *
|
||||
FrameShiftInMs()
|
||||
: end;
|
||||
}
|
||||
|
||||
WordPiece word_piece(word, offset + start, offset + end);
|
||||
path.word_pieces.emplace_back(word_piece);
|
||||
}
|
||||
}
|
||||
|
||||
// if (post_processor_ != nullptr) {
|
||||
// path.sentence = post_processor_->Process(path.sentence, finish);
|
||||
// }
|
||||
|
||||
result_.emplace_back(path);
|
||||
}
|
||||
|
||||
if (DecodedSomething()) {
|
||||
VLOG(1) << "Partial CTC result " << result_[0].sentence;
|
||||
}
|
||||
}
|
||||
|
||||
void U2Recognizer::AttentionRescoring() {
|
||||
decoder_->FinalizeSearch();
|
||||
UpdateResult(true);
|
||||
|
||||
// No need to do rescoring
|
||||
if (0.0 == opts_.decoder_opts.rescoring_weight) {
|
||||
LOG_EVERY_N(WARNING, 3) << "Not do AttentionRescoring!";
|
||||
return;
|
||||
}
|
||||
LOG_EVERY_N(WARNING, 3) << "Do AttentionRescoring!";
|
||||
|
||||
// Inputs() returns N-best input ids, which is the basic unit for rescoring
|
||||
// In CtcPrefixBeamSearch, inputs are the same to outputs
|
||||
const auto& hypotheses = decoder_->Inputs();
|
||||
int num_hyps = hypotheses.size();
|
||||
if (num_hyps <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
kaldi::Timer timer;
|
||||
std::vector<float> rescoring_score;
|
||||
decodable_->AttentionRescoring(
|
||||
hypotheses, opts_.decoder_opts.reverse_weight, &rescoring_score);
|
||||
VLOG(1) << "Attention Rescoring takes " << timer.Elapsed() << " sec.";
|
||||
|
||||
// combine ctc score and rescoring score
|
||||
for (size_t i = 0; i < num_hyps; i++) {
|
||||
VLOG(1) << "hyp " << i << " rescoring_score: " << rescoring_score[i]
|
||||
<< " ctc_score: " << result_[i].score;
|
||||
result_[i].score = opts_.decoder_opts.rescoring_weight * rescoring_score[i] +
|
||||
opts_.decoder_opts.ctc_weight * result_[i].score;
|
||||
}
|
||||
|
||||
std::sort(result_.begin(), result_.end(), DecodeResult::CompareFunc);
|
||||
VLOG(1) << "result: " << result_[0].sentence
|
||||
<< " score: " << result_[0].score;
|
||||
}
|
||||
|
||||
std::string U2Recognizer::GetFinalResult() {
|
||||
return result_[0].sentence;
|
||||
}
|
||||
|
||||
std::string U2Recognizer::GetPartialResult() {
|
||||
return result_[0].sentence;
|
||||
}
|
||||
|
||||
void U2Recognizer::SetFinished() {
|
||||
feature_pipeline_->SetFinished();
|
||||
input_finished_ = true;
|
||||
}
|
||||
|
||||
|
||||
} // namespace ppspeech
|
@ -0,0 +1,164 @@
|
||||
|
||||
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "decoder/common.h"
|
||||
#include "decoder/ctc_beam_search_opt.h"
|
||||
#include "decoder/ctc_prefix_beam_search_decoder.h"
|
||||
#include "decoder/decoder_itf.h"
|
||||
#include "frontend/audio/feature_pipeline.h"
|
||||
#include "nnet/decodable.h"
|
||||
|
||||
#include "fst/fstlib.h"
|
||||
#include "fst/symbol-table.h"
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
|
||||
struct DecodeOptions {
|
||||
// chunk_size is the frame number of one chunk after subsampling.
|
||||
// e.g. if subsample rate is 4 and chunk_size = 16, the frames in
|
||||
// one chunk are 67=16*4 + 3, stride is 64=16*4
|
||||
int chunk_size;
|
||||
int num_left_chunks;
|
||||
|
||||
// final_score = rescoring_weight * rescoring_score + ctc_weight *
|
||||
// ctc_score;
|
||||
// rescoring_score = left_to_right_score * (1 - reverse_weight) +
|
||||
// right_to_left_score * reverse_weight
|
||||
// Please note the concept of ctc_scores
|
||||
// in the following two search methods are different. For
|
||||
// CtcPrefixBeamSerch,
|
||||
// it's a sum(prefix) score + context score For CtcWfstBeamSerch, it's a
|
||||
// max(viterbi) path score + context score So we should carefully set
|
||||
// ctc_weight accroding to the search methods.
|
||||
float ctc_weight;
|
||||
float rescoring_weight;
|
||||
float reverse_weight;
|
||||
|
||||
// CtcEndpointConfig ctc_endpoint_opts;
|
||||
CTCBeamSearchOptions ctc_prefix_search_opts;
|
||||
|
||||
DecodeOptions()
|
||||
: chunk_size(16),
|
||||
num_left_chunks(-1),
|
||||
ctc_weight(0.5),
|
||||
rescoring_weight(1.0),
|
||||
reverse_weight(0.0) {}
|
||||
|
||||
void Register(kaldi::OptionsItf* opts) {
|
||||
std::string module = "DecoderConfig: ";
|
||||
opts->Register(
|
||||
"chunk-size",
|
||||
&chunk_size,
|
||||
module + "the frame number of one chunk after subsampling.");
|
||||
opts->Register("num-left-chunks",
|
||||
&num_left_chunks,
|
||||
module + "the left history chunks number.");
|
||||
opts->Register("ctc-weight",
|
||||
&ctc_weight,
|
||||
module +
|
||||
"ctc weight for rescore. final_score = "
|
||||
"rescoring_weight * rescoring_score + ctc_weight * "
|
||||
"ctc_score.");
|
||||
opts->Register("rescoring-weight",
|
||||
&rescoring_weight,
|
||||
module +
|
||||
"attention score weight for rescore. final_score = "
|
||||
"rescoring_weight * rescoring_score + ctc_weight * "
|
||||
"ctc_score.");
|
||||
opts->Register("reverse-weight",
|
||||
&reverse_weight,
|
||||
module +
|
||||
"reverse decoder weight. rescoring_score = "
|
||||
"left_to_right_score * (1 - reverse_weight) + "
|
||||
"right_to_left_score * reverse_weight.");
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
struct U2RecognizerResource {
|
||||
FeaturePipelineOptions feature_pipeline_opts{};
|
||||
ModelOptions model_opts{};
|
||||
DecodeOptions decoder_opts{};
|
||||
// CTCBeamSearchOptions beam_search_opts;
|
||||
kaldi::BaseFloat acoustic_scale{1.0};
|
||||
std::string vocab_path{};
|
||||
};
|
||||
|
||||
|
||||
class U2Recognizer {
|
||||
public:
|
||||
explicit U2Recognizer(const U2RecognizerResource& resouce);
|
||||
void Reset();
|
||||
void ResetContinuousDecoding();
|
||||
|
||||
void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& waves);
|
||||
void Decode();
|
||||
void Rescoring();
|
||||
|
||||
|
||||
std::string GetFinalResult();
|
||||
std::string GetPartialResult();
|
||||
|
||||
void SetFinished();
|
||||
bool IsFinished() { return input_finished_; }
|
||||
|
||||
bool DecodedSomething() const {
|
||||
return !result_.empty() && !result_[0].sentence.empty();
|
||||
}
|
||||
|
||||
|
||||
int FrameShiftInMs() const {
|
||||
// one decoder frame length in ms
|
||||
return decodable_->Nnet()->SubsamplingRate() *
|
||||
feature_pipeline_->FrameShift();
|
||||
}
|
||||
|
||||
|
||||
const std::vector<DecodeResult>& Result() const { return result_; }
|
||||
|
||||
private:
|
||||
void AttentionRescoring();
|
||||
void UpdateResult(bool finish = false);
|
||||
|
||||
private:
|
||||
U2RecognizerResource opts_;
|
||||
|
||||
// std::shared_ptr<U2RecognizerResource> resource_;
|
||||
// U2RecognizerResource resource_;
|
||||
std::shared_ptr<FeaturePipeline> feature_pipeline_;
|
||||
std::shared_ptr<Decodable> decodable_;
|
||||
std::unique_ptr<CTCPrefixBeamSearch> decoder_;
|
||||
|
||||
// e2e unit symbol table
|
||||
std::shared_ptr<fst::SymbolTable> unit_table_ = nullptr;
|
||||
std::shared_ptr<fst::SymbolTable> symbol_table_ = nullptr;
|
||||
|
||||
std::vector<DecodeResult> result_;
|
||||
|
||||
// global decoded frame offset
|
||||
int global_frame_offset_;
|
||||
// cur decoded frame num
|
||||
int num_frames_;
|
||||
// timestamp gap between words in a sentence
|
||||
const int time_stamp_gap_ = 100;
|
||||
|
||||
bool input_finished_;
|
||||
};
|
||||
|
||||
} // namespace ppspeech
|
@ -0,0 +1,137 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "decoder/u2_recognizer.h"
|
||||
#include "decoder/param.h"
|
||||
#include "kaldi/feat/wave-reader.h"
|
||||
#include "kaldi/util/table-types.h"
|
||||
|
||||
DEFINE_string(wav_rspecifier, "", "test feature rspecifier");
|
||||
DEFINE_string(result_wspecifier, "", "test result wspecifier");
|
||||
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
|
||||
DEFINE_int32(sample_rate, 16000, "sample rate");
|
||||
|
||||
|
||||
ppspeech::U2RecognizerResource InitOpts() {
|
||||
ppspeech::U2RecognizerResource resource;
|
||||
resource.acoustic_scale = FLAGS_acoustic_scale;
|
||||
resource.feature_pipeline_opts = ppspeech::InitFeaturePipelineOptions();
|
||||
|
||||
ppspeech::ModelOptions model_opts;
|
||||
model_opts.model_path = FLAGS_model_path;
|
||||
|
||||
resource.model_opts = model_opts;
|
||||
|
||||
ppspeech::DecodeOptions decoder_opts;
|
||||
decoder_opts.chunk_size=16;
|
||||
decoder_opts.num_left_chunks = -1;
|
||||
decoder_opts.ctc_weight = 0.5;
|
||||
decoder_opts.rescoring_weight = 1.0;
|
||||
decoder_opts.reverse_weight = 0.3;
|
||||
decoder_opts.ctc_prefix_search_opts.blank = 0;
|
||||
decoder_opts.ctc_prefix_search_opts.first_beam_size = 10;
|
||||
decoder_opts.ctc_prefix_search_opts.second_beam_size = 10;
|
||||
|
||||
resource.decoder_opts = decoder_opts;
|
||||
return resource;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
gflags::SetUsageMessage("Usage:");
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
google::InstallFailureSignalHandler();
|
||||
FLAGS_logtostderr = 1;
|
||||
|
||||
int32 num_done = 0, num_err = 0;
|
||||
double tot_wav_duration = 0.0;
|
||||
|
||||
ppspeech::U2RecognizerResource resource = InitOpts();
|
||||
ppspeech::U2Recognizer recognizer(resource);
|
||||
|
||||
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
|
||||
FLAGS_wav_rspecifier);
|
||||
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
|
||||
|
||||
int sample_rate = FLAGS_sample_rate;
|
||||
float streaming_chunk = FLAGS_streaming_chunk;
|
||||
int chunk_sample_size = streaming_chunk * sample_rate;
|
||||
LOG(INFO) << "sr: " << sample_rate;
|
||||
LOG(INFO) << "chunk size (s): " << streaming_chunk;
|
||||
LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
|
||||
|
||||
kaldi::Timer timer;
|
||||
|
||||
for (; !wav_reader.Done(); wav_reader.Next()) {
|
||||
std::string utt = wav_reader.Key();
|
||||
const kaldi::WaveData& wave_data = wav_reader.Value();
|
||||
LOG(INFO) << "utt: " << utt;
|
||||
LOG(INFO) << "wav dur: " << wave_data.Duration() << " sec.";
|
||||
tot_wav_duration += wave_data.Duration();
|
||||
|
||||
int32 this_channel = 0;
|
||||
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
|
||||
this_channel);
|
||||
int tot_samples = waveform.Dim();
|
||||
LOG(INFO) << "wav len (sample): " << tot_samples;
|
||||
|
||||
int sample_offset = 0;
|
||||
while (sample_offset < tot_samples) {
|
||||
int cur_chunk_size =
|
||||
std::min(chunk_sample_size, tot_samples - sample_offset);
|
||||
|
||||
kaldi::Vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
|
||||
for (int i = 0; i < cur_chunk_size; ++i) {
|
||||
wav_chunk(i) = waveform(sample_offset + i);
|
||||
}
|
||||
// wav_chunk = waveform.Range(sample_offset + i, cur_chunk_size);
|
||||
|
||||
recognizer.Accept(wav_chunk);
|
||||
if (cur_chunk_size < chunk_sample_size) {
|
||||
recognizer.SetFinished();
|
||||
}
|
||||
recognizer.Decode();
|
||||
LOG(INFO) << "Pratial result: " << recognizer.GetPartialResult();
|
||||
|
||||
// no overlap
|
||||
sample_offset += cur_chunk_size;
|
||||
}
|
||||
// second pass decoding
|
||||
recognizer.Rescoring();
|
||||
|
||||
std::string result = recognizer.GetFinalResult();
|
||||
|
||||
recognizer.Reset();
|
||||
|
||||
if (result.empty()) {
|
||||
// the TokenWriter can not write empty string.
|
||||
++num_err;
|
||||
LOG(INFO) << " the result of " << utt << " is empty";
|
||||
continue;
|
||||
}
|
||||
|
||||
LOG(INFO) << " the result of " << utt << " is " << result;
|
||||
|
||||
result_writer.Write(utt, result);
|
||||
|
||||
++num_done;
|
||||
}
|
||||
|
||||
double elapsed = timer.Elapsed();
|
||||
|
||||
LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done);
|
||||
LOG(INFO) << "cost:" << elapsed << " sec";
|
||||
LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec";
|
||||
LOG(INFO) << "the RTF is: " << elapsed / tot_wav_duration;
|
||||
}
|
Loading…
Reference in new issue