refactor options

pull/2524/head
Hui Zhang 3 years ago
parent 17ea30e7ca
commit 616fc4594b

@ -14,7 +14,7 @@ ctc_prefix_beam_search_decoder_main \
--model_path=$model_dir/export.jit \ --model_path=$model_dir/export.jit \
--nnet_decoder_chunk=16 \ --nnet_decoder_chunk=16 \
--receptive_field_length=7 \ --receptive_field_length=7 \
--downsampling_rate=4 \ --subsampling_rate=4 \
--vocab_path=$model_dir/unit.txt \ --vocab_path=$model_dir/unit.txt \
--feature_rspecifier=ark,t:$exp/fbank.ark \ --feature_rspecifier=ark,t:$exp/fbank.ark \
--result_wspecifier=ark,t:$exp/result.ark --result_wspecifier=ark,t:$exp/result.ark

@ -15,7 +15,7 @@ u2_nnet_main \
--feature_rspecifier=ark,t:$exp/fbank.ark \ --feature_rspecifier=ark,t:$exp/fbank.ark \
--nnet_decoder_chunk=16 \ --nnet_decoder_chunk=16 \
--receptive_field_length=7 \ --receptive_field_length=7 \
--downsampling_rate=4 \ --subsampling_rate=4 \
--acoustic_scale=1.0 \ --acoustic_scale=1.0 \
--nnet_encoder_outs_wspecifier=ark,t:$exp/encoder_outs.ark \ --nnet_encoder_outs_wspecifier=ark,t:$exp/encoder_outs.ark \
--nnet_prob_wspecifier=ark,t:$exp/logprobs.ark --nnet_prob_wspecifier=ark,t:$exp/logprobs.ark

@ -16,7 +16,7 @@ u2_recognizer_main \
--model_path=$model_dir/export.jit \ --model_path=$model_dir/export.jit \
--nnet_decoder_chunk=16 \ --nnet_decoder_chunk=16 \
--receptive_field_length=7 \ --receptive_field_length=7 \
--downsampling_rate=4 \ --subsampling_rate=4 \
--vocab_path=$model_dir/unit.txt \ --vocab_path=$model_dir/unit.txt \
--wav_rspecifier=scp:$data/wav.scp \ --wav_rspecifier=scp:$data/wav.scp \
--result_wspecifier=ark,t:$exp/result.ark --result_wspecifier=ark,t:$exp/result.ark

@ -1,44 +1,61 @@
project(decoder) project(decoder)
include_directories(${CMAKE_CURRENT_SOURCE_DIR/ctc_decoders}) include_directories(${CMAKE_CURRENT_SOURCE_DIR/ctc_decoders})
add_library(decoder STATIC
ctc_decoders/decoder_utils.cpp set(decoder_src )
ctc_decoders/path_trie.cpp
ctc_decoders/scorer.cpp if (USING_DS2)
ctc_beam_search_decoder.cc list(APPEND decoder_src
ctc_decoders/decoder_utils.cpp
ctc_decoders/path_trie.cpp
ctc_decoders/scorer.cpp
ctc_beam_search_decoder.cc
ctc_tlg_decoder.cc
recognizer.cc
)
endif()
if (USING_U2)
list(APPEND decoder_src
ctc_prefix_beam_search_decoder.cc ctc_prefix_beam_search_decoder.cc
ctc_tlg_decoder.cc
recognizer.cc
u2_recognizer.cc u2_recognizer.cc
) )
endif()
add_library(decoder STATIC ${decoder_src})
target_link_libraries(decoder PUBLIC kenlm utils fst frontend nnet kaldi-decoder absl::strings) target_link_libraries(decoder PUBLIC kenlm utils fst frontend nnet kaldi-decoder absl::strings)
# test # test
set(BINS if (USING_DS2)
set(BINS
ctc_beam_search_decoder_main ctc_beam_search_decoder_main
nnet_logprob_decoder_main nnet_logprob_decoder_main
recognizer_main recognizer_main
ctc_tlg_decoder_main ctc_tlg_decoder_main
) )
foreach(bin_name IN LISTS BINS) foreach(bin_name IN LISTS BINS)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS}) target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
endforeach() endforeach()
endif()
# u2 if (USING_U2)
set(TEST_BINS set(TEST_BINS
u2_recognizer_main
ctc_prefix_beam_search_decoder_main ctc_prefix_beam_search_decoder_main
) u2_recognizer_main
)
foreach(bin_name IN LISTS TEST_BINS) foreach(bin_name IN LISTS TEST_BINS)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util) target_link_libraries(${bin_name} nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util)
target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS}) target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS})
target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS}) target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS})
endforeach() endforeach()
endif()

@ -31,7 +31,7 @@ DEFINE_string(lm_path, "", "language model");
DEFINE_int32(receptive_field_length, DEFINE_int32(receptive_field_length,
7, 7,
"receptive field of two CNN(kernel=3) downsampling module."); "receptive field of two CNN(kernel=3) downsampling module.");
DEFINE_int32(downsampling_rate, DEFINE_int32(subsampling_rate,
4, 4,
"two CNN(kernel=3) module downsampling rate."); "two CNN(kernel=3) module downsampling rate.");
DEFINE_string( DEFINE_string(
@ -81,13 +81,8 @@ int main(int argc, char* argv[]) {
opts.lm_path = lm_path; opts.lm_path = lm_path;
ppspeech::CTCBeamSearch decoder(opts); ppspeech::CTCBeamSearch decoder(opts);
ppspeech::ModelOptions model_opts; ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags();
model_opts.model_path = model_path;
model_opts.param_path = model_params;
model_opts.cache_names = FLAGS_model_cache_names;
model_opts.cache_shape = FLAGS_model_cache_shapes;
model_opts.input_names = FLAGS_model_input_names;
model_opts.output_names = FLAGS_model_output_names;
std::shared_ptr<ppspeech::PaddleNnet> nnet( std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts)); new ppspeech::PaddleNnet(model_opts));
std::shared_ptr<ppspeech::DataCache> raw_data(new ppspeech::DataCache()); std::shared_ptr<ppspeech::DataCache> raw_data(new ppspeech::DataCache());
@ -95,8 +90,8 @@ int main(int argc, char* argv[]) {
new ppspeech::Decodable(nnet, raw_data)); new ppspeech::Decodable(nnet, raw_data));
int32 chunk_size = FLAGS_receptive_field_length + int32 chunk_size = FLAGS_receptive_field_length +
(FLAGS_nnet_decoder_chunk - 1) * FLAGS_downsampling_rate; (FLAGS_nnet_decoder_chunk - 1) * FLAGS_subsampling_rate;
int32 chunk_stride = FLAGS_downsampling_rate * FLAGS_nnet_decoder_chunk; int32 chunk_stride = FLAGS_subsampling_rate * FLAGS_nnet_decoder_chunk;
int32 receptive_field_length = FLAGS_receptive_field_length; int32 receptive_field_length = FLAGS_receptive_field_length;
LOG(INFO) << "chunk size (frame): " << chunk_size; LOG(INFO) << "chunk size (frame): " << chunk_size;
LOG(INFO) << "chunk stride (frame): " << chunk_stride; LOG(INFO) << "chunk stride (frame): " << chunk_stride;

