From acf1d27230bdeb3144dfa88da7843cb22ea0aa9c Mon Sep 17 00:00:00 2001 From: YangZhou <56786796+SmileGoat@users.noreply.github.com> Date: Fri, 30 Dec 2022 15:54:26 +0800 Subject: [PATCH] [speechx] rm ds2 && rm boost (#2786) * fix openfst download error * add acknowledgments of openfst * refactor directory * clean ctc_decoders dir * add nnet cache && make 2 thread work * do not compile websocket * rm ds2 && rm boost * rm ds2 example --- .pre-commit-config.yaml | 4 +- speechx/CMakeLists.txt | 18 - speechx/build.sh | 17 +- speechx/examples/ds2_ol/README.md | 7 - speechx/examples/ds2_ol/aishell/.gitignore | 3 - speechx/examples/ds2_ol/aishell/README.md | 133 - .../ds2_ol/aishell/local/aishell_train_lms.sh | 71 - .../ds2_ol/aishell/local/run_build_tlg.sh | 145 - .../ds2_ol/aishell/local/split_data.sh | 30 - speechx/examples/ds2_ol/aishell/path.sh | 24 - speechx/examples/ds2_ol/aishell/run.sh | 180 -- speechx/examples/ds2_ol/aishell/run_fbank.sh | 177 -- speechx/examples/ds2_ol/aishell/utils | 1 - speechx/examples/ds2_ol/onnx/.gitignore | 3 - speechx/examples/ds2_ol/onnx/README.md | 57 - .../examples/ds2_ol/onnx/local/infer_check.py | 100 - speechx/examples/ds2_ol/onnx/local/netron.sh | 14 - .../examples/ds2_ol/onnx/local/onnx_clone.sh | 7 - .../ds2_ol/onnx/local/onnx_convert_opset.py | 37 - .../ds2_ol/onnx/local/onnx_infer_shape.py | 2514 ----------------- .../examples/ds2_ol/onnx/local/onnx_opt.sh | 20 - .../ds2_ol/onnx/local/onnx_prune_model.py | 128 - .../ds2_ol/onnx/local/onnx_rename_model.py | 111 - .../ds2_ol/onnx/local/ort_dyanmic_quant.py | 48 - speechx/examples/ds2_ol/onnx/local/ort_opt.py | 45 - speechx/examples/ds2_ol/onnx/local/tonnx.sh | 26 - speechx/examples/ds2_ol/onnx/path.sh | 14 - speechx/examples/ds2_ol/onnx/run.sh | 91 - speechx/examples/ds2_ol/onnx/utils | 1 - speechx/examples/ds2_ol/websocket/.gitignore | 2 - speechx/examples/ds2_ol/websocket/README.md | 78 - speechx/examples/ds2_ol/websocket/path.sh | 14 - .../ds2_ol/websocket/websocket_client.sh | 35 - .../ds2_ol/websocket/websocket_server.sh | 55 - speechx/examples/u2pp_ol/wenetspeech/path.sh | 4 +- speechx/speechx/asr/decoder/CMakeLists.txt | 59 +- .../asr/decoder/ctc_beam_search_decoder.cc | 313 -- .../asr/decoder/ctc_beam_search_decoder.h | 73 - .../decoder/ctc_beam_search_decoder_main.cc | 167 -- .../asr/decoder/ctc_decoders/.gitignore | 9 - .../ctc_decoders/ctc_beam_search_decoder.cpp | 607 ---- .../ctc_decoders/ctc_beam_search_decoder.h | 175 -- .../ctc_decoders/ctc_greedy_decoder.cpp | 61 - .../decoder/ctc_decoders/ctc_greedy_decoder.h | 35 - .../decoder/ctc_decoders/decoder_utils.cpp | 193 -- .../asr/decoder/ctc_decoders/decoder_utils.h | 111 - .../asr/decoder/ctc_decoders/path_trie.cpp | 164 -- .../asr/decoder/ctc_decoders/path_trie.h | 82 - .../asr/decoder/ctc_decoders/scorer.cpp | 232 -- .../speechx/asr/decoder/ctc_decoders/scorer.h | 114 - .../asr/decoder/nnet_logprob_decoder_main.cc | 77 - speechx/speechx/asr/decoder/param.h | 3 +- speechx/speechx/asr/nnet/CMakeLists.txt | 24 +- speechx/speechx/asr/nnet/ds2_nnet.cc | 218 -- speechx/speechx/asr/nnet/ds2_nnet.h | 97 - speechx/speechx/asr/nnet/ds2_nnet_main.cc | 142 - speechx/speechx/asr/nnet/nnet_producer.cc | 1 - speechx/speechx/asr/recognizer/CMakeLists.txt | 50 +- speechx/speechx/asr/recognizer/recognizer.cc | 70 - speechx/speechx/asr/recognizer/recognizer.h | 70 - .../speechx/asr/recognizer/recognizer_main.cc | 105 - speechx/speechx/codelab/CMakeLists.txt | 1 - speechx/speechx/codelab/nnet/CMakeLists.txt | 6 - .../codelab/nnet/ds2_model_test_main.cc | 207 -- .../frontend/audio/cmvn_json2kaldi_main.cc | 46 +- speechx/speechx/common/utils/picojson.h | 1202 ++++++++ 66 files changed, 1265 insertions(+), 7663 deletions(-) delete mode 100644 speechx/examples/ds2_ol/README.md delete mode 100644 speechx/examples/ds2_ol/aishell/.gitignore delete mode 100644 speechx/examples/ds2_ol/aishell/README.md delete mode 100755 speechx/examples/ds2_ol/aishell/local/aishell_train_lms.sh delete mode 100755 speechx/examples/ds2_ol/aishell/local/run_build_tlg.sh delete mode 100755 speechx/examples/ds2_ol/aishell/local/split_data.sh delete mode 100755 speechx/examples/ds2_ol/aishell/path.sh delete mode 100755 speechx/examples/ds2_ol/aishell/run.sh delete mode 100755 speechx/examples/ds2_ol/aishell/run_fbank.sh delete mode 120000 speechx/examples/ds2_ol/aishell/utils delete mode 100644 speechx/examples/ds2_ol/onnx/.gitignore delete mode 100644 speechx/examples/ds2_ol/onnx/README.md delete mode 100755 speechx/examples/ds2_ol/onnx/local/infer_check.py delete mode 100755 speechx/examples/ds2_ol/onnx/local/netron.sh delete mode 100755 speechx/examples/ds2_ol/onnx/local/onnx_clone.sh delete mode 100755 speechx/examples/ds2_ol/onnx/local/onnx_convert_opset.py delete mode 100755 speechx/examples/ds2_ol/onnx/local/onnx_infer_shape.py delete mode 100755 speechx/examples/ds2_ol/onnx/local/onnx_opt.sh delete mode 100755 speechx/examples/ds2_ol/onnx/local/onnx_prune_model.py delete mode 100755 speechx/examples/ds2_ol/onnx/local/onnx_rename_model.py delete mode 100755 speechx/examples/ds2_ol/onnx/local/ort_dyanmic_quant.py delete mode 100755 speechx/examples/ds2_ol/onnx/local/ort_opt.py delete mode 100755 speechx/examples/ds2_ol/onnx/local/tonnx.sh delete mode 100755 speechx/examples/ds2_ol/onnx/path.sh delete mode 100755 speechx/examples/ds2_ol/onnx/run.sh delete mode 120000 speechx/examples/ds2_ol/onnx/utils delete mode 100644 speechx/examples/ds2_ol/websocket/.gitignore delete mode 100644 speechx/examples/ds2_ol/websocket/README.md delete mode 100755 speechx/examples/ds2_ol/websocket/path.sh delete mode 100755 speechx/examples/ds2_ol/websocket/websocket_client.sh delete mode 100755 speechx/examples/ds2_ol/websocket/websocket_server.sh delete mode 100644 speechx/speechx/asr/decoder/ctc_beam_search_decoder.cc delete mode 100644 speechx/speechx/asr/decoder/ctc_beam_search_decoder.h delete mode 100644 speechx/speechx/asr/decoder/ctc_beam_search_decoder_main.cc delete mode 100644 speechx/speechx/asr/decoder/ctc_decoders/.gitignore delete mode 100644 speechx/speechx/asr/decoder/ctc_decoders/ctc_beam_search_decoder.cpp delete mode 100644 speechx/speechx/asr/decoder/ctc_decoders/ctc_beam_search_decoder.h delete mode 100644 speechx/speechx/asr/decoder/ctc_decoders/ctc_greedy_decoder.cpp delete mode 100644 speechx/speechx/asr/decoder/ctc_decoders/ctc_greedy_decoder.h delete mode 100644 speechx/speechx/asr/decoder/ctc_decoders/decoder_utils.cpp delete mode 100644 speechx/speechx/asr/decoder/ctc_decoders/decoder_utils.h delete mode 100644 speechx/speechx/asr/decoder/ctc_decoders/path_trie.cpp delete mode 100644 speechx/speechx/asr/decoder/ctc_decoders/path_trie.h delete mode 100644 speechx/speechx/asr/decoder/ctc_decoders/scorer.cpp delete mode 100644 speechx/speechx/asr/decoder/ctc_decoders/scorer.h delete mode 100644 speechx/speechx/asr/decoder/nnet_logprob_decoder_main.cc delete mode 100644 speechx/speechx/asr/nnet/ds2_nnet.cc delete mode 100644 speechx/speechx/asr/nnet/ds2_nnet.h delete mode 100644 speechx/speechx/asr/nnet/ds2_nnet_main.cc delete mode 100644 speechx/speechx/asr/recognizer/recognizer.cc delete mode 100644 speechx/speechx/asr/recognizer/recognizer.h delete mode 100644 speechx/speechx/asr/recognizer/recognizer_main.cc delete mode 100644 speechx/speechx/codelab/nnet/CMakeLists.txt delete mode 100644 speechx/speechx/codelab/nnet/ds2_model_test_main.cc create mode 100644 speechx/speechx/common/utils/picojson.h diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 15b842d55..994619478 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -57,13 +57,13 @@ repos: entry: bash .pre-commit-hooks/clang-format.hook -i language: system files: \.(h\+\+|h|hh|hxx|hpp|cuh|c|cc|cpp|cu|c\+\+|cxx|tpp|txx)$ - exclude: (?=speechx/speechx/kaldi|audio/paddleaudio/src|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin|third_party/ctc_decoders).*(\.cpp|\.cc|\.h|\.hpp|\.py)$ + exclude: (?=speechx/speechx/kaldi|audio/paddleaudio/src|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin|third_party/ctc_decoders|speechx/speechx/common/utils).*(\.cpp|\.cc|\.h|\.hpp|\.py)$ - id: cpplint name: cpplint description: Static code analysis of C/C++ files language: python files: \.(h\+\+|h|hh|hxx|hpp|cuh|c|cc|cpp|cu|c\+\+|cxx|tpp|txx)$ - exclude: (?=speechx/speechx/kaldi|audio/paddleaudio/src|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin|third_party/ctc_decoders).*(\.cpp|\.cc|\.h|\.hpp|\.py)$ + exclude: (?=speechx/speechx/kaldi|audio/paddleaudio/src|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin|third_party/ctc_decoders|speechx/speechx/common/utils).*(\.cpp|\.cc|\.h|\.hpp|\.py)$ entry: cpplint --filter=-build,-whitespace,+whitespace/comma,-whitespace/indent - repo: https://github.com/asottile/reorder_python_imports rev: v2.4.0 diff --git a/speechx/CMakeLists.txt b/speechx/CMakeLists.txt index 45bf54194..cfce63dd9 100644 --- a/speechx/CMakeLists.txt +++ b/speechx/CMakeLists.txt @@ -44,9 +44,6 @@ option(TEST_DEBUG "option for debug" OFF) option(USE_PROFILING "enable c++ profling" OFF) option(WITH_TESTING "unit test" ON) -option(USING_U2 "compile u2 model." ON) -option(USING_DS2 "compile with ds2 model." OFF) - option(USING_GPU "u2 compute on GPU." OFF) ############################################################################### @@ -56,21 +53,6 @@ include(gflags) include(glog) -# boost -# include(boost) # not work -set(boost_SOURCE_DIR ${fc_patch}/boost-src) -set(BOOST_ROOT ${boost_SOURCE_DIR}) -include_directories(${boost_SOURCE_DIR}) -link_directories(${boost_SOURCE_DIR}/stage/lib) - -# Eigen -include(eigen) -find_package(Eigen3 REQUIRED) - -# Kenlm -include(kenlm) -add_dependencies(kenlm eigen boost) - #openblas include(openblas) diff --git a/speechx/build.sh b/speechx/build.sh index 7655f9635..94d250f5a 100755 --- a/speechx/build.sh +++ b/speechx/build.sh @@ -4,20 +4,5 @@ set -xe # the build script had verified in the paddlepaddle docker image. # please follow the instruction below to install PaddlePaddle image. # https://www.paddlepaddle.org.cn/documentation/docs/zh/install/docker/linux-docker.html -boost_SOURCE_DIR=$PWD/fc_patch/boost-src -if [ ! -d ${boost_SOURCE_DIR} ]; then wget -c https://boostorg.jfrog.io/artifactory/main/release/1.75.0/source/boost_1_75_0.tar.gz - tar xzfv boost_1_75_0.tar.gz - mkdir -p $PWD/fc_patch - mv boost_1_75_0 ${boost_SOURCE_DIR} - cd ${boost_SOURCE_DIR} - bash ./bootstrap.sh - ./b2 - cd - - echo -e "\n" -fi - -#rm -rf build -mkdir -p build - -cmake -B build -DBOOST_ROOT:STRING=${boost_SOURCE_DIR} +cmake -B build cmake --build build -j diff --git a/speechx/examples/ds2_ol/README.md b/speechx/examples/ds2_ol/README.md deleted file mode 100644 index d1da96cc9..000000000 --- a/speechx/examples/ds2_ol/README.md +++ /dev/null @@ -1,7 +0,0 @@ -# Deepspeech2 Streaming ASR - -## Examples - -* `websocket` - Streaming ASR with websocket for deepspeech2_aishell. -* `aishell` - Streaming Decoding under aishell dataset, for local WER test. -* `onnx` - Example to convert deepspeech2 to onnx format. diff --git a/speechx/examples/ds2_ol/aishell/.gitignore b/speechx/examples/ds2_ol/aishell/.gitignore deleted file mode 100644 index 68f993b47..000000000 --- a/speechx/examples/ds2_ol/aishell/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -data -exp -aishell_* diff --git a/speechx/examples/ds2_ol/aishell/README.md b/speechx/examples/ds2_ol/aishell/README.md deleted file mode 100644 index 2ee0bbca9..000000000 --- a/speechx/examples/ds2_ol/aishell/README.md +++ /dev/null @@ -1,133 +0,0 @@ -# Aishell - Deepspeech2 Streaming - -> We recommend using U2/U2++ model instead of DS2, please see [here](../../u2pp_ol/wenetspeech/). - -A C++ deployment example for using the deepspeech2 model to recognize `wav` and compute `CER`. We using AISHELL-1 as test data. - -## Source path.sh - -```bash -. path.sh -``` - -SpeechX bins is under `echo $SPEECHX_BUILD`, more info please see `path.sh`. - -## Recognize with linear feature - -```bash -bash run.sh -``` - -`run.sh` has multi stage, for details please see `run.sh`: - -1. donwload dataset, model and lm -2. convert cmvn format and compute feature -3. decode w/o lm by feature -4. decode w/ ngram lm by feature -5. decode w/ TLG graph by feature -6. recognize w/ TLG graph by wav input - -### Recognize with `.scp` file for wav - -This sciprt using `recognizer_main` to recognize wav file. - -The input is `scp` file which look like this: -```text -# head data/split1/1/aishell_test.scp -BAC009S0764W0121 /workspace/PaddleSpeech/speechx/examples/u2pp_ol/wenetspeech/data/test/S0764/BAC009S0764W0121.wav -BAC009S0764W0122 /workspace/PaddleSpeech/speechx/examples/u2pp_ol/wenetspeech/data/test/S0764/BAC009S0764W0122.wav -... -BAC009S0764W0125 /workspace/PaddleSpeech/speechx/examples/u2pp_ol/wenetspeech/data/test/S0764/BAC009S0764W0125.wav -``` - -If you want to recognize one wav, you can make `scp` file like this: -```text -key path/to/wav/file -``` - -Then specify `--wav_rspecifier=` param for `recognizer_main` bin. For other flags meaning, please see `help`: -```bash -recognizer_main --help -``` - -For the exmaple to using `recognizer_main` please see `run.sh`. - - -### CTC Prefix Beam Search w/o LM - -``` -Overall -> 16.14 % N=104612 C=88190 S=16110 D=312 I=465 -Mandarin -> 16.14 % N=104612 C=88190 S=16110 D=312 I=465 -Other -> 0.00 % N=0 C=0 S=0 D=0 I=0 -``` - -### CTC Prefix Beam Search w/ LM - -LM: zh_giga.no_cna_cmn.prune01244.klm -``` -Overall -> 7.86 % N=104768 C=96865 S=7573 D=330 I=327 -Mandarin -> 7.86 % N=104768 C=96865 S=7573 D=330 I=327 -Other -> 0.00 % N=0 C=0 S=0 D=0 I=0 -``` - -### CTC TLG WFST - -LM: [aishell train](http://paddlespeech.bj.bcebos.com/speechx/examples/ds2_ol/aishell/aishell_graph.zip) ---acoustic_scale=1.2 -``` -Overall -> 11.14 % N=103017 C=93363 S=9583 D=71 I=1819 -Mandarin -> 11.14 % N=103017 C=93363 S=9583 D=71 I=1818 -Other -> 0.00 % N=0 C=0 S=0 D=0 I=1 -``` - -LM: [wenetspeech](http://paddlespeech.bj.bcebos.com/speechx/examples/ds2_ol/aishell/wenetspeech_graph.zip) ---acoustic_scale=1.5 -``` -Overall -> 10.93 % N=104765 C=93410 S=9780 D=1575 I=95 -Mandarin -> 10.93 % N=104762 C=93410 S=9779 D=1573 I=95 -Other -> 100.00 % N=3 C=0 S=1 D=2 I=0 -``` - -## Recognize with fbank feature - -This script is same to `run.sh`, but using fbank feature. - -```bash -bash run_fbank.sh -``` - -### CTC Prefix Beam Search w/o LM - -``` -Overall -> 10.44 % N=104765 C=94194 S=10174 D=397 I=369 -Mandarin -> 10.44 % N=104762 C=94194 S=10171 D=397 I=369 -Other -> 100.00 % N=3 C=0 S=3 D=0 I=0 -``` - -### CTC Prefix Beam Search w/ LM - -LM: zh_giga.no_cna_cmn.prune01244.klm - -``` -Overall -> 5.82 % N=104765 C=99386 S=4944 D=435 I=720 -Mandarin -> 5.82 % N=104762 C=99386 S=4941 D=435 I=720 -English -> 0.00 % N=0 C=0 S=0 D=0 I=0 -``` - -### CTC TLG WFST - -LM: [aishell train](https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_graph2.zip) -``` -Overall -> 9.58 % N=104765 C=94817 S=4326 D=5622 I=84 -Mandarin -> 9.57 % N=104762 C=94817 S=4325 D=5620 I=84 -Other -> 100.00 % N=3 C=0 S=1 D=2 I=0 -``` - -## Build TLG WFST graph - -The script is for building TLG wfst graph, depending on `srilm`, please make sure it is installed. -For more information please see the script below. - -```bash - bash ./local/run_build_tlg.sh -``` diff --git a/speechx/examples/ds2_ol/aishell/local/aishell_train_lms.sh b/speechx/examples/ds2_ol/aishell/local/aishell_train_lms.sh deleted file mode 100755 index 544a1f59a..000000000 --- a/speechx/examples/ds2_ol/aishell/local/aishell_train_lms.sh +++ /dev/null @@ -1,71 +0,0 @@ -#!/bin/bash - -# To be run from one directory above this script. -. ./path.sh - -nj=40 -text=data/local/lm/text -lexicon=data/local/dict/lexicon.txt - -for f in "$text" "$lexicon"; do - [ ! -f $x ] && echo "$0: No such file $f" && exit 1; -done - -# Check SRILM tools -if ! which ngram-count > /dev/null; then - echo "srilm tools are not found, please download it and install it from: " - echo "http://www.speech.sri.com/projects/srilm/download.html" - echo "Then add the tools to your PATH" - exit 1 -fi - -# This script takes no arguments. It assumes you have already run -# aishell_data_prep.sh. -# It takes as input the files -# data/local/lm/text -# data/local/dict/lexicon.txt -dir=data/local/lm -mkdir -p $dir - -cleantext=$dir/text.no_oov - -# oov to -# lexicon line: word char0 ... charn -# text line: utt word0 ... wordn -> line: word0 ... wordn -text_dir=$(dirname $text) -split_name=$(basename $text) -./local/split_data.sh $text_dir $text $split_name $nj - -utils/run.pl JOB=1:$nj $text_dir/split${nj}/JOB/${split_name}.no_oov.log \ - cat ${text_dir}/split${nj}/JOB/${split_name} \| awk -v lex=$lexicon 'BEGIN{while((getline0){ seen[$1]=1; } } - {for(n=1; n<=NF;n++) { if (seen[$n]) { printf("%s ", $n); } else {printf(" ");} } printf("\n");}' \ - \> ${text_dir}/split${nj}/JOB/${split_name}.no_oov || exit 1; -cat ${text_dir}/split${nj}/*/${split_name}.no_oov > $cleantext - -# compute word counts, sort in descending order -# line: count word -cat $cleantext | awk '{for(n=2;n<=NF;n++) print $n; }' | sort --parallel=`nproc` | uniq -c | \ - sort --parallel=`nproc` -nr > $dir/word.counts || exit 1; - -# Get counts from acoustic training transcripts, and add one-count -# for each word in the lexicon (but not silence, we don't want it -# in the LM-- we'll add it optionally later). -cat $cleantext | awk '{for(n=2;n<=NF;n++) print $n; }' | \ - cat - <(grep -w -v '!SIL' $lexicon | awk '{print $1}') | \ - sort --parallel=`nproc` | uniq -c | sort --parallel=`nproc` -nr > $dir/unigram.counts || exit 1; - -# word with -cat $dir/unigram.counts | awk '{print $2}' | cat - <(echo ""; echo "" ) > $dir/wordlist - -# hold out to compute ppl -heldout_sent=10000 # Don't change this if you want result to be comparable with kaldi_lm results - -mkdir -p $dir -cat $cleantext | awk '{for(n=2;n<=NF;n++){ printf $n; if(n $dir/heldout -cat $cleantext | awk '{for(n=2;n<=NF;n++){ printf $n; if(n $dir/train - -ngram-count -text $dir/train -order 3 -limit-vocab -vocab $dir/wordlist -unk \ - -map-unk "" -kndiscount -interpolate -lm $dir/lm.arpa -ngram -lm $dir/lm.arpa -ppl $dir/heldout \ No newline at end of file diff --git a/speechx/examples/ds2_ol/aishell/local/run_build_tlg.sh b/speechx/examples/ds2_ol/aishell/local/run_build_tlg.sh deleted file mode 100755 index 07f47c7ea..000000000 --- a/speechx/examples/ds2_ol/aishell/local/run_build_tlg.sh +++ /dev/null @@ -1,145 +0,0 @@ -#!/bin/bash -set -eo pipefail - -. path.sh - -# attention, please replace the vocab is only for this script. -# different acustic model has different vocab -ckpt_dir=data/fbank_model -unit=$ckpt_dir/data/lang_char/vocab.txt # vocab file, line: char/spm_pice -model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/ - -stage=-1 -stop_stage=100 -corpus=aishell -lexicon=data/lexicon.txt # line: word ph0 ... phn, aishell/resource_aishell/lexicon.txt -text=data/text # line: utt text, aishell/data_aishell/transcript/aishell_transcript_v0.8.txt - -. utils/parse_options.sh - -data=$PWD/data -mkdir -p $data - -if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then - if [ ! -f $data/speech.ngram.zh.tar.gz ];then - # download ngram - pushd $data - wget -c http://paddlespeech.bj.bcebos.com/speechx/examples/ngram/zh/speech.ngram.zh.tar.gz - tar xvzf speech.ngram.zh.tar.gz - popd - fi - - if [ ! -f $ckpt_dir/data/mean_std.json ]; then - # download model - mkdir -p $ckpt_dir - pushd $ckpt_dir - wget -c https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr0/WIP1_asr0_deepspeech2_online_wenetspeech_ckpt_1.0.0a.model.tar.gz - tar xzfv WIP1_asr0_deepspeech2_online_wenetspeech_ckpt_1.0.0a.model.tar.gz - popd - fi -fi - -if [ ! -f $unit ]; then - echo "$0: No such file $unit" - exit 1; -fi - -if ! which ngram-count; then - # need srilm install - pushd $MAIN_ROOT/tools - make srilm.done - popd -fi - -mkdir -p data/local/dict -if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then - # Prepare dict - # line: char/spm_pices - cp $unit data/local/dict/units.txt - - if [ ! -f $lexicon ];then - utils/text_to_lexicon.py --has_key true --text $text --lexicon $lexicon - echo "Generate $lexicon from $text" - fi - - # filter by vocab - # line: word ph0 ... phn -> line: word char0 ... charn - utils/fst/prepare_dict.py \ - --unit_file $unit \ - --in_lexicon ${lexicon} \ - --out_lexicon data/local/dict/lexicon.txt -fi - -lm=data/local/lm -mkdir -p $lm - -if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then - # Train ngram lm - cp $text $lm/text - local/aishell_train_lms.sh - echo "build LM done." -fi - -# build TLG -if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then - # build T & L - utils/fst/compile_lexicon_token_fst.sh \ - data/local/dict data/local/tmp data/local/lang - - # build G & TLG - utils/fst/make_tlg.sh data/local/lm data/local/lang data/lang_test || exit 1; - -fi - -aishell_wav_scp=aishell_test.scp -nj=40 -cmvn=$data/cmvn_fbank.ark -wfst=$data/lang_test - -if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then - if [ ! -d $data/test ]; then - # download test dataset - pushd $data - wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_test.zip - unzip aishell_test.zip - popd - - realpath $data/test/*/*.wav > $data/wavlist - awk -F '/' '{ print $(NF) }' $data/wavlist | awk -F '.' '{ print $1 }' > $data/utt_id - paste $data/utt_id $data/wavlist > $data/$aishell_wav_scp - fi - - ./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj - - # convert cmvn format - cmvn-json2kaldi --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn -fi - -wer=aishell_wer -label_file=aishell_result -export GLOG_logtostderr=1 - -if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then - # recognize w/ TLG graph - utils/run.pl JOB=1:$nj $data/split${nj}/JOB/check_tlg.log \ - recognizer_main \ - --wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \ - --cmvn_file=$cmvn \ - --model_path=$model_dir/avg_5.jit.pdmodel \ - --streaming_chunk=30 \ - --use_fbank=true \ - --param_path=$model_dir/avg_5.jit.pdiparams \ - --word_symbol_table=$wfst/words.txt \ - --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ - --model_cache_shapes="5-1-2048,5-1-2048" \ - --graph_path=$wfst/TLG.fst --max_active=7500 \ - --acoustic_scale=1.2 \ - --result_wspecifier=ark,t:$data/split${nj}/JOB/result_check_tlg - - cat $data/split${nj}/*/result_check_tlg > $exp/${label_file}_check_tlg - utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file}_check_tlg > $exp/${wer}.check_tlg - echo "recognizer test have finished!!!" - echo "please checkout in ${exp}/${wer}.check_tlg" -fi - -exit 0 diff --git a/speechx/examples/ds2_ol/aishell/local/split_data.sh b/speechx/examples/ds2_ol/aishell/local/split_data.sh deleted file mode 100755 index 2af6fc5ab..000000000 --- a/speechx/examples/ds2_ol/aishell/local/split_data.sh +++ /dev/null @@ -1,30 +0,0 @@ -#!/usr/bin/env bash - -set -eo pipefail - -data=$1 -scp=$2 -split_name=$3 -numsplit=$4 - -# save in $data/split{n} -# $scp to split -# - -if [[ ! $numsplit -gt 0 ]]; then - echo "Invalid num-split argument"; - exit 1; -fi - -directories=$(for n in `seq $numsplit`; do echo $data/split${numsplit}/$n; done) -scp_splits=$(for n in `seq $numsplit`; do echo $data/split${numsplit}/$n/${split_name}; done) - -# if this mkdir fails due to argument-list being too long, iterate. -if ! mkdir -p $directories >&/dev/null; then - for n in `seq $numsplit`; do - mkdir -p $data/split${numsplit}/$n - done -fi - -echo "utils/split_scp.pl $scp $scp_splits" -utils/split_scp.pl $scp $scp_splits diff --git a/speechx/examples/ds2_ol/aishell/path.sh b/speechx/examples/ds2_ol/aishell/path.sh deleted file mode 100755 index 6e8039350..000000000 --- a/speechx/examples/ds2_ol/aishell/path.sh +++ /dev/null @@ -1,24 +0,0 @@ -# This contains the locations of binarys build required for running the examples. - -MAIN_ROOT=`realpath $PWD/../../../../` -SPEECHX_ROOT=$PWD/../../../ -SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx - -SPEECHX_TOOLS=$SPEECHX_ROOT/tools -TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin - -[ -d $SPEECHX_BUILD ] || { echo "Error: 'build/speechx' directory not found. please ensure that the project build successfully"; } - -export LC_AL=C - -# openfst bin & kaldi bin -KALDI_DIR=$SPEECHX_ROOT/build/speechx/kaldi/ -OPENFST_DIR=$SPEECHX_ROOT/fc_patch/openfst-build/src - -# srilm -export LIBLBFGS=${MAIN_ROOT}/tools/liblbfgs-1.10 -export LD_LIBRARY_PATH=${LD_LIBRARY_PATH:-}:${LIBLBFGS}/lib/.libs -export SRILM=${MAIN_ROOT}/tools/srilm - -SPEECHX_BIN=$SPEECHX_BUILD/decoder:$SPEECHX_BUILD/frontend/audio -export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN:${SRILM}/bin:${SRILM}/bin/i686-m64:$KALDI_DIR/lmbin:$KALDI_DIR/fstbin:$OPENFST_DIR/bin diff --git a/speechx/examples/ds2_ol/aishell/run.sh b/speechx/examples/ds2_ol/aishell/run.sh deleted file mode 100755 index 49438cb25..000000000 --- a/speechx/examples/ds2_ol/aishell/run.sh +++ /dev/null @@ -1,180 +0,0 @@ -#!/bin/bash -set -x -set -e - -. path.sh - -nj=40 -stage=0 -stop_stage=100 - -. utils/parse_options.sh - -# 1. compile -if [ ! -d ${SPEECHX_BUILD} ]; then - pushd ${SPEECHX_ROOT} - bash build.sh - popd -fi - -# input -mkdir -p data -data=$PWD/data - -ckpt_dir=$data/model -model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/ -vocb_dir=$ckpt_dir/data/lang_char/ - -# output -mkdir -p exp -exp=$PWD/exp - -aishell_wav_scp=aishell_test.scp -if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ];then - if [ ! -d $data/test ]; then - # donwload dataset - pushd $data - wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_test.zip - unzip aishell_test.zip - popd - - realpath $data/test/*/*.wav > $data/wavlist - awk -F '/' '{ print $(NF) }' $data/wavlist | awk -F '.' '{ print $1 }' > $data/utt_id - paste $data/utt_id $data/wavlist > $data/$aishell_wav_scp - fi - - if [ ! -f $ckpt_dir/data/mean_std.json ]; then - # download model - mkdir -p $ckpt_dir - pushd $ckpt_dir - wget -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz - tar xzfv asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz - popd - fi - - lm=$data/zh_giga.no_cna_cmn.prune01244.klm - if [ ! -f $lm ]; then - # download kenlm bin - pushd $data - wget -c https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm - popd - fi -fi - -# 3. make feature -text=$data/test/text -label_file=./aishell_result -wer=./aishell_wer - -export GLOG_logtostderr=1 - - -cmvn=$data/cmvn.ark -if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then - # 3. convert cmvn format and compute linear feat - cmvn_json2kaldi_main --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn - - ./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj - - utils/run.pl JOB=1:$nj $data/split${nj}/JOB/feat.log \ - compute_linear_spectrogram_main \ - --wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \ - --feature_wspecifier=ark,scp:$data/split${nj}/JOB/feat.ark,$data/split${nj}/JOB/feat.scp \ - --cmvn_file=$cmvn \ - echo "feature make have finished!!!" -fi - -if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then - # decode w/o lm - utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.wolm.log \ - ctc_beam_search_decoder_main \ - --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \ - --model_path=$model_dir/avg_1.jit.pdmodel \ - --param_path=$model_dir/avg_1.jit.pdiparams \ - --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ - --nnet_decoder_chunk=8 \ - --dict_file=$vocb_dir/vocab.txt \ - --result_wspecifier=ark,t:$data/split${nj}/JOB/result - - cat $data/split${nj}/*/result > $exp/${label_file} - utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file} > $exp/${wer} - echo "ctc-prefix-beam-search-decoder-ol without lm has finished!!!" - echo "please checkout in ${exp}/${wer}" - tail -n 7 $exp/${wer} -fi - -if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then - # decode w/ ngram lm with feature input - utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.lm.log \ - ctc_beam_search_decoder_main \ - --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \ - --model_path=$model_dir/avg_1.jit.pdmodel \ - --param_path=$model_dir/avg_1.jit.pdiparams \ - --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ - --nnet_decoder_chunk=8 \ - --dict_file=$vocb_dir/vocab.txt \ - --lm_path=$lm \ - --result_wspecifier=ark,t:$data/split${nj}/JOB/result_lm - - cat $data/split${nj}/*/result_lm > $exp/${label_file}_lm - utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file}_lm > $exp/${wer}.lm - echo "ctc-prefix-beam-search-decoder-ol with lm test has finished!!!" - echo "please checkout in ${exp}/${wer}.lm" - tail -n 7 $exp/${wer}.lm -fi - -wfst=$data/wfst/ -if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then - mkdir -p $wfst - if [ ! -f $wfst/aishell_graph.zip ]; then - # download TLG graph - pushd $wfst - wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_graph.zip - unzip aishell_graph.zip - mv aishell_graph/* $wfst - popd - fi -fi - -if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then - # decoder w/ TLG graph with feature input - utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.wfst.log \ - ctc_tlg_decoder_main \ - --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \ - --model_path=$model_dir/avg_1.jit.pdmodel \ - --param_path=$model_dir/avg_1.jit.pdiparams \ - --word_symbol_table=$wfst/words.txt \ - --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ - --graph_path=$wfst/TLG.fst --max_active=7500 \ - --nnet_decoder_chunk=8 \ - --acoustic_scale=1.2 \ - --result_wspecifier=ark,t:$data/split${nj}/JOB/result_tlg - - cat $data/split${nj}/*/result_tlg > $exp/${label_file}_tlg - utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file}_tlg > $exp/${wer}.tlg - echo "wfst-decoder-ol have finished!!!" - echo "please checkout in ${exp}/${wer}.tlg" - tail -n 7 $exp/${wer}.tlg -fi - -if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then - # recognize from wav file w/ TLG graph - utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recognizer.log \ - recognizer_main \ - --wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \ - --cmvn_file=$cmvn \ - --model_path=$model_dir/avg_1.jit.pdmodel \ - --param_path=$model_dir/avg_1.jit.pdiparams \ - --word_symbol_table=$wfst/words.txt \ - --nnet_decoder_chunk=8 \ - --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ - --graph_path=$wfst/TLG.fst --max_active=7500 \ - --acoustic_scale=1.2 \ - --result_wspecifier=ark,t:$data/split${nj}/JOB/result_recognizer - - cat $data/split${nj}/*/result_recognizer > $exp/${label_file}_recognizer - utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file}_recognizer > $exp/${wer}.recognizer - echo "recognizer test have finished!!!" - echo "please checkout in ${exp}/${wer}.recognizer" - tail -n 7 $exp/${wer}.recognizer -fi \ No newline at end of file diff --git a/speechx/examples/ds2_ol/aishell/run_fbank.sh b/speechx/examples/ds2_ol/aishell/run_fbank.sh deleted file mode 100755 index b93d6944d..000000000 --- a/speechx/examples/ds2_ol/aishell/run_fbank.sh +++ /dev/null @@ -1,177 +0,0 @@ -#!/bin/bash -set +x -set -e - -. path.sh - -nj=40 -stage=0 -stop_stage=5 - -. utils/parse_options.sh - -# 1. compile -if [ ! -d ${SPEECHX_EXAMPLES} ]; then - pushd ${SPEECHX_ROOT} - bash build.sh - popd -fi - -# input -mkdir -p data -data=$PWD/data - -ckpt_dir=$data/fbank_model -model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/ -vocb_dir=$ckpt_dir/data/lang_char/ - -# output -mkdir -p exp -exp=$PWD/exp - -lm=$data/zh_giga.no_cna_cmn.prune01244.klm -aishell_wav_scp=aishell_test.scp -if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ];then - if [ ! -d $data/test ]; then - pushd $data - wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_test.zip - unzip aishell_test.zip - popd - - realpath $data/test/*/*.wav > $data/wavlist - awk -F '/' '{ print $(NF) }' $data/wavlist | awk -F '.' '{ print $1 }' > $data/utt_id - paste $data/utt_id $data/wavlist > $data/$aishell_wav_scp - fi - - if [ ! -f $ckpt_dir/data/mean_std.json ]; then - mkdir -p $ckpt_dir - pushd $ckpt_dir - wget -c https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr0/WIP1_asr0_deepspeech2_online_wenetspeech_ckpt_1.0.0a.model.tar.gz - tar xzfv WIP1_asr0_deepspeech2_online_wenetspeech_ckpt_1.0.0a.model.tar.gz - popd - fi - - if [ ! -f $lm ]; then - pushd $data - wget -c https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm - popd - fi -fi - -# 3. make feature -text=$data/test/text -label_file=./aishell_result_fbank -wer=./aishell_wer_fbank - -export GLOG_logtostderr=1 - - -cmvn=$data/cmvn_fbank.ark -if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then - # 3. convert cmvn format and compute fbank feat - cmvn_json2kaldi_main --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn --binary=false - - ./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj - - utils/run.pl JOB=1:$nj $data/split${nj}/JOB/feat.log \ - compute_fbank_main \ - --wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \ - --feature_wspecifier=ark,scp:$data/split${nj}/JOB/fbank_feat.ark,$data/split${nj}/JOB/fbank_feat.scp \ - --cmvn_file=$cmvn \ - --streaming_chunk=36 -fi - -if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then - # decode w/ lm by feature - utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.fbank.wolm.log \ - ctc_beam_search_decoder_main \ - --feature_rspecifier=scp:$data/split${nj}/JOB/fbank_feat.scp \ - --model_path=$model_dir/avg_5.jit.pdmodel \ - --param_path=$model_dir/avg_5.jit.pdiparams \ - --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ - --model_cache_shapes="5-1-2048,5-1-2048" \ - --nnet_decoder_chunk=8 \ - --dict_file=$vocb_dir/vocab.txt \ - --result_wspecifier=ark,t:$data/split${nj}/JOB/result_fbank - - cat $data/split${nj}/*/result_fbank > $exp/${label_file} - utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file} > $exp/${wer} - tail -n 7 $exp/${wer} -fi - -if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then - # decode with ngram lm by feature - utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.fbank.lm.log \ - ctc_beam_search_decoder_main \ - --feature_rspecifier=scp:$data/split${nj}/JOB/fbank_feat.scp \ - --model_path=$model_dir/avg_5.jit.pdmodel \ - --param_path=$model_dir/avg_5.jit.pdiparams \ - --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ - --model_cache_shapes="5-1-2048,5-1-2048" \ - --nnet_decoder_chunk=8 \ - --dict_file=$vocb_dir/vocab.txt \ - --lm_path=$lm \ - --result_wspecifier=ark,t:$data/split${nj}/JOB/fbank_result_lm - - cat $data/split${nj}/*/fbank_result_lm > $exp/${label_file}_lm - utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file}_lm > $exp/${wer}.lm - tail -n 7 $exp/${wer}.lm -fi - -wfst=$data/wfst_fbank/ -if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then - mkdir -p $wfst - if [ ! -f $wfst/aishell_graph2.zip ]; then - pushd $wfst - wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_graph2.zip - unzip aishell_graph2.zip - mv aishell_graph2/* $wfst - popd - fi -fi - -if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then - # decode w/ TLG graph by feature - utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.fbank.wfst.log \ - ctc_tlg_decoder_main \ - --feature_rspecifier=scp:$data/split${nj}/JOB/fbank_feat.scp \ - --model_path=$model_dir/avg_5.jit.pdmodel \ - --param_path=$model_dir/avg_5.jit.pdiparams \ - --word_symbol_table=$wfst/words.txt \ - --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ - --model_cache_shapes="5-1-2048,5-1-2048" \ - --nnet_decoder_chunk=8 \ - --graph_path=$wfst/TLG.fst --max_active=7500 \ - --acoustic_scale=1.2 \ - --result_wspecifier=ark,t:$data/split${nj}/JOB/result_tlg - - cat $data/split${nj}/*/result_tlg > $exp/${label_file}_tlg - utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file}_tlg > $exp/${wer}.tlg - echo "wfst-decoder-ol have finished!!!" - echo "please checkout in ${exp}/${wer}.tlg" - tail -n 7 $exp/${wer}.tlg -fi - -if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then - # recgonize w/ TLG graph by wav - utils/run.pl JOB=1:$nj $data/split${nj}/JOB/fbank_recognizer.log \ - recognizer_main \ - --wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \ - --cmvn_file=$cmvn \ - --model_path=$model_dir/avg_5.jit.pdmodel \ - --use_fbank=true \ - --param_path=$model_dir/avg_5.jit.pdiparams \ - --word_symbol_table=$wfst/words.txt \ - --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ - --model_cache_shapes="5-1-2048,5-1-2048" \ - --nnet_decoder_chunk=8 \ - --graph_path=$wfst/TLG.fst --max_active=7500 \ - --acoustic_scale=1.2 \ - --result_wspecifier=ark,t:$data/split${nj}/JOB/result_fbank_recognizer - - cat $data/split${nj}/*/result_fbank_recognizer > $exp/${label_file}_recognizer - utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file}_recognizer > $exp/${wer}.recognizer - echo "recognizer test have finished!!!" - echo "please checkout in ${exp}/${wer}.recognizer" - tail -n 7 $exp/${wer}.recognizer -fi diff --git a/speechx/examples/ds2_ol/aishell/utils b/speechx/examples/ds2_ol/aishell/utils deleted file mode 120000 index c2519a9dd..000000000 --- a/speechx/examples/ds2_ol/aishell/utils +++ /dev/null @@ -1 +0,0 @@ -../../../../utils/ \ No newline at end of file diff --git a/speechx/examples/ds2_ol/onnx/.gitignore b/speechx/examples/ds2_ol/onnx/.gitignore deleted file mode 100644 index f862f73e2..000000000 --- a/speechx/examples/ds2_ol/onnx/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -data -log -exp diff --git a/speechx/examples/ds2_ol/onnx/README.md b/speechx/examples/ds2_ol/onnx/README.md deleted file mode 100644 index b98b74b6f..000000000 --- a/speechx/examples/ds2_ol/onnx/README.md +++ /dev/null @@ -1,57 +0,0 @@ -# Convert DeepSpeech2 model to ONNX format - -> We recommend using U2/U2++ model instead of DS2, please see [here](../../u2pp_ol/wenetspeech/). - -This example demonstrate converting ds2 model to ONNX fromat. - -Please make sure [Paddle2ONNX](https://github.com/PaddlePaddle/Paddle2ONNX) and [onnx-simplifier](https://github.com/zh794390558/onnx-simplifier/tree/dyn_time_shape) version is correct. - -The example test with these packages installed: -``` -paddle2onnx 0.9.8 # develop 62c5424e22cd93968dc831216fc9e0f0fce3d819 -paddleaudio 0.2.1 -paddlefsl 1.1.0 -paddlenlp 2.2.6 -paddlepaddle-gpu 2.2.2 -paddlespeech 0.0.0 # develop -paddlespeech-ctcdecoders 0.2.0 -paddlespeech-feat 0.1.0 -onnx 1.11.0 -onnx-simplifier 0.0.0 # https://github.com/zh794390558/onnx-simplifier/tree/dyn_time_shape -onnxoptimizer 0.2.7 -onnxruntime 1.11.0 -``` - - -## Using - -``` -bash run.sh --stage 0 --stop_stage 5 -``` - -1. convert deepspeech2 model to ONNX, using Paddle2ONNX. -2. check paddleinference and onnxruntime output equal. -3. optimize onnx model -4. check paddleinference and optimized onnxruntime output equal. -5. quantize onnx model -6. check paddleinference and optimized onnxruntime output equal. - -For more details please see `run.sh`. - -## Outputs -The optimized onnx model is `exp/model.opt.onnx`, quanted model is `exp/model.optset11.quant.onnx`. - - -## [Results](https://github.com/PaddlePaddle/PaddleSpeech/wiki/ASR-Benchmark#streaming-asr) - -机器硬件:`CPU:Intel(R) Xeon(R) Gold 6271C CPU @ 2.60GHz` -测试脚本:`Streaming Server` - -Acoustic Model | Model Size | enigne | dedoding_method | ctc_weight | decoding_chunk_size | num_decoding_left_chunk | RTF | -|:-------------:| :-----: | :-----: | :------------:| :-----: | :-----: | :-----: |:-----:| -| deepspeech2online_wenetspeech | 659MB | infernece | ctc_prefix_beam_search | - | 1 | - | 1.9108175171428279(utts=80) | -| deepspeech2online_wenetspeech | 659MB | onnx | ctc_prefix_beam_search | - | 1 | - | 0.5617182449999291 (utts=80) | -| deepspeech2online_wenetspeech | 166MB | onnx quant | ctc_prefix_beam_search | - | 1 | - | 0.44507715475808385 (utts=80) | - -> quant 和机器有关,不是所有机器都支持。ONNX quant测试机器指令集支持: -> Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology eagerfpu pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 arat umip pku ospke avx512_vnni spec_ctrl diff --git a/speechx/examples/ds2_ol/onnx/local/infer_check.py b/speechx/examples/ds2_ol/onnx/local/infer_check.py deleted file mode 100755 index f821baa12..000000000 --- a/speechx/examples/ds2_ol/onnx/local/infer_check.py +++ /dev/null @@ -1,100 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import argparse -import os -import pickle - -import numpy as np -import onnxruntime -import paddle - - -def parse_args(): - parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument( - '--input_file', - type=str, - default="static_ds2online_inputs.pickle", - help="aishell ds2 input data file. For wenetspeech, we only feed for infer model", - ) - parser.add_argument( - '--model_type', - type=str, - default="aishell", - help="aishell(1024) or wenetspeech(2048)", ) - parser.add_argument( - '--model_dir', type=str, default=".", help="paddle model dir.") - parser.add_argument( - '--model_prefix', - type=str, - default="avg_1.jit", - help="paddle model prefix.") - parser.add_argument( - '--onnx_model', - type=str, - default='./model.old.onnx', - help="onnx model.") - - return parser.parse_args() - - -if __name__ == '__main__': - FLAGS = parse_args() - - # input and output - with open(FLAGS.input_file, 'rb') as f: - iodict = pickle.load(f) - print(iodict.keys()) - - audio_chunk = iodict['audio_chunk'] - audio_chunk_lens = iodict['audio_chunk_lens'] - chunk_state_h_box = iodict['chunk_state_h_box'] - chunk_state_c_box = iodict['chunk_state_c_bos'] - print("raw state shape: ", chunk_state_c_box.shape) - - if FLAGS.model_type == 'wenetspeech': - chunk_state_h_box = np.repeat(chunk_state_h_box, 2, axis=-1) - chunk_state_c_box = np.repeat(chunk_state_c_box, 2, axis=-1) - print("state shape: ", chunk_state_c_box.shape) - - # paddle - model = paddle.jit.load(os.path.join(FLAGS.model_dir, FLAGS.model_prefix)) - res_chunk, res_lens, chunk_state_h, chunk_state_c = model( - paddle.to_tensor(audio_chunk), - paddle.to_tensor(audio_chunk_lens), - paddle.to_tensor(chunk_state_h_box), - paddle.to_tensor(chunk_state_c_box), ) - - # onnxruntime - options = onnxruntime.SessionOptions() - options.enable_profiling = True - sess = onnxruntime.InferenceSession(FLAGS.onnx_model, sess_options=options) - ort_res_chunk, ort_res_lens, ort_chunk_state_h, ort_chunk_state_c = sess.run( - ['softmax_0.tmp_0', 'tmp_5', 'concat_0.tmp_0', 'concat_1.tmp_0'], { - "audio_chunk": audio_chunk, - "audio_chunk_lens": audio_chunk_lens, - "chunk_state_h_box": chunk_state_h_box, - "chunk_state_c_box": chunk_state_c_box - }) - - print(sess.end_profiling()) - - # assert paddle equal ort - print(np.allclose(ort_res_chunk, res_chunk, atol=1e-6)) - print(np.allclose(ort_res_lens, res_lens, atol=1e-6)) - - if FLAGS.model_type == 'aishell': - print(np.allclose(ort_chunk_state_h, chunk_state_h, atol=1e-6)) - print(np.allclose(ort_chunk_state_c, chunk_state_c, atol=1e-6)) diff --git a/speechx/examples/ds2_ol/onnx/local/netron.sh b/speechx/examples/ds2_ol/onnx/local/netron.sh deleted file mode 100755 index 6dd9a39c9..000000000 --- a/speechx/examples/ds2_ol/onnx/local/netron.sh +++ /dev/null @@ -1,14 +0,0 @@ -#!/bin/bash - -# show model - -if [ $# != 1 ];then - echo "usage: $0 model_path" - exit 1 -fi - - -file=$1 - -pip install netron -netron -p 8082 --host $(hostname -i) $file \ No newline at end of file diff --git a/speechx/examples/ds2_ol/onnx/local/onnx_clone.sh b/speechx/examples/ds2_ol/onnx/local/onnx_clone.sh deleted file mode 100755 index bce22dbc8..000000000 --- a/speechx/examples/ds2_ol/onnx/local/onnx_clone.sh +++ /dev/null @@ -1,7 +0,0 @@ - -#!/bin/bash - -# clone onnx repos -git clone https://github.com/onnx/onnx.git -git clone https://github.com/microsoft/onnxruntime.git -git clone https://github.com/PaddlePaddle/Paddle2ONNX.git \ No newline at end of file diff --git a/speechx/examples/ds2_ol/onnx/local/onnx_convert_opset.py b/speechx/examples/ds2_ol/onnx/local/onnx_convert_opset.py deleted file mode 100755 index 00b5cf775..000000000 --- a/speechx/examples/ds2_ol/onnx/local/onnx_convert_opset.py +++ /dev/null @@ -1,37 +0,0 @@ -#!/usr/bin/env python3 -import argparse - -import onnx -from onnx import version_converter - -if __name__ == '__main__': - parser = argparse.ArgumentParser(prog=__doc__) - parser.add_argument( - "--model-file", type=str, required=True, help='path/to/the/model.onnx.') - parser.add_argument( - "--save-model", - type=str, - required=True, - help='path/to/saved/model.onnx.') - # Models must be opset10 or higher to be quantized. - parser.add_argument( - "--target-opset", type=int, default=11, help='path/to/the/model.onnx.') - - args = parser.parse_args() - - print(f"to opset: {args.target_opset}") - - # Preprocessing: load the model to be converted. - model_path = args.model_file - original_model = onnx.load(model_path) - - # print('The model before conversion:\n{}'.format(original_model)) - - # A full list of supported adapters can be found here: - # https://github.com/onnx/onnx/blob/main/onnx/version_converter.py#L21 - # Apply the version conversion on the original model - converted_model = version_converter.convert_version(original_model, - args.target_opset) - - # print('The model after conversion:\n{}'.format(converted_model)) - onnx.save(converted_model, args.save_model) diff --git a/speechx/examples/ds2_ol/onnx/local/onnx_infer_shape.py b/speechx/examples/ds2_ol/onnx/local/onnx_infer_shape.py deleted file mode 100755 index c53e9ec92..000000000 --- a/speechx/examples/ds2_ol/onnx/local/onnx_infer_shape.py +++ /dev/null @@ -1,2514 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# flake8: noqa -import argparse -import logging - -import numpy as np -import onnx -import sympy -from onnx import helper -from onnx import numpy_helper -from onnx import shape_inference -from packaging import version -assert version.parse(onnx.__version__) >= version.parse("1.8.0") - -logger = logging.getLogger(__name__) - - -def get_attribute(node, attr_name, default_value=None): - found = [attr for attr in node.attribute if attr.name == attr_name] - if found: - return helper.get_attribute_value(found[0]) - return default_value - - -def get_dim_from_proto(dim): - return getattr(dim, dim.WhichOneof('value')) if type( - dim.WhichOneof('value')) == str else None - - -def is_sequence(type_proto): - cls_type = type_proto.WhichOneof('value') - assert cls_type in ['tensor_type', 'sequence_type'] - return cls_type == 'sequence_type' - - -def get_shape_from_type_proto(type_proto): - assert not is_sequence(type_proto) - if type_proto.tensor_type.HasField('shape'): - return [get_dim_from_proto(d) for d in type_proto.tensor_type.shape.dim] - else: - return None # note no shape is different from shape without dim (scalar) - - -def get_shape_from_value_info(vi): - cls_type = vi.type.WhichOneof('value') - if cls_type is None: - return None - if is_sequence(vi.type): - if 'tensor_type' == vi.type.sequence_type.elem_type.WhichOneof('value'): - return get_shape_from_type_proto(vi.type.sequence_type.elem_type) - else: - return None - else: - return get_shape_from_type_proto(vi.type) - - -def make_named_value_info(name): - vi = onnx.ValueInfoProto() - vi.name = name - return vi - - -def get_shape_from_sympy_shape(sympy_shape): - return [ - None if i is None else (int(i) if is_literal(i) else str(i)) - for i in sympy_shape - ] - - -def is_literal(dim): - return type(dim) in [int, np.int64, np.int32, sympy.Integer] or (hasattr( - dim, 'is_number') and dim.is_number) - - -def handle_negative_axis(axis, rank): - assert axis < rank and axis >= -rank - return axis if axis >= 0 else rank + axis - - -def get_opset(mp, domain=None): - domain = domain or ['', 'onnx', 'ai.onnx'] - if type(domain) != list: - domain = [domain] - for opset in mp.opset_import: - if opset.domain in domain: - return opset.version - - return None - - -def as_scalar(x): - if type(x) == list: - assert len(x) == 1 - return x[0] - elif type(x) == np.ndarray: - return x.item() - else: - return x - - -def as_list(x, keep_none): - if type(x) == list: - return x - elif type(x) == np.ndarray: - return list(x) - elif keep_none and x is None: - return None - else: - return [x] - - -def sympy_reduce_product(x): - if type(x) == list: - value = sympy.Integer(1) - for v in x: - value = value * v - else: - value = x - return value - - -class SymbolicShapeInference: - def __init__(self, - int_max, - auto_merge, - guess_output_rank, - verbose, - prefix=''): - self.dispatcher_ = { - 'Add': - self._infer_symbolic_compute_ops, - 'ArrayFeatureExtractor': - self._infer_ArrayFeatureExtractor, - 'AveragePool': - self._infer_Pool, - 'BatchNormalization': - self._infer_BatchNormalization, - 'Cast': - self._infer_Cast, - 'CategoryMapper': - self._infer_CategoryMapper, - 'Compress': - self._infer_Compress, - 'Concat': - self._infer_Concat, - 'ConcatFromSequence': - self._infer_ConcatFromSequence, - 'Constant': - self._infer_Constant, - 'ConstantOfShape': - self._infer_ConstantOfShape, - 'Conv': - self._infer_Conv, - 'CumSum': - self._pass_on_shape_and_type, - 'Div': - self._infer_symbolic_compute_ops, - 'Einsum': - self._infer_Einsum, - 'Expand': - self._infer_Expand, - 'Equal': - self._infer_symbolic_compute_ops, - 'Floor': - self._infer_symbolic_compute_ops, - 'Gather': - self._infer_Gather, - 'GatherElements': - self._infer_GatherElements, - 'GatherND': - self._infer_GatherND, - 'Gelu': - self._pass_on_shape_and_type, - 'If': - self._infer_If, - 'Loop': - self._infer_Loop, - 'MatMul': - self._infer_MatMul, - 'MatMulInteger16': - self._infer_MatMulInteger, - 'MaxPool': - self._infer_Pool, - 'Max': - self._infer_symbolic_compute_ops, - 'Min': - self._infer_symbolic_compute_ops, - 'Mul': - self._infer_symbolic_compute_ops, - 'NonMaxSuppression': - self._infer_NonMaxSuppression, - 'NonZero': - self._infer_NonZero, - 'OneHot': - self._infer_OneHot, - 'Pad': - self._infer_Pad, - 'Range': - self._infer_Range, - 'Reciprocal': - self._pass_on_shape_and_type, - 'ReduceSum': - self._infer_ReduceSum, - 'ReduceProd': - self._infer_ReduceProd, - 'Reshape': - self._infer_Reshape, - 'Resize': - self._infer_Resize, - 'Round': - self._pass_on_shape_and_type, - 'Scan': - self._infer_Scan, - 'ScatterElements': - self._infer_ScatterElements, - 'SequenceAt': - self._infer_SequenceAt, - 'SequenceInsert': - self._infer_SequenceInsert, - 'Shape': - self._infer_Shape, - 'Size': - self._infer_Size, - 'Slice': - self._infer_Slice, - 'SoftmaxCrossEntropyLoss': - self._infer_SoftmaxCrossEntropyLoss, - 'SoftmaxCrossEntropyLossInternal': - self._infer_SoftmaxCrossEntropyLoss, - 'NegativeLogLikelihoodLossInternal': - self._infer_SoftmaxCrossEntropyLoss, - 'Split': - self._infer_Split, - 'SplitToSequence': - self._infer_SplitToSequence, - 'Squeeze': - self._infer_Squeeze, - 'Sub': - self._infer_symbolic_compute_ops, - 'Tile': - self._infer_Tile, - 'TopK': - self._infer_TopK, - 'Transpose': - self._infer_Transpose, - 'Unsqueeze': - self._infer_Unsqueeze, - 'Where': - self._infer_symbolic_compute_ops, - 'ZipMap': - self._infer_ZipMap, - 'Neg': - self._infer_symbolic_compute_ops, - # contrib ops: - 'Attention': - self._infer_Attention, - 'BiasGelu': - self._infer_BiasGelu, - 'EmbedLayerNormalization': - self._infer_EmbedLayerNormalization, - 'FastGelu': - self._infer_FastGelu, - 'Gelu': - self._infer_Gelu, - 'LayerNormalization': - self._infer_LayerNormalization, - 'LongformerAttention': - self._infer_LongformerAttention, - 'PythonOp': - self._infer_PythonOp, - 'SkipLayerNormalization': - self._infer_SkipLayerNormalization - } - self.aten_op_dispatcher_ = { - 'aten::embedding': self._infer_Gather, - 'aten::bitwise_or': self._infer_aten_bitwise_or, - 'aten::diagonal': self._infer_aten_diagonal, - 'aten::max_pool2d_with_indices': self._infer_aten_pool2d, - 'aten::multinomial': self._infer_aten_multinomial, - 'aten::unfold': self._infer_aten_unfold, - 'aten::argmax': self._infer_aten_argmax, - 'aten::avg_pool2d': self._infer_aten_pool2d, - 'aten::_adaptive_avg_pool2d': self._infer_aten_pool2d, - 'aten::binary_cross_entropy_with_logits': self._infer_aten_bce, - 'aten::numpy_T': self._infer_Transpose, - } - self.run_ = True - self.suggested_merge_ = {} - self.symbolic_dims_ = {} - self.input_symbols_ = {} - self.auto_merge_ = auto_merge - self.guess_output_rank_ = guess_output_rank - self.verbose_ = verbose - self.int_max_ = int_max - self.subgraph_id_ = 0 - self.prefix_ = prefix - - def _add_suggested_merge(self, symbols, apply=False): - assert all([(type(s) == str and s in self.symbolic_dims_) or - is_literal(s) for s in symbols]) - symbols = set(symbols) - for k, v in self.suggested_merge_.items(): - if k in symbols: - symbols.remove(k) - symbols.add(v) - map_to = None - # if there is literal, map to it first - for s in symbols: - if is_literal(s): - map_to = s - break - # when no literals, map to input symbolic dims, then existing symbolic dims - if map_to is None: - for s in symbols: - if s in self.input_symbols_: - map_to = s - break - if map_to is None: - for s in symbols: - if type(self.symbolic_dims_[s]) == sympy.Symbol: - map_to = s - break - # when nothing to map to, use the shorter one - if map_to is None: - if self.verbose_ > 0: - logger.warning( - 'Potential unsafe merge between symbolic expressions: ({})'. - format(','.join(symbols))) - symbols_list = list(symbols) - lens = [len(s) for s in symbols_list] - map_to = symbols_list[lens.index(min(lens))] - symbols.remove(map_to) - - for s in symbols: - if s == map_to: - continue - if is_literal(map_to) and is_literal(s): - assert int(map_to) == int(s) - self.suggested_merge_[s] = int(map_to) if is_literal( - map_to) else map_to - for k, v in self.suggested_merge_.items(): - if v == s: - self.suggested_merge_[k] = map_to - if apply and self.auto_merge_: - self._apply_suggested_merge() - - def _apply_suggested_merge(self, graph_input_only=False): - if not self.suggested_merge_: - return - for i in list(self.out_mp_.graph.input) + ( - [] if graph_input_only else list(self.out_mp_.graph.value_info)): - for d in i.type.tensor_type.shape.dim: - if d.dim_param in self.suggested_merge_: - v = self.suggested_merge_[d.dim_param] - if is_literal(v): - d.dim_value = int(v) - else: - d.dim_param = v - - def _preprocess(self, in_mp): - self.out_mp_ = onnx.ModelProto() - self.out_mp_.CopyFrom(in_mp) - self.graph_inputs_ = dict( - [(i.name, i) for i in list(self.out_mp_.graph.input)]) - self.initializers_ = dict( - [(i.name, i) for i in self.out_mp_.graph.initializer]) - self.known_vi_ = dict( - [(i.name, i) for i in list(self.out_mp_.graph.input)]) - self.known_vi_.update( - dict([(i.name, helper.make_tensor_value_info(i.name, i.data_type, - list(i.dims))) - for i in self.out_mp_.graph.initializer])) - - def _merge_symbols(self, dims): - if not all([type(d) == str for d in dims]): - if self.auto_merge_: - unique_dims = list(set(dims)) - is_int = [is_literal(d) for d in unique_dims] - assert sum( - is_int - ) <= 1 # if there are more than 1 unique ints, something is wrong - if sum(is_int) == 1: - int_dim = is_int.index(1) - if self.verbose_ > 0: - logger.debug('dim {} has been merged with value {}'. - format(unique_dims[:int_dim] + unique_dims[ - int_dim + 1:], unique_dims[int_dim])) - self._check_merged_dims(unique_dims, allow_broadcast=False) - return unique_dims[int_dim] - else: - if self.verbose_ > 0: - logger.debug('dim {} has been mergd with dim {}'.format( - unique_dims[1:], unique_dims[0])) - return dims[0] - else: - return None - if all([d == dims[0] for d in dims]): - return dims[0] - merged = [ - self.suggested_merge_[d] if d in self.suggested_merge_ else d - for d in dims - ] - if all([d == merged[0] for d in merged]): - assert merged[0] in self.symbolic_dims_ - return merged[0] - else: - return None - - # broadcast from right to left, and merge symbolic dims if needed - def _broadcast_shapes(self, shape1, shape2): - new_shape = [] - rank1 = len(shape1) - rank2 = len(shape2) - new_rank = max(rank1, rank2) - for i in range(new_rank): - dim1 = shape1[rank1 - 1 - i] if i < rank1 else 1 - dim2 = shape2[rank2 - 1 - i] if i < rank2 else 1 - if dim1 == 1 or dim1 == dim2: - new_dim = dim2 - elif dim2 == 1: - new_dim = dim1 - else: - new_dim = self._merge_symbols([dim1, dim2]) - if not new_dim: - # warning about unsupported broadcast when not auto merge - # note that auto merge has the risk of incorrectly merge symbols while one of them being 1 - # for example, 'a' = 1, 'b' = 5 at runtime is valid broadcasting, but with auto merge 'a' == 'b' - if self.auto_merge_: - self._add_suggested_merge([dim1, dim2], apply=True) - else: - logger.warning('unsupported broadcast between ' + str( - dim1) + ' ' + str(dim2)) - new_shape = [new_dim] + new_shape - return new_shape - - def _get_shape(self, node, idx): - name = node.input[idx] - if name in self.known_vi_: - vi = self.known_vi_[name] - return get_shape_from_value_info(vi) - else: - assert name in self.initializers_ - return list(self.initializers_[name].dims) - - def _get_shape_rank(self, node, idx): - return len(self._get_shape(node, idx)) - - def _get_sympy_shape(self, node, idx): - sympy_shape = [] - for d in self._get_shape(node, idx): - if type(d) == str: - sympy_shape.append(self.symbolic_dims_[d] if d in - self.symbolic_dims_ else sympy.Symbol( - d, integer=True, nonnegative=True)) - else: - assert None != d - sympy_shape.append(d) - return sympy_shape - - def _get_value(self, node, idx): - name = node.input[idx] - assert name in self.sympy_data_ or name in self.initializers_ - return self.sympy_data_[ - name] if name in self.sympy_data_ else numpy_helper.to_array( - self.initializers_[name]) - - def _try_get_value(self, node, idx): - if idx >= len(node.input): - return None - name = node.input[idx] - if name in self.sympy_data_ or name in self.initializers_: - return self._get_value(node, idx) - return None - - def _update_computed_dims(self, new_sympy_shape): - for i, new_dim in enumerate(new_sympy_shape): - if not is_literal(new_dim) and not type(new_dim) == str: - str_dim = str(new_dim) - if str_dim in self.suggested_merge_: - if is_literal(self.suggested_merge_[str_dim]): - continue # no need to create dim for literals - new_sympy_shape[i] = self.symbolic_dims_[ - self.suggested_merge_[str_dim]] - else: - # add new_dim if it's a computational expression - if not str(new_dim) in self.symbolic_dims_: - self.symbolic_dims_[str(new_dim)] = new_dim - - def _onnx_infer_single_node(self, node): - # skip onnx shape inference for some ops, as they are handled in _infer_* - skip_infer = node.op_type in [ - 'If', 'Loop', 'Scan', 'SplitToSequence', 'ZipMap', 'Attention', - 'BiasGelu', 'EmbedLayerNormalization', 'FastGelu', 'Gelu', - 'LayerNormalization', 'LongformerAttention', - 'SkipLayerNormalization', 'PythonOp' - ] - - if not skip_infer: - # Only pass initializers that satisfy the following condition: - # (1) Operator need value of some input for shape inference. - # For example, Unsqueeze in opset 13 uses the axes input to calculate shape of output. - # (2) opset version >= 9. In older version, initializer is required in graph input by onnx spec. - # (3) The initializer is not in graph input. The means the node input is "constant" in inference. - initializers = [] - if (get_opset(self.out_mp_) >= 9) and node.op_type in ['Unsqueeze']: - initializers = [ - self.initializers_[name] for name in node.input - if (name in self.initializers_ and name not in - self.graph_inputs_) - ] - - # run single node inference with self.known_vi_ shapes - tmp_graph = helper.make_graph( - [node], 'tmp', [self.known_vi_[i] for i in node.input if i], - [make_named_value_info(i) for i in node.output], initializers) - - self.tmp_mp_.graph.CopyFrom(tmp_graph) - - self.tmp_mp_ = shape_inference.infer_shapes(self.tmp_mp_) - - for i_o in range(len(node.output)): - o = node.output[i_o] - vi = self.out_mp_.graph.value_info.add() - if not skip_infer: - vi.CopyFrom(self.tmp_mp_.graph.output[i_o]) - else: - vi.name = o - self.known_vi_[o] = vi - - def _onnx_infer_subgraph(self, - node, - subgraph, - use_node_input=True, - inc_subgraph_id=True): - if self.verbose_ > 2: - logger.debug( - 'Inferencing subgraph of node {} with output({}...): {}'.format( - node.name, node.output[0], node.op_type)) - # node inputs are not passed directly to the subgraph - # it's up to the node dispatcher to prepare subgraph input - # for example, with Scan/Loop, subgraph input shape would be trimmed from node input shape - # besides, inputs in subgraph could shadow implicit inputs - subgraph_inputs = set( - [i.name for i in list(subgraph.initializer) + list(subgraph.input)]) - subgraph_implicit_input = set([ - name for name in self.known_vi_.keys() - if not name in subgraph_inputs - ]) - tmp_graph = helper.make_graph( - list(subgraph.node), 'tmp', - list(subgraph.input) + - [self.known_vi_[i] for i in subgraph_implicit_input], - [make_named_value_info(i.name) for i in subgraph.output]) - tmp_graph.initializer.extend([ - i for i in self.out_mp_.graph.initializer - if i.name in subgraph_implicit_input - ]) - tmp_graph.initializer.extend(subgraph.initializer) - self.tmp_mp_.graph.CopyFrom(tmp_graph) - - symbolic_shape_inference = SymbolicShapeInference( - self.int_max_, - self.auto_merge_, - self.guess_output_rank_, - self.verbose_, - prefix=self.prefix_ + '_' + str(self.subgraph_id_)) - if inc_subgraph_id: - self.subgraph_id_ += 1 - - all_shapes_inferred = False - symbolic_shape_inference._preprocess(self.tmp_mp_) - symbolic_shape_inference.suggested_merge_ = self.suggested_merge_.copy() - while symbolic_shape_inference.run_: - all_shapes_inferred = symbolic_shape_inference._infer_impl( - self.sympy_data_.copy()) - symbolic_shape_inference._update_output_from_vi() - if use_node_input: - # if subgraph uses node input, it needs to update to merged dims - subgraph.ClearField('input') - subgraph.input.extend( - symbolic_shape_inference.out_mp_.graph.input[:len(node.input)]) - subgraph.ClearField('output') - subgraph.output.extend(symbolic_shape_inference.out_mp_.graph.output) - subgraph.ClearField('value_info') - subgraph.value_info.extend( - symbolic_shape_inference.out_mp_.graph.value_info) - subgraph.ClearField('node') - subgraph.node.extend(symbolic_shape_inference.out_mp_.graph.node) - # for new symbolic dims from subgraph output, add to main graph symbolic dims - subgraph_shapes = [ - get_shape_from_value_info(o) - for o in symbolic_shape_inference.out_mp_.graph.output - ] - subgraph_new_symbolic_dims = set([ - d for s in subgraph_shapes - if s for d in s if type(d) == str and not d in self.symbolic_dims_ - ]) - new_dims = {} - for d in subgraph_new_symbolic_dims: - assert d in symbolic_shape_inference.symbolic_dims_ - new_dims[d] = symbolic_shape_inference.symbolic_dims_[d] - self.symbolic_dims_.update(new_dims) - return symbolic_shape_inference - - def _get_int_values(self, node, broadcast=False): - values = [self._try_get_value(node, i) for i in range(len(node.input))] - if all([v is not None for v in values]): - # some shape compute is in floating point, cast to int for sympy - for i, v in enumerate(values): - if type(v) != np.ndarray: - continue - if len(v.shape) > 1: - new_v = None # ignore value for rank > 1 - elif len(v.shape) == 0: - new_v = int(v.item()) - else: - assert len(v.shape) == 1 - new_v = [int(vv) for vv in v] - values[i] = new_v - values_len = [len(v) if type(v) == list else 0 for v in values] - max_len = max(values_len) - if max_len >= 1 and broadcast: - # broadcast - for i, v in enumerate(values): - if v is None: - continue # don't broadcast if value is unknown - if type(v) == list: - if len(v) < max_len: - values[i] = v * max_len - else: - assert len(v) == max_len - else: - values[i] = [v] * max_len - return values - - def _compute_on_sympy_data(self, node, op_func): - assert len(node.output) == 1 - values = self._get_int_values(node, broadcast=True) - if all([v is not None for v in values]): - is_list = [type(v) == list for v in values] - as_list = any(is_list) - if as_list: - self.sympy_data_[node.output[ - 0]] = [op_func(vs) for vs in zip(*values)] - else: - self.sympy_data_[node.output[0]] = op_func(values) - - def _pass_on_sympy_data(self, node): - assert len( - node. - input) == 1 or node.op_type in ['Reshape', 'Unsqueeze', 'Squeeze'] - self._compute_on_sympy_data(node, lambda x: x[0]) - - def _pass_on_shape_and_type(self, node): - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], self.known_vi_[ - node.input[0]].type.tensor_type.elem_type, - self._get_shape(node, 0))) - - def _new_symbolic_dim(self, prefix, dim): - new_dim = '{}_d{}'.format(prefix, dim) - if new_dim in self.suggested_merge_: - v = self.suggested_merge_[new_dim] - new_symbolic_dim = sympy.Integer(int(v)) if is_literal(v) else v - else: - new_symbolic_dim = sympy.Symbol( - new_dim, integer=True, nonnegative=True) - self.symbolic_dims_[new_dim] = new_symbolic_dim - return new_symbolic_dim - - def _new_symbolic_dim_from_output(self, node, out_idx=0, dim=0): - return self._new_symbolic_dim('{}{}_{}_o{}_'.format( - node.op_type, self.prefix_, - list(self.out_mp_.graph.node).index(node), out_idx), dim) - - def _new_symbolic_shape(self, rank, node, out_idx=0): - return [ - self._new_symbolic_dim_from_output(node, out_idx, i) - for i in range(rank) - ] - - def _compute_conv_pool_shape(self, node): - sympy_shape = self._get_sympy_shape(node, 0) - if len(node.input) > 1: - W_shape = self._get_sympy_shape(node, 1) - rank = len(W_shape) - 2 # number of spatial axes - kernel_shape = W_shape[-rank:] - sympy_shape[1] = W_shape[0] - else: - W_shape = None - kernel_shape = get_attribute(node, 'kernel_shape') - rank = len(kernel_shape) - - assert len(sympy_shape) == rank + 2 - - # only need to symbolic shape inference if input has symbolic dims in spatial axes - is_symbolic_dims = [not is_literal(i) for i in sympy_shape[-rank:]] - - if not any(is_symbolic_dims): - shape = get_shape_from_value_info(self.known_vi_[node.output[0]]) - if len(shape) > 0: - assert len(sympy_shape) == len(shape) - sympy_shape[-rank:] = [sympy.Integer(d) for d in shape[-rank:]] - return sympy_shape - - dilations = get_attribute(node, 'dilations', [1] * rank) - strides = get_attribute(node, 'strides', [1] * rank) - effective_kernel_shape = [(k - 1) * d + 1 - for k, d in zip(kernel_shape, dilations)] - pads = get_attribute(node, 'pads') - if pads is None: - pads = [0] * (2 * rank) - auto_pad = get_attribute(node, 'auto_pad', - b'NOTSET').decode('utf-8') - if auto_pad != 'VALID' and auto_pad != 'NOTSET': - try: - residual = [ - sympy.Mod(d, s) - for d, s in zip(sympy_shape[-rank:], strides) - ] - total_pads = [ - max(0, (k - s) if r == 0 else (k - r)) - for k, s, r in zip(effective_kernel_shape, strides, - residual) - ] - except TypeError: # sympy may throw TypeError: cannot determine truth value of Relational - total_pads = [ - max(0, (k - s)) - for k, s in zip(effective_kernel_shape, strides) - ] # assuming no residual if sympy throws error - elif auto_pad == 'VALID': - total_pads = [] - else: - total_pads = [0] * rank - else: - assert len(pads) == 2 * rank - total_pads = [p1 + p2 for p1, p2 in zip(pads[:rank], pads[rank:])] - - ceil_mode = get_attribute(node, 'ceil_mode', 0) - for i in range(rank): - effective_input_size = sympy_shape[-rank + i] - if len(total_pads) > 0: - effective_input_size = effective_input_size + total_pads[i] - if ceil_mode: - strided_kernel_positions = sympy.ceiling( - (effective_input_size - effective_kernel_shape[i]) / - strides[i]) - else: - strided_kernel_positions = ( - effective_input_size - effective_kernel_shape[i] - ) // strides[i] - sympy_shape[-rank + i] = strided_kernel_positions + 1 - return sympy_shape - - def _check_merged_dims(self, dims, allow_broadcast=True): - if allow_broadcast: - dims = [d for d in dims if not (is_literal(d) and int(d) <= 1)] - if not all([d == dims[0] for d in dims]): - self._add_suggested_merge(dims, apply=True) - - def _compute_matmul_shape(self, node, output_dtype=None): - lhs_shape = self._get_shape(node, 0) - rhs_shape = self._get_shape(node, 1) - lhs_rank = len(lhs_shape) - rhs_rank = len(rhs_shape) - lhs_reduce_dim = 0 - rhs_reduce_dim = 0 - assert lhs_rank > 0 and rhs_rank > 0 - if lhs_rank == 1 and rhs_rank == 1: - new_shape = [] - elif lhs_rank == 1: - rhs_reduce_dim = -2 - new_shape = rhs_shape[:rhs_reduce_dim] + [rhs_shape[-1]] - elif rhs_rank == 1: - lhs_reduce_dim = -1 - new_shape = lhs_shape[:lhs_reduce_dim] - else: - lhs_reduce_dim = -1 - rhs_reduce_dim = -2 - new_shape = self._broadcast_shapes( - lhs_shape[:-2], - rhs_shape[:-2]) + [lhs_shape[-2]] + [rhs_shape[-1]] - # merge reduce dim - self._check_merged_dims( - [lhs_shape[lhs_reduce_dim], rhs_shape[rhs_reduce_dim]], - allow_broadcast=False) - if output_dtype is None: - # infer output_dtype from input type when not specified - output_dtype = self.known_vi_[node.input[ - 0]].type.tensor_type.elem_type - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], output_dtype, - new_shape)) - - def _fuse_tensor_type(self, node, out_idx, dst_type, src_type): - ''' - update dst_tensor_type to be compatible with src_tensor_type when dimension mismatches - ''' - dst_tensor_type = dst_type.sequence_type.elem_type.tensor_type if is_sequence( - dst_type) else dst_type.tensor_type - src_tensor_type = src_type.sequence_type.elem_type.tensor_type if is_sequence( - src_type) else src_type.tensor_type - if dst_tensor_type.elem_type != src_tensor_type.elem_type: - node_id = node.name if node.name else node.op_type - raise ValueError( - f"For node {node_id}, dst_tensor_type.elem_type != src_tensor_type.elem_type: " - f"{onnx.onnx_pb.TensorProto.DataType.Name(dst_tensor_type.elem_type)} vs " - f"{onnx.onnx_pb.TensorProto.DataType.Name(src_tensor_type.elem_type)}" - ) - if dst_tensor_type.HasField('shape'): - for di, ds in enumerate( - zip(dst_tensor_type.shape.dim, src_tensor_type.shape.dim)): - if ds[0] != ds[1]: - # create a new symbolic dimension for node/out_idx/mismatch dim id in dst_tensor_type for tensor_type - # for sequence_type, clear the dimension - new_dim = onnx.TensorShapeProto.Dimension() - if not is_sequence(dst_type): - new_dim.dim_param = str( - self._new_symbolic_dim_from_output(node, out_idx, - di)) - dst_tensor_type.shape.dim[di].CopyFrom(new_dim) - else: - dst_tensor_type.CopyFrom(src_tensor_type) - - def _infer_ArrayFeatureExtractor(self, node): - data_shape = self._get_shape(node, 0) - indices_shape = self._get_shape(node, 1) - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], self.known_vi_[ - node.input[0]].type.tensor_type.elem_type, data_shape[:-1] + - indices_shape)) - - def _infer_symbolic_compute_ops(self, node): - funcs = { - 'Add': - lambda l: l[0] + l[1], - 'Div': - lambda l: l[0] // l[1], # integer div in sympy - 'Equal': - lambda l: l[0] == l[1], - 'Floor': - lambda l: sympy.floor(l[0]), - 'Max': - lambda l: l[1] if is_literal(l[0]) and int(l[0]) < -self.int_max_ else (l[0] if is_literal(l[1]) and int(l[1]) < -self.int_max_ else sympy.Max(l[0], l[1])), - 'Min': - lambda l: l[1] if is_literal(l[0]) and int(l[0]) > self.int_max_ else (l[0] if is_literal(l[1]) and int(l[1]) > self.int_max_ else sympy.Min(l[0], l[1])), - 'Mul': - lambda l: l[0] * l[1], - 'Sub': - lambda l: l[0] - l[1], - 'Where': - lambda l: l[1] if l[0] else l[2], - 'Neg': - lambda l: -l[0] - } - assert node.op_type in funcs - self._compute_on_sympy_data(node, funcs[node.op_type]) - - def _infer_Cast(self, node): - self._pass_on_sympy_data(node) - - def _infer_CategoryMapper(self, node): - input_type = self.known_vi_[node.input[0]].type.tensor_type.elem_type - if input_type == onnx.TensorProto.STRING: - output_type = onnx.TensorProto.INT64 - else: - output_type = onnx.TensorProto.STRING - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], output_type, - self._get_shape(node, 0))) - - def _infer_Compress(self, node): - input_shape = self._get_shape(node, 0) - # create a new symbolic dimension for Compress output - compress_len = str(self._new_symbolic_dim_from_output(node)) - axis = get_attribute(node, 'axis') - if axis == None: - # when axis is not specified, input is flattened before compress so output is 1D - output_shape = [compress_len] - else: - output_shape = input_shape - output_shape[handle_negative_axis(axis, len( - input_shape))] = compress_len - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], self.known_vi_[ - node.input[0]].type.tensor_type.elem_type, output_shape)) - - def _infer_Concat(self, node): - if any([ - i in self.sympy_data_ or i in self.initializers_ - for i in node.input - ]): - values = self._get_int_values(node) - print("=======", values, node.name, get_attribute(node, 'axis')) - if all([v is not None for v in values]): - axis = get_attribute(node, 'axis') - if axis < 0: - axis = axis + len(values[0]) - assert 0 == axis - self.sympy_data_[node.output[0]] = [] - for i in range(len(node.input)): - value = values[i] - if type(value) == list: - self.sympy_data_[node.output[0]].extend(value) - else: - self.sympy_data_[node.output[0]].append(value) - - sympy_shape = self._get_sympy_shape(node, 0) - axis = handle_negative_axis( - get_attribute(node, 'axis'), len(sympy_shape)) - for i_idx in range(1, len(node.input)): - input_shape = self._get_sympy_shape(node, i_idx) - if input_shape: - sympy_shape[axis] = sympy_shape[axis] + input_shape[axis] - self._update_computed_dims(sympy_shape) - # merge symbolic dims for non-concat axes - for d in range(len(sympy_shape)): - if d == axis: - continue - dims = [ - self._get_shape(node, i_idx)[d] - for i_idx in range(len(node.input)) - if self._get_shape(node, i_idx) - ] - if all([d == dims[0] for d in dims]): - continue - merged = self._merge_symbols(dims) - if type(merged) == str: - sympy_shape[d] = self.symbolic_dims_[merged] if merged else None - else: - sympy_shape[d] = merged - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], self.known_vi_[node.input[0]].type.tensor_type. - elem_type, get_shape_from_sympy_shape(sympy_shape))) - - def _infer_ConcatFromSequence(self, node): - seq_shape = self._get_shape(node, 0) - new_axis = 1 if get_attribute(node, 'new_axis') else 0 - axis = handle_negative_axis( - get_attribute(node, 'axis'), len(seq_shape) + new_axis) - concat_dim = str(self._new_symbolic_dim_from_output(node, 0, axis)) - new_shape = seq_shape - if new_axis: - new_shape = seq_shape[:axis] + [concat_dim] + seq_shape[axis:] - else: - new_shape[axis] = concat_dim - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], self.known_vi_[node.input[0]] - .type.sequence_type.elem_type.tensor_type.elem_type, new_shape)) - - def _infer_Constant(self, node): - t = get_attribute(node, 'value') - self.sympy_data_[node.output[0]] = numpy_helper.to_array(t) - - def _infer_ConstantOfShape(self, node): - sympy_shape = self._get_int_values(node)[0] - vi = self.known_vi_[node.output[0]] - if sympy_shape is not None: - if type(sympy_shape) != list: - sympy_shape = [sympy_shape] - self._update_computed_dims(sympy_shape) - # update sympy data if output type is int, and shape is known - if vi.type.tensor_type.elem_type == onnx.TensorProto.INT64 and all( - [is_literal(x) for x in sympy_shape]): - self.sympy_data_[node.output[0]] = np.ones( - [int(x) for x in sympy_shape], - dtype=np.int64) * numpy_helper.to_array( - get_attribute(node, 'value', 0)) - else: - # create new dynamic shape - # note input0 is a 1D vector of shape, the new symbolic shape has the rank of the shape vector length - sympy_shape = self._new_symbolic_shape( - self._get_shape(node, 0)[0], node) - - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], vi.type.tensor_type.elem_type, - get_shape_from_sympy_shape(sympy_shape))) - - def _infer_Conv(self, node): - sympy_shape = self._compute_conv_pool_shape(node) - self._update_computed_dims(sympy_shape) - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], vi.type.tensor_type.elem_type, - get_shape_from_sympy_shape(sympy_shape))) - - def _infer_Einsum(self, node): - # ref:https://github.com/onnx/onnx/blob/623dfaa0151b2e4ce49779c3ec31cbd78c592b80/onnx/defs/math/defs.cc#L3275 - equation = get_attribute(node, 'equation') - equation = equation.replace(b' ', b'') - mid_index = equation.find(b'->') - left_equation = equation[:mid_index] if mid_index != -1 else equation - - num_operands = 0 - num_ellipsis = 0 - num_ellipsis_indices = 0 - - letter_to_dim = {} - - terms = left_equation.split(b',') - for term in terms: - ellipsis_index = term.find(b'...') - shape = self._get_shape(node, num_operands) - rank = len(shape) - if ellipsis_index != -1: - if num_ellipsis == 0: - num_ellipsis_indices = rank - len(term) + 3 - num_ellipsis = num_ellipsis + 1 - for i in range(1, rank + 1): - letter = term[-i] - if letter != 46: # letter != b'.' - dim = shape[-i] - if letter not in letter_to_dim.keys(): - letter_to_dim[letter] = dim - elif type(dim) != sympy.Symbol: - letter_to_dim[letter] = dim - num_operands = num_operands + 1 - - new_sympy_shape = [] - from collections import OrderedDict - num_letter_occurrences = OrderedDict() - if mid_index != -1: - right_equation = equation[mid_index + 2:] - right_ellipsis_index = right_equation.find(b'...') - if right_ellipsis_index != -1: - for i in range(num_ellipsis_indices): - new_sympy_shape.append(shape[i]) - for c in right_equation: - if c != 46: # c != b'.' - new_sympy_shape.append(letter_to_dim[c]) - else: - for i in range(num_ellipsis_indices): - new_sympy_shape.append(shape[i]) - for c in left_equation: - if c != 44 and c != 46: # c != b',' and c != b'.': - if c in num_letter_occurrences: - num_letter_occurrences[c] = num_letter_occurrences[ - c] + 1 - else: - num_letter_occurrences[c] = 1 - for key, value in num_letter_occurrences.items(): - if value == 1: - new_sympy_shape.append(letter_to_dim[key]) - - output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], output_dtype, - new_sympy_shape)) - - def _infer_Expand(self, node): - expand_to_shape = as_list(self._try_get_value(node, 1), keep_none=True) - if expand_to_shape is not None: - # new_shape's dim can come from shape value - self._update_computed_dims(expand_to_shape) - shape = self._get_shape(node, 0) - new_shape = self._broadcast_shapes( - shape, get_shape_from_sympy_shape(expand_to_shape)) - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], self.known_vi_[ - node.input[0]].type.tensor_type.elem_type, new_shape)) - - def _infer_Gather(self, node): - data_shape = self._get_shape(node, 0) - axis = handle_negative_axis( - get_attribute(node, 'axis', 0), len(data_shape)) - indices_shape = self._get_shape(node, 1) - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], self.known_vi_[ - node.input[0]].type.tensor_type.elem_type, data_shape[:axis] + - indices_shape + data_shape[axis + - 1:])) - # for 1D input, do some sympy compute - if node.input[0] in self.sympy_data_ and len( - data_shape) == 1 and 0 == get_attribute(node, 'axis', 0): - idx = self._try_get_value(node, 1) - if idx is not None: - data = self.sympy_data_[node.input[0]] - if type(data) == list: - if type(idx) == np.ndarray and len(idx.shape) == 1: - self.sympy_data_[node.output[ - 0]] = [data[int(i)] for i in idx] - else: - self.sympy_data_[node.output[0]] = data[int(idx)] - else: - assert idx == 0 or idx == -1 - self.sympy_data_[node.output[0]] = data - - def _infer_GatherElements(self, node): - indices_shape = self._get_shape(node, 1) - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], self.known_vi_[ - node.input[0]].type.tensor_type.elem_type, indices_shape)) - - def _infer_GatherND(self, node): - data_shape = self._get_shape(node, 0) - data_rank = len(data_shape) - indices_shape = self._get_shape(node, 1) - indices_rank = len(indices_shape) - last_index_dimension = indices_shape[-1] - assert is_literal( - last_index_dimension) and last_index_dimension <= data_rank - new_shape = indices_shape[:-1] + data_shape[last_index_dimension:] - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], self.known_vi_[ - node.input[0]].type.tensor_type.elem_type, new_shape)) - - def _infer_If(self, node): - # special case for constant condition, in case there are mismatching shape from the non-executed branch - subgraphs = [ - get_attribute(node, 'then_branch'), get_attribute(node, - 'else_branch') - ] - cond = self._try_get_value(node, 0) - if cond is not None: - if as_scalar(cond) > 0: - subgraphs[1].CopyFrom(subgraphs[0]) - else: - subgraphs[0].CopyFrom(subgraphs[1]) - - for i_sub, subgraph in enumerate(subgraphs): - subgraph_infer = self._onnx_infer_subgraph( - node, subgraph, use_node_input=False) - for i_out in range(len(node.output)): - vi = self.known_vi_[node.output[i_out]] - if i_sub == 0: - vi.CopyFrom(subgraph.output[i_out]) - vi.name = node.output[i_out] - else: - self._fuse_tensor_type(node, i_out, vi.type, - subgraph.output[i_out].type) - - # pass on sympy data from subgraph, if cond is constant - if cond is not None and i_sub == (0 if as_scalar(cond) > 0 else - 1): - if subgraph.output[ - i_out].name in subgraph_infer.sympy_data_: - self.sympy_data_[vi.name] = subgraph_infer.sympy_data_[ - subgraph.output[i_out].name] - - def _infer_Loop(self, node): - subgraph = get_attribute(node, 'body') - assert len(subgraph.input) == len(node.input) - num_loop_carried = len( - node.input) - 2 # minus the length and initial loop condition - # when sequence_type is used as loop carried input - # needs to run subgraph infer twice if the tensor shape in sequence contains None - for i, si in enumerate(subgraph.input): - si_name = si.name - si.CopyFrom(self.known_vi_[node.input[i]]) - si.name = si_name - - self._onnx_infer_subgraph(node, subgraph) - - # check subgraph input/output for shape changes in loop carried variables - # for tensor_type, create new symbolic dim when changing, i.e., output = Concat(input, a) - # for sequence_type, propagate from output to input - need_second_infer = False - for i_out in range(1, num_loop_carried + 1): - so = subgraph.output[i_out] - so_shape = get_shape_from_value_info(so) - if is_sequence(so.type): - if so_shape and None in so_shape: - # copy shape from output to input - # note that loop input is [loop_len, cond, input_0, input_1, ...] - # while loop output is [cond, output_0, output_1, ...] - subgraph.input[i_out + - 1].type.sequence_type.elem_type.CopyFrom( - so.type.sequence_type.elem_type) - need_second_infer = True - else: - si = subgraph.input[i_out + 1] - si_shape = get_shape_from_value_info(si) - for di, dims in enumerate(zip(si_shape, so_shape)): - if dims[0] != dims[1]: - new_dim = onnx.TensorShapeProto.Dimension() - new_dim.dim_param = str( - self._new_symbolic_dim_from_output(node, i_out, di)) - si.type.tensor_type.shape.dim[di].CopyFrom(new_dim) - so.type.tensor_type.shape.dim[di].CopyFrom(new_dim) - need_second_infer = True - - if need_second_infer: - if self.verbose_ > 2: - logger.debug( - "Rerun Loop: {}({}...), because of sequence in loop carried variables". - format(node.name, node.output[0])) - self._onnx_infer_subgraph(node, subgraph, inc_subgraph_id=False) - - # create a new symbolic dimension for iteration dependent dimension - loop_iter_dim = str(self._new_symbolic_dim_from_output(node)) - for i in range(len(node.output)): - vi = self.known_vi_[node.output[i]] - vi.CopyFrom(subgraph.output[ - i + - 1]) # first subgraph output is condition, not in node output - if i >= num_loop_carried: - assert not is_sequence( - vi.type) # TODO: handle loop accumulation in sequence_type - subgraph_vi_dim = subgraph.output[i + - 1].type.tensor_type.shape.dim - vi.type.tensor_type.shape.ClearField('dim') - vi_dim = vi.type.tensor_type.shape.dim - vi_dim.add().dim_param = loop_iter_dim - vi_dim.extend(list(subgraph_vi_dim)) - vi.name = node.output[i] - - def _infer_MatMul(self, node): - self._compute_matmul_shape(node) - - def _infer_MatMulInteger(self, node): - self._compute_matmul_shape(node, onnx.TensorProto.INT32) - - def _infer_NonMaxSuppression(self, node): - selected = str(self._new_symbolic_dim_from_output(node)) - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[ - 0], onnx.TensorProto.INT64, [selected, 3])) - - def _infer_NonZero(self, node): - input_rank = self._get_shape_rank(node, 0) - # create a new symbolic dimension for NonZero output - nz_len = str(self._new_symbolic_dim_from_output(node, 0, 1)) - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[ - 0], vi.type.tensor_type.elem_type, [input_rank, nz_len])) - - def _infer_OneHot(self, node): - sympy_shape = self._get_sympy_shape(node, 0) - depth = self._try_get_value(node, 1) - axis = get_attribute(node, 'axis', -1) - axis = handle_negative_axis(axis, len(sympy_shape) + 1) - new_shape = get_shape_from_sympy_shape(sympy_shape[:axis] + [ - self._new_symbolic_dim_from_output(node) - if not is_literal(depth) else depth - ] + sympy_shape[axis:]) - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], self.known_vi_[ - node.input[2]].type.tensor_type.elem_type, new_shape)) - - def _infer_Pad(self, node): - if get_opset(self.out_mp_) <= 10: - pads = get_attribute(node, 'pads') - else: - pads = self._try_get_value(node, 1) - - sympy_shape = self._get_sympy_shape(node, 0) - rank = len(sympy_shape) - - if pads is not None: - assert len(pads) == 2 * rank - new_sympy_shape = [ - d + pad_up + pad_down - for d, pad_up, pad_down in zip(sympy_shape, pads[:rank], pads[ - rank:]) - ] - self._update_computed_dims(new_sympy_shape) - else: - # dynamic pads, create new symbolic dimensions - new_sympy_shape = self._new_symbolic_shape(rank, node) - output_tp = self.known_vi_[node.input[0]].type.tensor_type.elem_type - - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[ - 0], output_tp, get_shape_from_sympy_shape(new_sympy_shape))) - - def _infer_Pool(self, node): - sympy_shape = self._compute_conv_pool_shape(node) - self._update_computed_dims(sympy_shape) - for o in node.output: - if not o: - continue - vi = self.known_vi_[o] - vi.CopyFrom( - helper.make_tensor_value_info(o, vi.type.tensor_type.elem_type, - get_shape_from_sympy_shape( - sympy_shape))) - - def _infer_aten_bitwise_or(self, node): - shape0 = self._get_shape(node, 0) - shape1 = self._get_shape(node, 1) - new_shape = self._broadcast_shapes(shape0, shape1) - t0 = self.known_vi_[node.input[0]] - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[ - 0], t0.type.tensor_type.elem_type, new_shape)) - - def _infer_aten_diagonal(self, node): - sympy_shape = self._get_sympy_shape(node, 0) - rank = len(sympy_shape) - offset = self._try_get_value(node, 1) - dim1 = self._try_get_value(node, 2) - dim2 = self._try_get_value(node, 3) - - assert offset is not None and dim1 is not None and dim2 is not None - dim1 = handle_negative_axis(dim1, rank) - dim2 = handle_negative_axis(dim2, rank) - - new_shape = [] - for dim, val in enumerate(sympy_shape): - if dim not in [dim1, dim2]: - new_shape.append(val) - - shape1 = sympy_shape[dim1] - shape2 = sympy_shape[dim2] - if offset >= 0: - diag_shape = sympy.Max(0, sympy.Min(shape1, shape2 - offset)) - else: - diag_shape = sympy.Max(0, sympy.Min(shape1 + offset, shape2)) - new_shape.append(diag_shape) - - if node.output[0]: - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], self.known_vi_[ - node.input[0]].type.tensor_type.elem_type, - get_shape_from_sympy_shape( - new_shape))) - - def _infer_aten_multinomial(self, node): - sympy_shape = self._get_sympy_shape(node, 0) - rank = len(sympy_shape) - assert rank in [1, 2] - num_samples = self._try_get_value(node, 1) - di = rank - 1 - last_dim = num_samples if num_samples else str( - self._new_symbolic_dim_from_output(node, 0, di)) - output_shape = sympy_shape[:-1] + [last_dim] - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], onnx.TensorProto.INT64, - get_shape_from_sympy_shape(output_shape))) - - def _infer_aten_pool2d(self, node): - sympy_shape = self._get_sympy_shape(node, 0) - assert len(sympy_shape) == 4 - sympy_shape[-2:] = [ - self._new_symbolic_dim_from_output(node, 0, i) for i in [2, 3] - ] - self._update_computed_dims(sympy_shape) - for i, o in enumerate(node.output): - if not o: - continue - vi = self.known_vi_[o] - elem_type = onnx.TensorProto.INT64 if i == 1 else self.known_vi_[ - node.input[0]].type.tensor_type.elem_type - vi.CopyFrom( - helper.make_tensor_value_info( - o, elem_type, get_shape_from_sympy_shape(sympy_shape))) - - def _infer_aten_unfold(self, node): - sympy_shape = self._get_sympy_shape(node, 0) - dimension = self._try_get_value(node, 1) - size = self._try_get_value(node, 2) - step = self._try_get_value(node, 3) - if dimension is not None and size is not None and step is not None: - assert dimension < len(sympy_shape) - sympy_shape[dimension] = (sympy_shape[dimension] - size) // step + 1 - sympy_shape.append(size) - else: - rank = len(sympy_shape) - sympy_shape = self._new_symbolic_shape(rank + 1, node) - self._update_computed_dims(sympy_shape) - if node.output[0]: - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], self.known_vi_[ - node.input[0]].type.tensor_type.elem_type, - get_shape_from_sympy_shape( - sympy_shape))) - - def _infer_aten_argmax(self, node): - new_shape = None - if node.input[1] == '': - # The argmax of the flattened input is returned. - new_shape = [] - else: - dim = self._try_get_value(node, 1) - keepdim = self._try_get_value(node, 2) - if keepdim is not None: - sympy_shape = self._get_sympy_shape(node, 0) - if dim is not None: - dim = handle_negative_axis(dim, len(sympy_shape)) - if keepdim: - sympy_shape[dim] = 1 - else: - del sympy_shape[dim] - else: - rank = len(sympy_shape) - sympy_shape = self._new_symbolic_shape(rank if keepdim else - rank - 1, node) - self._update_computed_dims(sympy_shape) - new_shape = get_shape_from_sympy_shape(sympy_shape) - if node.output[0] and new_shape is not None: - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[ - 0], onnx.TensorProto.INT64, new_shape)) - - def _infer_aten_bce(self, node): - reduction = self._try_get_value(node, 4) - if reduction is None: - reduction = 1 - elem_type = self.known_vi_[node.input[0]].type.tensor_type.elem_type - vi = self.known_vi_[node.output[0]] - if reduction == 0: - vi.type.tensor_type.elem_type = elem_type - vi.type.tensor_type.shape.CopyFrom(onnx.TensorShapeProto()) - else: - vi.CopyFrom( - helper.make_tensor_value_info(vi.name, elem_type, - self._get_shape(node, 0))) - - def _infer_BatchNormalization(self, node): - self._propagate_shape_and_type(node) - - # this works for opsets < 14 and 14 since we check i < len(node.output) in the loop - for i in [1, 2, 3, 4]: - if i < len(node.output) and node.output[i] != "": - # all of these parameters have the same shape as the 1st input - self._propagate_shape_and_type( - node, input_index=1, output_index=i) - - def _infer_Range(self, node): - vi = self.known_vi_[node.output[0]] - input_data = self._get_int_values(node) - if all([i is not None for i in input_data]): - start = as_scalar(input_data[0]) - limit = as_scalar(input_data[1]) - delta = as_scalar(input_data[2]) - new_sympy_shape = [ - sympy.Max(sympy.ceiling((limit - start) / delta), 0) - ] - else: - new_sympy_shape = [self._new_symbolic_dim_from_output(node)] - self._update_computed_dims(new_sympy_shape) - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], self.known_vi_[node.input[0]].type.tensor_type. - elem_type, get_shape_from_sympy_shape(new_sympy_shape))) - - def _infer_ReduceSum(self, node): - keep_dims = get_attribute(node, 'keepdims', 1) - if get_opset(self.out_mp_) >= 13 and len(node.input) > 1: - # ReduceSum changes axes to input[1] in opset 13 - axes = self._try_get_value(node, 1) - vi = self.known_vi_[node.output[0]] - if axes is None: - assert keep_dims # can only handle keep_dims==True when axes is unknown, by generating new ranks - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], self.known_vi_[node.input[ - 0]].type.tensor_type.elem_type, - get_shape_from_sympy_shape( - self._new_symbolic_shape( - self._get_shape_rank(node, 0), node)))) - else: - shape = self._get_shape(node, 0) - output_shape = [] - axes = [handle_negative_axis(a, len(shape)) for a in axes] - for i, d in enumerate(shape): - if i in axes: - if keep_dims: - output_shape.append(1) - else: - output_shape.append(d) - vi.CopyFrom( - helper.make_tensor_value_info(node.output[ - 0], self.known_vi_[node.input[ - 0]].type.tensor_type.elem_type, output_shape)) - - def _infer_ReduceProd(self, node): - axes = get_attribute(node, 'axes') - keep_dims = get_attribute(node, 'keepdims', 1) - if keep_dims == 0 and axes == [0]: - data = self._get_int_values(node)[0] - if data is not None: - self.sympy_data_[node.output[0]] = sympy_reduce_product(data) - - def _infer_Reshape(self, node): - shape_value = self._try_get_value(node, 1) - vi = self.known_vi_[node.output[0]] - if shape_value is None: - shape_shape = self._get_shape(node, 1) - assert len(shape_shape) == 1 - shape_rank = shape_shape[0] - assert is_literal(shape_rank) - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], vi.type.tensor_type.elem_type, - get_shape_from_sympy_shape( - self._new_symbolic_shape(shape_rank, node)))) - else: - input_sympy_shape = self._get_sympy_shape(node, 0) - total = int(1) - for d in input_sympy_shape: - total = total * d - new_sympy_shape = [] - deferred_dim_idx = -1 - non_deferred_size = int(1) - for i, d in enumerate(shape_value): - if type(d) == sympy.Symbol: - new_sympy_shape.append(d) - elif d == 0: - new_sympy_shape.append(input_sympy_shape[i]) - non_deferred_size = non_deferred_size * input_sympy_shape[i] - else: - new_sympy_shape.append(d) - if d == -1: - deferred_dim_idx = i - elif d != 0: - non_deferred_size = non_deferred_size * d - - assert new_sympy_shape.count(-1) < 2 - if -1 in new_sympy_shape: - new_dim = total // non_deferred_size - new_sympy_shape[deferred_dim_idx] = new_dim - - self._update_computed_dims(new_sympy_shape) - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], vi.type.tensor_type.elem_type, - get_shape_from_sympy_shape(new_sympy_shape))) - - self._pass_on_sympy_data(node) - - def _infer_Resize(self, node): - vi = self.known_vi_[node.output[0]] - input_sympy_shape = self._get_sympy_shape(node, 0) - if get_opset(self.out_mp_) <= 10: - scales = self._try_get_value(node, 1) - if scales is not None: - new_sympy_shape = [ - sympy.simplify(sympy.floor(d * s)) - for d, s in zip(input_sympy_shape, scales) - ] - self._update_computed_dims(new_sympy_shape) - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], self.known_vi_[node.input[ - 0]].type.tensor_type.elem_type, - get_shape_from_sympy_shape(new_sympy_shape))) - else: - roi = self._try_get_value(node, 1) - scales = self._try_get_value(node, 2) - sizes = self._try_get_value(node, 3) - if sizes is not None: - new_sympy_shape = [ - sympy.simplify(sympy.floor(s)) for s in sizes - ] - self._update_computed_dims(new_sympy_shape) - elif scales is not None: - rank = len(scales) - if get_attribute(node, 'coordinate_transformation_mode' - ) == 'tf_crop_and_resize': - assert len(roi) == 2 * rank - roi_start = list(roi)[:rank] - roi_end = list(roi)[rank:] - else: - roi_start = [0] * rank - roi_end = [1] * rank - scales = list(scales) - new_sympy_shape = [ - sympy.simplify(sympy.floor(d * (end - start) * scale)) - for d, start, end, scale in zip(input_sympy_shape, - roi_start, roi_end, scales) - ] - self._update_computed_dims(new_sympy_shape) - else: - new_sympy_shape = self._new_symbolic_shape( - self._get_shape_rank(node, 0), node) - - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], self.known_vi_[ - node.input[0]].type.tensor_type.elem_type, - get_shape_from_sympy_shape( - new_sympy_shape))) - - def _infer_Scan(self, node): - subgraph = get_attribute(node, 'body') - num_scan_inputs = get_attribute(node, 'num_scan_inputs') - scan_input_axes = get_attribute(node, 'scan_input_axes', - [0] * num_scan_inputs) - num_scan_states = len(node.input) - num_scan_inputs - scan_input_axes = [ - handle_negative_axis( - ax, self._get_shape_rank(node, i + num_scan_states)) - for i, ax in enumerate(scan_input_axes) - ] - # We may have cases where the subgraph has optionial inputs that appear in both subgraph's input and initializer, - # but not in the node's input. In such cases, the input model might be invalid, but let's skip those optional inputs. - assert len(subgraph.input) >= len(node.input) - subgraph_inputs = subgraph.input[:len(node.input)] - for i, si in enumerate(subgraph_inputs): - subgraph_name = si.name - si.CopyFrom(self.known_vi_[node.input[i]]) - if i >= num_scan_states: - scan_input_dim = si.type.tensor_type.shape.dim - scan_input_dim.remove( - scan_input_dim[scan_input_axes[i - num_scan_states]]) - si.name = subgraph_name - self._onnx_infer_subgraph(node, subgraph) - num_scan_outputs = len(node.output) - num_scan_states - scan_output_axes = get_attribute(node, 'scan_output_axes', - [0] * num_scan_outputs) - scan_input_dim = get_shape_from_type_proto( - self.known_vi_[node.input[-1]].type)[scan_input_axes[-1]] - for i, o in enumerate(node.output): - vi = self.known_vi_[o] - if i >= num_scan_states: - shape = get_shape_from_type_proto(subgraph.output[i].type) - new_dim = handle_negative_axis( - scan_output_axes[i - num_scan_states], len(shape) + 1) - shape = shape[:new_dim] + [scan_input_dim] + shape[new_dim:] - vi.CopyFrom( - helper.make_tensor_value_info(o, subgraph.output[ - i].type.tensor_type.elem_type, shape)) - else: - vi.CopyFrom(subgraph.output[i]) - vi.name = o - - def _infer_ScatterElements(self, node): - data_shape = self._get_shape(node, 0) - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], self.known_vi_[ - node.input[0]].type.tensor_type.elem_type, data_shape)) - - def _infer_SequenceAt(self, node): - # need to create new symbolic dimension if sequence shape has None: - seq_shape = self._get_shape(node, 0) - vi = self.known_vi_[node.output[0]] - if seq_shape is not None: - for di, d in enumerate(seq_shape): - if d is not None: - continue - new_dim = onnx.TensorShapeProto.Dimension() - new_dim.dim_param = str( - self._new_symbolic_dim_from_output(node, 0, di)) - vi.type.tensor_type.shape.dim[di].CopyFrom(new_dim) - - def _infer_SequenceInsert(self, node): - # workaround bug in onnx's shape inference - vi_seq = self.known_vi_[node.input[0]] - vi_tensor = self.known_vi_[node.input[1]] - vi_out_seq = self.known_vi_[node.output[0]] - vi_out_seq.CopyFrom(vi_seq) - vi_out_seq.name = node.output[0] - self._fuse_tensor_type(node, 0, vi_out_seq.type, vi_tensor.type) - - def _infer_Shape(self, node): - self.sympy_data_[node.output[0]] = self._get_sympy_shape(node, 0) - - def _infer_Size(self, node): - sympy_shape = self._get_sympy_shape(node, 0) - self.sympy_data_[node.output[0]] = sympy_reduce_product(sympy_shape) - self.known_vi_[node.output[0]].CopyFrom( - helper.make_tensor_value_info(node.output[0], - onnx.TensorProto.INT64, [])) - - def _infer_Slice(self, node): - def less_equal(x, y): - try: - return bool(x <= y) - except TypeError: - pass - try: - return bool(y >= x) - except TypeError: - pass - try: - return bool(-x >= -y) - except TypeError: - pass - try: - return bool(-y <= -x) - except TypeError: - # the last attempt; this may raise TypeError - return bool(y - x >= 0) - - def handle_negative_index(index, bound): - """ normalizes a negative index to be in [0, bound) """ - try: - if not less_equal(0, index): - if is_literal(index) and index <= -self.int_max_: - # this case is handled separately - return index - return bound + index - except TypeError: - logger.warning("Cannot determine if {} < 0".format(index)) - return index - - if get_opset(self.out_mp_) <= 9: - axes = get_attribute(node, 'axes') - starts = get_attribute(node, 'starts') - ends = get_attribute(node, 'ends') - if not axes: - axes = list(range(len(starts))) - steps = [1] * len(axes) - else: - starts = as_list(self._try_get_value(node, 1), keep_none=True) - ends = as_list(self._try_get_value(node, 2), keep_none=True) - axes = self._try_get_value(node, 3) - steps = self._try_get_value(node, 4) - if axes is None and not (starts is None and ends is None): - axes = list( - range(0, len(starts if starts is not None else ends))) - if steps is None and not (starts is None and ends is None): - steps = [1] * len(starts if starts is not None else ends) - axes = as_list(axes, keep_none=True) - steps = as_list(steps, keep_none=True) - - new_sympy_shape = self._get_sympy_shape(node, 0) - if starts is None or ends is None: - if axes is None: - for i in range(len(new_sympy_shape)): - new_sympy_shape[i] = self._new_symbolic_dim_from_output( - node, 0, i) - else: - new_sympy_shape = get_shape_from_sympy_shape(new_sympy_shape) - for i in axes: - new_sympy_shape[i] = self._new_symbolic_dim_from_output( - node, 0, i) - else: - for i, s, e, t in zip(axes, starts, ends, steps): - e = handle_negative_index(e, new_sympy_shape[i]) - if is_literal(e): - if e >= self.int_max_: - e = new_sympy_shape[i] - elif e <= -self.int_max_: - e = 0 if s > 0 else -1 - elif is_literal(new_sympy_shape[i]): - if e < 0: - e = max(0, e + new_sympy_shape[i]) - e = min(e, new_sympy_shape[i]) - else: - if e > 0: - e = sympy.Min( - e, new_sympy_shape[i] - ) if e > 1 else e #special case for slicing first to make computation easier - else: - if is_literal(new_sympy_shape[i]): - e = sympy.Min(e, new_sympy_shape[i]) - else: - try: - if not less_equal(e, new_sympy_shape[i]): - e = new_sympy_shape[i] - except Exception: - logger.warning( - 'Unable to determine if {} <= {}, treat as equal'. - format(e, new_sympy_shape[i])) - e = new_sympy_shape[i] - - s = handle_negative_index(s, new_sympy_shape[i]) - if is_literal(new_sympy_shape[i]) and is_literal(s): - s = max(0, min(s, new_sympy_shape[i])) - - new_sympy_shape[i] = sympy.simplify( - (e - s + t + (-1 if t > 0 else 1)) // t) - - self._update_computed_dims(new_sympy_shape) - - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], vi.type.tensor_type.elem_type, - get_shape_from_sympy_shape(new_sympy_shape))) - - # handle sympy_data if needed, for slice in shape computation - if (node.input[0] in self.sympy_data_ and [0] == axes and - len(starts) == 1 and len(ends) == 1 and len(steps) == 1): - input_sympy_data = self.sympy_data_[node.input[0]] - if type(input_sympy_data) == list or ( - type(input_sympy_data) == np.array and - len(input_sympy_data.shape) == 1): - self.sympy_data_[node.output[0]] = input_sympy_data[starts[ - 0]:ends[0]:steps[0]] - - def _infer_SoftmaxCrossEntropyLoss(self, node): - vi = self.known_vi_[node.output[0]] - elem_type = self.known_vi_[node.input[0]].type.tensor_type.elem_type - vi.type.tensor_type.elem_type = elem_type - vi.type.tensor_type.shape.CopyFrom(onnx.TensorShapeProto()) - - if len(node.output) > 1: - data_shape = self._get_shape(node, 0) - vi = self.known_vi_[node.output[1]] - vi.CopyFrom( - helper.make_tensor_value_info(vi.name, elem_type, data_shape)) - - def _infer_Split_Common(self, node, make_value_info_func): - input_sympy_shape = self._get_sympy_shape(node, 0) - axis = handle_negative_axis( - get_attribute(node, 'axis', 0), len(input_sympy_shape)) - split = get_attribute(node, 'split') - if not split: - num_outputs = len(node.output) - split = [input_sympy_shape[axis] / - sympy.Integer(num_outputs)] * num_outputs - self._update_computed_dims(split) - else: - split = [sympy.Integer(s) for s in split] - - for i_o in range(len(split)): - vi = self.known_vi_[node.output[i_o]] - vi.CopyFrom( - make_value_info_func(node.output[i_o], self.known_vi_[ - node.input[0]].type.tensor_type.elem_type, - get_shape_from_sympy_shape( - input_sympy_shape[:axis] + [ - split[i_o] - ] + input_sympy_shape[axis + 1:]))) - self.known_vi_[vi.name] = vi - - def _infer_Split(self, node): - self._infer_Split_Common(node, helper.make_tensor_value_info) - - def _infer_SplitToSequence(self, node): - self._infer_Split_Common(node, helper.make_sequence_value_info) - - def _infer_Squeeze(self, node): - input_shape = self._get_shape(node, 0) - op_set = get_opset(self.out_mp_) - - # Depending on op-version 'axes' are provided as attribute or via 2nd input - if op_set < 13: - axes = get_attribute(node, 'axes') - assert self._try_get_value(node, 1) is None - else: - axes = self._try_get_value(node, 1) - assert get_attribute(node, 'axes') is None - - if axes is None: - # No axes have been provided (neither via attribute nor via input). - # In this case the 'Shape' op should remove all axis with dimension 1. - # For symbolic dimensions we guess they are !=1. - output_shape = [s for s in input_shape if s != 1] - if self.verbose_ > 0: - symbolic_dimensions = [s for s in input_shape if type(s) != int] - if len(symbolic_dimensions) > 0: - logger.debug( - f"Symbolic dimensions in input shape of op: '{node.op_type}' node: '{node.name}'. " - + - f"Assuming the following dimensions are never equal to 1: {symbolic_dimensions}" - ) - else: - axes = [handle_negative_axis(a, len(input_shape)) for a in axes] - output_shape = [] - for i in range(len(input_shape)): - if i not in axes: - output_shape.append(input_shape[i]) - else: - assert input_shape[i] == 1 or type(input_shape[i]) != int - if self.verbose_ > 0 and type(input_shape[i]) != int: - logger.debug( - f"Symbolic dimensions in input shape of op: '{node.op_type}' node: '{node.name}'. " - + - f"Assuming the dimension '{input_shape[i]}' at index {i} of the input to be equal to 1." - ) - - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], self.known_vi_[ - node.input[0]].type.tensor_type.elem_type, output_shape)) - self._pass_on_sympy_data(node) - - def _infer_Tile(self, node): - repeats_value = self._try_get_value(node, 1) - new_sympy_shape = [] - if repeats_value is not None: - input_sympy_shape = self._get_sympy_shape(node, 0) - for i, d in enumerate(input_sympy_shape): - new_dim = d * repeats_value[i] - new_sympy_shape.append(new_dim) - self._update_computed_dims(new_sympy_shape) - else: - new_sympy_shape = self._new_symbolic_shape( - self._get_shape_rank(node, 0), node) - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], vi.type.tensor_type.elem_type, - get_shape_from_sympy_shape(new_sympy_shape))) - - def _infer_TopK(self, node): - rank = self._get_shape_rank(node, 0) - axis = handle_negative_axis(get_attribute(node, 'axis', -1), rank) - new_shape = self._get_shape(node, 0) - - if get_opset(self.out_mp_) <= 9: - k = get_attribute(node, 'k') - else: - k = self._get_int_values(node)[1] - - if k == None: - k = self._new_symbolic_dim_from_output(node) - else: - k = as_scalar(k) - - if type(k) in [int, str]: - new_shape[axis] = k - else: - new_sympy_shape = self._get_sympy_shape(node, 0) - new_sympy_shape[axis] = k - self._update_computed_dims( - new_sympy_shape - ) # note that TopK dim could be computed in sympy_data, so need to update computed_dims when it enters shape - new_shape = get_shape_from_sympy_shape(new_sympy_shape) - - for i_o in range(len(node.output)): - vi = self.known_vi_[node.output[i_o]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[ - i_o], vi.type.tensor_type.elem_type, new_shape)) - - def _infer_Transpose(self, node): - if node.input[0] in self.sympy_data_: - data_shape = self._get_shape(node, 0) - perm = get_attribute(node, 'perm', - reversed(list(range(len(data_shape))))) - input_data = self.sympy_data_[node.input[0]] - self.sympy_data_[node.output[0]] = np.transpose( - np.array(input_data).reshape(*data_shape), - axes=tuple(perm)).flatten().tolist() - - def _infer_Unsqueeze(self, node): - input_shape = self._get_shape(node, 0) - op_set = get_opset(self.out_mp_) - - # Depending on op-version 'axes' are provided as attribute or via 2nd input - if op_set < 13: - axes = get_attribute(node, 'axes') - assert self._try_get_value(node, 1) is None - else: - axes = self._try_get_value(node, 1) - assert get_attribute(node, 'axes') is None - - output_rank = len(input_shape) + len(axes) - axes = [handle_negative_axis(a, output_rank) for a in axes] - - input_axis = 0 - output_shape = [] - for i in range(output_rank): - if i in axes: - output_shape.append(1) - else: - output_shape.append(input_shape[input_axis]) - input_axis += 1 - - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], self.known_vi_[ - node.input[0]].type.tensor_type.elem_type, output_shape)) - - self._pass_on_sympy_data(node) - - def _infer_ZipMap(self, node): - map_key_type = None - if get_attribute(node, 'classlabels_int64s') is not None: - map_key_type = onnx.TensorProto.INT64 - elif get_attribute(node, 'classlabels_strings') is not None: - map_key_type = onnx.TensorProto.STRING - - assert map_key_type is not None - new_vi = onnx.ValueInfoProto() - new_vi.name = node.output[0] - new_vi.type.sequence_type.elem_type.map_type.value_type.tensor_type.elem_type = onnx.TensorProto.FLOAT - new_vi.type.sequence_type.elem_type.map_type.key_type = map_key_type - vi = self.known_vi_[node.output[0]] - vi.CopyFrom(new_vi) - - def _infer_Attention(self, node): - shape = self._get_shape(node, 0) - shape_bias = self._get_shape(node, 2) - assert len(shape) == 3 and len(shape_bias) == 1 - qkv_hidden_sizes_attr = get_attribute(node, 'qkv_hidden_sizes') - if qkv_hidden_sizes_attr is not None: - assert len(qkv_hidden_sizes_attr) == 3 - shape[2] = int(qkv_hidden_sizes_attr[2]) - else: - shape[2] = int(shape_bias[0] / 3) - output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], output_dtype, shape)) - - if len(node.output) > 1: - # input shape: (batch_size, sequence_length, hidden_size) - # past shape: (2, batch_size, num_heads, past_sequence_length, head_size) - # mask shape: (batch_size, total_sequence_length) or (batch_size, sequence_length, total_sequence_length) or (batch_size, 1, max_seq_len, max_seq_len) - # present shape: (2, batch_size, num_heads, total_sequence_length, head_size), where total_sequence_length=sequence_length+past_sequence_length - input_shape = self._get_shape(node, 0) - past_shape = self._get_shape(node, 4) - mask_shape = self._get_shape(node, 3) - if len(past_shape) == 5: - if len(mask_shape) in [2, 3]: - past_shape[3] = mask_shape[-1] - elif isinstance(input_shape[1], int) and isinstance( - past_shape[3], int): - past_shape[3] = input_shape[1] + past_shape[3] - else: - past_shape[3] = f"{past_shape[3]}+{input_shape[1]}" - vi = self.known_vi_[node.output[1]] - vi.CopyFrom( - helper.make_tensor_value_info(vi.name, output_dtype, - past_shape)) - - def _infer_BiasGelu(self, node): - self._propagate_shape_and_type(node) - - def _infer_FastGelu(self, node): - self._propagate_shape_and_type(node) - - def _infer_Gelu(self, node): - self._propagate_shape_and_type(node) - - def _infer_LayerNormalization(self, node): - self._propagate_shape_and_type(node) - - def _infer_LongformerAttention(self, node): - self._propagate_shape_and_type(node) - - def _infer_EmbedLayerNormalization(self, node): - input_ids_shape = self._get_shape(node, 0) - word_embedding_shape = self._get_shape(node, 2) - assert len(input_ids_shape) == 2 and len(word_embedding_shape) == 2 - output_shape = input_ids_shape + [word_embedding_shape[1]] - - word_embedding_dtype = self.known_vi_[node.input[ - 2]].type.tensor_type.elem_type - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], word_embedding_dtype, - output_shape)) - - mask_index_shape = [input_ids_shape[0]] - vi = self.known_vi_[node.output[1]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[ - 1], onnx.TensorProto.INT32, mask_index_shape)) - - if len(node.output) > 2: - # Optional output of add before layer nomalization is done - # shape is same as the output - vi = self.known_vi_[node.output[2]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[ - 2], word_embedding_dtype, output_shape)) - - def _infer_SkipLayerNormalization(self, node): - self._propagate_shape_and_type(node) - - def _infer_PythonOp(self, node): - output_tensor_types = get_attribute(node, 'output_tensor_types') - assert output_tensor_types - output_tensor_ranks = get_attribute(node, 'output_tensor_ranks') - assert output_tensor_ranks - - # set the context output seperately. - # The first output is autograd's context. - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], - onnx.TensorProto.INT64, [])) - - # Outputs after autograd's context are tensors. - # We assume their ranks are fixed for different model inputs. - for i in range(len(node.output) - 1): - # Process the i-th tensor outputs. - vi = self.known_vi_[node.output[i + 1]] - sympy_shape = self._new_symbolic_shape(output_tensor_ranks[i], node) - shape = get_shape_from_sympy_shape(sympy_shape) - value_info = helper.make_tensor_value_info( - node.output[i + 1], output_tensor_types[i], shape) - vi.CopyFrom(value_info) - - def _propagate_shape_and_type(self, node, input_index=0, output_index=0): - shape = self._get_shape(node, input_index) - output_dtype = self.known_vi_[node.input[ - input_index]].type.tensor_type.elem_type - vi = self.known_vi_[node.output[output_index]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[output_index], - output_dtype, shape)) - - def _is_none_dim(self, dim_value): - if type(dim_value) != str: - return False - if "unk__" not in dim_value: - return False - if dim_value in self.symbolic_dims_.keys(): - return False - return True - - def _is_shape_contains_none_dim(self, out_shape): - for out in out_shape: - if self._is_none_dim(out): - return out - return None - - def _infer_impl(self, start_sympy_data=None): - self.sympy_data_ = start_sympy_data or {} - self.out_mp_.graph.ClearField('value_info') - self._apply_suggested_merge(graph_input_only=True) - self.input_symbols_ = set() - for i in self.out_mp_.graph.input: - input_shape = get_shape_from_value_info(i) - if input_shape is None: - continue - - if is_sequence(i.type): - input_dims = i.type.sequence_type.elem_type.tensor_type.shape.dim - else: - input_dims = i.type.tensor_type.shape.dim - - for i_dim, dim in enumerate(input_shape): - if dim is None: - # some models use None for symbolic dim in input, replace it with a string - input_dims[i_dim].dim_param = str( - self._new_symbolic_dim(i.name, i_dim)) - - self.input_symbols_.update( - [d for d in input_shape if type(d) == str]) - - for s in self.input_symbols_: - if s in self.suggested_merge_: - s_merge = self.suggested_merge_[s] - assert s_merge in self.symbolic_dims_ - self.symbolic_dims_[s] = self.symbolic_dims_[s_merge] - else: - # Since inputs are not produced by other ops, we can assume positivity - self.symbolic_dims_[s] = sympy.Symbol( - s, integer=True, positive=True) - # create a temporary ModelProto for single node inference - # note that we remove initializer to have faster inference - # for tensor ops like Reshape/Tile/Expand that read initializer, we need to do sympy computation based inference anyways - self.tmp_mp_ = onnx.ModelProto() - self.tmp_mp_.CopyFrom(self.out_mp_) - self.tmp_mp_.graph.ClearField('initializer') - - # compute prerequesite for node for topological sort - # node with subgraphs may have dependency on implicit inputs, which will affect topological sort - prereq_for_node = { - } # map from node to all its inputs, including implicit ones in subgraph - - def get_prereq(node): - names = set(i for i in node.input if i) - subgraphs = [] - if 'If' == node.op_type: - subgraphs = [ - get_attribute(node, 'then_branch'), - get_attribute(node, 'else_branch') - ] - elif node.op_type in ['Loop', 'Scan']: - subgraphs = [get_attribute(node, 'body')] - for g in subgraphs: - g_outputs_and_initializers = {i.name for i in g.initializer} - g_prereq = set() - for n in g.node: - g_outputs_and_initializers.update(n.output) - for n in g.node: - g_prereq.update([ - i for i in get_prereq(n) - if i not in g_outputs_and_initializers - ]) - names.update(g_prereq) - # remove subgraph inputs from g_prereq since those are local-only - for i in g.input: - if i.name in names: - names.remove(i.name) - return names - - for n in self.tmp_mp_.graph.node: - prereq_for_node[n.output[0]] = get_prereq(n) - - # topological sort nodes, note there might be dead nodes so we check if all graph outputs are reached to terminate - sorted_nodes = [] - sorted_known_vi = set([ - i.name - for i in list(self.out_mp_.graph.input) + list( - self.out_mp_.graph.initializer) - ]) - if any([o.name in sorted_known_vi for o in self.out_mp_.graph.output]): - # Loop/Scan will have some graph output in graph inputs, so don't do topological sort - sorted_nodes = self.out_mp_.graph.node - else: - while not all( - [o.name in sorted_known_vi for o in self.out_mp_.graph.output]): - old_sorted_nodes_len = len(sorted_nodes) - for node in self.out_mp_.graph.node: - if (node.output[0] not in sorted_known_vi) and all([ - i in sorted_known_vi - for i in prereq_for_node[node.output[0]] if i - ]): - sorted_known_vi.update(node.output) - sorted_nodes.append(node) - if old_sorted_nodes_len == len(sorted_nodes) and not all([ - o.name in sorted_known_vi - for o in self.out_mp_.graph.output - ]): - raise Exception('Invalid model with cyclic graph') - - for node in sorted_nodes: - assert all([i in self.known_vi_ for i in node.input if i]) - self._onnx_infer_single_node(node) - known_aten_op = False - if node.op_type in self.dispatcher_: - self.dispatcher_[node.op_type](node) - elif node.op_type in ['ConvTranspose']: - # onnx shape inference ops like ConvTranspose may have empty shape for symbolic input - # before adding symbolic compute for them - # mark the output type as UNDEFINED to allow guessing of rank - vi = self.known_vi_[node.output[0]] - if len(vi.type.tensor_type.shape.dim) == 0: - vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED - elif node.op_type == 'ATen' and node.domain == 'org.pytorch.aten': - for attr in node.attribute: - # TODO: Is overload_name needed? - if attr.name == 'operator': - aten_op_name = attr.s.decode('utf-8') if isinstance( - attr.s, bytes) else attr.s - if aten_op_name in self.aten_op_dispatcher_: - known_aten_op = True - self.aten_op_dispatcher_[aten_op_name](node) - break - - if self.verbose_ > 2: - logger.debug(node.op_type + ': ' + node.name) - for i, name in enumerate(node.input): - logger.debug(' Input {}: {} {}'.format( - i, name, 'initializer' - if name in self.initializers_ else '')) - - # onnx automatically merge dims with value, i.e. Mul(['aaa', 'bbb'], [1000, 1]) -> [1000, 'bbb'] - # symbolic shape inference needs to apply merge of 'aaa' -> 1000 in this case - if node.op_type in [ - 'Add', 'Sub', 'Mul', 'Div', 'MatMul', 'MatMulInteger', - 'MatMulInteger16', 'Where', 'Sum' - ]: - vi = self.known_vi_[node.output[0]] - out_rank = len(get_shape_from_type_proto(vi.type)) - in_shapes = [ - self._get_shape(node, i) for i in range(len(node.input)) - ] - for d in range(out_rank - (2 if node.op_type in [ - 'MatMul', 'MatMulInteger', 'MatMulInteger16' - ] else 0)): - in_dims = [ - s[len(s) - out_rank + d] for s in in_shapes - if len(s) + d >= out_rank - ] - if len(in_dims) > 1: - self._check_merged_dims(in_dims, allow_broadcast=True) - - for i_o in range(len(node.output)): - vi = self.known_vi_[node.output[i_o]] - out_type = vi.type - out_type_kind = out_type.WhichOneof('value') - - # do not process shape for non-tensors - if out_type_kind not in [ - 'tensor_type', 'sparse_tensor_type', None - ]: - if self.verbose_ > 2: - if out_type_kind == 'sequence_type': - seq_cls_type = out_type.sequence_type.elem_type.WhichOneof( - 'value') - if 'tensor_type' == seq_cls_type: - logger.debug(' {}: sequence of {} {}'.format( - node.output[i_o], - str(get_shape_from_value_info(vi)), - onnx.TensorProto.DataType.Name( - vi.type.sequence_type.elem_type. - tensor_type.elem_type))) - else: - logger.debug(' {}: sequence of {}'.format( - node.output[i_o], seq_cls_type)) - else: - logger.debug(' {}: {}'.format(node.output[i_o], - out_type_kind)) - continue - - out_shape = get_shape_from_value_info(vi) - out_type_undefined = out_type.tensor_type.elem_type == onnx.TensorProto.UNDEFINED - if self.verbose_ > 2: - logger.debug(' {}: {} {}'.format( - node.output[i_o], - str(out_shape), - onnx.TensorProto.DataType.Name( - vi.type.tensor_type.elem_type))) - if node.output[i_o] in self.sympy_data_: - logger.debug(' Sympy Data: ' + str(self.sympy_data_[ - node.output[i_o]])) - - # onnx >= 1.11.0, use unk__#index instead of None when the shape dim is uncertain - if (out_shape is not None and - (None in out_shape or - self._is_shape_contains_none_dim(out_shape)) - ) or out_type_undefined: - if self.auto_merge_: - if node.op_type in [ - 'Add', 'Sub', 'Mul', 'Div', 'MatMul', - 'MatMulInteger', 'MatMulInteger16', 'Concat', - 'Where', 'Sum', 'Equal', 'Less', 'Greater', - 'LessOrEqual', 'GreaterOrEqual' - ]: - shapes = [ - self._get_shape(node, i) - for i in range(len(node.input)) - ] - if node.op_type in [ - 'MatMul', 'MatMulInteger', 'MatMulInteger16' - ]: - if None in out_shape or self._is_shape_contains_none_dim( - out_shape): - if None in out_shape: - idx = out_shape.index(None) - else: - idx = out_shape.index( - self._is_shape_contains_none_dim( - out_shape)) - dim_idx = [ - len(s) - len(out_shape) + idx - for s in shapes - ] - # only support auto merge for MatMul for dim < rank-2 when rank > 2 - assert len( - shapes[0]) > 2 and dim_idx[0] < len( - shapes[0]) - 2 - assert len( - shapes[1]) > 2 and dim_idx[1] < len( - shapes[1]) - 2 - elif node.op_type == 'Expand': - # auto merge for cases like Expand([min(batch, 1), min(seq, 512)], [batch, seq]) - shapes = [ - self._get_shape(node, 0), self._get_value(node, - 1) - ] - else: - shapes = [] - - if shapes: - for idx in range(len(out_shape)): - if out_shape[ - idx] is not None and not self._is_none_dim( - out_shape[idx]): - continue - # note that the broadcasting rule aligns from right to left - # if a tensor has a lower rank (dim_idx[idx] < 0), it would automatically broadcast and need no merge - dim_idx = [ - len(s) - len(out_shape) + idx - for s in shapes - ] - if len(dim_idx) > 0: - self._add_suggested_merge([ - s[i] if is_literal(s[i]) else str(s[i]) - for s, i in zip(shapes, dim_idx) - if i >= 0 - ]) - self.run_ = True - else: - self.run_ = False - else: - self.run_ = False - - # create new dynamic dims for ops not handled by symbolic shape inference - if self.run_ == False and not node.op_type in self.dispatcher_ and not known_aten_op: - is_unknown_op = out_type_undefined and ( - out_shape is None or len(out_shape) == 0) - if is_unknown_op: - # unknown op to ONNX, maybe from higher opset or other domain - # only guess the output rank from input 0 when using guess_output_rank option - out_rank = self._get_shape_rank( - node, 0) if self.guess_output_rank_ else -1 - else: - # valid ONNX op, but not handled by symbolic shape inference, just assign dynamic shape - out_rank = len(out_shape) - - if out_rank >= 0: - new_shape = self._new_symbolic_shape(out_rank, node, - i_o) - if out_type_undefined: - # guess output data type from input vi if not defined - out_dtype = self.known_vi_[node.input[ - 0]].type.tensor_type.elem_type - else: - # otherwise, use original data type - out_dtype = vi.type.tensor_type.elem_type - vi.CopyFrom( - helper.make_tensor_value_info( - vi.name, out_dtype, - get_shape_from_sympy_shape(new_shape))) - - if self.verbose_ > 0: - if is_unknown_op: - logger.debug( - "Possible unknown op: {} node: {}, guessing {} shape". - format(node.op_type, node.name, - vi.name)) - if self.verbose_ > 2: - logger.debug(' {}: {} {}'.format( - node.output[i_o], - str(new_shape), - vi.type.tensor_type.elem_type)) - - self.run_ = True - continue # continue the inference after guess, no need to stop as no merge is needed - - if self.verbose_ > 0 or not self.auto_merge_ or out_type_undefined: - logger.debug( - 'Stopping at incomplete shape inference at ' + - node.op_type + ': ' + node.name) - logger.debug('node inputs:') - for i in node.input: - logger.debug(self.known_vi_[i]) - logger.debug('node outputs:') - for o in node.output: - logger.debug(self.known_vi_[o]) - if self.auto_merge_ and not out_type_undefined: - logger.debug('Merging: ' + str( - self.suggested_merge_)) - return False - - self.run_ = False - return True - - def _update_output_from_vi(self): - for output in self.out_mp_.graph.output: - if output.name in self.known_vi_: - output.CopyFrom(self.known_vi_[output.name]) - - @staticmethod - def infer_shapes(in_mp, - int_max=2**31 - 1, - auto_merge=False, - guess_output_rank=False, - verbose=0): - onnx_opset = get_opset(in_mp) - if (not onnx_opset) or onnx_opset < 7: - logger.warning('Only support models of onnx opset 7 and above.') - return None - symbolic_shape_inference = SymbolicShapeInference( - int_max, auto_merge, guess_output_rank, verbose) - all_shapes_inferred = False - symbolic_shape_inference._preprocess(in_mp) - while symbolic_shape_inference.run_: - all_shapes_inferred = symbolic_shape_inference._infer_impl() - symbolic_shape_inference._update_output_from_vi() - if not all_shapes_inferred: - raise Exception("Incomplete symbolic shape inference") - return symbolic_shape_inference.out_mp_ - - -def parse_arguments(): - parser = argparse.ArgumentParser() - parser.add_argument('--input', required=True, help='The input model file') - parser.add_argument('--output', help='The output model file') - parser.add_argument( - '--auto_merge', - help='Automatically merge symbolic dims when confliction happens', - action='store_true', - default=False) - parser.add_argument( - '--int_max', - help='maximum value for integer to be treated as boundless for ops like slice', - type=int, - default=2**31 - 1) - parser.add_argument( - '--guess_output_rank', - help='guess output rank to be the same as input 0 for unknown ops', - action='store_true', - default=False) - parser.add_argument( - '--verbose', - help='Prints detailed logs of inference, 0: turn off, 1: warnings, 3: detailed', - type=int, - default=0) - return parser.parse_args() - - -if __name__ == '__main__': - args = parse_arguments() - logger.info('input model: ' + args.input) - if args.output: - logger.info('output model ' + args.output) - logger.info('Doing symbolic shape inference...') - out_mp = SymbolicShapeInference.infer_shapes( - onnx.load(args.input), args.int_max, args.auto_merge, - args.guess_output_rank, args.verbose) - if args.output and out_mp: - onnx.save(out_mp, args.output) - logger.info('Done!') diff --git a/speechx/examples/ds2_ol/onnx/local/onnx_opt.sh b/speechx/examples/ds2_ol/onnx/local/onnx_opt.sh deleted file mode 100755 index ce2f24e58..000000000 --- a/speechx/examples/ds2_ol/onnx/local/onnx_opt.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash - -set -e - -if [ $# != 3 ];then - # ./local/onnx_opt.sh model.old.onnx model.opt.onnx "audio_chunk:1,-1,161 audio_chunk_lens:1 chunk_state_c_box:5,1,1024 chunk_state_h_box:5,1,1024" - echo "usage: $0 onnx.model.in onnx.model.out input_shape " - exit 1 -fi - -# onnx optimizer -pip install onnx-simplifier - -in=$1 -out=$2 -input_shape=$3 - -check_n=3 - -onnxsim $in $out $check_n --dynamic-input-shape --input-shape $input_shape \ No newline at end of file diff --git a/speechx/examples/ds2_ol/onnx/local/onnx_prune_model.py b/speechx/examples/ds2_ol/onnx/local/onnx_prune_model.py deleted file mode 100755 index 5b85eef3e..000000000 --- a/speechx/examples/ds2_ol/onnx/local/onnx_prune_model.py +++ /dev/null @@ -1,128 +0,0 @@ -#!/usr/bin/env python3 -W ignore::DeprecationWarning -# prune model by output names -import argparse -import copy -import sys - -import onnx - - -def parse_arguments(): - parser = argparse.ArgumentParser() - parser.add_argument( - '--model', - required=True, - help='Path of directory saved the input model.') - parser.add_argument( - '--output_names', - required=True, - nargs='+', - help='The outputs of pruned model.') - parser.add_argument( - '--save_file', required=True, help='Path to save the new onnx model.') - return parser.parse_args() - - -if __name__ == '__main__': - args = parse_arguments() - - if len(set(args.output_names)) < len(args.output_names): - print( - "[ERROR] There's dumplicate name in --output_names, which is not allowed." - ) - sys.exit(-1) - - model = onnx.load(args.model) - - # collect all node outputs and graph output - output_tensor_names = set() - for node in model.graph.node: - for out in node.output: - # may contain model output - output_tensor_names.add(out) - - # for out in model.graph.output: - # output_tensor_names.add(out.name) - - for output_name in args.output_names: - if output_name not in output_tensor_names: - print( - "[ERROR] Cannot find output tensor name '{}' in onnx model graph.". - format(output_name)) - sys.exit(-1) - - output_node_indices = set() # has output names - output_to_node = dict() # all node outputs - for i, node in enumerate(model.graph.node): - for out in node.output: - output_to_node[out] = i - if out in args.output_names: - output_node_indices.add(i) - - # from outputs find all the ancestors - reserved_node_indices = copy.deepcopy( - output_node_indices) # nodes need to keep - reserved_inputs = set() # model input to keep - new_output_node_indices = copy.deepcopy(output_node_indices) - - while True and len(new_output_node_indices) > 0: - output_node_indices = copy.deepcopy(new_output_node_indices) - - new_output_node_indices = set() - - for out_node_idx in output_node_indices: - # backtrace to parenet - for ipt in model.graph.node[out_node_idx].input: - if ipt in output_to_node: - reserved_node_indices.add(output_to_node[ipt]) - new_output_node_indices.add(output_to_node[ipt]) - else: - reserved_inputs.add(ipt) - - num_inputs = len(model.graph.input) - num_outputs = len(model.graph.output) - num_nodes = len(model.graph.node) - print( - f"old graph has {num_inputs} inputs, {num_outputs} outpus, {num_nodes} nodes" - ) - print(f"{len(reserved_node_indices)} node to keep.") - - # del node not to keep - for idx in range(num_nodes - 1, -1, -1): - if idx not in reserved_node_indices: - del model.graph.node[idx] - - # del graph input not to keep - for idx in range(num_inputs - 1, -1, -1): - if model.graph.input[idx].name not in reserved_inputs: - del model.graph.input[idx] - - # del old graph outputs - for i in range(num_outputs): - del model.graph.output[0] - - # new graph output as user input - for out in args.output_names: - model.graph.output.extend([onnx.ValueInfoProto(name=out)]) - - # infer shape - try: - from onnx_infer_shape import SymbolicShapeInference - model = SymbolicShapeInference.infer_shapes( - model, - int_max=2**31 - 1, - auto_merge=True, - guess_output_rank=False, - verbose=1) - except Exception as e: - print(f"skip infer shape step: {e}") - - # check onnx model - onnx.checker.check_model(model) - # save onnx model - onnx.save(model, args.save_file) - print("[Finished] The new model saved in {}.".format(args.save_file)) - print("[DEBUG INFO] The inputs of new model: {}".format( - [x.name for x in model.graph.input])) - print("[DEBUG INFO] The outputs of new model: {}".format( - [x.name for x in model.graph.output])) diff --git a/speechx/examples/ds2_ol/onnx/local/onnx_rename_model.py b/speechx/examples/ds2_ol/onnx/local/onnx_rename_model.py deleted file mode 100755 index fc00a82ec..000000000 --- a/speechx/examples/ds2_ol/onnx/local/onnx_rename_model.py +++ /dev/null @@ -1,111 +0,0 @@ -#!/usr/bin/env python3 -W ignore::DeprecationWarning -# rename node to new names -import argparse -import sys - -import onnx - - -def parse_arguments(): - parser = argparse.ArgumentParser() - parser.add_argument( - '--model', - required=True, - help='Path of directory saved the input model.') - parser.add_argument( - '--origin_names', - required=True, - nargs='+', - help='The original name you want to modify.') - parser.add_argument( - '--new_names', - required=True, - nargs='+', - help='The new name you want change to, the number of new_names should be same with the number of origin_names' - ) - parser.add_argument( - '--save_file', required=True, help='Path to save the new onnx model.') - return parser.parse_args() - - -if __name__ == '__main__': - args = parse_arguments() - - if len(set(args.origin_names)) < len(args.origin_names): - print( - "[ERROR] There's dumplicate name in --origin_names, which is not allowed." - ) - sys.exit(-1) - - if len(set(args.new_names)) < len(args.new_names): - print( - "[ERROR] There's dumplicate name in --new_names, which is not allowed." - ) - sys.exit(-1) - - if len(args.new_names) != len(args.origin_names): - print( - "[ERROR] Number of --new_names must be same with the number of --origin_names." - ) - sys.exit(-1) - - model = onnx.load(args.model) - - # collect input and all node output - output_tensor_names = set() - for ipt in model.graph.input: - output_tensor_names.add(ipt.name) - - for node in model.graph.node: - for out in node.output: - output_tensor_names.add(out) - - for origin_name in args.origin_names: - if origin_name not in output_tensor_names: - print( - f"[ERROR] Cannot find tensor name '{origin_name}' in onnx model graph." - ) - sys.exit(-1) - - for new_name in args.new_names: - if new_name in output_tensor_names: - print( - "[ERROR] The defined new_name '{}' is already exist in the onnx model, which is not allowed." - ) - sys.exit(-1) - - # rename graph input - for i, ipt in enumerate(model.graph.input): - if ipt.name in args.origin_names: - idx = args.origin_names.index(ipt.name) - model.graph.input[i].name = args.new_names[idx] - - # rename node input and output - for i, node in enumerate(model.graph.node): - for j, ipt in enumerate(node.input): - if ipt in args.origin_names: - idx = args.origin_names.index(ipt) - model.graph.node[i].input[j] = args.new_names[idx] - - for j, out in enumerate(node.output): - if out in args.origin_names: - idx = args.origin_names.index(out) - model.graph.node[i].output[j] = args.new_names[idx] - - # rename graph output - for i, out in enumerate(model.graph.output): - if out.name in args.origin_names: - idx = args.origin_names.index(out.name) - model.graph.output[i].name = args.new_names[idx] - - # check onnx model - onnx.checker.check_model(model) - - # save model - onnx.save(model, args.save_file) - - print("[Finished] The new model saved in {}.".format(args.save_file)) - print("[DEBUG INFO] The inputs of new model: {}".format( - [x.name for x in model.graph.input])) - print("[DEBUG INFO] The outputs of new model: {}".format( - [x.name for x in model.graph.output])) diff --git a/speechx/examples/ds2_ol/onnx/local/ort_dyanmic_quant.py b/speechx/examples/ds2_ol/onnx/local/ort_dyanmic_quant.py deleted file mode 100755 index 2c5692369..000000000 --- a/speechx/examples/ds2_ol/onnx/local/ort_dyanmic_quant.py +++ /dev/null @@ -1,48 +0,0 @@ -#!/usr/bin/env python3 -import argparse - -from onnxruntime.quantization import quantize_dynamic -from onnxruntime.quantization import QuantType - - -def quantize_onnx_model(onnx_model_path, - quantized_model_path, - nodes_to_exclude=[]): - print("Starting quantization...") - - quantize_dynamic( - onnx_model_path, - quantized_model_path, - weight_type=QuantType.QInt8, - nodes_to_exclude=nodes_to_exclude) - - print(f"Quantized model saved to: {quantized_model_path}") - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--model-in", - type=str, - required=True, - help="ONNX model", ) - parser.add_argument( - "--model-out", - type=str, - required=True, - default='model.quant.onnx', - help="ONNX model", ) - parser.add_argument( - "--nodes-to-exclude", - type=str, - required=True, - help="nodes to exclude. e.g. conv,linear.", ) - - args = parser.parse_args() - - nodes_to_exclude = args.nodes_to_exclude.split(',') - quantize_onnx_model(args.model_in, args.model_out, nodes_to_exclude) - - -if __name__ == "__main__": - main() diff --git a/speechx/examples/ds2_ol/onnx/local/ort_opt.py b/speechx/examples/ds2_ol/onnx/local/ort_opt.py deleted file mode 100755 index 8e995bcf0..000000000 --- a/speechx/examples/ds2_ol/onnx/local/ort_opt.py +++ /dev/null @@ -1,45 +0,0 @@ -#!/usr/bin/env python3 -import argparse - -import onnxruntime as ort - -# onnxruntime optimizer. -# https://onnxruntime.ai/docs/performance/graph-optimizations.html -# https://onnxruntime.ai/docs/api/python/api_summary.html#api - - -def parse_arguments(): - parser = argparse.ArgumentParser() - parser.add_argument( - '--model_in', required=True, type=str, help='Path to onnx model.') - parser.add_argument( - '--opt_level', - required=True, - type=int, - default=0, - choices=[0, 1, 2], - help='Path to onnx model.') - parser.add_argument( - '--model_out', required=True, help='path to save the optimized model.') - parser.add_argument('--debug', default=False, help='output debug info.') - return parser.parse_args() - - -if __name__ == '__main__': - args = parse_arguments() - - sess_options = ort.SessionOptions() - - # Set graph optimization level - print(f"opt level: {args.opt_level}") - if args.opt_level == 0: - sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC - elif args.opt_level == 1: - sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED - else: - sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL - - # To enable model serialization after graph optimization set this - sess_options.optimized_model_filepath = args.model_out - - session = ort.InferenceSession(args.model_in, sess_options) diff --git a/speechx/examples/ds2_ol/onnx/local/tonnx.sh b/speechx/examples/ds2_ol/onnx/local/tonnx.sh deleted file mode 100755 index 104872303..000000000 --- a/speechx/examples/ds2_ol/onnx/local/tonnx.sh +++ /dev/null @@ -1,26 +0,0 @@ -#!/bin/bash - -if [ $# != 4 ];then - # local/tonnx.sh data/exp/deepspeech2_online/checkpoints avg_1.jit.pdmodel avg_1.jit.pdiparams exp/model.onnx - echo "usage: $0 model_dir model_name param_name onnx_output_name" - exit 1 -fi - -dir=$1 -model=$2 -param=$3 -output=$4 - -pip install paddle2onnx -pip install onnx - -# https://github.com/PaddlePaddle/Paddle2ONNX#%E5%91%BD%E4%BB%A4%E8%A1%8C%E8%BD%AC%E6%8D%A2 - # opset10 support quantize -paddle2onnx --model_dir $dir \ - --model_filename $model \ - --params_filename $param \ - --save_file $output \ - --enable_dev_version True \ - --opset_version 11 \ - --enable_onnx_checker True - \ No newline at end of file diff --git a/speechx/examples/ds2_ol/onnx/path.sh b/speechx/examples/ds2_ol/onnx/path.sh deleted file mode 100755 index 97d487379..000000000 --- a/speechx/examples/ds2_ol/onnx/path.sh +++ /dev/null @@ -1,14 +0,0 @@ -# This contains the locations of binarys build required for running the examples. - -MAIN_ROOT=`realpath $PWD/../../../../` -SPEECHX_ROOT=$PWD/../../../ -SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx - -SPEECHX_TOOLS=$SPEECHX_ROOT/tools -TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin - -[ -d $SPEECHX_BUILD ] || { echo "Error: 'build/speechx' directory not found. please ensure that the project build successfully"; } - -export LC_AL=C - -export PATH=$PATH:$TOOLS_BIN diff --git a/speechx/examples/ds2_ol/onnx/run.sh b/speechx/examples/ds2_ol/onnx/run.sh deleted file mode 100755 index 3dc5e9100..000000000 --- a/speechx/examples/ds2_ol/onnx/run.sh +++ /dev/null @@ -1,91 +0,0 @@ -#!/bin/bash - -set -e - -. path.sh - -stage=0 -stop_stage=50 -tarfile=asr0_deepspeech2_online_wenetspeech_ckpt_1.0.2.model.tar.gz -#tarfile=asr0_deepspeech2_online_aishell_fbank161_ckpt_1.0.1.model.tar.gz -model_prefix=avg_10.jit -#model_prefix=avg_1.jit -model=${model_prefix}.pdmodel -param=${model_prefix}.pdiparams - -. utils/parse_options.sh - -data=data -exp=exp - -mkdir -p $data $exp - -dir=$data/exp/deepspeech2_online/checkpoints - -# wenetspeech or aishell -model_type=$(echo $tarfile | cut -d '_' -f 4) - -if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ];then - test -f $data/$tarfile || wget -P $data -c https://paddlespeech.bj.bcebos.com/s2t/$model_type/asr0/$tarfile - - # wenetspeech ds2 model - pushd $data - tar zxvf $tarfile - popd - - # ds2 model demo inputs - pushd $exp - wget -c http://paddlespeech.bj.bcebos.com/speechx/examples/ds2_ol/onnx/static_ds2online_inputs.pickle - popd -fi - -input_file=$exp/static_ds2online_inputs.pickle -test -e $input_file - -if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ];then - # to onnx - ./local/tonnx.sh $dir $model $param $exp/model.onnx - - ./local/infer_check.py --input_file $input_file --model_type $model_type --model_dir $dir --model_prefix $model_prefix --onnx_model $exp/model.onnx -fi - - -if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ] ;then - # ort graph optmize - ./local/ort_opt.py --model_in $exp/model.onnx --opt_level 0 --model_out $exp/model.ort.opt.onnx - - ./local/infer_check.py --input_file $input_file --model_type $model_type --model_dir $dir --model_prefix $model_prefix --onnx_model $exp/model.ort.opt.onnx -fi - - -if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ];then - # convert opset_num to 11 - ./local/onnx_convert_opset.py --target-opset 11 --model-file $exp/model.ort.opt.onnx --save-model $exp/model.optset11.onnx - - # quant model - nodes_to_exclude='p2o.Conv.0,p2o.Conv.2' - ./local/ort_dyanmic_quant.py --model-in $exp/model.optset11.onnx --model-out $exp/model.optset11.quant.onnx --nodes-to-exclude "${nodes_to_exclude}" - - ./local/infer_check.py --input_file $input_file --model_type $model_type --model_dir $dir --model_prefix $model_prefix --onnx_model $exp/model.optset11.quant.onnx -fi - - -# aishell rnn hidden is 1024 -# wenetspeech rnn hiddn is 2048 -if [ $model_type == 'aishell' ];then - input_shape="audio_chunk:1,-1,161 audio_chunk_lens:1 chunk_state_c_box:5,1,1024 chunk_state_h_box:5,1,1024" -elif [ $model_type == 'wenetspeech' ];then - input_shape="audio_chunk:1,-1,161 audio_chunk_lens:1 chunk_state_c_box:5,1,2048 chunk_state_h_box:5,1,2048" -else - echo "not support: $model_type" - exit -1 -fi - - -if [ ${stage} -le 51 ] && [ ${stop_stage} -ge 51 ] ;then - # wenetspeech ds2 model execed 2GB limit, will error. - # simplifying onnx model - ./local/onnx_opt.sh $exp/model.onnx $exp/model.opt.onnx "$input_shape" - - ./local/infer_check.py --input_file $input_file --model_type $model_type --model_dir $dir --model_prefix $model_prefix --onnx_model $exp/model.opt.onnx -fi diff --git a/speechx/examples/ds2_ol/onnx/utils b/speechx/examples/ds2_ol/onnx/utils deleted file mode 120000 index c2519a9dd..000000000 --- a/speechx/examples/ds2_ol/onnx/utils +++ /dev/null @@ -1 +0,0 @@ -../../../../utils/ \ No newline at end of file diff --git a/speechx/examples/ds2_ol/websocket/.gitignore b/speechx/examples/ds2_ol/websocket/.gitignore deleted file mode 100644 index bbd86a25b..000000000 --- a/speechx/examples/ds2_ol/websocket/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -data -exp diff --git a/speechx/examples/ds2_ol/websocket/README.md b/speechx/examples/ds2_ol/websocket/README.md deleted file mode 100644 index 3fa84135f..000000000 --- a/speechx/examples/ds2_ol/websocket/README.md +++ /dev/null @@ -1,78 +0,0 @@ -# Streaming DeepSpeech2 Server with WebSocket - -This example is about using `websocket` as streaming deepspeech2 server. For deepspeech2 model training please see [here](../../../../examples/aishell/asr0/). - -The websocket protocal is same to [PaddleSpeech Server](../../../../demos/streaming_asr_server/), -for detail of implementation please see [here](../../../speechx/protocol/websocket/). - - -## Source path.sh - -```bash -. path.sh -``` - -SpeechX bins is under `echo $SPEECHX_BUILD`, more info please see `path.sh`. - - -## Start WebSocket Server - -```bash -bash websoket_server.sh -``` - -The output is like below: - -```text -I1130 02:19:32.029882 12856 cmvn_json2kaldi_main.cc:39] cmvn josn path: /workspace/zhanghui/PaddleSpeech/speechx/examples/ds2_ol/websocket/data/model/data/mean_std.json -I1130 02:19:32.032230 12856 cmvn_json2kaldi_main.cc:73] nframe: 907497 -I1130 02:19:32.032564 12856 cmvn_json2kaldi_main.cc:85] cmvn stats have write into: /workspace/zhanghui/PaddleSpeech/speechx/examples/ds2_ol/websocket/data/cmvn.ark -I1130 02:19:32.032579 12856 cmvn_json2kaldi_main.cc:86] Binary: 1 -I1130 02:19:32.798342 12937 feature_pipeline.h:53] cmvn file: /workspace/zhanghui/PaddleSpeech/speechx/examples/ds2_ol/websocket/data/cmvn.ark -I1130 02:19:32.798542 12937 feature_pipeline.h:58] dither: 0 -I1130 02:19:32.798583 12937 feature_pipeline.h:60] frame shift ms: 10 -I1130 02:19:32.798588 12937 feature_pipeline.h:62] feature type: linear -I1130 02:19:32.798596 12937 feature_pipeline.h:80] frame length ms: 20 -I1130 02:19:32.798601 12937 feature_pipeline.h:88] subsampling rate: 4 -I1130 02:19:32.798606 12937 feature_pipeline.h:90] nnet receptive filed length: 7 -I1130 02:19:32.798611 12937 feature_pipeline.h:92] nnet chunk size: 1 -I1130 02:19:32.798615 12937 feature_pipeline.h:94] frontend fill zeros: 0 -I1130 02:19:32.798630 12937 nnet_itf.h:52] subsampling rate: 4 -I1130 02:19:32.798635 12937 nnet_itf.h:54] model path: /workspace/zhanghui/PaddleSpeech/speechx/examples/ds2_ol/websocket/data/model/exp/deepspeech2_online/checkpoints//avg_1.jit.pdmodel -I1130 02:19:32.798640 12937 nnet_itf.h:57] param path: /workspace/zhanghui/PaddleSpeech/speechx/examples/ds2_ol/websocket/data/model/exp/deepspeech2_online/checkpoints//avg_1.jit.pdiparams -I1130 02:19:32.798643 12937 nnet_itf.h:59] DS2 param: -I1130 02:19:32.798647 12937 nnet_itf.h:61] cache names: chunk_state_h_box,chunk_state_c_box -I1130 02:19:32.798652 12937 nnet_itf.h:63] cache shape: 5-1-1024,5-1-1024 -I1130 02:19:32.798656 12937 nnet_itf.h:65] input names: audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box -I1130 02:19:32.798660 12937 nnet_itf.h:67] output names: softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 -I1130 02:19:32.798664 12937 ctc_tlg_decoder.h:41] fst path: /workspace/zhanghui/PaddleSpeech/speechx/examples/ds2_ol/websocket/data/wfst//TLG.fst -I1130 02:19:32.798669 12937 ctc_tlg_decoder.h:42] fst symbole table: /workspace/zhanghui/PaddleSpeech/speechx/examples/ds2_ol/websocket/data/wfst//words.txt -I1130 02:19:32.798673 12937 ctc_tlg_decoder.h:47] LatticeFasterDecoder max active: 7500 -I1130 02:19:32.798677 12937 ctc_tlg_decoder.h:49] LatticeFasterDecoder beam: 15 -I1130 02:19:32.798681 12937 ctc_tlg_decoder.h:50] LatticeFasterDecoder lattice_beam: 7.5 -I1130 02:19:32.798708 12937 websocket_server_main.cc:37] Listening at port 8082 -``` - -## Start WebSocket Client - -```bash -bash websocket_client.sh -``` - -This script using AISHELL-1 test data to call websocket server. - -The input is specific by `--wav_rspecifier=scp:$data/$aishell_wav_scp`. - -The `scp` file which look like this: -```text -# head data/split1/1/aishell_test.scp -BAC009S0764W0121 /workspace/PaddleSpeech/speechx/examples/u2pp_ol/wenetspeech/data/test/S0764/BAC009S0764W0121.wav -BAC009S0764W0122 /workspace/PaddleSpeech/speechx/examples/u2pp_ol/wenetspeech/data/test/S0764/BAC009S0764W0122.wav -... -BAC009S0764W0125 /workspace/PaddleSpeech/speechx/examples/u2pp_ol/wenetspeech/data/test/S0764/BAC009S0764W0125.wav -``` - -If you want to recognize one wav, you can make `scp` file like this: -```text -key path/to/wav/file -``` diff --git a/speechx/examples/ds2_ol/websocket/path.sh b/speechx/examples/ds2_ol/websocket/path.sh deleted file mode 100755 index 6dd6bddbf..000000000 --- a/speechx/examples/ds2_ol/websocket/path.sh +++ /dev/null @@ -1,14 +0,0 @@ -# This contains the locations of binarys build required for running the examples. - -SPEECHX_ROOT=$PWD/../../../ -SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx - -SPEECHX_TOOLS=$SPEECHX_ROOT/tools -TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin - -[ -d $SPEECHX_BUILD ] || { echo "Error: 'build/speechx' directory not found. please ensure that the project build successfully"; } - -export LC_AL=C - -SPEECHX_BIN=$SPEECHX_BUILD/protocol/websocket:$SPEECHX_BUILD/frontend/audio -export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN diff --git a/speechx/examples/ds2_ol/websocket/websocket_client.sh b/speechx/examples/ds2_ol/websocket/websocket_client.sh deleted file mode 100755 index a508adfbc..000000000 --- a/speechx/examples/ds2_ol/websocket/websocket_client.sh +++ /dev/null @@ -1,35 +0,0 @@ -#!/bin/bash -set +x -set -e - -. path.sh - -# 1. compile -if [ ! -d ${SPEECHX_EXAMPLES} ]; then - pushd ${SPEECHX_ROOT} - bash build.sh - popd -fi - -# input -mkdir -p data -data=$PWD/data - -# output -aishell_wav_scp=aishell_test.scp -if [ ! -d $data/test ]; then - pushd $data - wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_test.zip - unzip aishell_test.zip - popd - - realpath $data/test/*/*.wav > $data/wavlist - awk -F '/' '{ print $(NF) }' $data/wavlist | awk -F '.' '{ print $1 }' > $data/utt_id - paste $data/utt_id $data/wavlist > $data/$aishell_wav_scp -fi - -export GLOG_logtostderr=1 - -# websocket client -websocket_client_main \ - --wav_rspecifier=scp:$data/$aishell_wav_scp --streaming_chunk=0.5 diff --git a/speechx/examples/ds2_ol/websocket/websocket_server.sh b/speechx/examples/ds2_ol/websocket/websocket_server.sh deleted file mode 100755 index 18d29857c..000000000 --- a/speechx/examples/ds2_ol/websocket/websocket_server.sh +++ /dev/null @@ -1,55 +0,0 @@ -#!/bin/bash -set +x -set -e - -. path.sh - -# 1. compile -if [ ! -d ${SPEECHX_EXAMPLES} ]; then - pushd ${SPEECHX_ROOT} - bash build.sh - popd -fi - -# input -mkdir -p data -data=$PWD/data -ckpt_dir=$data/model -model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/ -vocb_dir=$ckpt_dir/data/lang_char/ - - -if [ ! -f $ckpt_dir/data/mean_std.json ]; then - mkdir -p $ckpt_dir - pushd $ckpt_dir - wget -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz - tar xzfv asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz - popd -fi - -export GLOG_logtostderr=1 - -# 3. gen cmvn -cmvn=$data/cmvn.ark -cmvn_json2kaldi_main --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn - - -wfst=$data/wfst/ -mkdir -p $wfst -if [ ! -f $wfst/aishell_graph.zip ]; then - pushd $wfst - wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_graph.zip - unzip aishell_graph.zip - mv aishell_graph/* $wfst - popd -fi - -# 5. test websocket server -websocket_server_main \ - --cmvn_file=$cmvn \ - --model_path=$model_dir/avg_1.jit.pdmodel \ - --param_path=$model_dir/avg_1.jit.pdiparams \ - --word_symbol_table=$wfst/words.txt \ - --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ - --graph_path=$wfst/TLG.fst --max_active=7500 \ - --acoustic_scale=1.2 diff --git a/speechx/examples/u2pp_ol/wenetspeech/path.sh b/speechx/examples/u2pp_ol/wenetspeech/path.sh index ec278bd3d..9518db116 100644 --- a/speechx/examples/u2pp_ol/wenetspeech/path.sh +++ b/speechx/examples/u2pp_ol/wenetspeech/path.sh @@ -3,7 +3,7 @@ unset GREP_OPTIONS SPEECHX_ROOT=$PWD/../../../ -SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx +SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx/asr SPEECHX_TOOLS=$SPEECHX_ROOT/tools TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin @@ -12,7 +12,7 @@ TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin export LC_AL=C -export PATH=$PATH:$TOOLS_BIN:$SPEECHX_BUILD/nnet:$SPEECHX_BUILD/decoder:$SPEECHX_BUILD/frontend/audio:$SPEECHX_BUILD/recognizer +export PATH=$PATH:$TOOLS_BIN:$SPEECHX_BUILD/nnet:$SPEECHX_BUILD/decoder:$SPEECHX_BUILD/../common/frontend/audio:$SPEECHX_BUILD/recognizer PADDLE_LIB_PATH=$(python -c "import os; import paddle; include_dir=paddle.sysconfig.get_include(); paddle_dir=os.path.split(include_dir)[0]; libs_dir=os.path.join(paddle_dir, 'libs'); fluid_dir=os.path.join(paddle_dir, 'fluid'); out=':'.join([libs_dir, fluid_dir]); print(out);") export LD_LIBRARY_PATH=$PADDLE_LIB_PATH:$LD_LIBRARY_PATH diff --git a/speechx/speechx/asr/decoder/CMakeLists.txt b/speechx/speechx/asr/decoder/CMakeLists.txt index 93014fb90..b2f507080 100644 --- a/speechx/speechx/asr/decoder/CMakeLists.txt +++ b/speechx/speechx/asr/decoder/CMakeLists.txt @@ -1,55 +1,22 @@ -include_directories(${CMAKE_CURRENT_SOURCE_DIR/ctc_decoders}) - set(srcs) - -if (USING_DS2) list(APPEND srcs - ctc_decoders/decoder_utils.cpp - ctc_decoders/path_trie.cpp - ctc_decoders/scorer.cpp - ctc_beam_search_decoder.cc - ctc_tlg_decoder.cc + ctc_prefix_beam_search_decoder.cc ) -endif() - -if (USING_U2) - list(APPEND srcs - ctc_prefix_beam_search_decoder.cc - ) -endif() add_library(decoder STATIC ${srcs}) -target_link_libraries(decoder PUBLIC kenlm utils fst frontend nnet kaldi-decoder) +target_link_libraries(decoder PUBLIC utils fst frontend nnet kaldi-decoder) # test -if (USING_DS2) - set(BINS - ctc_beam_search_decoder_main - nnet_logprob_decoder_main - ctc_tlg_decoder_main - ) - - foreach(bin_name IN LISTS BINS) - add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) - target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) - target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS}) - endforeach() -endif() - - -if (USING_U2) - set(TEST_BINS - ctc_prefix_beam_search_decoder_main - ) - - foreach(bin_name IN LISTS TEST_BINS) - add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) - target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) - target_link_libraries(${bin_name} nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util) - target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS}) - target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) - target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS}) - endforeach() +set(TEST_BINS + ctc_prefix_beam_search_decoder_main +) -endif() +foreach(bin_name IN LISTS TEST_BINS) + add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) + target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) + target_link_libraries(${bin_name} nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util) + target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS}) + target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) + target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS}) +endforeach() diff --git a/speechx/speechx/asr/decoder/ctc_beam_search_decoder.cc b/speechx/speechx/asr/decoder/ctc_beam_search_decoder.cc deleted file mode 100644 index 6e3a0d136..000000000 --- a/speechx/speechx/asr/decoder/ctc_beam_search_decoder.cc +++ /dev/null @@ -1,313 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - - -#include "decoder/ctc_beam_search_decoder.h" - -#include "base/common.h" -#include "decoder/ctc_decoders/decoder_utils.h" -#include "utils/file_utils.h" - -namespace ppspeech { - -using std::vector; -using FSTMATCH = fst::SortedMatcher; - -CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts) - : opts_(opts), init_ext_scorer_(nullptr), space_id_(-1), root_(nullptr) { - LOG(INFO) << "dict path: " << opts_.dict_file; - if (!ReadFileToVector(opts_.dict_file, &vocabulary_)) { - LOG(INFO) << "load the dict failed"; - } - LOG(INFO) << "read the vocabulary success, dict size: " - << vocabulary_.size(); - - LOG(INFO) << "language model path: " << opts_.lm_path; - if (opts_.lm_path != "") { - init_ext_scorer_ = std::make_shared( - opts_.alpha, opts_.beta, opts_.lm_path, vocabulary_); - } - - CHECK_EQ(opts_.blank, 0); - - auto it = std::find(vocabulary_.begin(), vocabulary_.end(), " "); - space_id_ = it - vocabulary_.begin(); - // if no space in vocabulary - if (static_cast(space_id_) >= vocabulary_.size()) { - space_id_ = -2; - } -} - -void CTCBeamSearch::Reset() { - // num_frame_decoded_ = 0; - // ResetPrefixes(); - InitDecoder(); -} - -void CTCBeamSearch::InitDecoder() { - num_frame_decoded_ = 0; - // ResetPrefixes(); - prefixes_.clear(); - - root_ = std::make_shared(); - root_->score = root_->log_prob_b_prev = 0.0; - prefixes_.push_back(root_.get()); - if (init_ext_scorer_ != nullptr && - !init_ext_scorer_->is_character_based()) { - auto fst_dict = - static_cast(init_ext_scorer_->dictionary); - fst::StdVectorFst* dict_ptr = fst_dict->Copy(true); - root_->set_dictionary(dict_ptr); - - auto matcher = std::make_shared(*dict_ptr, fst::MATCH_INPUT); - root_->set_matcher(matcher); - } -} - -void CTCBeamSearch::Decode( - std::shared_ptr decodable) { - return; -} - -// todo rename, refactor -void CTCBeamSearch::AdvanceDecode( - const std::shared_ptr& decodable) { - while (1) { - vector> likelihood; - vector frame_prob; - bool flag = decodable->FrameLikelihood(num_frame_decoded_, &frame_prob); - if (flag == false) break; - likelihood.push_back(frame_prob); - AdvanceDecoding(likelihood); - } -} - -void CTCBeamSearch::ResetPrefixes() { - for (size_t i = 0; i < prefixes_.size(); i++) { - if (prefixes_[i] != nullptr) { - delete prefixes_[i]; - prefixes_[i] = nullptr; - } - } - prefixes_.clear(); -} - -int CTCBeamSearch::DecodeLikelihoods(const vector>& probs, - const vector& nbest_words) { - kaldi::Timer timer; - AdvanceDecoding(probs); - LOG(INFO) << "ctc decoding elapsed time(s) " - << static_cast(timer.Elapsed()) / 1000.0f; - return 0; -} - -vector> CTCBeamSearch::GetNBestPath(int n) { - int beam_size = n == -1 ? opts_.beam_size : std::min(n, opts_.beam_size); - return get_beam_search_result(prefixes_, vocabulary_, beam_size); -} - -vector> CTCBeamSearch::GetNBestPath() { - return GetNBestPath(-1); -} - -string CTCBeamSearch::GetBestPath() { - std::vector> result; - result = get_beam_search_result(prefixes_, vocabulary_, opts_.beam_size); - return result[0].second; -} - -string CTCBeamSearch::GetFinalBestPath() { - CalculateApproxScore(); - LMRescore(); - return GetBestPath(); -} - -void CTCBeamSearch::AdvanceDecoding(const vector>& probs) { - size_t num_time_steps = probs.size(); - size_t beam_size = opts_.beam_size; - double cutoff_prob = opts_.cutoff_prob; - size_t cutoff_top_n = opts_.cutoff_top_n; - - vector> probs_seq(probs.size(), - vector(probs[0].size(), 0)); - - int row = probs.size(); - int col = probs[0].size(); - for (int i = 0; i < row; i++) { - for (int j = 0; j < col; j++) { - probs_seq[i][j] = static_cast(probs[i][j]); - } - } - - for (size_t time_step = 0; time_step < num_time_steps; time_step++) { - const auto& prob = probs_seq[time_step]; - - float min_cutoff = -NUM_FLT_INF; - bool full_beam = false; - if (init_ext_scorer_ != nullptr) { - size_t num_prefixes_ = std::min(prefixes_.size(), beam_size); - std::sort(prefixes_.begin(), - prefixes_.begin() + num_prefixes_, - prefix_compare); - - if (num_prefixes_ == 0) { - continue; - } - min_cutoff = prefixes_[num_prefixes_ - 1]->score + - std::log(prob[opts_.blank]) - - std::max(0.0, init_ext_scorer_->beta); - - full_beam = (num_prefixes_ == beam_size); - } - - vector> log_prob_idx = - get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n); - - // loop over chars - size_t log_prob_idx_len = log_prob_idx.size(); - for (size_t index = 0; index < log_prob_idx_len; index++) { - SearchOneChar(full_beam, log_prob_idx[index], min_cutoff); - } - - prefixes_.clear(); - - // update log probs - root_->iterate_to_vec(prefixes_); - // only preserve top beam_size prefixes_ - if (prefixes_.size() >= beam_size) { - std::nth_element(prefixes_.begin(), - prefixes_.begin() + beam_size, - prefixes_.end(), - prefix_compare); - for (size_t i = beam_size; i < prefixes_.size(); ++i) { - prefixes_[i]->remove(); - } - } // end if - num_frame_decoded_++; - } // end for probs_seq -} - -int32 CTCBeamSearch::SearchOneChar( - const bool& full_beam, - const std::pair& log_prob_idx, - const BaseFloat& min_cutoff) { - size_t beam_size = opts_.beam_size; - const auto& c = log_prob_idx.first; - const auto& log_prob_c = log_prob_idx.second; - size_t prefixes_len = std::min(prefixes_.size(), beam_size); - - for (size_t i = 0; i < prefixes_len; ++i) { - auto prefix = prefixes_[i]; - if (full_beam && log_prob_c + prefix->score < min_cutoff) { - break; - } - - if (c == opts_.blank) { - prefix->log_prob_b_cur = - log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score); - continue; - } - - // repeated character - if (c == prefix->character) { - // p_{nb}(l;x_{1:t}) = p(c;x_{t})p(l;x_{1:t-1}) - prefix->log_prob_nb_cur = log_sum_exp( - prefix->log_prob_nb_cur, log_prob_c + prefix->log_prob_nb_prev); - } - - // get new prefix - auto prefix_new = prefix->get_path_trie(c); - if (prefix_new != nullptr) { - float log_p = -NUM_FLT_INF; - if (c == prefix->character && - prefix->log_prob_b_prev > -NUM_FLT_INF) { - // p_{nb}(l^{+};x_{1:t}) = p(c;x_{t})p_{b}(l;x_{1:t-1}) - log_p = log_prob_c + prefix->log_prob_b_prev; - } else if (c != prefix->character) { - // p_{nb}(l^{+};x_{1:t}) = p(c;x_{t}) p(l;x_{1:t-1}) - log_p = log_prob_c + prefix->score; - } - - // language model scoring - if (init_ext_scorer_ != nullptr && - (c == space_id_ || init_ext_scorer_->is_character_based())) { - PathTrie* prefix_to_score = nullptr; - // skip scoring the space - if (init_ext_scorer_->is_character_based()) { - prefix_to_score = prefix_new; - } else { - prefix_to_score = prefix; - } - - float score = 0.0; - vector ngram; - ngram = init_ext_scorer_->make_ngram(prefix_to_score); - // lm score: p_{lm}(W)^{\alpha} + \beta - score = init_ext_scorer_->get_log_cond_prob(ngram) * - init_ext_scorer_->alpha; - log_p += score; - log_p += init_ext_scorer_->beta; - } - // p_{nb}(l;x_{1:t}) - prefix_new->log_prob_nb_cur = - log_sum_exp(prefix_new->log_prob_nb_cur, log_p); - } - } // end of loop over prefix - return 0; -} - -void CTCBeamSearch::CalculateApproxScore() { - size_t beam_size = opts_.beam_size; - size_t num_prefixes_ = std::min(prefixes_.size(), beam_size); - std::sort( - prefixes_.begin(), prefixes_.begin() + num_prefixes_, prefix_compare); - - // compute aproximate ctc score as the return score, without affecting the - // return order of decoding result. To delete when decoder gets stable. - for (size_t i = 0; i < beam_size && i < prefixes_.size(); ++i) { - double approx_ctc = prefixes_[i]->score; - if (init_ext_scorer_ != nullptr) { - vector output; - prefixes_[i]->get_path_vec(output); - auto prefix_length = output.size(); - auto words = init_ext_scorer_->split_labels(output); - // remove word insert - approx_ctc = approx_ctc - prefix_length * init_ext_scorer_->beta; - // remove language model weight: - approx_ctc -= (init_ext_scorer_->get_sent_log_prob(words)) * - init_ext_scorer_->alpha; - } - prefixes_[i]->approx_ctc = approx_ctc; - } -} - -void CTCBeamSearch::LMRescore() { - size_t beam_size = opts_.beam_size; - if (init_ext_scorer_ != nullptr && - !init_ext_scorer_->is_character_based()) { - for (size_t i = 0; i < beam_size && i < prefixes_.size(); ++i) { - auto prefix = prefixes_[i]; - if (!prefix->is_empty() && prefix->character != space_id_) { - float score = 0.0; - vector ngram = init_ext_scorer_->make_ngram(prefix); - score = init_ext_scorer_->get_log_cond_prob(ngram) * - init_ext_scorer_->alpha; - score += init_ext_scorer_->beta; - prefix->score += score; - } - } - } -} - -} // namespace ppspeech diff --git a/speechx/speechx/asr/decoder/ctc_beam_search_decoder.h b/speechx/speechx/asr/decoder/ctc_beam_search_decoder.h deleted file mode 100644 index f06d88e32..000000000 --- a/speechx/speechx/asr/decoder/ctc_beam_search_decoder.h +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// used by deepspeech2 - -#pragma once - -#include "decoder/ctc_beam_search_opt.h" -#include "decoder/ctc_decoders/path_trie.h" -#include "decoder/ctc_decoders/scorer.h" -#include "decoder/decoder_itf.h" - -namespace ppspeech { - -class CTCBeamSearch : public DecoderBase { - public: - explicit CTCBeamSearch(const CTCBeamSearchOptions& opts); - ~CTCBeamSearch() {} - - void InitDecoder(); - - void Reset(); - - void AdvanceDecode( - const std::shared_ptr& decodable); - - void Decode(std::shared_ptr decodable); - - std::string GetBestPath(); - std::vector> GetNBestPath(); - std::vector> GetNBestPath(int n); - std::string GetFinalBestPath(); - - std::string GetPartialResult() { - CHECK(false) << "Not implement."; - return {}; - } - - int DecodeLikelihoods(const std::vector>& probs, - const std::vector& nbest_words); - - private: - void ResetPrefixes(); - - int32 SearchOneChar(const bool& full_beam, - const std::pair& log_prob_idx, - const BaseFloat& min_cutoff); - void CalculateApproxScore(); - void LMRescore(); - void AdvanceDecoding(const std::vector>& probs); - - CTCBeamSearchOptions opts_; - std::shared_ptr init_ext_scorer_; // todo separate later - std::vector vocabulary_; // todo remove later - int space_id_; - std::shared_ptr root_; - std::vector prefixes_; - - DISALLOW_COPY_AND_ASSIGN(CTCBeamSearch); -}; - -} // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/asr/decoder/ctc_beam_search_decoder_main.cc b/speechx/speechx/asr/decoder/ctc_beam_search_decoder_main.cc deleted file mode 100644 index ab0376b6b..000000000 --- a/speechx/speechx/asr/decoder/ctc_beam_search_decoder_main.cc +++ /dev/null @@ -1,167 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// used by deepspeech2 - -#include "base/flags.h" -#include "base/log.h" -#include "decoder/ctc_beam_search_decoder.h" -#include "frontend/audio/data_cache.h" -#include "kaldi/util/table-types.h" -#include "nnet/decodable.h" -#include "nnet/ds2_nnet.h" - -DEFINE_string(feature_rspecifier, "", "test feature rspecifier"); -DEFINE_string(result_wspecifier, "", "test result wspecifier"); -DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model"); -DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param"); -DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm"); -DEFINE_string(lm_path, "", "language model"); -DEFINE_int32(receptive_field_length, - 7, - "receptive field of two CNN(kernel=3) downsampling module."); -DEFINE_int32(subsampling_rate, - 4, - "two CNN(kernel=3) module downsampling rate."); -DEFINE_string( - model_input_names, - "audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box", - "model input names"); -DEFINE_string(model_output_names, - "softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0", - "model output names"); -DEFINE_string(model_cache_names, - "chunk_state_h_box,chunk_state_c_box", - "model cache names"); -DEFINE_string(model_cache_shapes, "5-1-1024,5-1-1024", "model cache shapes"); -DEFINE_int32(nnet_decoder_chunk, 1, "paddle nnet forward chunk"); - -using kaldi::BaseFloat; -using kaldi::Matrix; -using std::vector; - -// test ds2 online decoder by feeding speech feature -int main(int argc, char* argv[]) { - gflags::SetUsageMessage("Usage:"); - gflags::ParseCommandLineFlags(&argc, &argv, false); - google::InitGoogleLogging(argv[0]); - google::InstallFailureSignalHandler(); - FLAGS_logtostderr = 1; - - CHECK_NE(FLAGS_result_wspecifier, ""); - CHECK_NE(FLAGS_feature_rspecifier, ""); - - kaldi::SequentialBaseFloatMatrixReader feature_reader( - FLAGS_feature_rspecifier); - kaldi::TokenWriter result_writer(FLAGS_result_wspecifier); - std::string model_path = FLAGS_model_path; - std::string model_params = FLAGS_param_path; - std::string dict_file = FLAGS_dict_file; - std::string lm_path = FLAGS_lm_path; - LOG(INFO) << "model path: " << model_path; - LOG(INFO) << "model param: " << model_params; - LOG(INFO) << "dict path: " << dict_file; - LOG(INFO) << "lm path: " << lm_path; - - int32 num_done = 0, num_err = 0; - - ppspeech::CTCBeamSearchOptions opts; - opts.dict_file = dict_file; - opts.lm_path = lm_path; - ppspeech::CTCBeamSearch decoder(opts); - - ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags(); - - std::shared_ptr nnet( - new ppspeech::PaddleNnet(model_opts)); - std::shared_ptr raw_data(new ppspeech::DataCache()); - std::shared_ptr decodable( - new ppspeech::Decodable(nnet, raw_data)); - - int32 chunk_size = FLAGS_receptive_field_length + - (FLAGS_nnet_decoder_chunk - 1) * FLAGS_subsampling_rate; - int32 chunk_stride = FLAGS_subsampling_rate * FLAGS_nnet_decoder_chunk; - int32 receptive_field_length = FLAGS_receptive_field_length; - LOG(INFO) << "chunk size (frame): " << chunk_size; - LOG(INFO) << "chunk stride (frame): " << chunk_stride; - LOG(INFO) << "receptive field (frame): " << receptive_field_length; - decoder.InitDecoder(); - - kaldi::Timer timer; - for (; !feature_reader.Done(); feature_reader.Next()) { - string utt = feature_reader.Key(); - kaldi::Matrix feature = feature_reader.Value(); - raw_data->SetDim(feature.NumCols()); - LOG(INFO) << "process utt: " << utt; - LOG(INFO) << "rows: " << feature.NumRows(); - LOG(INFO) << "cols: " << feature.NumCols(); - - int32 row_idx = 0; - int32 padding_len = 0; - int32 ori_feature_len = feature.NumRows(); - if ((feature.NumRows() - chunk_size) % chunk_stride != 0) { - padding_len = - chunk_stride - (feature.NumRows() - chunk_size) % chunk_stride; - feature.Resize(feature.NumRows() + padding_len, - feature.NumCols(), - kaldi::kCopyData); - } - int32 num_chunks = (feature.NumRows() - chunk_size) / chunk_stride + 1; - for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) { - kaldi::Vector feature_chunk(chunk_size * - feature.NumCols()); - int32 feature_chunk_size = 0; - if (ori_feature_len > chunk_idx * chunk_stride) { - feature_chunk_size = std::min( - ori_feature_len - chunk_idx * chunk_stride, chunk_size); - } - if (feature_chunk_size < receptive_field_length) break; - - int32 start = chunk_idx * chunk_stride; - - for (int row_id = 0; row_id < chunk_size; ++row_id) { - kaldi::SubVector tmp(feature, start); - kaldi::SubVector f_chunk_tmp( - feature_chunk.Data() + row_id * feature.NumCols(), - feature.NumCols()); - f_chunk_tmp.CopyFromVec(tmp); - ++start; - } - raw_data->Accept(feature_chunk); - if (chunk_idx == num_chunks - 1) { - raw_data->SetFinished(); - } - decoder.AdvanceDecode(decodable); - } - std::string result; - result = decoder.GetFinalBestPath(); - decodable->Reset(); - decoder.Reset(); - if (result.empty()) { - // the TokenWriter can not write empty string. - ++num_err; - KALDI_LOG << " the result of " << utt << " is empty"; - continue; - } - KALDI_LOG << " the result of " << utt << " is " << result; - result_writer.Write(utt, result); - ++num_done; - } - - KALDI_LOG << "Done " << num_done << " utterances, " << num_err - << " with errors."; - double elapsed = timer.Elapsed(); - KALDI_LOG << " cost:" << elapsed << " s"; - return (num_done != 0 ? 0 : 1); -} diff --git a/speechx/speechx/asr/decoder/ctc_decoders/.gitignore b/speechx/speechx/asr/decoder/ctc_decoders/.gitignore deleted file mode 100644 index 0b1046ae8..000000000 --- a/speechx/speechx/asr/decoder/ctc_decoders/.gitignore +++ /dev/null @@ -1,9 +0,0 @@ -ThreadPool/ -build/ -dist/ -kenlm/ -openfst-1.6.3/ -openfst-1.6.3.tar.gz -swig_decoders.egg-info/ -decoders_wrap.cxx -swig_decoders.py diff --git a/speechx/speechx/asr/decoder/ctc_decoders/ctc_beam_search_decoder.cpp b/speechx/speechx/asr/decoder/ctc_decoders/ctc_beam_search_decoder.cpp deleted file mode 100644 index ebea5c222..000000000 --- a/speechx/speechx/asr/decoder/ctc_decoders/ctc_beam_search_decoder.cpp +++ /dev/null @@ -1,607 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "COPYING.APACHE2.0"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "ctc_beam_search_decoder.h" - -#include -#include -#include -#include -#include -#include - -#include "ThreadPool.h" -#include "fst/fstlib.h" - -#include "decoder_utils.h" -#include "path_trie.h" - -using FSTMATCH = fst::SortedMatcher; - - -std::vector> ctc_beam_search_decoding( - const std::vector> &probs_seq, - const std::vector &vocabulary, - size_t beam_size, - double cutoff_prob, - size_t cutoff_top_n, - Scorer *ext_scorer, - size_t blank_id) { - // dimension check - size_t num_time_steps = probs_seq.size(); - for (size_t i = 0; i < num_time_steps; ++i) { - VALID_CHECK_EQ(probs_seq[i].size(), - // vocabulary.size() + 1, - vocabulary.size(), - "The shape of probs_seq does not match with " - "the shape of the vocabulary"); - } - - - // assign space id - auto it = std::find(vocabulary.begin(), vocabulary.end(), kSPACE); - int space_id = it - vocabulary.begin(); - // if no space in vocabulary - if ((size_t)space_id >= vocabulary.size()) { - space_id = -2; - } - // init prefixes' root - PathTrie root; - root.score = root.log_prob_b_prev = 0.0; - std::vector prefixes; - prefixes.push_back(&root); - - if (ext_scorer != nullptr && !ext_scorer->is_character_based()) { - auto fst_dict = - static_cast(ext_scorer->dictionary); - fst::StdVectorFst *dict_ptr = fst_dict->Copy(true); - root.set_dictionary(dict_ptr); - auto matcher = std::make_shared(*dict_ptr, fst::MATCH_INPUT); - root.set_matcher(matcher); - } - - // prefix search over time - for (size_t time_step = 0; time_step < num_time_steps; ++time_step) { - auto &prob = probs_seq[time_step]; - - float min_cutoff = -NUM_FLT_INF; - bool full_beam = false; - if (ext_scorer != nullptr) { - size_t num_prefixes = std::min(prefixes.size(), beam_size); - std::sort(prefixes.begin(), - prefixes.begin() + num_prefixes, - prefix_compare); - min_cutoff = prefixes[num_prefixes - 1]->score + - std::log(prob[blank_id]) - - std::max(0.0, ext_scorer->beta); - full_beam = (num_prefixes == beam_size); - } - - std::vector> log_prob_idx = - get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n); - // loop over chars - for (size_t index = 0; index < log_prob_idx.size(); index++) { - auto c = log_prob_idx[index].first; - auto log_prob_c = log_prob_idx[index].second; - - for (size_t i = 0; i < prefixes.size() && i < beam_size; ++i) { - auto prefix = prefixes[i]; - if (full_beam && log_prob_c + prefix->score < min_cutoff) { - break; - } - // blank - if (c == blank_id) { - prefix->log_prob_b_cur = log_sum_exp( - prefix->log_prob_b_cur, log_prob_c + prefix->score); - continue; - } - // repeated character - if (c == prefix->character) { - prefix->log_prob_nb_cur = - log_sum_exp(prefix->log_prob_nb_cur, - log_prob_c + prefix->log_prob_nb_prev); - } - // get new prefix - auto prefix_new = prefix->get_path_trie(c); - - if (prefix_new != nullptr) { - float log_p = -NUM_FLT_INF; - - if (c == prefix->character && - prefix->log_prob_b_prev > -NUM_FLT_INF) { - log_p = log_prob_c + prefix->log_prob_b_prev; - } else if (c != prefix->character) { - log_p = log_prob_c + prefix->score; - } - - // language model scoring - if (ext_scorer != nullptr && - (c == space_id || ext_scorer->is_character_based())) { - PathTrie *prefix_to_score = nullptr; - // skip scoring the space - if (ext_scorer->is_character_based()) { - prefix_to_score = prefix_new; - } else { - prefix_to_score = prefix; - } - - float score = 0.0; - std::vector ngram; - ngram = ext_scorer->make_ngram(prefix_to_score); - score = ext_scorer->get_log_cond_prob(ngram) * - ext_scorer->alpha; - log_p += score; - log_p += ext_scorer->beta; - } - prefix_new->log_prob_nb_cur = - log_sum_exp(prefix_new->log_prob_nb_cur, log_p); - } - } // end of loop over prefix - } // end of loop over vocabulary - - - prefixes.clear(); - // update log probs - root.iterate_to_vec(prefixes); - - // only preserve top beam_size prefixes - if (prefixes.size() >= beam_size) { - std::nth_element(prefixes.begin(), - prefixes.begin() + beam_size, - prefixes.end(), - prefix_compare); - for (size_t i = beam_size; i < prefixes.size(); ++i) { - prefixes[i]->remove(); - } - } - } // end of loop over time - - // score the last word of each prefix that doesn't end with space - if (ext_scorer != nullptr && !ext_scorer->is_character_based()) { - for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { - auto prefix = prefixes[i]; - if (!prefix->is_empty() && prefix->character != space_id) { - float score = 0.0; - std::vector ngram = ext_scorer->make_ngram(prefix); - score = - ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha; - score += ext_scorer->beta; - prefix->score += score; - } - } - } - - size_t num_prefixes = std::min(prefixes.size(), beam_size); - std::sort( - prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare); - - // compute approximate ctc score as the return score, without affecting the - // return order of decoding result. To delete when decoder gets stable. - for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { - double approx_ctc = prefixes[i]->score; - if (ext_scorer != nullptr) { - std::vector output; - prefixes[i]->get_path_vec(output); - auto prefix_length = output.size(); - auto words = ext_scorer->split_labels(output); - // remove word insert - approx_ctc = approx_ctc - prefix_length * ext_scorer->beta; - // remove language model weight: - approx_ctc -= - (ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha; - } - prefixes[i]->approx_ctc = approx_ctc; - } - - return get_beam_search_result(prefixes, vocabulary, beam_size); -} - - -std::vector>> -ctc_beam_search_decoding_batch( - const std::vector>> &probs_split, - const std::vector &vocabulary, - size_t beam_size, - size_t num_processes, - double cutoff_prob, - size_t cutoff_top_n, - Scorer *ext_scorer, - size_t blank_id) { - VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!"); - // thread pool - ThreadPool pool(num_processes); - // number of samples - size_t batch_size = probs_split.size(); - - // enqueue the tasks of decoding - std::vector>>> res; - for (size_t i = 0; i < batch_size; ++i) { - res.emplace_back(pool.enqueue(ctc_beam_search_decoding, - probs_split[i], - vocabulary, - beam_size, - cutoff_prob, - cutoff_top_n, - ext_scorer, - blank_id)); - } - - // get decoding results - std::vector>> batch_results; - for (size_t i = 0; i < batch_size; ++i) { - batch_results.emplace_back(res[i].get()); - } - return batch_results; -} - -void ctc_beam_search_decode_chunk_begin(PathTrie *root, Scorer *ext_scorer) { - if (ext_scorer != nullptr && !ext_scorer->is_character_based()) { - auto fst_dict = - static_cast(ext_scorer->dictionary); - fst::StdVectorFst *dict_ptr = fst_dict->Copy(true); - root->set_dictionary(dict_ptr); - auto matcher = std::make_shared(*dict_ptr, fst::MATCH_INPUT); - root->set_matcher(matcher); - } -} - -void ctc_beam_search_decode_chunk( - PathTrie *root, - std::vector &prefixes, - const std::vector> &probs_seq, - const std::vector &vocabulary, - size_t beam_size, - double cutoff_prob, - size_t cutoff_top_n, - Scorer *ext_scorer, - size_t blank_id) { - // dimension check - size_t num_time_steps = probs_seq.size(); - for (size_t i = 0; i < num_time_steps; ++i) { - VALID_CHECK_EQ(probs_seq[i].size(), - // vocabulary.size() + 1, - vocabulary.size(), - "The shape of probs_seq does not match with " - "the shape of the vocabulary"); - } - - // assign space id - auto it = std::find(vocabulary.begin(), vocabulary.end(), kSPACE); - int space_id = it - vocabulary.begin(); - // if no space in vocabulary - if ((size_t)space_id >= vocabulary.size()) { - space_id = -2; - } - // init prefixes' root - // - // prefix search over time - for (size_t time_step = 0; time_step < num_time_steps; ++time_step) { - auto &prob = probs_seq[time_step]; - - float min_cutoff = -NUM_FLT_INF; - bool full_beam = false; - if (ext_scorer != nullptr) { - size_t num_prefixes = std::min(prefixes.size(), beam_size); - std::sort(prefixes.begin(), - prefixes.begin() + num_prefixes, - prefix_compare); - min_cutoff = prefixes[num_prefixes - 1]->score + - std::log(prob[blank_id]) - - std::max(0.0, ext_scorer->beta); - full_beam = (num_prefixes == beam_size); - } - - std::vector> log_prob_idx = - get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n); - // loop over chars - for (size_t index = 0; index < log_prob_idx.size(); index++) { - auto c = log_prob_idx[index].first; - auto log_prob_c = log_prob_idx[index].second; - - for (size_t i = 0; i < prefixes.size() && i < beam_size; ++i) { - auto prefix = prefixes[i]; - if (full_beam && log_prob_c + prefix->score < min_cutoff) { - break; - } - // blank - if (c == blank_id) { - prefix->log_prob_b_cur = log_sum_exp( - prefix->log_prob_b_cur, log_prob_c + prefix->score); - continue; - } - // repeated character - if (c == prefix->character) { - prefix->log_prob_nb_cur = - log_sum_exp(prefix->log_prob_nb_cur, - log_prob_c + prefix->log_prob_nb_prev); - } - // get new prefix - auto prefix_new = prefix->get_path_trie(c); - - if (prefix_new != nullptr) { - float log_p = -NUM_FLT_INF; - - if (c == prefix->character && - prefix->log_prob_b_prev > -NUM_FLT_INF) { - log_p = log_prob_c + prefix->log_prob_b_prev; - } else if (c != prefix->character) { - log_p = log_prob_c + prefix->score; - } - - // language model scoring - if (ext_scorer != nullptr && - (c == space_id || ext_scorer->is_character_based())) { - PathTrie *prefix_to_score = nullptr; - // skip scoring the space - if (ext_scorer->is_character_based()) { - prefix_to_score = prefix_new; - } else { - prefix_to_score = prefix; - } - - float score = 0.0; - std::vector ngram; - ngram = ext_scorer->make_ngram(prefix_to_score); - score = ext_scorer->get_log_cond_prob(ngram) * - ext_scorer->alpha; - log_p += score; - log_p += ext_scorer->beta; - } - prefix_new->log_prob_nb_cur = - log_sum_exp(prefix_new->log_prob_nb_cur, log_p); - } - } // end of loop over prefix - } // end of loop over vocabulary - - prefixes.clear(); - // update log probs - - root->iterate_to_vec(prefixes); - - // only preserve top beam_size prefixes - if (prefixes.size() >= beam_size) { - std::nth_element(prefixes.begin(), - prefixes.begin() + beam_size, - prefixes.end(), - prefix_compare); - for (size_t i = beam_size; i < prefixes.size(); ++i) { - prefixes[i]->remove(); - } - } - } // end of loop over time - - return; -} - - -std::vector> get_decode_result( - std::vector &prefixes, - const std::vector &vocabulary, - size_t beam_size, - Scorer *ext_scorer) { - auto it = std::find(vocabulary.begin(), vocabulary.end(), kSPACE); - int space_id = it - vocabulary.begin(); - // if no space in vocabulary - if ((size_t)space_id >= vocabulary.size()) { - space_id = -2; - } - // score the last word of each prefix that doesn't end with space - if (ext_scorer != nullptr && !ext_scorer->is_character_based()) { - for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { - auto prefix = prefixes[i]; - if (!prefix->is_empty() && prefix->character != space_id) { - float score = 0.0; - std::vector ngram = ext_scorer->make_ngram(prefix); - score = - ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha; - score += ext_scorer->beta; - prefix->score += score; - } - } - } - - size_t num_prefixes = std::min(prefixes.size(), beam_size); - std::sort( - prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare); - - // compute aproximate ctc score as the return score, without affecting the - // return order of decoding result. To delete when decoder gets stable. - for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { - double approx_ctc = prefixes[i]->score; - if (ext_scorer != nullptr) { - std::vector output; - prefixes[i]->get_path_vec(output); - auto prefix_length = output.size(); - auto words = ext_scorer->split_labels(output); - // remove word insert - approx_ctc = approx_ctc - prefix_length * ext_scorer->beta; - // remove language model weight: - approx_ctc -= - (ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha; - } - prefixes[i]->approx_ctc = approx_ctc; - } - - std::vector> res = - get_beam_search_result(prefixes, vocabulary, beam_size); - - // pay back the last word of each prefix that doesn't end with space (for - // decoding by chunk) - if (ext_scorer != nullptr && !ext_scorer->is_character_based()) { - for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { - auto prefix = prefixes[i]; - if (!prefix->is_empty() && prefix->character != space_id) { - float score = 0.0; - std::vector ngram = ext_scorer->make_ngram(prefix); - score = - ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha; - score += ext_scorer->beta; - prefix->score -= score; - } - } - } - return res; -} - - -void free_storage(std::unique_ptr &storage) { - storage = nullptr; -} - - -CtcBeamSearchDecoderBatch::~CtcBeamSearchDecoderBatch() {} - -CtcBeamSearchDecoderBatch::CtcBeamSearchDecoderBatch( - const std::vector &vocabulary, - size_t batch_size, - size_t beam_size, - size_t num_processes, - double cutoff_prob, - size_t cutoff_top_n, - Scorer *ext_scorer, - size_t blank_id) - : batch_size(batch_size), - beam_size(beam_size), - num_processes(num_processes), - cutoff_prob(cutoff_prob), - cutoff_top_n(cutoff_top_n), - ext_scorer(ext_scorer), - blank_id(blank_id) { - VALID_CHECK_GT(this->beam_size, 0, "beam_size must be greater than 0!"); - VALID_CHECK_GT( - this->num_processes, 0, "num_processes must be nonnegative!"); - this->vocabulary = vocabulary; - for (size_t i = 0; i < batch_size; i++) { - this->decoder_storage_vector.push_back( - std::unique_ptr( - new CtcBeamSearchDecoderStorage())); - ctc_beam_search_decode_chunk_begin( - this->decoder_storage_vector[i]->root, ext_scorer); - } -}; - -/** - * Input - * probs_split: shape [B, T, D] - */ -void CtcBeamSearchDecoderBatch::next( - const std::vector>> &probs_split, - const std::vector &has_value) { - VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!"); - // thread pool - size_t num_has_value = 0; - for (int i = 0; i < has_value.size(); i++) - if (has_value[i] == "true") num_has_value += 1; - ThreadPool pool(std::min(num_processes, num_has_value)); - // number of samples - size_t probs_num = probs_split.size(); - VALID_CHECK_EQ(this->batch_size, - probs_num, - "The batch size of the current input data should be same " - "with the input data before"); - - // enqueue the tasks of decoding - std::vector> res; - for (size_t i = 0; i < batch_size; ++i) { - if (has_value[i] == "true") { - res.emplace_back(pool.enqueue( - ctc_beam_search_decode_chunk, - std::ref(this->decoder_storage_vector[i]->root), - std::ref(this->decoder_storage_vector[i]->prefixes), - probs_split[i], - this->vocabulary, - this->beam_size, - this->cutoff_prob, - this->cutoff_top_n, - this->ext_scorer, - this->blank_id)); - } - } - - for (size_t i = 0; i < batch_size; ++i) { - res[i].get(); - } - return; -}; - -/** - * Return - * batch_result: shape[B, beam_size,(-approx_ctc score, string)] - */ -std::vector>> -CtcBeamSearchDecoderBatch::decode() { - VALID_CHECK_GT( - this->num_processes, 0, "num_processes must be nonnegative!"); - // thread pool - ThreadPool pool(this->num_processes); - // number of samples - // enqueue the tasks of decoding - std::vector>>> res; - for (size_t i = 0; i < this->batch_size; ++i) { - res.emplace_back( - pool.enqueue(get_decode_result, - std::ref(this->decoder_storage_vector[i]->prefixes), - this->vocabulary, - this->beam_size, - this->ext_scorer)); - } - // get decoding results - std::vector>> batch_results; - for (size_t i = 0; i < this->batch_size; ++i) { - batch_results.emplace_back(res[i].get()); - } - return batch_results; -} - - -/** - * reset the state of ctcBeamSearchDecoderBatch - */ -void CtcBeamSearchDecoderBatch::reset_state(size_t batch_size, - size_t beam_size, - size_t num_processes, - double cutoff_prob, - size_t cutoff_top_n) { - this->batch_size = batch_size; - this->beam_size = beam_size; - this->num_processes = num_processes; - this->cutoff_prob = cutoff_prob; - this->cutoff_top_n = cutoff_top_n; - - VALID_CHECK_GT(this->beam_size, 0, "beam_size must be greater than 0!"); - VALID_CHECK_GT( - this->num_processes, 0, "num_processes must be nonnegative!"); - // thread pool - ThreadPool pool(this->num_processes); - // number of samples - // enqueue the tasks of decoding - std::vector> res; - size_t storage_size = decoder_storage_vector.size(); - for (size_t i = 0; i < storage_size; i++) { - res.emplace_back(pool.enqueue( - free_storage, std::ref(this->decoder_storage_vector[i]))); - } - for (size_t i = 0; i < storage_size; ++i) { - res[i].get(); - } - std::vector>().swap( - decoder_storage_vector); - for (size_t i = 0; i < this->batch_size; i++) { - this->decoder_storage_vector.push_back( - std::unique_ptr( - new CtcBeamSearchDecoderStorage())); - ctc_beam_search_decode_chunk_begin( - this->decoder_storage_vector[i]->root, this->ext_scorer); - } -} \ No newline at end of file diff --git a/speechx/speechx/asr/decoder/ctc_decoders/ctc_beam_search_decoder.h b/speechx/speechx/asr/decoder/ctc_decoders/ctc_beam_search_decoder.h deleted file mode 100644 index 92d2b855f..000000000 --- a/speechx/speechx/asr/decoder/ctc_decoders/ctc_beam_search_decoder.h +++ /dev/null @@ -1,175 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "COPYING.APACHE2.0"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef CTC_BEAM_SEARCH_DECODER_H_ -#define CTC_BEAM_SEARCH_DECODER_H_ - -#include -#include -#include - -#include "scorer.h" - -/* CTC Beam Search Decoder - - * Parameters: - * probs_seq: 2-D vector that each element is a vector of probabilities - * over vocabulary of one time step. - * vocabulary: A vector of vocabulary. - * beam_size: The width of beam search. - * cutoff_prob: Cutoff probability for pruning. - * cutoff_top_n: Cutoff number for pruning. - * ext_scorer: External scorer to evaluate a prefix, which consists of - * n-gram language model scoring and word insertion term. - * Default null, decoding the input sample without scorer. - * Return: - * A vector that each element is a pair of score and decoding result, - * in desending order. -*/ -std::vector> ctc_beam_search_decoding( - const std::vector> &probs_seq, - const std::vector &vocabulary, - size_t beam_size, - double cutoff_prob = 1.0, - size_t cutoff_top_n = 40, - Scorer *ext_scorer = nullptr, - size_t blank_id = 0); - - -/* CTC Beam Search Decoder for batch data - - * Parameters: - * probs_seq: 3-D vector that each element is a 2-D vector that can be used - * by ctc_beam_search_decoder(). - * vocabulary: A vector of vocabulary. - * beam_size: The width of beam search. - * num_processes: Number of threads for beam search. - * cutoff_prob: Cutoff probability for pruning. - * cutoff_top_n: Cutoff number for pruning. - * ext_scorer: External scorer to evaluate a prefix, which consists of - * n-gram language model scoring and word insertion term. - * Default null, decoding the input sample without scorer. - * Return: - * A 2-D vector that each element is a vector of beam search decoding - * result for one audio sample. -*/ -std::vector>> -ctc_beam_search_decoding_batch( - const std::vector>> &probs_split, - const std::vector &vocabulary, - size_t beam_size, - size_t num_processes, - double cutoff_prob = 1.0, - size_t cutoff_top_n = 40, - Scorer *ext_scorer = nullptr, - size_t blank_id = 0); - -/** - * Store the root and prefixes for decoder - */ - -class CtcBeamSearchDecoderStorage { - public: - PathTrie *root = nullptr; - std::vector prefixes; - - CtcBeamSearchDecoderStorage() { - // init prefixes' root - this->root = new PathTrie(); - this->root->log_prob_b_prev = 0.0; - // The score of root is in log scale.Since the prob=1.0, the prob score - // in log scale is 0.0 - this->root->score = root->log_prob_b_prev; - // std::vector prefixes; - this->prefixes.push_back(root); - }; - - ~CtcBeamSearchDecoderStorage() { - if (root != nullptr) { - delete root; - root = nullptr; - } - }; -}; - -/** - * The ctc beam search decoder, support batchsize >= 1 - */ -class CtcBeamSearchDecoderBatch { - public: - CtcBeamSearchDecoderBatch(const std::vector &vocabulary, - size_t batch_size, - size_t beam_size, - size_t num_processes, - double cutoff_prob, - size_t cutoff_top_n, - Scorer *ext_scorer, - size_t blank_id); - - ~CtcBeamSearchDecoderBatch(); - void next(const std::vector>> &probs_split, - const std::vector &has_value); - - std::vector>> decode(); - - void reset_state(size_t batch_size, - size_t beam_size, - size_t num_processes, - double cutoff_prob, - size_t cutoff_top_n); - - private: - std::vector vocabulary; - size_t batch_size; - size_t beam_size; - size_t num_processes; - double cutoff_prob; - size_t cutoff_top_n; - Scorer *ext_scorer; - size_t blank_id; - std::vector> - decoder_storage_vector; -}; - -/** - * function for chunk decoding - */ -void ctc_beam_search_decode_chunk( - PathTrie *root, - std::vector &prefixes, - const std::vector> &probs_seq, - const std::vector &vocabulary, - size_t beam_size, - double cutoff_prob, - size_t cutoff_top_n, - Scorer *ext_scorer, - size_t blank_id); - -std::vector> get_decode_result( - std::vector &prefixes, - const std::vector &vocabulary, - size_t beam_size, - Scorer *ext_scorer); - -/** - * free the CtcBeamSearchDecoderStorage - */ -void free_storage(std::unique_ptr &storage); - -/** - * initialize the root - */ -void ctc_beam_search_decode_chunk_begin(PathTrie *root, Scorer *ext_scorer); - -#endif // CTC_BEAM_SEARCH_DECODER_H_ diff --git a/speechx/speechx/asr/decoder/ctc_decoders/ctc_greedy_decoder.cpp b/speechx/speechx/asr/decoder/ctc_decoders/ctc_greedy_decoder.cpp deleted file mode 100644 index 6aa3c9964..000000000 --- a/speechx/speechx/asr/decoder/ctc_decoders/ctc_greedy_decoder.cpp +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "COPYING.APACHE2.0"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "ctc_greedy_decoder.h" -#include "decoder_utils.h" - -std::string ctc_greedy_decoding( - const std::vector> &probs_seq, - const std::vector &vocabulary, - size_t blank_id) { - // dimension check - size_t num_time_steps = probs_seq.size(); - for (size_t i = 0; i < num_time_steps; ++i) { - VALID_CHECK_EQ(probs_seq[i].size(), - vocabulary.size(), - "The shape of probs_seq does not match with " - "the shape of the vocabulary"); - } - - // size_t blank_id = vocabulary.size(); - - std::vector max_idx_vec(num_time_steps, 0); - std::vector idx_vec; - for (size_t i = 0; i < num_time_steps; ++i) { - double max_prob = 0.0; - size_t max_idx = 0; - const std::vector &probs_step = probs_seq[i]; - for (size_t j = 0; j < probs_step.size(); ++j) { - if (max_prob < probs_step[j]) { - max_idx = j; - max_prob = probs_step[j]; - } - } - // id with maximum probability in current time step - max_idx_vec[i] = max_idx; - // deduplicate - if ((i == 0) || ((i > 0) && max_idx_vec[i] != max_idx_vec[i - 1])) { - idx_vec.push_back(max_idx_vec[i]); - } - } - - std::string best_path_result; - for (size_t i = 0; i < idx_vec.size(); ++i) { - if (idx_vec[i] != blank_id) { - std::string ch = vocabulary[idx_vec[i]]; - best_path_result += (ch == kSPACE) ? tSPACE : ch; - } - } - return best_path_result; -} diff --git a/speechx/speechx/asr/decoder/ctc_decoders/ctc_greedy_decoder.h b/speechx/speechx/asr/decoder/ctc_decoders/ctc_greedy_decoder.h deleted file mode 100644 index 4451600d6..000000000 --- a/speechx/speechx/asr/decoder/ctc_decoders/ctc_greedy_decoder.h +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "COPYING.APACHE2.0"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef CTC_GREEDY_DECODER_H -#define CTC_GREEDY_DECODER_H - -#include -#include - -/* CTC Greedy (Best Path) Decoder - * - * Parameters: - * probs_seq: 2-D vector that each element is a vector of probabilities - * over vocabulary of one time step. - * vocabulary: A vector of vocabulary. - * Return: - * The decoding result in string - */ -std::string ctc_greedy_decoding( - const std::vector>& probs_seq, - const std::vector& vocabulary, - size_t blank_id); - -#endif // CTC_GREEDY_DECODER_H diff --git a/speechx/speechx/asr/decoder/ctc_decoders/decoder_utils.cpp b/speechx/speechx/asr/decoder/ctc_decoders/decoder_utils.cpp deleted file mode 100644 index c7ef65428..000000000 --- a/speechx/speechx/asr/decoder/ctc_decoders/decoder_utils.cpp +++ /dev/null @@ -1,193 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "COPYING.APACHE2.0"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "decoder_utils.h" - -#include -#include -#include - -std::vector> get_pruned_log_probs( - const std::vector &prob_step, - double cutoff_prob, - size_t cutoff_top_n) { - std::vector> prob_idx; - for (size_t i = 0; i < prob_step.size(); ++i) { - prob_idx.push_back(std::pair(i, prob_step[i])); - } - // pruning of vocabulary - size_t cutoff_len = prob_step.size(); - if (cutoff_prob < 1.0 || cutoff_top_n < cutoff_len) { - std::sort(prob_idx.begin(), - prob_idx.end(), - pair_comp_second_rev); - if (cutoff_prob < 1.0) { - double cum_prob = 0.0; - cutoff_len = 0; - for (size_t i = 0; i < prob_idx.size(); ++i) { - cum_prob += prob_idx[i].second; - cutoff_len += 1; - if (cum_prob >= cutoff_prob || cutoff_len >= cutoff_top_n) - break; - } - } - prob_idx = std::vector>( - prob_idx.begin(), prob_idx.begin() + cutoff_len); - } - std::vector> log_prob_idx; - for (size_t i = 0; i < cutoff_len; ++i) { - log_prob_idx.push_back(std::pair( - prob_idx[i].first, log(prob_idx[i].second + NUM_FLT_MIN))); - } - return log_prob_idx; -} - - -std::vector> get_beam_search_result( - const std::vector &prefixes, - const std::vector &vocabulary, - size_t beam_size) { - // allow for the post processing - std::vector space_prefixes; - if (space_prefixes.empty()) { - for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { - space_prefixes.push_back(prefixes[i]); - } - } - - std::sort(space_prefixes.begin(), space_prefixes.end(), prefix_compare); - std::vector> output_vecs; - for (size_t i = 0; i < beam_size && i < space_prefixes.size(); ++i) { - std::vector output; - space_prefixes[i]->get_path_vec(output); - // convert index to string - std::string output_str; - for (size_t j = 0; j < output.size(); j++) { - std::string ch = vocabulary[output[j]]; - output_str += (ch == kSPACE) ? tSPACE : ch; - } - std::pair output_pair( - -space_prefixes[i]->approx_ctc, output_str); - output_vecs.emplace_back(output_pair); - } - - return output_vecs; -} - -size_t get_utf8_str_len(const std::string &str) { - size_t str_len = 0; - for (char c : str) { - str_len += ((c & 0xc0) != 0x80); - } - return str_len; -} - -std::vector split_utf8_str(const std::string &str) { - std::vector result; - std::string out_str; - - for (char c : str) { - if ((c & 0xc0) != 0x80) // new UTF-8 character - { - if (!out_str.empty()) { - result.push_back(out_str); - out_str.clear(); - } - } - - out_str.append(1, c); - } - result.push_back(out_str); - return result; -} - -std::vector split_str(const std::string &s, - const std::string &delim) { - std::vector result; - std::size_t start = 0, delim_len = delim.size(); - while (true) { - std::size_t end = s.find(delim, start); - if (end == std::string::npos) { - if (start < s.size()) { - result.push_back(s.substr(start)); - } - break; - } - if (end > start) { - result.push_back(s.substr(start, end - start)); - } - start = end + delim_len; - } - return result; -} - -bool prefix_compare(const PathTrie *x, const PathTrie *y) { - if (x->score == y->score) { - if (x->character == y->character) { - return false; - } else { - return (x->character < y->character); - } - } else { - return x->score > y->score; - } -} - -void add_word_to_fst(const std::vector &word, - fst::StdVectorFst *dictionary) { - if (dictionary->NumStates() == 0) { - fst::StdVectorFst::StateId start = dictionary->AddState(); - assert(start == 0); - dictionary->SetStart(start); - } - fst::StdVectorFst::StateId src = dictionary->Start(); - fst::StdVectorFst::StateId dst; - for (auto c : word) { - dst = dictionary->AddState(); - dictionary->AddArc(src, fst::StdArc(c, c, 0, dst)); - src = dst; - } - dictionary->SetFinal(dst, fst::StdArc::Weight::One()); -} - -bool add_word_to_dictionary( - const std::string &word, - const std::unordered_map &char_map, - bool add_space, - int SPACE_ID, - fst::StdVectorFst *dictionary) { - auto characters = split_utf8_str(word); - - std::vector int_word; - - for (auto &c : characters) { - if (c == " ") { - int_word.push_back(SPACE_ID); - } else { - auto int_c = char_map.find(c); - if (int_c != char_map.end()) { - int_word.push_back(int_c->second); - } else { - return false; // return without adding - } - } - } - - if (add_space) { - int_word.push_back(SPACE_ID); - } - - add_word_to_fst(int_word, dictionary); - return true; // return with successful adding -} diff --git a/speechx/speechx/asr/decoder/ctc_decoders/decoder_utils.h b/speechx/speechx/asr/decoder/ctc_decoders/decoder_utils.h deleted file mode 100644 index 098741552..000000000 --- a/speechx/speechx/asr/decoder/ctc_decoders/decoder_utils.h +++ /dev/null @@ -1,111 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "COPYING.APACHE2.0"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef DECODER_UTILS_H_ -#define DECODER_UTILS_H_ - -#include -#include -#include "fst/log.h" -#include "path_trie.h" - -const std::string kSPACE = ""; -const std::string tSPACE = " "; -const float NUM_FLT_INF = std::numeric_limits::max(); -const float NUM_FLT_MIN = std::numeric_limits::min(); - -// inline function for validation check -inline void check( - bool x, const char *expr, const char *file, int line, const char *err) { - if (!x) { - std::cout << "[" << file << ":" << line << "] "; - LOG(FATAL) << "\"" << expr << "\" check failed. " << err; - } -} - -#define VALID_CHECK(x, info) \ - check(static_cast(x), #x, __FILE__, __LINE__, info) -#define VALID_CHECK_EQ(x, y, info) VALID_CHECK((x) == (y), info) -#define VALID_CHECK_GT(x, y, info) VALID_CHECK((x) > (y), info) -#define VALID_CHECK_LT(x, y, info) VALID_CHECK((x) < (y), info) - - -// Function template for comparing two pairs -template -bool pair_comp_first_rev(const std::pair &a, - const std::pair &b) { - return a.first > b.first; -} - -// Function template for comparing two pairs -template -bool pair_comp_second_rev(const std::pair &a, - const std::pair &b) { - return a.second > b.second; -} - -// Return the sum of two probabilities in log scale -template -T log_sum_exp(const T &x, const T &y) { - static T num_min = -std::numeric_limits::max(); - if (x <= num_min) return y; - if (y <= num_min) return x; - T xmax = std::max(x, y); - return std::log(std::exp(x - xmax) + std::exp(y - xmax)) + xmax; -} - -// Get pruned probability vector for each time step's beam search -std::vector> get_pruned_log_probs( - const std::vector &prob_step, - double cutoff_prob, - size_t cutoff_top_n); - -// Get beam search result from prefixes in trie tree -std::vector> get_beam_search_result( - const std::vector &prefixes, - const std::vector &vocabulary, - size_t beam_size); - -// Functor for prefix comparsion -bool prefix_compare(const PathTrie *x, const PathTrie *y); - -/* Get length of utf8 encoding string - * See: http://stackoverflow.com/a/4063229 - */ -size_t get_utf8_str_len(const std::string &str); - -/* Split a string into a list of strings on a given string - * delimiter. NB: delimiters on beginning / end of string are - * trimmed. Eg, "FooBarFoo" split on "Foo" returns ["Bar"]. - */ -std::vector split_str(const std::string &s, - const std::string &delim); - -/* Splits string into vector of strings representing - * UTF-8 characters (not same as chars) - */ -std::vector split_utf8_str(const std::string &str); - -// Add a word in index to the dicionary of fst -void add_word_to_fst(const std::vector &word, - fst::StdVectorFst *dictionary); - -// Add a word in string to dictionary -bool add_word_to_dictionary( - const std::string &word, - const std::unordered_map &char_map, - bool add_space, - int SPACE_ID, - fst::StdVectorFst *dictionary); -#endif // DECODER_UTILS_H diff --git a/speechx/speechx/asr/decoder/ctc_decoders/path_trie.cpp b/speechx/speechx/asr/decoder/ctc_decoders/path_trie.cpp deleted file mode 100644 index 777ca0520..000000000 --- a/speechx/speechx/asr/decoder/ctc_decoders/path_trie.cpp +++ /dev/null @@ -1,164 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "COPYING.APACHE2.0"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "path_trie.h" - -#include -#include -#include -#include -#include - -#include "decoder_utils.h" - -PathTrie::PathTrie() { - log_prob_b_prev = -NUM_FLT_INF; - log_prob_nb_prev = -NUM_FLT_INF; - log_prob_b_cur = -NUM_FLT_INF; - log_prob_nb_cur = -NUM_FLT_INF; - score = -NUM_FLT_INF; - - ROOT_ = -1; - character = ROOT_; - exists_ = true; - parent = nullptr; - - dictionary_ = nullptr; - dictionary_state_ = 0; - has_dictionary_ = false; - - matcher_ = nullptr; -} - -PathTrie::~PathTrie() { - for (auto child : children_) { - delete child.second; - child.second = nullptr; - } -} - -PathTrie* PathTrie::get_path_trie(int new_char, bool reset) { - auto child = children_.begin(); - for (child = children_.begin(); child != children_.end(); ++child) { - if (child->first == new_char) { - break; - } - } - if (child != children_.end()) { - if (!child->second->exists_) { - child->second->exists_ = true; - child->second->log_prob_b_prev = -NUM_FLT_INF; - child->second->log_prob_nb_prev = -NUM_FLT_INF; - child->second->log_prob_b_cur = -NUM_FLT_INF; - child->second->log_prob_nb_cur = -NUM_FLT_INF; - } - return (child->second); - } else { - if (has_dictionary_) { - matcher_->SetState(dictionary_state_); - bool found = matcher_->Find(new_char + 1); - if (!found) { - // Adding this character causes word outside dictionary - auto FSTZERO = fst::TropicalWeight::Zero(); - auto final_weight = dictionary_->Final(dictionary_state_); - bool is_final = (final_weight != FSTZERO); - if (is_final && reset) { - dictionary_state_ = dictionary_->Start(); - } - return nullptr; - } else { - PathTrie* new_path = new PathTrie; - new_path->character = new_char; - new_path->parent = this; - new_path->dictionary_ = dictionary_; - new_path->dictionary_state_ = matcher_->Value().nextstate; - new_path->has_dictionary_ = true; - new_path->matcher_ = matcher_; - children_.push_back(std::make_pair(new_char, new_path)); - return new_path; - } - } else { - PathTrie* new_path = new PathTrie; - new_path->character = new_char; - new_path->parent = this; - children_.push_back(std::make_pair(new_char, new_path)); - return new_path; - } - } -} - -PathTrie* PathTrie::get_path_vec(std::vector& output) { - return get_path_vec(output, ROOT_); -} - -PathTrie* PathTrie::get_path_vec(std::vector& output, - int stop, - size_t max_steps) { - if (character == stop || character == ROOT_ || output.size() == max_steps) { - std::reverse(output.begin(), output.end()); - return this; - } else { - output.push_back(character); - return parent->get_path_vec(output, stop, max_steps); - } -} - -void PathTrie::iterate_to_vec(std::vector& output) { - if (exists_) { - log_prob_b_prev = log_prob_b_cur; - log_prob_nb_prev = log_prob_nb_cur; - - log_prob_b_cur = -NUM_FLT_INF; - log_prob_nb_cur = -NUM_FLT_INF; - - score = log_sum_exp(log_prob_b_prev, log_prob_nb_prev); - output.push_back(this); - } - for (auto child : children_) { - child.second->iterate_to_vec(output); - } -} - -void PathTrie::remove() { - exists_ = false; - if (children_.size() == 0) { - if (parent != nullptr) { - auto child = parent->children_.begin(); - for (child = parent->children_.begin(); - child != parent->children_.end(); - ++child) { - if (child->first == character) { - parent->children_.erase(child); - break; - } - } - if (parent->children_.size() == 0 && !parent->exists_) { - parent->remove(); - } - } - delete this; - } -} - - -void PathTrie::set_dictionary(fst::StdVectorFst* dictionary) { - dictionary_ = dictionary; - dictionary_state_ = dictionary->Start(); - has_dictionary_ = true; -} - -using FSTMATCH = fst::SortedMatcher; -void PathTrie::set_matcher(std::shared_ptr matcher) { - matcher_ = matcher; -} diff --git a/speechx/speechx/asr/decoder/ctc_decoders/path_trie.h b/speechx/speechx/asr/decoder/ctc_decoders/path_trie.h deleted file mode 100644 index 5193e0a47..000000000 --- a/speechx/speechx/asr/decoder/ctc_decoders/path_trie.h +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "COPYING.APACHE2.0"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef PATH_TRIE_H -#define PATH_TRIE_H - -#include -#include -#include -#include -#include - -#include "fst/fstlib.h" - -/* Trie tree for prefix storing and manipulating, with a dictionary in - * finite-state transducer for spelling correction. - */ -class PathTrie { - public: - PathTrie(); - ~PathTrie(); - - // get new prefix after appending new char - PathTrie* get_path_trie(int new_char, bool reset = true); - - // get the prefix in index from root to current node - PathTrie* get_path_vec(std::vector& output); - - // get the prefix in index from some stop node to current nodel - PathTrie* get_path_vec( - std::vector& output, - int stop, - size_t max_steps = std::numeric_limits::max()); - - // update log probs - void iterate_to_vec(std::vector& output); - - // set dictionary for FST - void set_dictionary(fst::StdVectorFst* dictionary); - - void set_matcher(std::shared_ptr>); - - bool is_empty() { return ROOT_ == character; } - - // remove current path from root - void remove(); - - float log_prob_b_prev; - float log_prob_nb_prev; - float log_prob_b_cur; - float log_prob_nb_cur; - float score; - float approx_ctc; - int character; - PathTrie* parent; - - private: - int ROOT_; - bool exists_; - bool has_dictionary_; - - std::vector> children_; - - // pointer to dictionary of FST - fst::StdVectorFst* dictionary_; - fst::StdVectorFst::StateId dictionary_state_; - // true if finding ars in FST - std::shared_ptr> matcher_; -}; - -#endif // PATH_TRIE_H diff --git a/speechx/speechx/asr/decoder/ctc_decoders/scorer.cpp b/speechx/speechx/asr/decoder/ctc_decoders/scorer.cpp deleted file mode 100644 index 6e7f68cf6..000000000 --- a/speechx/speechx/asr/decoder/ctc_decoders/scorer.cpp +++ /dev/null @@ -1,232 +0,0 @@ -// Licensed under GNU Lesser General Public License v3 (LGPLv3) (LGPL-3) (the -// "COPYING.LESSER.3"); - -#include "scorer.h" - -#include -#include - -#include "lm/config.hh" -#include "lm/model.hh" -#include "lm/state.hh" - -#include "decoder_utils.h" - -using namespace lm::ngram; -// if your platform is windows ,you need add the define -#define F_OK 0 -Scorer::Scorer(double alpha, - double beta, - const std::string& lm_path, - const std::vector& vocab_list) { - this->alpha = alpha; - this->beta = beta; - - dictionary = nullptr; - is_character_based_ = true; - language_model_ = nullptr; - - max_order_ = 0; - dict_size_ = 0; - SPACE_ID_ = -1; - - setup(lm_path, vocab_list); -} - -Scorer::~Scorer() { - if (language_model_ != nullptr) { - delete static_cast(language_model_); - } - if (dictionary != nullptr) { - delete static_cast(dictionary); - } -} - -void Scorer::setup(const std::string& lm_path, - const std::vector& vocab_list) { - // load language model - load_lm(lm_path); - // set char map for scorer - set_char_map(vocab_list); - // fill the dictionary for FST - if (!is_character_based()) { - fill_dictionary(true); - } -} - -void Scorer::load_lm(const std::string& lm_path) { - const char* filename = lm_path.c_str(); - VALID_CHECK_EQ(access(filename, F_OK), 0, "Invalid language model path"); - - RetriveStrEnumerateVocab enumerate; - lm::ngram::Config config; - config.enumerate_vocab = &enumerate; - language_model_ = lm::ngram::LoadVirtual(filename, config); - max_order_ = static_cast(language_model_)->Order(); - vocabulary_ = enumerate.vocabulary; - for (size_t i = 0; i < vocabulary_.size(); ++i) { - if (is_character_based_ && vocabulary_[i] != UNK_TOKEN && - vocabulary_[i] != START_TOKEN && vocabulary_[i] != END_TOKEN && - get_utf8_str_len(enumerate.vocabulary[i]) > 1) { - is_character_based_ = false; - } - } -} - -double Scorer::get_log_cond_prob(const std::vector& words) { - lm::base::Model* model = static_cast(language_model_); - double cond_prob; - lm::ngram::State state, tmp_state, out_state; - // avoid to inserting in begin - model->NullContextWrite(&state); - for (size_t i = 0; i < words.size(); ++i) { - lm::WordIndex word_index = model->BaseVocabulary().Index(words[i]); - // encounter OOV - if (word_index == 0) { - return OOV_SCORE; - } - cond_prob = model->BaseScore(&state, word_index, &out_state); - tmp_state = state; - state = out_state; - out_state = tmp_state; - } - // return log10 prob - return cond_prob; -} - -double Scorer::get_sent_log_prob(const std::vector& words) { - std::vector sentence; - if (words.size() == 0) { - for (size_t i = 0; i < max_order_; ++i) { - sentence.push_back(START_TOKEN); - } - } else { - for (size_t i = 0; i < max_order_ - 1; ++i) { - sentence.push_back(START_TOKEN); - } - sentence.insert(sentence.end(), words.begin(), words.end()); - } - sentence.push_back(END_TOKEN); - return get_log_prob(sentence); -} - -double Scorer::get_log_prob(const std::vector& words) { - assert(words.size() > max_order_); - double score = 0.0; - for (size_t i = 0; i < words.size() - max_order_ + 1; ++i) { - std::vector ngram(words.begin() + i, - words.begin() + i + max_order_); - score += get_log_cond_prob(ngram); - } - return score; -} - -void Scorer::reset_params(float alpha, float beta) { - this->alpha = alpha; - this->beta = beta; -} - -std::string Scorer::vec2str(const std::vector& input) { - std::string word; - for (auto ind : input) { - word += char_list_[ind]; - } - return word; -} - -std::vector Scorer::split_labels(const std::vector& labels) { - if (labels.empty()) return {}; - - std::string s = vec2str(labels); - std::vector words; - if (is_character_based_) { - words = split_utf8_str(s); - } else { - words = split_str(s, " "); - } - return words; -} - -void Scorer::set_char_map(const std::vector& char_list) { - char_list_ = char_list; - char_map_.clear(); - - // Set the char map for the FST for spelling correction - for (size_t i = 0; i < char_list_.size(); i++) { - if (char_list_[i] == kSPACE) { - SPACE_ID_ = i; - } - // The initial state of FST is state 0, hence the index of chars in - // the FST should start from 1 to avoid the conflict with the initial - // state, otherwise wrong decoding results would be given. - char_map_[char_list_[i]] = i + 1; - } -} - -std::vector Scorer::make_ngram(PathTrie* prefix) { - std::vector ngram; - PathTrie* current_node = prefix; - PathTrie* new_node = nullptr; - - for (int order = 0; order < max_order_; order++) { - std::vector prefix_vec; - - if (is_character_based_) { - new_node = current_node->get_path_vec(prefix_vec, SPACE_ID_, 1); - current_node = new_node; - } else { - new_node = current_node->get_path_vec(prefix_vec, SPACE_ID_); - current_node = new_node->parent; // Skipping spaces - } - - // reconstruct word - std::string word = vec2str(prefix_vec); - ngram.push_back(word); - - if (new_node->character == -1) { - // No more spaces, but still need order - for (int i = 0; i < max_order_ - order - 1; i++) { - ngram.push_back(START_TOKEN); - } - break; - } - } - std::reverse(ngram.begin(), ngram.end()); - return ngram; -} - -void Scorer::fill_dictionary(bool add_space) { - fst::StdVectorFst dictionary; - // For each unigram convert to ints and put in trie - int dict_size = 0; - for (const auto& word : vocabulary_) { - bool added = add_word_to_dictionary( - word, char_map_, add_space, SPACE_ID_ + 1, &dictionary); - dict_size += added ? 1 : 0; - } - - dict_size_ = dict_size; - - /* Simplify FST - - * This gets rid of "epsilon" transitions in the FST. - * These are transitions that don't require a string input to be taken. - * Getting rid of them is necessary to make the FST deterministic, but - * can greatly increase the size of the FST - */ - fst::RmEpsilon(&dictionary); - fst::StdVectorFst* new_dict = new fst::StdVectorFst; - - /* This makes the FST deterministic, meaning for any string input there's - * only one possible state the FST could be in. It is assumed our - * dictionary is deterministic when using it. - * (lest we'd have to check for multiple transitions at each state) - */ - fst::Determinize(dictionary, new_dict); - - /* Finds the simplest equivalent fst. This is unnecessary but decreases - * memory usage of the dictionary - */ - fst::Minimize(new_dict); - this->dictionary = new_dict; -} diff --git a/speechx/speechx/asr/decoder/ctc_decoders/scorer.h b/speechx/speechx/asr/decoder/ctc_decoders/scorer.h deleted file mode 100644 index 08e109b78..000000000 --- a/speechx/speechx/asr/decoder/ctc_decoders/scorer.h +++ /dev/null @@ -1,114 +0,0 @@ -// Licensed under GNU Lesser General Public License v3 (LGPLv3) (LGPL-3) (the -// "COPYING.LESSER.3"); - -#ifndef SCORER_H_ -#define SCORER_H_ - -#include -#include -#include -#include - -#include "lm/enumerate_vocab.hh" -#include "lm/virtual_interface.hh" -#include "lm/word_index.hh" - -#include "path_trie.h" - -const double OOV_SCORE = -1000.0; -const std::string START_TOKEN = ""; -const std::string UNK_TOKEN = ""; -const std::string END_TOKEN = ""; - -// Implement a callback to retrive the dictionary of language model. -class RetriveStrEnumerateVocab : public lm::EnumerateVocab { - public: - RetriveStrEnumerateVocab() {} - - void Add(lm::WordIndex index, const StringPiece &str) { - vocabulary.push_back(std::string(str.data(), str.length())); - } - - std::vector vocabulary; -}; - -/* External scorer to query score for n-gram or sentence, including language - * model scoring and word insertion. - * - * Example: - * Scorer scorer(alpha, beta, "path_of_language_model"); - * scorer.get_log_cond_prob({ "WORD1", "WORD2", "WORD3" }); - * scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" }); - */ -class Scorer { - public: - Scorer(double alpha, - double beta, - const std::string &lm_path, - const std::vector &vocabulary); - ~Scorer(); - - double get_log_cond_prob(const std::vector &words); - - double get_sent_log_prob(const std::vector &words); - - // return the max order - size_t get_max_order() const { return max_order_; } - - // return the dictionary size of language model - size_t get_dict_size() const { return dict_size_; } - - // retrun true if the language model is character based - bool is_character_based() const { return is_character_based_; } - - // reset params alpha & beta - void reset_params(float alpha, float beta); - - // make ngram for a given prefix - std::vector make_ngram(PathTrie *prefix); - - // trransform the labels in index to the vector of words (word based lm) or - // the vector of characters (character based lm) - std::vector split_labels(const std::vector &labels); - - // language model weight - double alpha; - // word insertion weight - double beta; - - // pointer to the dictionary of FST - void *dictionary; - - protected: - // necessary setup: load language model, set char map, fill FST's dictionary - void setup(const std::string &lm_path, - const std::vector &vocab_list); - - // load language model from given path - void load_lm(const std::string &lm_path); - - // fill dictionary for FST - void fill_dictionary(bool add_space); - - // set char map - void set_char_map(const std::vector &char_list); - - double get_log_prob(const std::vector &words); - - // translate the vector in index to string - std::string vec2str(const std::vector &input); - - private: - void *language_model_; - bool is_character_based_; - size_t max_order_; - size_t dict_size_; - - int SPACE_ID_; - std::vector char_list_; - std::unordered_map char_map_; - - std::vector vocabulary_; -}; - -#endif // SCORER_H_ diff --git a/speechx/speechx/asr/decoder/nnet_logprob_decoder_main.cc b/speechx/speechx/asr/decoder/nnet_logprob_decoder_main.cc deleted file mode 100644 index e0acbe77b..000000000 --- a/speechx/speechx/asr/decoder/nnet_logprob_decoder_main.cc +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// todo refactor, repalce with gtest - -#include "base/flags.h" -#include "base/log.h" -#include "decoder/ctc_beam_search_decoder.h" -#include "kaldi/util/table-types.h" -#include "nnet/decodable.h" - -DEFINE_string(nnet_prob_respecifier, "", "test nnet prob rspecifier"); -DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm"); -DEFINE_string(lm_path, "lm.klm", "language model"); - -using kaldi::BaseFloat; -using kaldi::Matrix; -using std::vector; - -// test decoder by feeding nnet posterior probability -int main(int argc, char* argv[]) { - gflags::SetUsageMessage("Usage:"); - gflags::ParseCommandLineFlags(&argc, &argv, false); - google::InitGoogleLogging(argv[0]); - google::InstallFailureSignalHandler(); - FLAGS_logtostderr = 1; - - kaldi::SequentialBaseFloatMatrixReader likelihood_reader( - FLAGS_nnet_prob_respecifier); - std::string dict_file = FLAGS_dict_file; - std::string lm_path = FLAGS_lm_path; - LOG(INFO) << "dict path: " << dict_file; - LOG(INFO) << "lm path: " << lm_path; - - int32 num_done = 0, num_err = 0; - - ppspeech::CTCBeamSearchOptions opts; - opts.dict_file = dict_file; - opts.lm_path = lm_path; - ppspeech::CTCBeamSearch decoder(opts); - - std::shared_ptr decodable( - new ppspeech::Decodable(nullptr, nullptr)); - - decoder.InitDecoder(); - - for (; !likelihood_reader.Done(); likelihood_reader.Next()) { - string utt = likelihood_reader.Key(); - const kaldi::Matrix likelihood = likelihood_reader.Value(); - LOG(INFO) << "process utt: " << utt; - LOG(INFO) << "rows: " << likelihood.NumRows(); - LOG(INFO) << "cols: " << likelihood.NumCols(); - decodable->Acceptlikelihood(likelihood); - decoder.AdvanceDecode(decodable); - std::string result; - result = decoder.GetFinalBestPath(); - KALDI_LOG << " the result of " << utt << " is " << result; - decodable->Reset(); - decoder.Reset(); - ++num_done; - } - - KALDI_LOG << "Done " << num_done << " utterances, " << num_err - << " with errors."; - return (num_done != 0 ? 0 : 1); -} diff --git a/speechx/speechx/asr/decoder/param.h b/speechx/speechx/asr/decoder/param.h index ebdd71197..cad6dbd8d 100644 --- a/speechx/speechx/asr/decoder/param.h +++ b/speechx/speechx/asr/decoder/param.h @@ -15,8 +15,7 @@ #pragma once #include "base/common.h" -#include "decoder/ctc_beam_search_decoder.h" -#include "decoder/ctc_tlg_decoder.h" +//#include "decoder/ctc_tlg_decoder.h" // feature DEFINE_bool(use_fbank, false, "False for fbank; or linear feature"); diff --git a/speechx/speechx/asr/nnet/CMakeLists.txt b/speechx/speechx/asr/nnet/CMakeLists.txt index 2846540ec..819cc2e89 100644 --- a/speechx/speechx/asr/nnet/CMakeLists.txt +++ b/speechx/speechx/asr/nnet/CMakeLists.txt @@ -1,30 +1,12 @@ set(srcs decodable.cc nnet_producer.cc) -if(USING_DS2) - list(APPEND srcs ds2_nnet.cc) -endif() - -if(USING_U2) - list(APPEND srcs u2_nnet.cc) -endif() +list(APPEND srcs u2_nnet.cc) add_library(nnet STATIC ${srcs}) target_link_libraries(nnet utils) -if(USING_U2) - target_compile_options(nnet PUBLIC ${PADDLE_COMPILE_FLAGS}) - target_include_directories(nnet PUBLIC ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) -endif() - - -if(USING_DS2) - set(bin_name ds2_nnet_main) - add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) - target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) - target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog nnet) - - target_link_libraries(${bin_name} ${DEPS}) -endif() +target_compile_options(nnet PUBLIC ${PADDLE_COMPILE_FLAGS}) +target_include_directories(nnet PUBLIC ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) # test bin #if(USING_U2) diff --git a/speechx/speechx/asr/nnet/ds2_nnet.cc b/speechx/speechx/asr/nnet/ds2_nnet.cc deleted file mode 100644 index f77c0a603..000000000 --- a/speechx/speechx/asr/nnet/ds2_nnet.cc +++ /dev/null @@ -1,218 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "nnet/ds2_nnet.h" - -#include "utils/strings.h" - -namespace ppspeech { - -using kaldi::Matrix; -using kaldi::Vector; -using std::shared_ptr; -using std::string; -using std::vector; - -void PaddleNnet::InitCacheEncouts(const ModelOptions& opts) { - std::vector cache_names; - cache_names = StrSplit(opts.cache_names, ","); - std::vector cache_shapes; - cache_shapes = StrSplit(opts.cache_shape, ","); - assert(cache_shapes.size() == cache_names.size()); - - cache_encouts_.clear(); - cache_names_idx_.clear(); - for (size_t i = 0; i < cache_shapes.size(); i++) { - std::vector tmp_shape; - tmp_shape = StrSplit(cache_shapes[i], "-"); - std::vector cur_shape; - std::transform(tmp_shape.begin(), - tmp_shape.end(), - std::back_inserter(cur_shape), - [](const std::string& s) { return atoi(s.c_str()); }); - cache_names_idx_[cache_names[i]] = i; - std::shared_ptr> cache_eout = - std::make_shared>(cur_shape); - cache_encouts_.push_back(cache_eout); - } -} - -PaddleNnet::PaddleNnet(const ModelOptions& opts) : opts_(opts) { - subsampling_rate_ = opts.subsample_rate; - paddle_infer::Config config; - config.SetModel(opts.model_path, opts.param_path); - if (opts.use_gpu) { - config.EnableUseGpu(500, 0); - } - config.SwitchIrOptim(opts.switch_ir_optim); - if (opts.enable_fc_padding == false) { - config.DisableFCPadding(); - } - if (opts.enable_profile) { - config.EnableProfile(); - } - pool.reset( - new paddle_infer::services::PredictorPool(config, opts.thread_num)); - if (pool == nullptr) { - LOG(ERROR) << "create the predictor pool failed"; - } - pool_usages.resize(opts.thread_num); - std::fill(pool_usages.begin(), pool_usages.end(), false); - LOG(INFO) << "load paddle model success"; - - LOG(INFO) << "start to check the predictor input and output names"; - LOG(INFO) << "input names: " << opts.input_names; - LOG(INFO) << "output names: " << opts.output_names; - std::vector input_names_vec = StrSplit(opts.input_names, ","); - std::vector output_names_vec = StrSplit(opts.output_names, ","); - - paddle_infer::Predictor* predictor = GetPredictor(); - - std::vector model_input_names = predictor->GetInputNames(); - assert(input_names_vec.size() == model_input_names.size()); - for (size_t i = 0; i < model_input_names.size(); i++) { - assert(input_names_vec[i] == model_input_names[i]); - } - - std::vector model_output_names = predictor->GetOutputNames(); - assert(output_names_vec.size() == model_output_names.size()); - for (size_t i = 0; i < output_names_vec.size(); i++) { - assert(output_names_vec[i] == model_output_names[i]); - } - - ReleasePredictor(predictor); - InitCacheEncouts(opts); -} - -void PaddleNnet::Reset() { InitCacheEncouts(opts_); } - -paddle_infer::Predictor* PaddleNnet::GetPredictor() { - paddle_infer::Predictor* predictor = nullptr; - - std::lock_guard guard(pool_mutex); - int pred_id = 0; - - while (pred_id < pool_usages.size()) { - if (pool_usages[pred_id] == false) { - predictor = pool->Retrive(pred_id); - break; - } - ++pred_id; - } - - if (predictor) { - pool_usages[pred_id] = true; - predictor_to_thread_id[predictor] = pred_id; - } else { - LOG(INFO) << "Failed to get predictor from pool !!!"; - } - - return predictor; -} - -int PaddleNnet::ReleasePredictor(paddle_infer::Predictor* predictor) { - std::lock_guard guard(pool_mutex); - auto iter = predictor_to_thread_id.find(predictor); - - if (iter == predictor_to_thread_id.end()) { - LOG(INFO) << "there is no such predictor"; - return 0; - } - - pool_usages[iter->second] = false; - predictor_to_thread_id.erase(predictor); - return 0; -} - -shared_ptr> PaddleNnet::GetCacheEncoder(const string& name) { - auto iter = cache_names_idx_.find(name); - if (iter == cache_names_idx_.end()) { - return nullptr; - } - assert(iter->second < cache_encouts_.size()); - return cache_encouts_[iter->second]; -} - -void PaddleNnet::FeedForward(const Vector& features, - const int32& feature_dim, - NnetOut* out) { - paddle_infer::Predictor* predictor = GetPredictor(); - - int feat_row = features.Dim() / feature_dim; - - std::vector input_names = predictor->GetInputNames(); - std::vector output_names = predictor->GetOutputNames(); - - // feed inputs - std::unique_ptr input_tensor = - predictor->GetInputHandle(input_names[0]); - std::vector INPUT_SHAPE = {1, feat_row, feature_dim}; - input_tensor->Reshape(INPUT_SHAPE); - input_tensor->CopyFromCpu(features.Data()); - - std::unique_ptr input_len = - predictor->GetInputHandle(input_names[1]); - std::vector input_len_size = {1}; - input_len->Reshape(input_len_size); - std::vector audio_len; - audio_len.push_back(feat_row); - input_len->CopyFromCpu(audio_len.data()); - - std::unique_ptr state_h = - predictor->GetInputHandle(input_names[2]); - shared_ptr> h_cache = GetCacheEncoder(input_names[2]); - state_h->Reshape(h_cache->get_shape()); - state_h->CopyFromCpu(h_cache->get_data().data()); - - std::unique_ptr state_c = - predictor->GetInputHandle(input_names[3]); - shared_ptr> c_cache = GetCacheEncoder(input_names[3]); - state_c->Reshape(c_cache->get_shape()); - state_c->CopyFromCpu(c_cache->get_data().data()); - - // forward - bool success = predictor->Run(); - - if (success == false) { - LOG(INFO) << "predictor run occurs error"; - } - - // fetch outpus - std::unique_ptr h_out = - predictor->GetOutputHandle(output_names[2]); - assert(h_cache->get_shape() == h_out->shape()); - h_out->CopyToCpu(h_cache->get_data().data()); - - std::unique_ptr c_out = - predictor->GetOutputHandle(output_names[3]); - assert(c_cache->get_shape() == c_out->shape()); - c_out->CopyToCpu(c_cache->get_data().data()); - - std::unique_ptr output_tensor = - predictor->GetOutputHandle(output_names[0]); - std::vector output_shape = output_tensor->shape(); - int32 row = output_shape[1]; - int32 col = output_shape[2]; - - - // inferences->Resize(row * col); - // *inference_dim = col; - out->logprobs.Resize(row * col); - out->vocab_dim = col; - output_tensor->CopyToCpu(out->logprobs.Data()); - - ReleasePredictor(predictor); -} - -} // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/asr/nnet/ds2_nnet.h b/speechx/speechx/asr/nnet/ds2_nnet.h deleted file mode 100644 index 420fa1771..000000000 --- a/speechx/speechx/asr/nnet/ds2_nnet.h +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#pragma once -#include - -#include "base/common.h" -#include "kaldi/matrix/kaldi-matrix.h" -#include "nnet/nnet_itf.h" -#include "paddle_inference_api.h" - -namespace ppspeech { - - -template -class Tensor { - public: - Tensor() {} - explicit Tensor(const std::vector& shape) : _shape(shape) { - int neml = std::accumulate( - _shape.begin(), _shape.end(), 1, std::multiplies()); - LOG(INFO) << "Tensor neml: " << neml; - _data.resize(neml, 0); - } - - void reshape(const std::vector& shape) { - _shape = shape; - int neml = std::accumulate( - _shape.begin(), _shape.end(), 1, std::multiplies()); - _data.resize(neml, 0); - } - - const std::vector& get_shape() const { return _shape; } - std::vector& get_data() { return _data; } - - private: - std::vector _shape; - std::vector _data; -}; - -class PaddleNnet : public NnetBase { - public: - explicit PaddleNnet(const ModelOptions& opts); - - void FeedForward(const kaldi::Vector& features, - const int32& feature_dim, - NnetOut* out) override; - - void AttentionRescoring(const std::vector>& hyps, - float reverse_weight, - std::vector* rescoring_score) override { - VLOG(2) << "deepspeech2 not has AttentionRescoring."; - } - - void Dim(); - - void Reset() override; - - bool IsLogProb() override { return false; } - - - std::shared_ptr> GetCacheEncoder( - const std::string& name); - - void InitCacheEncouts(const ModelOptions& opts); - - void EncoderOuts(std::vector>* encoder_out) - const override {} - - private: - paddle_infer::Predictor* GetPredictor(); - int ReleasePredictor(paddle_infer::Predictor* predictor); - - std::unique_ptr pool; - std::vector pool_usages; - std::mutex pool_mutex; - std::map predictor_to_thread_id; - std::map cache_names_idx_; - std::vector>> cache_encouts_; - - ModelOptions opts_; - - public: - DISALLOW_COPY_AND_ASSIGN(PaddleNnet); -}; - -} // namespace ppspeech diff --git a/speechx/speechx/asr/nnet/ds2_nnet_main.cc b/speechx/speechx/asr/nnet/ds2_nnet_main.cc deleted file mode 100644 index 6092b8a4c..000000000 --- a/speechx/speechx/asr/nnet/ds2_nnet_main.cc +++ /dev/null @@ -1,142 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "base/common.h" -#include "decoder/param.h" -#include "frontend/audio/assembler.h" -#include "frontend/audio/data_cache.h" -#include "kaldi/util/table-types.h" -#include "nnet/decodable.h" -#include "nnet/ds2_nnet.h" - -DEFINE_string(feature_rspecifier, "", "test feature rspecifier"); -DEFINE_string(nnet_prob_wspecifier, "", "nnet porb wspecifier"); - -using kaldi::BaseFloat; -using kaldi::Matrix; -using std::vector; - -int main(int argc, char* argv[]) { - gflags::SetUsageMessage("Usage:"); - gflags::ParseCommandLineFlags(&argc, &argv, false); - google::InitGoogleLogging(argv[0]); - google::InstallFailureSignalHandler(); - FLAGS_logtostderr = 1; - - kaldi::SequentialBaseFloatMatrixReader feature_reader( - FLAGS_feature_rspecifier); - kaldi::BaseFloatMatrixWriter nnet_writer(FLAGS_nnet_prob_wspecifier); - std::string model_graph = FLAGS_model_path; - std::string model_params = FLAGS_param_path; - LOG(INFO) << "model path: " << model_graph; - LOG(INFO) << "model param: " << model_params; - - int32 num_done = 0, num_err = 0; - - ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags(); - - std::shared_ptr nnet( - new ppspeech::PaddleNnet(model_opts)); - std::shared_ptr raw_data(new ppspeech::DataCache()); - std::shared_ptr decodable( - new ppspeech::Decodable(nnet, raw_data, FLAGS_acoustic_scale)); - - int32 chunk_size = FLAGS_receptive_field_length + - (FLAGS_nnet_decoder_chunk - 1) * FLAGS_subsampling_rate; - int32 chunk_stride = FLAGS_subsampling_rate * FLAGS_nnet_decoder_chunk; - int32 receptive_field_length = FLAGS_receptive_field_length; - LOG(INFO) << "chunk size (frame): " << chunk_size; - LOG(INFO) << "chunk stride (frame): " << chunk_stride; - LOG(INFO) << "receptive field (frame): " << receptive_field_length; - kaldi::Timer timer; - for (; !feature_reader.Done(); feature_reader.Next()) { - string utt = feature_reader.Key(); - kaldi::Matrix feature = feature_reader.Value(); - raw_data->SetDim(feature.NumCols()); - LOG(INFO) << "process utt: " << utt; - LOG(INFO) << "rows: " << feature.NumRows(); - LOG(INFO) << "cols: " << feature.NumCols(); - - int32 row_idx = 0; - int32 padding_len = 0; - int32 ori_feature_len = feature.NumRows(); - if ((feature.NumRows() - chunk_size) % chunk_stride != 0) { - padding_len = - chunk_stride - (feature.NumRows() - chunk_size) % chunk_stride; - feature.Resize(feature.NumRows() + padding_len, - feature.NumCols(), - kaldi::kCopyData); - } - int32 num_chunks = (feature.NumRows() - chunk_size) / chunk_stride + 1; - int32 frame_idx = 0; - std::vector> prob_vec; - for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) { - kaldi::Vector feature_chunk(chunk_size * - feature.NumCols()); - int32 feature_chunk_size = 0; - if (ori_feature_len > chunk_idx * chunk_stride) { - feature_chunk_size = std::min( - ori_feature_len - chunk_idx * chunk_stride, chunk_size); - } - if (feature_chunk_size < receptive_field_length) break; - - int32 start = chunk_idx * chunk_stride; - for (int row_id = 0; row_id < chunk_size; ++row_id) { - kaldi::SubVector tmp(feature, start); - kaldi::SubVector f_chunk_tmp( - feature_chunk.Data() + row_id * feature.NumCols(), - feature.NumCols()); - f_chunk_tmp.CopyFromVec(tmp); - ++start; - } - raw_data->Accept(feature_chunk); - if (chunk_idx == num_chunks - 1) { - raw_data->SetFinished(); - } - vector prob; - while (decodable->FrameLikelihood(frame_idx, &prob)) { - kaldi::Vector vec_tmp(prob.size()); - std::memcpy(vec_tmp.Data(), - prob.data(), - sizeof(kaldi::BaseFloat) * prob.size()); - prob_vec.push_back(vec_tmp); - frame_idx++; - } - } - decodable->Reset(); - if (prob_vec.size() == 0) { - // the TokenWriter can not write empty string. - ++num_err; - KALDI_LOG << " the nnet prob of " << utt << " is empty"; - continue; - } - kaldi::Matrix result(prob_vec.size(), - prob_vec[0].Dim()); - for (int row_idx = 0; row_idx < prob_vec.size(); ++row_idx) { - for (int32 col_idx = 0; col_idx < prob_vec[0].Dim(); ++col_idx) { - result(row_idx, col_idx) = prob_vec[row_idx](col_idx); - } - } - - nnet_writer.Write(utt, result); - ++num_done; - } - - double elapsed = timer.Elapsed(); - KALDI_LOG << " cost:" << elapsed << " s"; - - KALDI_LOG << "Done " << num_done << " utterances, " << num_err - << " with errors."; - return (num_done != 0 ? 0 : 1); -} diff --git a/speechx/speechx/asr/nnet/nnet_producer.cc b/speechx/speechx/asr/nnet/nnet_producer.cc index 3a0c4f188..955075913 100644 --- a/speechx/speechx/asr/nnet/nnet_producer.cc +++ b/speechx/speechx/asr/nnet/nnet_producer.cc @@ -65,7 +65,6 @@ bool NnetProducer::Compute() { size_t nframes = logprobs.Dim() / vocab_dim; VLOG(2) << "Forward out " << nframes << " decoder frames."; std::vector logprob(vocab_dim); - // remove later. for (size_t idx = 0; idx < nframes; ++idx) { for (size_t prob_idx = 0; prob_idx < vocab_dim; ++prob_idx) { logprob[prob_idx] = logprobs(idx * vocab_dim + prob_idx); diff --git a/speechx/speechx/asr/recognizer/CMakeLists.txt b/speechx/speechx/asr/recognizer/CMakeLists.txt index 53e2e58db..6d8db93c1 100644 --- a/speechx/speechx/asr/recognizer/CMakeLists.txt +++ b/speechx/speechx/asr/recognizer/CMakeLists.txt @@ -1,46 +1,22 @@ set(srcs) -if (USING_DS2) list(APPEND srcs -recognizer.cc + u2_recognizer.cc ) -endif() - -if (USING_U2) - list(APPEND srcs - u2_recognizer.cc - ) -endif() add_library(recognizer STATIC ${srcs}) target_link_libraries(recognizer PUBLIC decoder) -# test -if (USING_DS2) - set(BINS recognizer_main) - - foreach(bin_name IN LISTS BINS) - add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) - target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) - target_link_libraries(${bin_name} PUBLIC recognizer nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS}) - endforeach() -endif() - - -if (USING_U2) - set(TEST_BINS - u2_recognizer_main - u2_recognizer_thread_main - ) - - foreach(bin_name IN LISTS TEST_BINS) - add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) - target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) - target_link_libraries(${bin_name} recognizer nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util) - target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS}) - target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) - target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS}) - endforeach() - -endif() +set(TEST_BINS + u2_recognizer_main + u2_recognizer_thread_main +) +foreach(bin_name IN LISTS TEST_BINS) + add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) + target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) + target_link_libraries(${bin_name} recognizer nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util) + target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS}) + target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) + target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS}) +endforeach() \ No newline at end of file diff --git a/speechx/speechx/asr/recognizer/recognizer.cc b/speechx/speechx/asr/recognizer/recognizer.cc deleted file mode 100644 index c66318131..000000000 --- a/speechx/speechx/asr/recognizer/recognizer.cc +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "recognizer/recognizer.h" - - -namespace ppspeech { - -using kaldi::BaseFloat; -using kaldi::SubVector; -using kaldi::Vector; -using kaldi::VectorBase; -using std::unique_ptr; -using std::vector; - - -Recognizer::Recognizer(const RecognizerResource& resource) { - // resource_ = resource; - const FeaturePipelineOptions& feature_opts = resource.feature_pipeline_opts; - feature_pipeline_.reset(new FeaturePipeline(feature_opts)); - - std::shared_ptr nnet(new PaddleNnet(resource.model_opts)); - - BaseFloat ac_scale = resource.acoustic_scale; - decodable_.reset(new Decodable(nnet, feature_pipeline_, ac_scale)); - - decoder_.reset(new TLGDecoder(resource.tlg_opts)); - - input_finished_ = false; -} - -void Recognizer::Accept(const Vector& waves) { - feature_pipeline_->Accept(waves); -} - -void Recognizer::Decode() { decoder_->AdvanceDecode(decodable_); } - -std::string Recognizer::GetFinalResult() { - return decoder_->GetFinalBestPath(); -} - -std::string Recognizer::GetPartialResult() { - return decoder_->GetPartialResult(); -} - -void Recognizer::SetFinished() { - feature_pipeline_->SetFinished(); - input_finished_ = true; -} - -bool Recognizer::IsFinished() { return input_finished_; } - -void Recognizer::Reset() { - feature_pipeline_->Reset(); - decodable_->Reset(); - decoder_->Reset(); -} - -} // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/asr/recognizer/recognizer.h b/speechx/speechx/asr/recognizer/recognizer.h deleted file mode 100644 index 57d5bb363..000000000 --- a/speechx/speechx/asr/recognizer/recognizer.h +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// todo refactor later (SGoat) - -#pragma once - -#include "decoder/ctc_beam_search_decoder.h" -#include "decoder/ctc_tlg_decoder.h" -#include "frontend/audio/feature_pipeline.h" -#include "nnet/decodable.h" -#include "nnet/ds2_nnet.h" - -DECLARE_double(acoustic_scale); - -namespace ppspeech { - -struct RecognizerResource { - kaldi::BaseFloat acoustic_scale{1.0}; - FeaturePipelineOptions feature_pipeline_opts{}; - ModelOptions model_opts{}; - TLGDecoderOptions tlg_opts{}; - // CTCBeamSearchOptions beam_search_opts; - - static RecognizerResource InitFromFlags() { - RecognizerResource resource; - resource.acoustic_scale = FLAGS_acoustic_scale; - resource.feature_pipeline_opts = - FeaturePipelineOptions::InitFromFlags(); - resource.feature_pipeline_opts.assembler_opts.fill_zero = true; - LOG(INFO) << "ds2 need fill zero be true: " - << resource.feature_pipeline_opts.assembler_opts.fill_zero; - resource.model_opts = ModelOptions::InitFromFlags(); - resource.tlg_opts = TLGDecoderOptions::InitFromFlags(); - return resource; - } -}; - -class Recognizer { - public: - explicit Recognizer(const RecognizerResource& resouce); - void Accept(const kaldi::Vector& waves); - void Decode(); - std::string GetFinalResult(); - std::string GetPartialResult(); - void SetFinished(); - bool IsFinished(); - void Reset(); - - private: - // std::shared_ptr resource_; - // RecognizerResource resource_; - std::shared_ptr feature_pipeline_; - std::shared_ptr decodable_; - std::unique_ptr decoder_; - bool input_finished_; -}; - -} // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/asr/recognizer/recognizer_main.cc b/speechx/speechx/asr/recognizer/recognizer_main.cc deleted file mode 100644 index cb0de2d6a..000000000 --- a/speechx/speechx/asr/recognizer/recognizer_main.cc +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "decoder/param.h" -#include "kaldi/feat/wave-reader.h" -#include "kaldi/util/table-types.h" -#include "recognizer/recognizer.h" - -DEFINE_string(wav_rspecifier, "", "test feature rspecifier"); -DEFINE_string(result_wspecifier, "", "test result wspecifier"); -DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size"); -DEFINE_int32(sample_rate, 16000, "sample rate"); - - -int main(int argc, char* argv[]) { - gflags::SetUsageMessage("Usage:"); - gflags::ParseCommandLineFlags(&argc, &argv, false); - google::InitGoogleLogging(argv[0]); - google::InstallFailureSignalHandler(); - FLAGS_logtostderr = 1; - - ppspeech::RecognizerResource resource = - ppspeech::RecognizerResource::InitFromFlags(); - ppspeech::Recognizer recognizer(resource); - - kaldi::SequentialTableReader wav_reader( - FLAGS_wav_rspecifier); - kaldi::TokenWriter result_writer(FLAGS_result_wspecifier); - - int sample_rate = FLAGS_sample_rate; - float streaming_chunk = FLAGS_streaming_chunk; - int chunk_sample_size = streaming_chunk * sample_rate; - LOG(INFO) << "sr: " << sample_rate; - LOG(INFO) << "chunk size (s): " << streaming_chunk; - LOG(INFO) << "chunk size (sample): " << chunk_sample_size; - - int32 num_done = 0, num_err = 0; - double tot_wav_duration = 0.0; - - kaldi::Timer timer; - - for (; !wav_reader.Done(); wav_reader.Next()) { - std::string utt = wav_reader.Key(); - const kaldi::WaveData& wave_data = wav_reader.Value(); - - int32 this_channel = 0; - kaldi::SubVector waveform(wave_data.Data(), - this_channel); - int tot_samples = waveform.Dim(); - tot_wav_duration += tot_samples * 1.0 / sample_rate; - LOG(INFO) << "wav len (sample): " << tot_samples; - - int sample_offset = 0; - std::vector> feats; - int feature_rows = 0; - while (sample_offset < tot_samples) { - int cur_chunk_size = - std::min(chunk_sample_size, tot_samples - sample_offset); - - kaldi::Vector wav_chunk(cur_chunk_size); - for (int i = 0; i < cur_chunk_size; ++i) { - wav_chunk(i) = waveform(sample_offset + i); - } - // wav_chunk = waveform.Range(sample_offset + i, cur_chunk_size); - - recognizer.Accept(wav_chunk); - if (cur_chunk_size < chunk_sample_size) { - recognizer.SetFinished(); - } - recognizer.Decode(); - - // no overlap - sample_offset += cur_chunk_size; - } - - std::string result; - result = recognizer.GetFinalResult(); - recognizer.Reset(); - if (result.empty()) { - // the TokenWriter can not write empty string. - ++num_err; - KALDI_LOG << " the result of " << utt << " is empty"; - continue; - } - KALDI_LOG << " the result of " << utt << " is " << result; - result_writer.Write(utt, result); - ++num_done; - } - double elapsed = timer.Elapsed(); - KALDI_LOG << "Done " << num_done << " out of " << (num_err + num_done); - KALDI_LOG << " cost:" << elapsed << " s"; - KALDI_LOG << "total wav duration is: " << tot_wav_duration << " s"; - KALDI_LOG << "the RTF is: " << elapsed / tot_wav_duration; -} diff --git a/speechx/speechx/codelab/CMakeLists.txt b/speechx/speechx/codelab/CMakeLists.txt index 950432637..c8445fb82 100644 --- a/speechx/speechx/codelab/CMakeLists.txt +++ b/speechx/speechx/codelab/CMakeLists.txt @@ -1,4 +1,3 @@ cmake_minimum_required(VERSION 3.14 FATAL_ERROR) add_subdirectory(glog) -add_subdirectory(nnet) diff --git a/speechx/speechx/codelab/nnet/CMakeLists.txt b/speechx/speechx/codelab/nnet/CMakeLists.txt deleted file mode 100644 index dcad8a9c6..000000000 --- a/speechx/speechx/codelab/nnet/CMakeLists.txt +++ /dev/null @@ -1,6 +0,0 @@ -cmake_minimum_required(VERSION 3.14 FATAL_ERROR) - -set(bin_name ds2_model_test_main) -add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) -target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) -target_link_libraries(${bin_name} PUBLIC nnet gflags glog ${DEPS}) diff --git a/speechx/speechx/codelab/nnet/ds2_model_test_main.cc b/speechx/speechx/codelab/nnet/ds2_model_test_main.cc deleted file mode 100644 index ab7b2cb58..000000000 --- a/speechx/speechx/codelab/nnet/ds2_model_test_main.cc +++ /dev/null @@ -1,207 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// deepspeech2 online model info - -#include -#include -#include -#include -#include -#include -#include - -#include "base/flags.h" -#include "base/log.h" -#include "paddle_inference_api.h" - -using std::cout; -using std::endl; - - -DEFINE_string(model_path, "", "xxx.pdmodel"); -DEFINE_string(param_path, "", "xxx.pdiparams"); -DEFINE_int32(chunk_size, 35, "feature chunk size, unit:frame"); -DEFINE_int32(feat_dim, 161, "feature dim"); - - -void produce_data(std::vector>* data); -void model_forward_test(); - -void produce_data(std::vector>* data) { - int chunk_size = FLAGS_chunk_size; // chunk_size in frame - int col_size = FLAGS_feat_dim; // feat dim - cout << "chunk size: " << chunk_size << endl; - cout << "feat dim: " << col_size << endl; - - data->reserve(chunk_size); - data->back().reserve(col_size); - for (int row = 0; row < chunk_size; ++row) { - data->push_back(std::vector()); - for (int col_idx = 0; col_idx < col_size; ++col_idx) { - data->back().push_back(0.201); - } - } -} - -void model_forward_test() { - std::cout << "1. read the data" << std::endl; - std::vector> feats; - produce_data(&feats); - - std::cout << "2. load the model" << std::endl; - ; - std::string model_graph = FLAGS_model_path; - std::string model_params = FLAGS_param_path; - CHECK_NE(model_graph, ""); - CHECK_NE(model_params, ""); - cout << "model path: " << model_graph << endl; - cout << "model param path : " << model_params << endl; - - paddle_infer::Config config; - config.SetModel(model_graph, model_params); - config.SwitchIrOptim(false); - cout << "SwitchIrOptim: " << false << endl; - config.DisableFCPadding(); - cout << "DisableFCPadding: " << endl; - auto predictor = paddle_infer::CreatePredictor(config); - - std::cout << "3. feat shape, row=" << feats.size() - << ",col=" << feats[0].size() << std::endl; - std::vector pp_input_mat; - for (const auto& item : feats) { - pp_input_mat.insert(pp_input_mat.end(), item.begin(), item.end()); - } - - std::cout << "4. fead the data to model" << std::endl; - int row = feats.size(); - int col = feats[0].size(); - std::vector input_names = predictor->GetInputNames(); - std::vector output_names = predictor->GetOutputNames(); - for (auto name : input_names) { - cout << "model input names: " << name << endl; - } - for (auto name : output_names) { - cout << "model output names: " << name << endl; - } - - // input - std::unique_ptr input_tensor = - predictor->GetInputHandle(input_names[0]); - std::vector INPUT_SHAPE = {1, row, col}; - input_tensor->Reshape(INPUT_SHAPE); - input_tensor->CopyFromCpu(pp_input_mat.data()); - - // input length - std::unique_ptr input_len = - predictor->GetInputHandle(input_names[1]); - std::vector input_len_size = {1}; - input_len->Reshape(input_len_size); - std::vector audio_len; - audio_len.push_back(row); - input_len->CopyFromCpu(audio_len.data()); - - // state_h - std::unique_ptr chunk_state_h_box = - predictor->GetInputHandle(input_names[2]); - std::vector chunk_state_h_box_shape = {5, 1, 1024}; - chunk_state_h_box->Reshape(chunk_state_h_box_shape); - int chunk_state_h_box_size = - std::accumulate(chunk_state_h_box_shape.begin(), - chunk_state_h_box_shape.end(), - 1, - std::multiplies()); - std::vector chunk_state_h_box_data(chunk_state_h_box_size, 0.0f); - chunk_state_h_box->CopyFromCpu(chunk_state_h_box_data.data()); - - // state_c - std::unique_ptr chunk_state_c_box = - predictor->GetInputHandle(input_names[3]); - std::vector chunk_state_c_box_shape = {5, 1, 1024}; - chunk_state_c_box->Reshape(chunk_state_c_box_shape); - int chunk_state_c_box_size = - std::accumulate(chunk_state_c_box_shape.begin(), - chunk_state_c_box_shape.end(), - 1, - std::multiplies()); - std::vector chunk_state_c_box_data(chunk_state_c_box_size, 0.0f); - chunk_state_c_box->CopyFromCpu(chunk_state_c_box_data.data()); - - // run - bool success = predictor->Run(); - - // state_h out - std::unique_ptr h_out = - predictor->GetOutputHandle(output_names[2]); - std::vector h_out_shape = h_out->shape(); - int h_out_size = std::accumulate( - h_out_shape.begin(), h_out_shape.end(), 1, std::multiplies()); - std::vector h_out_data(h_out_size); - h_out->CopyToCpu(h_out_data.data()); - - // stage_c out - std::unique_ptr c_out = - predictor->GetOutputHandle(output_names[3]); - std::vector c_out_shape = c_out->shape(); - int c_out_size = std::accumulate( - c_out_shape.begin(), c_out_shape.end(), 1, std::multiplies()); - std::vector c_out_data(c_out_size); - c_out->CopyToCpu(c_out_data.data()); - - // output tensor - std::unique_ptr output_tensor = - predictor->GetOutputHandle(output_names[0]); - std::vector output_shape = output_tensor->shape(); - std::vector output_probs; - int output_size = std::accumulate( - output_shape.begin(), output_shape.end(), 1, std::multiplies()); - output_probs.resize(output_size); - output_tensor->CopyToCpu(output_probs.data()); - row = output_shape[1]; - col = output_shape[2]; - - // probs - std::vector> probs; - probs.reserve(row); - for (int i = 0; i < row; i++) { - probs.push_back(std::vector()); - probs.back().reserve(col); - - for (int j = 0; j < col; j++) { - probs.back().push_back(output_probs[i * col + j]); - } - } - - std::vector> log_feat = probs; - std::cout << "probs, row: " << log_feat.size() - << " col: " << log_feat[0].size() << std::endl; - for (size_t row_idx = 0; row_idx < log_feat.size(); ++row_idx) { - for (size_t col_idx = 0; col_idx < log_feat[row_idx].size(); - ++col_idx) { - std::cout << log_feat[row_idx][col_idx] << " "; - } - std::cout << std::endl; - } -} - -int main(int argc, char* argv[]) { - gflags::SetUsageMessage("Usage:"); - gflags::ParseCommandLineFlags(&argc, &argv, false); - google::InitGoogleLogging(argv[0]); - google::InstallFailureSignalHandler(); - FLAGS_logtostderr = 1; - - model_forward_test(); - return 0; -} diff --git a/speechx/speechx/common/frontend/audio/cmvn_json2kaldi_main.cc b/speechx/speechx/common/frontend/audio/cmvn_json2kaldi_main.cc index 713c9ef1e..8c65b3465 100644 --- a/speechx/speechx/common/frontend/audio/cmvn_json2kaldi_main.cc +++ b/speechx/speechx/common/frontend/audio/cmvn_json2kaldi_main.cc @@ -20,15 +20,12 @@ #include "kaldi/matrix/kaldi-matrix.h" #include "kaldi/util/kaldi-io.h" #include "utils/file_utils.h" -// #include "boost/json.hpp" -#include +#include "utils/picojson.h" DEFINE_string(json_file, "", "cmvn json file"); DEFINE_string(cmvn_write_path, "./cmvn.ark", "write cmvn"); DEFINE_bool(binary, true, "write cmvn in binary (true) or text(false)"); -using namespace boost::json; // from - int main(int argc, char* argv[]) { gflags::SetUsageMessage("Usage:"); gflags::ParseCommandLineFlags(&argc, &argv, false); @@ -40,36 +37,49 @@ int main(int argc, char* argv[]) { auto ifs = std::ifstream(FLAGS_json_file); std::string json_str = ppspeech::ReadFile2String(FLAGS_json_file); - auto value = boost::json::parse(json_str); - if (!value.is_object()) { + picojson::value value; + std::string err; + const char* json_end = picojson::parse( + value, json_str.c_str(), json_str.c_str() + json_str.size(), &err); + if (!value.is()) { LOG(ERROR) << "Input json file format error."; } - for (auto obj : value.as_object()) { - if (obj.key() == "mean_stat") { - VLOG(2) << "mean_stat:" << obj.value(); + const picojson::value::object& obj = value.get(); + for (picojson::value::object::const_iterator elem = obj.begin(); + elem != obj.end(); + ++elem) { + if (elem->first == "mean_stat") { + VLOG(2) << "mean_stat:" << elem->second; + // const picojson::value tmp = + // elem->second.get(0);//(); + double tmp = + elem->second.get(0).get(); //(); + VLOG(2) << "tmp: " << tmp; } - if (obj.key() == "var_stat") { - VLOG(2) << "var_stat: " << obj.value(); + if (elem->first == "var_stat") { + VLOG(2) << "var_stat: " << elem->second; } - if (obj.key() == "frame_num") { - VLOG(2) << "frame_num: " << obj.value(); + if (elem->first == "frame_num") { + VLOG(2) << "frame_num: " << elem->second; } } - boost::json::array mean_stat = value.at("mean_stat").as_array(); + const picojson::value::array& mean_stat = + value.get("mean_stat").get(); std::vector mean_stat_vec; for (auto it = mean_stat.begin(); it != mean_stat.end(); it++) { - mean_stat_vec.push_back(it->as_double()); + mean_stat_vec.push_back((*it).get()); } - boost::json::array var_stat = value.at("var_stat").as_array(); + const picojson::value::array& var_stat = + value.get("var_stat").get(); std::vector var_stat_vec; for (auto it = var_stat.begin(); it != var_stat.end(); it++) { - var_stat_vec.push_back(it->as_double()); + var_stat_vec.push_back((*it).get()); } - kaldi::int32 frame_num = uint64_t(value.at("frame_num").as_int64()); + kaldi::int32 frame_num = value.get("frame_num").get(); LOG(INFO) << "nframe: " << frame_num; size_t mean_size = mean_stat_vec.size(); diff --git a/speechx/speechx/common/utils/picojson.h b/speechx/speechx/common/utils/picojson.h new file mode 100644 index 000000000..28c5b7fa8 --- /dev/null +++ b/speechx/speechx/common/utils/picojson.h @@ -0,0 +1,1202 @@ +/* + * Copyright 2009-2010 Cybozu Labs, Inc. + * Copyright 2011-2014 Kazuho Oku + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ +#ifndef picojson_h +#define picojson_h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define PICOJSON_USE_INT64 1 + +// for isnan/isinf +#if __cplusplus >= 201103L +#include +#else +extern "C" { +#ifdef _MSC_VER +#include +#elif defined(__INTEL_COMPILER) +#include +#else +#include +#endif +} +#endif + +#ifndef PICOJSON_USE_RVALUE_REFERENCE +#if (defined(__cpp_rvalue_references) && __cpp_rvalue_references >= 200610) || (defined(_MSC_VER) && _MSC_VER >= 1600) +#define PICOJSON_USE_RVALUE_REFERENCE 1 +#else +#define PICOJSON_USE_RVALUE_REFERENCE 0 +#endif +#endif // PICOJSON_USE_RVALUE_REFERENCE + +#ifndef PICOJSON_NOEXCEPT +#if PICOJSON_USE_RVALUE_REFERENCE +#define PICOJSON_NOEXCEPT noexcept +#else +#define PICOJSON_NOEXCEPT throw() +#endif +#endif + +// experimental support for int64_t (see README.mkdn for detail) +#ifdef PICOJSON_USE_INT64 +#define __STDC_FORMAT_MACROS +#include +#if __cplusplus >= 201103L +#include +#else +extern "C" { +#include +} +#endif +#endif + +// to disable the use of localeconv(3), set PICOJSON_USE_LOCALE to 0 +#ifndef PICOJSON_USE_LOCALE +#define PICOJSON_USE_LOCALE 1 +#endif +#if PICOJSON_USE_LOCALE +extern "C" { +#include +} +#endif + +#ifndef PICOJSON_ASSERT +#define PICOJSON_ASSERT(e) \ + do { \ + if (!(e)) \ + throw std::runtime_error(#e); \ + } while (0) +#endif + +#ifdef _MSC_VER +#define SNPRINTF _snprintf_s +#pragma warning(push) +#pragma warning(disable : 4244) // conversion from int to char +#pragma warning(disable : 4127) // conditional expression is constant +#pragma warning(disable : 4702) // unreachable code +#pragma warning(disable : 4706) // assignment within conditional expression +#else +#define SNPRINTF snprintf +#endif + +namespace picojson { + +enum { + null_type, + boolean_type, + number_type, + string_type, + array_type, + object_type +#ifdef PICOJSON_USE_INT64 + , + int64_type +#endif +}; + +enum { INDENT_WIDTH = 2, DEFAULT_MAX_DEPTHS = 100 }; + +struct null {}; + +class value { +public: + typedef std::vector array; + typedef std::map object; + union _storage { + bool boolean_; + double number_; +#ifdef PICOJSON_USE_INT64 + int64_t int64_; +#endif + std::string *string_; + array *array_; + object *object_; + }; + +protected: + int type_; + _storage u_; + +public: + value(); + value(int type, bool); + explicit value(bool b); +#ifdef PICOJSON_USE_INT64 + explicit value(int64_t i); +#endif + explicit value(double n); + explicit value(const std::string &s); + explicit value(const array &a); + explicit value(const object &o); +#if PICOJSON_USE_RVALUE_REFERENCE + explicit value(std::string &&s); + explicit value(array &&a); + explicit value(object &&o); +#endif + explicit value(const char *s); + value(const char *s, size_t len); + ~value(); + value(const value &x); + value &operator=(const value &x); +#if PICOJSON_USE_RVALUE_REFERENCE + value(value &&x) PICOJSON_NOEXCEPT; + value &operator=(value &&x) PICOJSON_NOEXCEPT; +#endif + void swap(value &x) PICOJSON_NOEXCEPT; + template bool is() const; + template const T &get() const; + template T &get(); + template void set(const T &); +#if PICOJSON_USE_RVALUE_REFERENCE + template void set(T &&); +#endif + bool evaluate_as_boolean() const; + const value &get(const size_t idx) const; + const value &get(const std::string &key) const; + value &get(const size_t idx); + value &get(const std::string &key); + + bool contains(const size_t idx) const; + bool contains(const std::string &key) const; + std::string to_str() const; + template void serialize(Iter os, bool prettify = false) const; + std::string serialize(bool prettify = false) const; + +private: + template value(const T *); // intentionally defined to block implicit conversion of pointer to bool + template static void _indent(Iter os, int indent); + template void _serialize(Iter os, int indent) const; + std::string _serialize(int indent) const; + void clear(); +}; + +typedef value::array array; +typedef value::object object; + +inline value::value() : type_(null_type), u_() { +} + +inline value::value(int type, bool) : type_(type), u_() { + switch (type) { +#define INIT(p, v) \ + case p##type: \ + u_.p = v; \ + break + INIT(boolean_, false); + INIT(number_, 0.0); +#ifdef PICOJSON_USE_INT64 + INIT(int64_, 0); +#endif + INIT(string_, new std::string()); + INIT(array_, new array()); + INIT(object_, new object()); +#undef INIT + default: + break; + } +} + +inline value::value(bool b) : type_(boolean_type), u_() { + u_.boolean_ = b; +} + +#ifdef PICOJSON_USE_INT64 +inline value::value(int64_t i) : type_(int64_type), u_() { + u_.int64_ = i; +} +#endif + +inline value::value(double n) : type_(number_type), u_() { + if ( +#ifdef _MSC_VER + !_finite(n) +#elif __cplusplus >= 201103L + std::isnan(n) || std::isinf(n) +#else + isnan(n) || isinf(n) +#endif + ) { + throw std::overflow_error(""); + } + u_.number_ = n; +} + +inline value::value(const std::string &s) : type_(string_type), u_() { + u_.string_ = new std::string(s); +} + +inline value::value(const array &a) : type_(array_type), u_() { + u_.array_ = new array(a); +} + +inline value::value(const object &o) : type_(object_type), u_() { + u_.object_ = new object(o); +} + +#if PICOJSON_USE_RVALUE_REFERENCE +inline value::value(std::string &&s) : type_(string_type), u_() { + u_.string_ = new std::string(std::move(s)); +} + +inline value::value(array &&a) : type_(array_type), u_() { + u_.array_ = new array(std::move(a)); +} + +inline value::value(object &&o) : type_(object_type), u_() { + u_.object_ = new object(std::move(o)); +} +#endif + +inline value::value(const char *s) : type_(string_type), u_() { + u_.string_ = new std::string(s); +} + +inline value::value(const char *s, size_t len) : type_(string_type), u_() { + u_.string_ = new std::string(s, len); +} + +inline void value::clear() { + switch (type_) { +#define DEINIT(p) \ + case p##type: \ + delete u_.p; \ + break + DEINIT(string_); + DEINIT(array_); + DEINIT(object_); +#undef DEINIT + default: + break; + } +} + +inline value::~value() { + clear(); +} + +inline value::value(const value &x) : type_(x.type_), u_() { + switch (type_) { +#define INIT(p, v) \ + case p##type: \ + u_.p = v; \ + break + INIT(string_, new std::string(*x.u_.string_)); + INIT(array_, new array(*x.u_.array_)); + INIT(object_, new object(*x.u_.object_)); +#undef INIT + default: + u_ = x.u_; + break; + } +} + +inline value &value::operator=(const value &x) { + if (this != &x) { + value t(x); + swap(t); + } + return *this; +} + +#if PICOJSON_USE_RVALUE_REFERENCE +inline value::value(value &&x) PICOJSON_NOEXCEPT : type_(null_type), u_() { + swap(x); +} +inline value &value::operator=(value &&x) PICOJSON_NOEXCEPT { + swap(x); + return *this; +} +#endif +inline void value::swap(value &x) PICOJSON_NOEXCEPT { + std::swap(type_, x.type_); + std::swap(u_, x.u_); +} + +#define IS(ctype, jtype) \ + template <> inline bool value::is() const { \ + return type_ == jtype##_type; \ + } +IS(null, null) +IS(bool, boolean) +#ifdef PICOJSON_USE_INT64 +IS(int64_t, int64) +#endif +IS(std::string, string) +IS(array, array) +IS(object, object) +#undef IS +template <> inline bool value::is() const { + return type_ == number_type +#ifdef PICOJSON_USE_INT64 + || type_ == int64_type +#endif + ; +} + +#define GET(ctype, var) \ + template <> inline const ctype &value::get() const { \ + PICOJSON_ASSERT("type mismatch! call is() before get()" && is()); \ + return var; \ + } \ + template <> inline ctype &value::get() { \ + PICOJSON_ASSERT("type mismatch! call is() before get()" && is()); \ + return var; \ + } +GET(bool, u_.boolean_) +GET(std::string, *u_.string_) +GET(array, *u_.array_) +GET(object, *u_.object_) +#ifdef PICOJSON_USE_INT64 +GET(double, + (type_ == int64_type && (const_cast(this)->type_ = number_type, (const_cast(this)->u_.number_ = u_.int64_)), + u_.number_)) +GET(int64_t, u_.int64_) +#else +GET(double, u_.number_) +#endif +#undef GET + +#define SET(ctype, jtype, setter) \ + template <> inline void value::set(const ctype &_val) { \ + clear(); \ + type_ = jtype##_type; \ + setter \ + } +SET(bool, boolean, u_.boolean_ = _val;) +SET(std::string, string, u_.string_ = new std::string(_val);) +SET(array, array, u_.array_ = new array(_val);) +SET(object, object, u_.object_ = new object(_val);) +SET(double, number, u_.number_ = _val;) +#ifdef PICOJSON_USE_INT64 +SET(int64_t, int64, u_.int64_ = _val;) +#endif +#undef SET + +#if PICOJSON_USE_RVALUE_REFERENCE +#define MOVESET(ctype, jtype, setter) \ + template <> inline void value::set(ctype && _val) { \ + clear(); \ + type_ = jtype##_type; \ + setter \ + } +MOVESET(std::string, string, u_.string_ = new std::string(std::move(_val));) +MOVESET(array, array, u_.array_ = new array(std::move(_val));) +MOVESET(object, object, u_.object_ = new object(std::move(_val));) +#undef MOVESET +#endif + +inline bool value::evaluate_as_boolean() const { + switch (type_) { + case null_type: + return false; + case boolean_type: + return u_.boolean_; + case number_type: + return u_.number_ != 0; +#ifdef PICOJSON_USE_INT64 + case int64_type: + return u_.int64_ != 0; +#endif + case string_type: + return !u_.string_->empty(); + default: + return true; + } +} + +inline const value &value::get(const size_t idx) const { + static value s_null; + PICOJSON_ASSERT(is()); + return idx < u_.array_->size() ? (*u_.array_)[idx] : s_null; +} + +inline value &value::get(const size_t idx) { + static value s_null; + PICOJSON_ASSERT(is()); + return idx < u_.array_->size() ? (*u_.array_)[idx] : s_null; +} + +inline const value &value::get(const std::string &key) const { + static value s_null; + PICOJSON_ASSERT(is()); + object::const_iterator i = u_.object_->find(key); + return i != u_.object_->end() ? i->second : s_null; +} + +inline value &value::get(const std::string &key) { + static value s_null; + PICOJSON_ASSERT(is()); + object::iterator i = u_.object_->find(key); + return i != u_.object_->end() ? i->second : s_null; +} + +inline bool value::contains(const size_t idx) const { + PICOJSON_ASSERT(is()); + return idx < u_.array_->size(); +} + +inline bool value::contains(const std::string &key) const { + PICOJSON_ASSERT(is()); + object::const_iterator i = u_.object_->find(key); + return i != u_.object_->end(); +} + +inline std::string value::to_str() const { + switch (type_) { + case null_type: + return "null"; + case boolean_type: + return u_.boolean_ ? "true" : "false"; +#ifdef PICOJSON_USE_INT64 + case int64_type: { + char buf[sizeof("-9223372036854775808")]; + SNPRINTF(buf, sizeof(buf), "%" PRId64, u_.int64_); + return buf; + } +#endif + case number_type: { + char buf[256]; + double tmp; + SNPRINTF(buf, sizeof(buf), fabs(u_.number_) < (1ULL << 53) && modf(u_.number_, &tmp) == 0 ? "%.f" : "%.17g", u_.number_); +#if PICOJSON_USE_LOCALE + char *decimal_point = localeconv()->decimal_point; + if (strcmp(decimal_point, ".") != 0) { + size_t decimal_point_len = strlen(decimal_point); + for (char *p = buf; *p != '\0'; ++p) { + if (strncmp(p, decimal_point, decimal_point_len) == 0) { + return std::string(buf, p) + "." + (p + decimal_point_len); + } + } + } +#endif + return buf; + } + case string_type: + return *u_.string_; + case array_type: + return "array"; + case object_type: + return "object"; + default: + PICOJSON_ASSERT(0); +#ifdef _MSC_VER + __assume(0); +#endif + } + return std::string(); +} + +template void copy(const std::string &s, Iter oi) { + std::copy(s.begin(), s.end(), oi); +} + +template struct serialize_str_char { + Iter oi; + void operator()(char c) { + switch (c) { +#define MAP(val, sym) \ + case val: \ + copy(sym, oi); \ + break + MAP('"', "\\\""); + MAP('\\', "\\\\"); + MAP('/', "\\/"); + MAP('\b', "\\b"); + MAP('\f', "\\f"); + MAP('\n', "\\n"); + MAP('\r', "\\r"); + MAP('\t', "\\t"); +#undef MAP + default: + if (static_cast(c) < 0x20 || c == 0x7f) { + char buf[7]; + SNPRINTF(buf, sizeof(buf), "\\u%04x", c & 0xff); + copy(buf, buf + 6, oi); + } else { + *oi++ = c; + } + break; + } + } +}; + +template void serialize_str(const std::string &s, Iter oi) { + *oi++ = '"'; + serialize_str_char process_char = {oi}; + std::for_each(s.begin(), s.end(), process_char); + *oi++ = '"'; +} + +template void value::serialize(Iter oi, bool prettify) const { + return _serialize(oi, prettify ? 0 : -1); +} + +inline std::string value::serialize(bool prettify) const { + return _serialize(prettify ? 0 : -1); +} + +template void value::_indent(Iter oi, int indent) { + *oi++ = '\n'; + for (int i = 0; i < indent * INDENT_WIDTH; ++i) { + *oi++ = ' '; + } +} + +template void value::_serialize(Iter oi, int indent) const { + switch (type_) { + case string_type: + serialize_str(*u_.string_, oi); + break; + case array_type: { + *oi++ = '['; + if (indent != -1) { + ++indent; + } + for (array::const_iterator i = u_.array_->begin(); i != u_.array_->end(); ++i) { + if (i != u_.array_->begin()) { + *oi++ = ','; + } + if (indent != -1) { + _indent(oi, indent); + } + i->_serialize(oi, indent); + } + if (indent != -1) { + --indent; + if (!u_.array_->empty()) { + _indent(oi, indent); + } + } + *oi++ = ']'; + break; + } + case object_type: { + *oi++ = '{'; + if (indent != -1) { + ++indent; + } + for (object::const_iterator i = u_.object_->begin(); i != u_.object_->end(); ++i) { + if (i != u_.object_->begin()) { + *oi++ = ','; + } + if (indent != -1) { + _indent(oi, indent); + } + serialize_str(i->first, oi); + *oi++ = ':'; + if (indent != -1) { + *oi++ = ' '; + } + i->second._serialize(oi, indent); + } + if (indent != -1) { + --indent; + if (!u_.object_->empty()) { + _indent(oi, indent); + } + } + *oi++ = '}'; + break; + } + default: + copy(to_str(), oi); + break; + } + if (indent == 0) { + *oi++ = '\n'; + } +} + +inline std::string value::_serialize(int indent) const { + std::string s; + _serialize(std::back_inserter(s), indent); + return s; +} + +template class input { +protected: + Iter cur_, end_; + bool consumed_; + int line_; + +public: + input(const Iter &first, const Iter &last) : cur_(first), end_(last), consumed_(false), line_(1) { + } + int getc() { + if (consumed_) { + if (*cur_ == '\n') { + ++line_; + } + ++cur_; + } + if (cur_ == end_) { + consumed_ = false; + return -1; + } + consumed_ = true; + return *cur_ & 0xff; + } + void ungetc() { + consumed_ = false; + } + Iter cur() const { + if (consumed_) { + input *self = const_cast *>(this); + self->consumed_ = false; + ++self->cur_; + } + return cur_; + } + int line() const { + return line_; + } + void skip_ws() { + while (1) { + int ch = getc(); + if (!(ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r')) { + ungetc(); + break; + } + } + } + bool expect(const int expected) { + skip_ws(); + if (getc() != expected) { + ungetc(); + return false; + } + return true; + } + bool match(const std::string &pattern) { + for (std::string::const_iterator pi(pattern.begin()); pi != pattern.end(); ++pi) { + if (getc() != *pi) { + ungetc(); + return false; + } + } + return true; + } +}; + +template inline int _parse_quadhex(input &in) { + int uni_ch = 0, hex; + for (int i = 0; i < 4; i++) { + if ((hex = in.getc()) == -1) { + return -1; + } + if ('0' <= hex && hex <= '9') { + hex -= '0'; + } else if ('A' <= hex && hex <= 'F') { + hex -= 'A' - 0xa; + } else if ('a' <= hex && hex <= 'f') { + hex -= 'a' - 0xa; + } else { + in.ungetc(); + return -1; + } + uni_ch = uni_ch * 16 + hex; + } + return uni_ch; +} + +template inline bool _parse_codepoint(String &out, input &in) { + int uni_ch; + if ((uni_ch = _parse_quadhex(in)) == -1) { + return false; + } + if (0xd800 <= uni_ch && uni_ch <= 0xdfff) { + if (0xdc00 <= uni_ch) { + // a second 16-bit of a surrogate pair appeared + return false; + } + // first 16-bit of surrogate pair, get the next one + if (in.getc() != '\\' || in.getc() != 'u') { + in.ungetc(); + return false; + } + int second = _parse_quadhex(in); + if (!(0xdc00 <= second && second <= 0xdfff)) { + return false; + } + uni_ch = ((uni_ch - 0xd800) << 10) | ((second - 0xdc00) & 0x3ff); + uni_ch += 0x10000; + } + if (uni_ch < 0x80) { + out.push_back(static_cast(uni_ch)); + } else { + if (uni_ch < 0x800) { + out.push_back(static_cast(0xc0 | (uni_ch >> 6))); + } else { + if (uni_ch < 0x10000) { + out.push_back(static_cast(0xe0 | (uni_ch >> 12))); + } else { + out.push_back(static_cast(0xf0 | (uni_ch >> 18))); + out.push_back(static_cast(0x80 | ((uni_ch >> 12) & 0x3f))); + } + out.push_back(static_cast(0x80 | ((uni_ch >> 6) & 0x3f))); + } + out.push_back(static_cast(0x80 | (uni_ch & 0x3f))); + } + return true; +} + +template inline bool _parse_string(String &out, input &in) { + while (1) { + int ch = in.getc(); + if (ch < ' ') { + in.ungetc(); + return false; + } else if (ch == '"') { + return true; + } else if (ch == '\\') { + if ((ch = in.getc()) == -1) { + return false; + } + switch (ch) { +#define MAP(sym, val) \ + case sym: \ + out.push_back(val); \ + break + MAP('"', '\"'); + MAP('\\', '\\'); + MAP('/', '/'); + MAP('b', '\b'); + MAP('f', '\f'); + MAP('n', '\n'); + MAP('r', '\r'); + MAP('t', '\t'); +#undef MAP + case 'u': + if (!_parse_codepoint(out, in)) { + return false; + } + break; + default: + return false; + } + } else { + out.push_back(static_cast(ch)); + } + } + return false; +} + +template inline bool _parse_array(Context &ctx, input &in) { + if (!ctx.parse_array_start()) { + return false; + } + size_t idx = 0; + if (in.expect(']')) { + return ctx.parse_array_stop(idx); + } + do { + if (!ctx.parse_array_item(in, idx)) { + return false; + } + idx++; + } while (in.expect(',')); + return in.expect(']') && ctx.parse_array_stop(idx); +} + +template inline bool _parse_object(Context &ctx, input &in) { + if (!ctx.parse_object_start()) { + return false; + } + if (in.expect('}')) { + return ctx.parse_object_stop(); + } + do { + std::string key; + if (!in.expect('"') || !_parse_string(key, in) || !in.expect(':')) { + return false; + } + if (!ctx.parse_object_item(in, key)) { + return false; + } + } while (in.expect(',')); + return in.expect('}') && ctx.parse_object_stop(); +} + +template inline std::string _parse_number(input &in) { + std::string num_str; + while (1) { + int ch = in.getc(); + if (('0' <= ch && ch <= '9') || ch == '+' || ch == '-' || ch == 'e' || ch == 'E') { + num_str.push_back(static_cast(ch)); + } else if (ch == '.') { +#if PICOJSON_USE_LOCALE + num_str += localeconv()->decimal_point; +#else + num_str.push_back('.'); +#endif + } else { + in.ungetc(); + break; + } + } + return num_str; +} + +template inline bool _parse(Context &ctx, input &in) { + in.skip_ws(); + int ch = in.getc(); + switch (ch) { +#define IS(ch, text, op) \ + case ch: \ + if (in.match(text) && op) { \ + return true; \ + } else { \ + return false; \ + } + IS('n', "ull", ctx.set_null()); + IS('f', "alse", ctx.set_bool(false)); + IS('t', "rue", ctx.set_bool(true)); +#undef IS + case '"': + return ctx.parse_string(in); + case '[': + return _parse_array(ctx, in); + case '{': + return _parse_object(ctx, in); + default: + if (('0' <= ch && ch <= '9') || ch == '-') { + double f; + char *endp; + in.ungetc(); + std::string num_str(_parse_number(in)); + if (num_str.empty()) { + return false; + } +#ifdef PICOJSON_USE_INT64 + { + errno = 0; + intmax_t ival = strtoimax(num_str.c_str(), &endp, 10); + if (errno == 0 && std::numeric_limits::min() <= ival && ival <= std::numeric_limits::max() && + endp == num_str.c_str() + num_str.size()) { + ctx.set_int64(ival); + return true; + } + } +#endif + f = strtod(num_str.c_str(), &endp); + if (endp == num_str.c_str() + num_str.size()) { + ctx.set_number(f); + return true; + } + return false; + } + break; + } + in.ungetc(); + return false; +} + +class deny_parse_context { +public: + bool set_null() { + return false; + } + bool set_bool(bool) { + return false; + } +#ifdef PICOJSON_USE_INT64 + bool set_int64(int64_t) { + return false; + } +#endif + bool set_number(double) { + return false; + } + template bool parse_string(input &) { + return false; + } + bool parse_array_start() { + return false; + } + template bool parse_array_item(input &, size_t) { + return false; + } + bool parse_array_stop(size_t) { + return false; + } + bool parse_object_start() { + return false; + } + template bool parse_object_item(input &, const std::string &) { + return false; + } +}; + +class default_parse_context { +protected: + value *out_; + size_t depths_; + +public: + default_parse_context(value *out, size_t depths = DEFAULT_MAX_DEPTHS) : out_(out), depths_(depths) { + } + bool set_null() { + *out_ = value(); + return true; + } + bool set_bool(bool b) { + *out_ = value(b); + return true; + } +#ifdef PICOJSON_USE_INT64 + bool set_int64(int64_t i) { + *out_ = value(i); + return true; + } +#endif + bool set_number(double f) { + *out_ = value(f); + return true; + } + template bool parse_string(input &in) { + *out_ = value(string_type, false); + return _parse_string(out_->get(), in); + } + bool parse_array_start() { + if (depths_ == 0) + return false; + --depths_; + *out_ = value(array_type, false); + return true; + } + template bool parse_array_item(input &in, size_t) { + array &a = out_->get(); + a.push_back(value()); + default_parse_context ctx(&a.back(), depths_); + return _parse(ctx, in); + } + bool parse_array_stop(size_t) { + ++depths_; + return true; + } + bool parse_object_start() { + if (depths_ == 0) + return false; + *out_ = value(object_type, false); + return true; + } + template bool parse_object_item(input &in, const std::string &key) { + object &o = out_->get(); + default_parse_context ctx(&o[key], depths_); + return _parse(ctx, in); + } + bool parse_object_stop() { + ++depths_; + return true; + } + +private: + default_parse_context(const default_parse_context &); + default_parse_context &operator=(const default_parse_context &); +}; + +class null_parse_context { +protected: + size_t depths_; + +public: + struct dummy_str { + void push_back(int) { + } + }; + +public: + null_parse_context(size_t depths = DEFAULT_MAX_DEPTHS) : depths_(depths) { + } + bool set_null() { + return true; + } + bool set_bool(bool) { + return true; + } +#ifdef PICOJSON_USE_INT64 + bool set_int64(int64_t) { + return true; + } +#endif + bool set_number(double) { + return true; + } + template bool parse_string(input &in) { + dummy_str s; + return _parse_string(s, in); + } + bool parse_array_start() { + if (depths_ == 0) + return false; + --depths_; + return true; + } + template bool parse_array_item(input &in, size_t) { + return _parse(*this, in); + } + bool parse_array_stop(size_t) { + ++depths_; + return true; + } + bool parse_object_start() { + if (depths_ == 0) + return false; + --depths_; + return true; + } + template bool parse_object_item(input &in, const std::string &) { + ++depths_; + return _parse(*this, in); + } + bool parse_object_stop() { + return true; + } + +private: + null_parse_context(const null_parse_context &); + null_parse_context &operator=(const null_parse_context &); +}; + +// obsolete, use the version below +template inline std::string parse(value &out, Iter &pos, const Iter &last) { + std::string err; + pos = parse(out, pos, last, &err); + return err; +} + +template inline Iter _parse(Context &ctx, const Iter &first, const Iter &last, std::string *err) { + input in(first, last); + if (!_parse(ctx, in) && err != NULL) { + char buf[64]; + SNPRINTF(buf, sizeof(buf), "syntax error at line %d near: ", in.line()); + *err = buf; + while (1) { + int ch = in.getc(); + if (ch == -1 || ch == '\n') { + break; + } else if (ch >= ' ') { + err->push_back(static_cast(ch)); + } + } + } + return in.cur(); +} + +template inline Iter parse(value &out, const Iter &first, const Iter &last, std::string *err) { + default_parse_context ctx(&out); + return _parse(ctx, first, last, err); +} + +inline std::string parse(value &out, const std::string &s) { + std::string err; + parse(out, s.begin(), s.end(), &err); + return err; +} + +inline std::string parse(value &out, std::istream &is) { + std::string err; + parse(out, std::istreambuf_iterator(is.rdbuf()), std::istreambuf_iterator(), &err); + return err; +} + +template struct last_error_t { static std::string s; }; +template std::string last_error_t::s; + +inline void set_last_error(const std::string &s) { + last_error_t::s = s; +} + +inline const std::string &get_last_error() { + return last_error_t::s; +} + +inline bool operator==(const value &x, const value &y) { + if (x.is()) + return y.is(); +#define PICOJSON_CMP(type) \ + if (x.is()) \ + return y.is() && x.get() == y.get() + PICOJSON_CMP(bool); + PICOJSON_CMP(double); + PICOJSON_CMP(std::string); + PICOJSON_CMP(array); + PICOJSON_CMP(object); +#undef PICOJSON_CMP + PICOJSON_ASSERT(0); +#ifdef _MSC_VER + __assume(0); +#endif + return false; +} + +inline bool operator!=(const value &x, const value &y) { + return !(x == y); +} +} + +#if !PICOJSON_USE_RVALUE_REFERENCE +namespace std { +template <> inline void swap(picojson::value &x, picojson::value &y) { + x.swap(y); +} +} +#endif + +inline std::istream &operator>>(std::istream &is, picojson::value &x) { + picojson::set_last_error(std::string()); + const std::string err(picojson::parse(x, is)); + if (!err.empty()) { + picojson::set_last_error(err); + is.setstate(std::ios::failbit); + } + return is; +} + +inline std::ostream &operator<<(std::ostream &os, const picojson::value &x) { + x.serialize(std::ostream_iterator(os)); + return os; +} +#ifdef _MSC_VER +#pragma warning(pop) +#endif + +#endif \ No newline at end of file