[speechx] rm ds2 && rm boost (#2786)
* fix openfst download error * add acknowledgments of openfst * refactor directory * clean ctc_decoders dir * add nnet cache && make 2 thread work * do not compile websocket * rm ds2 && rm boost * rm ds2 examplepull/2854/head
parent
5046d8ee94
commit
acf1d27230
@ -1,7 +0,0 @@
|
||||
# Deepspeech2 Streaming ASR
|
||||
|
||||
## Examples
|
||||
|
||||
* `websocket` - Streaming ASR with websocket for deepspeech2_aishell.
|
||||
* `aishell` - Streaming Decoding under aishell dataset, for local WER test.
|
||||
* `onnx` - Example to convert deepspeech2 to onnx format.
|
@ -1,3 +0,0 @@
|
||||
data
|
||||
exp
|
||||
aishell_*
|
@ -1,133 +0,0 @@
|
||||
# Aishell - Deepspeech2 Streaming
|
||||
|
||||
> We recommend using U2/U2++ model instead of DS2, please see [here](../../u2pp_ol/wenetspeech/).
|
||||
|
||||
A C++ deployment example for using the deepspeech2 model to recognize `wav` and compute `CER`. We using AISHELL-1 as test data.
|
||||
|
||||
## Source path.sh
|
||||
|
||||
```bash
|
||||
. path.sh
|
||||
```
|
||||
|
||||
SpeechX bins is under `echo $SPEECHX_BUILD`, more info please see `path.sh`.
|
||||
|
||||
## Recognize with linear feature
|
||||
|
||||
```bash
|
||||
bash run.sh
|
||||
```
|
||||
|
||||
`run.sh` has multi stage, for details please see `run.sh`:
|
||||
|
||||
1. donwload dataset, model and lm
|
||||
2. convert cmvn format and compute feature
|
||||
3. decode w/o lm by feature
|
||||
4. decode w/ ngram lm by feature
|
||||
5. decode w/ TLG graph by feature
|
||||
6. recognize w/ TLG graph by wav input
|
||||
|
||||
### Recognize with `.scp` file for wav
|
||||
|
||||
This sciprt using `recognizer_main` to recognize wav file.
|
||||
|
||||
The input is `scp` file which look like this:
|
||||
```text
|
||||
# head data/split1/1/aishell_test.scp
|
||||
BAC009S0764W0121 /workspace/PaddleSpeech/speechx/examples/u2pp_ol/wenetspeech/data/test/S0764/BAC009S0764W0121.wav
|
||||
BAC009S0764W0122 /workspace/PaddleSpeech/speechx/examples/u2pp_ol/wenetspeech/data/test/S0764/BAC009S0764W0122.wav
|
||||
...
|
||||
BAC009S0764W0125 /workspace/PaddleSpeech/speechx/examples/u2pp_ol/wenetspeech/data/test/S0764/BAC009S0764W0125.wav
|
||||
```
|
||||
|
||||
If you want to recognize one wav, you can make `scp` file like this:
|
||||
```text
|
||||
key path/to/wav/file
|
||||
```
|
||||
|
||||
Then specify `--wav_rspecifier=` param for `recognizer_main` bin. For other flags meaning, please see `help`:
|
||||
```bash
|
||||
recognizer_main --help
|
||||
```
|
||||
|
||||
For the exmaple to using `recognizer_main` please see `run.sh`.
|
||||
|
||||
|
||||
### CTC Prefix Beam Search w/o LM
|
||||
|
||||
```
|
||||
Overall -> 16.14 % N=104612 C=88190 S=16110 D=312 I=465
|
||||
Mandarin -> 16.14 % N=104612 C=88190 S=16110 D=312 I=465
|
||||
Other -> 0.00 % N=0 C=0 S=0 D=0 I=0
|
||||
```
|
||||
|
||||
### CTC Prefix Beam Search w/ LM
|
||||
|
||||
LM: zh_giga.no_cna_cmn.prune01244.klm
|
||||
```
|
||||
Overall -> 7.86 % N=104768 C=96865 S=7573 D=330 I=327
|
||||
Mandarin -> 7.86 % N=104768 C=96865 S=7573 D=330 I=327
|
||||
Other -> 0.00 % N=0 C=0 S=0 D=0 I=0
|
||||
```
|
||||
|
||||
### CTC TLG WFST
|
||||
|
||||
LM: [aishell train](http://paddlespeech.bj.bcebos.com/speechx/examples/ds2_ol/aishell/aishell_graph.zip)
|
||||
--acoustic_scale=1.2
|
||||
```
|
||||
Overall -> 11.14 % N=103017 C=93363 S=9583 D=71 I=1819
|
||||
Mandarin -> 11.14 % N=103017 C=93363 S=9583 D=71 I=1818
|
||||
Other -> 0.00 % N=0 C=0 S=0 D=0 I=1
|
||||
```
|
||||
|
||||
LM: [wenetspeech](http://paddlespeech.bj.bcebos.com/speechx/examples/ds2_ol/aishell/wenetspeech_graph.zip)
|
||||
--acoustic_scale=1.5
|
||||
```
|
||||
Overall -> 10.93 % N=104765 C=93410 S=9780 D=1575 I=95
|
||||
Mandarin -> 10.93 % N=104762 C=93410 S=9779 D=1573 I=95
|
||||
Other -> 100.00 % N=3 C=0 S=1 D=2 I=0
|
||||
```
|
||||
|
||||
## Recognize with fbank feature
|
||||
|
||||
This script is same to `run.sh`, but using fbank feature.
|
||||
|
||||
```bash
|
||||
bash run_fbank.sh
|
||||
```
|
||||
|
||||
### CTC Prefix Beam Search w/o LM
|
||||
|
||||
```
|
||||
Overall -> 10.44 % N=104765 C=94194 S=10174 D=397 I=369
|
||||
Mandarin -> 10.44 % N=104762 C=94194 S=10171 D=397 I=369
|
||||
Other -> 100.00 % N=3 C=0 S=3 D=0 I=0
|
||||
```
|
||||
|
||||
### CTC Prefix Beam Search w/ LM
|
||||
|
||||
LM: zh_giga.no_cna_cmn.prune01244.klm
|
||||
|
||||
```
|
||||
Overall -> 5.82 % N=104765 C=99386 S=4944 D=435 I=720
|
||||
Mandarin -> 5.82 % N=104762 C=99386 S=4941 D=435 I=720
|
||||
English -> 0.00 % N=0 C=0 S=0 D=0 I=0
|
||||
```
|
||||
|
||||
### CTC TLG WFST
|
||||
|
||||
LM: [aishell train](https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_graph2.zip)
|
||||
```
|
||||
Overall -> 9.58 % N=104765 C=94817 S=4326 D=5622 I=84
|
||||
Mandarin -> 9.57 % N=104762 C=94817 S=4325 D=5620 I=84
|
||||
Other -> 100.00 % N=3 C=0 S=1 D=2 I=0
|
||||
```
|
||||
|
||||
## Build TLG WFST graph
|
||||
|
||||
The script is for building TLG wfst graph, depending on `srilm`, please make sure it is installed.
|
||||
For more information please see the script below.
|
||||
|
||||
```bash
|
||||
bash ./local/run_build_tlg.sh
|
||||
```
|
@ -1,71 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# To be run from one directory above this script.
|
||||
. ./path.sh
|
||||
|
||||
nj=40
|
||||
text=data/local/lm/text
|
||||
lexicon=data/local/dict/lexicon.txt
|
||||
|
||||
for f in "$text" "$lexicon"; do
|
||||
[ ! -f $x ] && echo "$0: No such file $f" && exit 1;
|
||||
done
|
||||
|
||||
# Check SRILM tools
|
||||
if ! which ngram-count > /dev/null; then
|
||||
echo "srilm tools are not found, please download it and install it from: "
|
||||
echo "http://www.speech.sri.com/projects/srilm/download.html"
|
||||
echo "Then add the tools to your PATH"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# This script takes no arguments. It assumes you have already run
|
||||
# aishell_data_prep.sh.
|
||||
# It takes as input the files
|
||||
# data/local/lm/text
|
||||
# data/local/dict/lexicon.txt
|
||||
dir=data/local/lm
|
||||
mkdir -p $dir
|
||||
|
||||
cleantext=$dir/text.no_oov
|
||||
|
||||
# oov to <SPOKEN_NOISE>
|
||||
# lexicon line: word char0 ... charn
|
||||
# text line: utt word0 ... wordn -> line: <SPOKEN_NOISE> word0 ... wordn
|
||||
text_dir=$(dirname $text)
|
||||
split_name=$(basename $text)
|
||||
./local/split_data.sh $text_dir $text $split_name $nj
|
||||
|
||||
utils/run.pl JOB=1:$nj $text_dir/split${nj}/JOB/${split_name}.no_oov.log \
|
||||
cat ${text_dir}/split${nj}/JOB/${split_name} \| awk -v lex=$lexicon 'BEGIN{while((getline<lex) >0){ seen[$1]=1; } }
|
||||
{for(n=1; n<=NF;n++) { if (seen[$n]) { printf("%s ", $n); } else {printf("<SPOKEN_NOISE> ");} } printf("\n");}' \
|
||||
\> ${text_dir}/split${nj}/JOB/${split_name}.no_oov || exit 1;
|
||||
cat ${text_dir}/split${nj}/*/${split_name}.no_oov > $cleantext
|
||||
|
||||
# compute word counts, sort in descending order
|
||||
# line: count word
|
||||
cat $cleantext | awk '{for(n=2;n<=NF;n++) print $n; }' | sort --parallel=`nproc` | uniq -c | \
|
||||
sort --parallel=`nproc` -nr > $dir/word.counts || exit 1;
|
||||
|
||||
# Get counts from acoustic training transcripts, and add one-count
|
||||
# for each word in the lexicon (but not silence, we don't want it
|
||||
# in the LM-- we'll add it optionally later).
|
||||
cat $cleantext | awk '{for(n=2;n<=NF;n++) print $n; }' | \
|
||||
cat - <(grep -w -v '!SIL' $lexicon | awk '{print $1}') | \
|
||||
sort --parallel=`nproc` | uniq -c | sort --parallel=`nproc` -nr > $dir/unigram.counts || exit 1;
|
||||
|
||||
# word with <s> </s>
|
||||
cat $dir/unigram.counts | awk '{print $2}' | cat - <(echo "<s>"; echo "</s>" ) > $dir/wordlist
|
||||
|
||||
# hold out to compute ppl
|
||||
heldout_sent=10000 # Don't change this if you want result to be comparable with kaldi_lm results
|
||||
|
||||
mkdir -p $dir
|
||||
cat $cleantext | awk '{for(n=2;n<=NF;n++){ printf $n; if(n<NF) printf " "; else print ""; }}' | \
|
||||
head -$heldout_sent > $dir/heldout
|
||||
cat $cleantext | awk '{for(n=2;n<=NF;n++){ printf $n; if(n<NF) printf " "; else print ""; }}' | \
|
||||
tail -n +$heldout_sent > $dir/train
|
||||
|
||||
ngram-count -text $dir/train -order 3 -limit-vocab -vocab $dir/wordlist -unk \
|
||||
-map-unk "<UNK>" -kndiscount -interpolate -lm $dir/lm.arpa
|
||||
ngram -lm $dir/lm.arpa -ppl $dir/heldout
|
@ -1,145 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -eo pipefail
|
||||
|
||||
. path.sh
|
||||
|
||||
# attention, please replace the vocab is only for this script.
|
||||
# different acustic model has different vocab
|
||||
ckpt_dir=data/fbank_model
|
||||
unit=$ckpt_dir/data/lang_char/vocab.txt # vocab file, line: char/spm_pice
|
||||
model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/
|
||||
|
||||
stage=-1
|
||||
stop_stage=100
|
||||
corpus=aishell
|
||||
lexicon=data/lexicon.txt # line: word ph0 ... phn, aishell/resource_aishell/lexicon.txt
|
||||
text=data/text # line: utt text, aishell/data_aishell/transcript/aishell_transcript_v0.8.txt
|
||||
|
||||
. utils/parse_options.sh
|
||||
|
||||
data=$PWD/data
|
||||
mkdir -p $data
|
||||
|
||||
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
|
||||
if [ ! -f $data/speech.ngram.zh.tar.gz ];then
|
||||
# download ngram
|
||||
pushd $data
|
||||
wget -c http://paddlespeech.bj.bcebos.com/speechx/examples/ngram/zh/speech.ngram.zh.tar.gz
|
||||
tar xvzf speech.ngram.zh.tar.gz
|
||||
popd
|
||||
fi
|
||||
|
||||
if [ ! -f $ckpt_dir/data/mean_std.json ]; then
|
||||
# download model
|
||||
mkdir -p $ckpt_dir
|
||||
pushd $ckpt_dir
|
||||
wget -c https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr0/WIP1_asr0_deepspeech2_online_wenetspeech_ckpt_1.0.0a.model.tar.gz
|
||||
tar xzfv WIP1_asr0_deepspeech2_online_wenetspeech_ckpt_1.0.0a.model.tar.gz
|
||||
popd
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ ! -f $unit ]; then
|
||||
echo "$0: No such file $unit"
|
||||
exit 1;
|
||||
fi
|
||||
|
||||
if ! which ngram-count; then
|
||||
# need srilm install
|
||||
pushd $MAIN_ROOT/tools
|
||||
make srilm.done
|
||||
popd
|
||||
fi
|
||||
|
||||
mkdir -p data/local/dict
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
# Prepare dict
|
||||
# line: char/spm_pices
|
||||
cp $unit data/local/dict/units.txt
|
||||
|
||||
if [ ! -f $lexicon ];then
|
||||
utils/text_to_lexicon.py --has_key true --text $text --lexicon $lexicon
|
||||
echo "Generate $lexicon from $text"
|
||||
fi
|
||||
|
||||
# filter by vocab
|
||||
# line: word ph0 ... phn -> line: word char0 ... charn
|
||||
utils/fst/prepare_dict.py \
|
||||
--unit_file $unit \
|
||||
--in_lexicon ${lexicon} \
|
||||
--out_lexicon data/local/dict/lexicon.txt
|
||||
fi
|
||||
|
||||
lm=data/local/lm
|
||||
mkdir -p $lm
|
||||
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
# Train ngram lm
|
||||
cp $text $lm/text
|
||||
local/aishell_train_lms.sh
|
||||
echo "build LM done."
|
||||
fi
|
||||
|
||||
# build TLG
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
# build T & L
|
||||
utils/fst/compile_lexicon_token_fst.sh \
|
||||
data/local/dict data/local/tmp data/local/lang
|
||||
|
||||
# build G & TLG
|
||||
utils/fst/make_tlg.sh data/local/lm data/local/lang data/lang_test || exit 1;
|
||||
|
||||
fi
|
||||
|
||||
aishell_wav_scp=aishell_test.scp
|
||||
nj=40
|
||||
cmvn=$data/cmvn_fbank.ark
|
||||
wfst=$data/lang_test
|
||||
|
||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
if [ ! -d $data/test ]; then
|
||||
# download test dataset
|
||||
pushd $data
|
||||
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_test.zip
|
||||
unzip aishell_test.zip
|
||||
popd
|
||||
|
||||
realpath $data/test/*/*.wav > $data/wavlist
|
||||
awk -F '/' '{ print $(NF) }' $data/wavlist | awk -F '.' '{ print $1 }' > $data/utt_id
|
||||
paste $data/utt_id $data/wavlist > $data/$aishell_wav_scp
|
||||
fi
|
||||
|
||||
./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj
|
||||
|
||||
# convert cmvn format
|
||||
cmvn-json2kaldi --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn
|
||||
fi
|
||||
|
||||
wer=aishell_wer
|
||||
label_file=aishell_result
|
||||
export GLOG_logtostderr=1
|
||||
|
||||
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
||||
# recognize w/ TLG graph
|
||||
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/check_tlg.log \
|
||||
recognizer_main \
|
||||
--wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \
|
||||
--cmvn_file=$cmvn \
|
||||
--model_path=$model_dir/avg_5.jit.pdmodel \
|
||||
--streaming_chunk=30 \
|
||||
--use_fbank=true \
|
||||
--param_path=$model_dir/avg_5.jit.pdiparams \
|
||||
--word_symbol_table=$wfst/words.txt \
|
||||
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
|
||||
--model_cache_shapes="5-1-2048,5-1-2048" \
|
||||
--graph_path=$wfst/TLG.fst --max_active=7500 \
|
||||
--acoustic_scale=1.2 \
|
||||
--result_wspecifier=ark,t:$data/split${nj}/JOB/result_check_tlg
|
||||
|
||||
cat $data/split${nj}/*/result_check_tlg > $exp/${label_file}_check_tlg
|
||||
utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file}_check_tlg > $exp/${wer}.check_tlg
|
||||
echo "recognizer test have finished!!!"
|
||||
echo "please checkout in ${exp}/${wer}.check_tlg"
|
||||
fi
|
||||
|
||||
exit 0
|
@ -1,30 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -eo pipefail
|
||||
|
||||
data=$1
|
||||
scp=$2
|
||||
split_name=$3
|
||||
numsplit=$4
|
||||
|
||||
# save in $data/split{n}
|
||||
# $scp to split
|
||||
#
|
||||
|
||||
if [[ ! $numsplit -gt 0 ]]; then
|
||||
echo "Invalid num-split argument";
|
||||
exit 1;
|
||||
fi
|
||||
|
||||
directories=$(for n in `seq $numsplit`; do echo $data/split${numsplit}/$n; done)
|
||||
scp_splits=$(for n in `seq $numsplit`; do echo $data/split${numsplit}/$n/${split_name}; done)
|
||||
|
||||
# if this mkdir fails due to argument-list being too long, iterate.
|
||||
if ! mkdir -p $directories >&/dev/null; then
|
||||
for n in `seq $numsplit`; do
|
||||
mkdir -p $data/split${numsplit}/$n
|
||||
done
|
||||
fi
|
||||
|
||||
echo "utils/split_scp.pl $scp $scp_splits"
|
||||
utils/split_scp.pl $scp $scp_splits
|
@ -1,24 +0,0 @@
|
||||
# This contains the locations of binarys build required for running the examples.
|
||||
|
||||
MAIN_ROOT=`realpath $PWD/../../../../`
|
||||
SPEECHX_ROOT=$PWD/../../../
|
||||
SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx
|
||||
|
||||
SPEECHX_TOOLS=$SPEECHX_ROOT/tools
|
||||
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
|
||||
|
||||
[ -d $SPEECHX_BUILD ] || { echo "Error: 'build/speechx' directory not found. please ensure that the project build successfully"; }
|
||||
|
||||
export LC_AL=C
|
||||
|
||||
# openfst bin & kaldi bin
|
||||
KALDI_DIR=$SPEECHX_ROOT/build/speechx/kaldi/
|
||||
OPENFST_DIR=$SPEECHX_ROOT/fc_patch/openfst-build/src
|
||||
|
||||
# srilm
|
||||
export LIBLBFGS=${MAIN_ROOT}/tools/liblbfgs-1.10
|
||||
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH:-}:${LIBLBFGS}/lib/.libs
|
||||
export SRILM=${MAIN_ROOT}/tools/srilm
|
||||
|
||||
SPEECHX_BIN=$SPEECHX_BUILD/decoder:$SPEECHX_BUILD/frontend/audio
|
||||
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN:${SRILM}/bin:${SRILM}/bin/i686-m64:$KALDI_DIR/lmbin:$KALDI_DIR/fstbin:$OPENFST_DIR/bin
|
@ -1,180 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -x
|
||||
set -e
|
||||
|
||||
. path.sh
|
||||
|
||||
nj=40
|
||||
stage=0
|
||||
stop_stage=100
|
||||
|
||||
. utils/parse_options.sh
|
||||
|
||||
# 1. compile
|
||||
if [ ! -d ${SPEECHX_BUILD} ]; then
|
||||
pushd ${SPEECHX_ROOT}
|
||||
bash build.sh
|
||||
popd
|
||||
fi
|
||||
|
||||
# input
|
||||
mkdir -p data
|
||||
data=$PWD/data
|
||||
|
||||
ckpt_dir=$data/model
|
||||
model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/
|
||||
vocb_dir=$ckpt_dir/data/lang_char/
|
||||
|
||||
# output
|
||||
mkdir -p exp
|
||||
exp=$PWD/exp
|
||||
|
||||
aishell_wav_scp=aishell_test.scp
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ];then
|
||||
if [ ! -d $data/test ]; then
|
||||
# donwload dataset
|
||||
pushd $data
|
||||
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_test.zip
|
||||
unzip aishell_test.zip
|
||||
popd
|
||||
|
||||
realpath $data/test/*/*.wav > $data/wavlist
|
||||
awk -F '/' '{ print $(NF) }' $data/wavlist | awk -F '.' '{ print $1 }' > $data/utt_id
|
||||
paste $data/utt_id $data/wavlist > $data/$aishell_wav_scp
|
||||
fi
|
||||
|
||||
if [ ! -f $ckpt_dir/data/mean_std.json ]; then
|
||||
# download model
|
||||
mkdir -p $ckpt_dir
|
||||
pushd $ckpt_dir
|
||||
wget -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
|
||||
tar xzfv asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
|
||||
popd
|
||||
fi
|
||||
|
||||
lm=$data/zh_giga.no_cna_cmn.prune01244.klm
|
||||
if [ ! -f $lm ]; then
|
||||
# download kenlm bin
|
||||
pushd $data
|
||||
wget -c https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm
|
||||
popd
|
||||
fi
|
||||
fi
|
||||
|
||||
# 3. make feature
|
||||
text=$data/test/text
|
||||
label_file=./aishell_result
|
||||
wer=./aishell_wer
|
||||
|
||||
export GLOG_logtostderr=1
|
||||
|
||||
|
||||
cmvn=$data/cmvn.ark
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
# 3. convert cmvn format and compute linear feat
|
||||
cmvn_json2kaldi_main --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn
|
||||
|
||||
./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj
|
||||
|
||||
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/feat.log \
|
||||
compute_linear_spectrogram_main \
|
||||
--wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \
|
||||
--feature_wspecifier=ark,scp:$data/split${nj}/JOB/feat.ark,$data/split${nj}/JOB/feat.scp \
|
||||
--cmvn_file=$cmvn \
|
||||
echo "feature make have finished!!!"
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
# decode w/o lm
|
||||
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.wolm.log \
|
||||
ctc_beam_search_decoder_main \
|
||||
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
|
||||
--model_path=$model_dir/avg_1.jit.pdmodel \
|
||||
--param_path=$model_dir/avg_1.jit.pdiparams \
|
||||
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
|
||||
--nnet_decoder_chunk=8 \
|
||||
--dict_file=$vocb_dir/vocab.txt \
|
||||
--result_wspecifier=ark,t:$data/split${nj}/JOB/result
|
||||
|
||||
cat $data/split${nj}/*/result > $exp/${label_file}
|
||||
utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file} > $exp/${wer}
|
||||
echo "ctc-prefix-beam-search-decoder-ol without lm has finished!!!"
|
||||
echo "please checkout in ${exp}/${wer}"
|
||||
tail -n 7 $exp/${wer}
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
# decode w/ ngram lm with feature input
|
||||
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.lm.log \
|
||||
ctc_beam_search_decoder_main \
|
||||
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
|
||||
--model_path=$model_dir/avg_1.jit.pdmodel \
|
||||
--param_path=$model_dir/avg_1.jit.pdiparams \
|
||||
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
|
||||
--nnet_decoder_chunk=8 \
|
||||
--dict_file=$vocb_dir/vocab.txt \
|
||||
--lm_path=$lm \
|
||||
--result_wspecifier=ark,t:$data/split${nj}/JOB/result_lm
|
||||
|
||||
cat $data/split${nj}/*/result_lm > $exp/${label_file}_lm
|
||||
utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file}_lm > $exp/${wer}.lm
|
||||
echo "ctc-prefix-beam-search-decoder-ol with lm test has finished!!!"
|
||||
echo "please checkout in ${exp}/${wer}.lm"
|
||||
tail -n 7 $exp/${wer}.lm
|
||||
fi
|
||||
|
||||
wfst=$data/wfst/
|
||||
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
||||
mkdir -p $wfst
|
||||
if [ ! -f $wfst/aishell_graph.zip ]; then
|
||||
# download TLG graph
|
||||
pushd $wfst
|
||||
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_graph.zip
|
||||
unzip aishell_graph.zip
|
||||
mv aishell_graph/* $wfst
|
||||
popd
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
||||
# decoder w/ TLG graph with feature input
|
||||
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.wfst.log \
|
||||
ctc_tlg_decoder_main \
|
||||
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
|
||||
--model_path=$model_dir/avg_1.jit.pdmodel \
|
||||
--param_path=$model_dir/avg_1.jit.pdiparams \
|
||||
--word_symbol_table=$wfst/words.txt \
|
||||
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
|
||||
--graph_path=$wfst/TLG.fst --max_active=7500 \
|
||||
--nnet_decoder_chunk=8 \
|
||||
--acoustic_scale=1.2 \
|
||||
--result_wspecifier=ark,t:$data/split${nj}/JOB/result_tlg
|
||||
|
||||
cat $data/split${nj}/*/result_tlg > $exp/${label_file}_tlg
|
||||
utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file}_tlg > $exp/${wer}.tlg
|
||||
echo "wfst-decoder-ol have finished!!!"
|
||||
echo "please checkout in ${exp}/${wer}.tlg"
|
||||
tail -n 7 $exp/${wer}.tlg
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
||||
# recognize from wav file w/ TLG graph
|
||||
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recognizer.log \
|
||||
recognizer_main \
|
||||
--wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \
|
||||
--cmvn_file=$cmvn \
|
||||
--model_path=$model_dir/avg_1.jit.pdmodel \
|
||||
--param_path=$model_dir/avg_1.jit.pdiparams \
|
||||
--word_symbol_table=$wfst/words.txt \
|
||||
--nnet_decoder_chunk=8 \
|
||||
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
|
||||
--graph_path=$wfst/TLG.fst --max_active=7500 \
|
||||
--acoustic_scale=1.2 \
|
||||
--result_wspecifier=ark,t:$data/split${nj}/JOB/result_recognizer
|
||||
|
||||
cat $data/split${nj}/*/result_recognizer > $exp/${label_file}_recognizer
|
||||
utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file}_recognizer > $exp/${wer}.recognizer
|
||||
echo "recognizer test have finished!!!"
|
||||
echo "please checkout in ${exp}/${wer}.recognizer"
|
||||
tail -n 7 $exp/${wer}.recognizer
|
||||
fi
|
@ -1,177 +0,0 @@
|
||||
#!/bin/bash
|
||||
set +x
|
||||
set -e
|
||||
|
||||
. path.sh
|
||||
|
||||
nj=40
|
||||
stage=0
|
||||
stop_stage=5
|
||||
|
||||
. utils/parse_options.sh
|
||||
|
||||
# 1. compile
|
||||
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
|
||||
pushd ${SPEECHX_ROOT}
|
||||
bash build.sh
|
||||
popd
|
||||
fi
|
||||
|
||||
# input
|
||||
mkdir -p data
|
||||
data=$PWD/data
|
||||
|
||||
ckpt_dir=$data/fbank_model
|
||||
model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/
|
||||
vocb_dir=$ckpt_dir/data/lang_char/
|
||||
|
||||
# output
|
||||
mkdir -p exp
|
||||
exp=$PWD/exp
|
||||
|
||||
lm=$data/zh_giga.no_cna_cmn.prune01244.klm
|
||||
aishell_wav_scp=aishell_test.scp
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ];then
|
||||
if [ ! -d $data/test ]; then
|
||||
pushd $data
|
||||
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_test.zip
|
||||
unzip aishell_test.zip
|
||||
popd
|
||||
|
||||
realpath $data/test/*/*.wav > $data/wavlist
|
||||
awk -F '/' '{ print $(NF) }' $data/wavlist | awk -F '.' '{ print $1 }' > $data/utt_id
|
||||
paste $data/utt_id $data/wavlist > $data/$aishell_wav_scp
|
||||
fi
|
||||
|
||||
if [ ! -f $ckpt_dir/data/mean_std.json ]; then
|
||||
mkdir -p $ckpt_dir
|
||||
pushd $ckpt_dir
|
||||
wget -c https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr0/WIP1_asr0_deepspeech2_online_wenetspeech_ckpt_1.0.0a.model.tar.gz
|
||||
tar xzfv WIP1_asr0_deepspeech2_online_wenetspeech_ckpt_1.0.0a.model.tar.gz
|
||||
popd
|
||||
fi
|
||||
|
||||
if [ ! -f $lm ]; then
|
||||
pushd $data
|
||||
wget -c https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm
|
||||
popd
|
||||
fi
|
||||
fi
|
||||
|
||||
# 3. make feature
|
||||
text=$data/test/text
|
||||
label_file=./aishell_result_fbank
|
||||
wer=./aishell_wer_fbank
|
||||
|
||||
export GLOG_logtostderr=1
|
||||
|
||||
|
||||
cmvn=$data/cmvn_fbank.ark
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
# 3. convert cmvn format and compute fbank feat
|
||||
cmvn_json2kaldi_main --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn --binary=false
|
||||
|
||||
./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj
|
||||
|
||||
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/feat.log \
|
||||
compute_fbank_main \
|
||||
--wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \
|
||||
--feature_wspecifier=ark,scp:$data/split${nj}/JOB/fbank_feat.ark,$data/split${nj}/JOB/fbank_feat.scp \
|
||||
--cmvn_file=$cmvn \
|
||||
--streaming_chunk=36
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
# decode w/ lm by feature
|
||||
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.fbank.wolm.log \
|
||||
ctc_beam_search_decoder_main \
|
||||
--feature_rspecifier=scp:$data/split${nj}/JOB/fbank_feat.scp \
|
||||
--model_path=$model_dir/avg_5.jit.pdmodel \
|
||||
--param_path=$model_dir/avg_5.jit.pdiparams \
|
||||
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
|
||||
--model_cache_shapes="5-1-2048,5-1-2048" \
|
||||
--nnet_decoder_chunk=8 \
|
||||
--dict_file=$vocb_dir/vocab.txt \
|
||||
--result_wspecifier=ark,t:$data/split${nj}/JOB/result_fbank
|
||||
|
||||
cat $data/split${nj}/*/result_fbank > $exp/${label_file}
|
||||
utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file} > $exp/${wer}
|
||||
tail -n 7 $exp/${wer}
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
# decode with ngram lm by feature
|
||||
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.fbank.lm.log \
|
||||
ctc_beam_search_decoder_main \
|
||||
--feature_rspecifier=scp:$data/split${nj}/JOB/fbank_feat.scp \
|
||||
--model_path=$model_dir/avg_5.jit.pdmodel \
|
||||
--param_path=$model_dir/avg_5.jit.pdiparams \
|
||||
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
|
||||
--model_cache_shapes="5-1-2048,5-1-2048" \
|
||||
--nnet_decoder_chunk=8 \
|
||||
--dict_file=$vocb_dir/vocab.txt \
|
||||
--lm_path=$lm \
|
||||
--result_wspecifier=ark,t:$data/split${nj}/JOB/fbank_result_lm
|
||||
|
||||
cat $data/split${nj}/*/fbank_result_lm > $exp/${label_file}_lm
|
||||
utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file}_lm > $exp/${wer}.lm
|
||||
tail -n 7 $exp/${wer}.lm
|
||||
fi
|
||||
|
||||
wfst=$data/wfst_fbank/
|
||||
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
||||
mkdir -p $wfst
|
||||
if [ ! -f $wfst/aishell_graph2.zip ]; then
|
||||
pushd $wfst
|
||||
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_graph2.zip
|
||||
unzip aishell_graph2.zip
|
||||
mv aishell_graph2/* $wfst
|
||||
popd
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
||||
# decode w/ TLG graph by feature
|
||||
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.fbank.wfst.log \
|
||||
ctc_tlg_decoder_main \
|
||||
--feature_rspecifier=scp:$data/split${nj}/JOB/fbank_feat.scp \
|
||||
--model_path=$model_dir/avg_5.jit.pdmodel \
|
||||
--param_path=$model_dir/avg_5.jit.pdiparams \
|
||||
--word_symbol_table=$wfst/words.txt \
|
||||
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
|
||||
--model_cache_shapes="5-1-2048,5-1-2048" \
|
||||
--nnet_decoder_chunk=8 \
|
||||
--graph_path=$wfst/TLG.fst --max_active=7500 \
|
||||
--acoustic_scale=1.2 \
|
||||
--result_wspecifier=ark,t:$data/split${nj}/JOB/result_tlg
|
||||
|
||||
cat $data/split${nj}/*/result_tlg > $exp/${label_file}_tlg
|
||||
utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file}_tlg > $exp/${wer}.tlg
|
||||
echo "wfst-decoder-ol have finished!!!"
|
||||
echo "please checkout in ${exp}/${wer}.tlg"
|
||||
tail -n 7 $exp/${wer}.tlg
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
||||
# recgonize w/ TLG graph by wav
|
||||
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/fbank_recognizer.log \
|
||||
recognizer_main \
|
||||
--wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \
|
||||
--cmvn_file=$cmvn \
|
||||
--model_path=$model_dir/avg_5.jit.pdmodel \
|
||||
--use_fbank=true \
|
||||
--param_path=$model_dir/avg_5.jit.pdiparams \
|
||||
--word_symbol_table=$wfst/words.txt \
|
||||
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
|
||||
--model_cache_shapes="5-1-2048,5-1-2048" \
|
||||
--nnet_decoder_chunk=8 \
|
||||
--graph_path=$wfst/TLG.fst --max_active=7500 \
|
||||
--acoustic_scale=1.2 \
|
||||
--result_wspecifier=ark,t:$data/split${nj}/JOB/result_fbank_recognizer
|
||||
|
||||
cat $data/split${nj}/*/result_fbank_recognizer > $exp/${label_file}_recognizer
|
||||
utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file}_recognizer > $exp/${wer}.recognizer
|
||||
echo "recognizer test have finished!!!"
|
||||
echo "please checkout in ${exp}/${wer}.recognizer"
|
||||
tail -n 7 $exp/${wer}.recognizer
|
||||
fi
|
@ -1 +0,0 @@
|
||||
../../../../utils/
|
@ -1,3 +0,0 @@
|
||||
data
|
||||
log
|
||||
exp
|
@ -1,100 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import os
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime
|
||||
import paddle
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
'--input_file',
|
||||
type=str,
|
||||
default="static_ds2online_inputs.pickle",
|
||||
help="aishell ds2 input data file. For wenetspeech, we only feed for infer model",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--model_type',
|
||||
type=str,
|
||||
default="aishell",
|
||||
help="aishell(1024) or wenetspeech(2048)", )
|
||||
parser.add_argument(
|
||||
'--model_dir', type=str, default=".", help="paddle model dir.")
|
||||
parser.add_argument(
|
||||
'--model_prefix',
|
||||
type=str,
|
||||
default="avg_1.jit",
|
||||
help="paddle model prefix.")
|
||||
parser.add_argument(
|
||||
'--onnx_model',
|
||||
type=str,
|
||||
default='./model.old.onnx',
|
||||
help="onnx model.")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
FLAGS = parse_args()
|
||||
|
||||
# input and output
|
||||
with open(FLAGS.input_file, 'rb') as f:
|
||||
iodict = pickle.load(f)
|
||||
print(iodict.keys())
|
||||
|
||||
audio_chunk = iodict['audio_chunk']
|
||||
audio_chunk_lens = iodict['audio_chunk_lens']
|
||||
chunk_state_h_box = iodict['chunk_state_h_box']
|
||||
chunk_state_c_box = iodict['chunk_state_c_bos']
|
||||
print("raw state shape: ", chunk_state_c_box.shape)
|
||||
|
||||
if FLAGS.model_type == 'wenetspeech':
|
||||
chunk_state_h_box = np.repeat(chunk_state_h_box, 2, axis=-1)
|
||||
chunk_state_c_box = np.repeat(chunk_state_c_box, 2, axis=-1)
|
||||
print("state shape: ", chunk_state_c_box.shape)
|
||||
|
||||
# paddle
|
||||
model = paddle.jit.load(os.path.join(FLAGS.model_dir, FLAGS.model_prefix))
|
||||
res_chunk, res_lens, chunk_state_h, chunk_state_c = model(
|
||||
paddle.to_tensor(audio_chunk),
|
||||
paddle.to_tensor(audio_chunk_lens),
|
||||
paddle.to_tensor(chunk_state_h_box),
|
||||
paddle.to_tensor(chunk_state_c_box), )
|
||||
|
||||
# onnxruntime
|
||||
options = onnxruntime.SessionOptions()
|
||||
options.enable_profiling = True
|
||||
sess = onnxruntime.InferenceSession(FLAGS.onnx_model, sess_options=options)
|
||||
ort_res_chunk, ort_res_lens, ort_chunk_state_h, ort_chunk_state_c = sess.run(
|
||||
['softmax_0.tmp_0', 'tmp_5', 'concat_0.tmp_0', 'concat_1.tmp_0'], {
|
||||
"audio_chunk": audio_chunk,
|
||||
"audio_chunk_lens": audio_chunk_lens,
|
||||
"chunk_state_h_box": chunk_state_h_box,
|
||||
"chunk_state_c_box": chunk_state_c_box
|
||||
})
|
||||
|
||||
print(sess.end_profiling())
|
||||
|
||||
# assert paddle equal ort
|
||||
print(np.allclose(ort_res_chunk, res_chunk, atol=1e-6))
|
||||
print(np.allclose(ort_res_lens, res_lens, atol=1e-6))
|
||||
|
||||
if FLAGS.model_type == 'aishell':
|
||||
print(np.allclose(ort_chunk_state_h, chunk_state_h, atol=1e-6))
|
||||
print(np.allclose(ort_chunk_state_c, chunk_state_c, atol=1e-6))
|
@ -1,14 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# show model
|
||||
|
||||
if [ $# != 1 ];then
|
||||
echo "usage: $0 model_path"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
file=$1
|
||||
|
||||
pip install netron
|
||||
netron -p 8082 --host $(hostname -i) $file
|
@ -1,7 +0,0 @@
|
||||
|
||||
#!/bin/bash
|
||||
|
||||
# clone onnx repos
|
||||
git clone https://github.com/onnx/onnx.git
|
||||
git clone https://github.com/microsoft/onnxruntime.git
|
||||
git clone https://github.com/PaddlePaddle/Paddle2ONNX.git
|
@ -1,37 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
|
||||
import onnx
|
||||
from onnx import version_converter
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(prog=__doc__)
|
||||
parser.add_argument(
|
||||
"--model-file", type=str, required=True, help='path/to/the/model.onnx.')
|
||||
parser.add_argument(
|
||||
"--save-model",
|
||||
type=str,
|
||||
required=True,
|
||||
help='path/to/saved/model.onnx.')
|
||||
# Models must be opset10 or higher to be quantized.
|
||||
parser.add_argument(
|
||||
"--target-opset", type=int, default=11, help='path/to/the/model.onnx.')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"to opset: {args.target_opset}")
|
||||
|
||||
# Preprocessing: load the model to be converted.
|
||||
model_path = args.model_file
|
||||
original_model = onnx.load(model_path)
|
||||
|
||||
# print('The model before conversion:\n{}'.format(original_model))
|
||||
|
||||
# A full list of supported adapters can be found here:
|
||||
# https://github.com/onnx/onnx/blob/main/onnx/version_converter.py#L21
|
||||
# Apply the version conversion on the original model
|
||||
converted_model = version_converter.convert_version(original_model,
|
||||
args.target_opset)
|
||||
|
||||
# print('The model after conversion:\n{}'.format(converted_model))
|
||||
onnx.save(converted_model, args.save_model)
|
File diff suppressed because it is too large
Load Diff
@ -1,20 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
if [ $# != 3 ];then
|
||||
# ./local/onnx_opt.sh model.old.onnx model.opt.onnx "audio_chunk:1,-1,161 audio_chunk_lens:1 chunk_state_c_box:5,1,1024 chunk_state_h_box:5,1,1024"
|
||||
echo "usage: $0 onnx.model.in onnx.model.out input_shape "
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# onnx optimizer
|
||||
pip install onnx-simplifier
|
||||
|
||||
in=$1
|
||||
out=$2
|
||||
input_shape=$3
|
||||
|
||||
check_n=3
|
||||
|
||||
onnxsim $in $out $check_n --dynamic-input-shape --input-shape $input_shape
|
@ -1,128 +0,0 @@
|
||||
#!/usr/bin/env python3 -W ignore::DeprecationWarning
|
||||
# prune model by output names
|
||||
import argparse
|
||||
import copy
|
||||
import sys
|
||||
|
||||
import onnx
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--model',
|
||||
required=True,
|
||||
help='Path of directory saved the input model.')
|
||||
parser.add_argument(
|
||||
'--output_names',
|
||||
required=True,
|
||||
nargs='+',
|
||||
help='The outputs of pruned model.')
|
||||
parser.add_argument(
|
||||
'--save_file', required=True, help='Path to save the new onnx model.')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_arguments()
|
||||
|
||||
if len(set(args.output_names)) < len(args.output_names):
|
||||
print(
|
||||
"[ERROR] There's dumplicate name in --output_names, which is not allowed."
|
||||
)
|
||||
sys.exit(-1)
|
||||
|
||||
model = onnx.load(args.model)
|
||||
|
||||
# collect all node outputs and graph output
|
||||
output_tensor_names = set()
|
||||
for node in model.graph.node:
|
||||
for out in node.output:
|
||||
# may contain model output
|
||||
output_tensor_names.add(out)
|
||||
|
||||
# for out in model.graph.output:
|
||||
# output_tensor_names.add(out.name)
|
||||
|
||||
for output_name in args.output_names:
|
||||
if output_name not in output_tensor_names:
|
||||
print(
|
||||
"[ERROR] Cannot find output tensor name '{}' in onnx model graph.".
|
||||
format(output_name))
|
||||
sys.exit(-1)
|
||||
|
||||
output_node_indices = set() # has output names
|
||||
output_to_node = dict() # all node outputs
|
||||
for i, node in enumerate(model.graph.node):
|
||||
for out in node.output:
|
||||
output_to_node[out] = i
|
||||
if out in args.output_names:
|
||||
output_node_indices.add(i)
|
||||
|
||||
# from outputs find all the ancestors
|
||||
reserved_node_indices = copy.deepcopy(
|
||||
output_node_indices) # nodes need to keep
|
||||
reserved_inputs = set() # model input to keep
|
||||
new_output_node_indices = copy.deepcopy(output_node_indices)
|
||||
|
||||
while True and len(new_output_node_indices) > 0:
|
||||
output_node_indices = copy.deepcopy(new_output_node_indices)
|
||||
|
||||
new_output_node_indices = set()
|
||||
|
||||
for out_node_idx in output_node_indices:
|
||||
# backtrace to parenet
|
||||
for ipt in model.graph.node[out_node_idx].input:
|
||||
if ipt in output_to_node:
|
||||
reserved_node_indices.add(output_to_node[ipt])
|
||||
new_output_node_indices.add(output_to_node[ipt])
|
||||
else:
|
||||
reserved_inputs.add(ipt)
|
||||
|
||||
num_inputs = len(model.graph.input)
|
||||
num_outputs = len(model.graph.output)
|
||||
num_nodes = len(model.graph.node)
|
||||
print(
|
||||
f"old graph has {num_inputs} inputs, {num_outputs} outpus, {num_nodes} nodes"
|
||||
)
|
||||
print(f"{len(reserved_node_indices)} node to keep.")
|
||||
|
||||
# del node not to keep
|
||||
for idx in range(num_nodes - 1, -1, -1):
|
||||
if idx not in reserved_node_indices:
|
||||
del model.graph.node[idx]
|
||||
|
||||
# del graph input not to keep
|
||||
for idx in range(num_inputs - 1, -1, -1):
|
||||
if model.graph.input[idx].name not in reserved_inputs:
|
||||
del model.graph.input[idx]
|
||||
|
||||
# del old graph outputs
|
||||
for i in range(num_outputs):
|
||||
del model.graph.output[0]
|
||||
|
||||
# new graph output as user input
|
||||
for out in args.output_names:
|
||||
model.graph.output.extend([onnx.ValueInfoProto(name=out)])
|
||||
|
||||
# infer shape
|
||||
try:
|
||||
from onnx_infer_shape import SymbolicShapeInference
|
||||
model = SymbolicShapeInference.infer_shapes(
|
||||
model,
|
||||
int_max=2**31 - 1,
|
||||
auto_merge=True,
|
||||
guess_output_rank=False,
|
||||
verbose=1)
|
||||
except Exception as e:
|
||||
print(f"skip infer shape step: {e}")
|
||||
|
||||
# check onnx model
|
||||
onnx.checker.check_model(model)
|
||||
# save onnx model
|
||||
onnx.save(model, args.save_file)
|
||||
print("[Finished] The new model saved in {}.".format(args.save_file))
|
||||
print("[DEBUG INFO] The inputs of new model: {}".format(
|
||||
[x.name for x in model.graph.input]))
|
||||
print("[DEBUG INFO] The outputs of new model: {}".format(
|
||||
[x.name for x in model.graph.output]))
|
@ -1,111 +0,0 @@
|
||||
#!/usr/bin/env python3 -W ignore::DeprecationWarning
|
||||
# rename node to new names
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
import onnx
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--model',
|
||||
required=True,
|
||||
help='Path of directory saved the input model.')
|
||||
parser.add_argument(
|
||||
'--origin_names',
|
||||
required=True,
|
||||
nargs='+',
|
||||
help='The original name you want to modify.')
|
||||
parser.add_argument(
|
||||
'--new_names',
|
||||
required=True,
|
||||
nargs='+',
|
||||
help='The new name you want change to, the number of new_names should be same with the number of origin_names'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--save_file', required=True, help='Path to save the new onnx model.')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_arguments()
|
||||
|
||||
if len(set(args.origin_names)) < len(args.origin_names):
|
||||
print(
|
||||
"[ERROR] There's dumplicate name in --origin_names, which is not allowed."
|
||||
)
|
||||
sys.exit(-1)
|
||||
|
||||
if len(set(args.new_names)) < len(args.new_names):
|
||||
print(
|
||||
"[ERROR] There's dumplicate name in --new_names, which is not allowed."
|
||||
)
|
||||
sys.exit(-1)
|
||||
|
||||
if len(args.new_names) != len(args.origin_names):
|
||||
print(
|
||||
"[ERROR] Number of --new_names must be same with the number of --origin_names."
|
||||
)
|
||||
sys.exit(-1)
|
||||
|
||||
model = onnx.load(args.model)
|
||||
|
||||
# collect input and all node output
|
||||
output_tensor_names = set()
|
||||
for ipt in model.graph.input:
|
||||
output_tensor_names.add(ipt.name)
|
||||
|
||||
for node in model.graph.node:
|
||||
for out in node.output:
|
||||
output_tensor_names.add(out)
|
||||
|
||||
for origin_name in args.origin_names:
|
||||
if origin_name not in output_tensor_names:
|
||||
print(
|
||||
f"[ERROR] Cannot find tensor name '{origin_name}' in onnx model graph."
|
||||
)
|
||||
sys.exit(-1)
|
||||
|
||||
for new_name in args.new_names:
|
||||
if new_name in output_tensor_names:
|
||||
print(
|
||||
"[ERROR] The defined new_name '{}' is already exist in the onnx model, which is not allowed."
|
||||
)
|
||||
sys.exit(-1)
|
||||
|
||||
# rename graph input
|
||||
for i, ipt in enumerate(model.graph.input):
|
||||
if ipt.name in args.origin_names:
|
||||
idx = args.origin_names.index(ipt.name)
|
||||
model.graph.input[i].name = args.new_names[idx]
|
||||
|
||||
# rename node input and output
|
||||
for i, node in enumerate(model.graph.node):
|
||||
for j, ipt in enumerate(node.input):
|
||||
if ipt in args.origin_names:
|
||||
idx = args.origin_names.index(ipt)
|
||||
model.graph.node[i].input[j] = args.new_names[idx]
|
||||
|
||||
for j, out in enumerate(node.output):
|
||||
if out in args.origin_names:
|
||||
idx = args.origin_names.index(out)
|
||||
model.graph.node[i].output[j] = args.new_names[idx]
|
||||
|
||||
# rename graph output
|
||||
for i, out in enumerate(model.graph.output):
|
||||
if out.name in args.origin_names:
|
||||
idx = args.origin_names.index(out.name)
|
||||
model.graph.output[i].name = args.new_names[idx]
|
||||
|
||||
# check onnx model
|
||||
onnx.checker.check_model(model)
|
||||
|
||||
# save model
|
||||
onnx.save(model, args.save_file)
|
||||
|
||||
print("[Finished] The new model saved in {}.".format(args.save_file))
|
||||
print("[DEBUG INFO] The inputs of new model: {}".format(
|
||||
[x.name for x in model.graph.input]))
|
||||
print("[DEBUG INFO] The outputs of new model: {}".format(
|
||||
[x.name for x in model.graph.output]))
|
@ -1,48 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
|
||||
from onnxruntime.quantization import quantize_dynamic
|
||||
from onnxruntime.quantization import QuantType
|
||||
|
||||
|
||||
def quantize_onnx_model(onnx_model_path,
|
||||
quantized_model_path,
|
||||
nodes_to_exclude=[]):
|
||||
print("Starting quantization...")
|
||||
|
||||
quantize_dynamic(
|
||||
onnx_model_path,
|
||||
quantized_model_path,
|
||||
weight_type=QuantType.QInt8,
|
||||
nodes_to_exclude=nodes_to_exclude)
|
||||
|
||||
print(f"Quantized model saved to: {quantized_model_path}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model-in",
|
||||
type=str,
|
||||
required=True,
|
||||
help="ONNX model", )
|
||||
parser.add_argument(
|
||||
"--model-out",
|
||||
type=str,
|
||||
required=True,
|
||||
default='model.quant.onnx',
|
||||
help="ONNX model", )
|
||||
parser.add_argument(
|
||||
"--nodes-to-exclude",
|
||||
type=str,
|
||||
required=True,
|
||||
help="nodes to exclude. e.g. conv,linear.", )
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
nodes_to_exclude = args.nodes_to_exclude.split(',')
|
||||
quantize_onnx_model(args.model_in, args.model_out, nodes_to_exclude)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,45 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
|
||||
import onnxruntime as ort
|
||||
|
||||
# onnxruntime optimizer.
|
||||
# https://onnxruntime.ai/docs/performance/graph-optimizations.html
|
||||
# https://onnxruntime.ai/docs/api/python/api_summary.html#api
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--model_in', required=True, type=str, help='Path to onnx model.')
|
||||
parser.add_argument(
|
||||
'--opt_level',
|
||||
required=True,
|
||||
type=int,
|
||||
default=0,
|
||||
choices=[0, 1, 2],
|
||||
help='Path to onnx model.')
|
||||
parser.add_argument(
|
||||
'--model_out', required=True, help='path to save the optimized model.')
|
||||
parser.add_argument('--debug', default=False, help='output debug info.')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_arguments()
|
||||
|
||||
sess_options = ort.SessionOptions()
|
||||
|
||||
# Set graph optimization level
|
||||
print(f"opt level: {args.opt_level}")
|
||||
if args.opt_level == 0:
|
||||
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
|
||||
elif args.opt_level == 1:
|
||||
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
|
||||
else:
|
||||
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
|
||||
# To enable model serialization after graph optimization set this
|
||||
sess_options.optimized_model_filepath = args.model_out
|
||||
|
||||
session = ort.InferenceSession(args.model_in, sess_options)
|
@ -1,26 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
if [ $# != 4 ];then
|
||||
# local/tonnx.sh data/exp/deepspeech2_online/checkpoints avg_1.jit.pdmodel avg_1.jit.pdiparams exp/model.onnx
|
||||
echo "usage: $0 model_dir model_name param_name onnx_output_name"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
dir=$1
|
||||
model=$2
|
||||
param=$3
|
||||
output=$4
|
||||
|
||||
pip install paddle2onnx
|
||||
pip install onnx
|
||||
|
||||
# https://github.com/PaddlePaddle/Paddle2ONNX#%E5%91%BD%E4%BB%A4%E8%A1%8C%E8%BD%AC%E6%8D%A2
|
||||
# opset10 support quantize
|
||||
paddle2onnx --model_dir $dir \
|
||||
--model_filename $model \
|
||||
--params_filename $param \
|
||||
--save_file $output \
|
||||
--enable_dev_version True \
|
||||
--opset_version 11 \
|
||||
--enable_onnx_checker True
|
||||
|
@ -1,14 +0,0 @@
|
||||
# This contains the locations of binarys build required for running the examples.
|
||||
|
||||
MAIN_ROOT=`realpath $PWD/../../../../`
|
||||
SPEECHX_ROOT=$PWD/../../../
|
||||
SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx
|
||||
|
||||
SPEECHX_TOOLS=$SPEECHX_ROOT/tools
|
||||
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
|
||||
|
||||
[ -d $SPEECHX_BUILD ] || { echo "Error: 'build/speechx' directory not found. please ensure that the project build successfully"; }
|
||||
|
||||
export LC_AL=C
|
||||
|
||||
export PATH=$PATH:$TOOLS_BIN
|
@ -1,91 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
. path.sh
|
||||
|
||||
stage=0
|
||||
stop_stage=50
|
||||
tarfile=asr0_deepspeech2_online_wenetspeech_ckpt_1.0.2.model.tar.gz
|
||||
#tarfile=asr0_deepspeech2_online_aishell_fbank161_ckpt_1.0.1.model.tar.gz
|
||||
model_prefix=avg_10.jit
|
||||
#model_prefix=avg_1.jit
|
||||
model=${model_prefix}.pdmodel
|
||||
param=${model_prefix}.pdiparams
|
||||
|
||||
. utils/parse_options.sh
|
||||
|
||||
data=data
|
||||
exp=exp
|
||||
|
||||
mkdir -p $data $exp
|
||||
|
||||
dir=$data/exp/deepspeech2_online/checkpoints
|
||||
|
||||
# wenetspeech or aishell
|
||||
model_type=$(echo $tarfile | cut -d '_' -f 4)
|
||||
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ];then
|
||||
test -f $data/$tarfile || wget -P $data -c https://paddlespeech.bj.bcebos.com/s2t/$model_type/asr0/$tarfile
|
||||
|
||||
# wenetspeech ds2 model
|
||||
pushd $data
|
||||
tar zxvf $tarfile
|
||||
popd
|
||||
|
||||
# ds2 model demo inputs
|
||||
pushd $exp
|
||||
wget -c http://paddlespeech.bj.bcebos.com/speechx/examples/ds2_ol/onnx/static_ds2online_inputs.pickle
|
||||
popd
|
||||
fi
|
||||
|
||||
input_file=$exp/static_ds2online_inputs.pickle
|
||||
test -e $input_file
|
||||
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ];then
|
||||
# to onnx
|
||||
./local/tonnx.sh $dir $model $param $exp/model.onnx
|
||||
|
||||
./local/infer_check.py --input_file $input_file --model_type $model_type --model_dir $dir --model_prefix $model_prefix --onnx_model $exp/model.onnx
|
||||
fi
|
||||
|
||||
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ] ;then
|
||||
# ort graph optmize
|
||||
./local/ort_opt.py --model_in $exp/model.onnx --opt_level 0 --model_out $exp/model.ort.opt.onnx
|
||||
|
||||
./local/infer_check.py --input_file $input_file --model_type $model_type --model_dir $dir --model_prefix $model_prefix --onnx_model $exp/model.ort.opt.onnx
|
||||
fi
|
||||
|
||||
|
||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ];then
|
||||
# convert opset_num to 11
|
||||
./local/onnx_convert_opset.py --target-opset 11 --model-file $exp/model.ort.opt.onnx --save-model $exp/model.optset11.onnx
|
||||
|
||||
# quant model
|
||||
nodes_to_exclude='p2o.Conv.0,p2o.Conv.2'
|
||||
./local/ort_dyanmic_quant.py --model-in $exp/model.optset11.onnx --model-out $exp/model.optset11.quant.onnx --nodes-to-exclude "${nodes_to_exclude}"
|
||||
|
||||
./local/infer_check.py --input_file $input_file --model_type $model_type --model_dir $dir --model_prefix $model_prefix --onnx_model $exp/model.optset11.quant.onnx
|
||||
fi
|
||||
|
||||
|
||||
# aishell rnn hidden is 1024
|
||||
# wenetspeech rnn hiddn is 2048
|
||||
if [ $model_type == 'aishell' ];then
|
||||
input_shape="audio_chunk:1,-1,161 audio_chunk_lens:1 chunk_state_c_box:5,1,1024 chunk_state_h_box:5,1,1024"
|
||||
elif [ $model_type == 'wenetspeech' ];then
|
||||
input_shape="audio_chunk:1,-1,161 audio_chunk_lens:1 chunk_state_c_box:5,1,2048 chunk_state_h_box:5,1,2048"
|
||||
else
|
||||
echo "not support: $model_type"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
|
||||
if [ ${stage} -le 51 ] && [ ${stop_stage} -ge 51 ] ;then
|
||||
# wenetspeech ds2 model execed 2GB limit, will error.
|
||||
# simplifying onnx model
|
||||
./local/onnx_opt.sh $exp/model.onnx $exp/model.opt.onnx "$input_shape"
|
||||
|
||||
./local/infer_check.py --input_file $input_file --model_type $model_type --model_dir $dir --model_prefix $model_prefix --onnx_model $exp/model.opt.onnx
|
||||
fi
|
@ -1 +0,0 @@
|
||||
../../../../utils/
|
@ -1,2 +0,0 @@
|
||||
data
|
||||
exp
|
@ -1,78 +0,0 @@
|
||||
# Streaming DeepSpeech2 Server with WebSocket
|
||||
|
||||
This example is about using `websocket` as streaming deepspeech2 server. For deepspeech2 model training please see [here](../../../../examples/aishell/asr0/).
|
||||
|
||||
The websocket protocal is same to [PaddleSpeech Server](../../../../demos/streaming_asr_server/),
|
||||
for detail of implementation please see [here](../../../speechx/protocol/websocket/).
|
||||
|
||||
|
||||
## Source path.sh
|
||||
|
||||
```bash
|
||||
. path.sh
|
||||
```
|
||||
|
||||
SpeechX bins is under `echo $SPEECHX_BUILD`, more info please see `path.sh`.
|
||||
|
||||
|
||||
## Start WebSocket Server
|
||||
|
||||
```bash
|
||||
bash websoket_server.sh
|
||||
```
|
||||
|
||||
The output is like below:
|
||||
|
||||
```text
|
||||
I1130 02:19:32.029882 12856 cmvn_json2kaldi_main.cc:39] cmvn josn path: /workspace/zhanghui/PaddleSpeech/speechx/examples/ds2_ol/websocket/data/model/data/mean_std.json
|
||||
I1130 02:19:32.032230 12856 cmvn_json2kaldi_main.cc:73] nframe: 907497
|
||||
I1130 02:19:32.032564 12856 cmvn_json2kaldi_main.cc:85] cmvn stats have write into: /workspace/zhanghui/PaddleSpeech/speechx/examples/ds2_ol/websocket/data/cmvn.ark
|
||||
I1130 02:19:32.032579 12856 cmvn_json2kaldi_main.cc:86] Binary: 1
|
||||
I1130 02:19:32.798342 12937 feature_pipeline.h:53] cmvn file: /workspace/zhanghui/PaddleSpeech/speechx/examples/ds2_ol/websocket/data/cmvn.ark
|
||||
I1130 02:19:32.798542 12937 feature_pipeline.h:58] dither: 0
|
||||
I1130 02:19:32.798583 12937 feature_pipeline.h:60] frame shift ms: 10
|
||||
I1130 02:19:32.798588 12937 feature_pipeline.h:62] feature type: linear
|
||||
I1130 02:19:32.798596 12937 feature_pipeline.h:80] frame length ms: 20
|
||||
I1130 02:19:32.798601 12937 feature_pipeline.h:88] subsampling rate: 4
|
||||
I1130 02:19:32.798606 12937 feature_pipeline.h:90] nnet receptive filed length: 7
|
||||
I1130 02:19:32.798611 12937 feature_pipeline.h:92] nnet chunk size: 1
|
||||
I1130 02:19:32.798615 12937 feature_pipeline.h:94] frontend fill zeros: 0
|
||||
I1130 02:19:32.798630 12937 nnet_itf.h:52] subsampling rate: 4
|
||||
I1130 02:19:32.798635 12937 nnet_itf.h:54] model path: /workspace/zhanghui/PaddleSpeech/speechx/examples/ds2_ol/websocket/data/model/exp/deepspeech2_online/checkpoints//avg_1.jit.pdmodel
|
||||
I1130 02:19:32.798640 12937 nnet_itf.h:57] param path: /workspace/zhanghui/PaddleSpeech/speechx/examples/ds2_ol/websocket/data/model/exp/deepspeech2_online/checkpoints//avg_1.jit.pdiparams
|
||||
I1130 02:19:32.798643 12937 nnet_itf.h:59] DS2 param:
|
||||
I1130 02:19:32.798647 12937 nnet_itf.h:61] cache names: chunk_state_h_box,chunk_state_c_box
|
||||
I1130 02:19:32.798652 12937 nnet_itf.h:63] cache shape: 5-1-1024,5-1-1024
|
||||
I1130 02:19:32.798656 12937 nnet_itf.h:65] input names: audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box
|
||||
I1130 02:19:32.798660 12937 nnet_itf.h:67] output names: softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0
|
||||
I1130 02:19:32.798664 12937 ctc_tlg_decoder.h:41] fst path: /workspace/zhanghui/PaddleSpeech/speechx/examples/ds2_ol/websocket/data/wfst//TLG.fst
|
||||
I1130 02:19:32.798669 12937 ctc_tlg_decoder.h:42] fst symbole table: /workspace/zhanghui/PaddleSpeech/speechx/examples/ds2_ol/websocket/data/wfst//words.txt
|
||||
I1130 02:19:32.798673 12937 ctc_tlg_decoder.h:47] LatticeFasterDecoder max active: 7500
|
||||
I1130 02:19:32.798677 12937 ctc_tlg_decoder.h:49] LatticeFasterDecoder beam: 15
|
||||
I1130 02:19:32.798681 12937 ctc_tlg_decoder.h:50] LatticeFasterDecoder lattice_beam: 7.5
|
||||
I1130 02:19:32.798708 12937 websocket_server_main.cc:37] Listening at port 8082
|
||||
```
|
||||
|
||||
## Start WebSocket Client
|
||||
|
||||
```bash
|
||||
bash websocket_client.sh
|
||||
```
|
||||
|
||||
This script using AISHELL-1 test data to call websocket server.
|
||||
|
||||
The input is specific by `--wav_rspecifier=scp:$data/$aishell_wav_scp`.
|
||||
|
||||
The `scp` file which look like this:
|
||||
```text
|
||||
# head data/split1/1/aishell_test.scp
|
||||
BAC009S0764W0121 /workspace/PaddleSpeech/speechx/examples/u2pp_ol/wenetspeech/data/test/S0764/BAC009S0764W0121.wav
|
||||
BAC009S0764W0122 /workspace/PaddleSpeech/speechx/examples/u2pp_ol/wenetspeech/data/test/S0764/BAC009S0764W0122.wav
|
||||
...
|
||||
BAC009S0764W0125 /workspace/PaddleSpeech/speechx/examples/u2pp_ol/wenetspeech/data/test/S0764/BAC009S0764W0125.wav
|
||||
```
|
||||
|
||||
If you want to recognize one wav, you can make `scp` file like this:
|
||||
```text
|
||||
key path/to/wav/file
|
||||
```
|
@ -1,14 +0,0 @@
|
||||
# This contains the locations of binarys build required for running the examples.
|
||||
|
||||
SPEECHX_ROOT=$PWD/../../../
|
||||
SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx
|
||||
|
||||
SPEECHX_TOOLS=$SPEECHX_ROOT/tools
|
||||
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
|
||||
|
||||
[ -d $SPEECHX_BUILD ] || { echo "Error: 'build/speechx' directory not found. please ensure that the project build successfully"; }
|
||||
|
||||
export LC_AL=C
|
||||
|
||||
SPEECHX_BIN=$SPEECHX_BUILD/protocol/websocket:$SPEECHX_BUILD/frontend/audio
|
||||
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN
|
@ -1,35 +0,0 @@
|
||||
#!/bin/bash
|
||||
set +x
|
||||
set -e
|
||||
|
||||
. path.sh
|
||||
|
||||
# 1. compile
|
||||
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
|
||||
pushd ${SPEECHX_ROOT}
|
||||
bash build.sh
|
||||
popd
|
||||
fi
|
||||
|
||||
# input
|
||||
mkdir -p data
|
||||
data=$PWD/data
|
||||
|
||||
# output
|
||||
aishell_wav_scp=aishell_test.scp
|
||||
if [ ! -d $data/test ]; then
|
||||
pushd $data
|
||||
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_test.zip
|
||||
unzip aishell_test.zip
|
||||
popd
|
||||
|
||||
realpath $data/test/*/*.wav > $data/wavlist
|
||||
awk -F '/' '{ print $(NF) }' $data/wavlist | awk -F '.' '{ print $1 }' > $data/utt_id
|
||||
paste $data/utt_id $data/wavlist > $data/$aishell_wav_scp
|
||||
fi
|
||||
|
||||
export GLOG_logtostderr=1
|
||||
|
||||
# websocket client
|
||||
websocket_client_main \
|
||||
--wav_rspecifier=scp:$data/$aishell_wav_scp --streaming_chunk=0.5
|
@ -1,55 +0,0 @@
|
||||
#!/bin/bash
|
||||
set +x
|
||||
set -e
|
||||
|
||||
. path.sh
|
||||
|
||||
# 1. compile
|
||||
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
|
||||
pushd ${SPEECHX_ROOT}
|
||||
bash build.sh
|
||||
popd
|
||||
fi
|
||||
|
||||
# input
|
||||
mkdir -p data
|
||||
data=$PWD/data
|
||||
ckpt_dir=$data/model
|
||||
model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/
|
||||
vocb_dir=$ckpt_dir/data/lang_char/
|
||||
|
||||
|
||||
if [ ! -f $ckpt_dir/data/mean_std.json ]; then
|
||||
mkdir -p $ckpt_dir
|
||||
pushd $ckpt_dir
|
||||
wget -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
|
||||
tar xzfv asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
|
||||
popd
|
||||
fi
|
||||
|
||||
export GLOG_logtostderr=1
|
||||
|
||||
# 3. gen cmvn
|
||||
cmvn=$data/cmvn.ark
|
||||
cmvn_json2kaldi_main --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn
|
||||
|
||||
|
||||
wfst=$data/wfst/
|
||||
mkdir -p $wfst
|
||||
if [ ! -f $wfst/aishell_graph.zip ]; then
|
||||
pushd $wfst
|
||||
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_graph.zip
|
||||
unzip aishell_graph.zip
|
||||
mv aishell_graph/* $wfst
|
||||
popd
|
||||
fi
|
||||
|
||||
# 5. test websocket server
|
||||
websocket_server_main \
|
||||
--cmvn_file=$cmvn \
|
||||
--model_path=$model_dir/avg_1.jit.pdmodel \
|
||||
--param_path=$model_dir/avg_1.jit.pdiparams \
|
||||
--word_symbol_table=$wfst/words.txt \
|
||||
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
|
||||
--graph_path=$wfst/TLG.fst --max_active=7500 \
|
||||
--acoustic_scale=1.2
|
@ -1,313 +0,0 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
#include "decoder/ctc_beam_search_decoder.h"
|
||||
|
||||
#include "base/common.h"
|
||||
#include "decoder/ctc_decoders/decoder_utils.h"
|
||||
#include "utils/file_utils.h"
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
using std::vector;
|
||||
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
|
||||
|
||||
CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts)
|
||||
: opts_(opts), init_ext_scorer_(nullptr), space_id_(-1), root_(nullptr) {
|
||||
LOG(INFO) << "dict path: " << opts_.dict_file;
|
||||
if (!ReadFileToVector(opts_.dict_file, &vocabulary_)) {
|
||||
LOG(INFO) << "load the dict failed";
|
||||
}
|
||||
LOG(INFO) << "read the vocabulary success, dict size: "
|
||||
<< vocabulary_.size();
|
||||
|
||||
LOG(INFO) << "language model path: " << opts_.lm_path;
|
||||
if (opts_.lm_path != "") {
|
||||
init_ext_scorer_ = std::make_shared<Scorer>(
|
||||
opts_.alpha, opts_.beta, opts_.lm_path, vocabulary_);
|
||||
}
|
||||
|
||||
CHECK_EQ(opts_.blank, 0);
|
||||
|
||||
auto it = std::find(vocabulary_.begin(), vocabulary_.end(), " ");
|
||||
space_id_ = it - vocabulary_.begin();
|
||||
// if no space in vocabulary
|
||||
if (static_cast<size_t>(space_id_) >= vocabulary_.size()) {
|
||||
space_id_ = -2;
|
||||
}
|
||||
}
|
||||
|
||||
void CTCBeamSearch::Reset() {
|
||||
// num_frame_decoded_ = 0;
|
||||
// ResetPrefixes();
|
||||
InitDecoder();
|
||||
}
|
||||
|
||||
void CTCBeamSearch::InitDecoder() {
|
||||
num_frame_decoded_ = 0;
|
||||
// ResetPrefixes();
|
||||
prefixes_.clear();
|
||||
|
||||
root_ = std::make_shared<PathTrie>();
|
||||
root_->score = root_->log_prob_b_prev = 0.0;
|
||||
prefixes_.push_back(root_.get());
|
||||
if (init_ext_scorer_ != nullptr &&
|
||||
!init_ext_scorer_->is_character_based()) {
|
||||
auto fst_dict =
|
||||
static_cast<fst::StdVectorFst*>(init_ext_scorer_->dictionary);
|
||||
fst::StdVectorFst* dict_ptr = fst_dict->Copy(true);
|
||||
root_->set_dictionary(dict_ptr);
|
||||
|
||||
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
|
||||
root_->set_matcher(matcher);
|
||||
}
|
||||
}
|
||||
|
||||
void CTCBeamSearch::Decode(
|
||||
std::shared_ptr<kaldi::DecodableInterface> decodable) {
|
||||
return;
|
||||
}
|
||||
|
||||
// todo rename, refactor
|
||||
void CTCBeamSearch::AdvanceDecode(
|
||||
const std::shared_ptr<kaldi::DecodableInterface>& decodable) {
|
||||
while (1) {
|
||||
vector<vector<BaseFloat>> likelihood;
|
||||
vector<BaseFloat> frame_prob;
|
||||
bool flag = decodable->FrameLikelihood(num_frame_decoded_, &frame_prob);
|
||||
if (flag == false) break;
|
||||
likelihood.push_back(frame_prob);
|
||||
AdvanceDecoding(likelihood);
|
||||
}
|
||||
}
|
||||
|
||||
void CTCBeamSearch::ResetPrefixes() {
|
||||
for (size_t i = 0; i < prefixes_.size(); i++) {
|
||||
if (prefixes_[i] != nullptr) {
|
||||
delete prefixes_[i];
|
||||
prefixes_[i] = nullptr;
|
||||
}
|
||||
}
|
||||
prefixes_.clear();
|
||||
}
|
||||
|
||||
int CTCBeamSearch::DecodeLikelihoods(const vector<vector<float>>& probs,
|
||||
const vector<string>& nbest_words) {
|
||||
kaldi::Timer timer;
|
||||
AdvanceDecoding(probs);
|
||||
LOG(INFO) << "ctc decoding elapsed time(s) "
|
||||
<< static_cast<float>(timer.Elapsed()) / 1000.0f;
|
||||
return 0;
|
||||
}
|
||||
|
||||
vector<std::pair<double, string>> CTCBeamSearch::GetNBestPath(int n) {
|
||||
int beam_size = n == -1 ? opts_.beam_size : std::min(n, opts_.beam_size);
|
||||
return get_beam_search_result(prefixes_, vocabulary_, beam_size);
|
||||
}
|
||||
|
||||
vector<std::pair<double, string>> CTCBeamSearch::GetNBestPath() {
|
||||
return GetNBestPath(-1);
|
||||
}
|
||||
|
||||
string CTCBeamSearch::GetBestPath() {
|
||||
std::vector<std::pair<double, std::string>> result;
|
||||
result = get_beam_search_result(prefixes_, vocabulary_, opts_.beam_size);
|
||||
return result[0].second;
|
||||
}
|
||||
|
||||
string CTCBeamSearch::GetFinalBestPath() {
|
||||
CalculateApproxScore();
|
||||
LMRescore();
|
||||
return GetBestPath();
|
||||
}
|
||||
|
||||
void CTCBeamSearch::AdvanceDecoding(const vector<vector<BaseFloat>>& probs) {
|
||||
size_t num_time_steps = probs.size();
|
||||
size_t beam_size = opts_.beam_size;
|
||||
double cutoff_prob = opts_.cutoff_prob;
|
||||
size_t cutoff_top_n = opts_.cutoff_top_n;
|
||||
|
||||
vector<vector<double>> probs_seq(probs.size(),
|
||||
vector<double>(probs[0].size(), 0));
|
||||
|
||||
int row = probs.size();
|
||||
int col = probs[0].size();
|
||||
for (int i = 0; i < row; i++) {
|
||||
for (int j = 0; j < col; j++) {
|
||||
probs_seq[i][j] = static_cast<double>(probs[i][j]);
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t time_step = 0; time_step < num_time_steps; time_step++) {
|
||||
const auto& prob = probs_seq[time_step];
|
||||
|
||||
float min_cutoff = -NUM_FLT_INF;
|
||||
bool full_beam = false;
|
||||
if (init_ext_scorer_ != nullptr) {
|
||||
size_t num_prefixes_ = std::min(prefixes_.size(), beam_size);
|
||||
std::sort(prefixes_.begin(),
|
||||
prefixes_.begin() + num_prefixes_,
|
||||
prefix_compare);
|
||||
|
||||
if (num_prefixes_ == 0) {
|
||||
continue;
|
||||
}
|
||||
min_cutoff = prefixes_[num_prefixes_ - 1]->score +
|
||||
std::log(prob[opts_.blank]) -
|
||||
std::max(0.0, init_ext_scorer_->beta);
|
||||
|
||||
full_beam = (num_prefixes_ == beam_size);
|
||||
}
|
||||
|
||||
vector<std::pair<size_t, float>> log_prob_idx =
|
||||
get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n);
|
||||
|
||||
// loop over chars
|
||||
size_t log_prob_idx_len = log_prob_idx.size();
|
||||
for (size_t index = 0; index < log_prob_idx_len; index++) {
|
||||
SearchOneChar(full_beam, log_prob_idx[index], min_cutoff);
|
||||
}
|
||||
|
||||
prefixes_.clear();
|
||||
|
||||
// update log probs
|
||||
root_->iterate_to_vec(prefixes_);
|
||||
// only preserve top beam_size prefixes_
|
||||
if (prefixes_.size() >= beam_size) {
|
||||
std::nth_element(prefixes_.begin(),
|
||||
prefixes_.begin() + beam_size,
|
||||
prefixes_.end(),
|
||||
prefix_compare);
|
||||
for (size_t i = beam_size; i < prefixes_.size(); ++i) {
|
||||
prefixes_[i]->remove();
|
||||
}
|
||||
} // end if
|
||||
num_frame_decoded_++;
|
||||
} // end for probs_seq
|
||||
}
|
||||
|
||||
int32 CTCBeamSearch::SearchOneChar(
|
||||
const bool& full_beam,
|
||||
const std::pair<size_t, BaseFloat>& log_prob_idx,
|
||||
const BaseFloat& min_cutoff) {
|
||||
size_t beam_size = opts_.beam_size;
|
||||
const auto& c = log_prob_idx.first;
|
||||
const auto& log_prob_c = log_prob_idx.second;
|
||||
size_t prefixes_len = std::min(prefixes_.size(), beam_size);
|
||||
|
||||
for (size_t i = 0; i < prefixes_len; ++i) {
|
||||
auto prefix = prefixes_[i];
|
||||
if (full_beam && log_prob_c + prefix->score < min_cutoff) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (c == opts_.blank) {
|
||||
prefix->log_prob_b_cur =
|
||||
log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score);
|
||||
continue;
|
||||
}
|
||||
|
||||
// repeated character
|
||||
if (c == prefix->character) {
|
||||
// p_{nb}(l;x_{1:t}) = p(c;x_{t})p(l;x_{1:t-1})
|
||||
prefix->log_prob_nb_cur = log_sum_exp(
|
||||
prefix->log_prob_nb_cur, log_prob_c + prefix->log_prob_nb_prev);
|
||||
}
|
||||
|
||||
// get new prefix
|
||||
auto prefix_new = prefix->get_path_trie(c);
|
||||
if (prefix_new != nullptr) {
|
||||
float log_p = -NUM_FLT_INF;
|
||||
if (c == prefix->character &&
|
||||
prefix->log_prob_b_prev > -NUM_FLT_INF) {
|
||||
// p_{nb}(l^{+};x_{1:t}) = p(c;x_{t})p_{b}(l;x_{1:t-1})
|
||||
log_p = log_prob_c + prefix->log_prob_b_prev;
|
||||
} else if (c != prefix->character) {
|
||||
// p_{nb}(l^{+};x_{1:t}) = p(c;x_{t}) p(l;x_{1:t-1})
|
||||
log_p = log_prob_c + prefix->score;
|
||||
}
|
||||
|
||||
// language model scoring
|
||||
if (init_ext_scorer_ != nullptr &&
|
||||
(c == space_id_ || init_ext_scorer_->is_character_based())) {
|
||||
PathTrie* prefix_to_score = nullptr;
|
||||
// skip scoring the space
|
||||
if (init_ext_scorer_->is_character_based()) {
|
||||
prefix_to_score = prefix_new;
|
||||
} else {
|
||||
prefix_to_score = prefix;
|
||||
}
|
||||
|
||||
float score = 0.0;
|
||||
vector<string> ngram;
|
||||
ngram = init_ext_scorer_->make_ngram(prefix_to_score);
|
||||
// lm score: p_{lm}(W)^{\alpha} + \beta
|
||||
score = init_ext_scorer_->get_log_cond_prob(ngram) *
|
||||
init_ext_scorer_->alpha;
|
||||
log_p += score;
|
||||
log_p += init_ext_scorer_->beta;
|
||||
}
|
||||
// p_{nb}(l;x_{1:t})
|
||||
prefix_new->log_prob_nb_cur =
|
||||
log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
|
||||
}
|
||||
} // end of loop over prefix
|
||||
return 0;
|
||||
}
|
||||
|
||||
void CTCBeamSearch::CalculateApproxScore() {
|
||||
size_t beam_size = opts_.beam_size;
|
||||
size_t num_prefixes_ = std::min(prefixes_.size(), beam_size);
|
||||
std::sort(
|
||||
prefixes_.begin(), prefixes_.begin() + num_prefixes_, prefix_compare);
|
||||
|
||||
// compute aproximate ctc score as the return score, without affecting the
|
||||
// return order of decoding result. To delete when decoder gets stable.
|
||||
for (size_t i = 0; i < beam_size && i < prefixes_.size(); ++i) {
|
||||
double approx_ctc = prefixes_[i]->score;
|
||||
if (init_ext_scorer_ != nullptr) {
|
||||
vector<int> output;
|
||||
prefixes_[i]->get_path_vec(output);
|
||||
auto prefix_length = output.size();
|
||||
auto words = init_ext_scorer_->split_labels(output);
|
||||
// remove word insert
|
||||
approx_ctc = approx_ctc - prefix_length * init_ext_scorer_->beta;
|
||||
// remove language model weight:
|
||||
approx_ctc -= (init_ext_scorer_->get_sent_log_prob(words)) *
|
||||
init_ext_scorer_->alpha;
|
||||
}
|
||||
prefixes_[i]->approx_ctc = approx_ctc;
|
||||
}
|
||||
}
|
||||
|
||||
void CTCBeamSearch::LMRescore() {
|
||||
size_t beam_size = opts_.beam_size;
|
||||
if (init_ext_scorer_ != nullptr &&
|
||||
!init_ext_scorer_->is_character_based()) {
|
||||
for (size_t i = 0; i < beam_size && i < prefixes_.size(); ++i) {
|
||||
auto prefix = prefixes_[i];
|
||||
if (!prefix->is_empty() && prefix->character != space_id_) {
|
||||
float score = 0.0;
|
||||
vector<string> ngram = init_ext_scorer_->make_ngram(prefix);
|
||||
score = init_ext_scorer_->get_log_cond_prob(ngram) *
|
||||
init_ext_scorer_->alpha;
|
||||
score += init_ext_scorer_->beta;
|
||||
prefix->score += score;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ppspeech
|
@ -1,73 +0,0 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// used by deepspeech2
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "decoder/ctc_beam_search_opt.h"
|
||||
#include "decoder/ctc_decoders/path_trie.h"
|
||||
#include "decoder/ctc_decoders/scorer.h"
|
||||
#include "decoder/decoder_itf.h"
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
class CTCBeamSearch : public DecoderBase {
|
||||
public:
|
||||
explicit CTCBeamSearch(const CTCBeamSearchOptions& opts);
|
||||
~CTCBeamSearch() {}
|
||||
|
||||
void InitDecoder();
|
||||
|
||||
void Reset();
|
||||
|
||||
void AdvanceDecode(
|
||||
const std::shared_ptr<kaldi::DecodableInterface>& decodable);
|
||||
|
||||
void Decode(std::shared_ptr<kaldi::DecodableInterface> decodable);
|
||||
|
||||
std::string GetBestPath();
|
||||
std::vector<std::pair<double, std::string>> GetNBestPath();
|
||||
std::vector<std::pair<double, std::string>> GetNBestPath(int n);
|
||||
std::string GetFinalBestPath();
|
||||
|
||||
std::string GetPartialResult() {
|
||||
CHECK(false) << "Not implement.";
|
||||
return {};
|
||||
}
|
||||
|
||||
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs,
|
||||
const std::vector<std::string>& nbest_words);
|
||||
|
||||
private:
|
||||
void ResetPrefixes();
|
||||
|
||||
int32 SearchOneChar(const bool& full_beam,
|
||||
const std::pair<size_t, BaseFloat>& log_prob_idx,
|
||||
const BaseFloat& min_cutoff);
|
||||
void CalculateApproxScore();
|
||||
void LMRescore();
|
||||
void AdvanceDecoding(const std::vector<std::vector<BaseFloat>>& probs);
|
||||
|
||||
CTCBeamSearchOptions opts_;
|
||||
std::shared_ptr<Scorer> init_ext_scorer_; // todo separate later
|
||||
std::vector<std::string> vocabulary_; // todo remove later
|
||||
int space_id_;
|
||||
std::shared_ptr<PathTrie> root_;
|
||||
std::vector<PathTrie*> prefixes_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN(CTCBeamSearch);
|
||||
};
|
||||
|
||||
} // namespace ppspeech
|
@ -1,167 +0,0 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// used by deepspeech2
|
||||
|
||||
#include "base/flags.h"
|
||||
#include "base/log.h"
|
||||
#include "decoder/ctc_beam_search_decoder.h"
|
||||
#include "frontend/audio/data_cache.h"
|
||||
#include "kaldi/util/table-types.h"
|
||||
#include "nnet/decodable.h"
|
||||
#include "nnet/ds2_nnet.h"
|
||||
|
||||
DEFINE_string(feature_rspecifier, "", "test feature rspecifier");
|
||||
DEFINE_string(result_wspecifier, "", "test result wspecifier");
|
||||
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
|
||||
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
|
||||
DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm");
|
||||
DEFINE_string(lm_path, "", "language model");
|
||||
DEFINE_int32(receptive_field_length,
|
||||
7,
|
||||
"receptive field of two CNN(kernel=3) downsampling module.");
|
||||
DEFINE_int32(subsampling_rate,
|
||||
4,
|
||||
"two CNN(kernel=3) module downsampling rate.");
|
||||
DEFINE_string(
|
||||
model_input_names,
|
||||
"audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box",
|
||||
"model input names");
|
||||
DEFINE_string(model_output_names,
|
||||
"softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0",
|
||||
"model output names");
|
||||
DEFINE_string(model_cache_names,
|
||||
"chunk_state_h_box,chunk_state_c_box",
|
||||
"model cache names");
|
||||
DEFINE_string(model_cache_shapes, "5-1-1024,5-1-1024", "model cache shapes");
|
||||
DEFINE_int32(nnet_decoder_chunk, 1, "paddle nnet forward chunk");
|
||||
|
||||
using kaldi::BaseFloat;
|
||||
using kaldi::Matrix;
|
||||
using std::vector;
|
||||
|
||||
// test ds2 online decoder by feeding speech feature
|
||||
int main(int argc, char* argv[]) {
|
||||
gflags::SetUsageMessage("Usage:");
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
google::InstallFailureSignalHandler();
|
||||
FLAGS_logtostderr = 1;
|
||||
|
||||
CHECK_NE(FLAGS_result_wspecifier, "");
|
||||
CHECK_NE(FLAGS_feature_rspecifier, "");
|
||||
|
||||
kaldi::SequentialBaseFloatMatrixReader feature_reader(
|
||||
FLAGS_feature_rspecifier);
|
||||
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
|
||||
std::string model_path = FLAGS_model_path;
|
||||
std::string model_params = FLAGS_param_path;
|
||||
std::string dict_file = FLAGS_dict_file;
|
||||
std::string lm_path = FLAGS_lm_path;
|
||||
LOG(INFO) << "model path: " << model_path;
|
||||
LOG(INFO) << "model param: " << model_params;
|
||||
LOG(INFO) << "dict path: " << dict_file;
|
||||
LOG(INFO) << "lm path: " << lm_path;
|
||||
|
||||
int32 num_done = 0, num_err = 0;
|
||||
|
||||
ppspeech::CTCBeamSearchOptions opts;
|
||||
opts.dict_file = dict_file;
|
||||
opts.lm_path = lm_path;
|
||||
ppspeech::CTCBeamSearch decoder(opts);
|
||||
|
||||
ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags();
|
||||
|
||||
std::shared_ptr<ppspeech::PaddleNnet> nnet(
|
||||
new ppspeech::PaddleNnet(model_opts));
|
||||
std::shared_ptr<ppspeech::DataCache> raw_data(new ppspeech::DataCache());
|
||||
std::shared_ptr<ppspeech::Decodable> decodable(
|
||||
new ppspeech::Decodable(nnet, raw_data));
|
||||
|
||||
int32 chunk_size = FLAGS_receptive_field_length +
|
||||
(FLAGS_nnet_decoder_chunk - 1) * FLAGS_subsampling_rate;
|
||||
int32 chunk_stride = FLAGS_subsampling_rate * FLAGS_nnet_decoder_chunk;
|
||||
int32 receptive_field_length = FLAGS_receptive_field_length;
|
||||
LOG(INFO) << "chunk size (frame): " << chunk_size;
|
||||
LOG(INFO) << "chunk stride (frame): " << chunk_stride;
|
||||
LOG(INFO) << "receptive field (frame): " << receptive_field_length;
|
||||
decoder.InitDecoder();
|
||||
|
||||
kaldi::Timer timer;
|
||||
for (; !feature_reader.Done(); feature_reader.Next()) {
|
||||
string utt = feature_reader.Key();
|
||||
kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
|
||||
raw_data->SetDim(feature.NumCols());
|
||||
LOG(INFO) << "process utt: " << utt;
|
||||
LOG(INFO) << "rows: " << feature.NumRows();
|
||||
LOG(INFO) << "cols: " << feature.NumCols();
|
||||
|
||||
int32 row_idx = 0;
|
||||
int32 padding_len = 0;
|
||||
int32 ori_feature_len = feature.NumRows();
|
||||
if ((feature.NumRows() - chunk_size) % chunk_stride != 0) {
|
||||
padding_len =
|
||||
chunk_stride - (feature.NumRows() - chunk_size) % chunk_stride;
|
||||
feature.Resize(feature.NumRows() + padding_len,
|
||||
feature.NumCols(),
|
||||
kaldi::kCopyData);
|
||||
}
|
||||
int32 num_chunks = (feature.NumRows() - chunk_size) / chunk_stride + 1;
|
||||
for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) {
|
||||
kaldi::Vector<kaldi::BaseFloat> feature_chunk(chunk_size *
|
||||
feature.NumCols());
|
||||
int32 feature_chunk_size = 0;
|
||||
if (ori_feature_len > chunk_idx * chunk_stride) {
|
||||
feature_chunk_size = std::min(
|
||||
ori_feature_len - chunk_idx * chunk_stride, chunk_size);
|
||||
}
|
||||
if (feature_chunk_size < receptive_field_length) break;
|
||||
|
||||
int32 start = chunk_idx * chunk_stride;
|
||||
|
||||
for (int row_id = 0; row_id < chunk_size; ++row_id) {
|
||||
kaldi::SubVector<kaldi::BaseFloat> tmp(feature, start);
|
||||
kaldi::SubVector<kaldi::BaseFloat> f_chunk_tmp(
|
||||
feature_chunk.Data() + row_id * feature.NumCols(),
|
||||
feature.NumCols());
|
||||
f_chunk_tmp.CopyFromVec(tmp);
|
||||
++start;
|
||||
}
|
||||
raw_data->Accept(feature_chunk);
|
||||
if (chunk_idx == num_chunks - 1) {
|
||||
raw_data->SetFinished();
|
||||
}
|
||||
decoder.AdvanceDecode(decodable);
|
||||
}
|
||||
std::string result;
|
||||
result = decoder.GetFinalBestPath();
|
||||
decodable->Reset();
|
||||
decoder.Reset();
|
||||
if (result.empty()) {
|
||||
// the TokenWriter can not write empty string.
|
||||
++num_err;
|
||||
KALDI_LOG << " the result of " << utt << " is empty";
|
||||
continue;
|
||||
}
|
||||
KALDI_LOG << " the result of " << utt << " is " << result;
|
||||
result_writer.Write(utt, result);
|
||||
++num_done;
|
||||
}
|
||||
|
||||
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
|
||||
<< " with errors.";
|
||||
double elapsed = timer.Elapsed();
|
||||
KALDI_LOG << " cost:" << elapsed << " s";
|
||||
return (num_done != 0 ? 0 : 1);
|
||||
}
|
@ -1,9 +0,0 @@
|
||||
ThreadPool/
|
||||
build/
|
||||
dist/
|
||||
kenlm/
|
||||
openfst-1.6.3/
|
||||
openfst-1.6.3.tar.gz
|
||||
swig_decoders.egg-info/
|
||||
decoders_wrap.cxx
|
||||
swig_decoders.py
|
@ -1,607 +0,0 @@
|
||||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "COPYING.APACHE2.0");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "ctc_beam_search_decoder.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <map>
|
||||
#include <utility>
|
||||
|
||||
#include "ThreadPool.h"
|
||||
#include "fst/fstlib.h"
|
||||
|
||||
#include "decoder_utils.h"
|
||||
#include "path_trie.h"
|
||||
|
||||
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
|
||||
|
||||
|
||||
std::vector<std::pair<double, std::string>> ctc_beam_search_decoding(
|
||||
const std::vector<std::vector<double>> &probs_seq,
|
||||
const std::vector<std::string> &vocabulary,
|
||||
size_t beam_size,
|
||||
double cutoff_prob,
|
||||
size_t cutoff_top_n,
|
||||
Scorer *ext_scorer,
|
||||
size_t blank_id) {
|
||||
// dimension check
|
||||
size_t num_time_steps = probs_seq.size();
|
||||
for (size_t i = 0; i < num_time_steps; ++i) {
|
||||
VALID_CHECK_EQ(probs_seq[i].size(),
|
||||
// vocabulary.size() + 1,
|
||||
vocabulary.size(),
|
||||
"The shape of probs_seq does not match with "
|
||||
"the shape of the vocabulary");
|
||||
}
|
||||
|
||||
|
||||
// assign space id
|
||||
auto it = std::find(vocabulary.begin(), vocabulary.end(), kSPACE);
|
||||
int space_id = it - vocabulary.begin();
|
||||
// if no space in vocabulary
|
||||
if ((size_t)space_id >= vocabulary.size()) {
|
||||
space_id = -2;
|
||||
}
|
||||
// init prefixes' root
|
||||
PathTrie root;
|
||||
root.score = root.log_prob_b_prev = 0.0;
|
||||
std::vector<PathTrie *> prefixes;
|
||||
prefixes.push_back(&root);
|
||||
|
||||
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
|
||||
auto fst_dict =
|
||||
static_cast<fst::StdVectorFst *>(ext_scorer->dictionary);
|
||||
fst::StdVectorFst *dict_ptr = fst_dict->Copy(true);
|
||||
root.set_dictionary(dict_ptr);
|
||||
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
|
||||
root.set_matcher(matcher);
|
||||
}
|
||||
|
||||
// prefix search over time
|
||||
for (size_t time_step = 0; time_step < num_time_steps; ++time_step) {
|
||||
auto &prob = probs_seq[time_step];
|
||||
|
||||
float min_cutoff = -NUM_FLT_INF;
|
||||
bool full_beam = false;
|
||||
if (ext_scorer != nullptr) {
|
||||
size_t num_prefixes = std::min(prefixes.size(), beam_size);
|
||||
std::sort(prefixes.begin(),
|
||||
prefixes.begin() + num_prefixes,
|
||||
prefix_compare);
|
||||
min_cutoff = prefixes[num_prefixes - 1]->score +
|
||||
std::log(prob[blank_id]) -
|
||||
std::max(0.0, ext_scorer->beta);
|
||||
full_beam = (num_prefixes == beam_size);
|
||||
}
|
||||
|
||||
std::vector<std::pair<size_t, float>> log_prob_idx =
|
||||
get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n);
|
||||
// loop over chars
|
||||
for (size_t index = 0; index < log_prob_idx.size(); index++) {
|
||||
auto c = log_prob_idx[index].first;
|
||||
auto log_prob_c = log_prob_idx[index].second;
|
||||
|
||||
for (size_t i = 0; i < prefixes.size() && i < beam_size; ++i) {
|
||||
auto prefix = prefixes[i];
|
||||
if (full_beam && log_prob_c + prefix->score < min_cutoff) {
|
||||
break;
|
||||
}
|
||||
// blank
|
||||
if (c == blank_id) {
|
||||
prefix->log_prob_b_cur = log_sum_exp(
|
||||
prefix->log_prob_b_cur, log_prob_c + prefix->score);
|
||||
continue;
|
||||
}
|
||||
// repeated character
|
||||
if (c == prefix->character) {
|
||||
prefix->log_prob_nb_cur =
|
||||
log_sum_exp(prefix->log_prob_nb_cur,
|
||||
log_prob_c + prefix->log_prob_nb_prev);
|
||||
}
|
||||
// get new prefix
|
||||
auto prefix_new = prefix->get_path_trie(c);
|
||||
|
||||
if (prefix_new != nullptr) {
|
||||
float log_p = -NUM_FLT_INF;
|
||||
|
||||
if (c == prefix->character &&
|
||||
prefix->log_prob_b_prev > -NUM_FLT_INF) {
|
||||
log_p = log_prob_c + prefix->log_prob_b_prev;
|
||||
} else if (c != prefix->character) {
|
||||
log_p = log_prob_c + prefix->score;
|
||||
}
|
||||
|
||||
// language model scoring
|
||||
if (ext_scorer != nullptr &&
|
||||
(c == space_id || ext_scorer->is_character_based())) {
|
||||
PathTrie *prefix_to_score = nullptr;
|
||||
// skip scoring the space
|
||||
if (ext_scorer->is_character_based()) {
|
||||
prefix_to_score = prefix_new;
|
||||
} else {
|
||||
prefix_to_score = prefix;
|
||||
}
|
||||
|
||||
float score = 0.0;
|
||||
std::vector<std::string> ngram;
|
||||
ngram = ext_scorer->make_ngram(prefix_to_score);
|
||||
score = ext_scorer->get_log_cond_prob(ngram) *
|
||||
ext_scorer->alpha;
|
||||
log_p += score;
|
||||
log_p += ext_scorer->beta;
|
||||
}
|
||||
prefix_new->log_prob_nb_cur =
|
||||
log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
|
||||
}
|
||||
} // end of loop over prefix
|
||||
} // end of loop over vocabulary
|
||||
|
||||
|
||||
prefixes.clear();
|
||||
// update log probs
|
||||
root.iterate_to_vec(prefixes);
|
||||
|
||||
// only preserve top beam_size prefixes
|
||||
if (prefixes.size() >= beam_size) {
|
||||
std::nth_element(prefixes.begin(),
|
||||
prefixes.begin() + beam_size,
|
||||
prefixes.end(),
|
||||
prefix_compare);
|
||||
for (size_t i = beam_size; i < prefixes.size(); ++i) {
|
||||
prefixes[i]->remove();
|
||||
}
|
||||
}
|
||||
} // end of loop over time
|
||||
|
||||
// score the last word of each prefix that doesn't end with space
|
||||
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
|
||||
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
|
||||
auto prefix = prefixes[i];
|
||||
if (!prefix->is_empty() && prefix->character != space_id) {
|
||||
float score = 0.0;
|
||||
std::vector<std::string> ngram = ext_scorer->make_ngram(prefix);
|
||||
score =
|
||||
ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha;
|
||||
score += ext_scorer->beta;
|
||||
prefix->score += score;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_t num_prefixes = std::min(prefixes.size(), beam_size);
|
||||
std::sort(
|
||||
prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare);
|
||||
|
||||
// compute approximate ctc score as the return score, without affecting the
|
||||
// return order of decoding result. To delete when decoder gets stable.
|
||||
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
|
||||
double approx_ctc = prefixes[i]->score;
|
||||
if (ext_scorer != nullptr) {
|
||||
std::vector<int> output;
|
||||
prefixes[i]->get_path_vec(output);
|
||||
auto prefix_length = output.size();
|
||||
auto words = ext_scorer->split_labels(output);
|
||||
// remove word insert
|
||||
approx_ctc = approx_ctc - prefix_length * ext_scorer->beta;
|
||||
// remove language model weight:
|
||||
approx_ctc -=
|
||||
(ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha;
|
||||
}
|
||||
prefixes[i]->approx_ctc = approx_ctc;
|
||||
}
|
||||
|
||||
return get_beam_search_result(prefixes, vocabulary, beam_size);
|
||||
}
|
||||
|
||||
|
||||
std::vector<std::vector<std::pair<double, std::string>>>
|
||||
ctc_beam_search_decoding_batch(
|
||||
const std::vector<std::vector<std::vector<double>>> &probs_split,
|
||||
const std::vector<std::string> &vocabulary,
|
||||
size_t beam_size,
|
||||
size_t num_processes,
|
||||
double cutoff_prob,
|
||||
size_t cutoff_top_n,
|
||||
Scorer *ext_scorer,
|
||||
size_t blank_id) {
|
||||
VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!");
|
||||
// thread pool
|
||||
ThreadPool pool(num_processes);
|
||||
// number of samples
|
||||
size_t batch_size = probs_split.size();
|
||||
|
||||
// enqueue the tasks of decoding
|
||||
std::vector<std::future<std::vector<std::pair<double, std::string>>>> res;
|
||||
for (size_t i = 0; i < batch_size; ++i) {
|
||||
res.emplace_back(pool.enqueue(ctc_beam_search_decoding,
|
||||
probs_split[i],
|
||||
vocabulary,
|
||||
beam_size,
|
||||
cutoff_prob,
|
||||
cutoff_top_n,
|
||||
ext_scorer,
|
||||
blank_id));
|
||||
}
|
||||
|
||||
// get decoding results
|
||||
std::vector<std::vector<std::pair<double, std::string>>> batch_results;
|
||||
for (size_t i = 0; i < batch_size; ++i) {
|
||||
batch_results.emplace_back(res[i].get());
|
||||
}
|
||||
return batch_results;
|
||||
}
|
||||
|
||||
void ctc_beam_search_decode_chunk_begin(PathTrie *root, Scorer *ext_scorer) {
|
||||
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
|
||||
auto fst_dict =
|
||||
static_cast<fst::StdVectorFst *>(ext_scorer->dictionary);
|
||||
fst::StdVectorFst *dict_ptr = fst_dict->Copy(true);
|
||||
root->set_dictionary(dict_ptr);
|
||||
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
|
||||
root->set_matcher(matcher);
|
||||
}
|
||||
}
|
||||
|
||||
void ctc_beam_search_decode_chunk(
|
||||
PathTrie *root,
|
||||
std::vector<PathTrie *> &prefixes,
|
||||
const std::vector<std::vector<double>> &probs_seq,
|
||||
const std::vector<std::string> &vocabulary,
|
||||
size_t beam_size,
|
||||
double cutoff_prob,
|
||||
size_t cutoff_top_n,
|
||||
Scorer *ext_scorer,
|
||||
size_t blank_id) {
|
||||
// dimension check
|
||||
size_t num_time_steps = probs_seq.size();
|
||||
for (size_t i = 0; i < num_time_steps; ++i) {
|
||||
VALID_CHECK_EQ(probs_seq[i].size(),
|
||||
// vocabulary.size() + 1,
|
||||
vocabulary.size(),
|
||||
"The shape of probs_seq does not match with "
|
||||
"the shape of the vocabulary");
|
||||
}
|
||||
|
||||
// assign space id
|
||||
auto it = std::find(vocabulary.begin(), vocabulary.end(), kSPACE);
|
||||
int space_id = it - vocabulary.begin();
|
||||
// if no space in vocabulary
|
||||
if ((size_t)space_id >= vocabulary.size()) {
|
||||
space_id = -2;
|
||||
}
|
||||
// init prefixes' root
|
||||
//
|
||||
// prefix search over time
|
||||
for (size_t time_step = 0; time_step < num_time_steps; ++time_step) {
|
||||
auto &prob = probs_seq[time_step];
|
||||
|
||||
float min_cutoff = -NUM_FLT_INF;
|
||||
bool full_beam = false;
|
||||
if (ext_scorer != nullptr) {
|
||||
size_t num_prefixes = std::min(prefixes.size(), beam_size);
|
||||
std::sort(prefixes.begin(),
|
||||
prefixes.begin() + num_prefixes,
|
||||
prefix_compare);
|
||||
min_cutoff = prefixes[num_prefixes - 1]->score +
|
||||
std::log(prob[blank_id]) -
|
||||
std::max(0.0, ext_scorer->beta);
|
||||
full_beam = (num_prefixes == beam_size);
|
||||
}
|
||||
|
||||
std::vector<std::pair<size_t, float>> log_prob_idx =
|
||||
get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n);
|
||||
// loop over chars
|
||||
for (size_t index = 0; index < log_prob_idx.size(); index++) {
|
||||
auto c = log_prob_idx[index].first;
|
||||
auto log_prob_c = log_prob_idx[index].second;
|
||||
|
||||
for (size_t i = 0; i < prefixes.size() && i < beam_size; ++i) {
|
||||
auto prefix = prefixes[i];
|
||||
if (full_beam && log_prob_c + prefix->score < min_cutoff) {
|
||||
break;
|
||||
}
|
||||
// blank
|
||||
if (c == blank_id) {
|
||||
prefix->log_prob_b_cur = log_sum_exp(
|
||||
prefix->log_prob_b_cur, log_prob_c + prefix->score);
|
||||
continue;
|
||||
}
|
||||
// repeated character
|
||||
if (c == prefix->character) {
|
||||
prefix->log_prob_nb_cur =
|
||||
log_sum_exp(prefix->log_prob_nb_cur,
|
||||
log_prob_c + prefix->log_prob_nb_prev);
|
||||
}
|
||||
// get new prefix
|
||||
auto prefix_new = prefix->get_path_trie(c);
|
||||
|
||||
if (prefix_new != nullptr) {
|
||||
float log_p = -NUM_FLT_INF;
|
||||
|
||||
if (c == prefix->character &&
|
||||
prefix->log_prob_b_prev > -NUM_FLT_INF) {
|
||||
log_p = log_prob_c + prefix->log_prob_b_prev;
|
||||
} else if (c != prefix->character) {
|
||||
log_p = log_prob_c + prefix->score;
|
||||
}
|
||||
|
||||
// language model scoring
|
||||
if (ext_scorer != nullptr &&
|
||||
(c == space_id || ext_scorer->is_character_based())) {
|
||||
PathTrie *prefix_to_score = nullptr;
|
||||
// skip scoring the space
|
||||
if (ext_scorer->is_character_based()) {
|
||||
prefix_to_score = prefix_new;
|
||||
} else {
|
||||
prefix_to_score = prefix;
|
||||
}
|
||||
|
||||
float score = 0.0;
|
||||
std::vector<std::string> ngram;
|
||||
ngram = ext_scorer->make_ngram(prefix_to_score);
|
||||
score = ext_scorer->get_log_cond_prob(ngram) *
|
||||
ext_scorer->alpha;
|
||||
log_p += score;
|
||||
log_p += ext_scorer->beta;
|
||||
}
|
||||
prefix_new->log_prob_nb_cur =
|
||||
log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
|
||||
}
|
||||
} // end of loop over prefix
|
||||
} // end of loop over vocabulary
|
||||
|
||||
prefixes.clear();
|
||||
// update log probs
|
||||
|
||||
root->iterate_to_vec(prefixes);
|
||||
|
||||
// only preserve top beam_size prefixes
|
||||
if (prefixes.size() >= beam_size) {
|
||||
std::nth_element(prefixes.begin(),
|
||||
prefixes.begin() + beam_size,
|
||||
prefixes.end(),
|
||||
prefix_compare);
|
||||
for (size_t i = beam_size; i < prefixes.size(); ++i) {
|
||||
prefixes[i]->remove();
|
||||
}
|
||||
}
|
||||
} // end of loop over time
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
std::vector<std::pair<double, std::string>> get_decode_result(
|
||||
std::vector<PathTrie *> &prefixes,
|
||||
const std::vector<std::string> &vocabulary,
|
||||
size_t beam_size,
|
||||
Scorer *ext_scorer) {
|
||||
auto it = std::find(vocabulary.begin(), vocabulary.end(), kSPACE);
|
||||
int space_id = it - vocabulary.begin();
|
||||
// if no space in vocabulary
|
||||
if ((size_t)space_id >= vocabulary.size()) {
|
||||
space_id = -2;
|
||||
}
|
||||
// score the last word of each prefix that doesn't end with space
|
||||
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
|
||||
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
|
||||
auto prefix = prefixes[i];
|
||||
if (!prefix->is_empty() && prefix->character != space_id) {
|
||||
float score = 0.0;
|
||||
std::vector<std::string> ngram = ext_scorer->make_ngram(prefix);
|
||||
score =
|
||||
ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha;
|
||||
score += ext_scorer->beta;
|
||||
prefix->score += score;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_t num_prefixes = std::min(prefixes.size(), beam_size);
|
||||
std::sort(
|
||||
prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare);
|
||||
|
||||
// compute aproximate ctc score as the return score, without affecting the
|
||||
// return order of decoding result. To delete when decoder gets stable.
|
||||
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
|
||||
double approx_ctc = prefixes[i]->score;
|
||||
if (ext_scorer != nullptr) {
|
||||
std::vector<int> output;
|
||||
prefixes[i]->get_path_vec(output);
|
||||
auto prefix_length = output.size();
|
||||
auto words = ext_scorer->split_labels(output);
|
||||
// remove word insert
|
||||
approx_ctc = approx_ctc - prefix_length * ext_scorer->beta;
|
||||
// remove language model weight:
|
||||
approx_ctc -=
|
||||
(ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha;
|
||||
}
|
||||
prefixes[i]->approx_ctc = approx_ctc;
|
||||
}
|
||||
|
||||
std::vector<std::pair<double, std::string>> res =
|
||||
get_beam_search_result(prefixes, vocabulary, beam_size);
|
||||
|
||||
// pay back the last word of each prefix that doesn't end with space (for
|
||||
// decoding by chunk)
|
||||
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
|
||||
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
|
||||
auto prefix = prefixes[i];
|
||||
if (!prefix->is_empty() && prefix->character != space_id) {
|
||||
float score = 0.0;
|
||||
std::vector<std::string> ngram = ext_scorer->make_ngram(prefix);
|
||||
score =
|
||||
ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha;
|
||||
score += ext_scorer->beta;
|
||||
prefix->score -= score;
|
||||
}
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
void free_storage(std::unique_ptr<CtcBeamSearchDecoderStorage> &storage) {
|
||||
storage = nullptr;
|
||||
}
|
||||
|
||||
|
||||
CtcBeamSearchDecoderBatch::~CtcBeamSearchDecoderBatch() {}
|
||||
|
||||
CtcBeamSearchDecoderBatch::CtcBeamSearchDecoderBatch(
|
||||
const std::vector<std::string> &vocabulary,
|
||||
size_t batch_size,
|
||||
size_t beam_size,
|
||||
size_t num_processes,
|
||||
double cutoff_prob,
|
||||
size_t cutoff_top_n,
|
||||
Scorer *ext_scorer,
|
||||
size_t blank_id)
|
||||
: batch_size(batch_size),
|
||||
beam_size(beam_size),
|
||||
num_processes(num_processes),
|
||||
cutoff_prob(cutoff_prob),
|
||||
cutoff_top_n(cutoff_top_n),
|
||||
ext_scorer(ext_scorer),
|
||||
blank_id(blank_id) {
|
||||
VALID_CHECK_GT(this->beam_size, 0, "beam_size must be greater than 0!");
|
||||
VALID_CHECK_GT(
|
||||
this->num_processes, 0, "num_processes must be nonnegative!");
|
||||
this->vocabulary = vocabulary;
|
||||
for (size_t i = 0; i < batch_size; i++) {
|
||||
this->decoder_storage_vector.push_back(
|
||||
std::unique_ptr<CtcBeamSearchDecoderStorage>(
|
||||
new CtcBeamSearchDecoderStorage()));
|
||||
ctc_beam_search_decode_chunk_begin(
|
||||
this->decoder_storage_vector[i]->root, ext_scorer);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Input
|
||||
* probs_split: shape [B, T, D]
|
||||
*/
|
||||
void CtcBeamSearchDecoderBatch::next(
|
||||
const std::vector<std::vector<std::vector<double>>> &probs_split,
|
||||
const std::vector<std::string> &has_value) {
|
||||
VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!");
|
||||
// thread pool
|
||||
size_t num_has_value = 0;
|
||||
for (int i = 0; i < has_value.size(); i++)
|
||||
if (has_value[i] == "true") num_has_value += 1;
|
||||
ThreadPool pool(std::min(num_processes, num_has_value));
|
||||
// number of samples
|
||||
size_t probs_num = probs_split.size();
|
||||
VALID_CHECK_EQ(this->batch_size,
|
||||
probs_num,
|
||||
"The batch size of the current input data should be same "
|
||||
"with the input data before");
|
||||
|
||||
// enqueue the tasks of decoding
|
||||
std::vector<std::future<void>> res;
|
||||
for (size_t i = 0; i < batch_size; ++i) {
|
||||
if (has_value[i] == "true") {
|
||||
res.emplace_back(pool.enqueue(
|
||||
ctc_beam_search_decode_chunk,
|
||||
std::ref(this->decoder_storage_vector[i]->root),
|
||||
std::ref(this->decoder_storage_vector[i]->prefixes),
|
||||
probs_split[i],
|
||||
this->vocabulary,
|
||||
this->beam_size,
|
||||
this->cutoff_prob,
|
||||
this->cutoff_top_n,
|
||||
this->ext_scorer,
|
||||
this->blank_id));
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < batch_size; ++i) {
|
||||
res[i].get();
|
||||
}
|
||||
return;
|
||||
};
|
||||
|
||||
/**
|
||||
* Return
|
||||
* batch_result: shape[B, beam_size,(-approx_ctc score, string)]
|
||||
*/
|
||||
std::vector<std::vector<std::pair<double, std::string>>>
|
||||
CtcBeamSearchDecoderBatch::decode() {
|
||||
VALID_CHECK_GT(
|
||||
this->num_processes, 0, "num_processes must be nonnegative!");
|
||||
// thread pool
|
||||
ThreadPool pool(this->num_processes);
|
||||
// number of samples
|
||||
// enqueue the tasks of decoding
|
||||
std::vector<std::future<std::vector<std::pair<double, std::string>>>> res;
|
||||
for (size_t i = 0; i < this->batch_size; ++i) {
|
||||
res.emplace_back(
|
||||
pool.enqueue(get_decode_result,
|
||||
std::ref(this->decoder_storage_vector[i]->prefixes),
|
||||
this->vocabulary,
|
||||
this->beam_size,
|
||||
this->ext_scorer));
|
||||
}
|
||||
// get decoding results
|
||||
std::vector<std::vector<std::pair<double, std::string>>> batch_results;
|
||||
for (size_t i = 0; i < this->batch_size; ++i) {
|
||||
batch_results.emplace_back(res[i].get());
|
||||
}
|
||||
return batch_results;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* reset the state of ctcBeamSearchDecoderBatch
|
||||
*/
|
||||
void CtcBeamSearchDecoderBatch::reset_state(size_t batch_size,
|
||||
size_t beam_size,
|
||||
size_t num_processes,
|
||||
double cutoff_prob,
|
||||
size_t cutoff_top_n) {
|
||||
this->batch_size = batch_size;
|
||||
this->beam_size = beam_size;
|
||||
this->num_processes = num_processes;
|
||||
this->cutoff_prob = cutoff_prob;
|
||||
this->cutoff_top_n = cutoff_top_n;
|
||||
|
||||
VALID_CHECK_GT(this->beam_size, 0, "beam_size must be greater than 0!");
|
||||
VALID_CHECK_GT(
|
||||
this->num_processes, 0, "num_processes must be nonnegative!");
|
||||
// thread pool
|
||||
ThreadPool pool(this->num_processes);
|
||||
// number of samples
|
||||
// enqueue the tasks of decoding
|
||||
std::vector<std::future<void>> res;
|
||||
size_t storage_size = decoder_storage_vector.size();
|
||||
for (size_t i = 0; i < storage_size; i++) {
|
||||
res.emplace_back(pool.enqueue(
|
||||
free_storage, std::ref(this->decoder_storage_vector[i])));
|
||||
}
|
||||
for (size_t i = 0; i < storage_size; ++i) {
|
||||
res[i].get();
|
||||
}
|
||||
std::vector<std::unique_ptr<CtcBeamSearchDecoderStorage>>().swap(
|
||||
decoder_storage_vector);
|
||||
for (size_t i = 0; i < this->batch_size; i++) {
|
||||
this->decoder_storage_vector.push_back(
|
||||
std::unique_ptr<CtcBeamSearchDecoderStorage>(
|
||||
new CtcBeamSearchDecoderStorage()));
|
||||
ctc_beam_search_decode_chunk_begin(
|
||||
this->decoder_storage_vector[i]->root, this->ext_scorer);
|
||||
}
|
||||
}
|
@ -1,175 +0,0 @@
|
||||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "COPYING.APACHE2.0");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef CTC_BEAM_SEARCH_DECODER_H_
|
||||
#define CTC_BEAM_SEARCH_DECODER_H_
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "scorer.h"
|
||||
|
||||
/* CTC Beam Search Decoder
|
||||
|
||||
* Parameters:
|
||||
* probs_seq: 2-D vector that each element is a vector of probabilities
|
||||
* over vocabulary of one time step.
|
||||
* vocabulary: A vector of vocabulary.
|
||||
* beam_size: The width of beam search.
|
||||
* cutoff_prob: Cutoff probability for pruning.
|
||||
* cutoff_top_n: Cutoff number for pruning.
|
||||
* ext_scorer: External scorer to evaluate a prefix, which consists of
|
||||
* n-gram language model scoring and word insertion term.
|
||||
* Default null, decoding the input sample without scorer.
|
||||
* Return:
|
||||
* A vector that each element is a pair of score and decoding result,
|
||||
* in desending order.
|
||||
*/
|
||||
std::vector<std::pair<double, std::string>> ctc_beam_search_decoding(
|
||||
const std::vector<std::vector<double>> &probs_seq,
|
||||
const std::vector<std::string> &vocabulary,
|
||||
size_t beam_size,
|
||||
double cutoff_prob = 1.0,
|
||||
size_t cutoff_top_n = 40,
|
||||
Scorer *ext_scorer = nullptr,
|
||||
size_t blank_id = 0);
|
||||
|
||||
|
||||
/* CTC Beam Search Decoder for batch data
|
||||
|
||||
* Parameters:
|
||||
* probs_seq: 3-D vector that each element is a 2-D vector that can be used
|
||||
* by ctc_beam_search_decoder().
|
||||
* vocabulary: A vector of vocabulary.
|
||||
* beam_size: The width of beam search.
|
||||
* num_processes: Number of threads for beam search.
|
||||
* cutoff_prob: Cutoff probability for pruning.
|
||||
* cutoff_top_n: Cutoff number for pruning.
|
||||
* ext_scorer: External scorer to evaluate a prefix, which consists of
|
||||
* n-gram language model scoring and word insertion term.
|
||||
* Default null, decoding the input sample without scorer.
|
||||
* Return:
|
||||
* A 2-D vector that each element is a vector of beam search decoding
|
||||
* result for one audio sample.
|
||||
*/
|
||||
std::vector<std::vector<std::pair<double, std::string>>>
|
||||
ctc_beam_search_decoding_batch(
|
||||
const std::vector<std::vector<std::vector<double>>> &probs_split,
|
||||
const std::vector<std::string> &vocabulary,
|
||||
size_t beam_size,
|
||||
size_t num_processes,
|
||||
double cutoff_prob = 1.0,
|
||||
size_t cutoff_top_n = 40,
|
||||
Scorer *ext_scorer = nullptr,
|
||||
size_t blank_id = 0);
|
||||
|
||||
/**
|
||||
* Store the root and prefixes for decoder
|
||||
*/
|
||||
|
||||
class CtcBeamSearchDecoderStorage {
|
||||
public:
|
||||
PathTrie *root = nullptr;
|
||||
std::vector<PathTrie *> prefixes;
|
||||
|
||||
CtcBeamSearchDecoderStorage() {
|
||||
// init prefixes' root
|
||||
this->root = new PathTrie();
|
||||
this->root->log_prob_b_prev = 0.0;
|
||||
// The score of root is in log scale.Since the prob=1.0, the prob score
|
||||
// in log scale is 0.0
|
||||
this->root->score = root->log_prob_b_prev;
|
||||
// std::vector<PathTrie *> prefixes;
|
||||
this->prefixes.push_back(root);
|
||||
};
|
||||
|
||||
~CtcBeamSearchDecoderStorage() {
|
||||
if (root != nullptr) {
|
||||
delete root;
|
||||
root = nullptr;
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
/**
|
||||
* The ctc beam search decoder, support batchsize >= 1
|
||||
*/
|
||||
class CtcBeamSearchDecoderBatch {
|
||||
public:
|
||||
CtcBeamSearchDecoderBatch(const std::vector<std::string> &vocabulary,
|
||||
size_t batch_size,
|
||||
size_t beam_size,
|
||||
size_t num_processes,
|
||||
double cutoff_prob,
|
||||
size_t cutoff_top_n,
|
||||
Scorer *ext_scorer,
|
||||
size_t blank_id);
|
||||
|
||||
~CtcBeamSearchDecoderBatch();
|
||||
void next(const std::vector<std::vector<std::vector<double>>> &probs_split,
|
||||
const std::vector<std::string> &has_value);
|
||||
|
||||
std::vector<std::vector<std::pair<double, std::string>>> decode();
|
||||
|
||||
void reset_state(size_t batch_size,
|
||||
size_t beam_size,
|
||||
size_t num_processes,
|
||||
double cutoff_prob,
|
||||
size_t cutoff_top_n);
|
||||
|
||||
private:
|
||||
std::vector<std::string> vocabulary;
|
||||
size_t batch_size;
|
||||
size_t beam_size;
|
||||
size_t num_processes;
|
||||
double cutoff_prob;
|
||||
size_t cutoff_top_n;
|
||||
Scorer *ext_scorer;
|
||||
size_t blank_id;
|
||||
std::vector<std::unique_ptr<CtcBeamSearchDecoderStorage>>
|
||||
decoder_storage_vector;
|
||||
};
|
||||
|
||||
/**
|
||||
* function for chunk decoding
|
||||
*/
|
||||
void ctc_beam_search_decode_chunk(
|
||||
PathTrie *root,
|
||||
std::vector<PathTrie *> &prefixes,
|
||||
const std::vector<std::vector<double>> &probs_seq,
|
||||
const std::vector<std::string> &vocabulary,
|
||||
size_t beam_size,
|
||||
double cutoff_prob,
|
||||
size_t cutoff_top_n,
|
||||
Scorer *ext_scorer,
|
||||
size_t blank_id);
|
||||
|
||||
std::vector<std::pair<double, std::string>> get_decode_result(
|
||||
std::vector<PathTrie *> &prefixes,
|
||||
const std::vector<std::string> &vocabulary,
|
||||
size_t beam_size,
|
||||
Scorer *ext_scorer);
|
||||
|
||||
/**
|
||||
* free the CtcBeamSearchDecoderStorage
|
||||
*/
|
||||
void free_storage(std::unique_ptr<CtcBeamSearchDecoderStorage> &storage);
|
||||
|
||||
/**
|
||||
* initialize the root
|
||||
*/
|
||||
void ctc_beam_search_decode_chunk_begin(PathTrie *root, Scorer *ext_scorer);
|
||||
|
||||
#endif // CTC_BEAM_SEARCH_DECODER_H_
|
@ -1,61 +0,0 @@
|
||||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "COPYING.APACHE2.0");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "ctc_greedy_decoder.h"
|
||||
#include "decoder_utils.h"
|
||||
|
||||
std::string ctc_greedy_decoding(
|
||||
const std::vector<std::vector<double>> &probs_seq,
|
||||
const std::vector<std::string> &vocabulary,
|
||||
size_t blank_id) {
|
||||
// dimension check
|
||||
size_t num_time_steps = probs_seq.size();
|
||||
for (size_t i = 0; i < num_time_steps; ++i) {
|
||||
VALID_CHECK_EQ(probs_seq[i].size(),
|
||||
vocabulary.size(),
|
||||
"The shape of probs_seq does not match with "
|
||||
"the shape of the vocabulary");
|
||||
}
|
||||
|
||||
// size_t blank_id = vocabulary.size();
|
||||
|
||||
std::vector<size_t> max_idx_vec(num_time_steps, 0);
|
||||
std::vector<size_t> idx_vec;
|
||||
for (size_t i = 0; i < num_time_steps; ++i) {
|
||||
double max_prob = 0.0;
|
||||
size_t max_idx = 0;
|
||||
const std::vector<double> &probs_step = probs_seq[i];
|
||||
for (size_t j = 0; j < probs_step.size(); ++j) {
|
||||
if (max_prob < probs_step[j]) {
|
||||
max_idx = j;
|
||||
max_prob = probs_step[j];
|
||||
}
|
||||
}
|
||||
// id with maximum probability in current time step
|
||||
max_idx_vec[i] = max_idx;
|
||||
// deduplicate
|
||||
if ((i == 0) || ((i > 0) && max_idx_vec[i] != max_idx_vec[i - 1])) {
|
||||
idx_vec.push_back(max_idx_vec[i]);
|
||||
}
|
||||
}
|
||||
|
||||
std::string best_path_result;
|
||||
for (size_t i = 0; i < idx_vec.size(); ++i) {
|
||||
if (idx_vec[i] != blank_id) {
|
||||
std::string ch = vocabulary[idx_vec[i]];
|
||||
best_path_result += (ch == kSPACE) ? tSPACE : ch;
|
||||
}
|
||||
}
|
||||
return best_path_result;
|
||||
}
|
@ -1,35 +0,0 @@
|
||||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "COPYING.APACHE2.0");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef CTC_GREEDY_DECODER_H
|
||||
#define CTC_GREEDY_DECODER_H
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
/* CTC Greedy (Best Path) Decoder
|
||||
*
|
||||
* Parameters:
|
||||
* probs_seq: 2-D vector that each element is a vector of probabilities
|
||||
* over vocabulary of one time step.
|
||||
* vocabulary: A vector of vocabulary.
|
||||
* Return:
|
||||
* The decoding result in string
|
||||
*/
|
||||
std::string ctc_greedy_decoding(
|
||||
const std::vector<std::vector<double>>& probs_seq,
|
||||
const std::vector<std::string>& vocabulary,
|
||||
size_t blank_id);
|
||||
|
||||
#endif // CTC_GREEDY_DECODER_H
|
@ -1,193 +0,0 @@
|
||||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "COPYING.APACHE2.0");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "decoder_utils.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <limits>
|
||||
|
||||
std::vector<std::pair<size_t, float>> get_pruned_log_probs(
|
||||
const std::vector<double> &prob_step,
|
||||
double cutoff_prob,
|
||||
size_t cutoff_top_n) {
|
||||
std::vector<std::pair<int, double>> prob_idx;
|
||||
for (size_t i = 0; i < prob_step.size(); ++i) {
|
||||
prob_idx.push_back(std::pair<int, double>(i, prob_step[i]));
|
||||
}
|
||||
// pruning of vocabulary
|
||||
size_t cutoff_len = prob_step.size();
|
||||
if (cutoff_prob < 1.0 || cutoff_top_n < cutoff_len) {
|
||||
std::sort(prob_idx.begin(),
|
||||
prob_idx.end(),
|
||||
pair_comp_second_rev<int, double>);
|
||||
if (cutoff_prob < 1.0) {
|
||||
double cum_prob = 0.0;
|
||||
cutoff_len = 0;
|
||||
for (size_t i = 0; i < prob_idx.size(); ++i) {
|
||||
cum_prob += prob_idx[i].second;
|
||||
cutoff_len += 1;
|
||||
if (cum_prob >= cutoff_prob || cutoff_len >= cutoff_top_n)
|
||||
break;
|
||||
}
|
||||
}
|
||||
prob_idx = std::vector<std::pair<int, double>>(
|
||||
prob_idx.begin(), prob_idx.begin() + cutoff_len);
|
||||
}
|
||||
std::vector<std::pair<size_t, float>> log_prob_idx;
|
||||
for (size_t i = 0; i < cutoff_len; ++i) {
|
||||
log_prob_idx.push_back(std::pair<int, float>(
|
||||
prob_idx[i].first, log(prob_idx[i].second + NUM_FLT_MIN)));
|
||||
}
|
||||
return log_prob_idx;
|
||||
}
|
||||
|
||||
|
||||
std::vector<std::pair<double, std::string>> get_beam_search_result(
|
||||
const std::vector<PathTrie *> &prefixes,
|
||||
const std::vector<std::string> &vocabulary,
|
||||
size_t beam_size) {
|
||||
// allow for the post processing
|
||||
std::vector<PathTrie *> space_prefixes;
|
||||
if (space_prefixes.empty()) {
|
||||
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
|
||||
space_prefixes.push_back(prefixes[i]);
|
||||
}
|
||||
}
|
||||
|
||||
std::sort(space_prefixes.begin(), space_prefixes.end(), prefix_compare);
|
||||
std::vector<std::pair<double, std::string>> output_vecs;
|
||||
for (size_t i = 0; i < beam_size && i < space_prefixes.size(); ++i) {
|
||||
std::vector<int> output;
|
||||
space_prefixes[i]->get_path_vec(output);
|
||||
// convert index to string
|
||||
std::string output_str;
|
||||
for (size_t j = 0; j < output.size(); j++) {
|
||||
std::string ch = vocabulary[output[j]];
|
||||
output_str += (ch == kSPACE) ? tSPACE : ch;
|
||||
}
|
||||
std::pair<double, std::string> output_pair(
|
||||
-space_prefixes[i]->approx_ctc, output_str);
|
||||
output_vecs.emplace_back(output_pair);
|
||||
}
|
||||
|
||||
return output_vecs;
|
||||
}
|
||||
|
||||
size_t get_utf8_str_len(const std::string &str) {
|
||||
size_t str_len = 0;
|
||||
for (char c : str) {
|
||||
str_len += ((c & 0xc0) != 0x80);
|
||||
}
|
||||
return str_len;
|
||||
}
|
||||
|
||||
std::vector<std::string> split_utf8_str(const std::string &str) {
|
||||
std::vector<std::string> result;
|
||||
std::string out_str;
|
||||
|
||||
for (char c : str) {
|
||||
if ((c & 0xc0) != 0x80) // new UTF-8 character
|
||||
{
|
||||
if (!out_str.empty()) {
|
||||
result.push_back(out_str);
|
||||
out_str.clear();
|
||||
}
|
||||
}
|
||||
|
||||
out_str.append(1, c);
|
||||
}
|
||||
result.push_back(out_str);
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<std::string> split_str(const std::string &s,
|
||||
const std::string &delim) {
|
||||
std::vector<std::string> result;
|
||||
std::size_t start = 0, delim_len = delim.size();
|
||||
while (true) {
|
||||
std::size_t end = s.find(delim, start);
|
||||
if (end == std::string::npos) {
|
||||
if (start < s.size()) {
|
||||
result.push_back(s.substr(start));
|
||||
}
|
||||
break;
|
||||
}
|
||||
if (end > start) {
|
||||
result.push_back(s.substr(start, end - start));
|
||||
}
|
||||
start = end + delim_len;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool prefix_compare(const PathTrie *x, const PathTrie *y) {
|
||||
if (x->score == y->score) {
|
||||
if (x->character == y->character) {
|
||||
return false;
|
||||
} else {
|
||||
return (x->character < y->character);
|
||||
}
|
||||
} else {
|
||||
return x->score > y->score;
|
||||
}
|
||||
}
|
||||
|
||||
void add_word_to_fst(const std::vector<int> &word,
|
||||
fst::StdVectorFst *dictionary) {
|
||||
if (dictionary->NumStates() == 0) {
|
||||
fst::StdVectorFst::StateId start = dictionary->AddState();
|
||||
assert(start == 0);
|
||||
dictionary->SetStart(start);
|
||||
}
|
||||
fst::StdVectorFst::StateId src = dictionary->Start();
|
||||
fst::StdVectorFst::StateId dst;
|
||||
for (auto c : word) {
|
||||
dst = dictionary->AddState();
|
||||
dictionary->AddArc(src, fst::StdArc(c, c, 0, dst));
|
||||
src = dst;
|
||||
}
|
||||
dictionary->SetFinal(dst, fst::StdArc::Weight::One());
|
||||
}
|
||||
|
||||
bool add_word_to_dictionary(
|
||||
const std::string &word,
|
||||
const std::unordered_map<std::string, int> &char_map,
|
||||
bool add_space,
|
||||
int SPACE_ID,
|
||||
fst::StdVectorFst *dictionary) {
|
||||
auto characters = split_utf8_str(word);
|
||||
|
||||
std::vector<int> int_word;
|
||||
|
||||
for (auto &c : characters) {
|
||||
if (c == " ") {
|
||||
int_word.push_back(SPACE_ID);
|
||||
} else {
|
||||
auto int_c = char_map.find(c);
|
||||
if (int_c != char_map.end()) {
|
||||
int_word.push_back(int_c->second);
|
||||
} else {
|
||||
return false; // return without adding
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (add_space) {
|
||||
int_word.push_back(SPACE_ID);
|
||||
}
|
||||
|
||||
add_word_to_fst(int_word, dictionary);
|
||||
return true; // return with successful adding
|
||||
}
|
@ -1,111 +0,0 @@
|
||||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "COPYING.APACHE2.0");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef DECODER_UTILS_H_
|
||||
#define DECODER_UTILS_H_
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include "fst/log.h"
|
||||
#include "path_trie.h"
|
||||
|
||||
const std::string kSPACE = "<space>";
|
||||
const std::string tSPACE = " ";
|
||||
const float NUM_FLT_INF = std::numeric_limits<float>::max();
|
||||
const float NUM_FLT_MIN = std::numeric_limits<float>::min();
|
||||
|
||||
// inline function for validation check
|
||||
inline void check(
|
||||
bool x, const char *expr, const char *file, int line, const char *err) {
|
||||
if (!x) {
|
||||
std::cout << "[" << file << ":" << line << "] ";
|
||||
LOG(FATAL) << "\"" << expr << "\" check failed. " << err;
|
||||
}
|
||||
}
|
||||
|
||||
#define VALID_CHECK(x, info) \
|
||||
check(static_cast<bool>(x), #x, __FILE__, __LINE__, info)
|
||||
#define VALID_CHECK_EQ(x, y, info) VALID_CHECK((x) == (y), info)
|
||||
#define VALID_CHECK_GT(x, y, info) VALID_CHECK((x) > (y), info)
|
||||
#define VALID_CHECK_LT(x, y, info) VALID_CHECK((x) < (y), info)
|
||||
|
||||
|
||||
// Function template for comparing two pairs
|
||||
template <typename T1, typename T2>
|
||||
bool pair_comp_first_rev(const std::pair<T1, T2> &a,
|
||||
const std::pair<T1, T2> &b) {
|
||||
return a.first > b.first;
|
||||
}
|
||||
|
||||
// Function template for comparing two pairs
|
||||
template <typename T1, typename T2>
|
||||
bool pair_comp_second_rev(const std::pair<T1, T2> &a,
|
||||
const std::pair<T1, T2> &b) {
|
||||
return a.second > b.second;
|
||||
}
|
||||
|
||||
// Return the sum of two probabilities in log scale
|
||||
template <typename T>
|
||||
T log_sum_exp(const T &x, const T &y) {
|
||||
static T num_min = -std::numeric_limits<T>::max();
|
||||
if (x <= num_min) return y;
|
||||
if (y <= num_min) return x;
|
||||
T xmax = std::max(x, y);
|
||||
return std::log(std::exp(x - xmax) + std::exp(y - xmax)) + xmax;
|
||||
}
|
||||
|
||||
// Get pruned probability vector for each time step's beam search
|
||||
std::vector<std::pair<size_t, float>> get_pruned_log_probs(
|
||||
const std::vector<double> &prob_step,
|
||||
double cutoff_prob,
|
||||
size_t cutoff_top_n);
|
||||
|
||||
// Get beam search result from prefixes in trie tree
|
||||
std::vector<std::pair<double, std::string>> get_beam_search_result(
|
||||
const std::vector<PathTrie *> &prefixes,
|
||||
const std::vector<std::string> &vocabulary,
|
||||
size_t beam_size);
|
||||
|
||||
// Functor for prefix comparsion
|
||||
bool prefix_compare(const PathTrie *x, const PathTrie *y);
|
||||
|
||||
/* Get length of utf8 encoding string
|
||||
* See: http://stackoverflow.com/a/4063229
|
||||
*/
|
||||
size_t get_utf8_str_len(const std::string &str);
|
||||
|
||||
/* Split a string into a list of strings on a given string
|
||||
* delimiter. NB: delimiters on beginning / end of string are
|
||||
* trimmed. Eg, "FooBarFoo" split on "Foo" returns ["Bar"].
|
||||
*/
|
||||
std::vector<std::string> split_str(const std::string &s,
|
||||
const std::string &delim);
|
||||
|
||||
/* Splits string into vector of strings representing
|
||||
* UTF-8 characters (not same as chars)
|
||||
*/
|
||||
std::vector<std::string> split_utf8_str(const std::string &str);
|
||||
|
||||
// Add a word in index to the dicionary of fst
|
||||
void add_word_to_fst(const std::vector<int> &word,
|
||||
fst::StdVectorFst *dictionary);
|
||||
|
||||
// Add a word in string to dictionary
|
||||
bool add_word_to_dictionary(
|
||||
const std::string &word,
|
||||
const std::unordered_map<std::string, int> &char_map,
|
||||
bool add_space,
|
||||
int SPACE_ID,
|
||||
fst::StdVectorFst *dictionary);
|
||||
#endif // DECODER_UTILS_H
|
@ -1,164 +0,0 @@
|
||||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "COPYING.APACHE2.0");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "path_trie.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "decoder_utils.h"
|
||||
|
||||
PathTrie::PathTrie() {
|
||||
log_prob_b_prev = -NUM_FLT_INF;
|
||||
log_prob_nb_prev = -NUM_FLT_INF;
|
||||
log_prob_b_cur = -NUM_FLT_INF;
|
||||
log_prob_nb_cur = -NUM_FLT_INF;
|
||||
score = -NUM_FLT_INF;
|
||||
|
||||
ROOT_ = -1;
|
||||
character = ROOT_;
|
||||
exists_ = true;
|
||||
parent = nullptr;
|
||||
|
||||
dictionary_ = nullptr;
|
||||
dictionary_state_ = 0;
|
||||
has_dictionary_ = false;
|
||||
|
||||
matcher_ = nullptr;
|
||||
}
|
||||
|
||||
PathTrie::~PathTrie() {
|
||||
for (auto child : children_) {
|
||||
delete child.second;
|
||||
child.second = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
PathTrie* PathTrie::get_path_trie(int new_char, bool reset) {
|
||||
auto child = children_.begin();
|
||||
for (child = children_.begin(); child != children_.end(); ++child) {
|
||||
if (child->first == new_char) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (child != children_.end()) {
|
||||
if (!child->second->exists_) {
|
||||
child->second->exists_ = true;
|
||||
child->second->log_prob_b_prev = -NUM_FLT_INF;
|
||||
child->second->log_prob_nb_prev = -NUM_FLT_INF;
|
||||
child->second->log_prob_b_cur = -NUM_FLT_INF;
|
||||
child->second->log_prob_nb_cur = -NUM_FLT_INF;
|
||||
}
|
||||
return (child->second);
|
||||
} else {
|
||||
if (has_dictionary_) {
|
||||
matcher_->SetState(dictionary_state_);
|
||||
bool found = matcher_->Find(new_char + 1);
|
||||
if (!found) {
|
||||
// Adding this character causes word outside dictionary
|
||||
auto FSTZERO = fst::TropicalWeight::Zero();
|
||||
auto final_weight = dictionary_->Final(dictionary_state_);
|
||||
bool is_final = (final_weight != FSTZERO);
|
||||
if (is_final && reset) {
|
||||
dictionary_state_ = dictionary_->Start();
|
||||
}
|
||||
return nullptr;
|
||||
} else {
|
||||
PathTrie* new_path = new PathTrie;
|
||||
new_path->character = new_char;
|
||||
new_path->parent = this;
|
||||
new_path->dictionary_ = dictionary_;
|
||||
new_path->dictionary_state_ = matcher_->Value().nextstate;
|
||||
new_path->has_dictionary_ = true;
|
||||
new_path->matcher_ = matcher_;
|
||||
children_.push_back(std::make_pair(new_char, new_path));
|
||||
return new_path;
|
||||
}
|
||||
} else {
|
||||
PathTrie* new_path = new PathTrie;
|
||||
new_path->character = new_char;
|
||||
new_path->parent = this;
|
||||
children_.push_back(std::make_pair(new_char, new_path));
|
||||
return new_path;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
PathTrie* PathTrie::get_path_vec(std::vector<int>& output) {
|
||||
return get_path_vec(output, ROOT_);
|
||||
}
|
||||
|
||||
PathTrie* PathTrie::get_path_vec(std::vector<int>& output,
|
||||
int stop,
|
||||
size_t max_steps) {
|
||||
if (character == stop || character == ROOT_ || output.size() == max_steps) {
|
||||
std::reverse(output.begin(), output.end());
|
||||
return this;
|
||||
} else {
|
||||
output.push_back(character);
|
||||
return parent->get_path_vec(output, stop, max_steps);
|
||||
}
|
||||
}
|
||||
|
||||
void PathTrie::iterate_to_vec(std::vector<PathTrie*>& output) {
|
||||
if (exists_) {
|
||||
log_prob_b_prev = log_prob_b_cur;
|
||||
log_prob_nb_prev = log_prob_nb_cur;
|
||||
|
||||
log_prob_b_cur = -NUM_FLT_INF;
|
||||
log_prob_nb_cur = -NUM_FLT_INF;
|
||||
|
||||
score = log_sum_exp(log_prob_b_prev, log_prob_nb_prev);
|
||||
output.push_back(this);
|
||||
}
|
||||
for (auto child : children_) {
|
||||
child.second->iterate_to_vec(output);
|
||||
}
|
||||
}
|
||||
|
||||
void PathTrie::remove() {
|
||||
exists_ = false;
|
||||
if (children_.size() == 0) {
|
||||
if (parent != nullptr) {
|
||||
auto child = parent->children_.begin();
|
||||
for (child = parent->children_.begin();
|
||||
child != parent->children_.end();
|
||||
++child) {
|
||||
if (child->first == character) {
|
||||
parent->children_.erase(child);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (parent->children_.size() == 0 && !parent->exists_) {
|
||||
parent->remove();
|
||||
}
|
||||
}
|
||||
delete this;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void PathTrie::set_dictionary(fst::StdVectorFst* dictionary) {
|
||||
dictionary_ = dictionary;
|
||||
dictionary_state_ = dictionary->Start();
|
||||
has_dictionary_ = true;
|
||||
}
|
||||
|
||||
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
|
||||
void PathTrie::set_matcher(std::shared_ptr<FSTMATCH> matcher) {
|
||||
matcher_ = matcher;
|
||||
}
|
@ -1,82 +0,0 @@
|
||||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "COPYING.APACHE2.0");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef PATH_TRIE_H
|
||||
#define PATH_TRIE_H
|
||||
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "fst/fstlib.h"
|
||||
|
||||
/* Trie tree for prefix storing and manipulating, with a dictionary in
|
||||
* finite-state transducer for spelling correction.
|
||||
*/
|
||||
class PathTrie {
|
||||
public:
|
||||
PathTrie();
|
||||
~PathTrie();
|
||||
|
||||
// get new prefix after appending new char
|
||||
PathTrie* get_path_trie(int new_char, bool reset = true);
|
||||
|
||||
// get the prefix in index from root to current node
|
||||
PathTrie* get_path_vec(std::vector<int>& output);
|
||||
|
||||
// get the prefix in index from some stop node to current nodel
|
||||
PathTrie* get_path_vec(
|
||||
std::vector<int>& output,
|
||||
int stop,
|
||||
size_t max_steps = std::numeric_limits<size_t>::max());
|
||||
|
||||
// update log probs
|
||||
void iterate_to_vec(std::vector<PathTrie*>& output);
|
||||
|
||||
// set dictionary for FST
|
||||
void set_dictionary(fst::StdVectorFst* dictionary);
|
||||
|
||||
void set_matcher(std::shared_ptr<fst::SortedMatcher<fst::StdVectorFst>>);
|
||||
|
||||
bool is_empty() { return ROOT_ == character; }
|
||||
|
||||
// remove current path from root
|
||||
void remove();
|
||||
|
||||
float log_prob_b_prev;
|
||||
float log_prob_nb_prev;
|
||||
float log_prob_b_cur;
|
||||
float log_prob_nb_cur;
|
||||
float score;
|
||||
float approx_ctc;
|
||||
int character;
|
||||
PathTrie* parent;
|
||||
|
||||
private:
|
||||
int ROOT_;
|
||||
bool exists_;
|
||||
bool has_dictionary_;
|
||||
|
||||
std::vector<std::pair<int, PathTrie*>> children_;
|
||||
|
||||
// pointer to dictionary of FST
|
||||
fst::StdVectorFst* dictionary_;
|
||||
fst::StdVectorFst::StateId dictionary_state_;
|
||||
// true if finding ars in FST
|
||||
std::shared_ptr<fst::SortedMatcher<fst::StdVectorFst>> matcher_;
|
||||
};
|
||||
|
||||
#endif // PATH_TRIE_H
|
@ -1,232 +0,0 @@
|
||||
// Licensed under GNU Lesser General Public License v3 (LGPLv3) (LGPL-3) (the
|
||||
// "COPYING.LESSER.3");
|
||||
|
||||
#include "scorer.h"
|
||||
|
||||
#include <unistd.h>
|
||||
#include <iostream>
|
||||
|
||||
#include "lm/config.hh"
|
||||
#include "lm/model.hh"
|
||||
#include "lm/state.hh"
|
||||
|
||||
#include "decoder_utils.h"
|
||||
|
||||
using namespace lm::ngram;
|
||||
// if your platform is windows ,you need add the define
|
||||
#define F_OK 0
|
||||
Scorer::Scorer(double alpha,
|
||||
double beta,
|
||||
const std::string& lm_path,
|
||||
const std::vector<std::string>& vocab_list) {
|
||||
this->alpha = alpha;
|
||||
this->beta = beta;
|
||||
|
||||
dictionary = nullptr;
|
||||
is_character_based_ = true;
|
||||
language_model_ = nullptr;
|
||||
|
||||
max_order_ = 0;
|
||||
dict_size_ = 0;
|
||||
SPACE_ID_ = -1;
|
||||
|
||||
setup(lm_path, vocab_list);
|
||||
}
|
||||
|
||||
Scorer::~Scorer() {
|
||||
if (language_model_ != nullptr) {
|
||||
delete static_cast<lm::base::Model*>(language_model_);
|
||||
}
|
||||
if (dictionary != nullptr) {
|
||||
delete static_cast<fst::StdVectorFst*>(dictionary);
|
||||
}
|
||||
}
|
||||
|
||||
void Scorer::setup(const std::string& lm_path,
|
||||
const std::vector<std::string>& vocab_list) {
|
||||
// load language model
|
||||
load_lm(lm_path);
|
||||
// set char map for scorer
|
||||
set_char_map(vocab_list);
|
||||
// fill the dictionary for FST
|
||||
if (!is_character_based()) {
|
||||
fill_dictionary(true);
|
||||
}
|
||||
}
|
||||
|
||||
void Scorer::load_lm(const std::string& lm_path) {
|
||||
const char* filename = lm_path.c_str();
|
||||
VALID_CHECK_EQ(access(filename, F_OK), 0, "Invalid language model path");
|
||||
|
||||
RetriveStrEnumerateVocab enumerate;
|
||||
lm::ngram::Config config;
|
||||
config.enumerate_vocab = &enumerate;
|
||||
language_model_ = lm::ngram::LoadVirtual(filename, config);
|
||||
max_order_ = static_cast<lm::base::Model*>(language_model_)->Order();
|
||||
vocabulary_ = enumerate.vocabulary;
|
||||
for (size_t i = 0; i < vocabulary_.size(); ++i) {
|
||||
if (is_character_based_ && vocabulary_[i] != UNK_TOKEN &&
|
||||
vocabulary_[i] != START_TOKEN && vocabulary_[i] != END_TOKEN &&
|
||||
get_utf8_str_len(enumerate.vocabulary[i]) > 1) {
|
||||
is_character_based_ = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
double Scorer::get_log_cond_prob(const std::vector<std::string>& words) {
|
||||
lm::base::Model* model = static_cast<lm::base::Model*>(language_model_);
|
||||
double cond_prob;
|
||||
lm::ngram::State state, tmp_state, out_state;
|
||||
// avoid to inserting <s> in begin
|
||||
model->NullContextWrite(&state);
|
||||
for (size_t i = 0; i < words.size(); ++i) {
|
||||
lm::WordIndex word_index = model->BaseVocabulary().Index(words[i]);
|
||||
// encounter OOV
|
||||
if (word_index == 0) {
|
||||
return OOV_SCORE;
|
||||
}
|
||||
cond_prob = model->BaseScore(&state, word_index, &out_state);
|
||||
tmp_state = state;
|
||||
state = out_state;
|
||||
out_state = tmp_state;
|
||||
}
|
||||
// return log10 prob
|
||||
return cond_prob;
|
||||
}
|
||||
|
||||
double Scorer::get_sent_log_prob(const std::vector<std::string>& words) {
|
||||
std::vector<std::string> sentence;
|
||||
if (words.size() == 0) {
|
||||
for (size_t i = 0; i < max_order_; ++i) {
|
||||
sentence.push_back(START_TOKEN);
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < max_order_ - 1; ++i) {
|
||||
sentence.push_back(START_TOKEN);
|
||||
}
|
||||
sentence.insert(sentence.end(), words.begin(), words.end());
|
||||
}
|
||||
sentence.push_back(END_TOKEN);
|
||||
return get_log_prob(sentence);
|
||||
}
|
||||
|
||||
double Scorer::get_log_prob(const std::vector<std::string>& words) {
|
||||
assert(words.size() > max_order_);
|
||||
double score = 0.0;
|
||||
for (size_t i = 0; i < words.size() - max_order_ + 1; ++i) {
|
||||
std::vector<std::string> ngram(words.begin() + i,
|
||||
words.begin() + i + max_order_);
|
||||
score += get_log_cond_prob(ngram);
|
||||
}
|
||||
return score;
|
||||
}
|
||||
|
||||
void Scorer::reset_params(float alpha, float beta) {
|
||||
this->alpha = alpha;
|
||||
this->beta = beta;
|
||||
}
|
||||
|
||||
std::string Scorer::vec2str(const std::vector<int>& input) {
|
||||
std::string word;
|
||||
for (auto ind : input) {
|
||||
word += char_list_[ind];
|
||||
}
|
||||
return word;
|
||||
}
|
||||
|
||||
std::vector<std::string> Scorer::split_labels(const std::vector<int>& labels) {
|
||||
if (labels.empty()) return {};
|
||||
|
||||
std::string s = vec2str(labels);
|
||||
std::vector<std::string> words;
|
||||
if (is_character_based_) {
|
||||
words = split_utf8_str(s);
|
||||
} else {
|
||||
words = split_str(s, " ");
|
||||
}
|
||||
return words;
|
||||
}
|
||||
|
||||
void Scorer::set_char_map(const std::vector<std::string>& char_list) {
|
||||
char_list_ = char_list;
|
||||
char_map_.clear();
|
||||
|
||||
// Set the char map for the FST for spelling correction
|
||||
for (size_t i = 0; i < char_list_.size(); i++) {
|
||||
if (char_list_[i] == kSPACE) {
|
||||
SPACE_ID_ = i;
|
||||
}
|
||||
// The initial state of FST is state 0, hence the index of chars in
|
||||
// the FST should start from 1 to avoid the conflict with the initial
|
||||
// state, otherwise wrong decoding results would be given.
|
||||
char_map_[char_list_[i]] = i + 1;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::string> Scorer::make_ngram(PathTrie* prefix) {
|
||||
std::vector<std::string> ngram;
|
||||
PathTrie* current_node = prefix;
|
||||
PathTrie* new_node = nullptr;
|
||||
|
||||
for (int order = 0; order < max_order_; order++) {
|
||||
std::vector<int> prefix_vec;
|
||||
|
||||
if (is_character_based_) {
|
||||
new_node = current_node->get_path_vec(prefix_vec, SPACE_ID_, 1);
|
||||
current_node = new_node;
|
||||
} else {
|
||||
new_node = current_node->get_path_vec(prefix_vec, SPACE_ID_);
|
||||
current_node = new_node->parent; // Skipping spaces
|
||||
}
|
||||
|
||||
// reconstruct word
|
||||
std::string word = vec2str(prefix_vec);
|
||||
ngram.push_back(word);
|
||||
|
||||
if (new_node->character == -1) {
|
||||
// No more spaces, but still need order
|
||||
for (int i = 0; i < max_order_ - order - 1; i++) {
|
||||
ngram.push_back(START_TOKEN);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
std::reverse(ngram.begin(), ngram.end());
|
||||
return ngram;
|
||||
}
|
||||
|
||||
void Scorer::fill_dictionary(bool add_space) {
|
||||
fst::StdVectorFst dictionary;
|
||||
// For each unigram convert to ints and put in trie
|
||||
int dict_size = 0;
|
||||
for (const auto& word : vocabulary_) {
|
||||
bool added = add_word_to_dictionary(
|
||||
word, char_map_, add_space, SPACE_ID_ + 1, &dictionary);
|
||||
dict_size += added ? 1 : 0;
|
||||
}
|
||||
|
||||
dict_size_ = dict_size;
|
||||
|
||||
/* Simplify FST
|
||||
|
||||
* This gets rid of "epsilon" transitions in the FST.
|
||||
* These are transitions that don't require a string input to be taken.
|
||||
* Getting rid of them is necessary to make the FST deterministic, but
|
||||
* can greatly increase the size of the FST
|
||||
*/
|
||||
fst::RmEpsilon(&dictionary);
|
||||
fst::StdVectorFst* new_dict = new fst::StdVectorFst;
|
||||
|
||||
/* This makes the FST deterministic, meaning for any string input there's
|
||||
* only one possible state the FST could be in. It is assumed our
|
||||
* dictionary is deterministic when using it.
|
||||
* (lest we'd have to check for multiple transitions at each state)
|
||||
*/
|
||||
fst::Determinize(dictionary, new_dict);
|
||||
|
||||
/* Finds the simplest equivalent fst. This is unnecessary but decreases
|
||||
* memory usage of the dictionary
|
||||
*/
|
||||
fst::Minimize(new_dict);
|
||||
this->dictionary = new_dict;
|
||||
}
|
@ -1,114 +0,0 @@
|
||||
// Licensed under GNU Lesser General Public License v3 (LGPLv3) (LGPL-3) (the
|
||||
// "COPYING.LESSER.3");
|
||||
|
||||
#ifndef SCORER_H_
|
||||
#define SCORER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "lm/enumerate_vocab.hh"
|
||||
#include "lm/virtual_interface.hh"
|
||||
#include "lm/word_index.hh"
|
||||
|
||||
#include "path_trie.h"
|
||||
|
||||
const double OOV_SCORE = -1000.0;
|
||||
const std::string START_TOKEN = "<s>";
|
||||
const std::string UNK_TOKEN = "<unk>";
|
||||
const std::string END_TOKEN = "</s>";
|
||||
|
||||
// Implement a callback to retrive the dictionary of language model.
|
||||
class RetriveStrEnumerateVocab : public lm::EnumerateVocab {
|
||||
public:
|
||||
RetriveStrEnumerateVocab() {}
|
||||
|
||||
void Add(lm::WordIndex index, const StringPiece &str) {
|
||||
vocabulary.push_back(std::string(str.data(), str.length()));
|
||||
}
|
||||
|
||||
std::vector<std::string> vocabulary;
|
||||
};
|
||||
|
||||
/* External scorer to query score for n-gram or sentence, including language
|
||||
* model scoring and word insertion.
|
||||
*
|
||||
* Example:
|
||||
* Scorer scorer(alpha, beta, "path_of_language_model");
|
||||
* scorer.get_log_cond_prob({ "WORD1", "WORD2", "WORD3" });
|
||||
* scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" });
|
||||
*/
|
||||
class Scorer {
|
||||
public:
|
||||
Scorer(double alpha,
|
||||
double beta,
|
||||
const std::string &lm_path,
|
||||
const std::vector<std::string> &vocabulary);
|
||||
~Scorer();
|
||||
|
||||
double get_log_cond_prob(const std::vector<std::string> &words);
|
||||
|
||||
double get_sent_log_prob(const std::vector<std::string> &words);
|
||||
|
||||
// return the max order
|
||||
size_t get_max_order() const { return max_order_; }
|
||||
|
||||
// return the dictionary size of language model
|
||||
size_t get_dict_size() const { return dict_size_; }
|
||||
|
||||
// retrun true if the language model is character based
|
||||
bool is_character_based() const { return is_character_based_; }
|
||||
|
||||
// reset params alpha & beta
|
||||
void reset_params(float alpha, float beta);
|
||||
|
||||
// make ngram for a given prefix
|
||||
std::vector<std::string> make_ngram(PathTrie *prefix);
|
||||
|
||||
// trransform the labels in index to the vector of words (word based lm) or
|
||||
// the vector of characters (character based lm)
|
||||
std::vector<std::string> split_labels(const std::vector<int> &labels);
|
||||
|
||||
// language model weight
|
||||
double alpha;
|
||||
// word insertion weight
|
||||
double beta;
|
||||
|
||||
// pointer to the dictionary of FST
|
||||
void *dictionary;
|
||||
|
||||
protected:
|
||||
// necessary setup: load language model, set char map, fill FST's dictionary
|
||||
void setup(const std::string &lm_path,
|
||||
const std::vector<std::string> &vocab_list);
|
||||
|
||||
// load language model from given path
|
||||
void load_lm(const std::string &lm_path);
|
||||
|
||||
// fill dictionary for FST
|
||||
void fill_dictionary(bool add_space);
|
||||
|
||||
// set char map
|
||||
void set_char_map(const std::vector<std::string> &char_list);
|
||||
|
||||
double get_log_prob(const std::vector<std::string> &words);
|
||||
|
||||
// translate the vector in index to string
|
||||
std::string vec2str(const std::vector<int> &input);
|
||||
|
||||
private:
|
||||
void *language_model_;
|
||||
bool is_character_based_;
|
||||
size_t max_order_;
|
||||
size_t dict_size_;
|
||||
|
||||
int SPACE_ID_;
|
||||
std::vector<std::string> char_list_;
|
||||
std::unordered_map<std::string, int> char_map_;
|
||||
|
||||
std::vector<std::string> vocabulary_;
|
||||
};
|
||||
|
||||
#endif // SCORER_H_
|
@ -1,77 +0,0 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// todo refactor, repalce with gtest
|
||||
|
||||
#include "base/flags.h"
|
||||
#include "base/log.h"
|
||||
#include "decoder/ctc_beam_search_decoder.h"
|
||||
#include "kaldi/util/table-types.h"
|
||||
#include "nnet/decodable.h"
|
||||
|
||||
DEFINE_string(nnet_prob_respecifier, "", "test nnet prob rspecifier");
|
||||
DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm");
|
||||
DEFINE_string(lm_path, "lm.klm", "language model");
|
||||
|
||||
using kaldi::BaseFloat;
|
||||
using kaldi::Matrix;
|
||||
using std::vector;
|
||||
|
||||
// test decoder by feeding nnet posterior probability
|
||||
int main(int argc, char* argv[]) {
|
||||
gflags::SetUsageMessage("Usage:");
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
google::InstallFailureSignalHandler();
|
||||
FLAGS_logtostderr = 1;
|
||||
|
||||
kaldi::SequentialBaseFloatMatrixReader likelihood_reader(
|
||||
FLAGS_nnet_prob_respecifier);
|
||||
std::string dict_file = FLAGS_dict_file;
|
||||
std::string lm_path = FLAGS_lm_path;
|
||||
LOG(INFO) << "dict path: " << dict_file;
|
||||
LOG(INFO) << "lm path: " << lm_path;
|
||||
|
||||
int32 num_done = 0, num_err = 0;
|
||||
|
||||
ppspeech::CTCBeamSearchOptions opts;
|
||||
opts.dict_file = dict_file;
|
||||
opts.lm_path = lm_path;
|
||||
ppspeech::CTCBeamSearch decoder(opts);
|
||||
|
||||
std::shared_ptr<ppspeech::Decodable> decodable(
|
||||
new ppspeech::Decodable(nullptr, nullptr));
|
||||
|
||||
decoder.InitDecoder();
|
||||
|
||||
for (; !likelihood_reader.Done(); likelihood_reader.Next()) {
|
||||
string utt = likelihood_reader.Key();
|
||||
const kaldi::Matrix<BaseFloat> likelihood = likelihood_reader.Value();
|
||||
LOG(INFO) << "process utt: " << utt;
|
||||
LOG(INFO) << "rows: " << likelihood.NumRows();
|
||||
LOG(INFO) << "cols: " << likelihood.NumCols();
|
||||
decodable->Acceptlikelihood(likelihood);
|
||||
decoder.AdvanceDecode(decodable);
|
||||
std::string result;
|
||||
result = decoder.GetFinalBestPath();
|
||||
KALDI_LOG << " the result of " << utt << " is " << result;
|
||||
decodable->Reset();
|
||||
decoder.Reset();
|
||||
++num_done;
|
||||
}
|
||||
|
||||
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
|
||||
<< " with errors.";
|
||||
return (num_done != 0 ? 0 : 1);
|
||||
}
|
@ -1,218 +0,0 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "nnet/ds2_nnet.h"
|
||||
|
||||
#include "utils/strings.h"
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
using kaldi::Matrix;
|
||||
using kaldi::Vector;
|
||||
using std::shared_ptr;
|
||||
using std::string;
|
||||
using std::vector;
|
||||
|
||||
void PaddleNnet::InitCacheEncouts(const ModelOptions& opts) {
|
||||
std::vector<std::string> cache_names;
|
||||
cache_names = StrSplit(opts.cache_names, ",");
|
||||
std::vector<std::string> cache_shapes;
|
||||
cache_shapes = StrSplit(opts.cache_shape, ",");
|
||||
assert(cache_shapes.size() == cache_names.size());
|
||||
|
||||
cache_encouts_.clear();
|
||||
cache_names_idx_.clear();
|
||||
for (size_t i = 0; i < cache_shapes.size(); i++) {
|
||||
std::vector<std::string> tmp_shape;
|
||||
tmp_shape = StrSplit(cache_shapes[i], "-");
|
||||
std::vector<int> cur_shape;
|
||||
std::transform(tmp_shape.begin(),
|
||||
tmp_shape.end(),
|
||||
std::back_inserter(cur_shape),
|
||||
[](const std::string& s) { return atoi(s.c_str()); });
|
||||
cache_names_idx_[cache_names[i]] = i;
|
||||
std::shared_ptr<Tensor<BaseFloat>> cache_eout =
|
||||
std::make_shared<Tensor<BaseFloat>>(cur_shape);
|
||||
cache_encouts_.push_back(cache_eout);
|
||||
}
|
||||
}
|
||||
|
||||
PaddleNnet::PaddleNnet(const ModelOptions& opts) : opts_(opts) {
|
||||
subsampling_rate_ = opts.subsample_rate;
|
||||
paddle_infer::Config config;
|
||||
config.SetModel(opts.model_path, opts.param_path);
|
||||
if (opts.use_gpu) {
|
||||
config.EnableUseGpu(500, 0);
|
||||
}
|
||||
config.SwitchIrOptim(opts.switch_ir_optim);
|
||||
if (opts.enable_fc_padding == false) {
|
||||
config.DisableFCPadding();
|
||||
}
|
||||
if (opts.enable_profile) {
|
||||
config.EnableProfile();
|
||||
}
|
||||
pool.reset(
|
||||
new paddle_infer::services::PredictorPool(config, opts.thread_num));
|
||||
if (pool == nullptr) {
|
||||
LOG(ERROR) << "create the predictor pool failed";
|
||||
}
|
||||
pool_usages.resize(opts.thread_num);
|
||||
std::fill(pool_usages.begin(), pool_usages.end(), false);
|
||||
LOG(INFO) << "load paddle model success";
|
||||
|
||||
LOG(INFO) << "start to check the predictor input and output names";
|
||||
LOG(INFO) << "input names: " << opts.input_names;
|
||||
LOG(INFO) << "output names: " << opts.output_names;
|
||||
std::vector<std::string> input_names_vec = StrSplit(opts.input_names, ",");
|
||||
std::vector<std::string> output_names_vec = StrSplit(opts.output_names, ",");
|
||||
|
||||
paddle_infer::Predictor* predictor = GetPredictor();
|
||||
|
||||
std::vector<std::string> model_input_names = predictor->GetInputNames();
|
||||
assert(input_names_vec.size() == model_input_names.size());
|
||||
for (size_t i = 0; i < model_input_names.size(); i++) {
|
||||
assert(input_names_vec[i] == model_input_names[i]);
|
||||
}
|
||||
|
||||
std::vector<std::string> model_output_names = predictor->GetOutputNames();
|
||||
assert(output_names_vec.size() == model_output_names.size());
|
||||
for (size_t i = 0; i < output_names_vec.size(); i++) {
|
||||
assert(output_names_vec[i] == model_output_names[i]);
|
||||
}
|
||||
|
||||
ReleasePredictor(predictor);
|
||||
InitCacheEncouts(opts);
|
||||
}
|
||||
|
||||
void PaddleNnet::Reset() { InitCacheEncouts(opts_); }
|
||||
|
||||
paddle_infer::Predictor* PaddleNnet::GetPredictor() {
|
||||
paddle_infer::Predictor* predictor = nullptr;
|
||||
|
||||
std::lock_guard<std::mutex> guard(pool_mutex);
|
||||
int pred_id = 0;
|
||||
|
||||
while (pred_id < pool_usages.size()) {
|
||||
if (pool_usages[pred_id] == false) {
|
||||
predictor = pool->Retrive(pred_id);
|
||||
break;
|
||||
}
|
||||
++pred_id;
|
||||
}
|
||||
|
||||
if (predictor) {
|
||||
pool_usages[pred_id] = true;
|
||||
predictor_to_thread_id[predictor] = pred_id;
|
||||
} else {
|
||||
LOG(INFO) << "Failed to get predictor from pool !!!";
|
||||
}
|
||||
|
||||
return predictor;
|
||||
}
|
||||
|
||||
int PaddleNnet::ReleasePredictor(paddle_infer::Predictor* predictor) {
|
||||
std::lock_guard<std::mutex> guard(pool_mutex);
|
||||
auto iter = predictor_to_thread_id.find(predictor);
|
||||
|
||||
if (iter == predictor_to_thread_id.end()) {
|
||||
LOG(INFO) << "there is no such predictor";
|
||||
return 0;
|
||||
}
|
||||
|
||||
pool_usages[iter->second] = false;
|
||||
predictor_to_thread_id.erase(predictor);
|
||||
return 0;
|
||||
}
|
||||
|
||||
shared_ptr<Tensor<BaseFloat>> PaddleNnet::GetCacheEncoder(const string& name) {
|
||||
auto iter = cache_names_idx_.find(name);
|
||||
if (iter == cache_names_idx_.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
assert(iter->second < cache_encouts_.size());
|
||||
return cache_encouts_[iter->second];
|
||||
}
|
||||
|
||||
void PaddleNnet::FeedForward(const Vector<BaseFloat>& features,
|
||||
const int32& feature_dim,
|
||||
NnetOut* out) {
|
||||
paddle_infer::Predictor* predictor = GetPredictor();
|
||||
|
||||
int feat_row = features.Dim() / feature_dim;
|
||||
|
||||
std::vector<std::string> input_names = predictor->GetInputNames();
|
||||
std::vector<std::string> output_names = predictor->GetOutputNames();
|
||||
|
||||
// feed inputs
|
||||
std::unique_ptr<paddle_infer::Tensor> input_tensor =
|
||||
predictor->GetInputHandle(input_names[0]);
|
||||
std::vector<int> INPUT_SHAPE = {1, feat_row, feature_dim};
|
||||
input_tensor->Reshape(INPUT_SHAPE);
|
||||
input_tensor->CopyFromCpu(features.Data());
|
||||
|
||||
std::unique_ptr<paddle_infer::Tensor> input_len =
|
||||
predictor->GetInputHandle(input_names[1]);
|
||||
std::vector<int> input_len_size = {1};
|
||||
input_len->Reshape(input_len_size);
|
||||
std::vector<int64_t> audio_len;
|
||||
audio_len.push_back(feat_row);
|
||||
input_len->CopyFromCpu(audio_len.data());
|
||||
|
||||
std::unique_ptr<paddle_infer::Tensor> state_h =
|
||||
predictor->GetInputHandle(input_names[2]);
|
||||
shared_ptr<Tensor<BaseFloat>> h_cache = GetCacheEncoder(input_names[2]);
|
||||
state_h->Reshape(h_cache->get_shape());
|
||||
state_h->CopyFromCpu(h_cache->get_data().data());
|
||||
|
||||
std::unique_ptr<paddle_infer::Tensor> state_c =
|
||||
predictor->GetInputHandle(input_names[3]);
|
||||
shared_ptr<Tensor<float>> c_cache = GetCacheEncoder(input_names[3]);
|
||||
state_c->Reshape(c_cache->get_shape());
|
||||
state_c->CopyFromCpu(c_cache->get_data().data());
|
||||
|
||||
// forward
|
||||
bool success = predictor->Run();
|
||||
|
||||
if (success == false) {
|
||||
LOG(INFO) << "predictor run occurs error";
|
||||
}
|
||||
|
||||
// fetch outpus
|
||||
std::unique_ptr<paddle_infer::Tensor> h_out =
|
||||
predictor->GetOutputHandle(output_names[2]);
|
||||
assert(h_cache->get_shape() == h_out->shape());
|
||||
h_out->CopyToCpu(h_cache->get_data().data());
|
||||
|
||||
std::unique_ptr<paddle_infer::Tensor> c_out =
|
||||
predictor->GetOutputHandle(output_names[3]);
|
||||
assert(c_cache->get_shape() == c_out->shape());
|
||||
c_out->CopyToCpu(c_cache->get_data().data());
|
||||
|
||||
std::unique_ptr<paddle_infer::Tensor> output_tensor =
|
||||
predictor->GetOutputHandle(output_names[0]);
|
||||
std::vector<int> output_shape = output_tensor->shape();
|
||||
int32 row = output_shape[1];
|
||||
int32 col = output_shape[2];
|
||||
|
||||
|
||||
// inferences->Resize(row * col);
|
||||
// *inference_dim = col;
|
||||
out->logprobs.Resize(row * col);
|
||||
out->vocab_dim = col;
|
||||
output_tensor->CopyToCpu(out->logprobs.Data());
|
||||
|
||||
ReleasePredictor(predictor);
|
||||
}
|
||||
|
||||
} // namespace ppspeech
|
@ -1,97 +0,0 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
#include <numeric>
|
||||
|
||||
#include "base/common.h"
|
||||
#include "kaldi/matrix/kaldi-matrix.h"
|
||||
#include "nnet/nnet_itf.h"
|
||||
#include "paddle_inference_api.h"
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
|
||||
template <typename T>
|
||||
class Tensor {
|
||||
public:
|
||||
Tensor() {}
|
||||
explicit Tensor(const std::vector<int>& shape) : _shape(shape) {
|
||||
int neml = std::accumulate(
|
||||
_shape.begin(), _shape.end(), 1, std::multiplies<int>());
|
||||
LOG(INFO) << "Tensor neml: " << neml;
|
||||
_data.resize(neml, 0);
|
||||
}
|
||||
|
||||
void reshape(const std::vector<int>& shape) {
|
||||
_shape = shape;
|
||||
int neml = std::accumulate(
|
||||
_shape.begin(), _shape.end(), 1, std::multiplies<int>());
|
||||
_data.resize(neml, 0);
|
||||
}
|
||||
|
||||
const std::vector<int>& get_shape() const { return _shape; }
|
||||
std::vector<T>& get_data() { return _data; }
|
||||
|
||||
private:
|
||||
std::vector<int> _shape;
|
||||
std::vector<T> _data;
|
||||
};
|
||||
|
||||
class PaddleNnet : public NnetBase {
|
||||
public:
|
||||
explicit PaddleNnet(const ModelOptions& opts);
|
||||
|
||||
void FeedForward(const kaldi::Vector<kaldi::BaseFloat>& features,
|
||||
const int32& feature_dim,
|
||||
NnetOut* out) override;
|
||||
|
||||
void AttentionRescoring(const std::vector<std::vector<int>>& hyps,
|
||||
float reverse_weight,
|
||||
std::vector<float>* rescoring_score) override {
|
||||
VLOG(2) << "deepspeech2 not has AttentionRescoring.";
|
||||
}
|
||||
|
||||
void Dim();
|
||||
|
||||
void Reset() override;
|
||||
|
||||
bool IsLogProb() override { return false; }
|
||||
|
||||
|
||||
std::shared_ptr<Tensor<kaldi::BaseFloat>> GetCacheEncoder(
|
||||
const std::string& name);
|
||||
|
||||
void InitCacheEncouts(const ModelOptions& opts);
|
||||
|
||||
void EncoderOuts(std::vector<kaldi::Vector<kaldi::BaseFloat>>* encoder_out)
|
||||
const override {}
|
||||
|
||||
private:
|
||||
paddle_infer::Predictor* GetPredictor();
|
||||
int ReleasePredictor(paddle_infer::Predictor* predictor);
|
||||
|
||||
std::unique_ptr<paddle_infer::services::PredictorPool> pool;
|
||||
std::vector<bool> pool_usages;
|
||||
std::mutex pool_mutex;
|
||||
std::map<paddle_infer::Predictor*, int> predictor_to_thread_id;
|
||||
std::map<std::string, int> cache_names_idx_;
|
||||
std::vector<std::shared_ptr<Tensor<kaldi::BaseFloat>>> cache_encouts_;
|
||||
|
||||
ModelOptions opts_;
|
||||
|
||||
public:
|
||||
DISALLOW_COPY_AND_ASSIGN(PaddleNnet);
|
||||
};
|
||||
|
||||
} // namespace ppspeech
|
@ -1,142 +0,0 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "base/common.h"
|
||||
#include "decoder/param.h"
|
||||
#include "frontend/audio/assembler.h"
|
||||
#include "frontend/audio/data_cache.h"
|
||||
#include "kaldi/util/table-types.h"
|
||||
#include "nnet/decodable.h"
|
||||
#include "nnet/ds2_nnet.h"
|
||||
|
||||
DEFINE_string(feature_rspecifier, "", "test feature rspecifier");
|
||||
DEFINE_string(nnet_prob_wspecifier, "", "nnet porb wspecifier");
|
||||
|
||||
using kaldi::BaseFloat;
|
||||
using kaldi::Matrix;
|
||||
using std::vector;
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
gflags::SetUsageMessage("Usage:");
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
google::InstallFailureSignalHandler();
|
||||
FLAGS_logtostderr = 1;
|
||||
|
||||
kaldi::SequentialBaseFloatMatrixReader feature_reader(
|
||||
FLAGS_feature_rspecifier);
|
||||
kaldi::BaseFloatMatrixWriter nnet_writer(FLAGS_nnet_prob_wspecifier);
|
||||
std::string model_graph = FLAGS_model_path;
|
||||
std::string model_params = FLAGS_param_path;
|
||||
LOG(INFO) << "model path: " << model_graph;
|
||||
LOG(INFO) << "model param: " << model_params;
|
||||
|
||||
int32 num_done = 0, num_err = 0;
|
||||
|
||||
ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags();
|
||||
|
||||
std::shared_ptr<ppspeech::PaddleNnet> nnet(
|
||||
new ppspeech::PaddleNnet(model_opts));
|
||||
std::shared_ptr<ppspeech::DataCache> raw_data(new ppspeech::DataCache());
|
||||
std::shared_ptr<ppspeech::Decodable> decodable(
|
||||
new ppspeech::Decodable(nnet, raw_data, FLAGS_acoustic_scale));
|
||||
|
||||
int32 chunk_size = FLAGS_receptive_field_length +
|
||||
(FLAGS_nnet_decoder_chunk - 1) * FLAGS_subsampling_rate;
|
||||
int32 chunk_stride = FLAGS_subsampling_rate * FLAGS_nnet_decoder_chunk;
|
||||
int32 receptive_field_length = FLAGS_receptive_field_length;
|
||||
LOG(INFO) << "chunk size (frame): " << chunk_size;
|
||||
LOG(INFO) << "chunk stride (frame): " << chunk_stride;
|
||||
LOG(INFO) << "receptive field (frame): " << receptive_field_length;
|
||||
kaldi::Timer timer;
|
||||
for (; !feature_reader.Done(); feature_reader.Next()) {
|
||||
string utt = feature_reader.Key();
|
||||
kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
|
||||
raw_data->SetDim(feature.NumCols());
|
||||
LOG(INFO) << "process utt: " << utt;
|
||||
LOG(INFO) << "rows: " << feature.NumRows();
|
||||
LOG(INFO) << "cols: " << feature.NumCols();
|
||||
|
||||
int32 row_idx = 0;
|
||||
int32 padding_len = 0;
|
||||
int32 ori_feature_len = feature.NumRows();
|
||||
if ((feature.NumRows() - chunk_size) % chunk_stride != 0) {
|
||||
padding_len =
|
||||
chunk_stride - (feature.NumRows() - chunk_size) % chunk_stride;
|
||||
feature.Resize(feature.NumRows() + padding_len,
|
||||
feature.NumCols(),
|
||||
kaldi::kCopyData);
|
||||
}
|
||||
int32 num_chunks = (feature.NumRows() - chunk_size) / chunk_stride + 1;
|
||||
int32 frame_idx = 0;
|
||||
std::vector<kaldi::Vector<kaldi::BaseFloat>> prob_vec;
|
||||
for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) {
|
||||
kaldi::Vector<kaldi::BaseFloat> feature_chunk(chunk_size *
|
||||
feature.NumCols());
|
||||
int32 feature_chunk_size = 0;
|
||||
if (ori_feature_len > chunk_idx * chunk_stride) {
|
||||
feature_chunk_size = std::min(
|
||||
ori_feature_len - chunk_idx * chunk_stride, chunk_size);
|
||||
}
|
||||
if (feature_chunk_size < receptive_field_length) break;
|
||||
|
||||
int32 start = chunk_idx * chunk_stride;
|
||||
for (int row_id = 0; row_id < chunk_size; ++row_id) {
|
||||
kaldi::SubVector<kaldi::BaseFloat> tmp(feature, start);
|
||||
kaldi::SubVector<kaldi::BaseFloat> f_chunk_tmp(
|
||||
feature_chunk.Data() + row_id * feature.NumCols(),
|
||||
feature.NumCols());
|
||||
f_chunk_tmp.CopyFromVec(tmp);
|
||||
++start;
|
||||
}
|
||||
raw_data->Accept(feature_chunk);
|
||||
if (chunk_idx == num_chunks - 1) {
|
||||
raw_data->SetFinished();
|
||||
}
|
||||
vector<kaldi::BaseFloat> prob;
|
||||
while (decodable->FrameLikelihood(frame_idx, &prob)) {
|
||||
kaldi::Vector<kaldi::BaseFloat> vec_tmp(prob.size());
|
||||
std::memcpy(vec_tmp.Data(),
|
||||
prob.data(),
|
||||
sizeof(kaldi::BaseFloat) * prob.size());
|
||||
prob_vec.push_back(vec_tmp);
|
||||
frame_idx++;
|
||||
}
|
||||
}
|
||||
decodable->Reset();
|
||||
if (prob_vec.size() == 0) {
|
||||
// the TokenWriter can not write empty string.
|
||||
++num_err;
|
||||
KALDI_LOG << " the nnet prob of " << utt << " is empty";
|
||||
continue;
|
||||
}
|
||||
kaldi::Matrix<kaldi::BaseFloat> result(prob_vec.size(),
|
||||
prob_vec[0].Dim());
|
||||
for (int row_idx = 0; row_idx < prob_vec.size(); ++row_idx) {
|
||||
for (int32 col_idx = 0; col_idx < prob_vec[0].Dim(); ++col_idx) {
|
||||
result(row_idx, col_idx) = prob_vec[row_idx](col_idx);
|
||||
}
|
||||
}
|
||||
|
||||
nnet_writer.Write(utt, result);
|
||||
++num_done;
|
||||
}
|
||||
|
||||
double elapsed = timer.Elapsed();
|
||||
KALDI_LOG << " cost:" << elapsed << " s";
|
||||
|
||||
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
|
||||
<< " with errors.";
|
||||
return (num_done != 0 ? 0 : 1);
|
||||
}
|
@ -1,70 +0,0 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "recognizer/recognizer.h"
|
||||
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
using kaldi::BaseFloat;
|
||||
using kaldi::SubVector;
|
||||
using kaldi::Vector;
|
||||
using kaldi::VectorBase;
|
||||
using std::unique_ptr;
|
||||
using std::vector;
|
||||
|
||||
|
||||
Recognizer::Recognizer(const RecognizerResource& resource) {
|
||||
// resource_ = resource;
|
||||
const FeaturePipelineOptions& feature_opts = resource.feature_pipeline_opts;
|
||||
feature_pipeline_.reset(new FeaturePipeline(feature_opts));
|
||||
|
||||
std::shared_ptr<PaddleNnet> nnet(new PaddleNnet(resource.model_opts));
|
||||
|
||||
BaseFloat ac_scale = resource.acoustic_scale;
|
||||
decodable_.reset(new Decodable(nnet, feature_pipeline_, ac_scale));
|
||||
|
||||
decoder_.reset(new TLGDecoder(resource.tlg_opts));
|
||||
|
||||
input_finished_ = false;
|
||||
}
|
||||
|
||||
void Recognizer::Accept(const Vector<BaseFloat>& waves) {
|
||||
feature_pipeline_->Accept(waves);
|
||||
}
|
||||
|
||||
void Recognizer::Decode() { decoder_->AdvanceDecode(decodable_); }
|
||||
|
||||
std::string Recognizer::GetFinalResult() {
|
||||
return decoder_->GetFinalBestPath();
|
||||
}
|
||||
|
||||
std::string Recognizer::GetPartialResult() {
|
||||
return decoder_->GetPartialResult();
|
||||
}
|
||||
|
||||
void Recognizer::SetFinished() {
|
||||
feature_pipeline_->SetFinished();
|
||||
input_finished_ = true;
|
||||
}
|
||||
|
||||
bool Recognizer::IsFinished() { return input_finished_; }
|
||||
|
||||
void Recognizer::Reset() {
|
||||
feature_pipeline_->Reset();
|
||||
decodable_->Reset();
|
||||
decoder_->Reset();
|
||||
}
|
||||
|
||||
} // namespace ppspeech
|
@ -1,70 +0,0 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// todo refactor later (SGoat)
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "decoder/ctc_beam_search_decoder.h"
|
||||
#include "decoder/ctc_tlg_decoder.h"
|
||||
#include "frontend/audio/feature_pipeline.h"
|
||||
#include "nnet/decodable.h"
|
||||
#include "nnet/ds2_nnet.h"
|
||||
|
||||
DECLARE_double(acoustic_scale);
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
struct RecognizerResource {
|
||||
kaldi::BaseFloat acoustic_scale{1.0};
|
||||
FeaturePipelineOptions feature_pipeline_opts{};
|
||||
ModelOptions model_opts{};
|
||||
TLGDecoderOptions tlg_opts{};
|
||||
// CTCBeamSearchOptions beam_search_opts;
|
||||
|
||||
static RecognizerResource InitFromFlags() {
|
||||
RecognizerResource resource;
|
||||
resource.acoustic_scale = FLAGS_acoustic_scale;
|
||||
resource.feature_pipeline_opts =
|
||||
FeaturePipelineOptions::InitFromFlags();
|
||||
resource.feature_pipeline_opts.assembler_opts.fill_zero = true;
|
||||
LOG(INFO) << "ds2 need fill zero be true: "
|
||||
<< resource.feature_pipeline_opts.assembler_opts.fill_zero;
|
||||
resource.model_opts = ModelOptions::InitFromFlags();
|
||||
resource.tlg_opts = TLGDecoderOptions::InitFromFlags();
|
||||
return resource;
|
||||
}
|
||||
};
|
||||
|
||||
class Recognizer {
|
||||
public:
|
||||
explicit Recognizer(const RecognizerResource& resouce);
|
||||
void Accept(const kaldi::Vector<kaldi::BaseFloat>& waves);
|
||||
void Decode();
|
||||
std::string GetFinalResult();
|
||||
std::string GetPartialResult();
|
||||
void SetFinished();
|
||||
bool IsFinished();
|
||||
void Reset();
|
||||
|
||||
private:
|
||||
// std::shared_ptr<RecognizerResource> resource_;
|
||||
// RecognizerResource resource_;
|
||||
std::shared_ptr<FeaturePipeline> feature_pipeline_;
|
||||
std::shared_ptr<Decodable> decodable_;
|
||||
std::unique_ptr<TLGDecoder> decoder_;
|
||||
bool input_finished_;
|
||||
};
|
||||
|
||||
} // namespace ppspeech
|
@ -1,105 +0,0 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "decoder/param.h"
|
||||
#include "kaldi/feat/wave-reader.h"
|
||||
#include "kaldi/util/table-types.h"
|
||||
#include "recognizer/recognizer.h"
|
||||
|
||||
DEFINE_string(wav_rspecifier, "", "test feature rspecifier");
|
||||
DEFINE_string(result_wspecifier, "", "test result wspecifier");
|
||||
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
|
||||
DEFINE_int32(sample_rate, 16000, "sample rate");
|
||||
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
gflags::SetUsageMessage("Usage:");
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
google::InstallFailureSignalHandler();
|
||||
FLAGS_logtostderr = 1;
|
||||
|
||||
ppspeech::RecognizerResource resource =
|
||||
ppspeech::RecognizerResource::InitFromFlags();
|
||||
ppspeech::Recognizer recognizer(resource);
|
||||
|
||||
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
|
||||
FLAGS_wav_rspecifier);
|
||||
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
|
||||
|
||||
int sample_rate = FLAGS_sample_rate;
|
||||
float streaming_chunk = FLAGS_streaming_chunk;
|
||||
int chunk_sample_size = streaming_chunk * sample_rate;
|
||||
LOG(INFO) << "sr: " << sample_rate;
|
||||
LOG(INFO) << "chunk size (s): " << streaming_chunk;
|
||||
LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
|
||||
|
||||
int32 num_done = 0, num_err = 0;
|
||||
double tot_wav_duration = 0.0;
|
||||
|
||||
kaldi::Timer timer;
|
||||
|
||||
for (; !wav_reader.Done(); wav_reader.Next()) {
|
||||
std::string utt = wav_reader.Key();
|
||||
const kaldi::WaveData& wave_data = wav_reader.Value();
|
||||
|
||||
int32 this_channel = 0;
|
||||
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
|
||||
this_channel);
|
||||
int tot_samples = waveform.Dim();
|
||||
tot_wav_duration += tot_samples * 1.0 / sample_rate;
|
||||
LOG(INFO) << "wav len (sample): " << tot_samples;
|
||||
|
||||
int sample_offset = 0;
|
||||
std::vector<kaldi::Vector<BaseFloat>> feats;
|
||||
int feature_rows = 0;
|
||||
while (sample_offset < tot_samples) {
|
||||
int cur_chunk_size =
|
||||
std::min(chunk_sample_size, tot_samples - sample_offset);
|
||||
|
||||
kaldi::Vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
|
||||
for (int i = 0; i < cur_chunk_size; ++i) {
|
||||
wav_chunk(i) = waveform(sample_offset + i);
|
||||
}
|
||||
// wav_chunk = waveform.Range(sample_offset + i, cur_chunk_size);
|
||||
|
||||
recognizer.Accept(wav_chunk);
|
||||
if (cur_chunk_size < chunk_sample_size) {
|
||||
recognizer.SetFinished();
|
||||
}
|
||||
recognizer.Decode();
|
||||
|
||||
// no overlap
|
||||
sample_offset += cur_chunk_size;
|
||||
}
|
||||
|
||||
std::string result;
|
||||
result = recognizer.GetFinalResult();
|
||||
recognizer.Reset();
|
||||
if (result.empty()) {
|
||||
// the TokenWriter can not write empty string.
|
||||
++num_err;
|
||||
KALDI_LOG << " the result of " << utt << " is empty";
|
||||
continue;
|
||||
}
|
||||
KALDI_LOG << " the result of " << utt << " is " << result;
|
||||
result_writer.Write(utt, result);
|
||||
++num_done;
|
||||
}
|
||||
double elapsed = timer.Elapsed();
|
||||
KALDI_LOG << "Done " << num_done << " out of " << (num_err + num_done);
|
||||
KALDI_LOG << " cost:" << elapsed << " s";
|
||||
KALDI_LOG << "total wav duration is: " << tot_wav_duration << " s";
|
||||
KALDI_LOG << "the RTF is: " << elapsed / tot_wav_duration;
|
||||
}
|
@ -1,6 +0,0 @@
|
||||
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
|
||||
|
||||
set(bin_name ds2_model_test_main)
|
||||
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
|
||||
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
||||
target_link_libraries(${bin_name} PUBLIC nnet gflags glog ${DEPS})
|
@ -1,207 +0,0 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// deepspeech2 online model info
|
||||
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <iterator>
|
||||
#include <numeric>
|
||||
#include <thread>
|
||||
|
||||
#include "base/flags.h"
|
||||
#include "base/log.h"
|
||||
#include "paddle_inference_api.h"
|
||||
|
||||
using std::cout;
|
||||
using std::endl;
|
||||
|
||||
|
||||
DEFINE_string(model_path, "", "xxx.pdmodel");
|
||||
DEFINE_string(param_path, "", "xxx.pdiparams");
|
||||
DEFINE_int32(chunk_size, 35, "feature chunk size, unit:frame");
|
||||
DEFINE_int32(feat_dim, 161, "feature dim");
|
||||
|
||||
|
||||
void produce_data(std::vector<std::vector<float>>* data);
|
||||
void model_forward_test();
|
||||
|
||||
void produce_data(std::vector<std::vector<float>>* data) {
|
||||
int chunk_size = FLAGS_chunk_size; // chunk_size in frame
|
||||
int col_size = FLAGS_feat_dim; // feat dim
|
||||
cout << "chunk size: " << chunk_size << endl;
|
||||
cout << "feat dim: " << col_size << endl;
|
||||
|
||||
data->reserve(chunk_size);
|
||||
data->back().reserve(col_size);
|
||||
for (int row = 0; row < chunk_size; ++row) {
|
||||
data->push_back(std::vector<float>());
|
||||
for (int col_idx = 0; col_idx < col_size; ++col_idx) {
|
||||
data->back().push_back(0.201);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void model_forward_test() {
|
||||
std::cout << "1. read the data" << std::endl;
|
||||
std::vector<std::vector<float>> feats;
|
||||
produce_data(&feats);
|
||||
|
||||
std::cout << "2. load the model" << std::endl;
|
||||
;
|
||||
std::string model_graph = FLAGS_model_path;
|
||||
std::string model_params = FLAGS_param_path;
|
||||
CHECK_NE(model_graph, "");
|
||||
CHECK_NE(model_params, "");
|
||||
cout << "model path: " << model_graph << endl;
|
||||
cout << "model param path : " << model_params << endl;
|
||||
|
||||
paddle_infer::Config config;
|
||||
config.SetModel(model_graph, model_params);
|
||||
config.SwitchIrOptim(false);
|
||||
cout << "SwitchIrOptim: " << false << endl;
|
||||
config.DisableFCPadding();
|
||||
cout << "DisableFCPadding: " << endl;
|
||||
auto predictor = paddle_infer::CreatePredictor(config);
|
||||
|
||||
std::cout << "3. feat shape, row=" << feats.size()
|
||||
<< ",col=" << feats[0].size() << std::endl;
|
||||
std::vector<float> pp_input_mat;
|
||||
for (const auto& item : feats) {
|
||||
pp_input_mat.insert(pp_input_mat.end(), item.begin(), item.end());
|
||||
}
|
||||
|
||||
std::cout << "4. fead the data to model" << std::endl;
|
||||
int row = feats.size();
|
||||
int col = feats[0].size();
|
||||
std::vector<std::string> input_names = predictor->GetInputNames();
|
||||
std::vector<std::string> output_names = predictor->GetOutputNames();
|
||||
for (auto name : input_names) {
|
||||
cout << "model input names: " << name << endl;
|
||||
}
|
||||
for (auto name : output_names) {
|
||||
cout << "model output names: " << name << endl;
|
||||
}
|
||||
|
||||
// input
|
||||
std::unique_ptr<paddle_infer::Tensor> input_tensor =
|
||||
predictor->GetInputHandle(input_names[0]);
|
||||
std::vector<int> INPUT_SHAPE = {1, row, col};
|
||||
input_tensor->Reshape(INPUT_SHAPE);
|
||||
input_tensor->CopyFromCpu(pp_input_mat.data());
|
||||
|
||||
// input length
|
||||
std::unique_ptr<paddle_infer::Tensor> input_len =
|
||||
predictor->GetInputHandle(input_names[1]);
|
||||
std::vector<int> input_len_size = {1};
|
||||
input_len->Reshape(input_len_size);
|
||||
std::vector<int64_t> audio_len;
|
||||
audio_len.push_back(row);
|
||||
input_len->CopyFromCpu(audio_len.data());
|
||||
|
||||
// state_h
|
||||
std::unique_ptr<paddle_infer::Tensor> chunk_state_h_box =
|
||||
predictor->GetInputHandle(input_names[2]);
|
||||
std::vector<int> chunk_state_h_box_shape = {5, 1, 1024};
|
||||
chunk_state_h_box->Reshape(chunk_state_h_box_shape);
|
||||
int chunk_state_h_box_size =
|
||||
std::accumulate(chunk_state_h_box_shape.begin(),
|
||||
chunk_state_h_box_shape.end(),
|
||||
1,
|
||||
std::multiplies<int>());
|
||||
std::vector<float> chunk_state_h_box_data(chunk_state_h_box_size, 0.0f);
|
||||
chunk_state_h_box->CopyFromCpu(chunk_state_h_box_data.data());
|
||||
|
||||
// state_c
|
||||
std::unique_ptr<paddle_infer::Tensor> chunk_state_c_box =
|
||||
predictor->GetInputHandle(input_names[3]);
|
||||
std::vector<int> chunk_state_c_box_shape = {5, 1, 1024};
|
||||
chunk_state_c_box->Reshape(chunk_state_c_box_shape);
|
||||
int chunk_state_c_box_size =
|
||||
std::accumulate(chunk_state_c_box_shape.begin(),
|
||||
chunk_state_c_box_shape.end(),
|
||||
1,
|
||||
std::multiplies<int>());
|
||||
std::vector<float> chunk_state_c_box_data(chunk_state_c_box_size, 0.0f);
|
||||
chunk_state_c_box->CopyFromCpu(chunk_state_c_box_data.data());
|
||||
|
||||
// run
|
||||
bool success = predictor->Run();
|
||||
|
||||
// state_h out
|
||||
std::unique_ptr<paddle_infer::Tensor> h_out =
|
||||
predictor->GetOutputHandle(output_names[2]);
|
||||
std::vector<int> h_out_shape = h_out->shape();
|
||||
int h_out_size = std::accumulate(
|
||||
h_out_shape.begin(), h_out_shape.end(), 1, std::multiplies<int>());
|
||||
std::vector<float> h_out_data(h_out_size);
|
||||
h_out->CopyToCpu(h_out_data.data());
|
||||
|
||||
// stage_c out
|
||||
std::unique_ptr<paddle_infer::Tensor> c_out =
|
||||
predictor->GetOutputHandle(output_names[3]);
|
||||
std::vector<int> c_out_shape = c_out->shape();
|
||||
int c_out_size = std::accumulate(
|
||||
c_out_shape.begin(), c_out_shape.end(), 1, std::multiplies<int>());
|
||||
std::vector<float> c_out_data(c_out_size);
|
||||
c_out->CopyToCpu(c_out_data.data());
|
||||
|
||||
// output tensor
|
||||
std::unique_ptr<paddle_infer::Tensor> output_tensor =
|
||||
predictor->GetOutputHandle(output_names[0]);
|
||||
std::vector<int> output_shape = output_tensor->shape();
|
||||
std::vector<float> output_probs;
|
||||
int output_size = std::accumulate(
|
||||
output_shape.begin(), output_shape.end(), 1, std::multiplies<int>());
|
||||
output_probs.resize(output_size);
|
||||
output_tensor->CopyToCpu(output_probs.data());
|
||||
row = output_shape[1];
|
||||
col = output_shape[2];
|
||||
|
||||
// probs
|
||||
std::vector<std::vector<float>> probs;
|
||||
probs.reserve(row);
|
||||
for (int i = 0; i < row; i++) {
|
||||
probs.push_back(std::vector<float>());
|
||||
probs.back().reserve(col);
|
||||
|
||||
for (int j = 0; j < col; j++) {
|
||||
probs.back().push_back(output_probs[i * col + j]);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<float>> log_feat = probs;
|
||||
std::cout << "probs, row: " << log_feat.size()
|
||||
<< " col: " << log_feat[0].size() << std::endl;
|
||||
for (size_t row_idx = 0; row_idx < log_feat.size(); ++row_idx) {
|
||||
for (size_t col_idx = 0; col_idx < log_feat[row_idx].size();
|
||||
++col_idx) {
|
||||
std::cout << log_feat[row_idx][col_idx] << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
gflags::SetUsageMessage("Usage:");
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
google::InstallFailureSignalHandler();
|
||||
FLAGS_logtostderr = 1;
|
||||
|
||||
model_forward_test();
|
||||
return 0;
|
||||
}
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in new issue