parent
7dc9cba3be
commit
86eb718908
@ -1,41 +0,0 @@
|
|||||||
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
|
|
||||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "base/common.h"
|
|
||||||
|
|
||||||
namespace ppspeech {
|
|
||||||
|
|
||||||
struct WordPiece {
|
|
||||||
std::string word;
|
|
||||||
int start = -1;
|
|
||||||
int end = -1;
|
|
||||||
|
|
||||||
WordPiece(std::string word, int start, int end)
|
|
||||||
: word(std::move(word)), start(start), end(end) {}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct DecodeResult {
|
|
||||||
float score = -kBaseFloatMax;
|
|
||||||
std::string sentence;
|
|
||||||
std::vector<WordPiece> word_pieces;
|
|
||||||
|
|
||||||
static bool CompareFunc(const DecodeResult& a, const DecodeResult& b) {
|
|
||||||
return a.score > b.score;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace ppspeech
|
|
@ -0,0 +1,209 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "decoder/u2_recognizer.h"
|
||||||
|
#include "nnet/u2_nnet.h"
|
||||||
|
|
||||||
|
namespace ppspeech {
|
||||||
|
|
||||||
|
using kaldi::Vector;
|
||||||
|
using kaldi::VectorBase;
|
||||||
|
using kaldi::BaseFloat;
|
||||||
|
using std::vector;
|
||||||
|
using kaldi::SubVector;
|
||||||
|
using std::unique_ptr;
|
||||||
|
|
||||||
|
U2Recognizer::U2Recognizer(const U2RecognizerResource& resource): opts_(resource) {
|
||||||
|
const FeaturePipelineOptions& feature_opts = resource.feature_pipeline_opts;
|
||||||
|
feature_pipeline_.reset(new FeaturePipeline(feature_opts));
|
||||||
|
|
||||||
|
std::shared_ptr<NnetInterface> nnet(new U2Nnet(resource.model_opts));
|
||||||
|
|
||||||
|
BaseFloat am_scale = resource.acoustic_scale;
|
||||||
|
decodable_.reset(new Decodable(nnet, feature_pipeline_, am_scale));
|
||||||
|
|
||||||
|
decoder_.reset(new CTCPrefixBeamSearch(resource.vocab_path, resource.decoder_opts.ctc_prefix_search_opts));
|
||||||
|
|
||||||
|
unit_table_ = decoder_->VocabTable();
|
||||||
|
symbol_table_ = unit_table_;
|
||||||
|
|
||||||
|
input_finished_ = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void U2Recognizer::Reset() {
|
||||||
|
global_frame_offset_ = 0;
|
||||||
|
num_frames_ = 0;
|
||||||
|
result_.clear();
|
||||||
|
|
||||||
|
feature_pipeline_->Reset();
|
||||||
|
decodable_->Reset();
|
||||||
|
decoder_->Reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
void U2Recognizer::ResetContinuousDecoding() {
|
||||||
|
global_frame_offset_ = num_frames_;
|
||||||
|
num_frames_ = 0;
|
||||||
|
result_.clear();
|
||||||
|
|
||||||
|
feature_pipeline_->Reset();
|
||||||
|
decodable_->Reset();
|
||||||
|
decoder_->Reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void U2Recognizer::Accept(const VectorBase<BaseFloat>& waves) {
|
||||||
|
feature_pipeline_->Accept(waves);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void U2Recognizer::Decode() {
|
||||||
|
decoder_->AdvanceDecode(decodable_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void U2Recognizer::Rescoring() {
|
||||||
|
// Do attention Rescoring
|
||||||
|
kaldi::Timer timer;
|
||||||
|
AttentionRescoring();
|
||||||
|
VLOG(1) << "Rescoring cost latency: " << timer.Elapsed() << " sec.";
|
||||||
|
}
|
||||||
|
|
||||||
|
void U2Recognizer::UpdateResult(bool finish) {
|
||||||
|
const auto& hypotheses = decoder_->Outputs();
|
||||||
|
const auto& inputs = decoder_->Inputs();
|
||||||
|
const auto& likelihood = decoder_->Likelihood();
|
||||||
|
const auto& times = decoder_->Times();
|
||||||
|
result_.clear();
|
||||||
|
|
||||||
|
CHECK_EQ(hypotheses.size(), likelihood.size());
|
||||||
|
for (size_t i = 0; i < hypotheses.size(); i++) {
|
||||||
|
const std::vector<int>& hypothesis = hypotheses[i];
|
||||||
|
|
||||||
|
DecodeResult path;
|
||||||
|
path.score = likelihood[i];
|
||||||
|
for (size_t j = 0; j < hypothesis.size(); j++) {
|
||||||
|
std::string word = symbol_table_->Find(hypothesis[j]);
|
||||||
|
// A detailed explanation of this if-else branch can be found in
|
||||||
|
// https://github.com/wenet-e2e/wenet/issues/583#issuecomment-907994058
|
||||||
|
if (decoder_->Type() == kWfstBeamSearch) {
|
||||||
|
path.sentence += (" " + word);
|
||||||
|
} else {
|
||||||
|
path.sentence += (word);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TimeStamp is only supported in final result
|
||||||
|
// TimeStamp of the output of CtcWfstBeamSearch may be inaccurate due to
|
||||||
|
// various FST operations when building the decoding graph. So here we use
|
||||||
|
// time stamp of the input(e2e model unit), which is more accurate, and it
|
||||||
|
// requires the symbol table of the e2e model used in training.
|
||||||
|
if (unit_table_ != nullptr && finish) {
|
||||||
|
int offset = global_frame_offset_ * FrameShiftInMs();
|
||||||
|
|
||||||
|
const std::vector<int>& input = inputs[i];
|
||||||
|
const std::vector<int> time_stamp = times[i];
|
||||||
|
CHECK_EQ(input.size(), time_stamp.size());
|
||||||
|
|
||||||
|
for (size_t j = 0; j < input.size(); j++) {
|
||||||
|
std::string word = unit_table_->Find(input[j]);
|
||||||
|
|
||||||
|
int start = time_stamp[j] * FrameShiftInMs() - time_stamp_gap_ > 0
|
||||||
|
? time_stamp[j] * FrameShiftInMs() - time_stamp_gap_
|
||||||
|
: 0;
|
||||||
|
if (j > 0) {
|
||||||
|
start = (time_stamp[j] - time_stamp[j - 1]) * FrameShiftInMs() <
|
||||||
|
time_stamp_gap_
|
||||||
|
? (time_stamp[j - 1] + time_stamp[j]) / 2 *
|
||||||
|
FrameShiftInMs()
|
||||||
|
: start;
|
||||||
|
}
|
||||||
|
|
||||||
|
int end = time_stamp[j] * FrameShiftInMs();
|
||||||
|
if (j < input.size() - 1) {
|
||||||
|
end = (time_stamp[j + 1] - time_stamp[j]) * FrameShiftInMs() <
|
||||||
|
time_stamp_gap_
|
||||||
|
? (time_stamp[j + 1] + time_stamp[j]) / 2 *
|
||||||
|
FrameShiftInMs()
|
||||||
|
: end;
|
||||||
|
}
|
||||||
|
|
||||||
|
WordPiece word_piece(word, offset + start, offset + end);
|
||||||
|
path.word_pieces.emplace_back(word_piece);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// if (post_processor_ != nullptr) {
|
||||||
|
// path.sentence = post_processor_->Process(path.sentence, finish);
|
||||||
|
// }
|
||||||
|
|
||||||
|
result_.emplace_back(path);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (DecodedSomething()) {
|
||||||
|
VLOG(1) << "Partial CTC result " << result_[0].sentence;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void U2Recognizer::AttentionRescoring() {
|
||||||
|
decoder_->FinalizeSearch();
|
||||||
|
UpdateResult(true);
|
||||||
|
|
||||||
|
// No need to do rescoring
|
||||||
|
if (0.0 == opts_.decoder_opts.rescoring_weight) {
|
||||||
|
LOG_EVERY_N(WARNING, 3) << "Not do AttentionRescoring!";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
LOG_EVERY_N(WARNING, 3) << "Do AttentionRescoring!";
|
||||||
|
|
||||||
|
// Inputs() returns N-best input ids, which is the basic unit for rescoring
|
||||||
|
// In CtcPrefixBeamSearch, inputs are the same to outputs
|
||||||
|
const auto& hypotheses = decoder_->Inputs();
|
||||||
|
int num_hyps = hypotheses.size();
|
||||||
|
if (num_hyps <= 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
kaldi::Timer timer;
|
||||||
|
std::vector<float> rescoring_score;
|
||||||
|
decodable_->AttentionRescoring(
|
||||||
|
hypotheses, opts_.decoder_opts.reverse_weight, &rescoring_score);
|
||||||
|
VLOG(1) << "Attention Rescoring takes " << timer.Elapsed() << " sec.";
|
||||||
|
|
||||||
|
// combine ctc score and rescoring score
|
||||||
|
for (size_t i = 0; i < num_hyps; i++) {
|
||||||
|
VLOG(1) << "hyp " << i << " rescoring_score: " << rescoring_score[i]
|
||||||
|
<< " ctc_score: " << result_[i].score;
|
||||||
|
result_[i].score = opts_.decoder_opts.rescoring_weight * rescoring_score[i] +
|
||||||
|
opts_.decoder_opts.ctc_weight * result_[i].score;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::sort(result_.begin(), result_.end(), DecodeResult::CompareFunc);
|
||||||
|
VLOG(1) << "result: " << result_[0].sentence
|
||||||
|
<< " score: " << result_[0].score;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string U2Recognizer::GetFinalResult() {
|
||||||
|
return result_[0].sentence;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string U2Recognizer::GetPartialResult() {
|
||||||
|
return result_[0].sentence;
|
||||||
|
}
|
||||||
|
|
||||||
|
void U2Recognizer::SetFinished() {
|
||||||
|
feature_pipeline_->SetFinished();
|
||||||
|
input_finished_ = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
} // namespace ppspeech
|
@ -0,0 +1,164 @@
|
|||||||
|
|
||||||
|
|
||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "decoder/common.h"
|
||||||
|
#include "decoder/ctc_beam_search_opt.h"
|
||||||
|
#include "decoder/ctc_prefix_beam_search_decoder.h"
|
||||||
|
#include "decoder/decoder_itf.h"
|
||||||
|
#include "frontend/audio/feature_pipeline.h"
|
||||||
|
#include "nnet/decodable.h"
|
||||||
|
|
||||||
|
#include "fst/fstlib.h"
|
||||||
|
#include "fst/symbol-table.h"
|
||||||
|
|
||||||
|
namespace ppspeech {
|
||||||
|
|
||||||
|
|
||||||
|
struct DecodeOptions {
|
||||||
|
// chunk_size is the frame number of one chunk after subsampling.
|
||||||
|
// e.g. if subsample rate is 4 and chunk_size = 16, the frames in
|
||||||
|
// one chunk are 67=16*4 + 3, stride is 64=16*4
|
||||||
|
int chunk_size;
|
||||||
|
int num_left_chunks;
|
||||||
|
|
||||||
|
// final_score = rescoring_weight * rescoring_score + ctc_weight *
|
||||||
|
// ctc_score;
|
||||||
|
// rescoring_score = left_to_right_score * (1 - reverse_weight) +
|
||||||
|
// right_to_left_score * reverse_weight
|
||||||
|
// Please note the concept of ctc_scores
|
||||||
|
// in the following two search methods are different. For
|
||||||
|
// CtcPrefixBeamSerch,
|
||||||
|
// it's a sum(prefix) score + context score For CtcWfstBeamSerch, it's a
|
||||||
|
// max(viterbi) path score + context score So we should carefully set
|
||||||
|
// ctc_weight accroding to the search methods.
|
||||||
|
float ctc_weight;
|
||||||
|
float rescoring_weight;
|
||||||
|
float reverse_weight;
|
||||||
|
|
||||||
|
// CtcEndpointConfig ctc_endpoint_opts;
|
||||||
|
CTCBeamSearchOptions ctc_prefix_search_opts;
|
||||||
|
|
||||||
|
DecodeOptions()
|
||||||
|
: chunk_size(16),
|
||||||
|
num_left_chunks(-1),
|
||||||
|
ctc_weight(0.5),
|
||||||
|
rescoring_weight(1.0),
|
||||||
|
reverse_weight(0.0) {}
|
||||||
|
|
||||||
|
void Register(kaldi::OptionsItf* opts) {
|
||||||
|
std::string module = "DecoderConfig: ";
|
||||||
|
opts->Register(
|
||||||
|
"chunk-size",
|
||||||
|
&chunk_size,
|
||||||
|
module + "the frame number of one chunk after subsampling.");
|
||||||
|
opts->Register("num-left-chunks",
|
||||||
|
&num_left_chunks,
|
||||||
|
module + "the left history chunks number.");
|
||||||
|
opts->Register("ctc-weight",
|
||||||
|
&ctc_weight,
|
||||||
|
module +
|
||||||
|
"ctc weight for rescore. final_score = "
|
||||||
|
"rescoring_weight * rescoring_score + ctc_weight * "
|
||||||
|
"ctc_score.");
|
||||||
|
opts->Register("rescoring-weight",
|
||||||
|
&rescoring_weight,
|
||||||
|
module +
|
||||||
|
"attention score weight for rescore. final_score = "
|
||||||
|
"rescoring_weight * rescoring_score + ctc_weight * "
|
||||||
|
"ctc_score.");
|
||||||
|
opts->Register("reverse-weight",
|
||||||
|
&reverse_weight,
|
||||||
|
module +
|
||||||
|
"reverse decoder weight. rescoring_score = "
|
||||||
|
"left_to_right_score * (1 - reverse_weight) + "
|
||||||
|
"right_to_left_score * reverse_weight.");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
struct U2RecognizerResource {
|
||||||
|
FeaturePipelineOptions feature_pipeline_opts{};
|
||||||
|
ModelOptions model_opts{};
|
||||||
|
DecodeOptions decoder_opts{};
|
||||||
|
// CTCBeamSearchOptions beam_search_opts;
|
||||||
|
kaldi::BaseFloat acoustic_scale{1.0};
|
||||||
|
std::string vocab_path{};
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
class U2Recognizer {
|
||||||
|
public:
|
||||||
|
explicit U2Recognizer(const U2RecognizerResource& resouce);
|
||||||
|
void Reset();
|
||||||
|
void ResetContinuousDecoding();
|
||||||
|
|
||||||
|
void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& waves);
|
||||||
|
void Decode();
|
||||||
|
void Rescoring();
|
||||||
|
|
||||||
|
|
||||||
|
std::string GetFinalResult();
|
||||||
|
std::string GetPartialResult();
|
||||||
|
|
||||||
|
void SetFinished();
|
||||||
|
bool IsFinished() { return input_finished_; }
|
||||||
|
|
||||||
|
bool DecodedSomething() const {
|
||||||
|
return !result_.empty() && !result_[0].sentence.empty();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
int FrameShiftInMs() const {
|
||||||
|
// one decoder frame length in ms
|
||||||
|
return decodable_->Nnet()->SubsamplingRate() *
|
||||||
|
feature_pipeline_->FrameShift();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
const std::vector<DecodeResult>& Result() const { return result_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
void AttentionRescoring();
|
||||||
|
void UpdateResult(bool finish = false);
|
||||||
|
|
||||||
|
private:
|
||||||
|
U2RecognizerResource opts_;
|
||||||
|
|
||||||
|
// std::shared_ptr<U2RecognizerResource> resource_;
|
||||||
|
// U2RecognizerResource resource_;
|
||||||
|
std::shared_ptr<FeaturePipeline> feature_pipeline_;
|
||||||
|
std::shared_ptr<Decodable> decodable_;
|
||||||
|
std::unique_ptr<CTCPrefixBeamSearch> decoder_;
|
||||||
|
|
||||||
|
// e2e unit symbol table
|
||||||
|
std::shared_ptr<fst::SymbolTable> unit_table_ = nullptr;
|
||||||
|
std::shared_ptr<fst::SymbolTable> symbol_table_ = nullptr;
|
||||||
|
|
||||||
|
std::vector<DecodeResult> result_;
|
||||||
|
|
||||||
|
// global decoded frame offset
|
||||||
|
int global_frame_offset_;
|
||||||
|
// cur decoded frame num
|
||||||
|
int num_frames_;
|
||||||
|
// timestamp gap between words in a sentence
|
||||||
|
const int time_stamp_gap_ = 100;
|
||||||
|
|
||||||
|
bool input_finished_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace ppspeech
|
@ -0,0 +1,137 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "decoder/u2_recognizer.h"
|
||||||
|
#include "decoder/param.h"
|
||||||
|
#include "kaldi/feat/wave-reader.h"
|
||||||
|
#include "kaldi/util/table-types.h"
|
||||||
|
|
||||||
|
DEFINE_string(wav_rspecifier, "", "test feature rspecifier");
|
||||||
|
DEFINE_string(result_wspecifier, "", "test result wspecifier");
|
||||||
|
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
|
||||||
|
DEFINE_int32(sample_rate, 16000, "sample rate");
|
||||||
|
|
||||||
|
|
||||||
|
ppspeech::U2RecognizerResource InitOpts() {
|
||||||
|
ppspeech::U2RecognizerResource resource;
|
||||||
|
resource.acoustic_scale = FLAGS_acoustic_scale;
|
||||||
|
resource.feature_pipeline_opts = ppspeech::InitFeaturePipelineOptions();
|
||||||
|
|
||||||
|
ppspeech::ModelOptions model_opts;
|
||||||
|
model_opts.model_path = FLAGS_model_path;
|
||||||
|
|
||||||
|
resource.model_opts = model_opts;
|
||||||
|
|
||||||
|
ppspeech::DecodeOptions decoder_opts;
|
||||||
|
decoder_opts.chunk_size=16;
|
||||||
|
decoder_opts.num_left_chunks = -1;
|
||||||
|
decoder_opts.ctc_weight = 0.5;
|
||||||
|
decoder_opts.rescoring_weight = 1.0;
|
||||||
|
decoder_opts.reverse_weight = 0.3;
|
||||||
|
decoder_opts.ctc_prefix_search_opts.blank = 0;
|
||||||
|
decoder_opts.ctc_prefix_search_opts.first_beam_size = 10;
|
||||||
|
decoder_opts.ctc_prefix_search_opts.second_beam_size = 10;
|
||||||
|
|
||||||
|
resource.decoder_opts = decoder_opts;
|
||||||
|
return resource;
|
||||||
|
}
|
||||||
|
|
||||||
|
int main(int argc, char* argv[]) {
|
||||||
|
gflags::SetUsageMessage("Usage:");
|
||||||
|
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
||||||
|
google::InitGoogleLogging(argv[0]);
|
||||||
|
google::InstallFailureSignalHandler();
|
||||||
|
FLAGS_logtostderr = 1;
|
||||||
|
|
||||||
|
int32 num_done = 0, num_err = 0;
|
||||||
|
double tot_wav_duration = 0.0;
|
||||||
|
|
||||||
|
ppspeech::U2RecognizerResource resource = InitOpts();
|
||||||
|
ppspeech::U2Recognizer recognizer(resource);
|
||||||
|
|
||||||
|
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
|
||||||
|
FLAGS_wav_rspecifier);
|
||||||
|
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
|
||||||
|
|
||||||
|
int sample_rate = FLAGS_sample_rate;
|
||||||
|
float streaming_chunk = FLAGS_streaming_chunk;
|
||||||
|
int chunk_sample_size = streaming_chunk * sample_rate;
|
||||||
|
LOG(INFO) << "sr: " << sample_rate;
|
||||||
|
LOG(INFO) << "chunk size (s): " << streaming_chunk;
|
||||||
|
LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
|
||||||
|
|
||||||
|
kaldi::Timer timer;
|
||||||
|
|
||||||
|
for (; !wav_reader.Done(); wav_reader.Next()) {
|
||||||
|
std::string utt = wav_reader.Key();
|
||||||
|
const kaldi::WaveData& wave_data = wav_reader.Value();
|
||||||
|
LOG(INFO) << "utt: " << utt;
|
||||||
|
LOG(INFO) << "wav dur: " << wave_data.Duration() << " sec.";
|
||||||
|
tot_wav_duration += wave_data.Duration();
|
||||||
|
|
||||||
|
int32 this_channel = 0;
|
||||||
|
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
|
||||||
|
this_channel);
|
||||||
|
int tot_samples = waveform.Dim();
|
||||||
|
LOG(INFO) << "wav len (sample): " << tot_samples;
|
||||||
|
|
||||||
|
int sample_offset = 0;
|
||||||
|
while (sample_offset < tot_samples) {
|
||||||
|
int cur_chunk_size =
|
||||||
|
std::min(chunk_sample_size, tot_samples - sample_offset);
|
||||||
|
|
||||||
|
kaldi::Vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
|
||||||
|
for (int i = 0; i < cur_chunk_size; ++i) {
|
||||||
|
wav_chunk(i) = waveform(sample_offset + i);
|
||||||
|
}
|
||||||
|
// wav_chunk = waveform.Range(sample_offset + i, cur_chunk_size);
|
||||||
|
|
||||||
|
recognizer.Accept(wav_chunk);
|
||||||
|
if (cur_chunk_size < chunk_sample_size) {
|
||||||
|
recognizer.SetFinished();
|
||||||
|
}
|
||||||
|
recognizer.Decode();
|
||||||
|
LOG(INFO) << "Pratial result: " << recognizer.GetPartialResult();
|
||||||
|
|
||||||
|
// no overlap
|
||||||
|
sample_offset += cur_chunk_size;
|
||||||
|
}
|
||||||
|
// second pass decoding
|
||||||
|
recognizer.Rescoring();
|
||||||
|
|
||||||
|
std::string result = recognizer.GetFinalResult();
|
||||||
|
|
||||||
|
recognizer.Reset();
|
||||||
|
|
||||||
|
if (result.empty()) {
|
||||||
|
// the TokenWriter can not write empty string.
|
||||||
|
++num_err;
|
||||||
|
LOG(INFO) << " the result of " << utt << " is empty";
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG(INFO) << " the result of " << utt << " is " << result;
|
||||||
|
|
||||||
|
result_writer.Write(utt, result);
|
||||||
|
|
||||||
|
++num_done;
|
||||||
|
}
|
||||||
|
|
||||||
|
double elapsed = timer.Elapsed();
|
||||||
|
|
||||||
|
LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done);
|
||||||
|
LOG(INFO) << "cost:" << elapsed << " sec";
|
||||||
|
LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec";
|
||||||
|
LOG(INFO) << "the RTF is: " << elapsed / tot_wav_duration;
|
||||||
|
}
|
Loading…
Reference in new issue