@ -30,7 +30,7 @@ DEFINE_string(model_path, "", "paddle nnet model");
DEFINE_int32(receptive_field_length, DEFINE_int32(receptive_field_length,
7, 7,
"receptive field of two CNN(kernel=3) downsampling module."); "receptive field of two CNN(kernel=3) downsampling module.");
DEFINE_int32(downsampling_rate, DEFINE_int32(subsampling_rate,
4, 4,
"two CNN(kernel=3) module downsampling rate."); "two CNN(kernel=3) module downsampling rate.");
@ -81,8 +81,8 @@ int main(int argc, char* argv[]) {
int32 chunk_size = FLAGS_receptive_field_length + int32 chunk_size = FLAGS_receptive_field_length +
(FLAGS_nnet_decoder_chunk - 1) * FLAGS_downsampling_rate; (FLAGS_nnet_decoder_chunk - 1) * FLAGS_subsampling_rate;
int32 chunk_stride = FLAGS_downsampling_rate * FLAGS_nnet_decoder_chunk; int32 chunk_stride = FLAGS_subsampling_rate * FLAGS_nnet_decoder_chunk;
int32 receptive_field_length = FLAGS_receptive_field_length; int32 receptive_field_length = FLAGS_receptive_field_length;
LOG(INFO) << "chunk size (frame): " << chunk_size; LOG(INFO) << "chunk size (frame): " << chunk_size;
LOG(INFO) << "chunk stride (frame): " << chunk_stride; LOG(INFO) << "chunk stride (frame): " << chunk_stride;

@ -20,15 +20,37 @@
#include "kaldi/decoder/lattice-faster-online-decoder.h" #include "kaldi/decoder/lattice-faster-online-decoder.h"
#include "util/parse-options.h" #include "util/parse-options.h"
DECLARE_string(graph_path);
DECLARE_string(word_symbol_table);
DECLARE_int32(max_active);
DECLARE_double(beam);
DECLARE_double(lattice_beam);
namespace ppspeech { namespace ppspeech {
struct TLGDecoderOptions { struct TLGDecoderOptions {
kaldi::LatticeFasterDecoderConfig opts; kaldi::LatticeFasterDecoderConfig opts{};
// todo remove later, add into decode resource // todo remove later, add into decode resource
std::string word_symbol_table; std::string word_symbol_table{};
std::string fst_path; std::string fst_path{};
TLGDecoderOptions() : word_symbol_table(""), fst_path("") {} static TLGDecoderOptions InitFromFlags(){
TLGDecoderOptions decoder_opts;
decoder_opts.word_symbol_table = FLAGS_word_symbol_table;
decoder_opts.fst_path = FLAGS_graph_path;
LOG(INFO) << "fst path: " << decoder_opts.fst_path;
LOG(INFO) << "fst symbole table: " << decoder_opts.word_symbol_table;
decoder_opts.opts.max_active = FLAGS_max_active;
decoder_opts.opts.beam = FLAGS_beam;
decoder_opts.opts.lattice_beam = FLAGS_lattice_beam;
LOG(INFO) << "LatticeFasterDecoder max active: " << decoder_opts.opts.max_active ;
LOG(INFO) << "LatticeFasterDecoder beam: " << decoder_opts.opts.beam ;
LOG(INFO) << "LatticeFasterDecoder lattice_beam: " << decoder_opts.opts.lattice_beam ;
return decoder_opts;
}
}; };
class TLGDecoder : public DecoderInterface { class TLGDecoder : public DecoderInterface {

@ -19,6 +19,7 @@
#include "frontend/audio/data_cache.h" #include "frontend/audio/data_cache.h"
#include "nnet/decodable.h" #include "nnet/decodable.h"
#include "nnet/ds2_nnet.h" #include "nnet/ds2_nnet.h"
#include "decoder/param.h"
#include "decoder/ctc_tlg_decoder.h" #include "decoder/ctc_tlg_decoder.h"
#include "kaldi/util/table-types.h" #include "kaldi/util/table-types.h"
@ -26,30 +27,7 @@
DEFINE_string(feature_rspecifier, "", "test feature rspecifier"); DEFINE_string(feature_rspecifier, "", "test feature rspecifier");
DEFINE_string(result_wspecifier, "", "test result wspecifier"); DEFINE_string(result_wspecifier, "", "test result wspecifier");
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
DEFINE_string(word_symbol_table, "words.txt", "word symbol table");
DEFINE_string(graph_path, "TLG", "decoder graph");
DEFINE_double(acoustic_scale, 1.0, "acoustic scale");
DEFINE_int32(max_active, 7500, "decoder graph");
DEFINE_int32(nnet_decoder_chunk, 1, "paddle nnet forward chunk");
DEFINE_int32(receptive_field_length,
7,
"receptive field of two CNN(kernel=3) downsampling module.");
DEFINE_int32(downsampling_rate,
4,
"two CNN(kernel=3) module downsampling rate.");
DEFINE_string(
model_input_names,
"audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box",
"model input names");
DEFINE_string(model_output_names,
"softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0",
"model output names");
DEFINE_string(model_cache_names,
"chunk_state_h_box,chunk_state_c_box",
"model cache names");
DEFINE_string(model_cache_shapes, "5-1-1024,5-1-1024", "model cache shapes");
using kaldi::BaseFloat; using kaldi::BaseFloat;
using kaldi::Matrix; using kaldi::Matrix;
@ -66,32 +44,16 @@ int main(int argc, char* argv[]) {
kaldi::SequentialBaseFloatMatrixReader feature_reader( kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_rspecifier); FLAGS_feature_rspecifier);
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier); kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
std::string model_graph = FLAGS_model_path;
std::string model_params = FLAGS_param_path;
std::string word_symbol_table = FLAGS_word_symbol_table;
std::string graph_path = FLAGS_graph_path;
LOG(INFO) << "model path: " << model_graph;
LOG(INFO) << "model param: " << model_params;
LOG(INFO) << "word symbol path: " << word_symbol_table;
LOG(INFO) << "graph path: " << graph_path;
int32 num_done = 0, num_err = 0; int32 num_done = 0, num_err = 0;
ppspeech::TLGDecoderOptions opts; ppspeech::TLGDecoderOptions opts = ppspeech::TLGDecoderOptions::InitFromFlags();
opts.word_symbol_table = word_symbol_table;
opts.fst_path = graph_path;
opts.opts.max_active = FLAGS_max_active;
opts.opts.beam = 15.0; opts.opts.beam = 15.0;
opts.opts.lattice_beam = 7.5; opts.opts.lattice_beam = 7.5;
ppspeech::TLGDecoder decoder(opts); ppspeech::TLGDecoder decoder(opts);
ppspeech::ModelOptions model_opts; ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags();
model_opts.model_path = model_graph;
model_opts.param_path = model_params;
model_opts.cache_names = FLAGS_model_cache_names;
model_opts.cache_shape = FLAGS_model_cache_shapes;
model_opts.input_names = FLAGS_model_input_names;
model_opts.output_names = FLAGS_model_output_names;
std::shared_ptr<ppspeech::PaddleNnet> nnet( std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts)); new ppspeech::PaddleNnet(model_opts));
std::shared_ptr<ppspeech::DataCache> raw_data(new ppspeech::DataCache()); std::shared_ptr<ppspeech::DataCache> raw_data(new ppspeech::DataCache());
@ -99,12 +61,13 @@ int main(int argc, char* argv[]) {
new ppspeech::Decodable(nnet, raw_data, FLAGS_acoustic_scale)); new ppspeech::Decodable(nnet, raw_data, FLAGS_acoustic_scale));
int32 chunk_size = FLAGS_receptive_field_length + int32 chunk_size = FLAGS_receptive_field_length +
(FLAGS_nnet_decoder_chunk - 1) * FLAGS_downsampling_rate; (FLAGS_nnet_decoder_chunk - 1) * FLAGS_subsampling_rate;
int32 chunk_stride = FLAGS_downsampling_rate * FLAGS_nnet_decoder_chunk; int32 chunk_stride = FLAGS_subsampling_rate * FLAGS_nnet_decoder_chunk;
int32 receptive_field_length = FLAGS_receptive_field_length; int32 receptive_field_length = FLAGS_receptive_field_length;
LOG(INFO) << "chunk size (frame): " << chunk_size; LOG(INFO) << "chunk size (frame): " << chunk_size;
LOG(INFO) << "chunk stride (frame): " << chunk_stride; LOG(INFO) << "chunk stride (frame): " << chunk_stride;
LOG(INFO) << "receptive field (frame): " << receptive_field_length; LOG(INFO) << "receptive field (frame): " << receptive_field_length;
decoder.InitDecoder(); decoder.InitDecoder();
kaldi::Timer timer; kaldi::Timer timer;
for (; !feature_reader.Done(); feature_reader.Next()) { for (; !feature_reader.Done(); feature_reader.Next()) {

@ -17,8 +17,6 @@
#include "base/common.h" #include "base/common.h"
#include "decoder/ctc_beam_search_decoder.h" #include "decoder/ctc_beam_search_decoder.h"
#include "decoder/ctc_tlg_decoder.h" #include "decoder/ctc_tlg_decoder.h"
#include "frontend/audio/feature_pipeline.h"
// feature // feature
DEFINE_bool(use_fbank, false, "False for fbank; or linear feature"); DEFINE_bool(use_fbank, false, "False for fbank; or linear feature");
@ -27,18 +25,18 @@ DEFINE_bool(use_fbank, false, "False for fbank; or linear feature");
DEFINE_int32(num_bins, 161, "num bins of mel"); DEFINE_int32(num_bins, 161, "num bins of mel");
DEFINE_string(cmvn_file, "", "read cmvn"); DEFINE_string(cmvn_file, "", "read cmvn");
// feature sliding window // feature sliding window
DEFINE_int32(receptive_field_length, DEFINE_int32(receptive_field_length,
7, 7,
"receptive field of two CNN(kernel=3) downsampling module."); "receptive field of two CNN(kernel=3) downsampling module.");
DEFINE_int32(downsampling_rate, DEFINE_int32(subsampling_rate,
4, 4,
"two CNN(kernel=3) module downsampling rate."); "two CNN(kernel=3) module downsampling rate.");
DEFINE_int32(nnet_decoder_chunk, 1, "paddle nnet forward chunk"); DEFINE_int32(nnet_decoder_chunk, 1, "paddle nnet forward chunk");
// nnet // nnet
DEFINE_string(vocab_path, "", "nnet vocab path.");
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model"); DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param"); DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
DEFINE_string( DEFINE_string(
@ -52,10 +50,11 @@ DEFINE_string(model_cache_names,
"chunk_state_h_box,chunk_state_c_box", "chunk_state_h_box,chunk_state_c_box",
"model cache names"); "model cache names");
DEFINE_string(model_cache_shapes, "5-1-1024,5-1-1024", "model cache shapes"); DEFINE_string(model_cache_shapes, "5-1-1024,5-1-1024", "model cache shapes");
DEFINE_string(vocab_path, "", "nnet vocab path.");
// decoder // decoder
DEFINE_double(acoustic_scale, 1.0, "acoustic scale"); DEFINE_double(acoustic_scale, 1.0, "acoustic scale");
DEFINE_string(graph_path, "TLG", "decoder graph"); DEFINE_string(graph_path, "TLG", "decoder graph");
DEFINE_string(word_symbol_table, "words.txt", "word symbol table"); DEFINE_string(word_symbol_table, "words.txt", "word symbol table");
DEFINE_int32(max_active, 7500, "max active"); DEFINE_int32(max_active, 7500, "max active");
@ -63,37 +62,20 @@ DEFINE_double(beam, 15.0, "decoder beam");
DEFINE_double(lattice_beam, 7.5, "decoder beam"); DEFINE_double(lattice_beam, 7.5, "decoder beam");
namespace ppspeech {
// todo refactor later
FeaturePipelineOptions InitFeaturePipelineOptions() {
FeaturePipelineOptions opts;
opts.cmvn_file = FLAGS_cmvn_file;
kaldi::FrameExtractionOptions frame_opts;
frame_opts.dither = 0.0;
frame_opts.frame_shift_ms = 10;
opts.use_fbank = FLAGS_use_fbank;
LOG(INFO) << "feature type: " << (opts.use_fbank ? "fbank" : "linear");
if (opts.use_fbank) {
opts.to_float32 = false;
frame_opts.window_type = "povey";
frame_opts.frame_length_ms = 25;
opts.fbank_opts.mel_opts.num_bins = FLAGS_num_bins;
opts.fbank_opts.frame_opts = frame_opts;
LOG(INFO) << "num bins: " << opts.fbank_opts.mel_opts.num_bins;
} else {
opts.to_float32 = true;
frame_opts.remove_dc_offset = false;
frame_opts.frame_length_ms = 20;
frame_opts.window_type = "hanning";
frame_opts.preemph_coeff = 0.0;
opts.linear_spectrogram_opts.frame_opts = frame_opts;
}
opts.assembler_opts.subsampling_rate = FLAGS_downsampling_rate;
opts.assembler_opts.receptive_filed_length = FLAGS_receptive_field_length;
opts.assembler_opts.nnet_decoder_chunk = FLAGS_nnet_decoder_chunk;
return opts;
}
} // namespace ppspeech // DecodeOptions flags
// DEFINE_int32(chunk_size, -1, "decoding chunk size");
DEFINE_int32(num_left_chunks, -1, "left chunks in decoding");
DEFINE_double(ctc_weight,
0.5,
"ctc weight when combining ctc score and rescoring score");
DEFINE_double(rescoring_weight,
1.0,
"rescoring weight when combining ctc score and rescoring score");
DEFINE_double(reverse_weight,
0.3,
"used for bitransformer rescoring. it must be 0.0 if decoder is"
"conventional transformer decoder, and only reverse_weight > 0.0"
"dose the right to left decoder will be calculated and used");
DEFINE_int32(nbest, 10, "nbest for ctc wfst or prefix search");
DEFINE_int32(blank, 0, "blank id in vocab");

@ -22,14 +22,26 @@
#include "nnet/decodable.h" #include "nnet/decodable.h"
#include "nnet/ds2_nnet.h" #include "nnet/ds2_nnet.h"
DECLARE_double(acoustic_scale);
namespace ppspeech { namespace ppspeech {
struct RecognizerResource { struct RecognizerResource {
kaldi::BaseFloat acoustic_scale{1.0};
FeaturePipelineOptions feature_pipeline_opts{}; FeaturePipelineOptions feature_pipeline_opts{};
ModelOptions model_opts{}; ModelOptions model_opts{};
TLGDecoderOptions tlg_opts{}; TLGDecoderOptions tlg_opts{};
// CTCBeamSearchOptions beam_search_opts; // CTCBeamSearchOptions beam_search_opts;
kaldi::BaseFloat acoustic_scale{1.0};
static RecognizerResource InitFromFlags(){
RecognizerResource resource;
resource.acoustic_scale = FLAGS_acoustic_scale;
resource.feature_pipeline_opts = FeaturePipelineOptions::InitFromFlags();
resource.model_opts = ModelOptions::InitFromFlags();
resource.tlg_opts = TLGDecoderOptions::InitFromFlags();
return resource;
}
}; };
class Recognizer { class Recognizer {

@ -25,27 +25,9 @@ DEFINE_int32(sample_rate, 16000, "sample rate");
ppspeech::RecognizerResource InitRecognizerResoure() { ppspeech::RecognizerResource InitRecognizerResoure() {
ppspeech::RecognizerResource resource; ppspeech::RecognizerResource resource;
resource.acoustic_scale = FLAGS_acoustic_scale; resource.acoustic_scale = FLAGS_acoustic_scale;
resource.feature_pipeline_opts = ppspeech::InitFeaturePipelineOptions(); resource.feature_pipeline_opts = ppspeech::FeaturePipelineOptions::InitFromFlags();
resource.model_opts = ppspeech::ModelOptions::InitFromFlags();
ppspeech::ModelOptions model_opts; resource.tlg_opts = ppspeech::TLGDecoderOptions::InitFromFlags();
model_opts.model_path = FLAGS_model_path;
model_opts.param_path = FLAGS_param_path;
model_opts.cache_names = FLAGS_model_cache_names;
model_opts.cache_shape = FLAGS_model_cache_shapes;
model_opts.input_names = FLAGS_model_input_names;
model_opts.output_names = FLAGS_model_output_names;
model_opts.subsample_rate = FLAGS_downsampling_rate;
resource.model_opts = model_opts;
ppspeech::TLGDecoderOptions decoder_opts;
decoder_opts.word_symbol_table = FLAGS_word_symbol_table;
decoder_opts.fst_path = FLAGS_graph_path;
decoder_opts.opts.max_active = FLAGS_max_active;
decoder_opts.opts.beam = FLAGS_beam;
decoder_opts.opts.lattice_beam = FLAGS_lattice_beam;
resource.tlg_opts = decoder_opts;
return resource; return resource;
} }

@ -26,15 +26,25 @@
#include "fst/fstlib.h" #include "fst/fstlib.h"
#include "fst/symbol-table.h" #include "fst/symbol-table.h"
namespace ppspeech { DECLARE_int32(nnet_decoder_chunk);
DECLARE_int32(num_left_chunks);
DECLARE_double(ctc_weight);
DECLARE_double(rescoring_weight);
DECLARE_double(reverse_weight);
DECLARE_int32(nbest);
DECLARE_int32(blank);
DECLARE_double(acoustic_scale);
DECLARE_string(vocab_path);
namespace ppspeech {
struct DecodeOptions { struct DecodeOptions {
// chunk_size is the frame number of one chunk after subsampling. // chunk_size is the frame number of one chunk after subsampling.
// e.g. if subsample rate is 4 and chunk_size = 16, the frames in // e.g. if subsample rate is 4 and chunk_size = 16, the frames in
// one chunk are 67=16*4 + 3, stride is 64=16*4 // one chunk are 67=16*4 + 3, stride is 64=16*4
int chunk_size; int chunk_size{16};
int num_left_chunks; int num_left_chunks{-1};
// final_score = rescoring_weight * rescoring_score + ctc_weight * // final_score = rescoring_weight * rescoring_score + ctc_weight *
// ctc_score; // ctc_score;
@ -46,51 +56,27 @@ struct DecodeOptions {
// it's a sum(prefix) score + context score For CtcWfstBeamSerch, it's a // it's a sum(prefix) score + context score For CtcWfstBeamSerch, it's a
// max(viterbi) path score + context score So we should carefully set // max(viterbi) path score + context score So we should carefully set
// ctc_weight accroding to the search methods. // ctc_weight accroding to the search methods.
float ctc_weight; float ctc_weight{0.0};
float rescoring_weight; float rescoring_weight{1.0};
float reverse_weight; float reverse_weight{0.0};
// CtcEndpointConfig ctc_endpoint_opts; // CtcEndpointConfig ctc_endpoint_opts;
CTCBeamSearchOptions ctc_prefix_search_opts; CTCBeamSearchOptions ctc_prefix_search_opts{};
DecodeOptions() static DecodeOptions InitFromFlags(){
: chunk_size(16), DecodeOptions decoder_opts;
num_left_chunks(-1), decoder_opts.chunk_size=FLAGS_nnet_decoder_chunk;
ctc_weight(0.5), decoder_opts.num_left_chunks = FLAGS_num_left_chunks;
rescoring_weight(1.0), decoder_opts.ctc_weight = FLAGS_ctc_weight;
reverse_weight(0.0) {} decoder_opts.rescoring_weight = FLAGS_rescoring_weight;
decoder_opts.reverse_weight = FLAGS_reverse_weight;
void Register(kaldi::OptionsItf* opts) { decoder_opts.ctc_prefix_search_opts.blank = FLAGS_blank;
std::string module = "DecoderConfig: "; decoder_opts.ctc_prefix_search_opts.first_beam_size = FLAGS_nbest;
opts->Register( decoder_opts.ctc_prefix_search_opts.second_beam_size = FLAGS_nbest;
"chunk-size", return decoder_opts;
&chunk_size,
module + "the frame number of one chunk after subsampling.");
opts->Register("num-left-chunks",
&num_left_chunks,
module + "the left history chunks number.");
opts->Register("ctc-weight",
&ctc_weight,
module +
"ctc weight for rescore. final_score = "
"rescoring_weight * rescoring_score + ctc_weight * "
"ctc_score.");
opts->Register("rescoring-weight",
&rescoring_weight,
module +
"attention score weight for rescore. final_score = "
"rescoring_weight * rescoring_score + ctc_weight * "
"ctc_score.");
opts->Register("reverse-weight",
&reverse_weight,
module +
"reverse decoder weight. rescoring_score = "
"left_to_right_score * (1 - reverse_weight) + "
"right_to_left_score * reverse_weight.");
} }
}; };
struct U2RecognizerResource { struct U2RecognizerResource {
kaldi::BaseFloat acoustic_scale{1.0}; kaldi::BaseFloat acoustic_scale{1.0};
std::string vocab_path{}; std::string vocab_path{};
@ -98,7 +84,17 @@ struct U2RecognizerResource {
FeaturePipelineOptions feature_pipeline_opts{}; FeaturePipelineOptions feature_pipeline_opts{};
ModelOptions model_opts{}; ModelOptions model_opts{};
DecodeOptions decoder_opts{}; DecodeOptions decoder_opts{};
// CTCBeamSearchOptions beam_search_opts;
static U2RecognizerResource InitFromFlags() {
U2RecognizerResource resource;
resource.vocab_path = FLAGS_vocab_path;
resource.acoustic_scale = FLAGS_acoustic_scale;
resource.feature_pipeline_opts = ppspeech::FeaturePipelineOptions::InitFromFlags();
resource.model_opts = ppspeech::ModelOptions::InitFromFlags();
resource.decoder_opts = ppspeech::DecodeOptions::InitFromFlags();
return resource;
}
}; };

@ -22,35 +22,6 @@ DEFINE_string(result_wspecifier, "", "test result wspecifier");
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size"); DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
DEFINE_int32(sample_rate, 16000, "sample rate"); DEFINE_int32(sample_rate, 16000, "sample rate");
ppspeech::U2RecognizerResource InitOpts() {
ppspeech::U2RecognizerResource resource;
resource.vocab_path = FLAGS_vocab_path;
resource.acoustic_scale = FLAGS_acoustic_scale;
resource.feature_pipeline_opts = ppspeech::InitFeaturePipelineOptions();
LOG(INFO) << "feature!";
ppspeech::ModelOptions model_opts;
model_opts.model_path = FLAGS_model_path;
resource.model_opts = model_opts;
LOG(INFO) << "model!";
ppspeech::DecodeOptions decoder_opts;
decoder_opts.chunk_size=16;
decoder_opts.num_left_chunks = -1;
decoder_opts.ctc_weight = 0.5;
decoder_opts.rescoring_weight = 1.0;
decoder_opts.reverse_weight = 0.3;
decoder_opts.ctc_prefix_search_opts.blank = 0;
decoder_opts.ctc_prefix_search_opts.first_beam_size = 10;
decoder_opts.ctc_prefix_search_opts.second_beam_size = 10;
resource.decoder_opts = decoder_opts;
LOG(INFO) << "decoder!";
return resource;
}
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:"); gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false); gflags::ParseCommandLineFlags(&argc, &argv, false);
@ -72,7 +43,7 @@ int main(int argc, char* argv[]) {
LOG(INFO) << "chunk size (s): " << streaming_chunk; LOG(INFO) << "chunk size (s): " << streaming_chunk;
LOG(INFO) << "chunk size (sample): " << chunk_sample_size; LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
ppspeech::U2RecognizerResource resource = InitOpts(); ppspeech::U2RecognizerResource resource = ppspeech::U2RecognizerResource::InitFromFlags();
ppspeech::U2Recognizer recognizer(resource); ppspeech::U2Recognizer recognizer(resource);
kaldi::Timer timer; kaldi::Timer timer;

@ -25,26 +25,71 @@
#include "frontend/audio/linear_spectrogram.h" #include "frontend/audio/linear_spectrogram.h"
#include "frontend/audio/normalizer.h" #include "frontend/audio/normalizer.h"
// feature
DECLARE_bool(use_fbank);
DECLARE_int32(num_bins);
DECLARE_string(cmvn_file);
// feature sliding window
DECLARE_int32(receptive_field_length);
DECLARE_int32(subsampling_rate);
DECLARE_int32(nnet_decoder_chunk);
namespace ppspeech { namespace ppspeech {
struct FeaturePipelineOptions { struct FeaturePipelineOptions {
std::string cmvn_file; std::string cmvn_file{};
bool to_float32; // true, only for linear feature bool to_float32{false}; // true, only for linear feature
bool use_fbank; bool use_fbank{true};
LinearSpectrogramOptions linear_spectrogram_opts; LinearSpectrogramOptions linear_spectrogram_opts{};
kaldi::FbankOptions fbank_opts; kaldi::FbankOptions fbank_opts{};
FeatureCacheOptions feature_cache_opts; FeatureCacheOptions feature_cache_opts{};
AssemblerOptions assembler_opts; AssemblerOptions assembler_opts{};
FeaturePipelineOptions() static FeaturePipelineOptions InitFromFlags(){
: cmvn_file(""), FeaturePipelineOptions opts;
to_float32(false), // true, only for linear feature opts.cmvn_file = FLAGS_cmvn_file;
use_fbank(true), LOG(INFO) << "cmvn file: " << opts.cmvn_file;
linear_spectrogram_opts(),
fbank_opts(), // frame options
feature_cache_opts(), kaldi::FrameExtractionOptions frame_opts;
assembler_opts() {} frame_opts.dither = 0.0;
LOG(INFO) << "dither: " << frame_opts.dither;
frame_opts.frame_shift_ms = 10;
LOG(INFO) << "frame shift ms: " << frame_opts.frame_shift_ms;
opts.use_fbank = FLAGS_use_fbank;
LOG(INFO) << "feature type: " << (opts.use_fbank ? "fbank" : "linear");
if (opts.use_fbank) {
opts.to_float32 = false;
frame_opts.window_type = "povey";
frame_opts.frame_length_ms = 25;
opts.fbank_opts.mel_opts.num_bins = FLAGS_num_bins;
LOG(INFO) << "num bins: " << opts.fbank_opts.mel_opts.num_bins;
opts.fbank_opts.frame_opts = frame_opts;
} else {
opts.to_float32 = true;
frame_opts.remove_dc_offset = false;
frame_opts.frame_length_ms = 20;
frame_opts.window_type = "hanning";
frame_opts.preemph_coeff = 0.0;
opts.linear_spectrogram_opts.frame_opts = frame_opts;
}
LOG(INFO) << "frame length ms: " << frame_opts.frame_length_ms;
// assembler opts
opts.assembler_opts.subsampling_rate = FLAGS_subsampling_rate;
LOG(INFO) << "subsampling rate: " << opts.assembler_opts.subsampling_rate;
opts.assembler_opts.receptive_filed_length = FLAGS_receptive_field_length;
LOG(INFO) << "nnet receptive filed length: " << opts.assembler_opts.receptive_filed_length;
opts.assembler_opts.nnet_decoder_chunk = FLAGS_nnet_decoder_chunk;
LOG(INFO) << "nnet chunk size: " << opts.assembler_opts.nnet_decoder_chunk;
return opts;
}
}; };
class FeaturePipeline : public FrontendInterface { class FeaturePipeline : public FrontendInterface {
public: public:
explicit FeaturePipeline(const FeaturePipelineOptions& opts); explicit FeaturePipeline(const FeaturePipelineOptions& opts);

@ -14,6 +14,7 @@
#include "nnet/ds2_nnet.h" #include "nnet/ds2_nnet.h"
#include "base/common.h" #include "base/common.h"
#include "decoder/param.h"
#include "frontend/audio/assembler.h" #include "frontend/audio/assembler.h"
#include "frontend/audio/data_cache.h" #include "frontend/audio/data_cache.h"
#include "kaldi/util/table-types.h" #include "kaldi/util/table-types.h"
@ -21,27 +22,6 @@
DEFINE_string(feature_rspecifier, "", "test feature rspecifier"); DEFINE_string(feature_rspecifier, "", "test feature rspecifier");
DEFINE_string(nnet_prob_wspecifier, "", "nnet porb wspecifier"); DEFINE_string(nnet_prob_wspecifier, "", "nnet porb wspecifier");
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
DEFINE_int32(nnet_decoder_chunk, 1, "paddle nnet forward chunk");
DEFINE_int32(receptive_field_length,
7,
"receptive field of two CNN(kernel=3) downsampling module.");
DEFINE_int32(downsampling_rate,
4,
"two CNN(kernel=3) module downsampling rate.");
DEFINE_string(
model_input_names,
"audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box",
"model input names");
DEFINE_string(model_output_names,
"softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0",
"model output names");
DEFINE_string(model_cache_names,
"chunk_state_h_box,chunk_state_c_box",
"model cache names");
DEFINE_string(model_cache_shapes, "5-1-1024,5-1-1024", "model cache shapes");
DEFINE_double(acoustic_scale, 1.0, "acoustic scale");
using kaldi::BaseFloat; using kaldi::BaseFloat;
using kaldi::Matrix; using kaldi::Matrix;
@ -64,13 +44,8 @@ int main(int argc, char* argv[]) {
int32 num_done = 0, num_err = 0; int32 num_done = 0, num_err = 0;
ppspeech::ModelOptions model_opts; ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags();
model_opts.model_path = model_graph;
model_opts.param_path = model_params;
model_opts.cache_names = FLAGS_model_cache_names;
model_opts.cache_shape = FLAGS_model_cache_shapes;
model_opts.input_names = FLAGS_model_input_names;
model_opts.output_names = FLAGS_model_output_names;
std::shared_ptr<ppspeech::PaddleNnet> nnet( std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts)); new ppspeech::PaddleNnet(model_opts));
std::shared_ptr<ppspeech::DataCache> raw_data(new ppspeech::DataCache()); std::shared_ptr<ppspeech::DataCache> raw_data(new ppspeech::DataCache());
@ -78,8 +53,8 @@ int main(int argc, char* argv[]) {
new ppspeech::Decodable(nnet, raw_data, FLAGS_acoustic_scale)); new ppspeech::Decodable(nnet, raw_data, FLAGS_acoustic_scale));
int32 chunk_size = FLAGS_receptive_field_length + int32 chunk_size = FLAGS_receptive_field_length +
(FLAGS_nnet_decoder_chunk - 1) * FLAGS_downsampling_rate; (FLAGS_nnet_decoder_chunk - 1) * FLAGS_subsampling_rate;
int32 chunk_stride = FLAGS_downsampling_rate * FLAGS_nnet_decoder_chunk; int32 chunk_stride = FLAGS_subsampling_rate * FLAGS_nnet_decoder_chunk;
int32 receptive_field_length = FLAGS_receptive_field_length; int32 receptive_field_length = FLAGS_receptive_field_length;
LOG(INFO) << "chunk size (frame): " << chunk_size; LOG(INFO) << "chunk size (frame): " << chunk_size;
LOG(INFO) << "chunk stride (frame): " << chunk_stride; LOG(INFO) << "chunk stride (frame): " << chunk_stride;

@ -20,53 +20,54 @@
#include "kaldi/matrix/kaldi-matrix.h" #include "kaldi/matrix/kaldi-matrix.h"
#include "kaldi/util/options-itf.h" #include "kaldi/util/options-itf.h"
namespace ppspeech { DECLARE_int32(subsampling_rate);
DECLARE_string(model_path);
DECLARE_string(param_path);
DECLARE_string(model_input_names);
DECLARE_string(model_output_names);
DECLARE_string(model_cache_names);
DECLARE_string(model_cache_shapes);
namespace ppspeech {
struct ModelOptions { struct ModelOptions {
// common
int subsample_rate{1};
int thread_num{1}; // predictor thread pool size for ds2;
bool use_gpu{false};
std::string model_path; std::string model_path;
std::string param_path; std::string param_path;
int thread_num; // predictor thread pool size for ds2;
bool use_gpu; // ds2 for inference
bool switch_ir_optim; std::string input_names{};
std::string input_names; std::string output_names{};
std::string output_names; std::string cache_names{};
std::string cache_names; std::string cache_shape{};
std::string cache_shape; bool switch_ir_optim{false};
bool enable_fc_padding; bool enable_fc_padding{false};
bool enable_profile; bool enable_profile{false};
int subsample_rate;
ModelOptions() static ModelOptions InitFromFlags(){
: model_path(""), ModelOptions opts;
param_path(""), opts.subsample_rate = FLAGS_subsampling_rate;
thread_num(1), LOG(INFO) << "subsampling rate: " << opts.subsample_rate;
use_gpu(false), opts.model_path = FLAGS_model_path;
input_names(""), LOG(INFO) << "model path: " << opts.model_path ;
output_names(""),
cache_names(""), opts.param_path = FLAGS_param_path;
cache_shape(""), LOG(INFO) << "param path: " << opts.param_path ;
switch_ir_optim(false),
enable_fc_padding(false), LOG(INFO) << "DS2 param: ";
enable_profile(false), opts.cache_names = FLAGS_model_cache_names;
subsample_rate(0) {} LOG(INFO) << " cache names: " << opts.cache_names;
opts.cache_shape = FLAGS_model_cache_shapes;
void Register(kaldi::OptionsItf* opts) { LOG(INFO) << " cache shape: " << opts.cache_shape;
opts->Register("model-path", &model_path, "model file path"); opts.input_names = FLAGS_model_input_names;
opts->Register("model-param", &param_path, "params model file path"); LOG(INFO) << " input names: " << opts.input_names;
opts->Register("thread-num", &thread_num, "thread num"); opts.output_names = FLAGS_model_output_names;
opts->Register("use-gpu", &use_gpu, "if use gpu"); LOG(INFO) << " output names: " << opts.output_names;
opts->Register("input-names", &input_names, "paddle input names"); return opts;
opts->Register("output-names", &output_names, "paddle output names");
opts->Register("cache-names", &cache_names, "cache names");
opts->Register("cache-shape", &cache_shape, "cache shape");
opts->Register("switch-ir-optiom",
&switch_ir_optim,
"paddle SwitchIrOptim option");
opts->Register("enable-fc-padding",
&enable_fc_padding,
"paddle EnableFCPadding option");
opts->Register(
"enable-profile", &enable_profile, "paddle EnableProfile option");
} }
}; };

@ -17,7 +17,6 @@
#include "base/common.h" #include "base/common.h"
#include "kaldi/matrix/kaldi-matrix.h" #include "kaldi/matrix/kaldi-matrix.h"
#include "nnet/nnet_itf.h" #include "nnet/nnet_itf.h"
#include "paddle/extension.h" #include "paddle/extension.h"
#include "paddle/jit/all.h" #include "paddle/jit/all.h"

@ -12,28 +12,20 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "nnet/u2_nnet.h"
#include "base/common.h" #include "base/common.h"
#include "frontend/audio/assembler.h" #include "frontend/audio/assembler.h"
#include "frontend/audio/data_cache.h" #include "frontend/audio/data_cache.h"
#include "kaldi/util/table-types.h" #include "kaldi/util/table-types.h"
#include "nnet/decodable.h" #include "nnet/decodable.h"
#include "decoder/param.h"
#include "nnet/u2_nnet.h"
DEFINE_string(feature_rspecifier, "", "test feature rspecifier"); DEFINE_string(feature_rspecifier, "", "test feature rspecifier");
DEFINE_string(nnet_prob_wspecifier, "", "nnet porb wspecifier"); DEFINE_string(nnet_prob_wspecifier, "", "nnet porb wspecifier");
DEFINE_string(nnet_encoder_outs_wspecifier, "", "nnet encoder outs wspecifier"); DEFINE_string(nnet_encoder_outs_wspecifier, "", "nnet encoder outs wspecifier");
DEFINE_string(model_path, "", "paddle nnet model");
DEFINE_int32(nnet_decoder_chunk, 16, "nnet forward chunk");
DEFINE_int32(receptive_field_length,
7,
"receptive field of two CNN(kernel=3) downsampling module.");
DEFINE_int32(downsampling_rate,
4,
"two CNN(kernel=3) module downsampling rate.");
DEFINE_double(acoustic_scale, 1.0, "acoustic scale");
using kaldi::BaseFloat; using kaldi::BaseFloat;
using kaldi::Matrix; using kaldi::Matrix;
using std::vector; using std::vector;
@ -58,13 +50,12 @@ int main(int argc, char* argv[]) {
kaldi::BaseFloatMatrixWriter nnet_out_writer(FLAGS_nnet_prob_wspecifier); kaldi::BaseFloatMatrixWriter nnet_out_writer(FLAGS_nnet_prob_wspecifier);
kaldi::BaseFloatMatrixWriter nnet_encoder_outs_writer(FLAGS_nnet_encoder_outs_wspecifier); kaldi::BaseFloatMatrixWriter nnet_encoder_outs_writer(FLAGS_nnet_encoder_outs_wspecifier);
ppspeech::ModelOptions model_opts; ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags();
model_opts.model_path = FLAGS_model_path;
int32 chunk_size = int32 chunk_size =
(FLAGS_nnet_decoder_chunk - 1) * FLAGS_downsampling_rate + (FLAGS_nnet_decoder_chunk - 1) * FLAGS_subsampling_rate +
FLAGS_receptive_field_length; FLAGS_receptive_field_length;
int32 chunk_stride = FLAGS_downsampling_rate * FLAGS_nnet_decoder_chunk; int32 chunk_stride = FLAGS_subsampling_rate * FLAGS_nnet_decoder_chunk;
int32 receptive_field_length = FLAGS_receptive_field_length; int32 receptive_field_length = FLAGS_receptive_field_length;
LOG(INFO) << "chunk size (frame): " << chunk_size; LOG(INFO) << "chunk size (frame): " << chunk_size;
LOG(INFO) << "chunk stride (frame): " << chunk_stride; LOG(INFO) << "chunk stride (frame): " << chunk_stride;

@ -20,27 +20,9 @@ DEFINE_int32(port, 8082, "websocket listening port");
ppspeech::RecognizerResource InitRecognizerResoure() { ppspeech::RecognizerResource InitRecognizerResoure() {
ppspeech::RecognizerResource resource; ppspeech::RecognizerResource resource;
resource.acoustic_scale = FLAGS_acoustic_scale; resource.acoustic_scale = FLAGS_acoustic_scale;
resource.feature_pipeline_opts = ppspeech::InitFeaturePipelineOptions(); resource.feature_pipeline_opts = ppspeech::FeaturePipelineOptions::InitFromFlags();
resource.model_opts = ppspeech::ModelOptions::InitFromFlags();
ppspeech::ModelOptions model_opts; resource.tlg_opts = ppspeech::TLGDecoderOptions::InitFromFlags();
model_opts.model_path = FLAGS_model_path;
model_opts.param_path = FLAGS_param_path;
model_opts.cache_names = FLAGS_model_cache_names;
model_opts.cache_shape = FLAGS_model_cache_shapes;
model_opts.input_names = FLAGS_model_input_names;
model_opts.output_names = FLAGS_model_output_names;
model_opts.subsample_rate = FLAGS_downsampling_rate;
resource.model_opts = model_opts;
ppspeech::TLGDecoderOptions decoder_opts;
decoder_opts.word_symbol_table = FLAGS_word_symbol_table;
decoder_opts.fst_path = FLAGS_graph_path;
decoder_opts.opts.max_active = FLAGS_max_active;
decoder_opts.opts.beam = FLAGS_beam;
decoder_opts.opts.lattice_beam = FLAGS_lattice_beam;
resource.tlg_opts = decoder_opts;
return resource; return resource;
} }

Loading…
Cancel
Save