commit
bbf2401e3e
@ -0,0 +1,29 @@
|
||||
# This file is used by clang-format to autoformat paddle source code
|
||||
#
|
||||
# The clang-format is part of llvm toolchain.
|
||||
# It need to install llvm and clang to format source code style.
|
||||
#
|
||||
# The basic usage is,
|
||||
# clang-format -i -style=file PATH/TO/SOURCE/CODE
|
||||
#
|
||||
# The -style=file implicit use ".clang-format" file located in one of
|
||||
# parent directory.
|
||||
# The -i means inplace change.
|
||||
#
|
||||
# The document of clang-format is
|
||||
# http://clang.llvm.org/docs/ClangFormat.html
|
||||
# http://clang.llvm.org/docs/ClangFormatStyleOptions.html
|
||||
---
|
||||
Language: Cpp
|
||||
BasedOnStyle: Google
|
||||
IndentWidth: 4
|
||||
TabWidth: 4
|
||||
ContinuationIndentWidth: 4
|
||||
MaxEmptyLinesToKeep: 2
|
||||
AccessModifierOffset: -2 # The private/protected/public has no indent in class
|
||||
Standard: Cpp11
|
||||
AllowAllParametersOfDeclarationOnNextLine: true
|
||||
BinPackParameters: false
|
||||
BinPackArguments: false
|
||||
...
|
||||
|
@ -1 +1,2 @@
|
||||
tools/valgrind*
|
||||
*log
|
||||
|
@ -1,12 +0,0 @@
|
||||
include(FetchContent)
|
||||
|
||||
FetchContent_Declare(
|
||||
gflags
|
||||
URL https://github.com/gflags/gflags/archive/v2.2.1.zip
|
||||
URL_HASH SHA256=4e44b69e709c826734dbbbd5208f61888a2faf63f239d73d8ba0011b2dccc97a
|
||||
)
|
||||
|
||||
FetchContent_MakeAvailable(gflags)
|
||||
|
||||
# openfst need
|
||||
include_directories(${gflags_BINARY_DIR}/include)
|
@ -0,0 +1,11 @@
|
||||
include(FetchContent)
|
||||
|
||||
FetchContent_Declare(
|
||||
gflags
|
||||
URL https://github.com/gflags/gflags/archive/v2.2.2.zip
|
||||
URL_HASH SHA256=19713a36c9f32b33df59d1c79b4958434cb005b5b47dc5400a7a4b078111d9b5
|
||||
)
|
||||
FetchContent_MakeAvailable(gflags)
|
||||
|
||||
# openfst need
|
||||
include_directories(${gflags_BINARY_DIR}/include)
|
@ -1,8 +1,8 @@
|
||||
include(FetchContent)
|
||||
FetchContent_Declare(
|
||||
gtest
|
||||
URL https://github.com/google/googletest/archive/release-1.10.0.zip
|
||||
URL_HASH SHA256=94c634d499558a76fa649edb13721dce6e98fb1e7018dfaeba3cd7a083945e91
|
||||
URL https://github.com/google/googletest/archive/release-1.11.0.zip
|
||||
URL_HASH SHA256=353571c2440176ded91c2de6d6cd88ddd41401d14692ec1f99e35d013feda55a
|
||||
)
|
||||
FetchContent_MakeAvailable(gtest)
|
||||
|
@ -0,0 +1,49 @@
|
||||
set(paddle_SOURCE_DIR ${fc_patch}/paddle-lib)
|
||||
set(paddle_PREFIX_DIR ${fc_patch}/paddle-lib-prefix)
|
||||
|
||||
include(FetchContent)
|
||||
FetchContent_Declare(
|
||||
paddle
|
||||
URL https://paddle-inference-lib.bj.bcebos.com/2.2.2/cxx_c/Linux/CPU/gcc8.2_avx_mkl/paddle_inference.tgz
|
||||
URL_HASH SHA256=7c6399e778c6554a929b5a39ba2175e702e115145e8fa690d2af974101d98873
|
||||
PREFIX ${paddle_PREFIX_DIR}
|
||||
SOURCE_DIR ${paddle_SOURCE_DIR}
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND ""
|
||||
INSTALL_COMMAND ""
|
||||
)
|
||||
FetchContent_MakeAvailable(paddle)
|
||||
|
||||
set(PADDLE_LIB_THIRD_PARTY_PATH "${paddle_SOURCE_DIR}/third_party/install/")
|
||||
|
||||
include_directories("${paddle_SOURCE_DIR}/paddle/include")
|
||||
include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/include")
|
||||
include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}xxhash/include")
|
||||
include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}cryptopp/include")
|
||||
|
||||
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/lib")
|
||||
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}xxhash/lib")
|
||||
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}cryptopp/lib")
|
||||
link_directories("${paddle_SOURCE_DIR}/paddle/lib")
|
||||
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}mklml/lib")
|
||||
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}mkldnn/lib")
|
||||
|
||||
##paddle with mkl
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp")
|
||||
set(MATH_LIB_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}mklml")
|
||||
include_directories("${MATH_LIB_PATH}/include")
|
||||
set(MATH_LIB ${MATH_LIB_PATH}/lib/libmklml_intel${CMAKE_SHARED_LIBRARY_SUFFIX}
|
||||
${MATH_LIB_PATH}/lib/libiomp5${CMAKE_SHARED_LIBRARY_SUFFIX})
|
||||
|
||||
set(MKLDNN_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}mkldnn")
|
||||
include_directories("${MKLDNN_PATH}/include")
|
||||
set(MKLDNN_LIB ${MKLDNN_PATH}/lib/libmkldnn.so.0)
|
||||
set(EXTERNAL_LIB "-lrt -ldl -lpthread")
|
||||
|
||||
# global vars
|
||||
set(DEPS ${paddle_SOURCE_DIR}/paddle/lib/libpaddle_inference${CMAKE_SHARED_LIBRARY_SUFFIX} CACHE INTERNAL "deps")
|
||||
set(DEPS ${DEPS}
|
||||
${MATH_LIB} ${MKLDNN_LIB}
|
||||
glog gflags protobuf xxhash cryptopp
|
||||
${EXTERNAL_LIB} CACHE INTERNAL "deps")
|
||||
message(STATUS "Deps libraries: ${DEPS}")
|
@ -1,8 +1,9 @@
|
||||
# Codelab
|
||||
|
||||
## introduction
|
||||
> The below is for developing and offline testing.
|
||||
> Do not run it only if you know what it is.
|
||||
|
||||
> The below is for developing and offline testing. Do not run it only if you know what it is.
|
||||
* nnet
|
||||
* feat
|
||||
* decoder
|
||||
* u2
|
||||
|
@ -0,0 +1,2 @@
|
||||
data
|
||||
exp
|
@ -0,0 +1 @@
|
||||
data
|
@ -0,0 +1 @@
|
||||
# u2/u2pp Streaming Test
|
@ -0,0 +1,22 @@
|
||||
#!/bin/bash
|
||||
set +x
|
||||
set -e
|
||||
|
||||
. path.sh
|
||||
|
||||
data=data
|
||||
exp=exp
|
||||
mkdir -p $exp
|
||||
ckpt_dir=$data/model
|
||||
model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
|
||||
|
||||
ctc_prefix_beam_search_decoder_main \
|
||||
--model_path=$model_dir/export.jit \
|
||||
--nnet_decoder_chunk=16 \
|
||||
--receptive_field_length=7 \
|
||||
--subsampling_rate=4 \
|
||||
--vocab_path=$model_dir/unit.txt \
|
||||
--feature_rspecifier=ark,t:$exp/fbank.ark \
|
||||
--result_wspecifier=ark,t:$exp/result.ark
|
||||
|
||||
echo "u2 ctc prefix beam search decode."
|
@ -0,0 +1,27 @@
|
||||
#!/bin/bash
|
||||
set -x
|
||||
set -e
|
||||
|
||||
. path.sh
|
||||
|
||||
data=data
|
||||
exp=exp
|
||||
mkdir -p $exp
|
||||
ckpt_dir=./data/model
|
||||
model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
|
||||
|
||||
|
||||
cmvn_json2kaldi_main \
|
||||
--json_file $model_dir/mean_std.json \
|
||||
--cmvn_write_path $exp/cmvn.ark \
|
||||
--binary=false
|
||||
|
||||
echo "convert json cmvn to kaldi ark."
|
||||
|
||||
compute_fbank_main \
|
||||
--num_bins 80 \
|
||||
--wav_rspecifier=scp:$data/wav.scp \
|
||||
--cmvn_file=$exp/cmvn.ark \
|
||||
--feature_wspecifier=ark,t:$exp/fbank.ark
|
||||
|
||||
echo "compute fbank feature."
|
@ -0,0 +1,23 @@
|
||||
#!/bin/bash
|
||||
set -x
|
||||
set -e
|
||||
|
||||
. path.sh
|
||||
|
||||
data=data
|
||||
exp=exp
|
||||
mkdir -p $exp
|
||||
ckpt_dir=./data/model
|
||||
model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
|
||||
|
||||
u2_nnet_main \
|
||||
--model_path=$model_dir/export.jit \
|
||||
--feature_rspecifier=ark,t:$exp/fbank.ark \
|
||||
--nnet_decoder_chunk=16 \
|
||||
--receptive_field_length=7 \
|
||||
--subsampling_rate=4 \
|
||||
--acoustic_scale=1.0 \
|
||||
--nnet_encoder_outs_wspecifier=ark,t:$exp/encoder_outs.ark \
|
||||
--nnet_prob_wspecifier=ark,t:$exp/logprobs.ark
|
||||
echo "u2 nnet decode."
|
||||
|
@ -0,0 +1,22 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
. path.sh
|
||||
|
||||
data=data
|
||||
exp=exp
|
||||
mkdir -p $exp
|
||||
ckpt_dir=./data/model
|
||||
model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
|
||||
|
||||
u2_recognizer_main \
|
||||
--use_fbank=true \
|
||||
--num_bins=80 \
|
||||
--cmvn_file=$exp/cmvn.ark \
|
||||
--model_path=$model_dir/export.jit \
|
||||
--nnet_decoder_chunk=16 \
|
||||
--receptive_field_length=7 \
|
||||
--subsampling_rate=4 \
|
||||
--vocab_path=$model_dir/unit.txt \
|
||||
--wav_rspecifier=scp:$data/wav.scp \
|
||||
--result_wspecifier=ark,t:$exp/result.ark
|
@ -0,0 +1,18 @@
|
||||
# This contains the locations of binarys build required for running the examples.
|
||||
|
||||
unset GREP_OPTIONS
|
||||
|
||||
SPEECHX_ROOT=$PWD/../../../
|
||||
SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx
|
||||
|
||||
SPEECHX_TOOLS=$SPEECHX_ROOT/tools
|
||||
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
|
||||
|
||||
[ -d $SPEECHX_BUILD ] || { echo "Error: 'build/speechx' directory not found. please ensure that the project build successfully"; }
|
||||
|
||||
export LC_AL=C
|
||||
|
||||
export PATH=$PATH:$TOOLS_BIN:$SPEECHX_BUILD/nnet:$SPEECHX_BUILD/decoder:$SPEECHX_BUILD/frontend/audio:$SPEECHX_BUILD/recognizer
|
||||
|
||||
PADDLE_LIB_PATH=$(python -c "import os; import paddle; include_dir=paddle.sysconfig.get_include(); paddle_dir=os.path.split(include_dir)[0]; libs_dir=os.path.join(paddle_dir, 'libs'); fluid_dir=os.path.join(paddle_dir, 'fluid'); out=':'.join([libs_dir, fluid_dir]); print(out);")
|
||||
export LD_LIBRARY_PATH=$PADDLE_LIB_PATH:$LD_LIBRARY_PATH
|
@ -0,0 +1,43 @@
|
||||
#!/bin/bash
|
||||
set -x
|
||||
set -e
|
||||
|
||||
. path.sh
|
||||
|
||||
# 1. compile
|
||||
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
|
||||
pushd ${SPEECHX_ROOT}
|
||||
bash build.sh
|
||||
popd
|
||||
fi
|
||||
|
||||
# 2. download model
|
||||
if [ ! -f data/model/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model.tar.gz ]; then
|
||||
mkdir -p data/model
|
||||
pushd data/model
|
||||
wget -c https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/static/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model.tar.gz
|
||||
tar xzfv asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model.tar.gz
|
||||
popd
|
||||
fi
|
||||
|
||||
# produce wav scp
|
||||
if [ ! -f data/wav.scp ]; then
|
||||
mkdir -p data
|
||||
pushd data
|
||||
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
|
||||
echo "utt1 " $PWD/zh.wav > wav.scp
|
||||
popd
|
||||
fi
|
||||
|
||||
data=data
|
||||
exp=exp
|
||||
mkdir -p $exp
|
||||
ckpt_dir=./data/model
|
||||
model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
|
||||
|
||||
|
||||
./local/feat.sh
|
||||
|
||||
./local/nnet.sh
|
||||
|
||||
./local/decode.sh
|
@ -0,0 +1,5 @@
|
||||
# U2/U2++ Streaming ASR
|
||||
|
||||
## Examples
|
||||
|
||||
* `wenetspeech` - Streaming Decoding with wenetspeech u2/u2++ model. Using aishell test data for testing.
|
@ -0,0 +1,3 @@
|
||||
data
|
||||
utils
|
||||
exp
|
@ -0,0 +1,28 @@
|
||||
# u2/u2pp Streaming ASR
|
||||
|
||||
## Testing with Aishell Test Data
|
||||
|
||||
## Download wav and model
|
||||
|
||||
```
|
||||
run.sh --stop_stage 0
|
||||
```
|
||||
|
||||
### compute feature
|
||||
|
||||
```
|
||||
./run.sh --stage 1 --stop_stage 1
|
||||
```
|
||||
|
||||
### decoding using feature
|
||||
|
||||
```
|
||||
./run.sh --stage 2 --stop_stage 2
|
||||
```
|
||||
|
||||
### decoding using wav
|
||||
|
||||
|
||||
```
|
||||
./run.sh --stage 3 --stop_stage 3
|
||||
```
|
@ -0,0 +1,71 @@
|
||||
#!/bin/bash
|
||||
|
||||
# To be run from one directory above this script.
|
||||
. ./path.sh
|
||||
|
||||
nj=40
|
||||
text=data/local/lm/text
|
||||
lexicon=data/local/dict/lexicon.txt
|
||||
|
||||
for f in "$text" "$lexicon"; do
|
||||
[ ! -f $x ] && echo "$0: No such file $f" && exit 1;
|
||||
done
|
||||
|
||||
# Check SRILM tools
|
||||
if ! which ngram-count > /dev/null; then
|
||||
echo "srilm tools are not found, please download it and install it from: "
|
||||
echo "http://www.speech.sri.com/projects/srilm/download.html"
|
||||
echo "Then add the tools to your PATH"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# This script takes no arguments. It assumes you have already run
|
||||
# aishell_data_prep.sh.
|
||||
# It takes as input the files
|
||||
# data/local/lm/text
|
||||
# data/local/dict/lexicon.txt
|
||||
dir=data/local/lm
|
||||
mkdir -p $dir
|
||||
|
||||
cleantext=$dir/text.no_oov
|
||||
|
||||
# oov to <SPOKEN_NOISE>
|
||||
# lexicon line: word char0 ... charn
|
||||
# text line: utt word0 ... wordn -> line: <SPOKEN_NOISE> word0 ... wordn
|
||||
text_dir=$(dirname $text)
|
||||
split_name=$(basename $text)
|
||||
./local/split_data.sh $text_dir $text $split_name $nj
|
||||
|
||||
utils/run.pl JOB=1:$nj $text_dir/split${nj}/JOB/${split_name}.no_oov.log \
|
||||
cat ${text_dir}/split${nj}/JOB/${split_name} \| awk -v lex=$lexicon 'BEGIN{while((getline<lex) >0){ seen[$1]=1; } }
|
||||
{for(n=1; n<=NF;n++) { if (seen[$n]) { printf("%s ", $n); } else {printf("<SPOKEN_NOISE> ");} } printf("\n");}' \
|
||||
\> ${text_dir}/split${nj}/JOB/${split_name}.no_oov || exit 1;
|
||||
cat ${text_dir}/split${nj}/*/${split_name}.no_oov > $cleantext
|
||||
|
||||
# compute word counts, sort in descending order
|
||||
# line: count word
|
||||
cat $cleantext | awk '{for(n=2;n<=NF;n++) print $n; }' | sort --parallel=`nproc` | uniq -c | \
|
||||
sort --parallel=`nproc` -nr > $dir/word.counts || exit 1;
|
||||
|
||||
# Get counts from acoustic training transcripts, and add one-count
|
||||
# for each word in the lexicon (but not silence, we don't want it
|
||||
# in the LM-- we'll add it optionally later).
|
||||
cat $cleantext | awk '{for(n=2;n<=NF;n++) print $n; }' | \
|
||||
cat - <(grep -w -v '!SIL' $lexicon | awk '{print $1}') | \
|
||||
sort --parallel=`nproc` | uniq -c | sort --parallel=`nproc` -nr > $dir/unigram.counts || exit 1;
|
||||
|
||||
# word with <s> </s>
|
||||
cat $dir/unigram.counts | awk '{print $2}' | cat - <(echo "<s>"; echo "</s>" ) > $dir/wordlist
|
||||
|
||||
# hold out to compute ppl
|
||||
heldout_sent=10000 # Don't change this if you want result to be comparable with kaldi_lm results
|
||||
|
||||
mkdir -p $dir
|
||||
cat $cleantext | awk '{for(n=2;n<=NF;n++){ printf $n; if(n<NF) printf " "; else print ""; }}' | \
|
||||
head -$heldout_sent > $dir/heldout
|
||||
cat $cleantext | awk '{for(n=2;n<=NF;n++){ printf $n; if(n<NF) printf " "; else print ""; }}' | \
|
||||
tail -n +$heldout_sent > $dir/train
|
||||
|
||||
ngram-count -text $dir/train -order 3 -limit-vocab -vocab $dir/wordlist -unk \
|
||||
-map-unk "<UNK>" -kndiscount -interpolate -lm $dir/lm.arpa
|
||||
ngram -lm $dir/lm.arpa -ppl $dir/heldout
|
@ -0,0 +1,25 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
. path.sh
|
||||
|
||||
data=data
|
||||
exp=exp
|
||||
nj=20
|
||||
mkdir -p $exp
|
||||
ckpt_dir=./data/model
|
||||
model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
|
||||
|
||||
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/decoder.fbank.wolm.log \
|
||||
ctc_prefix_beam_search_decoder_main \
|
||||
--model_path=$model_dir/export.jit \
|
||||
--vocab_path=$model_dir/unit.txt \
|
||||
--nnet_decoder_chunk=16 \
|
||||
--receptive_field_length=7 \
|
||||
--subsampling_rate=4 \
|
||||
--feature_rspecifier=scp:$data/split${nj}/JOB/fbank.scp \
|
||||
--result_wspecifier=ark,t:$data/split${nj}/JOB/result_decode.ark
|
||||
|
||||
cat $data/split${nj}/*/result_decode.ark > $exp/${label_file}
|
||||
utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file} > $exp/${wer}
|
||||
tail -n 7 $exp/${wer}
|
@ -0,0 +1,31 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
. path.sh
|
||||
|
||||
data=data
|
||||
exp=exp
|
||||
nj=20
|
||||
mkdir -p $exp
|
||||
ckpt_dir=./data/model
|
||||
model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
|
||||
aishell_wav_scp=aishell_test.scp
|
||||
|
||||
cmvn_json2kaldi_main \
|
||||
--json_file $model_dir/mean_std.json \
|
||||
--cmvn_write_path $exp/cmvn.ark \
|
||||
--binary=false
|
||||
|
||||
echo "convert json cmvn to kaldi ark."
|
||||
|
||||
./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj
|
||||
|
||||
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/feat.log \
|
||||
compute_fbank_main \
|
||||
--num_bins 80 \
|
||||
--cmvn_file=$exp/cmvn.ark \
|
||||
--streaming_chunk=36 \
|
||||
--wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \
|
||||
--feature_wspecifier=ark,scp:$data/split${nj}/JOB/fbank.ark,$data/split${nj}/JOB/fbank.scp
|
||||
|
||||
echo "compute fbank feature."
|
@ -0,0 +1,23 @@
|
||||
#!/bin/bash
|
||||
set -x
|
||||
set -e
|
||||
|
||||
. path.sh
|
||||
|
||||
data=data
|
||||
exp=exp
|
||||
mkdir -p $exp
|
||||
ckpt_dir=./data/model
|
||||
model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
|
||||
|
||||
u2_nnet_main \
|
||||
--model_path=$model_dir/export.jit \
|
||||
--feature_rspecifier=ark,t:$exp/fbank.ark \
|
||||
--nnet_decoder_chunk=16 \
|
||||
--receptive_field_length=7 \
|
||||
--subsampling_rate=4 \
|
||||
--acoustic_scale=1.0 \
|
||||
--nnet_encoder_outs_wspecifier=ark,t:$exp/encoder_outs.ark \
|
||||
--nnet_prob_wspecifier=ark,t:$exp/logprobs.ark
|
||||
echo "u2 nnet decode."
|
||||
|
@ -0,0 +1,37 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
. path.sh
|
||||
|
||||
data=data
|
||||
exp=exp
|
||||
nj=20
|
||||
|
||||
|
||||
mkdir -p $exp
|
||||
ckpt_dir=./data/model
|
||||
model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
|
||||
aishell_wav_scp=aishell_test.scp
|
||||
text=$data/test/text
|
||||
|
||||
./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj
|
||||
|
||||
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recognizer.log \
|
||||
u2_recognizer_main \
|
||||
--use_fbank=true \
|
||||
--num_bins=80 \
|
||||
--cmvn_file=$exp/cmvn.ark \
|
||||
--model_path=$model_dir/export.jit \
|
||||
--vocab_path=$model_dir/unit.txt \
|
||||
--nnet_decoder_chunk=16 \
|
||||
--receptive_field_length=7 \
|
||||
--subsampling_rate=4 \
|
||||
--wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \
|
||||
--result_wspecifier=ark,t:$data/split${nj}/JOB/result_recognizer.ark
|
||||
|
||||
|
||||
cat $data/split${nj}/*/result_recognizer.ark > $exp/aishell_recognizer
|
||||
utils/compute-wer.py --char=1 --v=1 $text $exp/aishell_recognizer > $exp/aishell.recognizer.err
|
||||
echo "recognizer test have finished!!!"
|
||||
echo "please checkout in $exp/aishell.recognizer.err"
|
||||
tail -n 7 $exp/aishell.recognizer.err
|
@ -0,0 +1,30 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -eo pipefail
|
||||
|
||||
data=$1
|
||||
scp=$2
|
||||
split_name=$3
|
||||
numsplit=$4
|
||||
|
||||
# save in $data/split{n}
|
||||
# $scp to split
|
||||
#
|
||||
|
||||
if [[ ! $numsplit -gt 0 ]]; then
|
||||
echo "$0: Invalid num-split argument";
|
||||
exit 1;
|
||||
fi
|
||||
|
||||
directories=$(for n in `seq $numsplit`; do echo $data/split${numsplit}/$n; done)
|
||||
scp_splits=$(for n in `seq $numsplit`; do echo $data/split${numsplit}/$n/${split_name}; done)
|
||||
|
||||
# if this mkdir fails due to argument-list being too long, iterate.
|
||||
if ! mkdir -p $directories >&/dev/null; then
|
||||
for n in `seq $numsplit`; do
|
||||
mkdir -p $data/split${numsplit}/$n
|
||||
done
|
||||
fi
|
||||
|
||||
echo "utils/split_scp.pl $scp $scp_splits"
|
||||
utils/split_scp.pl $scp $scp_splits
|
@ -0,0 +1,18 @@
|
||||
# This contains the locations of binarys build required for running the examples.
|
||||
|
||||
unset GREP_OPTIONS
|
||||
|
||||
SPEECHX_ROOT=$PWD/../../../
|
||||
SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx
|
||||
|
||||
SPEECHX_TOOLS=$SPEECHX_ROOT/tools
|
||||
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
|
||||
|
||||
[ -d $SPEECHX_BUILD ] || { echo "Error: 'build/speechx' directory not found. please ensure that the project build successfully"; }
|
||||
|
||||
export LC_AL=C
|
||||
|
||||
export PATH=$PATH:$TOOLS_BIN:$SPEECHX_BUILD/nnet:$SPEECHX_BUILD/decoder:$SPEECHX_BUILD/frontend/audio:$SPEECHX_BUILD/recognizer
|
||||
|
||||
PADDLE_LIB_PATH=$(python -c "import os; import paddle; include_dir=paddle.sysconfig.get_include(); paddle_dir=os.path.split(include_dir)[0]; libs_dir=os.path.join(paddle_dir, 'libs'); fluid_dir=os.path.join(paddle_dir, 'fluid'); out=':'.join([libs_dir, fluid_dir]); print(out);")
|
||||
export LD_LIBRARY_PATH=$PADDLE_LIB_PATH:$LD_LIBRARY_PATH
|
@ -0,0 +1,76 @@
|
||||
#!/bin/bash
|
||||
set +x
|
||||
set -e
|
||||
|
||||
. path.sh
|
||||
|
||||
nj=40
|
||||
stage=0
|
||||
stop_stage=5
|
||||
|
||||
. utils/parse_options.sh
|
||||
|
||||
# input
|
||||
data=data
|
||||
exp=exp
|
||||
mkdir -p $exp $data
|
||||
|
||||
|
||||
# 1. compile
|
||||
if [ ! -d ${SPEECHX_BUILD} ]; then
|
||||
pushd ${SPEECHX_ROOT}
|
||||
bash build.sh
|
||||
popd
|
||||
fi
|
||||
|
||||
|
||||
ckpt_dir=$data/model
|
||||
model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
|
||||
|
||||
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ];then
|
||||
# download model
|
||||
if [ ! -f $ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model.tar.gz ]; then
|
||||
mkdir -p $ckpt_dir
|
||||
pushd $ckpt_dir
|
||||
|
||||
wget -c https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/static/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model.tar.gz
|
||||
tar xzfv asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model.tar.gz
|
||||
|
||||
popd
|
||||
fi
|
||||
|
||||
# test wav scp
|
||||
if [ ! -f data/wav.scp ]; then
|
||||
mkdir -p $data
|
||||
pushd $data
|
||||
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
|
||||
echo "utt1 " $PWD/zh.wav > wav.scp
|
||||
popd
|
||||
fi
|
||||
|
||||
# aishell wav scp
|
||||
if [ ! -d $data/test ]; then
|
||||
pushd $data
|
||||
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_test.zip
|
||||
unzip aishell_test.zip
|
||||
popd
|
||||
|
||||
realpath $data/test/*/*.wav > $data/wavlist
|
||||
awk -F '/' '{ print $(NF) }' $data/wavlist | awk -F '.' '{ print $1 }' > $data/utt_id
|
||||
paste $data/utt_id $data/wavlist > $data/$aishell_wav_scp
|
||||
fi
|
||||
fi
|
||||
|
||||
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
./local/feat.sh
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
./local/decode.sh
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
./loca/recognizer.sh
|
||||
fi
|
@ -1,25 +1,55 @@
|
||||
project(decoder)
|
||||
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR/ctc_decoders})
|
||||
add_library(decoder STATIC
|
||||
ctc_beam_search_decoder.cc
|
||||
|
||||
set(srcs)
|
||||
|
||||
if (USING_DS2)
|
||||
list(APPEND srcs
|
||||
ctc_decoders/decoder_utils.cpp
|
||||
ctc_decoders/path_trie.cpp
|
||||
ctc_decoders/scorer.cpp
|
||||
ctc_beam_search_decoder.cc
|
||||
ctc_tlg_decoder.cc
|
||||
recognizer.cc
|
||||
)
|
||||
target_link_libraries(decoder PUBLIC kenlm utils fst frontend nnet kaldi-decoder)
|
||||
endif()
|
||||
|
||||
set(BINS
|
||||
ctc_prefix_beam_search_decoder_main
|
||||
if (USING_U2)
|
||||
list(APPEND srcs
|
||||
ctc_prefix_beam_search_decoder.cc
|
||||
)
|
||||
endif()
|
||||
|
||||
add_library(decoder STATIC ${srcs})
|
||||
target_link_libraries(decoder PUBLIC kenlm utils fst frontend nnet kaldi-decoder absl::strings)
|
||||
|
||||
# test
|
||||
if (USING_DS2)
|
||||
set(BINS
|
||||
ctc_beam_search_decoder_main
|
||||
nnet_logprob_decoder_main
|
||||
recognizer_main
|
||||
tlg_decoder_main
|
||||
)
|
||||
ctc_tlg_decoder_main
|
||||
)
|
||||
|
||||
foreach(bin_name IN LISTS BINS)
|
||||
foreach(bin_name IN LISTS BINS)
|
||||
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
|
||||
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
||||
target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
|
||||
endforeach()
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
|
||||
if (USING_U2)
|
||||
set(TEST_BINS
|
||||
ctc_prefix_beam_search_decoder_main
|
||||
)
|
||||
|
||||
foreach(bin_name IN LISTS TEST_BINS)
|
||||
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
|
||||
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
||||
target_link_libraries(${bin_name} nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util)
|
||||
target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS})
|
||||
target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
|
||||
target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS})
|
||||
endforeach()
|
||||
|
||||
endif()
|
||||
|
||||
|
@ -0,0 +1,78 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
|
||||
#include "base/common.h"
|
||||
#include "util/parse-options.h"
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
|
||||
struct CTCBeamSearchOptions {
|
||||
// common
|
||||
int blank;
|
||||
|
||||
// ds2
|
||||
std::string dict_file;
|
||||
std::string lm_path;
|
||||
int beam_size;
|
||||
BaseFloat alpha;
|
||||
BaseFloat beta;
|
||||
BaseFloat cutoff_prob;
|
||||
int cutoff_top_n;
|
||||
int num_proc_bsearch;
|
||||
|
||||
// u2
|
||||
int first_beam_size;
|
||||
int second_beam_size;
|
||||
CTCBeamSearchOptions()
|
||||
: blank(0),
|
||||
dict_file("vocab.txt"),
|
||||
lm_path(""),
|
||||
beam_size(300),
|
||||
alpha(1.9f),
|
||||
beta(5.0),
|
||||
cutoff_prob(0.99f),
|
||||
cutoff_top_n(40),
|
||||
num_proc_bsearch(10),
|
||||
first_beam_size(10),
|
||||
second_beam_size(10) {}
|
||||
|
||||
void Register(kaldi::OptionsItf* opts) {
|
||||
std::string module = "Ds2BeamSearchConfig: ";
|
||||
opts->Register("dict", &dict_file, module + "vocab file path.");
|
||||
opts->Register(
|
||||
"lm-path", &lm_path, module + "ngram language model path.");
|
||||
opts->Register("alpha", &alpha, module + "alpha");
|
||||
opts->Register("beta", &beta, module + "beta");
|
||||
opts->Register("beam-size",
|
||||
&beam_size,
|
||||
module + "beam size for beam search method");
|
||||
opts->Register("cutoff-prob", &cutoff_prob, module + "cutoff probs");
|
||||
opts->Register("cutoff-top-n", &cutoff_top_n, module + "cutoff top n");
|
||||
opts->Register(
|
||||
"num-proc-bsearch", &num_proc_bsearch, module + "num proc bsearch");
|
||||
|
||||
opts->Register("blank", &blank, "blank id, default is 0.");
|
||||
|
||||
module = "U2BeamSearchConfig: ";
|
||||
opts->Register(
|
||||
"first-beam-size", &first_beam_size, module + "first beam size.");
|
||||
opts->Register("second-beam-size",
|
||||
&second_beam_size,
|
||||
module + "second beam size.");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ppspeech
|
@ -0,0 +1,370 @@
|
||||
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu)
|
||||
// 2022 Binbin Zhang (binbzha@qq.com)
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
#include "decoder/ctc_prefix_beam_search_decoder.h"
|
||||
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "base/common.h"
|
||||
#include "decoder/ctc_beam_search_opt.h"
|
||||
#include "decoder/ctc_prefix_beam_search_score.h"
|
||||
#include "utils/math.h"
|
||||
|
||||
#ifdef USE_PROFILING
|
||||
#include "paddle/fluid/platform/profiler.h"
|
||||
using paddle::platform::RecordEvent;
|
||||
using paddle::platform::TracerEventType;
|
||||
#endif
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
CTCPrefixBeamSearch::CTCPrefixBeamSearch(const std::string& vocab_path,
|
||||
const CTCBeamSearchOptions& opts)
|
||||
: opts_(opts) {
|
||||
unit_table_ = std::shared_ptr<fst::SymbolTable>(
|
||||
fst::SymbolTable::ReadText(vocab_path));
|
||||
CHECK(unit_table_ != nullptr);
|
||||
|
||||
Reset();
|
||||
}
|
||||
|
||||
void CTCPrefixBeamSearch::Reset() {
|
||||
num_frame_decoded_ = 0;
|
||||
|
||||
cur_hyps_.clear();
|
||||
|
||||
hypotheses_.clear();
|
||||
likelihood_.clear();
|
||||
viterbi_likelihood_.clear();
|
||||
times_.clear();
|
||||
outputs_.clear();
|
||||
|
||||
// empty hyp with Score
|
||||
std::vector<int> empty;
|
||||
PrefixScore prefix_score;
|
||||
prefix_score.InitEmpty();
|
||||
cur_hyps_[empty] = prefix_score;
|
||||
|
||||
outputs_.emplace_back(empty);
|
||||
hypotheses_.emplace_back(empty);
|
||||
likelihood_.emplace_back(prefix_score.TotalScore());
|
||||
times_.emplace_back(empty);
|
||||
}
|
||||
|
||||
void CTCPrefixBeamSearch::InitDecoder() { Reset(); }
|
||||
|
||||
|
||||
void CTCPrefixBeamSearch::AdvanceDecode(
|
||||
const std::shared_ptr<kaldi::DecodableInterface>& decodable) {
|
||||
while (1) {
|
||||
// forward frame by frame
|
||||
std::vector<kaldi::BaseFloat> frame_prob;
|
||||
bool flag = decodable->FrameLikelihood(num_frame_decoded_, &frame_prob);
|
||||
if (flag == false) {
|
||||
VLOG(1) << "decoder advance decode exit." << frame_prob.size();
|
||||
break;
|
||||
}
|
||||
|
||||
std::vector<std::vector<kaldi::BaseFloat>> likelihood;
|
||||
likelihood.push_back(frame_prob);
|
||||
AdvanceDecoding(likelihood);
|
||||
VLOG(2) << "num_frame_decoded_: " << num_frame_decoded_;
|
||||
}
|
||||
}
|
||||
|
||||
static bool PrefixScoreCompare(
|
||||
const std::pair<std::vector<int>, PrefixScore>& a,
|
||||
const std::pair<std::vector<int>, PrefixScore>& b) {
|
||||
// log domain
|
||||
return a.second.TotalScore() > b.second.TotalScore();
|
||||
}
|
||||
|
||||
|
||||
void CTCPrefixBeamSearch::AdvanceDecoding(
|
||||
const std::vector<std::vector<kaldi::BaseFloat>>& logp) {
|
||||
#ifdef USE_PROFILING
|
||||
RecordEvent event("CtcPrefixBeamSearch::AdvanceDecoding",
|
||||
TracerEventType::UserDefined,
|
||||
1);
|
||||
#endif
|
||||
|
||||
if (logp.size() == 0) return;
|
||||
|
||||
int first_beam_size =
|
||||
std::min(static_cast<int>(logp[0].size()), opts_.first_beam_size);
|
||||
|
||||
for (int t = 0; t < logp.size(); ++t, ++num_frame_decoded_) {
|
||||
const std::vector<kaldi::BaseFloat>& logp_t = logp[t];
|
||||
std::unordered_map<std::vector<int>, PrefixScore, PrefixScoreHash>
|
||||
next_hyps;
|
||||
|
||||
// 1. first beam prune, only select topk candidates
|
||||
std::vector<kaldi::BaseFloat> topk_score;
|
||||
std::vector<int32_t> topk_index;
|
||||
TopK(logp_t, first_beam_size, &topk_score, &topk_index);
|
||||
VLOG(2) << "topk: " << num_frame_decoded_ << " "
|
||||
<< *std::max_element(logp_t.begin(), logp_t.end()) << " "
|
||||
<< topk_score[0];
|
||||
for (int i = 0; i < topk_score.size(); i++) {
|
||||
VLOG(2) << "topk: " << num_frame_decoded_ << " " << topk_score[i];
|
||||
}
|
||||
|
||||
// 2. token passing
|
||||
for (int i = 0; i < topk_index.size(); ++i) {
|
||||
int id = topk_index[i];
|
||||
auto prob = topk_score[i];
|
||||
|
||||
for (const auto& it : cur_hyps_) {
|
||||
const std::vector<int>& prefix = it.first;
|
||||
const PrefixScore& prefix_score = it.second;
|
||||
|
||||
// If prefix doesn't exist in next_hyps, next_hyps[prefix] will
|
||||
// insert
|
||||
// PrefixScore(-inf, -inf) by default, since the default
|
||||
// constructor
|
||||
// of PrefixScore will set fields b(blank ending Score) and
|
||||
// nb(none blank ending Score) to -inf, respectively.
|
||||
|
||||
if (id == opts_.blank) {
|
||||
// case 0: *a + <blank> => *a, *a<blank> + <blank> => *a,
|
||||
// prefix not
|
||||
// change
|
||||
PrefixScore& next_score = next_hyps[prefix];
|
||||
next_score.b =
|
||||
LogSumExp(next_score.b, prefix_score.Score() + prob);
|
||||
|
||||
// timestamp, blank is slince, not effact timestamp
|
||||
next_score.v_b = prefix_score.ViterbiScore() + prob;
|
||||
next_score.times_b = prefix_score.Times();
|
||||
|
||||
// Prefix not changed, copy the context from pefix
|
||||
if (context_graph_ && !next_score.has_context) {
|
||||
next_score.CopyContext(prefix_score);
|
||||
next_score.has_context = true;
|
||||
}
|
||||
|
||||
} else if (!prefix.empty() && id == prefix.back()) {
|
||||
// case 1: *a + a => *a, prefix not changed
|
||||
PrefixScore& next_score1 = next_hyps[prefix];
|
||||
next_score1.nb =
|
||||
LogSumExp(next_score1.nb, prefix_score.nb + prob);
|
||||
|
||||
// timestamp, non-blank symbol effact timestamp
|
||||
if (next_score1.v_nb < prefix_score.v_nb + prob) {
|
||||
// compute viterbi Score
|
||||
next_score1.v_nb = prefix_score.v_nb + prob;
|
||||
if (next_score1.cur_token_prob < prob) {
|
||||
// store max token prob
|
||||
next_score1.cur_token_prob = prob;
|
||||
// update this timestamp as token appeared here.
|
||||
next_score1.times_nb = prefix_score.times_nb;
|
||||
assert(next_score1.times_nb.size() > 0);
|
||||
next_score1.times_nb.back() = num_frame_decoded_;
|
||||
}
|
||||
}
|
||||
|
||||
// Prefix not changed, copy the context from pefix
|
||||
if (context_graph_ && !next_score1.has_context) {
|
||||
next_score1.CopyContext(prefix_score);
|
||||
next_score1.has_context = true;
|
||||
}
|
||||
|
||||
// case 2: *a<blank> + a => *aa, prefix changed.
|
||||
std::vector<int> new_prefix(prefix);
|
||||
new_prefix.emplace_back(id);
|
||||
PrefixScore& next_score2 = next_hyps[new_prefix];
|
||||
next_score2.nb =
|
||||
LogSumExp(next_score2.nb, prefix_score.b + prob);
|
||||
|
||||
// timestamp, non-blank symbol effact timestamp
|
||||
if (next_score2.v_nb < prefix_score.v_b + prob) {
|
||||
// compute viterbi Score
|
||||
next_score2.v_nb = prefix_score.v_b + prob;
|
||||
// new token added
|
||||
next_score2.cur_token_prob = prob;
|
||||
next_score2.times_nb = prefix_score.times_b;
|
||||
next_score2.times_nb.emplace_back(num_frame_decoded_);
|
||||
}
|
||||
|
||||
// Prefix changed, calculate the context Score.
|
||||
if (context_graph_ && !next_score2.has_context) {
|
||||
next_score2.UpdateContext(
|
||||
context_graph_, prefix_score, id, prefix.size());
|
||||
next_score2.has_context = true;
|
||||
}
|
||||
|
||||
} else {
|
||||
// id != prefix.back()
|
||||
// case 3: *a + b => *ab, *a<blank> +b => *ab
|
||||
std::vector<int> new_prefix(prefix);
|
||||
new_prefix.emplace_back(id);
|
||||
PrefixScore& next_score = next_hyps[new_prefix];
|
||||
next_score.nb =
|
||||
LogSumExp(next_score.nb, prefix_score.Score() + prob);
|
||||
|
||||
// timetamp, non-blank symbol effact timestamp
|
||||
if (next_score.v_nb < prefix_score.ViterbiScore() + prob) {
|
||||
next_score.v_nb = prefix_score.ViterbiScore() + prob;
|
||||
|
||||
next_score.cur_token_prob = prob;
|
||||
next_score.times_nb = prefix_score.Times();
|
||||
next_score.times_nb.emplace_back(num_frame_decoded_);
|
||||
}
|
||||
|
||||
// Prefix changed, calculate the context Score.
|
||||
if (context_graph_ && !next_score.has_context) {
|
||||
next_score.UpdateContext(
|
||||
context_graph_, prefix_score, id, prefix.size());
|
||||
next_score.has_context = true;
|
||||
}
|
||||
}
|
||||
} // end for (const auto& it : cur_hyps_)
|
||||
} // end for (int i = 0; i < topk_index.size(); ++i)
|
||||
|
||||
// 3. second beam prune, only keep top n best paths
|
||||
std::vector<std::pair<std::vector<int>, PrefixScore>> arr(
|
||||
next_hyps.begin(), next_hyps.end());
|
||||
int second_beam_size =
|
||||
std::min(static_cast<int>(arr.size()), opts_.second_beam_size);
|
||||
std::nth_element(arr.begin(),
|
||||
arr.begin() + second_beam_size,
|
||||
arr.end(),
|
||||
PrefixScoreCompare);
|
||||
arr.resize(second_beam_size);
|
||||
std::sort(arr.begin(), arr.end(), PrefixScoreCompare);
|
||||
|
||||
// 4. update cur_hyps by next_hyps, and get new result
|
||||
UpdateHypotheses(arr);
|
||||
} // end for (int t = 0; t < logp.size(); ++t, ++num_frame_decoded_)
|
||||
}
|
||||
|
||||
|
||||
void CTCPrefixBeamSearch::UpdateHypotheses(
|
||||
const std::vector<std::pair<std::vector<int>, PrefixScore>>& hyps) {
|
||||
cur_hyps_.clear();
|
||||
|
||||
outputs_.clear();
|
||||
hypotheses_.clear();
|
||||
likelihood_.clear();
|
||||
viterbi_likelihood_.clear();
|
||||
times_.clear();
|
||||
|
||||
for (auto& item : hyps) {
|
||||
cur_hyps_[item.first] = item.second;
|
||||
|
||||
UpdateOutputs(item);
|
||||
hypotheses_.emplace_back(std::move(item.first));
|
||||
likelihood_.emplace_back(item.second.TotalScore());
|
||||
viterbi_likelihood_.emplace_back(item.second.ViterbiScore());
|
||||
times_.emplace_back(item.second.Times());
|
||||
}
|
||||
}
|
||||
|
||||
void CTCPrefixBeamSearch::UpdateOutputs(
|
||||
const std::pair<std::vector<int>, PrefixScore>& prefix) {
|
||||
const std::vector<int>& input = prefix.first;
|
||||
const std::vector<int>& start_boundaries = prefix.second.start_boundaries;
|
||||
const std::vector<int>& end_boundaries = prefix.second.end_boundaries;
|
||||
|
||||
// add <context> </context> tag
|
||||
std::vector<int> output;
|
||||
int s = 0;
|
||||
int e = 0;
|
||||
for (int i = 0; i < input.size(); ++i) {
|
||||
output.emplace_back(input[i]);
|
||||
}
|
||||
|
||||
outputs_.emplace_back(output);
|
||||
}
|
||||
|
||||
void CTCPrefixBeamSearch::FinalizeSearch() {
|
||||
UpdateFinalContext();
|
||||
|
||||
VLOG(2) << "num_frame_decoded_: " << num_frame_decoded_;
|
||||
int cnt = 0;
|
||||
for (int i = 0; i < hypotheses_.size(); i++) {
|
||||
VLOG(2) << "hyp " << cnt << " len: " << hypotheses_[i].size()
|
||||
<< " ctc score: " << likelihood_[i];
|
||||
for (int j = 0; j < hypotheses_[i].size(); j++) {
|
||||
VLOG(2) << hypotheses_[i][j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CTCPrefixBeamSearch::UpdateFinalContext() {
|
||||
if (context_graph_ == nullptr) return;
|
||||
|
||||
CHECK(hypotheses_.size() == cur_hyps_.size());
|
||||
CHECK(hypotheses_.size() == likelihood_.size());
|
||||
|
||||
// We should backoff the context Score/state when the context is
|
||||
// not fully matched at the last time.
|
||||
for (const auto& prefix : hypotheses_) {
|
||||
PrefixScore& prefix_score = cur_hyps_[prefix];
|
||||
if (prefix_score.context_score != 0) {
|
||||
prefix_score.UpdateContext(
|
||||
context_graph_, prefix_score, 0, prefix.size());
|
||||
}
|
||||
}
|
||||
std::vector<std::pair<std::vector<int>, PrefixScore>> arr(cur_hyps_.begin(),
|
||||
cur_hyps_.end());
|
||||
std::sort(arr.begin(), arr.end(), PrefixScoreCompare);
|
||||
|
||||
// Update cur_hyps_ and get new result
|
||||
UpdateHypotheses(arr);
|
||||
}
|
||||
|
||||
std::string CTCPrefixBeamSearch::GetBestPath(int index) {
|
||||
int n_hyps = Outputs().size();
|
||||
CHECK_GT(n_hyps, 0);
|
||||
CHECK_LT(index, n_hyps);
|
||||
std::vector<int> one = Outputs()[index];
|
||||
std::string sentence;
|
||||
for (int i = 0; i < one.size(); i++) {
|
||||
sentence += unit_table_->Find(one[i]);
|
||||
}
|
||||
return sentence;
|
||||
}
|
||||
|
||||
std::string CTCPrefixBeamSearch::GetBestPath() { return GetBestPath(0); }
|
||||
|
||||
std::vector<std::pair<double, std::string>> CTCPrefixBeamSearch::GetNBestPath(
|
||||
int n) {
|
||||
int hyps_size = hypotheses_.size();
|
||||
CHECK_GT(hyps_size, 0);
|
||||
|
||||
int min_n = n == -1 ? hypotheses_.size() : std::min(n, hyps_size);
|
||||
|
||||
std::vector<std::pair<double, std::string>> n_best;
|
||||
n_best.reserve(min_n);
|
||||
|
||||
for (int i = 0; i < min_n; i++) {
|
||||
n_best.emplace_back(Likelihood()[i], GetBestPath(i));
|
||||
}
|
||||
return n_best;
|
||||
}
|
||||
|
||||
std::vector<std::pair<double, std::string>>
|
||||
CTCPrefixBeamSearch::GetNBestPath() {
|
||||
return GetNBestPath(-1);
|
||||
}
|
||||
|
||||
std::string CTCPrefixBeamSearch::GetFinalBestPath() { return GetBestPath(); }
|
||||
|
||||
std::string CTCPrefixBeamSearch::GetPartialResult() { return GetBestPath(); }
|
||||
|
||||
|
||||
} // namespace ppspeech
|
@ -0,0 +1,101 @@
|
||||
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// modified from
|
||||
// https://github.com/wenet-e2e/wenet/blob/main/runtime/core/decoder/ctc_prefix_beam_search.cc
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "decoder/ctc_beam_search_opt.h"
|
||||
#include "decoder/ctc_prefix_beam_search_score.h"
|
||||
#include "decoder/decoder_itf.h"
|
||||
#include "fst/symbol-table.h"
|
||||
|
||||
namespace ppspeech {
|
||||
class ContextGraph;
|
||||
class CTCPrefixBeamSearch : public DecoderBase {
|
||||
public:
|
||||
CTCPrefixBeamSearch(const std::string& vocab_path,
|
||||
const CTCBeamSearchOptions& opts);
|
||||
~CTCPrefixBeamSearch() {}
|
||||
|
||||
SearchType Type() const { return SearchType::kPrefixBeamSearch; }
|
||||
|
||||
void InitDecoder() override;
|
||||
|
||||
void Reset() override;
|
||||
|
||||
void AdvanceDecode(
|
||||
const std::shared_ptr<kaldi::DecodableInterface>& decodable) override;
|
||||
|
||||
std::string GetFinalBestPath() override;
|
||||
std::string GetPartialResult() override;
|
||||
|
||||
void FinalizeSearch();
|
||||
|
||||
const std::shared_ptr<fst::SymbolTable> VocabTable() const {
|
||||
return unit_table_;
|
||||
}
|
||||
|
||||
const std::vector<std::vector<int>>& Inputs() const { return hypotheses_; }
|
||||
const std::vector<std::vector<int>>& Outputs() const { return outputs_; }
|
||||
const std::vector<float>& Likelihood() const { return likelihood_; }
|
||||
const std::vector<float>& ViterbiLikelihood() const {
|
||||
return viterbi_likelihood_;
|
||||
}
|
||||
const std::vector<std::vector<int>>& Times() const { return times_; }
|
||||
|
||||
|
||||
protected:
|
||||
std::string GetBestPath() override;
|
||||
std::vector<std::pair<double, std::string>> GetNBestPath() override;
|
||||
std::vector<std::pair<double, std::string>> GetNBestPath(int n) override;
|
||||
|
||||
private:
|
||||
std::string GetBestPath(int index);
|
||||
|
||||
void AdvanceDecoding(
|
||||
const std::vector<std::vector<kaldi::BaseFloat>>& logp);
|
||||
|
||||
void UpdateOutputs(const std::pair<std::vector<int>, PrefixScore>& prefix);
|
||||
void UpdateHypotheses(
|
||||
const std::vector<std::pair<std::vector<int>, PrefixScore>>& prefix);
|
||||
void UpdateFinalContext();
|
||||
|
||||
|
||||
private:
|
||||
CTCBeamSearchOptions opts_;
|
||||
std::shared_ptr<fst::SymbolTable> unit_table_{nullptr};
|
||||
|
||||
std::unordered_map<std::vector<int>, PrefixScore, PrefixScoreHash>
|
||||
cur_hyps_;
|
||||
|
||||
// n-best list and corresponding likelihood, in sorted order
|
||||
std::vector<std::vector<int>> hypotheses_;
|
||||
std::vector<float> likelihood_;
|
||||
|
||||
std::vector<std::vector<int>> times_;
|
||||
std::vector<float> viterbi_likelihood_;
|
||||
|
||||
// Outputs contain the hypotheses_ and tags lik: <context> and </context>
|
||||
std::vector<std::vector<int>> outputs_;
|
||||
|
||||
std::shared_ptr<ContextGraph> context_graph_{nullptr};
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN(CTCPrefixBeamSearch);
|
||||
};
|
||||
|
||||
|
||||
} // namespace ppspeech
|
@ -0,0 +1,98 @@
|
||||
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// modified from
|
||||
// https://github.com/wenet-e2e/wenet/blob/main/runtime/core/decoder/ctc_prefix_beam_search.h
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "base/common.h"
|
||||
#include "utils/math.h"
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
class ContextGraph;
|
||||
|
||||
struct PrefixScore {
|
||||
// decoding, unit in log scale
|
||||
float b = -kBaseFloatMax; // blank ending score
|
||||
float nb = -kBaseFloatMax; // none-blank ending score
|
||||
|
||||
// decoding score, sum
|
||||
float Score() const { return LogSumExp(b, nb); }
|
||||
|
||||
// timestamp, unit in log sclae
|
||||
float v_b = -kBaseFloatMax; // viterbi blank ending score
|
||||
float v_nb = -kBaseFloatMax; // niterbi none-blank ending score
|
||||
float cur_token_prob = -kBaseFloatMax; // prob of current token
|
||||
std::vector<int> times_b; // times of viterbi blank path
|
||||
std::vector<int> times_nb; // times of viterbi non-blank path
|
||||
|
||||
|
||||
// timestamp score, max
|
||||
float ViterbiScore() const { return std::max(v_b, v_nb); }
|
||||
|
||||
// get timestamp
|
||||
const std::vector<int>& Times() const {
|
||||
return v_b > v_nb ? times_b : times_nb;
|
||||
}
|
||||
|
||||
// context state
|
||||
bool has_context = false;
|
||||
int context_state = 0;
|
||||
float context_score = 0;
|
||||
std::vector<int> start_boundaries;
|
||||
std::vector<int> end_boundaries;
|
||||
|
||||
|
||||
// decodign score with context bias
|
||||
float TotalScore() const { return Score() + context_score; }
|
||||
|
||||
void CopyContext(const PrefixScore& prefix_score) {
|
||||
context_state = prefix_score.context_state;
|
||||
context_score = prefix_score.context_score;
|
||||
start_boundaries = prefix_score.start_boundaries;
|
||||
end_boundaries = prefix_score.end_boundaries;
|
||||
}
|
||||
|
||||
void UpdateContext(const std::shared_ptr<ContextGraph>& constext_graph,
|
||||
const PrefixScore& prefix_score,
|
||||
int word_id,
|
||||
int prefix_len) {
|
||||
CHECK(false);
|
||||
}
|
||||
|
||||
void InitEmpty() {
|
||||
b = 0.0f; // log(1)
|
||||
nb = -kBaseFloatMax; // log(0)
|
||||
v_b = 0.0f; // log(1)
|
||||
v_nb = 0.0f; // log(1)
|
||||
}
|
||||
};
|
||||
|
||||
struct PrefixScoreHash {
|
||||
// https://stackoverflow.com/questions/20511347/a-good-hash-function-for-a-vector
|
||||
std::size_t operator()(const std::vector<int>& prefix) const {
|
||||
std::size_t seed = prefix.size();
|
||||
for (auto& i : prefix) {
|
||||
seed ^= i + 0x9e3779b9 + (seed << 6) + (seed >> 2);
|
||||
}
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
using PrefixWithScoreType = std::pair<std::vector<int>, PrefixScoreHash>;
|
||||
|
||||
} // namespace ppspeech
|
@ -0,0 +1,137 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// todo refactor, repalce with gtest
|
||||
|
||||
#include "base/common.h"
|
||||
#include "decoder/ctc_tlg_decoder.h"
|
||||
#include "decoder/param.h"
|
||||
#include "frontend/audio/data_cache.h"
|
||||
#include "kaldi/util/table-types.h"
|
||||
#include "nnet/decodable.h"
|
||||
#include "nnet/ds2_nnet.h"
|
||||
|
||||
|
||||
DEFINE_string(feature_rspecifier, "", "test feature rspecifier");
|
||||
DEFINE_string(result_wspecifier, "", "test result wspecifier");
|
||||
|
||||
|
||||
using kaldi::BaseFloat;
|
||||
using kaldi::Matrix;
|
||||
using std::vector;
|
||||
|
||||
// test TLG decoder by feeding speech feature.
|
||||
int main(int argc, char* argv[]) {
|
||||
gflags::SetUsageMessage("Usage:");
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
google::InstallFailureSignalHandler();
|
||||
FLAGS_logtostderr = 1;
|
||||
|
||||
kaldi::SequentialBaseFloatMatrixReader feature_reader(
|
||||
FLAGS_feature_rspecifier);
|
||||
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
|
||||
|
||||
int32 num_done = 0, num_err = 0;
|
||||
|
||||
ppspeech::TLGDecoderOptions opts =
|
||||
ppspeech::TLGDecoderOptions::InitFromFlags();
|
||||
opts.opts.beam = 15.0;
|
||||
opts.opts.lattice_beam = 7.5;
|
||||
ppspeech::TLGDecoder decoder(opts);
|
||||
|
||||
ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags();
|
||||
|
||||
std::shared_ptr<ppspeech::PaddleNnet> nnet(
|
||||
new ppspeech::PaddleNnet(model_opts));
|
||||
std::shared_ptr<ppspeech::DataCache> raw_data(new ppspeech::DataCache());
|
||||
std::shared_ptr<ppspeech::Decodable> decodable(
|
||||
new ppspeech::Decodable(nnet, raw_data, FLAGS_acoustic_scale));
|
||||
|
||||
int32 chunk_size = FLAGS_receptive_field_length +
|
||||
(FLAGS_nnet_decoder_chunk - 1) * FLAGS_subsampling_rate;
|
||||
int32 chunk_stride = FLAGS_subsampling_rate * FLAGS_nnet_decoder_chunk;
|
||||
int32 receptive_field_length = FLAGS_receptive_field_length;
|
||||
LOG(INFO) << "chunk size (frame): " << chunk_size;
|
||||
LOG(INFO) << "chunk stride (frame): " << chunk_stride;
|
||||
LOG(INFO) << "receptive field (frame): " << receptive_field_length;
|
||||
|
||||
decoder.InitDecoder();
|
||||
kaldi::Timer timer;
|
||||
for (; !feature_reader.Done(); feature_reader.Next()) {
|
||||
string utt = feature_reader.Key();
|
||||
kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
|
||||
raw_data->SetDim(feature.NumCols());
|
||||
LOG(INFO) << "process utt: " << utt;
|
||||
LOG(INFO) << "rows: " << feature.NumRows();
|
||||
LOG(INFO) << "cols: " << feature.NumCols();
|
||||
|
||||
int32 row_idx = 0;
|
||||
int32 padding_len = 0;
|
||||
int32 ori_feature_len = feature.NumRows();
|
||||
if ((feature.NumRows() - chunk_size) % chunk_stride != 0) {
|
||||
padding_len =
|
||||
chunk_stride - (feature.NumRows() - chunk_size) % chunk_stride;
|
||||
feature.Resize(feature.NumRows() + padding_len,
|
||||
feature.NumCols(),
|
||||
kaldi::kCopyData);
|
||||
}
|
||||
int32 num_chunks = (feature.NumRows() - chunk_size) / chunk_stride + 1;
|
||||
for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) {
|
||||
kaldi::Vector<kaldi::BaseFloat> feature_chunk(chunk_size *
|
||||
feature.NumCols());
|
||||
int32 feature_chunk_size = 0;
|
||||
if (ori_feature_len > chunk_idx * chunk_stride) {
|
||||
feature_chunk_size = std::min(
|
||||
ori_feature_len - chunk_idx * chunk_stride, chunk_size);
|
||||
}
|
||||
if (feature_chunk_size < receptive_field_length) break;
|
||||
|
||||
int32 start = chunk_idx * chunk_stride;
|
||||
for (int row_id = 0; row_id < chunk_size; ++row_id) {
|
||||
kaldi::SubVector<kaldi::BaseFloat> tmp(feature, start);
|
||||
kaldi::SubVector<kaldi::BaseFloat> f_chunk_tmp(
|
||||
feature_chunk.Data() + row_id * feature.NumCols(),
|
||||
feature.NumCols());
|
||||
f_chunk_tmp.CopyFromVec(tmp);
|
||||
++start;
|
||||
}
|
||||
raw_data->Accept(feature_chunk);
|
||||
if (chunk_idx == num_chunks - 1) {
|
||||
raw_data->SetFinished();
|
||||
}
|
||||
decoder.AdvanceDecode(decodable);
|
||||
}
|
||||
std::string result;
|
||||
result = decoder.GetFinalBestPath();
|
||||
decodable->Reset();
|
||||
decoder.Reset();
|
||||
if (result.empty()) {
|
||||
// the TokenWriter can not write empty string.
|
||||
++num_err;
|
||||
KALDI_LOG << " the result of " << utt << " is empty";
|
||||
continue;
|
||||
}
|
||||
KALDI_LOG << " the result of " << utt << " is " << result;
|
||||
result_writer.Write(utt, result);
|
||||
++num_done;
|
||||
}
|
||||
|
||||
double elapsed = timer.Elapsed();
|
||||
KALDI_LOG << " cost:" << elapsed << " s";
|
||||
|
||||
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
|
||||
<< " with errors.";
|
||||
return (num_done != 0 ? 0 : 1);
|
||||
}
|
@ -0,0 +1,66 @@
|
||||
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "base/common.h"
|
||||
#include "kaldi/decoder/decodable-itf.h"
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
enum SearchType {
|
||||
kPrefixBeamSearch = 0,
|
||||
kWfstBeamSearch = 1,
|
||||
};
|
||||
class DecoderInterface {
|
||||
public:
|
||||
virtual ~DecoderInterface() {}
|
||||
|
||||
virtual void InitDecoder() = 0;
|
||||
|
||||
virtual void Reset() = 0;
|
||||
|
||||
// call AdvanceDecoding
|
||||
virtual void AdvanceDecode(
|
||||
const std::shared_ptr<kaldi::DecodableInterface>& decodable) = 0;
|
||||
|
||||
// call GetBestPath
|
||||
virtual std::string GetFinalBestPath() = 0;
|
||||
|
||||
virtual std::string GetPartialResult() = 0;
|
||||
|
||||
protected:
|
||||
// virtual void AdvanceDecoding(kaldi::DecodableInterface* decodable) = 0;
|
||||
|
||||
// virtual void Decode() = 0;
|
||||
|
||||
virtual std::string GetBestPath() = 0;
|
||||
|
||||
virtual std::vector<std::pair<double, std::string>> GetNBestPath() = 0;
|
||||
|
||||
virtual std::vector<std::pair<double, std::string>> GetNBestPath(int n) = 0;
|
||||
};
|
||||
|
||||
class DecoderBase : public DecoderInterface {
|
||||
protected:
|
||||
// start from one
|
||||
int NumFrameDecoded() { return num_frame_decoded_ + 1; }
|
||||
|
||||
protected:
|
||||
// current decoding frame number, abs_time_step_
|
||||
int32 num_frame_decoded_;
|
||||
};
|
||||
|
||||
} // namespace ppspeech
|
@ -1,14 +1,39 @@
|
||||
project(nnet)
|
||||
set(srcs decodable.cc)
|
||||
|
||||
add_library(nnet STATIC
|
||||
decodable.cc
|
||||
paddle_nnet.cc
|
||||
)
|
||||
if(USING_DS2)
|
||||
list(APPEND srcs ds2_nnet.cc)
|
||||
endif()
|
||||
|
||||
if(USING_U2)
|
||||
list(APPEND srcs u2_nnet.cc)
|
||||
endif()
|
||||
|
||||
add_library(nnet STATIC ${srcs})
|
||||
target_link_libraries(nnet absl::strings)
|
||||
|
||||
set(bin_name nnet_forward_main)
|
||||
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
|
||||
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
||||
target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog nnet ${DEPS})
|
||||
if(USING_U2)
|
||||
target_compile_options(nnet PUBLIC ${PADDLE_COMPILE_FLAGS})
|
||||
target_include_directories(nnet PUBLIC ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
|
||||
endif()
|
||||
|
||||
|
||||
if(USING_DS2)
|
||||
set(bin_name ds2_nnet_main)
|
||||
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
|
||||
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
||||
target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog nnet)
|
||||
|
||||
target_link_libraries(${bin_name} ${DEPS})
|
||||
endif()
|
||||
|
||||
# test bin
|
||||
if(USING_U2)
|
||||
set(bin_name u2_nnet_main)
|
||||
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
|
||||
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
||||
target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog nnet)
|
||||
|
||||
target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS})
|
||||
target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
|
||||
target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS})
|
||||
endif()
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue