parent
1ecc345bd6
commit
2857b24ecc
@ -1,7 +0,0 @@
|
|||||||
# Deepspeech2 Streaming ASR
|
|
||||||
|
|
||||||
## Examples
|
|
||||||
|
|
||||||
* `websocket` - Streaming ASR with websocket for deepspeech2_aishell.
|
|
||||||
* `aishell` - Streaming Decoding under aishell dataset, for local WER test.
|
|
||||||
* `onnx` - Example to convert deepspeech2 to onnx format.
|
|
@ -1,3 +0,0 @@
|
|||||||
data
|
|
||||||
exp
|
|
||||||
aishell_*
|
|
@ -1,133 +0,0 @@
|
|||||||
# Aishell - Deepspeech2 Streaming
|
|
||||||
|
|
||||||
> We recommend using U2/U2++ model instead of DS2, please see [here](../../u2pp_ol/wenetspeech/).
|
|
||||||
|
|
||||||
A C++ deployment example for using the deepspeech2 model to recognize `wav` and compute `CER`. We using AISHELL-1 as test data.
|
|
||||||
|
|
||||||
## Source path.sh
|
|
||||||
|
|
||||||
```bash
|
|
||||||
. path.sh
|
|
||||||
```
|
|
||||||
|
|
||||||
SpeechX bins is under `echo $SPEECHX_BUILD`, more info please see `path.sh`.
|
|
||||||
|
|
||||||
## Recognize with linear feature
|
|
||||||
|
|
||||||
```bash
|
|
||||||
bash run.sh
|
|
||||||
```
|
|
||||||
|
|
||||||
`run.sh` has multi stage, for details please see `run.sh`:
|
|
||||||
|
|
||||||
1. donwload dataset, model and lm
|
|
||||||
2. convert cmvn format and compute feature
|
|
||||||
3. decode w/o lm by feature
|
|
||||||
4. decode w/ ngram lm by feature
|
|
||||||
5. decode w/ TLG graph by feature
|
|
||||||
6. recognize w/ TLG graph by wav input
|
|
||||||
|
|
||||||
### Recognize with `.scp` file for wav
|
|
||||||
|
|
||||||
This sciprt using `recognizer_main` to recognize wav file.
|
|
||||||
|
|
||||||
The input is `scp` file which look like this:
|
|
||||||
```text
|
|
||||||
# head data/split1/1/aishell_test.scp
|
|
||||||
BAC009S0764W0121 /workspace/PaddleSpeech/speechx/examples/u2pp_ol/wenetspeech/data/test/S0764/BAC009S0764W0121.wav
|
|
||||||
BAC009S0764W0122 /workspace/PaddleSpeech/speechx/examples/u2pp_ol/wenetspeech/data/test/S0764/BAC009S0764W0122.wav
|
|
||||||
...
|
|
||||||
BAC009S0764W0125 /workspace/PaddleSpeech/speechx/examples/u2pp_ol/wenetspeech/data/test/S0764/BAC009S0764W0125.wav
|
|
||||||
```
|
|
||||||
|
|
||||||
If you want to recognize one wav, you can make `scp` file like this:
|
|
||||||
```text
|
|
||||||
key path/to/wav/file
|
|
||||||
```
|
|
||||||
|
|
||||||
Then specify `--wav_rspecifier=` param for `recognizer_main` bin. For other flags meaning, please see `help`:
|
|
||||||
```bash
|
|
||||||
recognizer_main --help
|
|
||||||
```
|
|
||||||
|
|
||||||
For the exmaple to using `recognizer_main` please see `run.sh`.
|
|
||||||
|
|
||||||
|
|
||||||
### CTC Prefix Beam Search w/o LM
|
|
||||||
|
|
||||||
```
|
|
||||||
Overall -> 16.14 % N=104612 C=88190 S=16110 D=312 I=465
|
|
||||||
Mandarin -> 16.14 % N=104612 C=88190 S=16110 D=312 I=465
|
|
||||||
Other -> 0.00 % N=0 C=0 S=0 D=0 I=0
|
|
||||||
```
|
|
||||||
|
|
||||||
### CTC Prefix Beam Search w/ LM
|
|
||||||
|
|
||||||
LM: zh_giga.no_cna_cmn.prune01244.klm
|
|
||||||
```
|
|
||||||
Overall -> 7.86 % N=104768 C=96865 S=7573 D=330 I=327
|
|
||||||
Mandarin -> 7.86 % N=104768 C=96865 S=7573 D=330 I=327
|
|
||||||
Other -> 0.00 % N=0 C=0 S=0 D=0 I=0
|
|
||||||
```
|
|
||||||
|
|
||||||
### CTC TLG WFST
|
|
||||||
|
|
||||||
LM: [aishell train](http://paddlespeech.bj.bcebos.com/speechx/examples/ds2_ol/aishell/aishell_graph.zip)
|
|
||||||
--acoustic_scale=1.2
|
|
||||||
```
|
|
||||||
Overall -> 11.14 % N=103017 C=93363 S=9583 D=71 I=1819
|
|
||||||
Mandarin -> 11.14 % N=103017 C=93363 S=9583 D=71 I=1818
|
|
||||||
Other -> 0.00 % N=0 C=0 S=0 D=0 I=1
|
|
||||||
```
|
|
||||||
|
|
||||||
LM: [wenetspeech](http://paddlespeech.bj.bcebos.com/speechx/examples/ds2_ol/aishell/wenetspeech_graph.zip)
|
|
||||||
--acoustic_scale=1.5
|
|
||||||
```
|
|
||||||
Overall -> 10.93 % N=104765 C=93410 S=9780 D=1575 I=95
|
|
||||||
Mandarin -> 10.93 % N=104762 C=93410 S=9779 D=1573 I=95
|
|
||||||
Other -> 100.00 % N=3 C=0 S=1 D=2 I=0
|
|
||||||
```
|
|
||||||
|
|
||||||
## Recognize with fbank feature
|
|
||||||
|
|
||||||
This script is same to `run.sh`, but using fbank feature.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
bash run_fbank.sh
|
|
||||||
```
|
|
||||||
|
|
||||||
### CTC Prefix Beam Search w/o LM
|
|
||||||
|
|
||||||
```
|
|
||||||
Overall -> 10.44 % N=104765 C=94194 S=10174 D=397 I=369
|
|
||||||
Mandarin -> 10.44 % N=104762 C=94194 S=10171 D=397 I=369
|
|
||||||
Other -> 100.00 % N=3 C=0 S=3 D=0 I=0
|
|
||||||
```
|
|
||||||
|
|
||||||
### CTC Prefix Beam Search w/ LM
|
|
||||||
|
|
||||||
LM: zh_giga.no_cna_cmn.prune01244.klm
|
|
||||||
|
|
||||||
```
|
|
||||||
Overall -> 5.82 % N=104765 C=99386 S=4944 D=435 I=720
|
|
||||||
Mandarin -> 5.82 % N=104762 C=99386 S=4941 D=435 I=720
|
|
||||||
English -> 0.00 % N=0 C=0 S=0 D=0 I=0
|
|
||||||
```
|
|
||||||
|
|
||||||
### CTC TLG WFST
|
|
||||||
|
|
||||||
LM: [aishell train](https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_graph2.zip)
|
|
||||||
```
|
|
||||||
Overall -> 9.58 % N=104765 C=94817 S=4326 D=5622 I=84
|
|
||||||
Mandarin -> 9.57 % N=104762 C=94817 S=4325 D=5620 I=84
|
|
||||||
Other -> 100.00 % N=3 C=0 S=1 D=2 I=0
|
|
||||||
```
|
|
||||||
|
|
||||||
## Build TLG WFST graph
|
|
||||||
|
|
||||||
The script is for building TLG wfst graph, depending on `srilm`, please make sure it is installed.
|
|
||||||
For more information please see the script below.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
bash ./local/run_build_tlg.sh
|
|
||||||
```
|
|
@ -1,71 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
# To be run from one directory above this script.
|
|
||||||
. ./path.sh
|
|
||||||
|
|
||||||
nj=40
|
|
||||||
text=data/local/lm/text
|
|
||||||
lexicon=data/local/dict/lexicon.txt
|
|
||||||
|
|
||||||
for f in "$text" "$lexicon"; do
|
|
||||||
[ ! -f $x ] && echo "$0: No such file $f" && exit 1;
|
|
||||||
done
|
|
||||||
|
|
||||||
# Check SRILM tools
|
|
||||||
if ! which ngram-count > /dev/null; then
|
|
||||||
echo "srilm tools are not found, please download it and install it from: "
|
|
||||||
echo "http://www.speech.sri.com/projects/srilm/download.html"
|
|
||||||
echo "Then add the tools to your PATH"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
# This script takes no arguments. It assumes you have already run
|
|
||||||
# aishell_data_prep.sh.
|
|
||||||
# It takes as input the files
|
|
||||||
# data/local/lm/text
|
|
||||||
# data/local/dict/lexicon.txt
|
|
||||||
dir=data/local/lm
|
|
||||||
mkdir -p $dir
|
|
||||||
|
|
||||||
cleantext=$dir/text.no_oov
|
|
||||||
|
|
||||||
# oov to <SPOKEN_NOISE>
|
|
||||||
# lexicon line: word char0 ... charn
|
|
||||||
# text line: utt word0 ... wordn -> line: <SPOKEN_NOISE> word0 ... wordn
|
|
||||||
text_dir=$(dirname $text)
|
|
||||||
split_name=$(basename $text)
|
|
||||||
./local/split_data.sh $text_dir $text $split_name $nj
|
|
||||||
|
|
||||||
utils/run.pl JOB=1:$nj $text_dir/split${nj}/JOB/${split_name}.no_oov.log \
|
|
||||||
cat ${text_dir}/split${nj}/JOB/${split_name} \| awk -v lex=$lexicon 'BEGIN{while((getline<lex) >0){ seen[$1]=1; } }
|
|
||||||
{for(n=1; n<=NF;n++) { if (seen[$n]) { printf("%s ", $n); } else {printf("<SPOKEN_NOISE> ");} } printf("\n");}' \
|
|
||||||
\> ${text_dir}/split${nj}/JOB/${split_name}.no_oov || exit 1;
|
|
||||||
cat ${text_dir}/split${nj}/*/${split_name}.no_oov > $cleantext
|
|
||||||
|
|
||||||
# compute word counts, sort in descending order
|
|
||||||
# line: count word
|
|
||||||
cat $cleantext | awk '{for(n=2;n<=NF;n++) print $n; }' | sort --parallel=`nproc` | uniq -c | \
|
|
||||||
sort --parallel=`nproc` -nr > $dir/word.counts || exit 1;
|
|
||||||
|
|
||||||
# Get counts from acoustic training transcripts, and add one-count
|
|
||||||
# for each word in the lexicon (but not silence, we don't want it
|
|
||||||
# in the LM-- we'll add it optionally later).
|
|
||||||
cat $cleantext | awk '{for(n=2;n<=NF;n++) print $n; }' | \
|
|
||||||
cat - <(grep -w -v '!SIL' $lexicon | awk '{print $1}') | \
|
|
||||||
sort --parallel=`nproc` | uniq -c | sort --parallel=`nproc` -nr > $dir/unigram.counts || exit 1;
|
|
||||||
|
|
||||||
# word with <s> </s>
|
|
||||||
cat $dir/unigram.counts | awk '{print $2}' | cat - <(echo "<s>"; echo "</s>" ) > $dir/wordlist
|
|
||||||
|
|
||||||
# hold out to compute ppl
|
|
||||||
heldout_sent=10000 # Don't change this if you want result to be comparable with kaldi_lm results
|
|
||||||
|
|
||||||
mkdir -p $dir
|
|
||||||
cat $cleantext | awk '{for(n=2;n<=NF;n++){ printf $n; if(n<NF) printf " "; else print ""; }}' | \
|
|
||||||
head -$heldout_sent > $dir/heldout
|
|
||||||
cat $cleantext | awk '{for(n=2;n<=NF;n++){ printf $n; if(n<NF) printf " "; else print ""; }}' | \
|
|
||||||
tail -n +$heldout_sent > $dir/train
|
|
||||||
|
|
||||||
ngram-count -text $dir/train -order 3 -limit-vocab -vocab $dir/wordlist -unk \
|
|
||||||
-map-unk "<UNK>" -kndiscount -interpolate -lm $dir/lm.arpa
|
|
||||||
ngram -lm $dir/lm.arpa -ppl $dir/heldout
|
|
@ -1,145 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
set -eo pipefail
|
|
||||||
|
|
||||||
. path.sh
|
|
||||||
|
|
||||||
# attention, please replace the vocab is only for this script.
|
|
||||||
# different acustic model has different vocab
|
|
||||||
ckpt_dir=data/fbank_model
|
|
||||||
unit=$ckpt_dir/data/lang_char/vocab.txt # vocab file, line: char/spm_pice
|
|
||||||
model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/
|
|
||||||
|
|
||||||
stage=-1
|
|
||||||
stop_stage=100
|
|
||||||
corpus=aishell
|
|
||||||
lexicon=data/lexicon.txt # line: word ph0 ... phn, aishell/resource_aishell/lexicon.txt
|
|
||||||
text=data/text # line: utt text, aishell/data_aishell/transcript/aishell_transcript_v0.8.txt
|
|
||||||
|
|
||||||
. utils/parse_options.sh
|
|
||||||
|
|
||||||
data=$PWD/data
|
|
||||||
mkdir -p $data
|
|
||||||
|
|
||||||
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
|
|
||||||
if [ ! -f $data/speech.ngram.zh.tar.gz ];then
|
|
||||||
# download ngram
|
|
||||||
pushd $data
|
|
||||||
wget -c http://paddlespeech.bj.bcebos.com/speechx/examples/ngram/zh/speech.ngram.zh.tar.gz
|
|
||||||
tar xvzf speech.ngram.zh.tar.gz
|
|
||||||
popd
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ ! -f $ckpt_dir/data/mean_std.json ]; then
|
|
||||||
# download model
|
|
||||||
mkdir -p $ckpt_dir
|
|
||||||
pushd $ckpt_dir
|
|
||||||
wget -c https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr0/WIP1_asr0_deepspeech2_online_wenetspeech_ckpt_1.0.0a.model.tar.gz
|
|
||||||
tar xzfv WIP1_asr0_deepspeech2_online_wenetspeech_ckpt_1.0.0a.model.tar.gz
|
|
||||||
popd
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ ! -f $unit ]; then
|
|
||||||
echo "$0: No such file $unit"
|
|
||||||
exit 1;
|
|
||||||
fi
|
|
||||||
|
|
||||||
if ! which ngram-count; then
|
|
||||||
# need srilm install
|
|
||||||
pushd $MAIN_ROOT/tools
|
|
||||||
make srilm.done
|
|
||||||
popd
|
|
||||||
fi
|
|
||||||
|
|
||||||
mkdir -p data/local/dict
|
|
||||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
|
||||||
# Prepare dict
|
|
||||||
# line: char/spm_pices
|
|
||||||
cp $unit data/local/dict/units.txt
|
|
||||||
|
|
||||||
if [ ! -f $lexicon ];then
|
|
||||||
utils/text_to_lexicon.py --has_key true --text $text --lexicon $lexicon
|
|
||||||
echo "Generate $lexicon from $text"
|
|
||||||
fi
|
|
||||||
|
|
||||||
# filter by vocab
|
|
||||||
# line: word ph0 ... phn -> line: word char0 ... charn
|
|
||||||
utils/fst/prepare_dict.py \
|
|
||||||
--unit_file $unit \
|
|
||||||
--in_lexicon ${lexicon} \
|
|
||||||
--out_lexicon data/local/dict/lexicon.txt
|
|
||||||
fi
|
|
||||||
|
|
||||||
lm=data/local/lm
|
|
||||||
mkdir -p $lm
|
|
||||||
|
|
||||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
|
||||||
# Train ngram lm
|
|
||||||
cp $text $lm/text
|
|
||||||
local/aishell_train_lms.sh
|
|
||||||
echo "build LM done."
|
|
||||||
fi
|
|
||||||
|
|
||||||
# build TLG
|
|
||||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
|
||||||
# build T & L
|
|
||||||
utils/fst/compile_lexicon_token_fst.sh \
|
|
||||||
data/local/dict data/local/tmp data/local/lang
|
|
||||||
|
|
||||||
# build G & TLG
|
|
||||||
utils/fst/make_tlg.sh data/local/lm data/local/lang data/lang_test || exit 1;
|
|
||||||
|
|
||||||
fi
|
|
||||||
|
|
||||||
aishell_wav_scp=aishell_test.scp
|
|
||||||
nj=40
|
|
||||||
cmvn=$data/cmvn_fbank.ark
|
|
||||||
wfst=$data/lang_test
|
|
||||||
|
|
||||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
|
||||||
if [ ! -d $data/test ]; then
|
|
||||||
# download test dataset
|
|
||||||
pushd $data
|
|
||||||
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_test.zip
|
|
||||||
unzip aishell_test.zip
|
|
||||||
popd
|
|
||||||
|
|
||||||
realpath $data/test/*/*.wav > $data/wavlist
|
|
||||||
awk -F '/' '{ print $(NF) }' $data/wavlist | awk -F '.' '{ print $1 }' > $data/utt_id
|
|
||||||
paste $data/utt_id $data/wavlist > $data/$aishell_wav_scp
|
|
||||||
fi
|
|
||||||
|
|
||||||
./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj
|
|
||||||
|
|
||||||
# convert cmvn format
|
|
||||||
cmvn-json2kaldi --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn
|
|
||||||
fi
|
|
||||||
|
|
||||||
wer=aishell_wer
|
|
||||||
label_file=aishell_result
|
|
||||||
export GLOG_logtostderr=1
|
|
||||||
|
|
||||||
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
|
||||||
# recognize w/ TLG graph
|
|
||||||
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/check_tlg.log \
|
|
||||||
recognizer_main \
|
|
||||||
--wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \
|
|
||||||
--cmvn_file=$cmvn \
|
|
||||||
--model_path=$model_dir/avg_5.jit.pdmodel \
|
|
||||||
--streaming_chunk=30 \
|
|
||||||
--use_fbank=true \
|
|
||||||
--param_path=$model_dir/avg_5.jit.pdiparams \
|
|
||||||
--word_symbol_table=$wfst/words.txt \
|
|
||||||
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
|
|
||||||
--model_cache_shapes="5-1-2048,5-1-2048" \
|
|
||||||
--graph_path=$wfst/TLG.fst --max_active=7500 \
|
|
||||||
--acoustic_scale=1.2 \
|
|
||||||
--result_wspecifier=ark,t:$data/split${nj}/JOB/result_check_tlg
|
|
||||||
|
|
||||||
cat $data/split${nj}/*/result_check_tlg > $exp/${label_file}_check_tlg
|
|
||||||
utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file}_check_tlg > $exp/${wer}.check_tlg
|
|
||||||
echo "recognizer test have finished!!!"
|
|
||||||
echo "please checkout in ${exp}/${wer}.check_tlg"
|
|
||||||
fi
|
|
||||||
|
|
||||||
exit 0
|
|
@ -1,30 +0,0 @@
|
|||||||
#!/usr/bin/env bash
|
|
||||||
|
|
||||||
set -eo pipefail
|
|
||||||
|
|
||||||
data=$1
|
|
||||||
scp=$2
|
|
||||||
split_name=$3
|
|
||||||
numsplit=$4
|
|
||||||
|
|
||||||
# save in $data/split{n}
|
|
||||||
# $scp to split
|
|
||||||
#
|
|
||||||
|
|
||||||
if [[ ! $numsplit -gt 0 ]]; then
|
|
||||||
echo "Invalid num-split argument";
|
|
||||||
exit 1;
|
|
||||||
fi
|
|
||||||
|
|
||||||
directories=$(for n in `seq $numsplit`; do echo $data/split${numsplit}/$n; done)
|
|
||||||
scp_splits=$(for n in `seq $numsplit`; do echo $data/split${numsplit}/$n/${split_name}; done)
|
|
||||||
|
|
||||||
# if this mkdir fails due to argument-list being too long, iterate.
|
|
||||||
if ! mkdir -p $directories >&/dev/null; then
|
|
||||||
for n in `seq $numsplit`; do
|
|
||||||
mkdir -p $data/split${numsplit}/$n
|
|
||||||
done
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo "utils/split_scp.pl $scp $scp_splits"
|
|
||||||
utils/split_scp.pl $scp $scp_splits
|
|
@ -1,24 +0,0 @@
|
|||||||
# This contains the locations of binarys build required for running the examples.
|
|
||||||
|
|
||||||
MAIN_ROOT=`realpath $PWD/../../../../`
|
|
||||||
SPEECHX_ROOT=$PWD/../../../
|
|
||||||
SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx
|
|
||||||
|
|
||||||
SPEECHX_TOOLS=$SPEECHX_ROOT/tools
|
|
||||||
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
|
|
||||||
|
|
||||||
[ -d $SPEECHX_BUILD ] || { echo "Error: 'build/speechx' directory not found. please ensure that the project build successfully"; }
|
|
||||||
|
|
||||||
export LC_AL=C
|
|
||||||
|
|
||||||
# openfst bin & kaldi bin
|
|
||||||
KALDI_DIR=$SPEECHX_ROOT/build/speechx/kaldi/
|
|
||||||
OPENFST_DIR=$SPEECHX_ROOT/fc_patch/openfst-build/src
|
|
||||||
|
|
||||||
# srilm
|
|
||||||
export LIBLBFGS=${MAIN_ROOT}/tools/liblbfgs-1.10
|
|
||||||
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH:-}:${LIBLBFGS}/lib/.libs
|
|
||||||
export SRILM=${MAIN_ROOT}/tools/srilm
|
|
||||||
|
|
||||||
SPEECHX_BIN=$SPEECHX_BUILD/decoder:$SPEECHX_BUILD/frontend/audio
|
|
||||||
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN:${SRILM}/bin:${SRILM}/bin/i686-m64:$KALDI_DIR/lmbin:$KALDI_DIR/fstbin:$OPENFST_DIR/bin
|
|
@ -1,180 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
set -x
|
|
||||||
set -e
|
|
||||||
|
|
||||||
. path.sh
|
|
||||||
|
|
||||||
nj=40
|
|
||||||
stage=0
|
|
||||||
stop_stage=100
|
|
||||||
|
|
||||||
. utils/parse_options.sh
|
|
||||||
|
|
||||||
# 1. compile
|
|
||||||
if [ ! -d ${SPEECHX_BUILD} ]; then
|
|
||||||
pushd ${SPEECHX_ROOT}
|
|
||||||
bash build.sh
|
|
||||||
popd
|
|
||||||
fi
|
|
||||||
|
|
||||||
# input
|
|
||||||
mkdir -p data
|
|
||||||
data=$PWD/data
|
|
||||||
|
|
||||||
ckpt_dir=$data/model
|
|
||||||
model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/
|
|
||||||
vocb_dir=$ckpt_dir/data/lang_char/
|
|
||||||
|
|
||||||
# output
|
|
||||||
mkdir -p exp
|
|
||||||
exp=$PWD/exp
|
|
||||||
|
|
||||||
aishell_wav_scp=aishell_test.scp
|
|
||||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ];then
|
|
||||||
if [ ! -d $data/test ]; then
|
|
||||||
# donwload dataset
|
|
||||||
pushd $data
|
|
||||||
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_test.zip
|
|
||||||
unzip aishell_test.zip
|
|
||||||
popd
|
|
||||||
|
|
||||||
realpath $data/test/*/*.wav > $data/wavlist
|
|
||||||
awk -F '/' '{ print $(NF) }' $data/wavlist | awk -F '.' '{ print $1 }' > $data/utt_id
|
|
||||||
paste $data/utt_id $data/wavlist > $data/$aishell_wav_scp
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ ! -f $ckpt_dir/data/mean_std.json ]; then
|
|
||||||
# download model
|
|
||||||
mkdir -p $ckpt_dir
|
|
||||||
pushd $ckpt_dir
|
|
||||||
wget -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
|
|
||||||
tar xzfv asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
|
|
||||||
popd
|
|
||||||
fi
|
|
||||||
|
|
||||||
lm=$data/zh_giga.no_cna_cmn.prune01244.klm
|
|
||||||
if [ ! -f $lm ]; then
|
|
||||||
# download kenlm bin
|
|
||||||
pushd $data
|
|
||||||
wget -c https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm
|
|
||||||
popd
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
# 3. make feature
|
|
||||||
text=$data/test/text
|
|
||||||
label_file=./aishell_result
|
|
||||||
wer=./aishell_wer
|
|
||||||
|
|
||||||
export GLOG_logtostderr=1
|
|
||||||
|
|
||||||
|
|
||||||
cmvn=$data/cmvn.ark
|
|
||||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
|
||||||
# 3. convert cmvn format and compute linear feat
|
|
||||||
cmvn_json2kaldi_main --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn
|
|
||||||
|
|
||||||
./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj
|
|
||||||
|
|
||||||
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/feat.log \
|
|
||||||
compute_linear_spectrogram_main \
|
|
||||||
--wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \
|
|
||||||
--feature_wspecifier=ark,scp:$data/split${nj}/JOB/feat.ark,$data/split${nj}/JOB/feat.scp \
|
|
||||||
--cmvn_file=$cmvn \
|
|
||||||
echo "feature make have finished!!!"
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
|
||||||
# decode w/o lm
|
|
||||||
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.wolm.log \
|
|
||||||
ctc_beam_search_decoder_main \
|
|
||||||
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
|
|
||||||
--model_path=$model_dir/avg_1.jit.pdmodel \
|
|
||||||
--param_path=$model_dir/avg_1.jit.pdiparams \
|
|
||||||
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
|
|
||||||
--nnet_decoder_chunk=8 \
|
|
||||||
--dict_file=$vocb_dir/vocab.txt \
|
|
||||||
--result_wspecifier=ark,t:$data/split${nj}/JOB/result
|
|
||||||
|
|
||||||
cat $data/split${nj}/*/result > $exp/${label_file}
|
|
||||||
utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file} > $exp/${wer}
|
|
||||||
echo "ctc-prefix-beam-search-decoder-ol without lm has finished!!!"
|
|
||||||
echo "please checkout in ${exp}/${wer}"
|
|
||||||
tail -n 7 $exp/${wer}
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
|
||||||
# decode w/ ngram lm with feature input
|
|
||||||
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.lm.log \
|
|
||||||
ctc_beam_search_decoder_main \
|
|
||||||
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
|
|
||||||
--model_path=$model_dir/avg_1.jit.pdmodel \
|
|
||||||
--param_path=$model_dir/avg_1.jit.pdiparams \
|
|
||||||
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
|
|
||||||
--nnet_decoder_chunk=8 \
|
|
||||||
--dict_file=$vocb_dir/vocab.txt \
|
|
||||||
--lm_path=$lm \
|
|
||||||
--result_wspecifier=ark,t:$data/split${nj}/JOB/result_lm
|
|
||||||
|
|
||||||
cat $data/split${nj}/*/result_lm > $exp/${label_file}_lm
|
|
||||||
utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file}_lm > $exp/${wer}.lm
|
|
||||||
echo "ctc-prefix-beam-search-decoder-ol with lm test has finished!!!"
|
|
||||||
echo "please checkout in ${exp}/${wer}.lm"
|
|
||||||
tail -n 7 $exp/${wer}.lm
|
|
||||||
fi
|
|
||||||
|
|
||||||
wfst=$data/wfst/
|
|
||||||
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
|
||||||
mkdir -p $wfst
|
|
||||||
if [ ! -f $wfst/aishell_graph.zip ]; then
|
|
||||||
# download TLG graph
|
|
||||||
pushd $wfst
|
|
||||||
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_graph.zip
|
|
||||||
unzip aishell_graph.zip
|
|
||||||
mv aishell_graph/* $wfst
|
|
||||||
popd
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
|
||||||
# decoder w/ TLG graph with feature input
|
|
||||||
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.wfst.log \
|
|
||||||
ctc_tlg_decoder_main \
|
|
||||||
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
|
|
||||||
--model_path=$model_dir/avg_1.jit.pdmodel \
|
|
||||||
--param_path=$model_dir/avg_1.jit.pdiparams \
|
|
||||||
--word_symbol_table=$wfst/words.txt \
|
|
||||||
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
|
|
||||||
--graph_path=$wfst/TLG.fst --max_active=7500 \
|
|
||||||
--nnet_decoder_chunk=8 \
|
|
||||||
--acoustic_scale=1.2 \
|
|
||||||
--result_wspecifier=ark,t:$data/split${nj}/JOB/result_tlg
|
|
||||||
|
|
||||||
cat $data/split${nj}/*/result_tlg > $exp/${label_file}_tlg
|
|
||||||
utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file}_tlg > $exp/${wer}.tlg
|
|
||||||
echo "wfst-decoder-ol have finished!!!"
|
|
||||||
echo "please checkout in ${exp}/${wer}.tlg"
|
|
||||||
tail -n 7 $exp/${wer}.tlg
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
|
||||||
# recognize from wav file w/ TLG graph
|
|
||||||
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recognizer.log \
|
|
||||||
recognizer_main \
|
|
||||||
--wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \
|
|
||||||
--cmvn_file=$cmvn \
|
|
||||||
--model_path=$model_dir/avg_1.jit.pdmodel \
|
|
||||||
--param_path=$model_dir/avg_1.jit.pdiparams \
|
|
||||||
--word_symbol_table=$wfst/words.txt \
|
|
||||||
--nnet_decoder_chunk=8 \
|
|
||||||
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
|
|
||||||
--graph_path=$wfst/TLG.fst --max_active=7500 \
|
|
||||||
--acoustic_scale=1.2 \
|
|
||||||
--result_wspecifier=ark,t:$data/split${nj}/JOB/result_recognizer
|
|
||||||
|
|
||||||
cat $data/split${nj}/*/result_recognizer > $exp/${label_file}_recognizer
|
|
||||||
utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file}_recognizer > $exp/${wer}.recognizer
|
|
||||||
echo "recognizer test have finished!!!"
|
|
||||||
echo "please checkout in ${exp}/${wer}.recognizer"
|
|
||||||
tail -n 7 $exp/${wer}.recognizer
|
|
||||||
fi
|
|
@ -1,177 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
set +x
|
|
||||||
set -e
|
|
||||||
|
|
||||||
. path.sh
|
|
||||||
|
|
||||||
nj=40
|
|
||||||
stage=0
|
|
||||||
stop_stage=5
|
|
||||||
|
|
||||||
. utils/parse_options.sh
|
|
||||||
|
|
||||||
# 1. compile
|
|
||||||
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
|
|
||||||
pushd ${SPEECHX_ROOT}
|
|
||||||
bash build.sh
|
|
||||||
popd
|
|
||||||
fi
|
|
||||||
|
|
||||||
# input
|
|
||||||
mkdir -p data
|
|
||||||
data=$PWD/data
|
|
||||||
|
|
||||||
ckpt_dir=$data/fbank_model
|
|
||||||
model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/
|
|
||||||
vocb_dir=$ckpt_dir/data/lang_char/
|
|
||||||
|
|
||||||
# output
|
|
||||||
mkdir -p exp
|
|
||||||
exp=$PWD/exp
|
|
||||||
|
|
||||||
lm=$data/zh_giga.no_cna_cmn.prune01244.klm
|
|
||||||
aishell_wav_scp=aishell_test.scp
|
|
||||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ];then
|
|
||||||
if [ ! -d $data/test ]; then
|
|
||||||
pushd $data
|
|
||||||
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_test.zip
|
|
||||||
unzip aishell_test.zip
|
|
||||||
popd
|
|
||||||
|
|
||||||
realpath $data/test/*/*.wav > $data/wavlist
|
|
||||||
awk -F '/' '{ print $(NF) }' $data/wavlist | awk -F '.' '{ print $1 }' > $data/utt_id
|
|
||||||
paste $data/utt_id $data/wavlist > $data/$aishell_wav_scp
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ ! -f $ckpt_dir/data/mean_std.json ]; then
|
|
||||||
mkdir -p $ckpt_dir
|
|
||||||
pushd $ckpt_dir
|
|
||||||
wget -c https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr0/WIP1_asr0_deepspeech2_online_wenetspeech_ckpt_1.0.0a.model.tar.gz
|
|
||||||
tar xzfv WIP1_asr0_deepspeech2_online_wenetspeech_ckpt_1.0.0a.model.tar.gz
|
|
||||||
popd
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ ! -f $lm ]; then
|
|
||||||
pushd $data
|
|
||||||
wget -c https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm
|
|
||||||
popd
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
# 3. make feature
|
|
||||||
text=$data/test/text
|
|
||||||
label_file=./aishell_result_fbank
|
|
||||||
wer=./aishell_wer_fbank
|
|
||||||
|
|
||||||
export GLOG_logtostderr=1
|
|
||||||
|
|
||||||
|
|
||||||
cmvn=$data/cmvn_fbank.ark
|
|
||||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
|
||||||
# 3. convert cmvn format and compute fbank feat
|
|
||||||
cmvn_json2kaldi_main --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn --binary=false
|
|
||||||
|
|
||||||
./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj
|
|
||||||
|
|
||||||
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/feat.log \
|
|
||||||
compute_fbank_main \
|
|
||||||
--wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \
|
|
||||||
--feature_wspecifier=ark,scp:$data/split${nj}/JOB/fbank_feat.ark,$data/split${nj}/JOB/fbank_feat.scp \
|
|
||||||
--cmvn_file=$cmvn \
|
|
||||||
--streaming_chunk=36
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
|
||||||
# decode w/ lm by feature
|
|
||||||
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.fbank.wolm.log \
|
|
||||||
ctc_beam_search_decoder_main \
|
|
||||||
--feature_rspecifier=scp:$data/split${nj}/JOB/fbank_feat.scp \
|
|
||||||
--model_path=$model_dir/avg_5.jit.pdmodel \
|
|
||||||
--param_path=$model_dir/avg_5.jit.pdiparams \
|
|
||||||
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
|
|
||||||
--model_cache_shapes="5-1-2048,5-1-2048" \
|
|
||||||
--nnet_decoder_chunk=8 \
|
|
||||||
--dict_file=$vocb_dir/vocab.txt \
|
|
||||||
--result_wspecifier=ark,t:$data/split${nj}/JOB/result_fbank
|
|
||||||
|
|
||||||
cat $data/split${nj}/*/result_fbank > $exp/${label_file}
|
|
||||||
utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file} > $exp/${wer}
|
|
||||||
tail -n 7 $exp/${wer}
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
|
||||||
# decode with ngram lm by feature
|
|
||||||
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.fbank.lm.log \
|
|
||||||
ctc_beam_search_decoder_main \
|
|
||||||
--feature_rspecifier=scp:$data/split${nj}/JOB/fbank_feat.scp \
|
|
||||||
--model_path=$model_dir/avg_5.jit.pdmodel \
|
|
||||||
--param_path=$model_dir/avg_5.jit.pdiparams \
|
|
||||||
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
|
|
||||||
--model_cache_shapes="5-1-2048,5-1-2048" \
|
|
||||||
--nnet_decoder_chunk=8 \
|
|
||||||
--dict_file=$vocb_dir/vocab.txt \
|
|
||||||
--lm_path=$lm \
|
|
||||||
--result_wspecifier=ark,t:$data/split${nj}/JOB/fbank_result_lm
|
|
||||||
|
|
||||||
cat $data/split${nj}/*/fbank_result_lm > $exp/${label_file}_lm
|
|
||||||
utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file}_lm > $exp/${wer}.lm
|
|
||||||
tail -n 7 $exp/${wer}.lm
|
|
||||||
fi
|
|
||||||
|
|
||||||
wfst=$data/wfst_fbank/
|
|
||||||
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
|
||||||
mkdir -p $wfst
|
|
||||||
if [ ! -f $wfst/aishell_graph2.zip ]; then
|
|
||||||
pushd $wfst
|
|
||||||
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_graph2.zip
|
|
||||||
unzip aishell_graph2.zip
|
|
||||||
mv aishell_graph2/* $wfst
|
|
||||||
popd
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
|
||||||
# decode w/ TLG graph by feature
|
|
||||||
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.fbank.wfst.log \
|
|
||||||
ctc_tlg_decoder_main \
|
|
||||||
--feature_rspecifier=scp:$data/split${nj}/JOB/fbank_feat.scp \
|
|
||||||
--model_path=$model_dir/avg_5.jit.pdmodel \
|
|
||||||
--param_path=$model_dir/avg_5.jit.pdiparams \
|
|
||||||
--word_symbol_table=$wfst/words.txt \
|
|
||||||
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
|
|
||||||
--model_cache_shapes="5-1-2048,5-1-2048" \
|
|
||||||
--nnet_decoder_chunk=8 \
|
|
||||||
--graph_path=$wfst/TLG.fst --max_active=7500 \
|
|
||||||
--acoustic_scale=1.2 \
|
|
||||||
--result_wspecifier=ark,t:$data/split${nj}/JOB/result_tlg
|
|
||||||
|
|
||||||
cat $data/split${nj}/*/result_tlg > $exp/${label_file}_tlg
|
|
||||||
utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file}_tlg > $exp/${wer}.tlg
|
|
||||||
echo "wfst-decoder-ol have finished!!!"
|
|
||||||
echo "please checkout in ${exp}/${wer}.tlg"
|
|
||||||
tail -n 7 $exp/${wer}.tlg
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
|
||||||
# recgonize w/ TLG graph by wav
|
|
||||||
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/fbank_recognizer.log \
|
|
||||||
recognizer_main \
|
|
||||||
--wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \
|
|
||||||
--cmvn_file=$cmvn \
|
|
||||||
--model_path=$model_dir/avg_5.jit.pdmodel \
|
|
||||||
--use_fbank=true \
|
|
||||||
--param_path=$model_dir/avg_5.jit.pdiparams \
|
|
||||||
--word_symbol_table=$wfst/words.txt \
|
|
||||||
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
|
|
||||||
--model_cache_shapes="5-1-2048,5-1-2048" \
|
|
||||||
--nnet_decoder_chunk=8 \
|
|
||||||
--graph_path=$wfst/TLG.fst --max_active=7500 \
|
|
||||||
--acoustic_scale=1.2 \
|
|
||||||
--result_wspecifier=ark,t:$data/split${nj}/JOB/result_fbank_recognizer
|
|
||||||
|
|
||||||
cat $data/split${nj}/*/result_fbank_recognizer > $exp/${label_file}_recognizer
|
|
||||||
utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file}_recognizer > $exp/${wer}.recognizer
|
|
||||||
echo "recognizer test have finished!!!"
|
|
||||||
echo "please checkout in ${exp}/${wer}.recognizer"
|
|
||||||
tail -n 7 $exp/${wer}.recognizer
|
|
||||||
fi
|
|
@ -1 +0,0 @@
|
|||||||
../../../../utils/
|
|
@ -1,3 +0,0 @@
|
|||||||
data
|
|
||||||
log
|
|
||||||
exp
|
|
@ -1,100 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import argparse
|
|
||||||
import os
|
|
||||||
import pickle
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import onnxruntime
|
|
||||||
import paddle
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
|
||||||
parser = argparse.ArgumentParser(description=__doc__)
|
|
||||||
parser.add_argument(
|
|
||||||
'--input_file',
|
|
||||||
type=str,
|
|
||||||
default="static_ds2online_inputs.pickle",
|
|
||||||
help="aishell ds2 input data file. For wenetspeech, we only feed for infer model",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--model_type',
|
|
||||||
type=str,
|
|
||||||
default="aishell",
|
|
||||||
help="aishell(1024) or wenetspeech(2048)", )
|
|
||||||
parser.add_argument(
|
|
||||||
'--model_dir', type=str, default=".", help="paddle model dir.")
|
|
||||||
parser.add_argument(
|
|
||||||
'--model_prefix',
|
|
||||||
type=str,
|
|
||||||
default="avg_1.jit",
|
|
||||||
help="paddle model prefix.")
|
|
||||||
parser.add_argument(
|
|
||||||
'--onnx_model',
|
|
||||||
type=str,
|
|
||||||
default='./model.old.onnx',
|
|
||||||
help="onnx model.")
|
|
||||||
|
|
||||||
return parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
FLAGS = parse_args()
|
|
||||||
|
|
||||||
# input and output
|
|
||||||
with open(FLAGS.input_file, 'rb') as f:
|
|
||||||
iodict = pickle.load(f)
|
|
||||||
print(iodict.keys())
|
|
||||||
|
|
||||||
audio_chunk = iodict['audio_chunk']
|
|
||||||
audio_chunk_lens = iodict['audio_chunk_lens']
|
|
||||||
chunk_state_h_box = iodict['chunk_state_h_box']
|
|
||||||
chunk_state_c_box = iodict['chunk_state_c_bos']
|
|
||||||
print("raw state shape: ", chunk_state_c_box.shape)
|
|
||||||
|
|
||||||
if FLAGS.model_type == 'wenetspeech':
|
|
||||||
chunk_state_h_box = np.repeat(chunk_state_h_box, 2, axis=-1)
|
|
||||||
chunk_state_c_box = np.repeat(chunk_state_c_box, 2, axis=-1)
|
|
||||||
print("state shape: ", chunk_state_c_box.shape)
|
|
||||||
|
|
||||||
# paddle
|
|
||||||
model = paddle.jit.load(os.path.join(FLAGS.model_dir, FLAGS.model_prefix))
|
|
||||||
res_chunk, res_lens, chunk_state_h, chunk_state_c = model(
|
|
||||||
paddle.to_tensor(audio_chunk),
|
|
||||||
paddle.to_tensor(audio_chunk_lens),
|
|
||||||
paddle.to_tensor(chunk_state_h_box),
|
|
||||||
paddle.to_tensor(chunk_state_c_box), )
|
|
||||||
|
|
||||||
# onnxruntime
|
|
||||||
options = onnxruntime.SessionOptions()
|
|
||||||
options.enable_profiling = True
|
|
||||||
sess = onnxruntime.InferenceSession(FLAGS.onnx_model, sess_options=options)
|
|
||||||
ort_res_chunk, ort_res_lens, ort_chunk_state_h, ort_chunk_state_c = sess.run(
|
|
||||||
['softmax_0.tmp_0', 'tmp_5', 'concat_0.tmp_0', 'concat_1.tmp_0'], {
|
|
||||||
"audio_chunk": audio_chunk,
|
|
||||||
"audio_chunk_lens": audio_chunk_lens,
|
|
||||||
"chunk_state_h_box": chunk_state_h_box,
|
|
||||||
"chunk_state_c_box": chunk_state_c_box
|
|
||||||
})
|
|
||||||
|
|
||||||
print(sess.end_profiling())
|
|
||||||
|
|
||||||
# assert paddle equal ort
|
|
||||||
print(np.allclose(ort_res_chunk, res_chunk, atol=1e-6))
|
|
||||||
print(np.allclose(ort_res_lens, res_lens, atol=1e-6))
|
|
||||||
|
|
||||||
if FLAGS.model_type == 'aishell':
|
|
||||||
print(np.allclose(ort_chunk_state_h, chunk_state_h, atol=1e-6))
|
|
||||||
print(np.allclose(ort_chunk_state_c, chunk_state_c, atol=1e-6))
|
|
@ -1,14 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
# show model
|
|
||||||
|
|
||||||
if [ $# != 1 ];then
|
|
||||||
echo "usage: $0 model_path"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
file=$1
|
|
||||||
|
|
||||||
pip install netron
|
|
||||||
netron -p 8082 --host $(hostname -i) $file
|
|
@ -1,7 +0,0 @@
|
|||||||
|
|
||||||
#!/bin/bash
|
|
||||||
|
|
||||||
# clone onnx repos
|
|
||||||
git clone https://github.com/onnx/onnx.git
|
|
||||||
git clone https://github.com/microsoft/onnxruntime.git
|
|
||||||
git clone https://github.com/PaddlePaddle/Paddle2ONNX.git
|
|
@ -1,37 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
import onnx
|
|
||||||
from onnx import version_converter
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
parser = argparse.ArgumentParser(prog=__doc__)
|
|
||||||
parser.add_argument(
|
|
||||||
"--model-file", type=str, required=True, help='path/to/the/model.onnx.')
|
|
||||||
parser.add_argument(
|
|
||||||
"--save-model",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help='path/to/saved/model.onnx.')
|
|
||||||
# Models must be opset10 or higher to be quantized.
|
|
||||||
parser.add_argument(
|
|
||||||
"--target-opset", type=int, default=11, help='path/to/the/model.onnx.')
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
print(f"to opset: {args.target_opset}")
|
|
||||||
|
|
||||||
# Preprocessing: load the model to be converted.
|
|
||||||
model_path = args.model_file
|
|
||||||
original_model = onnx.load(model_path)
|
|
||||||
|
|
||||||
# print('The model before conversion:\n{}'.format(original_model))
|
|
||||||
|
|
||||||
# A full list of supported adapters can be found here:
|
|
||||||
# https://github.com/onnx/onnx/blob/main/onnx/version_converter.py#L21
|
|
||||||
# Apply the version conversion on the original model
|
|
||||||
converted_model = version_converter.convert_version(original_model,
|
|
||||||
args.target_opset)
|
|
||||||
|
|
||||||
# print('The model after conversion:\n{}'.format(converted_model))
|
|
||||||
onnx.save(converted_model, args.save_model)
|
|
File diff suppressed because it is too large
Load Diff
@ -1,20 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
set -e
|
|
||||||
|
|
||||||
if [ $# != 3 ];then
|
|
||||||
# ./local/onnx_opt.sh model.old.onnx model.opt.onnx "audio_chunk:1,-1,161 audio_chunk_lens:1 chunk_state_c_box:5,1,1024 chunk_state_h_box:5,1,1024"
|
|
||||||
echo "usage: $0 onnx.model.in onnx.model.out input_shape "
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
# onnx optimizer
|
|
||||||
pip install onnx-simplifier
|
|
||||||
|
|
||||||
in=$1
|
|
||||||
out=$2
|
|
||||||
input_shape=$3
|
|
||||||
|
|
||||||
check_n=3
|
|
||||||
|
|
||||||
onnxsim $in $out $check_n --dynamic-input-shape --input-shape $input_shape
|
|
@ -1,128 +0,0 @@
|
|||||||
#!/usr/bin/env python3 -W ignore::DeprecationWarning
|
|
||||||
# prune model by output names
|
|
||||||
import argparse
|
|
||||||
import copy
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import onnx
|
|
||||||
|
|
||||||
|
|
||||||
def parse_arguments():
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument(
|
|
||||||
'--model',
|
|
||||||
required=True,
|
|
||||||
help='Path of directory saved the input model.')
|
|
||||||
parser.add_argument(
|
|
||||||
'--output_names',
|
|
||||||
required=True,
|
|
||||||
nargs='+',
|
|
||||||
help='The outputs of pruned model.')
|
|
||||||
parser.add_argument(
|
|
||||||
'--save_file', required=True, help='Path to save the new onnx model.')
|
|
||||||
return parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
args = parse_arguments()
|
|
||||||
|
|
||||||
if len(set(args.output_names)) < len(args.output_names):
|
|
||||||
print(
|
|
||||||
"[ERROR] There's dumplicate name in --output_names, which is not allowed."
|
|
||||||
)
|
|
||||||
sys.exit(-1)
|
|
||||||
|
|
||||||
model = onnx.load(args.model)
|
|
||||||
|
|
||||||
# collect all node outputs and graph output
|
|
||||||
output_tensor_names = set()
|
|
||||||
for node in model.graph.node:
|
|
||||||
for out in node.output:
|
|
||||||
# may contain model output
|
|
||||||
output_tensor_names.add(out)
|
|
||||||
|
|
||||||
# for out in model.graph.output:
|
|
||||||
# output_tensor_names.add(out.name)
|
|
||||||
|
|
||||||
for output_name in args.output_names:
|
|
||||||
if output_name not in output_tensor_names:
|
|
||||||
print(
|
|
||||||
"[ERROR] Cannot find output tensor name '{}' in onnx model graph.".
|
|
||||||
format(output_name))
|
|
||||||
sys.exit(-1)
|
|
||||||
|
|
||||||
output_node_indices = set() # has output names
|
|
||||||
output_to_node = dict() # all node outputs
|
|
||||||
for i, node in enumerate(model.graph.node):
|
|
||||||
for out in node.output:
|
|
||||||
output_to_node[out] = i
|
|
||||||
if out in args.output_names:
|
|
||||||
output_node_indices.add(i)
|
|
||||||
|
|
||||||
# from outputs find all the ancestors
|
|
||||||
reserved_node_indices = copy.deepcopy(
|
|
||||||
output_node_indices) # nodes need to keep
|
|
||||||
reserved_inputs = set() # model input to keep
|
|
||||||
new_output_node_indices = copy.deepcopy(output_node_indices)
|
|
||||||
|
|
||||||
while True and len(new_output_node_indices) > 0:
|
|
||||||
output_node_indices = copy.deepcopy(new_output_node_indices)
|
|
||||||
|
|
||||||
new_output_node_indices = set()
|
|
||||||
|
|
||||||
for out_node_idx in output_node_indices:
|
|
||||||
# backtrace to parenet
|
|
||||||
for ipt in model.graph.node[out_node_idx].input:
|
|
||||||
if ipt in output_to_node:
|
|
||||||
reserved_node_indices.add(output_to_node[ipt])
|
|
||||||
new_output_node_indices.add(output_to_node[ipt])
|
|
||||||
else:
|
|
||||||
reserved_inputs.add(ipt)
|
|
||||||
|
|
||||||
num_inputs = len(model.graph.input)
|
|
||||||
num_outputs = len(model.graph.output)
|
|
||||||
num_nodes = len(model.graph.node)
|
|
||||||
print(
|
|
||||||
f"old graph has {num_inputs} inputs, {num_outputs} outpus, {num_nodes} nodes"
|
|
||||||
)
|
|
||||||
print(f"{len(reserved_node_indices)} node to keep.")
|
|
||||||
|
|
||||||
# del node not to keep
|
|
||||||
for idx in range(num_nodes - 1, -1, -1):
|
|
||||||
if idx not in reserved_node_indices:
|
|
||||||
del model.graph.node[idx]
|
|
||||||
|
|
||||||
# del graph input not to keep
|
|
||||||
for idx in range(num_inputs - 1, -1, -1):
|
|
||||||
if model.graph.input[idx].name not in reserved_inputs:
|
|
||||||
del model.graph.input[idx]
|
|
||||||
|
|
||||||
# del old graph outputs
|
|
||||||
for i in range(num_outputs):
|
|
||||||
del model.graph.output[0]
|
|
||||||
|
|
||||||
# new graph output as user input
|
|
||||||
for out in args.output_names:
|
|
||||||
model.graph.output.extend([onnx.ValueInfoProto(name=out)])
|
|
||||||
|
|
||||||
# infer shape
|
|
||||||
try:
|
|
||||||
from onnx_infer_shape import SymbolicShapeInference
|
|
||||||
model = SymbolicShapeInference.infer_shapes(
|
|
||||||
model,
|
|
||||||
int_max=2**31 - 1,
|
|
||||||
auto_merge=True,
|
|
||||||
guess_output_rank=False,
|
|
||||||
verbose=1)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"skip infer shape step: {e}")
|
|
||||||
|
|
||||||
# check onnx model
|
|
||||||
onnx.checker.check_model(model)
|
|
||||||
# save onnx model
|
|
||||||
onnx.save(model, args.save_file)
|
|
||||||
print("[Finished] The new model saved in {}.".format(args.save_file))
|
|
||||||
print("[DEBUG INFO] The inputs of new model: {}".format(
|
|
||||||
[x.name for x in model.graph.input]))
|
|
||||||
print("[DEBUG INFO] The outputs of new model: {}".format(
|
|
||||||
[x.name for x in model.graph.output]))
|
|
@ -1,111 +0,0 @@
|
|||||||
#!/usr/bin/env python3 -W ignore::DeprecationWarning
|
|
||||||
# rename node to new names
|
|
||||||
import argparse
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import onnx
|
|
||||||
|
|
||||||
|
|
||||||
def parse_arguments():
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument(
|
|
||||||
'--model',
|
|
||||||
required=True,
|
|
||||||
help='Path of directory saved the input model.')
|
|
||||||
parser.add_argument(
|
|
||||||
'--origin_names',
|
|
||||||
required=True,
|
|
||||||
nargs='+',
|
|
||||||
help='The original name you want to modify.')
|
|
||||||
parser.add_argument(
|
|
||||||
'--new_names',
|
|
||||||
required=True,
|
|
||||||
nargs='+',
|
|
||||||
help='The new name you want change to, the number of new_names should be same with the number of origin_names'
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--save_file', required=True, help='Path to save the new onnx model.')
|
|
||||||
return parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
args = parse_arguments()
|
|
||||||
|
|
||||||
if len(set(args.origin_names)) < len(args.origin_names):
|
|
||||||
print(
|
|
||||||
"[ERROR] There's dumplicate name in --origin_names, which is not allowed."
|
|
||||||
)
|
|
||||||
sys.exit(-1)
|
|
||||||
|
|
||||||
if len(set(args.new_names)) < len(args.new_names):
|
|
||||||
print(
|
|
||||||
"[ERROR] There's dumplicate name in --new_names, which is not allowed."
|
|
||||||
)
|
|
||||||
sys.exit(-1)
|
|
||||||
|
|
||||||
if len(args.new_names) != len(args.origin_names):
|
|
||||||
print(
|
|
||||||
"[ERROR] Number of --new_names must be same with the number of --origin_names."
|
|
||||||
)
|
|
||||||
sys.exit(-1)
|
|
||||||
|
|
||||||
model = onnx.load(args.model)
|
|
||||||
|
|
||||||
# collect input and all node output
|
|
||||||
output_tensor_names = set()
|
|
||||||
for ipt in model.graph.input:
|
|
||||||
output_tensor_names.add(ipt.name)
|
|
||||||
|
|
||||||
for node in model.graph.node:
|
|
||||||
for out in node.output:
|
|
||||||
output_tensor_names.add(out)
|
|
||||||
|
|
||||||
for origin_name in args.origin_names:
|
|
||||||
if origin_name not in output_tensor_names:
|
|
||||||
print(
|
|
||||||
f"[ERROR] Cannot find tensor name '{origin_name}' in onnx model graph."
|
|
||||||
)
|
|
||||||
sys.exit(-1)
|
|
||||||
|
|
||||||
for new_name in args.new_names:
|
|
||||||
if new_name in output_tensor_names:
|
|
||||||
print(
|
|
||||||
"[ERROR] The defined new_name '{}' is already exist in the onnx model, which is not allowed."
|
|
||||||
)
|
|
||||||
sys.exit(-1)
|
|
||||||
|
|
||||||
# rename graph input
|
|
||||||
for i, ipt in enumerate(model.graph.input):
|
|
||||||
if ipt.name in args.origin_names:
|
|
||||||
idx = args.origin_names.index(ipt.name)
|
|
||||||
model.graph.input[i].name = args.new_names[idx]
|
|
||||||
|
|
||||||
# rename node input and output
|
|
||||||
for i, node in enumerate(model.graph.node):
|
|
||||||
for j, ipt in enumerate(node.input):
|
|
||||||
if ipt in args.origin_names:
|
|
||||||
idx = args.origin_names.index(ipt)
|
|
||||||
model.graph.node[i].input[j] = args.new_names[idx]
|
|
||||||
|
|
||||||
for j, out in enumerate(node.output):
|
|
||||||
if out in args.origin_names:
|
|
||||||
idx = args.origin_names.index(out)
|
|
||||||
model.graph.node[i].output[j] = args.new_names[idx]
|
|
||||||
|
|
||||||
# rename graph output
|
|
||||||
for i, out in enumerate(model.graph.output):
|
|
||||||
if out.name in args.origin_names:
|
|
||||||
idx = args.origin_names.index(out.name)
|
|
||||||
model.graph.output[i].name = args.new_names[idx]
|
|
||||||
|
|
||||||
# check onnx model
|
|
||||||
onnx.checker.check_model(model)
|
|
||||||
|
|
||||||
# save model
|
|
||||||
onnx.save(model, args.save_file)
|
|
||||||
|
|
||||||
print("[Finished] The new model saved in {}.".format(args.save_file))
|
|
||||||
print("[DEBUG INFO] The inputs of new model: {}".format(
|
|
||||||
[x.name for x in model.graph.input]))
|
|
||||||
print("[DEBUG INFO] The outputs of new model: {}".format(
|
|
||||||
[x.name for x in model.graph.output]))
|
|
@ -1,48 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
from onnxruntime.quantization import quantize_dynamic
|
|
||||||
from onnxruntime.quantization import QuantType
|
|
||||||
|
|
||||||
|
|
||||||
def quantize_onnx_model(onnx_model_path,
|
|
||||||
quantized_model_path,
|
|
||||||
nodes_to_exclude=[]):
|
|
||||||
print("Starting quantization...")
|
|
||||||
|
|
||||||
quantize_dynamic(
|
|
||||||
onnx_model_path,
|
|
||||||
quantized_model_path,
|
|
||||||
weight_type=QuantType.QInt8,
|
|
||||||
nodes_to_exclude=nodes_to_exclude)
|
|
||||||
|
|
||||||
print(f"Quantized model saved to: {quantized_model_path}")
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument(
|
|
||||||
"--model-in",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="ONNX model", )
|
|
||||||
parser.add_argument(
|
|
||||||
"--model-out",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
default='model.quant.onnx',
|
|
||||||
help="ONNX model", )
|
|
||||||
parser.add_argument(
|
|
||||||
"--nodes-to-exclude",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="nodes to exclude. e.g. conv,linear.", )
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
nodes_to_exclude = args.nodes_to_exclude.split(',')
|
|
||||||
quantize_onnx_model(args.model_in, args.model_out, nodes_to_exclude)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
@ -1,45 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
import onnxruntime as ort
|
|
||||||
|
|
||||||
# onnxruntime optimizer.
|
|
||||||
# https://onnxruntime.ai/docs/performance/graph-optimizations.html
|
|
||||||
# https://onnxruntime.ai/docs/api/python/api_summary.html#api
|
|
||||||
|
|
||||||
|
|
||||||
def parse_arguments():
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument(
|
|
||||||
'--model_in', required=True, type=str, help='Path to onnx model.')
|
|
||||||
parser.add_argument(
|
|
||||||
'--opt_level',
|
|
||||||
required=True,
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
choices=[0, 1, 2],
|
|
||||||
help='Path to onnx model.')
|
|
||||||
parser.add_argument(
|
|
||||||
'--model_out', required=True, help='path to save the optimized model.')
|
|
||||||
parser.add_argument('--debug', default=False, help='output debug info.')
|
|
||||||
return parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
args = parse_arguments()
|
|
||||||
|
|
||||||
sess_options = ort.SessionOptions()
|
|
||||||
|
|
||||||
# Set graph optimization level
|
|
||||||
print(f"opt level: {args.opt_level}")
|
|
||||||
if args.opt_level == 0:
|
|
||||||
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
|
|
||||||
elif args.opt_level == 1:
|
|
||||||
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
|
|
||||||
else:
|
|
||||||
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
||||||
|
|
||||||
# To enable model serialization after graph optimization set this
|
|
||||||
sess_options.optimized_model_filepath = args.model_out
|
|
||||||
|
|
||||||
session = ort.InferenceSession(args.model_in, sess_options)
|
|
@ -1,26 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
if [ $# != 4 ];then
|
|
||||||
# local/tonnx.sh data/exp/deepspeech2_online/checkpoints avg_1.jit.pdmodel avg_1.jit.pdiparams exp/model.onnx
|
|
||||||
echo "usage: $0 model_dir model_name param_name onnx_output_name"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
dir=$1
|
|
||||||
model=$2
|
|
||||||
param=$3
|
|
||||||
output=$4
|
|
||||||
|
|
||||||
pip install paddle2onnx
|
|
||||||
pip install onnx
|
|
||||||
|
|
||||||
# https://github.com/PaddlePaddle/Paddle2ONNX#%E5%91%BD%E4%BB%A4%E8%A1%8C%E8%BD%AC%E6%8D%A2
|
|
||||||
# opset10 support quantize
|
|
||||||
paddle2onnx --model_dir $dir \
|
|
||||||
--model_filename $model \
|
|
||||||
--params_filename $param \
|
|
||||||
--save_file $output \
|
|
||||||
--enable_dev_version True \
|
|
||||||
--opset_version 11 \
|
|
||||||
--enable_onnx_checker True
|
|
||||||
|
|
@ -1,14 +0,0 @@
|
|||||||
# This contains the locations of binarys build required for running the examples.
|
|
||||||
|
|
||||||
MAIN_ROOT=`realpath $PWD/../../../../`
|
|
||||||
SPEECHX_ROOT=$PWD/../../../
|
|
||||||
SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx
|
|
||||||
|
|
||||||
SPEECHX_TOOLS=$SPEECHX_ROOT/tools
|
|
||||||
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
|
|
||||||
|
|
||||||
[ -d $SPEECHX_BUILD ] || { echo "Error: 'build/speechx' directory not found. please ensure that the project build successfully"; }
|
|
||||||
|
|
||||||
export LC_AL=C
|
|
||||||
|
|
||||||
export PATH=$PATH:$TOOLS_BIN
|
|
@ -1,91 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
set -e
|
|
||||||
|
|
||||||
. path.sh
|
|
||||||
|
|
||||||
stage=0
|
|
||||||
stop_stage=50
|
|
||||||
tarfile=asr0_deepspeech2_online_wenetspeech_ckpt_1.0.2.model.tar.gz
|
|
||||||
#tarfile=asr0_deepspeech2_online_aishell_fbank161_ckpt_1.0.1.model.tar.gz
|
|
||||||
model_prefix=avg_10.jit
|
|
||||||
#model_prefix=avg_1.jit
|
|
||||||
model=${model_prefix}.pdmodel
|
|
||||||
param=${model_prefix}.pdiparams
|
|
||||||
|
|
||||||
. utils/parse_options.sh
|
|
||||||
|
|
||||||
data=data
|
|
||||||
exp=exp
|
|
||||||
|
|
||||||
mkdir -p $data $exp
|
|
||||||
|
|
||||||
dir=$data/exp/deepspeech2_online/checkpoints
|
|
||||||
|
|
||||||
# wenetspeech or aishell
|
|
||||||
model_type=$(echo $tarfile | cut -d '_' -f 4)
|
|
||||||
|
|
||||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ];then
|
|
||||||
test -f $data/$tarfile || wget -P $data -c https://paddlespeech.bj.bcebos.com/s2t/$model_type/asr0/$tarfile
|
|
||||||
|
|
||||||
# wenetspeech ds2 model
|
|
||||||
pushd $data
|
|
||||||
tar zxvf $tarfile
|
|
||||||
popd
|
|
||||||
|
|
||||||
# ds2 model demo inputs
|
|
||||||
pushd $exp
|
|
||||||
wget -c http://paddlespeech.bj.bcebos.com/speechx/examples/ds2_ol/onnx/static_ds2online_inputs.pickle
|
|
||||||
popd
|
|
||||||
fi
|
|
||||||
|
|
||||||
input_file=$exp/static_ds2online_inputs.pickle
|
|
||||||
test -e $input_file
|
|
||||||
|
|
||||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ];then
|
|
||||||
# to onnx
|
|
||||||
./local/tonnx.sh $dir $model $param $exp/model.onnx
|
|
||||||
|
|
||||||
./local/infer_check.py --input_file $input_file --model_type $model_type --model_dir $dir --model_prefix $model_prefix --onnx_model $exp/model.onnx
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ] ;then
|
|
||||||
# ort graph optmize
|
|
||||||
./local/ort_opt.py --model_in $exp/model.onnx --opt_level 0 --model_out $exp/model.ort.opt.onnx
|
|
||||||
|
|
||||||
./local/infer_check.py --input_file $input_file --model_type $model_type --model_dir $dir --model_prefix $model_prefix --onnx_model $exp/model.ort.opt.onnx
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ];then
|
|
||||||
# convert opset_num to 11
|
|
||||||
./local/onnx_convert_opset.py --target-opset 11 --model-file $exp/model.ort.opt.onnx --save-model $exp/model.optset11.onnx
|
|
||||||
|
|
||||||
# quant model
|
|
||||||
nodes_to_exclude='p2o.Conv.0,p2o.Conv.2'
|
|
||||||
./local/ort_dyanmic_quant.py --model-in $exp/model.optset11.onnx --model-out $exp/model.optset11.quant.onnx --nodes-to-exclude "${nodes_to_exclude}"
|
|
||||||
|
|
||||||
./local/infer_check.py --input_file $input_file --model_type $model_type --model_dir $dir --model_prefix $model_prefix --onnx_model $exp/model.optset11.quant.onnx
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
# aishell rnn hidden is 1024
|
|
||||||
# wenetspeech rnn hiddn is 2048
|
|
||||||
if [ $model_type == 'aishell' ];then
|
|
||||||
input_shape="audio_chunk:1,-1,161 audio_chunk_lens:1 chunk_state_c_box:5,1,1024 chunk_state_h_box:5,1,1024"
|
|
||||||
elif [ $model_type == 'wenetspeech' ];then
|
|
||||||
input_shape="audio_chunk:1,-1,161 audio_chunk_lens:1 chunk_state_c_box:5,1,2048 chunk_state_h_box:5,1,2048"
|
|
||||||
else
|
|
||||||
echo "not support: $model_type"
|
|
||||||
exit -1
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
if [ ${stage} -le 51 ] && [ ${stop_stage} -ge 51 ] ;then
|
|
||||||
# wenetspeech ds2 model execed 2GB limit, will error.
|
|
||||||
# simplifying onnx model
|
|
||||||
./local/onnx_opt.sh $exp/model.onnx $exp/model.opt.onnx "$input_shape"
|
|
||||||
|
|
||||||
./local/infer_check.py --input_file $input_file --model_type $model_type --model_dir $dir --model_prefix $model_prefix --onnx_model $exp/model.opt.onnx
|
|
||||||
fi
|
|
@ -1 +0,0 @@
|
|||||||
../../../../utils/
|
|
@ -1,2 +0,0 @@
|
|||||||
data
|
|
||||||
exp
|
|
@ -1,78 +0,0 @@
|
|||||||
# Streaming DeepSpeech2 Server with WebSocket
|
|
||||||
|
|
||||||
This example is about using `websocket` as streaming deepspeech2 server. For deepspeech2 model training please see [here](../../../../examples/aishell/asr0/).
|
|
||||||
|
|
||||||
The websocket protocal is same to [PaddleSpeech Server](../../../../demos/streaming_asr_server/),
|
|
||||||
for detail of implementation please see [here](../../../speechx/protocol/websocket/).
|
|
||||||
|
|
||||||
|
|
||||||
## Source path.sh
|
|
||||||
|
|
||||||
```bash
|
|
||||||
. path.sh
|
|
||||||
```
|
|
||||||
|
|
||||||
SpeechX bins is under `echo $SPEECHX_BUILD`, more info please see `path.sh`.
|
|
||||||
|
|
||||||
|
|
||||||
## Start WebSocket Server
|
|
||||||
|
|
||||||
```bash
|
|
||||||
bash websoket_server.sh
|
|
||||||
```
|
|
||||||
|
|
||||||
The output is like below:
|
|
||||||
|
|
||||||
```text
|
|
||||||
I1130 02:19:32.029882 12856 cmvn_json2kaldi_main.cc:39] cmvn josn path: /workspace/zhanghui/PaddleSpeech/speechx/examples/ds2_ol/websocket/data/model/data/mean_std.json
|
|
||||||
I1130 02:19:32.032230 12856 cmvn_json2kaldi_main.cc:73] nframe: 907497
|
|
||||||
I1130 02:19:32.032564 12856 cmvn_json2kaldi_main.cc:85] cmvn stats have write into: /workspace/zhanghui/PaddleSpeech/speechx/examples/ds2_ol/websocket/data/cmvn.ark
|
|
||||||
I1130 02:19:32.032579 12856 cmvn_json2kaldi_main.cc:86] Binary: 1
|
|
||||||
I1130 02:19:32.798342 12937 feature_pipeline.h:53] cmvn file: /workspace/zhanghui/PaddleSpeech/speechx/examples/ds2_ol/websocket/data/cmvn.ark
|
|
||||||
I1130 02:19:32.798542 12937 feature_pipeline.h:58] dither: 0
|
|
||||||
I1130 02:19:32.798583 12937 feature_pipeline.h:60] frame shift ms: 10
|
|
||||||
I1130 02:19:32.798588 12937 feature_pipeline.h:62] feature type: linear
|
|
||||||
I1130 02:19:32.798596 12937 feature_pipeline.h:80] frame length ms: 20
|
|
||||||
I1130 02:19:32.798601 12937 feature_pipeline.h:88] subsampling rate: 4
|
|
||||||
I1130 02:19:32.798606 12937 feature_pipeline.h:90] nnet receptive filed length: 7
|
|
||||||
I1130 02:19:32.798611 12937 feature_pipeline.h:92] nnet chunk size: 1
|
|
||||||
I1130 02:19:32.798615 12937 feature_pipeline.h:94] frontend fill zeros: 0
|
|
||||||
I1130 02:19:32.798630 12937 nnet_itf.h:52] subsampling rate: 4
|
|
||||||
I1130 02:19:32.798635 12937 nnet_itf.h:54] model path: /workspace/zhanghui/PaddleSpeech/speechx/examples/ds2_ol/websocket/data/model/exp/deepspeech2_online/checkpoints//avg_1.jit.pdmodel
|
|
||||||
I1130 02:19:32.798640 12937 nnet_itf.h:57] param path: /workspace/zhanghui/PaddleSpeech/speechx/examples/ds2_ol/websocket/data/model/exp/deepspeech2_online/checkpoints//avg_1.jit.pdiparams
|
|
||||||
I1130 02:19:32.798643 12937 nnet_itf.h:59] DS2 param:
|
|
||||||
I1130 02:19:32.798647 12937 nnet_itf.h:61] cache names: chunk_state_h_box,chunk_state_c_box
|
|
||||||
I1130 02:19:32.798652 12937 nnet_itf.h:63] cache shape: 5-1-1024,5-1-1024
|
|
||||||
I1130 02:19:32.798656 12937 nnet_itf.h:65] input names: audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box
|
|
||||||
I1130 02:19:32.798660 12937 nnet_itf.h:67] output names: softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0
|
|
||||||
I1130 02:19:32.798664 12937 ctc_tlg_decoder.h:41] fst path: /workspace/zhanghui/PaddleSpeech/speechx/examples/ds2_ol/websocket/data/wfst//TLG.fst
|
|
||||||
I1130 02:19:32.798669 12937 ctc_tlg_decoder.h:42] fst symbole table: /workspace/zhanghui/PaddleSpeech/speechx/examples/ds2_ol/websocket/data/wfst//words.txt
|
|
||||||
I1130 02:19:32.798673 12937 ctc_tlg_decoder.h:47] LatticeFasterDecoder max active: 7500
|
|
||||||
I1130 02:19:32.798677 12937 ctc_tlg_decoder.h:49] LatticeFasterDecoder beam: 15
|
|
||||||
I1130 02:19:32.798681 12937 ctc_tlg_decoder.h:50] LatticeFasterDecoder lattice_beam: 7.5
|
|
||||||
I1130 02:19:32.798708 12937 websocket_server_main.cc:37] Listening at port 8082
|
|
||||||
```
|
|
||||||
|
|
||||||
## Start WebSocket Client
|
|
||||||
|
|
||||||
```bash
|
|
||||||
bash websocket_client.sh
|
|
||||||
```
|
|
||||||
|
|
||||||
This script using AISHELL-1 test data to call websocket server.
|
|
||||||
|
|
||||||
The input is specific by `--wav_rspecifier=scp:$data/$aishell_wav_scp`.
|
|
||||||
|
|
||||||
The `scp` file which look like this:
|
|
||||||
```text
|
|
||||||
# head data/split1/1/aishell_test.scp
|
|
||||||
BAC009S0764W0121 /workspace/PaddleSpeech/speechx/examples/u2pp_ol/wenetspeech/data/test/S0764/BAC009S0764W0121.wav
|
|
||||||
BAC009S0764W0122 /workspace/PaddleSpeech/speechx/examples/u2pp_ol/wenetspeech/data/test/S0764/BAC009S0764W0122.wav
|
|
||||||
...
|
|
||||||
BAC009S0764W0125 /workspace/PaddleSpeech/speechx/examples/u2pp_ol/wenetspeech/data/test/S0764/BAC009S0764W0125.wav
|
|
||||||
```
|
|
||||||
|
|
||||||
If you want to recognize one wav, you can make `scp` file like this:
|
|
||||||
```text
|
|
||||||
key path/to/wav/file
|
|
||||||
```
|
|
@ -1,14 +0,0 @@
|
|||||||
# This contains the locations of binarys build required for running the examples.
|
|
||||||
|
|
||||||
SPEECHX_ROOT=$PWD/../../../
|
|
||||||
SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx
|
|
||||||
|
|
||||||
SPEECHX_TOOLS=$SPEECHX_ROOT/tools
|
|
||||||
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
|
|
||||||
|
|
||||||
[ -d $SPEECHX_BUILD ] || { echo "Error: 'build/speechx' directory not found. please ensure that the project build successfully"; }
|
|
||||||
|
|
||||||
export LC_AL=C
|
|
||||||
|
|
||||||
SPEECHX_BIN=$SPEECHX_BUILD/protocol/websocket:$SPEECHX_BUILD/frontend/audio
|
|
||||||
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN
|
|
@ -1,35 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
set +x
|
|
||||||
set -e
|
|
||||||
|
|
||||||
. path.sh
|
|
||||||
|
|
||||||
# 1. compile
|
|
||||||
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
|
|
||||||
pushd ${SPEECHX_ROOT}
|
|
||||||
bash build.sh
|
|
||||||
popd
|
|
||||||
fi
|
|
||||||
|
|
||||||
# input
|
|
||||||
mkdir -p data
|
|
||||||
data=$PWD/data
|
|
||||||
|
|
||||||
# output
|
|
||||||
aishell_wav_scp=aishell_test.scp
|
|
||||||
if [ ! -d $data/test ]; then
|
|
||||||
pushd $data
|
|
||||||
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_test.zip
|
|
||||||
unzip aishell_test.zip
|
|
||||||
popd
|
|
||||||
|
|
||||||
realpath $data/test/*/*.wav > $data/wavlist
|
|
||||||
awk -F '/' '{ print $(NF) }' $data/wavlist | awk -F '.' '{ print $1 }' > $data/utt_id
|
|
||||||
paste $data/utt_id $data/wavlist > $data/$aishell_wav_scp
|
|
||||||
fi
|
|
||||||
|
|
||||||
export GLOG_logtostderr=1
|
|
||||||
|
|
||||||
# websocket client
|
|
||||||
websocket_client_main \
|
|
||||||
--wav_rspecifier=scp:$data/$aishell_wav_scp --streaming_chunk=0.5
|
|
@ -1,55 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
set +x
|
|
||||||
set -e
|
|
||||||
|
|
||||||
. path.sh
|
|
||||||
|
|
||||||
# 1. compile
|
|
||||||
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
|
|
||||||
pushd ${SPEECHX_ROOT}
|
|
||||||
bash build.sh
|
|
||||||
popd
|
|
||||||
fi
|
|
||||||
|
|
||||||
# input
|
|
||||||
mkdir -p data
|
|
||||||
data=$PWD/data
|
|
||||||
ckpt_dir=$data/model
|
|
||||||
model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/
|
|
||||||
vocb_dir=$ckpt_dir/data/lang_char/
|
|
||||||
|
|
||||||
|
|
||||||
if [ ! -f $ckpt_dir/data/mean_std.json ]; then
|
|
||||||
mkdir -p $ckpt_dir
|
|
||||||
pushd $ckpt_dir
|
|
||||||
wget -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
|
|
||||||
tar xzfv asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
|
|
||||||
popd
|
|
||||||
fi
|
|
||||||
|
|
||||||
export GLOG_logtostderr=1
|
|
||||||
|
|
||||||
# 3. gen cmvn
|
|
||||||
cmvn=$data/cmvn.ark
|
|
||||||
cmvn_json2kaldi_main --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn
|
|
||||||
|
|
||||||
|
|
||||||
wfst=$data/wfst/
|
|
||||||
mkdir -p $wfst
|
|
||||||
if [ ! -f $wfst/aishell_graph.zip ]; then
|
|
||||||
pushd $wfst
|
|
||||||
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_graph.zip
|
|
||||||
unzip aishell_graph.zip
|
|
||||||
mv aishell_graph/* $wfst
|
|
||||||
popd
|
|
||||||
fi
|
|
||||||
|
|
||||||
# 5. test websocket server
|
|
||||||
websocket_server_main \
|
|
||||||
--cmvn_file=$cmvn \
|
|
||||||
--model_path=$model_dir/avg_1.jit.pdmodel \
|
|
||||||
--param_path=$model_dir/avg_1.jit.pdiparams \
|
|
||||||
--word_symbol_table=$wfst/words.txt \
|
|
||||||
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
|
|
||||||
--graph_path=$wfst/TLG.fst --max_active=7500 \
|
|
||||||
--acoustic_scale=1.2
|
|
@ -1,6 +0,0 @@
|
|||||||
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
|
|
||||||
|
|
||||||
set(bin_name ds2_model_test_main)
|
|
||||||
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
|
|
||||||
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
|
||||||
target_link_libraries(${bin_name} PUBLIC nnet gflags glog ${DEPS})
|
|
@ -1,207 +0,0 @@
|
|||||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
// deepspeech2 online model info
|
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
#include <fstream>
|
|
||||||
#include <functional>
|
|
||||||
#include <iostream>
|
|
||||||
#include <iterator>
|
|
||||||
#include <numeric>
|
|
||||||
#include <thread>
|
|
||||||
|
|
||||||
#include "base/flags.h"
|
|
||||||
#include "base/log.h"
|
|
||||||
#include "paddle_inference_api.h"
|
|
||||||
|
|
||||||
using std::cout;
|
|
||||||
using std::endl;
|
|
||||||
|
|
||||||
|
|
||||||
DEFINE_string(model_path, "", "xxx.pdmodel");
|
|
||||||
DEFINE_string(param_path, "", "xxx.pdiparams");
|
|
||||||
DEFINE_int32(chunk_size, 35, "feature chunk size, unit:frame");
|
|
||||||
DEFINE_int32(feat_dim, 161, "feature dim");
|
|
||||||
|
|
||||||
|
|
||||||
void produce_data(std::vector<std::vector<float>>* data);
|
|
||||||
void model_forward_test();
|
|
||||||
|
|
||||||
void produce_data(std::vector<std::vector<float>>* data) {
|
|
||||||
int chunk_size = FLAGS_chunk_size; // chunk_size in frame
|
|
||||||
int col_size = FLAGS_feat_dim; // feat dim
|
|
||||||
cout << "chunk size: " << chunk_size << endl;
|
|
||||||
cout << "feat dim: " << col_size << endl;
|
|
||||||
|
|
||||||
data->reserve(chunk_size);
|
|
||||||
data->back().reserve(col_size);
|
|
||||||
for (int row = 0; row < chunk_size; ++row) {
|
|
||||||
data->push_back(std::vector<float>());
|
|
||||||
for (int col_idx = 0; col_idx < col_size; ++col_idx) {
|
|
||||||
data->back().push_back(0.201);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void model_forward_test() {
|
|
||||||
std::cout << "1. read the data" << std::endl;
|
|
||||||
std::vector<std::vector<float>> feats;
|
|
||||||
produce_data(&feats);
|
|
||||||
|
|
||||||
std::cout << "2. load the model" << std::endl;
|
|
||||||
;
|
|
||||||
std::string model_graph = FLAGS_model_path;
|
|
||||||
std::string model_params = FLAGS_param_path;
|
|
||||||
CHECK_NE(model_graph, "");
|
|
||||||
CHECK_NE(model_params, "");
|
|
||||||
cout << "model path: " << model_graph << endl;
|
|
||||||
cout << "model param path : " << model_params << endl;
|
|
||||||
|
|
||||||
paddle_infer::Config config;
|
|
||||||
config.SetModel(model_graph, model_params);
|
|
||||||
config.SwitchIrOptim(false);
|
|
||||||
cout << "SwitchIrOptim: " << false << endl;
|
|
||||||
config.DisableFCPadding();
|
|
||||||
cout << "DisableFCPadding: " << endl;
|
|
||||||
auto predictor = paddle_infer::CreatePredictor(config);
|
|
||||||
|
|
||||||
std::cout << "3. feat shape, row=" << feats.size()
|
|
||||||
<< ",col=" << feats[0].size() << std::endl;
|
|
||||||
std::vector<float> pp_input_mat;
|
|
||||||
for (const auto& item : feats) {
|
|
||||||
pp_input_mat.insert(pp_input_mat.end(), item.begin(), item.end());
|
|
||||||
}
|
|
||||||
|
|
||||||
std::cout << "4. fead the data to model" << std::endl;
|
|
||||||
int row = feats.size();
|
|
||||||
int col = feats[0].size();
|
|
||||||
std::vector<std::string> input_names = predictor->GetInputNames();
|
|
||||||
std::vector<std::string> output_names = predictor->GetOutputNames();
|
|
||||||
for (auto name : input_names) {
|
|
||||||
cout << "model input names: " << name << endl;
|
|
||||||
}
|
|
||||||
for (auto name : output_names) {
|
|
||||||
cout << "model output names: " << name << endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
// input
|
|
||||||
std::unique_ptr<paddle_infer::Tensor> input_tensor =
|
|
||||||
predictor->GetInputHandle(input_names[0]);
|
|
||||||
std::vector<int> INPUT_SHAPE = {1, row, col};
|
|
||||||
input_tensor->Reshape(INPUT_SHAPE);
|
|
||||||
input_tensor->CopyFromCpu(pp_input_mat.data());
|
|
||||||
|
|
||||||
// input length
|
|
||||||
std::unique_ptr<paddle_infer::Tensor> input_len =
|
|
||||||
predictor->GetInputHandle(input_names[1]);
|
|
||||||
std::vector<int> input_len_size = {1};
|
|
||||||
input_len->Reshape(input_len_size);
|
|
||||||
std::vector<int64_t> audio_len;
|
|
||||||
audio_len.push_back(row);
|
|
||||||
input_len->CopyFromCpu(audio_len.data());
|
|
||||||
|
|
||||||
// state_h
|
|
||||||
std::unique_ptr<paddle_infer::Tensor> chunk_state_h_box =
|
|
||||||
predictor->GetInputHandle(input_names[2]);
|
|
||||||
std::vector<int> chunk_state_h_box_shape = {5, 1, 1024};
|
|
||||||
chunk_state_h_box->Reshape(chunk_state_h_box_shape);
|
|
||||||
int chunk_state_h_box_size =
|
|
||||||
std::accumulate(chunk_state_h_box_shape.begin(),
|
|
||||||
chunk_state_h_box_shape.end(),
|
|
||||||
1,
|
|
||||||
std::multiplies<int>());
|
|
||||||
std::vector<float> chunk_state_h_box_data(chunk_state_h_box_size, 0.0f);
|
|
||||||
chunk_state_h_box->CopyFromCpu(chunk_state_h_box_data.data());
|
|
||||||
|
|
||||||
// state_c
|
|
||||||
std::unique_ptr<paddle_infer::Tensor> chunk_state_c_box =
|
|
||||||
predictor->GetInputHandle(input_names[3]);
|
|
||||||
std::vector<int> chunk_state_c_box_shape = {5, 1, 1024};
|
|
||||||
chunk_state_c_box->Reshape(chunk_state_c_box_shape);
|
|
||||||
int chunk_state_c_box_size =
|
|
||||||
std::accumulate(chunk_state_c_box_shape.begin(),
|
|
||||||
chunk_state_c_box_shape.end(),
|
|
||||||
1,
|
|
||||||
std::multiplies<int>());
|
|
||||||
std::vector<float> chunk_state_c_box_data(chunk_state_c_box_size, 0.0f);
|
|
||||||
chunk_state_c_box->CopyFromCpu(chunk_state_c_box_data.data());
|
|
||||||
|
|
||||||
// run
|
|
||||||
bool success = predictor->Run();
|
|
||||||
|
|
||||||
// state_h out
|
|
||||||
std::unique_ptr<paddle_infer::Tensor> h_out =
|
|
||||||
predictor->GetOutputHandle(output_names[2]);
|
|
||||||
std::vector<int> h_out_shape = h_out->shape();
|
|
||||||
int h_out_size = std::accumulate(
|
|
||||||
h_out_shape.begin(), h_out_shape.end(), 1, std::multiplies<int>());
|
|
||||||
std::vector<float> h_out_data(h_out_size);
|
|
||||||
h_out->CopyToCpu(h_out_data.data());
|
|
||||||
|
|
||||||
// stage_c out
|
|
||||||
std::unique_ptr<paddle_infer::Tensor> c_out =
|
|
||||||
predictor->GetOutputHandle(output_names[3]);
|
|
||||||
std::vector<int> c_out_shape = c_out->shape();
|
|
||||||
int c_out_size = std::accumulate(
|
|
||||||
c_out_shape.begin(), c_out_shape.end(), 1, std::multiplies<int>());
|
|
||||||
std::vector<float> c_out_data(c_out_size);
|
|
||||||
c_out->CopyToCpu(c_out_data.data());
|
|
||||||
|
|
||||||
// output tensor
|
|
||||||
std::unique_ptr<paddle_infer::Tensor> output_tensor =
|
|
||||||
predictor->GetOutputHandle(output_names[0]);
|
|
||||||
std::vector<int> output_shape = output_tensor->shape();
|
|
||||||
std::vector<float> output_probs;
|
|
||||||
int output_size = std::accumulate(
|
|
||||||
output_shape.begin(), output_shape.end(), 1, std::multiplies<int>());
|
|
||||||
output_probs.resize(output_size);
|
|
||||||
output_tensor->CopyToCpu(output_probs.data());
|
|
||||||
row = output_shape[1];
|
|
||||||
col = output_shape[2];
|
|
||||||
|
|
||||||
// probs
|
|
||||||
std::vector<std::vector<float>> probs;
|
|
||||||
probs.reserve(row);
|
|
||||||
for (int i = 0; i < row; i++) {
|
|
||||||
probs.push_back(std::vector<float>());
|
|
||||||
probs.back().reserve(col);
|
|
||||||
|
|
||||||
for (int j = 0; j < col; j++) {
|
|
||||||
probs.back().push_back(output_probs[i * col + j]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<std::vector<float>> log_feat = probs;
|
|
||||||
std::cout << "probs, row: " << log_feat.size()
|
|
||||||
<< " col: " << log_feat[0].size() << std::endl;
|
|
||||||
for (size_t row_idx = 0; row_idx < log_feat.size(); ++row_idx) {
|
|
||||||
for (size_t col_idx = 0; col_idx < log_feat[row_idx].size();
|
|
||||||
++col_idx) {
|
|
||||||
std::cout << log_feat[row_idx][col_idx] << " ";
|
|
||||||
}
|
|
||||||
std::cout << std::endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
int main(int argc, char* argv[]) {
|
|
||||||
gflags::SetUsageMessage("Usage:");
|
|
||||||
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
|
||||||
google::InitGoogleLogging(argv[0]);
|
|
||||||
google::InstallFailureSignalHandler();
|
|
||||||
FLAGS_logtostderr = 1;
|
|
||||||
|
|
||||||
model_forward_test();
|
|
||||||
return 0;
|
|
||||||
}
|
|
Loading…
Reference in new issue