[engine]add recognizer api && clean params && make a shared decoder resource (#3165)

pull/3188/head
YangZhou 2 years ago committed by GitHub
parent 11ce08b260
commit b05ead51d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -22,51 +22,22 @@ namespace ppspeech {
struct CTCBeamSearchOptions {
// common
int blank;
// ds2
std::string dict_file;
std::string lm_path;
int beam_size;
BaseFloat alpha;
BaseFloat beta;
BaseFloat cutoff_prob;
int cutoff_top_n;
int num_proc_bsearch;
std::string word_symbol_table;
// u2
int first_beam_size;
int second_beam_size;
CTCBeamSearchOptions()
: blank(0),
dict_file("vocab.txt"),
lm_path(""),
beam_size(300),
alpha(1.9f),
beta(5.0),
cutoff_prob(0.99f),
cutoff_top_n(40),
num_proc_bsearch(10),
word_symbol_table("vocab.txt"),
first_beam_size(10),
second_beam_size(10) {}
void Register(kaldi::OptionsItf* opts) {
std::string module = "Ds2BeamSearchConfig: ";
opts->Register("dict", &dict_file, module + "vocab file path.");
opts->Register(
"lm-path", &lm_path, module + "ngram language model path.");
opts->Register("alpha", &alpha, module + "alpha");
opts->Register("beta", &beta, module + "beta");
opts->Register("beam-size",
&beam_size,
module + "beam size for beam search method");
opts->Register("cutoff-prob", &cutoff_prob, module + "cutoff probs");
opts->Register("cutoff-top-n", &cutoff_top_n, module + "cutoff top n");
opts->Register(
"num-proc-bsearch", &num_proc_bsearch, module + "num proc bsearch");
std::string module = "CTCBeamSearchOptions: ";
opts->Register("word_symbol_table", &word_symbol_table, module + "vocab file path.");
opts->Register("blank", &blank, "blank id, default is 0.");
module = "U2BeamSearchConfig: ";
opts->Register(
"first-beam-size", &first_beam_size, module + "first beam size.");
opts->Register("second-beam-size",

@ -30,11 +30,10 @@ using paddle::platform::TracerEventType;
namespace ppspeech {
CTCPrefixBeamSearch::CTCPrefixBeamSearch(const std::string& vocab_path,
const CTCBeamSearchOptions& opts)
CTCPrefixBeamSearch::CTCPrefixBeamSearch(const CTCBeamSearchOptions& opts)
: opts_(opts) {
unit_table_ = std::shared_ptr<fst::SymbolTable>(
fst::SymbolTable::ReadText(vocab_path));
fst::SymbolTable::ReadText(opts.word_symbol_table));
CHECK(unit_table_ != nullptr);
Reset();

@ -27,8 +27,7 @@ namespace ppspeech {
class ContextGraph;
class CTCPrefixBeamSearch : public DecoderBase {
public:
CTCPrefixBeamSearch(const std::string& vocab_path,
const CTCBeamSearchOptions& opts);
CTCPrefixBeamSearch(const CTCBeamSearchOptions& opts);
~CTCPrefixBeamSearch() {}
SearchType Type() const { return SearchType::kPrefixBeamSearch; }

@ -23,7 +23,7 @@
DEFINE_string(feature_rspecifier, "", "test feature rspecifier");
DEFINE_string(result_wspecifier, "", "test result wspecifier");
DEFINE_string(vocab_path, "", "vocab path");
DEFINE_string(word_symbol_table, "", "vocab path");
DEFINE_string(model_path, "", "paddle nnet model");
@ -52,10 +52,10 @@ int main(int argc, char* argv[]) {
CHECK_NE(FLAGS_result_wspecifier, "");
CHECK_NE(FLAGS_feature_rspecifier, "");
CHECK_NE(FLAGS_vocab_path, "");
CHECK_NE(FLAGS_word_symbol_table, "");
CHECK_NE(FLAGS_model_path, "");
LOG(INFO) << "model path: " << FLAGS_model_path;
LOG(INFO) << "Reading vocab table " << FLAGS_vocab_path;
LOG(INFO) << "Reading vocab table " << FLAGS_word_symbol_table;
kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_rspecifier);
@ -80,7 +80,8 @@ int main(int argc, char* argv[]) {
opts.blank = 0;
opts.first_beam_size = 10;
opts.second_beam_size = 10;
ppspeech::CTCPrefixBeamSearch decoder(FLAGS_vocab_path, opts);
opts.word_symbol_table = FLAGS_word_symbol_table;
ppspeech::CTCPrefixBeamSearch decoder(opts);
int32 chunk_size = FLAGS_receptive_field_length +

@ -13,12 +13,14 @@
// limitations under the License.
#include "decoder/ctc_tlg_decoder.h"
namespace ppspeech {
TLGDecoder::TLGDecoder(TLGDecoderOptions opts) : opts_(opts) {
fst_.reset(fst::Fst<fst::StdArc>::Read(opts.fst_path));
fst_ = opts.fst_ptr;
CHECK(fst_ != nullptr);
CHECK(!opts.word_symbol_table.empty());
word_symbol_table_.reset(
fst::SymbolTable::ReadText(opts.word_symbol_table));

@ -18,6 +18,7 @@
#include "decoder/decoder_itf.h"
#include "kaldi/decoder/lattice-faster-online-decoder.h"
#include "util/parse-options.h"
#include "utils/file_utils.h"
DECLARE_string(word_symbol_table);
DECLARE_string(graph_path);
@ -33,9 +34,10 @@ struct TLGDecoderOptions {
// todo remove later, add into decode resource
std::string word_symbol_table;
std::string fst_path;
std::shared_ptr<fst::Fst<fst::StdArc>> fst_ptr;
int nbest;
TLGDecoderOptions() : word_symbol_table(""), fst_path(""), nbest(10) {}
TLGDecoderOptions() : word_symbol_table(""), fst_path(""), fst_ptr(nullptr), nbest(10) {}
static TLGDecoderOptions InitFromFlags() {
TLGDecoderOptions decoder_opts;
@ -44,6 +46,11 @@ struct TLGDecoderOptions {
LOG(INFO) << "fst path: " << decoder_opts.fst_path;
LOG(INFO) << "fst symbole table: " << decoder_opts.word_symbol_table;
if (!decoder_opts.fst_path.empty()) {
CHECK(FileExists(decoder_opts.fst_path));
decoder_opts.fst_ptr.reset(fst::Fst<fst::StdArc>::Read(FLAGS_graph_path));
}
decoder_opts.opts.max_active = FLAGS_max_active;
decoder_opts.opts.beam = FLAGS_beam;
decoder_opts.opts.lattice_beam = FLAGS_lattice_beam;

@ -37,28 +37,14 @@ DEFINE_int32(nnet_decoder_chunk, 1, "paddle nnet forward chunk");
// nnet
DEFINE_string(vocab_path, "", "nnet vocab path.");
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
#ifdef USE_ONNX
DEFINE_bool(with_onnx_model, false, "True mean the model path is onnx model path");
#endif
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
DEFINE_string(
model_input_names,
"audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box",
"model input names");
DEFINE_string(model_output_names,
"softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0",
"model output names");
DEFINE_string(model_cache_names,
"chunk_state_h_box,chunk_state_c_box",
"model cache names");
DEFINE_string(model_cache_shapes, "5-1-1024,5-1-1024", "model cache shapes");
//DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
// decoder
DEFINE_double(acoustic_scale, 1.0, "acoustic scale");
DEFINE_string(graph_path, "", "decoder graph");
DEFINE_string(word_symbol_table, "", "word symbol table");
DEFINE_int32(max_active, 7500, "max active");

@ -33,24 +33,12 @@ namespace ppspeech {
struct ModelOptions {
// common
int subsample_rate{1};
int thread_num{1}; // predictor thread pool size for ds2;
bool use_gpu{false};
std::string model_path;
#ifdef USE_ONNX
bool with_onnx_model{false};
#endif
std::string param_path;
// ds2 for inference
std::string input_names{};
std::string output_names{};
std::string cache_names{};
std::string cache_shape{};
bool switch_ir_optim{false};
bool enable_fc_padding{false};
bool enable_profile{false};
static ModelOptions InitFromFlags() {
ModelOptions opts;
opts.subsample_rate = FLAGS_subsampling_rate;
@ -61,19 +49,6 @@ struct ModelOptions {
opts.with_onnx_model = FLAGS_with_onnx_model;
LOG(INFO) << "with onnx model: " << opts.with_onnx_model;
#endif
opts.param_path = FLAGS_param_path;
LOG(INFO) << "param path: " << opts.param_path;
LOG(INFO) << "DS2 param: ";
opts.cache_names = FLAGS_model_cache_names;
LOG(INFO) << " cache names: " << opts.cache_names;
opts.cache_shape = FLAGS_model_cache_shapes;
LOG(INFO) << " cache shape: " << opts.cache_shape;
opts.input_names = FLAGS_model_input_names;
LOG(INFO) << " input names: " << opts.input_names;
opts.output_names = FLAGS_model_output_names;
LOG(INFO) << " output names: " << opts.output_names;
return opts;
}
};
@ -121,7 +96,7 @@ class NnetInterface {
class NnetBase : public NnetInterface {
public:
int SubsamplingRate() const { return subsampling_rate_; }
virtual std::shared_ptr<NnetBase> Clone() const = 0;
protected:
int subsampling_rate_{1};
};

@ -45,7 +45,7 @@ void NnetProducer::Acceptlikelihood(
bool NnetProducer::Read(std::vector<kaldi::BaseFloat>* nnet_prob) {
bool flag = cache_.pop(nnet_prob);
LOG(INFO) << "nnet cache_ size: " << cache_.size();
VLOG(1) << "nnet cache_ size: " << cache_.size();
return flag;
}
@ -53,7 +53,6 @@ bool NnetProducer::Compute() {
vector<BaseFloat> features;
if (frontend_ == NULL || frontend_->Read(&features) == false) {
// no feat or frontend_ not init.
LOG(INFO) << "no feat avalible";
if (frontend_->IsFinished() == true) {
finished_ = true;
}

@ -3,6 +3,8 @@ set(srcs)
list(APPEND srcs
recognizer_controller.cc
recognizer_controller_impl.cc
recognizer_instance.cc
recognizer.cc
)
add_library(recognizer STATIC ${srcs})
@ -10,6 +12,7 @@ target_link_libraries(recognizer PUBLIC decoder)
set(TEST_BINS
recognizer_batch_main
recognizer_batch_main2
recognizer_main
)

@ -11,3 +11,36 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "recognizer/recognizer.h"
#include "recognizer/recognizer_instance.h"
bool InitRecognizer(const std::string& model_file,
const std::string& word_symbol_table_file,
const std::string& fst_file,
int num_instance) {
return ppspeech::RecognizerInstance::GetInstance().Init(model_file,
word_symbol_table_file,
fst_file,
num_instance);
}
int GetRecognizerInstanceId() {
return ppspeech::RecognizerInstance::GetInstance().GetRecognizerInstanceId();
}
void InitDecoder(int instance_id) {
return ppspeech::RecognizerInstance::GetInstance().InitDecoder(instance_id);
}
void AcceptData(const std::vector<float>& waves, int instance_id) {
return ppspeech::RecognizerInstance::GetInstance().Accept(waves, instance_id);
}
void SetInputFinished(int instance_id) {
return ppspeech::RecognizerInstance::GetInstance().SetInputFinished(instance_id);
}
std::string GetFinalResult(int instance_id) {
return ppspeech::RecognizerInstance::GetInstance().GetResult(instance_id);
}

@ -11,3 +11,18 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
bool InitRecognizer(const std::string& model_file,
const std::string& word_symbol_table_file,
const std::string& fst_file,
int num_instance);
int GetRecognizerInstanceId();
void InitDecoder(int instance_id);
void AcceptData(const std::vector<float>& waves, int instance_id);
void SetInputFinished(int instance_id);
std::string GetFinalResult(int instance_id);

@ -0,0 +1,168 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "common/base/thread_pool.h"
#include "common/utils/file_utils.h"
#include "common/utils/strings.h"
#include "decoder/param.h"
#include "frontend/wave-reader.h"
#include "kaldi/util/table-types.h"
#include "nnet/u2_nnet.h"
#include "recognizer/recognizer.h"
DEFINE_string(wav_rspecifier, "", "test feature rspecifier");
DEFINE_string(result_wspecifier, "", "test result wspecifier");
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
DEFINE_int32(sample_rate, 16000, "sample rate");
DEFINE_int32(njob, 3, "njob");
using std::string;
using std::vector;
void SplitUtt(string wavlist_file,
vector<vector<string>>* uttlists,
vector<vector<string>>* wavlists,
int njob) {
vector<string> wavlist;
wavlists->resize(njob);
uttlists->resize(njob);
ppspeech::ReadFileToVector(wavlist_file, &wavlist);
for (size_t idx = 0; idx < wavlist.size(); ++idx) {
string utt_str = wavlist[idx];
vector<string> utt_wav = ppspeech::StrSplit(utt_str, " \t");
LOG(INFO) << utt_wav[0];
CHECK_EQ(utt_wav.size(), size_t(2));
uttlists->at(idx % njob).push_back(utt_wav[0]);
wavlists->at(idx % njob).push_back(utt_wav[1]);
}
}
void recognizer_func(std::vector<string> wavlist,
std::vector<string> uttlist,
std::vector<string>* results) {
int32 num_done = 0, num_err = 0;
double tot_wav_duration = 0.0;
double tot_attention_rescore_time = 0.0;
double tot_decode_time = 0.0;
int chunk_sample_size = FLAGS_streaming_chunk * FLAGS_sample_rate;
if (wavlist.empty()) return;
results->reserve(wavlist.size());
for (size_t idx = 0; idx < wavlist.size(); ++idx) {
std::string utt = uttlist[idx];
std::string wav_file = wavlist[idx];
std::ifstream infile;
infile.open(wav_file, std::ifstream::in);
kaldi::WaveData wave_data;
wave_data.Read(infile);
int32 recog_id = -1;
while (recog_id == -1) {
recog_id = GetRecognizerInstanceId();
}
InitDecoder(recog_id);
LOG(INFO) << "utt: " << utt;
LOG(INFO) << "wav dur: " << wave_data.Duration() << " sec.";
double dur = wave_data.Duration();
tot_wav_duration += dur;
int32 this_channel = 0;
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
this_channel);
int tot_samples = waveform.Dim();
LOG(INFO) << "wav len (sample): " << tot_samples;
int sample_offset = 0;
kaldi::Timer local_timer;
while (sample_offset < tot_samples) {
int cur_chunk_size =
std::min(chunk_sample_size, tot_samples - sample_offset);
std::vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk[i] = waveform(sample_offset + i);
}
AcceptData(wav_chunk, recog_id);
// no overlap
sample_offset += cur_chunk_size;
}
SetInputFinished(recog_id);
CHECK(sample_offset == tot_samples);
std::string result = GetFinalResult(recog_id);
if (result.empty()) {
// the TokenWriter can not write empty string.
++num_err;
LOG(INFO) << " the result of " << utt << " is empty";
result = " ";
}
tot_decode_time += local_timer.Elapsed();
LOG(INFO) << utt << " " << result;
LOG(INFO) << " RTF: " << local_timer.Elapsed() / dur << " dur: " << dur
<< " cost: " << local_timer.Elapsed();
results->push_back(result);
++num_done;
}
LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done);
LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec";
LOG(INFO) << "total decode cost:" << tot_decode_time << " sec";
LOG(INFO) << "RTF is: " << tot_decode_time / tot_wav_duration;
}
int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
int sample_rate = FLAGS_sample_rate;
float streaming_chunk = FLAGS_streaming_chunk;
int chunk_sample_size = streaming_chunk * sample_rate;
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
int njob = FLAGS_njob;
LOG(INFO) << "sr: " << sample_rate;
LOG(INFO) << "chunk size (s): " << streaming_chunk;
LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
InitRecognizer(FLAGS_model_path, FLAGS_word_symbol_table, FLAGS_graph_path, njob);
ThreadPool threadpool(njob);
vector<vector<string>> wavlist;
vector<vector<string>> uttlist;
vector<vector<string>> resultlist(njob);
vector<std::future<void>> futurelist;
SplitUtt(FLAGS_wav_rspecifier, &uttlist, &wavlist, njob);
for (size_t i = 0; i < njob; ++i) {
std::future<void> f = threadpool.enqueue(recognizer_func,
wavlist[i],
uttlist[i],
&resultlist[i]);
futurelist.push_back(std::move(f));
}
for (size_t i = 0; i < njob; ++i) {
futurelist[i].get();
}
for (size_t idx = 0; idx < njob; ++idx) {
for (size_t utt_idx = 0; utt_idx < uttlist[idx].size(); ++utt_idx) {
string utt = uttlist[idx][utt_idx];
string result = resultlist[idx][utt_idx];
result_writer.Write(utt, result);
}
}
return 0;
}

@ -18,10 +18,9 @@
namespace ppspeech {
RecognizerController::RecognizerController(int num_worker, RecognizerResource resource) {
nnet_ = std::make_shared<ppspeech::U2Nnet>(resource.model_opts);
recognizer_workers.resize(num_worker);
for (size_t i = 0; i < num_worker; ++i) {
recognizer_workers[i].reset(new ppspeech::RecognizerControllerImpl(resource, nnet_->Clone()));
recognizer_workers[i].reset(new ppspeech::RecognizerControllerImpl(resource));
waiting_workers.push(i);
}
}

@ -18,7 +18,6 @@
#include <memory>
#include "recognizer/recognizer_controller_impl.h"
#include "nnet/u2_nnet.h"
namespace ppspeech {
@ -34,7 +33,6 @@ class RecognizerController {
private:
std::queue<int> waiting_workers;
std::shared_ptr<ppspeech::U2Nnet> nnet_;
std::mutex mutex_;
std::vector<std::unique_ptr<ppspeech::RecognizerControllerImpl>> recognizer_workers;

@ -26,24 +26,24 @@ RecognizerControllerImpl::RecognizerControllerImpl(const RecognizerResource& res
new FeaturePipeline(feature_opts));
std::shared_ptr<NnetBase> nnet;
#ifndef USE_ONNX
nnet.reset(new U2Nnet(resource.model_opts));
nnet = resource.nnet->Clone();
#else
if (resource.model_opts.with_onnx_model){
nnet.reset(new U2OnnxNnet(resource.model_opts));
} else {
nnet.reset(new U2Nnet(resource.model_opts));
nnet = resource.nnet->Clone();
}
#endif
nnet_producer_.reset(new NnetProducer(nnet, feature_pipeline));
nnet_thread_ = std::thread(RunNnetEvaluation, this);
decodable_.reset(new Decodable(nnet_producer_, am_scale));
CHECK_NE(resource.vocab_path, "");
if (resource.decoder_opts.tlg_decoder_opts.fst_path.empty()) {
LOG(INFO) << resource.decoder_opts.tlg_decoder_opts.fst_path;
LOG(INFO) << "Init PrefixBeamSearch Decoder";
decoder_ = std::make_unique<CTCPrefixBeamSearch>(
resource.vocab_path, resource.decoder_opts.ctc_prefix_search_opts);
resource.decoder_opts.ctc_prefix_search_opts);
} else {
LOG(INFO) << "Init TLGDecoder";
decoder_ = std::make_unique<TLGDecoder>(
resource.decoder_opts.tlg_decoder_opts);
}
@ -55,33 +55,6 @@ RecognizerControllerImpl::RecognizerControllerImpl(const RecognizerResource& res
result_.clear();
}
RecognizerControllerImpl::RecognizerControllerImpl(const RecognizerResource& resource,
std::shared_ptr<NnetBase> nnet)
:opts_(resource) {
BaseFloat am_scale = resource.acoustic_scale;
const FeaturePipelineOptions& feature_opts = resource.feature_pipeline_opts;
std::shared_ptr<FeaturePipeline> feature_pipeline =
std::make_shared<FeaturePipeline>(feature_opts);
nnet_producer_ = std::make_shared<NnetProducer>(nnet, feature_pipeline);
nnet_thread_ = std::thread(RunNnetEvaluation, this);
decodable_.reset(new Decodable(nnet_producer_, am_scale));
CHECK_NE(resource.vocab_path, "");
if (resource.decoder_opts.tlg_decoder_opts.fst_path == "") {
decoder_.reset(new CTCPrefixBeamSearch(
resource.vocab_path, resource.decoder_opts.ctc_prefix_search_opts));
} else {
decoder_.reset(new TLGDecoder(resource.decoder_opts.tlg_decoder_opts));
}
symbol_table_ = decoder_->WordSymbolTable();
global_frame_offset_ = 0;
input_finished_ = false;
num_frames_ = 0;
result_.clear();
}
RecognizerControllerImpl::~RecognizerControllerImpl() {
WaitFinished();
}

@ -32,8 +32,6 @@ namespace ppspeech {
class RecognizerControllerImpl {
public:
explicit RecognizerControllerImpl(const RecognizerResource& resource);
explicit RecognizerControllerImpl(const RecognizerResource& resource,
std::shared_ptr<NnetBase> nnet);
~RecognizerControllerImpl();
void Accept(std::vector<float> data);
void InitDecoder();

@ -1,13 +0,0 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

@ -1,13 +0,0 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

@ -0,0 +1,66 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "recognizer/recognizer_instance.h"
namespace ppspeech {
RecognizerInstance& RecognizerInstance::GetInstance() {
static RecognizerInstance instance;
return instance;
}
bool RecognizerInstance::Init(const std::string& model_file,
const std::string& word_symbol_table_file,
const std::string& fst_file,
int num_instance) {
RecognizerResource resource = RecognizerResource::InitFromFlags();
resource.model_opts.model_path = model_file;
//resource.vocab_path = word_symbol_table_file;
if (!fst_file.empty()) {
resource.decoder_opts.tlg_decoder_opts.fst_path = fst_file;
resource.decoder_opts.tlg_decoder_opts.fst_path = word_symbol_table_file;
} else {
resource.decoder_opts.ctc_prefix_search_opts.word_symbol_table =
word_symbol_table_file;
}
recognizer_controller_ = std::make_unique<RecognizerController>(num_instance, resource);
return true;
}
void RecognizerInstance::InitDecoder(int idx) {
recognizer_controller_->InitDecoder(idx);
return;
}
int RecognizerInstance::GetRecognizerInstanceId() {
return recognizer_controller_->GetRecognizerInstanceId();
}
void RecognizerInstance::Accept(const std::vector<float>& waves, int idx) const {
recognizer_controller_->Accept(waves, idx);
return;
}
void RecognizerInstance::SetInputFinished(int idx) const {
recognizer_controller_->SetInputFinished(idx);
return;
}
std::string RecognizerInstance::GetResult(int idx) const {
return recognizer_controller_->GetFinalResult(idx);
}
}

@ -0,0 +1,42 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base/common.h"
#include "recognizer/recognizer_controller.h"
namespace ppspeech {
class RecognizerInstance {
public:
static RecognizerInstance& GetInstance();
RecognizerInstance() {}
~RecognizerInstance() {}
bool Init(const std::string& model_file,
const std::string& word_symbol_table_file,
const std::string& fst_file,
int num_instance);
int GetRecognizerInstanceId();
void InitDecoder(int idx);
void Accept(const std::vector<float>& waves, int idx) const;
void SetInputFinished(int idx) const;
std::string GetResult(int idx) const;
private:
std::unique_ptr<RecognizerController> recognizer_controller_;
};
} // namespace ppspeech

@ -12,7 +12,6 @@ DECLARE_double(reverse_weight);
DECLARE_int32(nbest);
DECLARE_int32(blank);
DECLARE_double(acoustic_scale);
DECLARE_string(vocab_path);
DECLARE_string(word_symbol_table);
namespace ppspeech {
@ -52,6 +51,8 @@ struct DecodeOptions {
decoder_opts.ctc_prefix_search_opts.blank = FLAGS_blank;
decoder_opts.ctc_prefix_search_opts.first_beam_size = FLAGS_nbest;
decoder_opts.ctc_prefix_search_opts.second_beam_size = FLAGS_nbest;
decoder_opts.ctc_prefix_search_opts.word_symbol_table =
FLAGS_word_symbol_table;
decoder_opts.tlg_decoder_opts =
ppspeech::TLGDecoderOptions::InitFromFlags();
@ -68,18 +69,17 @@ struct DecodeOptions {
};
struct RecognizerResource {
// decodable opt
kaldi::BaseFloat acoustic_scale{1.0};
std::string vocab_path{};
FeaturePipelineOptions feature_pipeline_opts{};
ModelOptions model_opts{};
DecodeOptions decoder_opts{};
std::shared_ptr<NnetBase> nnet;
static RecognizerResource InitFromFlags() {
RecognizerResource resource;
resource.vocab_path = FLAGS_vocab_path;
resource.acoustic_scale = FLAGS_acoustic_scale;
LOG(INFO) << "vocab path: " << resource.vocab_path;
LOG(INFO) << "acoustic_scale: " << resource.acoustic_scale;
resource.feature_pipeline_opts =
@ -89,6 +89,15 @@ struct RecognizerResource {
<< resource.feature_pipeline_opts.assembler_opts.fill_zero;
resource.model_opts = ppspeech::ModelOptions::InitFromFlags();
resource.decoder_opts = ppspeech::DecodeOptions::InitFromFlags();
#ifndef USE_ONNX
resource.nnet.reset(new U2Nnet(resource.model_opts));
#else
if (resource.model_opts.with_onnx_model){
resource.nnet.reset(new U2OnnxNnet(resource.model_opts));
} else {
resource.nnet.reset(new U2Nnet(resource.model_opts));
}
#endif
return resource;
}
};

@ -14,6 +14,8 @@
#include "utils/file_utils.h"
#include <sys/stat.h>
namespace ppspeech {
bool ReadFileToVector(const std::string& filename,
@ -40,4 +42,31 @@ std::string ReadFile2String(const std::string& path) {
return std::string((std::istreambuf_iterator<char>(input_file)),
std::istreambuf_iterator<char>());
}
bool FileExists(const std::string& strFilename) {
// this funciton if from:
// https://github.com/kaldi-asr/kaldi/blob/master/src/fstext/deterministic-fst-test.cc
struct stat stFileInfo;
bool blnReturn;
int intStat;
// Attempt to get the file attributes
intStat = stat(strFilename.c_str(), &stFileInfo);
if (intStat == 0) {
// We were able to get the file attributes
// so the file obviously exists.
blnReturn = true;
} else {
// We were not able to get the file attributes.
// This may mean that we don't have permission to
// access the folder which contains this file. If you
// need to do that level of checking, lookup the
// return values of stat which will give you
// more details on why stat failed.
blnReturn = false;
}
return blnReturn;
}
} // namespace ppspeech

@ -20,4 +20,7 @@ bool ReadFileToVector(const std::string& filename,
std::vector<std::string>* data);
std::string ReadFile2String(const std::string& path);
bool FileExists(const std::string& filename);
} // namespace ppspeech

@ -14,7 +14,7 @@ text=$data/test/text
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/decoder.log \
ctc_prefix_beam_search_decoder_main \
--model_path=$model_dir/export.jit \
--vocab_path=$model_dir/unit.txt \
--word_symbol_table=$model_dir/unit.txt \
--nnet_decoder_chunk=16 \
--receptive_field_length=7 \
--subsampling_rate=4 \

@ -21,7 +21,7 @@ recognizer_main \
--num_bins=80 \
--cmvn_file=$model_dir/mean_std.json \
--model_path=$model_dir/export.jit \
--vocab_path=$model_dir/unit.txt \
--word_symbol_table=$model_dir/unit.txt \
--nnet_decoder_chunk=16 \
--receptive_field_length=7 \
--subsampling_rate=4 \

@ -21,7 +21,7 @@ u2_recognizer_main \
--num_bins=80 \
--cmvn_file=$model_dir/mean_std.json \
--model_path=$model_dir/export \
--vocab_path=$model_dir/unit.txt \
--word_symbol_table=$model_dir/unit.txt \
--nnet_decoder_chunk=16 \
--receptive_field_length=7 \
--subsampling_rate=4 \

@ -0,0 +1,41 @@
#!/bin/bash
set -e
data=data
exp=exp
nj=40
. utils/parse_options.sh
mkdir -p $exp
ckpt_dir=./data/model
model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.3.0.model/
aishell_wav_scp=aishell_test.scp
text=$data/test/text
./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj
lang_dir=./data/lang_test/
graph=$lang_dir/TLG.fst
word_table=$lang_dir/words.txt
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recognizer_wfst.log \
recognizer_main \
--use_fbank=true \
--num_bins=80 \
--cmvn_file=$model_dir/mean_std.json \
--model_path=$model_dir/export.jit \
--graph_path=$lang_dir/TLG.fst \
--word_symbol_table=$word_table \
--nnet_decoder_chunk=16 \
--receptive_field_length=7 \
--subsampling_rate=4 \
--wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \
--result_wspecifier=ark,t:$data/split${nj}/JOB/result_recognizer_wfst.ark
cat $data/split${nj}/*/result_recognizer_wfst.ark > $exp/aishell_recognizer_wfst
utils/compute-wer.py --char=1 --v=1 $text $exp/aishell_recognizer_wfst > $exp/aishell.recognizer_wfst.err
echo "recognizer test have finished!!!"
echo "please checkout in $exp/aishell.recognizer_wfst.err"
tail -n 7 $exp/aishell.recognizer_wfst.err
Loading…
Cancel
Save