add u2 recg

pull/2524/head
Hui Zhang 3 years ago
parent 7dc9cba3be
commit 86eb718908

@ -1,5 +1,5 @@
#!/bin/bash
set -x
set +x
set -e
. path.sh

@ -9,6 +9,7 @@ add_library(decoder STATIC
ctc_prefix_beam_search_decoder.cc
ctc_tlg_decoder.cc
recognizer.cc
u2_recognizer.cc
)
target_link_libraries(decoder PUBLIC kenlm utils fst frontend nnet kaldi-decoder absl::strings)
@ -28,10 +29,16 @@ endforeach()
# u2
set(bin_name ctc_prefix_beam_search_decoder_main)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util)
target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS})
target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS})
set(TEST_BINS
u2_recognizer_main
ctc_prefix_beam_search_decoder_main
)
foreach(bin_name IN LISTS TEST_BINS)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util)
target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS})
target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS})
endforeach()

@ -1,3 +1,4 @@
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
@ -12,10 +13,36 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "base/basic_types.h"
#pragma once
#include "base/common.h"
struct DecoderResult {
BaseFloat acoustic_score;
std::vector<int32> words_idx;
std::vector<pair<int32, int32>> time_stamp;
std::vector<std::pair<int32, int32>> time_stamp;
};
namespace ppspeech {
struct WordPiece {
std::string word;
int start = -1;
int end = -1;
WordPiece(std::string word, int start, int end)
: word(std::move(word)), start(start), end(end) {}
};
struct DecodeResult {
float score = -kBaseFloatMax;
std::string sentence;
std::vector<WordPiece> word_pieces;
static bool CompareFunc(const DecodeResult& a, const DecodeResult& b) {
return a.score > b.score;
}
};
} // namespace ppspeech

@ -76,68 +76,4 @@ struct CTCBeamSearchOptions {
}
};
// used by u2 model
struct CTCBeamSearchDecoderOptions {
// chunk_size is the frame number of one chunk after subsampling.
// e.g. if subsample rate is 4 and chunk_size = 16, the frames in
// one chunk are 67=16*4 + 3, stride is 64=16*4
int chunk_size;
int num_left_chunks;
// final_score = rescoring_weight * rescoring_score + ctc_weight *
// ctc_score;
// rescoring_score = left_to_right_score * (1 - reverse_weight) +
// right_to_left_score * reverse_weight
// Please note the concept of ctc_scores
// in the following two search methods are different. For
// CtcPrefixBeamSerch,
// it's a sum(prefix) score + context score For CtcWfstBeamSerch, it's a
// max(viterbi) path score + context score So we should carefully set
// ctc_weight accroding to the search methods.
float ctc_weight;
float rescoring_weight;
float reverse_weight;
// CtcEndpointConfig ctc_endpoint_opts;
CTCBeamSearchOptions ctc_prefix_search_opts;
CTCBeamSearchDecoderOptions()
: chunk_size(16),
num_left_chunks(-1),
ctc_weight(0.5),
rescoring_weight(1.0),
reverse_weight(0.0) {}
void Register(kaldi::OptionsItf* opts) {
std::string module = "DecoderConfig: ";
opts->Register(
"chunk-size",
&chunk_size,
module + "the frame number of one chunk after subsampling.");
opts->Register("num-left-chunks",
&num_left_chunks,
module + "the left history chunks number.");
opts->Register("ctc-weight",
&ctc_weight,
module +
"ctc weight for rescore. final_score = "
"rescoring_weight * rescoring_score + ctc_weight * "
"ctc_score.");
opts->Register("rescoring-weight",
&rescoring_weight,
module +
"attention score weight for rescore. final_score = "
"rescoring_weight * rescoring_score + ctc_weight * "
"ctc_score.");
opts->Register("reverse-weight",
&reverse_weight,
module +
"reverse decoder weight. rescoring_score = "
"left_to_right_score * (1 - reverse_weight) + "
"right_to_left_score * reverse_weight.");
}
};
} // namespace ppspeech

