[engine] add wfst recognizer in example (#3173)

* update wfst script

* add skip blank
pull/3198/head
YangZhou 2 years ago committed by GitHub
parent 5e2251afda
commit 8c2196ea0c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -87,9 +87,9 @@ void CTCPrefixBeamSearch::AdvanceDecode(
VLOG(1) << "num_frame_decoded_: " << num_frame_decoded_;
}
VLOG(1) << "AdvanceDecode feat + forward cost: " << feat_nnet_cost
VLOG(2) << "AdvanceDecode feat + forward cost: " << feat_nnet_cost
<< " sec.";
VLOG(1) << "AdvanceDecode search cost: " << search_cost << " sec.";
VLOG(2) << "AdvanceDecode search cost: " << search_cost << " sec.";
}
static bool PrefixScoreCompare(

@ -71,7 +71,7 @@ int main(int argc, char* argv[]) {
std::shared_ptr<ppspeech::DataCache> raw_data =
std::make_shared<ppspeech::DataCache>();
std::shared_ptr<ppspeech::NnetProducer> nnet_producer =
std::make_shared<ppspeech::NnetProducer>(nnet, raw_data);
std::make_shared<ppspeech::NnetProducer>(nnet, raw_data, 1.0);
std::shared_ptr<ppspeech::Decodable> decodable =
std::make_shared<ppspeech::Decodable>(nnet_producer);

@ -44,7 +44,7 @@ struct TLGDecoderOptions {
decoder_opts.word_symbol_table = FLAGS_word_symbol_table;
decoder_opts.fst_path = FLAGS_graph_path;
LOG(INFO) << "fst path: " << decoder_opts.fst_path;
LOG(INFO) << "fst symbole table: " << decoder_opts.word_symbol_table;
LOG(INFO) << "symbole table: " << decoder_opts.word_symbol_table;
if (!decoder_opts.fst_path.empty()) {
CHECK(FileExists(decoder_opts.fst_path));

@ -54,7 +54,7 @@ int main(int argc, char* argv[]) {
ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags();
std::shared_ptr<ppspeech::NnetProducer> nnet_producer =
std::make_shared<ppspeech::NnetProducer>(nullptr);
std::make_shared<ppspeech::NnetProducer>(nullptr, nullptr, 1.0);
std::shared_ptr<ppspeech::Decodable> decodable(
new ppspeech::Decodable(nnet_producer, FLAGS_acoustic_scale));

@ -35,13 +35,11 @@ DEFINE_int32(subsampling_rate,
"two CNN(kernel=3) module downsampling rate.");
DEFINE_int32(nnet_decoder_chunk, 1, "paddle nnet forward chunk");
// nnet
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
#ifdef USE_ONNX
DEFINE_bool(with_onnx_model, false, "True mean the model path is onnx model path");
#endif
//DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
// decoder
DEFINE_double(acoustic_scale, 1.0, "acoustic scale");
@ -50,10 +48,9 @@ DEFINE_string(word_symbol_table, "", "word symbol table");
DEFINE_int32(max_active, 7500, "max active");
DEFINE_double(beam, 15.0, "decoder beam");
DEFINE_double(lattice_beam, 7.5, "decoder beam");
DEFINE_double(blank_threshold, 0.98, "blank skip threshold");
// DecodeOptions flags
// DEFINE_int32(chunk_size, -1, "decoding chunk size");
DEFINE_int32(num_left_chunks, -1, "left chunks in decoding");
DEFINE_double(ctc_weight,
0.5,

@ -22,8 +22,9 @@ using kaldi::BaseFloat;
using std::vector;
NnetProducer::NnetProducer(std::shared_ptr<NnetBase> nnet,
std::shared_ptr<FrontendInterface> frontend)
: nnet_(nnet), frontend_(frontend) {
std::shared_ptr<FrontendInterface> frontend,
float blank_threshold)
: nnet_(nnet), frontend_(frontend), blank_threshold_(blank_threshold) {
Reset();
}
@ -45,7 +46,6 @@ void NnetProducer::Acceptlikelihood(
bool NnetProducer::Read(std::vector<kaldi::BaseFloat>* nnet_prob) {
bool flag = cache_.pop(nnet_prob);
VLOG(1) << "nnet cache_ size: " << cache_.size();
return flag;
}
@ -70,7 +70,22 @@ bool NnetProducer::Compute() {
std::vector<BaseFloat> logprob(
out.logprobs.data() + idx * vocab_dim,
out.logprobs.data() + (idx + 1) * vocab_dim);
cache_.push_back(logprob);
// process blank prob
float blank_prob = std::exp(logprob[0]);
if (blank_prob > blank_threshold_) {
last_frame_logprob_ = logprob;
is_last_frame_skip_ = true;
continue;
} else {
int cur_max = std::max(logprob.begin(), logprob.end()) - logprob.begin();
if (cur_max == last_max_elem_ && cur_max != 0 && is_last_frame_skip_) {
cache_.push_back(last_frame_logprob_);
last_max_elem_ = cur_max;
}
last_max_elem_ = cur_max;
is_last_frame_skip_ = false;
cache_.push_back(logprob);
}
}
return true;
}

@ -24,7 +24,8 @@ namespace ppspeech {
class NnetProducer {
public:
explicit NnetProducer(std::shared_ptr<NnetBase> nnet,
std::shared_ptr<FrontendInterface> frontend = NULL);
std::shared_ptr<FrontendInterface> frontend,
float blank_threshold);
// Feed feats or waves
void Accept(const std::vector<kaldi::BaseFloat>& inputs);
@ -64,6 +65,10 @@ class NnetProducer {
std::shared_ptr<FrontendInterface> frontend_;
std::shared_ptr<NnetBase> nnet_;
SafeQueue<std::vector<kaldi::BaseFloat>> cache_;
std::vector<BaseFloat> last_frame_logprob_;
bool is_last_frame_skip_ = false;
int last_max_elem_ = -1;
float blank_threshold_ = 0.0;
bool finished_;
DISALLOW_COPY_AND_ASSIGN(NnetProducer);

@ -124,7 +124,15 @@ U2Nnet::U2Nnet(const U2Nnet& other) {
offset_ = other.offset_;
// copy model ptr
model_ = other.model_->Clone();
// model_ = other.model_->Clone();
// hack, fix later
#ifdef WITH_GPU
dev_ = phi::GPUPlace();
#else
dev_ = phi::CPUPlace();
#endif
paddle::jit::Layer model = paddle::jit::Load(other.opts_.model_path, dev_);
model_ = std::make_shared<paddle::jit::Layer>(std::move(model));
ctc_activation_ = model_->Function("ctc_activation");
subsampling_rate_ = model_->Attribute<int>("subsampling_rate");
right_context_ = model_->Attribute<int>("right_context");
@ -166,6 +174,7 @@ void U2Nnet::Reset() {
std::move(paddle::zeros({0, 0, 0, 0}, paddle::DataType::FLOAT32));
encoder_outs_.clear();
VLOG(1) << "FeedForward cost: " << cost_time_ << " sec. ";
VLOG(3) << "u2nnet reset";
}
@ -185,8 +194,10 @@ void U2Nnet::FeedForward(const std::vector<BaseFloat>& features,
std::vector<kaldi::BaseFloat> ctc_probs;
ForwardEncoderChunkImpl(
features, feature_dim, &out->logprobs, &out->vocab_dim);
VLOG(1) << "FeedForward cost: " << timer.Elapsed() << " sec. "
float forward_chunk_time = timer.Elapsed();
VLOG(1) << "FeedForward cost: " << forward_chunk_time << " sec. "
<< features.size() / feature_dim << " frames.";
cost_time_ += forward_chunk_time;
}

@ -113,8 +113,8 @@ class U2Nnet : public U2NnetBase {
void EncoderOuts(
std::vector<std::vector<kaldi::BaseFloat>>* encoder_out) const;
ModelOptions opts_; // hack, fix later
private:
ModelOptions opts_;
phi::Place dev_;
std::shared_ptr<paddle::jit::Layer> model_{nullptr};
@ -127,6 +127,7 @@ class U2Nnet : public U2NnetBase {
paddle::jit::Function forward_encoder_chunk_;
paddle::jit::Function forward_attention_decoder_;
paddle::jit::Function ctc_activation_;
float cost_time_ = 0.0;
};
} // namespace ppspeech

@ -21,6 +21,7 @@ namespace ppspeech {
RecognizerControllerImpl::RecognizerControllerImpl(const RecognizerResource& resource)
: opts_(resource) {
BaseFloat am_scale = resource.acoustic_scale;
BaseFloat blank_threshold = resource.blank_threshold;
const FeaturePipelineOptions& feature_opts = resource.feature_pipeline_opts;
std::shared_ptr<FeaturePipeline> feature_pipeline(
new FeaturePipeline(feature_opts));
@ -34,7 +35,7 @@ RecognizerControllerImpl::RecognizerControllerImpl(const RecognizerResource& res
nnet = resource.nnet->Clone();
}
#endif
nnet_producer_.reset(new NnetProducer(nnet, feature_pipeline));
nnet_producer_.reset(new NnetProducer(nnet, feature_pipeline, blank_threshold));
nnet_thread_ = std::thread(RunNnetEvaluation, this);
decodable_.reset(new Decodable(nnet_producer_, am_scale));

@ -88,7 +88,8 @@ int main(int argc, char* argv[]) {
kaldi::Timer timer;
recognizer_ptr->AttentionRescoring();
tot_attention_rescore_time += timer.Elapsed();
float rescore_time = timer.Elapsed();
tot_attention_rescore_time += rescore_time;
std::string result = recognizer_ptr->GetFinalResult();
if (result.empty()) {
@ -101,7 +102,7 @@ int main(int argc, char* argv[]) {
tot_decode_time += local_timer.Elapsed();
LOG(INFO) << utt << " " << result;
LOG(INFO) << " RTF: " << local_timer.Elapsed() / dur << " dur: " << dur
<< " cost: " << local_timer.Elapsed();
<< " cost: " << local_timer.Elapsed() << " rescore:" << rescore_time;
result_writer.Write(utt, result);

@ -12,6 +12,7 @@ DECLARE_double(reverse_weight);
DECLARE_int32(nbest);
DECLARE_int32(blank);
DECLARE_double(acoustic_scale);
DECLARE_double(blank_threshold);
DECLARE_string(word_symbol_table);
namespace ppspeech {
@ -71,6 +72,7 @@ struct DecodeOptions {
struct RecognizerResource {
// decodable opt
kaldi::BaseFloat acoustic_scale{1.0};
kaldi::BaseFloat blank_threshold{0.98};
FeaturePipelineOptions feature_pipeline_opts{};
ModelOptions model_opts{};
@ -80,6 +82,7 @@ struct RecognizerResource {
static RecognizerResource InitFromFlags() {
RecognizerResource resource;
resource.acoustic_scale = FLAGS_acoustic_scale;
resource.blank_threshold = FLAGS_blank_threshold;
LOG(INFO) << "acoustic_scale: " << resource.acoustic_scale;
resource.feature_pipeline_opts =

@ -11,5 +11,5 @@ fsttablecompose
foreach(binary IN LISTS BINS)
add_executable(${binary} ${CMAKE_CURRENT_SOURCE_DIR}/${binary}.cc)
target_include_directories(${binary} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${binary} PUBLIC kaldi-fstext glog libgflags_nothreads.so fst dl)
target_link_libraries(${binary} PUBLIC kaldi-fstext glog gflags fst dl)
endforeach()

@ -4,7 +4,7 @@
## U2++ Attention Rescore
> Intel(R) Xeon(R) Gold 6271C CPU @ 2.60GHz, support `avx512_vnni`
> Intel(R) Xeon(R) Gold 6148 CPU @ 2.40GHz, support `avx512_vnni`
> RTF with feature and decoder which is more end to end.
### FP32
@ -23,18 +23,15 @@ Other -> 100.00 % N=3 C=0 S=3 D=0 I=0
#### RTF
```
I1027 10:52:38.662868 51665 u2_recognizer_main.cc:122] total wav duration is: 36108.9 sec
I1027 10:52:38.662858 51665 u2_recognizer_main.cc:121] total cost:11169.1 sec
I1027 10:52:38.662876 51665 u2_recognizer_main.cc:123] RTF is: 0.309318
I1027 10:52:38.662868 51665 recognizer_main.cc:122] total wav duration is: 36108.9 sec
I1027 10:52:38.662858 51665 recognizer_main.cc:121] total cost:9577.31 sec
I1027 10:52:38.662876 51665 recognizer_main.cc:123] RTF is: 0.265234
```
### INT8
`local/recognizer_quant.sh`
> RTF relative improve 12.8%, which count feature and decoder time.
> Test under Paddle commit c331e2ce2031d68a553bc9469a07c30d718438f3
#### CER
```
@ -52,16 +49,22 @@ I1110 09:59:52.551717 37249 u2_recognizer_main.cc:123] total decode cost:9737.63
I1110 09:59:52.551723 37249 u2_recognizer_main.cc:124] RTF is: 0.269674
```
### CTC Prefix Beam Search
### TLG decoder without attention rescore
`local/decode.sh`
`local/recognizer_wfst.sh`
#### CER
```
Overall -> 6.74 % N=104765 C=98106 S=6516 D=143 I=401
Mandarin -> 6.74 % N=104762 C=98106 S=6513 D=143 I=401
English -> 0.00 % N=0 C=0 S=0 D=0 I=0
Overall -> 4.73 % N=104765 C=100001 S=4283 D=481 I=187
Mandarin -> 4.72 % N=104762 C=100001 S=4280 D=481 I=187
Other -> 100.00 % N=3 C=0 S=3 D=0 I=0
```
#### RTF
```
I0417 08:07:15.300631 75784 recognizer_main.cc:113] total wav duration is: 36108.9 sec
I0417 08:07:15.300642 75784 recognizer_main.cc:114] total decode cost:10247.7 sec
I0417 08:07:15.300648 75784 recognizer_main.cc:115] total rescore cost:908.228 sec
I0417 08:07:15.300653 75784 recognizer_main.cc:116] RTF is: 0.283
```

@ -0,0 +1,36 @@
#!/bin/bash
set -e
data=data
exp=exp
nj=20
. utils/parse_options.sh
mkdir -p $exp
ckpt_dir=./data/model
model_dir=$ckpt_dir/onnx_model/
aishell_wav_scp=aishell_test.scp
text=$data/test/text
./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recognizer.fd.log \
recognizer_main \
--use_fbank=true \
--num_bins=80 \
--model_path=$model_dir \
--word_symbol_table=$model_dir/unit.txt \
--nnet_decoder_chunk=16 \
--receptive_field_length=7 \
--subsampling_rate=4 \
--with_onnx_model=true \
--wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \
--result_wspecifier=ark,t:$data/split${nj}/JOB/recognizer.fd.rsl.ark
cat $data/split${nj}/*/recognizer.fd.rsl.ark > $exp/aishell.recognizer.fd.rsl
utils/compute-wer.py --char=1 --v=1 $text $exp/aishell.recognizer.fd.rsl > $exp/aishell.recognizer.fd.err
echo "recognizer fd test have finished!!!"
echo "please checkout in $exp/aishell.recognizer.fd.err"
tail -n 7 $exp/aishell.recognizer.fd.err

@ -16,7 +16,7 @@ text=$data/test/text
./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recognizer.quant.log \
u2_recognizer_main \
recognizer_main \
--use_fbank=true \
--num_bins=80 \
--cmvn_file=$model_dir/mean_std.json \

@ -3,7 +3,7 @@ set -e
data=data
exp=exp
nj=40
nj=20
. utils/parse_options.sh
@ -19,6 +19,15 @@ lang_dir=./data/lang_test/
graph=$lang_dir/TLG.fst
word_table=$lang_dir/words.txt
if [ ! -f $graph ]; then
# download ngram, if you want to make graph by yourself, please refer local/run_build_tlg.sh
mkdir -p $lang_dir
pushd $lang_dir
wget -c https://paddlespeech.bj.bcebos.com/speechx/examples/ngram/zh/tlg.zip
unzip tlg.zip
popd
fi
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recognizer_wfst.log \
recognizer_main \
--use_fbank=true \
@ -31,6 +40,8 @@ recognizer_main \
--receptive_field_length=7 \
--subsampling_rate=4 \
--wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \
--rescoring_weight=0.0 \
--acoustic_scale=2 \
--result_wspecifier=ark,t:$data/split${nj}/JOB/result_recognizer_wfst.ark

@ -0,0 +1,51 @@
#!/bin/bash
set -e
data=data
exp=exp
nj=20
. utils/parse_options.sh
mkdir -p $exp
ckpt_dir=./data/model
model_dir=$ckpt_dir/onnx_model/
aishell_wav_scp=aishell_test.scp
text=$data/test/text
./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj
lang_dir=./data/lang_test/
graph=$lang_dir/TLG.fst
word_table=$lang_dir/words.txt
if [ ! -f $graph ]; then
# download ngram, if you want to make graph by yourself, please refer local/run_build_tlg.sh
mkdir -p $lang_dir
pushd $lang_dir
wget -c https://paddlespeech.bj.bcebos.com/speechx/examples/ngram/zh/tlg.zip
unzip tlg.zip
popd
fi
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recognizer_wfst_fd.log \
recognizer_main \
--use_fbank=true \
--num_bins=80 \
--model_path=$model_dir \
--graph_path=$lang_dir/TLG.fst \
--word_symbol_table=$word_table \
--nnet_decoder_chunk=16 \
--receptive_field_length=7 \
--subsampling_rate=4 \
--wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \
--rescoring_weight=0.0 \
--acoustic_scale=2 \
--result_wspecifier=ark,t:$data/split${nj}/JOB/result_recognizer_wfst_fd.ark
cat $data/split${nj}/*/result_recognizer_wfst_fd.ark > $exp/aishell_recognizer_wfst_fd
utils/compute-wer.py --char=1 --v=1 $text $exp/aishell_recognizer_wfst_fd > $exp/aishell.recognizer_wfst_fd.err
echo "recognizer test have finished!!!"
echo "please checkout in $exp/aishell.recognizer_wfst_fd.err"
tail -n 7 $exp/aishell.recognizer_wfst_fd.err

@ -7,13 +7,12 @@ set -eo pipefail
# different acustic model has different vocab
ckpt_dir=data/model/asr1_chunk_conformer_u2pp_wenetspeech_static_1.3.0.model
unit=$ckpt_dir/vocab.txt # vocab file, line: char/spm_pice
model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/
stage=2
stop_stage=100
corpus=aishell
lexicon=data/lexicon.txt # line: word ph0 ... phn, aishell/resource_aishell/lexicon.txt
text=data/text # line: utt text, aishell/data_aishell/transcript/aishell_transcript_v0.8.txt
text=data/text # line: utt text, aishell/data_aishell/transcript/aishell_transcript_v0.8.txt filter by data/train/text
. utils/parse_options.sh

@ -12,7 +12,7 @@ TOOLS_BIN=$ENGINE_TOOLS/valgrind/install/bin
export LC_AL=C
export PATH=$PATH:$TOOLS_BIN:$ENGINE_BUILD/nnet:$ENGINE_BUILD/decoder:$ENGINE_BUILD/../common/frontend/audio:$ENGINE_BUILD/recognizer
export PATH=$PATH:$TOOLS_BIN:$ENGINE_BUILD/nnet:$ENGINE_BUILD/decoder:$ENGINE_BUILD/../common/frontend/audio:$ENGINE_BUILD/recognizer:../../../fc_patch/openfst/bin:$ENGINE_BUILD/../kaldi/fstbin:$ENGINE_BUILD/../kaldi/lmbin
#PADDLE_LIB_PATH=$(python -c "import os; import paddle; include_dir=paddle.sysconfig.get_include(); paddle_dir=os.path.split(include_dir)[0]; libs_dir=os.path.join(paddle_dir, 'libs'); fluid_dir=os.path.join(paddle_dir, 'fluid'); out=':'.join([libs_dir, fluid_dir]); print(out);")
export LD_LIBRARY_PATH=$PADDLE_LIB_PATH:$LD_LIBRARY_PATH

@ -69,23 +69,17 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ];then
fi
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# process compute fbank feat
./local/feat.sh
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# decode with fbank feat input
./local/decode.sh
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# decode with wav input
./local/recognizer.sh
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# decode with wav input with quanted model
./local/recognizer_quant.sh
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# decode with wfst
./local/recognizer_wfst.sh
fi

Loading…
Cancel
Save