add skip blank

pull/3173/head
YangZhou 2 years ago
parent 6e1afbeca1
commit 85a1744ecc

@ -71,7 +71,7 @@ int main(int argc, char* argv[]) {
std::shared_ptr<ppspeech::DataCache> raw_data =
std::make_shared<ppspeech::DataCache>();
std::shared_ptr<ppspeech::NnetProducer> nnet_producer =
std::make_shared<ppspeech::NnetProducer>(nnet, raw_data);
std::make_shared<ppspeech::NnetProducer>(nnet, raw_data, 1.0);
std::shared_ptr<ppspeech::Decodable> decodable =
std::make_shared<ppspeech::Decodable>(nnet_producer);

@ -44,7 +44,7 @@ struct TLGDecoderOptions {
decoder_opts.word_symbol_table = FLAGS_word_symbol_table;
decoder_opts.fst_path = FLAGS_graph_path;
LOG(INFO) << "fst path: " << decoder_opts.fst_path;
LOG(INFO) << "fst symbole table: " << decoder_opts.word_symbol_table;
LOG(INFO) << "symbole table: " << decoder_opts.word_symbol_table;
if (!decoder_opts.fst_path.empty()) {
CHECK(FileExists(decoder_opts.fst_path));

@ -54,7 +54,7 @@ int main(int argc, char* argv[]) {
ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags();
std::shared_ptr<ppspeech::NnetProducer> nnet_producer =
std::make_shared<ppspeech::NnetProducer>(nullptr);
std::make_shared<ppspeech::NnetProducer>(nullptr, nullptr, 1.0);
std::shared_ptr<ppspeech::Decodable> decodable(
new ppspeech::Decodable(nnet_producer, FLAGS_acoustic_scale));

@ -35,13 +35,11 @@ DEFINE_int32(subsampling_rate,
"two CNN(kernel=3) module downsampling rate.");
DEFINE_int32(nnet_decoder_chunk, 1, "paddle nnet forward chunk");
// nnet
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
#ifdef USE_ONNX
DEFINE_bool(with_onnx_model, false, "True mean the model path is onnx model path");
#endif
//DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
// decoder
DEFINE_double(acoustic_scale, 1.0, "acoustic scale");
@ -50,10 +48,9 @@ DEFINE_string(word_symbol_table, "", "word symbol table");
DEFINE_int32(max_active, 7500, "max active");
DEFINE_double(beam, 15.0, "decoder beam");
DEFINE_double(lattice_beam, 7.5, "decoder beam");
DEFINE_double(blank_threshold, 0.98, "blank skip threshold");
// DecodeOptions flags
// DEFINE_int32(chunk_size, -1, "decoding chunk size");
DEFINE_int32(num_left_chunks, -1, "left chunks in decoding");
DEFINE_double(ctc_weight,
0.5,

@ -22,8 +22,9 @@ using kaldi::BaseFloat;
using std::vector;
NnetProducer::NnetProducer(std::shared_ptr<NnetBase> nnet,
std::shared_ptr<FrontendInterface> frontend)
: nnet_(nnet), frontend_(frontend) {
std::shared_ptr<FrontendInterface> frontend,
float blank_threshold)
: nnet_(nnet), frontend_(frontend), blank_threshold_(blank_threshold) {
Reset();
}
@ -70,7 +71,22 @@ bool NnetProducer::Compute() {
std::vector<BaseFloat> logprob(
out.logprobs.data() + idx * vocab_dim,
out.logprobs.data() + (idx + 1) * vocab_dim);
cache_.push_back(logprob);
// process blank prob
float blank_prob = std::exp(logprob[0]);
if (blank_prob > blank_threshold_) {
last_frame_logprob_ = logprob;
is_last_frame_skip_ = true;
continue;
} else {
int cur_max = std::max(logprob.begin(), logprob.end()) - logprob.begin();
if (cur_max == last_max_elem_ && cur_max != 0 && is_last_frame_skip_) {
cache_.push_back(last_frame_logprob_);
last_max_elem_ = cur_max;
}
last_max_elem_ = cur_max;
is_last_frame_skip_ = false;
cache_.push_back(logprob);
}
}
return true;
}

@ -24,7 +24,8 @@ namespace ppspeech {
class NnetProducer {
public:
explicit NnetProducer(std::shared_ptr<NnetBase> nnet,
std::shared_ptr<FrontendInterface> frontend = NULL);
std::shared_ptr<FrontendInterface> frontend,
float blank_threshold);
// Feed feats or waves
void Accept(const std::vector<kaldi::BaseFloat>& inputs);
@ -64,6 +65,10 @@ class NnetProducer {
std::shared_ptr<FrontendInterface> frontend_;
std::shared_ptr<NnetBase> nnet_;
SafeQueue<std::vector<kaldi::BaseFloat>> cache_;
std::vector<BaseFloat> last_frame_logprob_;
bool is_last_frame_skip_ = false;
int last_max_elem_ = -1;
float blank_threshold_ = 0.0;
bool finished_;
DISALLOW_COPY_AND_ASSIGN(NnetProducer);

@ -21,6 +21,7 @@ namespace ppspeech {
RecognizerControllerImpl::RecognizerControllerImpl(const RecognizerResource& resource)
: opts_(resource) {
BaseFloat am_scale = resource.acoustic_scale;
BaseFloat blank_threshold = resource.blank_threshold;
const FeaturePipelineOptions& feature_opts = resource.feature_pipeline_opts;
std::shared_ptr<FeaturePipeline> feature_pipeline(
new FeaturePipeline(feature_opts));
@ -34,7 +35,7 @@ RecognizerControllerImpl::RecognizerControllerImpl(const RecognizerResource& res
nnet = resource.nnet->Clone();
}
#endif
nnet_producer_.reset(new NnetProducer(nnet, feature_pipeline));
nnet_producer_.reset(new NnetProducer(nnet, feature_pipeline, blank_threshold));
nnet_thread_ = std::thread(RunNnetEvaluation, this);
decodable_.reset(new Decodable(nnet_producer_, am_scale));

@ -12,6 +12,7 @@ DECLARE_double(reverse_weight);
DECLARE_int32(nbest);
DECLARE_int32(blank);
DECLARE_double(acoustic_scale);
DECLARE_double(blank_threshold);
DECLARE_string(word_symbol_table);
namespace ppspeech {
@ -71,6 +72,7 @@ struct DecodeOptions {
struct RecognizerResource {
// decodable opt
kaldi::BaseFloat acoustic_scale{1.0};
kaldi::BaseFloat blank_threshold{0.98};
FeaturePipelineOptions feature_pipeline_opts{};
ModelOptions model_opts{};
@ -80,6 +82,7 @@ struct RecognizerResource {
static RecognizerResource InitFromFlags() {
RecognizerResource resource;
resource.acoustic_scale = FLAGS_acoustic_scale;
resource.blank_threshold = FLAGS_blank_threshold;
LOG(INFO) << "acoustic_scale: " << resource.acoustic_scale;
resource.feature_pipeline_opts =

@ -11,5 +11,5 @@ fsttablecompose
foreach(binary IN LISTS BINS)
add_executable(${binary} ${CMAKE_CURRENT_SOURCE_DIR}/${binary}.cc)
target_include_directories(${binary} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${binary} PUBLIC kaldi-fstext glog libgflags_nothreads.so fst dl)
target_link_libraries(${binary} PUBLIC kaldi-fstext glog gflags fst dl)
endforeach()

@ -3,7 +3,7 @@ set -e
data=data
exp=exp
nj=40
nj=20
. utils/parse_options.sh

Loading…
Cancel
Save