@ -30,8 +30,14 @@ using paddle::platform::TracerEventType;
namespace ppspeech {
CTCPrefixBeamSearch::CTCPrefixBeamSearch(const CTCBeamSearchOptions& opts)
CTCPrefixBeamSearch::CTCPrefixBeamSearch(
const std::string vocab_path,
const CTCBeamSearchOptions& opts)
: opts_(opts) {
unit_table_ = std::shared_ptr<fst::SymbolTable>(fst::SymbolTable::ReadText(vocab_path));
CHECK(unit_table_ != nullptr);
Reset();
}
@ -322,7 +328,11 @@ void CTCPrefixBeamSearch::UpdateFinalContext() {
CHECK(n_hyps > 0);
CHECK(index < n_hyps);
std::vector<int> one = Outputs()[index];
return std::string(absl::StrJoin(one, kSpaceSymbol));
std::string sentence;
for (int i = 0; i < one.size(); i++){
sentence += unit_table_->Find(one[i]);
}
return sentence;
}
std::string CTCPrefixBeamSearch::GetBestPath() {

@ -15,17 +15,21 @@
#pragma once
#include "decoder/ctc_beam_search_opt.h"
#include "decoder/ctc_prefix_beam_search_result.h"
#include "decoder/ctc_prefix_beam_search_score.h"
#include "decoder/decoder_itf.h"
#include "fst/symbol-table.h"
namespace ppspeech {
class ContextGraph;
class CTCPrefixBeamSearch : public DecoderInterface {
public:
explicit CTCPrefixBeamSearch(const CTCBeamSearchOptions& opts);
explicit CTCPrefixBeamSearch(const std::string vocab_path,
const CTCBeamSearchOptions& opts);
~CTCPrefixBeamSearch() {}
SearchType Type() const { return SearchType::kPrefixBeamSearch; }
void InitDecoder() override;
void Reset() override;
@ -38,10 +42,9 @@ class CTCPrefixBeamSearch : public DecoderInterface {
void FinalizeSearch();
protected:
std::string GetBestPath() override;
std::vector<std::pair<double, std::string>> GetNBestPath() override;
std::vector<std::pair<double, std::string>> GetNBestPath(int n) override;
const std::shared_ptr<fst::SymbolTable> VocabTable() const {
return unit_table_;
}
const std::vector<std::vector<int>>& Inputs() const { return hypotheses_; }
const std::vector<std::vector<int>>& Outputs() const { return outputs_; }
@ -52,6 +55,11 @@ class CTCPrefixBeamSearch : public DecoderInterface {
const std::vector<std::vector<int>>& Times() const { return times_; }
protected:
std::string GetBestPath() override;
std::vector<std::pair<double, std::string>> GetNBestPath() override;
std::vector<std::pair<double, std::string>> GetNBestPath(int n) override;
private:
std::string GetBestPath(int index);
@ -66,6 +74,7 @@ class CTCPrefixBeamSearch : public DecoderInterface {
private:
CTCBeamSearchOptions opts_;
std::shared_ptr<fst::SymbolTable> unit_table_;
std::unordered_map<std::vector<int>, PrefixScore, PrefixScoreHash>
cur_hyps_;
@ -86,28 +95,4 @@ class CTCPrefixBeamSearch : public DecoderInterface {
};
class CTCPrefixBeamSearchDecoder : public CTCPrefixBeamSearch {
public:
explicit CTCPrefixBeamSearchDecoder(const CTCBeamSearchDecoderOptions& opts)
: CTCPrefixBeamSearch(opts.ctc_prefix_search_opts), opts_(opts) {}
~CTCPrefixBeamSearchDecoder() {}
private:
CTCBeamSearchDecoderOptions opts_;
// cache feature
bool start_ = false; // false, this is first frame.
// for continues decoding
int num_frames_ = 0;
int global_frame_offset_ = 0;
const int time_stamp_gap_ =
100; // timestamp gap between words in a sentence
// std::unique_ptr<CtcEndpoint> ctc_endpointer_;
int num_frames_in_current_chunk_ = 0;
std::vector<DecodeResult> result_;
};
} // namespace ppspeech

@ -55,14 +55,12 @@ int main(int argc, char* argv[]) {
CHECK(FLAGS_vocab_path != "");
CHECK(FLAGS_model_path != "");
LOG(INFO) << "model path: " << FLAGS_model_path;
LOG(INFO) << "Reading vocab table " << FLAGS_vocab_path;
kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_rspecifier);
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
LOG(INFO) << "Reading vocab table " << FLAGS_vocab_path;
fst::SymbolTable* unit_table = fst::SymbolTable::ReadText(FLAGS_vocab_path);
// nnet
ppspeech::ModelOptions model_opts;
model_opts.model_path = FLAGS_model_path;
@ -75,16 +73,11 @@ int main(int argc, char* argv[]) {
new ppspeech::Decodable(nnet, raw_data));
// decoder
ppspeech::CTCBeamSearchDecoderOptions opts;
opts.chunk_size = 16;
opts.num_left_chunks = -1;
opts.ctc_weight = 0.5;
opts.rescoring_weight = 1.0;
opts.reverse_weight = 0.3;
opts.ctc_prefix_search_opts.blank = 0;
opts.ctc_prefix_search_opts.first_beam_size = 10;
opts.ctc_prefix_search_opts.second_beam_size = 10;
ppspeech::CTCPrefixBeamSearchDecoder decoder(opts);
ppspeech::CTCBeamSearchOptions opts;
opts.blank = 0;
opts.first_beam_size = 10;
opts.second_beam_size = 10;
ppspeech::CTCPrefixBeamSearch decoder(FLAGS_vocab_path, opts);
int32 chunk_size = FLAGS_receptive_field_length +
@ -150,17 +143,14 @@ int main(int argc, char* argv[]) {
// forward nnet
decoder.AdvanceDecode(decodable);
LOG(INFO) << "Partial result: " << decoder.GetPartialResult();
}
decoder.FinalizeSearch();
// get 1-best result
std::string result_ints = decoder.GetFinalBestPath();
std::vector<std::string> tokenids = absl::StrSplit(result_ints, ppspeech::kSpaceSymbol);
std::string result;
for (int i = 0; i < tokenids.size(); i++){
result += unit_table->Find(std::stoi(tokenids[i]));
}
std::string result = decoder.GetFinalBestPath();
// after process one utt, then reset state.
decodable->Reset();

@ -1,41 +0,0 @@
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base/common.h"
namespace ppspeech {
struct WordPiece {
std::string word;
int start = -1;
int end = -1;
WordPiece(std::string word, int start, int end)
: word(std::move(word)), start(start), end(end) {}
};
struct DecodeResult {
float score = -kBaseFloatMax;
std::string sentence;
std::vector<WordPiece> word_pieces;
static bool CompareFunc(const DecodeResult& a, const DecodeResult& b) {
return a.score > b.score;
}
};
} // namespace ppspeech

@ -20,6 +20,10 @@
namespace ppspeech {
enum SearchType {
kPrefixBeamSearch = 0,
kWfstBeamSearch = 1,
};
class DecoderInterface {
public:
virtual ~DecoderInterface() {}

@ -19,12 +19,15 @@
#include "decoder/ctc_tlg_decoder.h"
#include "frontend/audio/feature_pipeline.h"
// feature
DEFINE_bool(use_fbank, false, "False for fbank; or linear feature");
// DEFINE_bool(to_float32, true, "audio convert to pcm32. True for linear
// feature, or fbank");
DEFINE_int32(num_bins, 161, "num bins of mel");
DEFINE_string(cmvn_file, "", "read cmvn");
// feature sliding window
DEFINE_int32(receptive_field_length,
7,
@ -33,6 +36,8 @@ DEFINE_int32(downsampling_rate,
4,
"two CNN(kernel=3) module downsampling rate.");
DEFINE_int32(nnet_decoder_chunk, 1, "paddle nnet forward chunk");
// nnet
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
@ -89,34 +94,4 @@ FeaturePipelineOptions InitFeaturePipelineOptions() {
return opts;
}
ModelOptions InitModelOptions() {
ModelOptions model_opts;
model_opts.model_path = FLAGS_model_path;
model_opts.param_path = FLAGS_param_path;
model_opts.cache_names = FLAGS_model_cache_names;
model_opts.cache_shape = FLAGS_model_cache_shapes;
model_opts.input_names = FLAGS_model_input_names;
model_opts.output_names = FLAGS_model_output_names;
return model_opts;
}
TLGDecoderOptions InitDecoderOptions() {
TLGDecoderOptions decoder_opts;
decoder_opts.word_symbol_table = FLAGS_word_symbol_table;
decoder_opts.fst_path = FLAGS_graph_path;
decoder_opts.opts.max_active = FLAGS_max_active;
decoder_opts.opts.beam = FLAGS_beam;
decoder_opts.opts.lattice_beam = FLAGS_lattice_beam;
return decoder_opts;
}
RecognizerResource InitRecognizerResoure() {
RecognizerResource resource;
resource.acoustic_scale = FLAGS_acoustic_scale;
resource.feature_pipeline_opts = InitFeaturePipelineOptions();
resource.model_opts = InitModelOptions();
resource.tlg_opts = InitDecoderOptions();
return resource;
}
} // namespace ppspeech

@ -14,6 +14,7 @@
#include "decoder/recognizer.h"
namespace ppspeech {
using kaldi::Vector;
@ -23,14 +24,19 @@ using std::vector;
using kaldi::SubVector;
using std::unique_ptr;
Recognizer::Recognizer(const RecognizerResource& resource) {
// resource_ = resource;
const FeaturePipelineOptions& feature_opts = resource.feature_pipeline_opts;
feature_pipeline_.reset(new FeaturePipeline(feature_opts));
std::shared_ptr<PaddleNnet> nnet(new PaddleNnet(resource.model_opts));
BaseFloat ac_scale = resource.acoustic_scale;
decodable_.reset(new Decodable(nnet, feature_pipeline_, ac_scale));
decoder_.reset(new TLGDecoder(resource.tlg_opts));
input_finished_ = false;
}

@ -25,16 +25,11 @@
namespace ppspeech {
struct RecognizerResource {
FeaturePipelineOptions feature_pipeline_opts;
ModelOptions model_opts;
TLGDecoderOptions tlg_opts;
FeaturePipelineOptions feature_pipeline_opts{};
ModelOptions model_opts{};
TLGDecoderOptions tlg_opts{};
// CTCBeamSearchOptions beam_search_opts;
kaldi::BaseFloat acoustic_scale;
RecognizerResource()
: acoustic_scale(1.0),
feature_pipeline_opts(),
model_opts(),
tlg_opts() {}
kaldi::BaseFloat acoustic_scale{1.0};
};
class Recognizer {

@ -22,6 +22,33 @@ DEFINE_string(result_wspecifier, "", "test result wspecifier");
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
DEFINE_int32(sample_rate, 16000, "sample rate");
ppspeech::RecognizerResource InitRecognizerResoure() {
ppspeech::RecognizerResource resource;
resource.acoustic_scale = FLAGS_acoustic_scale;
resource.feature_pipeline_opts = ppspeech::InitFeaturePipelineOptions();
ppspeech::ModelOptions model_opts;
model_opts.model_path = FLAGS_model_path;
model_opts.param_path = FLAGS_param_path;
model_opts.cache_names = FLAGS_model_cache_names;
model_opts.cache_shape = FLAGS_model_cache_shapes;
model_opts.input_names = FLAGS_model_input_names;
model_opts.output_names = FLAGS_model_output_names;
model_opts.subsample_rate = FLAGS_downsampling_rate;
resource.model_opts = model_opts;
ppspeech::TLGDecoderOptions decoder_opts;
decoder_opts.word_symbol_table = FLAGS_word_symbol_table;
decoder_opts.fst_path = FLAGS_graph_path;
decoder_opts.opts.max_active = FLAGS_max_active;
decoder_opts.opts.beam = FLAGS_beam;
decoder_opts.opts.lattice_beam = FLAGS_lattice_beam;
resource.tlg_opts = decoder_opts;
return resource;
}
int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false);
@ -29,7 +56,7 @@ int main(int argc, char* argv[]) {
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
ppspeech::RecognizerResource resource = ppspeech::InitRecognizerResoure();
ppspeech::RecognizerResource resource = InitRecognizerResoure();
ppspeech::Recognizer recognizer(resource);
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(

@ -0,0 +1,209 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder/u2_recognizer.h"
#include "nnet/u2_nnet.h"
namespace ppspeech {
using kaldi::Vector;
using kaldi::VectorBase;
using kaldi::BaseFloat;
using std::vector;
using kaldi::SubVector;
using std::unique_ptr;
U2Recognizer::U2Recognizer(const U2RecognizerResource& resource): opts_(resource) {
const FeaturePipelineOptions& feature_opts = resource.feature_pipeline_opts;
feature_pipeline_.reset(new FeaturePipeline(feature_opts));
std::shared_ptr<NnetInterface> nnet(new U2Nnet(resource.model_opts));
BaseFloat am_scale = resource.acoustic_scale;
decodable_.reset(new Decodable(nnet, feature_pipeline_, am_scale));
decoder_.reset(new CTCPrefixBeamSearch(resource.vocab_path, resource.decoder_opts.ctc_prefix_search_opts));
unit_table_ = decoder_->VocabTable();
symbol_table_ = unit_table_;
input_finished_ = false;
}
void U2Recognizer::Reset() {
global_frame_offset_ = 0;
num_frames_ = 0;
result_.clear();
feature_pipeline_->Reset();
decodable_->Reset();
decoder_->Reset();
}
void U2Recognizer::ResetContinuousDecoding() {
global_frame_offset_ = num_frames_;
num_frames_ = 0;
result_.clear();
feature_pipeline_->Reset();
decodable_->Reset();
decoder_->Reset();
}
void U2Recognizer::Accept(const VectorBase<BaseFloat>& waves) {
feature_pipeline_->Accept(waves);
}
void U2Recognizer::Decode() {
decoder_->AdvanceDecode(decodable_);
}
void U2Recognizer::Rescoring() {
// Do attention Rescoring
kaldi::Timer timer;
AttentionRescoring();
VLOG(1) << "Rescoring cost latency: " << timer.Elapsed() << " sec.";
}
void U2Recognizer::UpdateResult(bool finish) {
const auto& hypotheses = decoder_->Outputs();
const auto& inputs = decoder_->Inputs();
const auto& likelihood = decoder_->Likelihood();
const auto& times = decoder_->Times();
result_.clear();
CHECK_EQ(hypotheses.size(), likelihood.size());
for (size_t i = 0; i < hypotheses.size(); i++) {
const std::vector<int>& hypothesis = hypotheses[i];
DecodeResult path;
path.score = likelihood[i];
for (size_t j = 0; j < hypothesis.size(); j++) {
std::string word = symbol_table_->Find(hypothesis[j]);
// A detailed explanation of this if-else branch can be found in
// https://github.com/wenet-e2e/wenet/issues/583#issuecomment-907994058
if (decoder_->Type() == kWfstBeamSearch) {
path.sentence += (" " + word);
} else {
path.sentence += (word);
}
}
// TimeStamp is only supported in final result
// TimeStamp of the output of CtcWfstBeamSearch may be inaccurate due to
// various FST operations when building the decoding graph. So here we use
// time stamp of the input(e2e model unit), which is more accurate, and it
// requires the symbol table of the e2e model used in training.
if (unit_table_ != nullptr && finish) {
int offset = global_frame_offset_ * FrameShiftInMs();
const std::vector<int>& input = inputs[i];
const std::vector<int> time_stamp = times[i];
CHECK_EQ(input.size(), time_stamp.size());
for (size_t j = 0; j < input.size(); j++) {
std::string word = unit_table_->Find(input[j]);
int start = time_stamp[j] * FrameShiftInMs() - time_stamp_gap_ > 0
? time_stamp[j] * FrameShiftInMs() - time_stamp_gap_
: 0;
if (j > 0) {
start = (time_stamp[j] - time_stamp[j - 1]) * FrameShiftInMs() <
time_stamp_gap_
? (time_stamp[j - 1] + time_stamp[j]) / 2 *
FrameShiftInMs()
: start;
}
int end = time_stamp[j] * FrameShiftInMs();
if (j < input.size() - 1) {
end = (time_stamp[j + 1] - time_stamp[j]) * FrameShiftInMs() <
time_stamp_gap_
? (time_stamp[j + 1] + time_stamp[j]) / 2 *
FrameShiftInMs()
: end;
}
WordPiece word_piece(word, offset + start, offset + end);
path.word_pieces.emplace_back(word_piece);
}
}
// if (post_processor_ != nullptr) {
// path.sentence = post_processor_->Process(path.sentence, finish);
// }
result_.emplace_back(path);
}
if (DecodedSomething()) {
VLOG(1) << "Partial CTC result " << result_[0].sentence;
}
}
void U2Recognizer::AttentionRescoring() {
decoder_->FinalizeSearch();
UpdateResult(true);
// No need to do rescoring
if (0.0 == opts_.decoder_opts.rescoring_weight) {
LOG_EVERY_N(WARNING, 3) << "Not do AttentionRescoring!";
return;
}
LOG_EVERY_N(WARNING, 3) << "Do AttentionRescoring!";
// Inputs() returns N-best input ids, which is the basic unit for rescoring
// In CtcPrefixBeamSearch, inputs are the same to outputs
const auto& hypotheses = decoder_->Inputs();
int num_hyps = hypotheses.size();
if (num_hyps <= 0) {
return;
}
kaldi::Timer timer;
std::vector<float> rescoring_score;
decodable_->AttentionRescoring(
hypotheses, opts_.decoder_opts.reverse_weight, &rescoring_score);
VLOG(1) << "Attention Rescoring takes " << timer.Elapsed() << " sec.";
// combine ctc score and rescoring score
for (size_t i = 0; i < num_hyps; i++) {
VLOG(1) << "hyp " << i << " rescoring_score: " << rescoring_score[i]
<< " ctc_score: " << result_[i].score;
result_[i].score = opts_.decoder_opts.rescoring_weight * rescoring_score[i] +
opts_.decoder_opts.ctc_weight * result_[i].score;
}
std::sort(result_.begin(), result_.end(), DecodeResult::CompareFunc);
VLOG(1) << "result: " << result_[0].sentence
<< " score: " << result_[0].score;
}
std::string U2Recognizer::GetFinalResult() {
return result_[0].sentence;
}
std::string U2Recognizer::GetPartialResult() {
return result_[0].sentence;
}
void U2Recognizer::SetFinished() {
feature_pipeline_->SetFinished();
input_finished_ = true;
}
} // namespace ppspeech

@ -0,0 +1,164 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "decoder/common.h"
#include "decoder/ctc_beam_search_opt.h"
#include "decoder/ctc_prefix_beam_search_decoder.h"
#include "decoder/decoder_itf.h"
#include "frontend/audio/feature_pipeline.h"
#include "nnet/decodable.h"
#include "fst/fstlib.h"
#include "fst/symbol-table.h"
namespace ppspeech {
struct DecodeOptions {
// chunk_size is the frame number of one chunk after subsampling.
// e.g. if subsample rate is 4 and chunk_size = 16, the frames in
// one chunk are 67=16*4 + 3, stride is 64=16*4
int chunk_size;
int num_left_chunks;
// final_score = rescoring_weight * rescoring_score + ctc_weight *
// ctc_score;
// rescoring_score = left_to_right_score * (1 - reverse_weight) +
// right_to_left_score * reverse_weight
// Please note the concept of ctc_scores
// in the following two search methods are different. For
// CtcPrefixBeamSerch,
// it's a sum(prefix) score + context score For CtcWfstBeamSerch, it's a
// max(viterbi) path score + context score So we should carefully set
// ctc_weight accroding to the search methods.
float ctc_weight;
float rescoring_weight;
float reverse_weight;
// CtcEndpointConfig ctc_endpoint_opts;
CTCBeamSearchOptions ctc_prefix_search_opts;
DecodeOptions()
: chunk_size(16),
num_left_chunks(-1),
ctc_weight(0.5),
rescoring_weight(1.0),
reverse_weight(0.0) {}
void Register(kaldi::OptionsItf* opts) {
std::string module = "DecoderConfig: ";
opts->Register(
"chunk-size",
&chunk_size,
module + "the frame number of one chunk after subsampling.");
opts->Register("num-left-chunks",
&num_left_chunks,
module + "the left history chunks number.");
opts->Register("ctc-weight",
&ctc_weight,
module +
"ctc weight for rescore. final_score = "
"rescoring_weight * rescoring_score + ctc_weight * "
"ctc_score.");
opts->Register("rescoring-weight",
&rescoring_weight,
module +
"attention score weight for rescore. final_score = "
"rescoring_weight * rescoring_score + ctc_weight * "
"ctc_score.");
opts->Register("reverse-weight",
&reverse_weight,
module +
"reverse decoder weight. rescoring_score = "
"left_to_right_score * (1 - reverse_weight) + "
"right_to_left_score * reverse_weight.");
}
};
struct U2RecognizerResource {
FeaturePipelineOptions feature_pipeline_opts{};
ModelOptions model_opts{};
DecodeOptions decoder_opts{};
// CTCBeamSearchOptions beam_search_opts;
kaldi::BaseFloat acoustic_scale{1.0};
std::string vocab_path{};
};
class U2Recognizer {
public:
explicit U2Recognizer(const U2RecognizerResource& resouce);
void Reset();
void ResetContinuousDecoding();
void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& waves);
void Decode();
void Rescoring();
std::string GetFinalResult();
std::string GetPartialResult();
void SetFinished();
bool IsFinished() { return input_finished_; }
bool DecodedSomething() const {
return !result_.empty() && !result_[0].sentence.empty();
}
int FrameShiftInMs() const {
// one decoder frame length in ms
return decodable_->Nnet()->SubsamplingRate() *
feature_pipeline_->FrameShift();
}
const std::vector<DecodeResult>& Result() const { return result_; }
private:
void AttentionRescoring();
void UpdateResult(bool finish = false);
private:
U2RecognizerResource opts_;
// std::shared_ptr<U2RecognizerResource> resource_;
// U2RecognizerResource resource_;
std::shared_ptr<FeaturePipeline> feature_pipeline_;
std::shared_ptr<Decodable> decodable_;
std::unique_ptr<CTCPrefixBeamSearch> decoder_;
// e2e unit symbol table
std::shared_ptr<fst::SymbolTable> unit_table_ = nullptr;
std::shared_ptr<fst::SymbolTable> symbol_table_ = nullptr;
std::vector<DecodeResult> result_;
// global decoded frame offset
int global_frame_offset_;
// cur decoded frame num
int num_frames_;
// timestamp gap between words in a sentence
const int time_stamp_gap_ = 100;
bool input_finished_;
};
} // namespace ppspeech

@ -0,0 +1,137 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder/u2_recognizer.h"
#include "decoder/param.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/table-types.h"
DEFINE_string(wav_rspecifier, "", "test feature rspecifier");
DEFINE_string(result_wspecifier, "", "test result wspecifier");
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
DEFINE_int32(sample_rate, 16000, "sample rate");
ppspeech::U2RecognizerResource InitOpts() {
ppspeech::U2RecognizerResource resource;
resource.acoustic_scale = FLAGS_acoustic_scale;
resource.feature_pipeline_opts = ppspeech::InitFeaturePipelineOptions();
ppspeech::ModelOptions model_opts;
model_opts.model_path = FLAGS_model_path;
resource.model_opts = model_opts;
ppspeech::DecodeOptions decoder_opts;
decoder_opts.chunk_size=16;
decoder_opts.num_left_chunks = -1;
decoder_opts.ctc_weight = 0.5;
decoder_opts.rescoring_weight = 1.0;
decoder_opts.reverse_weight = 0.3;
decoder_opts.ctc_prefix_search_opts.blank = 0;
decoder_opts.ctc_prefix_search_opts.first_beam_size = 10;
decoder_opts.ctc_prefix_search_opts.second_beam_size = 10;
resource.decoder_opts = decoder_opts;
return resource;
}
int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
int32 num_done = 0, num_err = 0;
double tot_wav_duration = 0.0;
ppspeech::U2RecognizerResource resource = InitOpts();
ppspeech::U2Recognizer recognizer(resource);
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
FLAGS_wav_rspecifier);
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
int sample_rate = FLAGS_sample_rate;
float streaming_chunk = FLAGS_streaming_chunk;
int chunk_sample_size = streaming_chunk * sample_rate;
LOG(INFO) << "sr: " << sample_rate;
LOG(INFO) << "chunk size (s): " << streaming_chunk;
LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
kaldi::Timer timer;
for (; !wav_reader.Done(); wav_reader.Next()) {
std::string utt = wav_reader.Key();
const kaldi::WaveData& wave_data = wav_reader.Value();
LOG(INFO) << "utt: " << utt;
LOG(INFO) << "wav dur: " << wave_data.Duration() << " sec.";
tot_wav_duration += wave_data.Duration();
int32 this_channel = 0;
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
this_channel);
int tot_samples = waveform.Dim();
LOG(INFO) << "wav len (sample): " << tot_samples;
int sample_offset = 0;
while (sample_offset < tot_samples) {
int cur_chunk_size =
std::min(chunk_sample_size, tot_samples - sample_offset);
kaldi::Vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk(i) = waveform(sample_offset + i);
}
// wav_chunk = waveform.Range(sample_offset + i, cur_chunk_size);
recognizer.Accept(wav_chunk);
if (cur_chunk_size < chunk_sample_size) {
recognizer.SetFinished();
}
recognizer.Decode();
LOG(INFO) << "Pratial result: " << recognizer.GetPartialResult();
// no overlap
sample_offset += cur_chunk_size;
}
// second pass decoding
recognizer.Rescoring();
std::string result = recognizer.GetFinalResult();
recognizer.Reset();
if (result.empty()) {
// the TokenWriter can not write empty string.
++num_err;
LOG(INFO) << " the result of " << utt << " is empty";
continue;
}
LOG(INFO) << " the result of " << utt << " is " << result;
result_writer.Write(utt, result);
++num_done;
}
double elapsed = timer.Elapsed();
LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done);
LOG(INFO) << "cost:" << elapsed << " sec";
LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec";
LOG(INFO) << "the RTF is: " << elapsed / tot_wav_duration;
}

@ -18,7 +18,7 @@ namespace ppspeech {
using std::unique_ptr;
FeaturePipeline::FeaturePipeline(const FeaturePipelineOptions& opts) {
FeaturePipeline::FeaturePipeline(const FeaturePipelineOptions& opts) : opts_(opts) {
unique_ptr<FrontendInterface> data_source(
new ppspeech::AudioCache(1000 * kint16max, opts.to_float32));

@ -26,7 +26,6 @@
#include "frontend/audio/normalizer.h"
namespace ppspeech {
struct FeaturePipelineOptions {
std::string cmvn_file;
bool to_float32; // true, only for linear feature
@ -60,7 +59,21 @@ class FeaturePipeline : public FrontendInterface {
virtual bool IsFinished() const { return base_extractor_->IsFinished(); }
virtual void Reset() { base_extractor_->Reset(); }
const FeaturePipelineOptions& Config() { return opts_; }
const BaseFloat FrameShift() const {
return opts_.fbank_opts.frame_opts.frame_shift_ms;
}
const BaseFloat FrameLength() const {
return opts_.fbank_opts.frame_opts.frame_length_ms;
}
const BaseFloat SampleRate() const {
return opts_.fbank_opts.frame_opts.samp_freq;
}
private:
FeaturePipelineOptions opts_;
std::unique_ptr<FrontendInterface> base_extractor_;
};
}
} // namespace ppspeech

@ -48,6 +48,7 @@ void PaddleNnet::InitCacheEncouts(const ModelOptions& opts) {
}
PaddleNnet::PaddleNnet(const ModelOptions& opts) : opts_(opts) {
subsampling_rate_ = opts.subsample_rate;
paddle_infer::Config config;
config.SetModel(opts.model_path, opts.param_path);
if (opts.use_gpu) {

@ -67,6 +67,7 @@ class PaddleNnet : public NnetInterface {
bool IsLogProb() override { return false; }
std::shared_ptr<Tensor<kaldi::BaseFloat>> GetCacheEncoder(
const std::string& name);
@ -85,6 +86,7 @@ class PaddleNnet : public NnetInterface {
std::map<paddle_infer::Predictor*, int> predictor_to_thread_id;
std::map<std::string, int> cache_names_idx_;
std::vector<std::shared_ptr<Tensor<kaldi::BaseFloat>>> cache_encouts_;
ModelOptions opts_;
public:

@ -35,6 +35,7 @@ struct ModelOptions {
std::string cache_shape;
bool enable_fc_padding;
bool enable_profile;
int subsample_rate;
ModelOptions()
: model_path(""),
param_path(""),
@ -46,7 +47,8 @@ struct ModelOptions {
cache_shape(""),
switch_ir_optim(false),
enable_fc_padding(false),
enable_profile(false) {}
enable_profile(false),
subsample_rate(0) {}
void Register(kaldi::OptionsItf* opts) {
opts->Register("model-path", &model_path, "model file path");
@ -102,9 +104,14 @@ class NnetInterface {
// true, nnet output is logprob; otherwise is prob,
virtual bool IsLogProb() = 0;
int SubsamplingRate() const { return subsampling_rate_; }
// using to get encoder outs. e.g. seq2seq with Attention model.
virtual void EncoderOuts(
std::vector<kaldi::Vector<kaldi::BaseFloat>>* encoder_out) const = 0;
protected:
int subsampling_rate_{1};
};
} // namespace ppspeech

@ -30,7 +30,7 @@ class U2NnetBase : public NnetInterface {
public:
virtual int context() const { return right_context_ + 1; }
virtual int right_context() const { return right_context_; }
virtual int subsampling_rate() const { return subsampling_rate_; }
virtual int eos() const { return eos_; }
virtual int sos() const { return sos_; }
virtual int is_bidecoder() const { return is_bidecoder_; }
@ -64,7 +64,6 @@ class U2NnetBase : public NnetInterface {
protected:
// model specification
int right_context_{0};
int subsampling_rate_{1};
int sos_{0};
int eos_{0};

@ -1,5 +1,3 @@
# project(websocket)
add_library(websocket STATIC
websocket_server.cc
websocket_client.cc

@ -17,11 +17,38 @@
DEFINE_int32(port, 8082, "websocket listening port");
ppspeech::RecognizerResource InitRecognizerResoure() {
ppspeech::RecognizerResource resource;
resource.acoustic_scale = FLAGS_acoustic_scale;
resource.feature_pipeline_opts = ppspeech::InitFeaturePipelineOptions();
ppspeech::ModelOptions model_opts;
model_opts.model_path = FLAGS_model_path;
model_opts.param_path = FLAGS_param_path;
model_opts.cache_names = FLAGS_model_cache_names;
model_opts.cache_shape = FLAGS_model_cache_shapes;
model_opts.input_names = FLAGS_model_input_names;
model_opts.output_names = FLAGS_model_output_names;
model_opts.subsample_rate = FLAGS_downsampling_rate;
resource.model_opts = model_opts;
ppspeech::TLGDecoderOptions decoder_opts;
decoder_opts.word_symbol_table = FLAGS_word_symbol_table;
decoder_opts.fst_path = FLAGS_graph_path;
decoder_opts.opts.max_active = FLAGS_max_active;
decoder_opts.opts.beam = FLAGS_beam;
decoder_opts.opts.lattice_beam = FLAGS_lattice_beam;
resource.tlg_opts = decoder_opts;
return resource;
}
int main(int argc, char *argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
ppspeech::RecognizerResource resource = ppspeech::InitRecognizerResoure();
ppspeech::RecognizerResource resource = InitRecognizerResoure();
ppspeech::WebSocketServer server(FLAGS_port, resource);
LOG(INFO) << "Listening at port " << FLAGS_port;

Loading…
Cancel
Save