add skip blank

pull/3173/head
YangZhou 2 years ago
parent 6e1afbeca1
commit 85a1744ecc

@ -71,7 +71,7 @@ int main(int argc, char* argv[]) {
std::shared_ptr<ppspeech::DataCache> raw_data = std::shared_ptr<ppspeech::DataCache> raw_data =
std::make_shared<ppspeech::DataCache>(); std::make_shared<ppspeech::DataCache>();
std::shared_ptr<ppspeech::NnetProducer> nnet_producer = std::shared_ptr<ppspeech::NnetProducer> nnet_producer =
std::make_shared<ppspeech::NnetProducer>(nnet, raw_data); std::make_shared<ppspeech::NnetProducer>(nnet, raw_data, 1.0);
std::shared_ptr<ppspeech::Decodable> decodable = std::shared_ptr<ppspeech::Decodable> decodable =
std::make_shared<ppspeech::Decodable>(nnet_producer); std::make_shared<ppspeech::Decodable>(nnet_producer);

@ -44,7 +44,7 @@ struct TLGDecoderOptions {
decoder_opts.word_symbol_table = FLAGS_word_symbol_table; decoder_opts.word_symbol_table = FLAGS_word_symbol_table;
decoder_opts.fst_path = FLAGS_graph_path; decoder_opts.fst_path = FLAGS_graph_path;
LOG(INFO) << "fst path: " << decoder_opts.fst_path; LOG(INFO) << "fst path: " << decoder_opts.fst_path;
LOG(INFO) << "fst symbole table: " << decoder_opts.word_symbol_table; LOG(INFO) << "symbole table: " << decoder_opts.word_symbol_table;
if (!decoder_opts.fst_path.empty()) { if (!decoder_opts.fst_path.empty()) {
CHECK(FileExists(decoder_opts.fst_path)); CHECK(FileExists(decoder_opts.fst_path));

@ -54,7 +54,7 @@ int main(int argc, char* argv[]) {
ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags(); ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags();
std::shared_ptr<ppspeech::NnetProducer> nnet_producer = std::shared_ptr<ppspeech::NnetProducer> nnet_producer =
std::make_shared<ppspeech::NnetProducer>(nullptr); std::make_shared<ppspeech::NnetProducer>(nullptr, nullptr, 1.0);
std::shared_ptr<ppspeech::Decodable> decodable( std::shared_ptr<ppspeech::Decodable> decodable(
new ppspeech::Decodable(nnet_producer, FLAGS_acoustic_scale)); new ppspeech::Decodable(nnet_producer, FLAGS_acoustic_scale));

@ -35,13 +35,11 @@ DEFINE_int32(subsampling_rate,
"two CNN(kernel=3) module downsampling rate."); "two CNN(kernel=3) module downsampling rate.");
DEFINE_int32(nnet_decoder_chunk, 1, "paddle nnet forward chunk"); DEFINE_int32(nnet_decoder_chunk, 1, "paddle nnet forward chunk");
// nnet // nnet
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model"); DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
#ifdef USE_ONNX #ifdef USE_ONNX
DEFINE_bool(with_onnx_model, false, "True mean the model path is onnx model path"); DEFINE_bool(with_onnx_model, false, "True mean the model path is onnx model path");
#endif #endif
//DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
// decoder // decoder
DEFINE_double(acoustic_scale, 1.0, "acoustic scale"); DEFINE_double(acoustic_scale, 1.0, "acoustic scale");
@ -50,10 +48,9 @@ DEFINE_string(word_symbol_table, "", "word symbol table");
DEFINE_int32(max_active, 7500, "max active"); DEFINE_int32(max_active, 7500, "max active");
DEFINE_double(beam, 15.0, "decoder beam"); DEFINE_double(beam, 15.0, "decoder beam");
DEFINE_double(lattice_beam, 7.5, "decoder beam"); DEFINE_double(lattice_beam, 7.5, "decoder beam");
DEFINE_double(blank_threshold, 0.98, "blank skip threshold");
// DecodeOptions flags // DecodeOptions flags
// DEFINE_int32(chunk_size, -1, "decoding chunk size");
DEFINE_int32(num_left_chunks, -1, "left chunks in decoding"); DEFINE_int32(num_left_chunks, -1, "left chunks in decoding");
DEFINE_double(ctc_weight, DEFINE_double(ctc_weight,
0.5, 0.5,

@ -22,8 +22,9 @@ using kaldi::BaseFloat;
using std::vector; using std::vector;
NnetProducer::NnetProducer(std::shared_ptr<NnetBase> nnet, NnetProducer::NnetProducer(std::shared_ptr<NnetBase> nnet,
std::shared_ptr<FrontendInterface> frontend) std::shared_ptr<FrontendInterface> frontend,
: nnet_(nnet), frontend_(frontend) { float blank_threshold)
: nnet_(nnet), frontend_(frontend), blank_threshold_(blank_threshold) {
Reset(); Reset();
} }
@ -70,7 +71,22 @@ bool NnetProducer::Compute() {
std::vector<BaseFloat> logprob( std::vector<BaseFloat> logprob(
out.logprobs.data() + idx * vocab_dim, out.logprobs.data() + idx * vocab_dim,
out.logprobs.data() + (idx + 1) * vocab_dim); out.logprobs.data() + (idx + 1) * vocab_dim);
cache_.push_back(logprob); // process blank prob
float blank_prob = std::exp(logprob[0]);
if (blank_prob > blank_threshold_) {
last_frame_logprob_ = logprob;
is_last_frame_skip_ = true;
continue;
} else {
int cur_max = std::max(logprob.begin(), logprob.end()) - logprob.begin();
if (cur_max == last_max_elem_ && cur_max != 0 && is_last_frame_skip_) {
cache_.push_back(last_frame_logprob_);
last_max_elem_ = cur_max;
}
last_max_elem_ = cur_max;
is_last_frame_skip_ = false;
cache_.push_back(logprob);
}
} }
return true; return true;
} }

@ -24,7 +24,8 @@ namespace ppspeech {
class NnetProducer { class NnetProducer {
public: public:
explicit NnetProducer(std::shared_ptr<NnetBase> nnet, explicit NnetProducer(std::shared_ptr<NnetBase> nnet,
std::shared_ptr<FrontendInterface> frontend = NULL); std::shared_ptr<FrontendInterface> frontend,
float blank_threshold);
// Feed feats or waves // Feed feats or waves
void Accept(const std::vector<kaldi::BaseFloat>& inputs); void Accept(const std::vector<kaldi::BaseFloat>& inputs);
@ -64,6 +65,10 @@ class NnetProducer {
std::shared_ptr<FrontendInterface> frontend_; std::shared_ptr<FrontendInterface> frontend_;
std::shared_ptr<NnetBase> nnet_; std::shared_ptr<NnetBase> nnet_;
SafeQueue<std::vector<kaldi::BaseFloat>> cache_; SafeQueue<std::vector<kaldi::BaseFloat>> cache_;
std::vector<BaseFloat> last_frame_logprob_;
bool is_last_frame_skip_ = false;
int last_max_elem_ = -1;
float blank_threshold_ = 0.0;
bool finished_; bool finished_;
DISALLOW_COPY_AND_ASSIGN(NnetProducer); DISALLOW_COPY_AND_ASSIGN(NnetProducer);

@ -21,6 +21,7 @@ namespace ppspeech {
RecognizerControllerImpl::RecognizerControllerImpl(const RecognizerResource& resource) RecognizerControllerImpl::RecognizerControllerImpl(const RecognizerResource& resource)
: opts_(resource) { : opts_(resource) {
BaseFloat am_scale = resource.acoustic_scale; BaseFloat am_scale = resource.acoustic_scale;
BaseFloat blank_threshold = resource.blank_threshold;
const FeaturePipelineOptions& feature_opts = resource.feature_pipeline_opts; const FeaturePipelineOptions& feature_opts = resource.feature_pipeline_opts;
std::shared_ptr<FeaturePipeline> feature_pipeline( std::shared_ptr<FeaturePipeline> feature_pipeline(
new FeaturePipeline(feature_opts)); new FeaturePipeline(feature_opts));
@ -34,7 +35,7 @@ RecognizerControllerImpl::RecognizerControllerImpl(const RecognizerResource& res
nnet = resource.nnet->Clone(); nnet = resource.nnet->Clone();
} }
#endif #endif
nnet_producer_.reset(new NnetProducer(nnet, feature_pipeline)); nnet_producer_.reset(new NnetProducer(nnet, feature_pipeline, blank_threshold));
nnet_thread_ = std::thread(RunNnetEvaluation, this); nnet_thread_ = std::thread(RunNnetEvaluation, this);
decodable_.reset(new Decodable(nnet_producer_, am_scale)); decodable_.reset(new Decodable(nnet_producer_, am_scale));

@ -12,6 +12,7 @@ DECLARE_double(reverse_weight);
DECLARE_int32(nbest); DECLARE_int32(nbest);
DECLARE_int32(blank); DECLARE_int32(blank);
DECLARE_double(acoustic_scale); DECLARE_double(acoustic_scale);
DECLARE_double(blank_threshold);
DECLARE_string(word_symbol_table); DECLARE_string(word_symbol_table);
namespace ppspeech { namespace ppspeech {
@ -71,6 +72,7 @@ struct DecodeOptions {
struct RecognizerResource { struct RecognizerResource {
// decodable opt // decodable opt
kaldi::BaseFloat acoustic_scale{1.0}; kaldi::BaseFloat acoustic_scale{1.0};
kaldi::BaseFloat blank_threshold{0.98};
FeaturePipelineOptions feature_pipeline_opts{}; FeaturePipelineOptions feature_pipeline_opts{};
ModelOptions model_opts{}; ModelOptions model_opts{};
@ -80,6 +82,7 @@ struct RecognizerResource {
static RecognizerResource InitFromFlags() { static RecognizerResource InitFromFlags() {
RecognizerResource resource; RecognizerResource resource;
resource.acoustic_scale = FLAGS_acoustic_scale; resource.acoustic_scale = FLAGS_acoustic_scale;
resource.blank_threshold = FLAGS_blank_threshold;
LOG(INFO) << "acoustic_scale: " << resource.acoustic_scale; LOG(INFO) << "acoustic_scale: " << resource.acoustic_scale;
resource.feature_pipeline_opts = resource.feature_pipeline_opts =

@ -11,5 +11,5 @@ fsttablecompose
foreach(binary IN LISTS BINS) foreach(binary IN LISTS BINS)
add_executable(${binary} ${CMAKE_CURRENT_SOURCE_DIR}/${binary}.cc) add_executable(${binary} ${CMAKE_CURRENT_SOURCE_DIR}/${binary}.cc)
target_include_directories(${binary} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) target_include_directories(${binary} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${binary} PUBLIC kaldi-fstext glog libgflags_nothreads.so fst dl) target_link_libraries(${binary} PUBLIC kaldi-fstext glog gflags fst dl)
endforeach() endforeach()

@ -3,7 +3,7 @@ set -e
data=data data=data
exp=exp exp=exp
nj=40 nj=20
. utils/parse_options.sh . utils/parse_options.sh

Loading…
Cancel
Save