From daadec0c63975ee7bdb0893411325d57c4534935 Mon Sep 17 00:00:00 2001 From: Yang Zhou Date: Tue, 24 May 2022 16:23:52 +0800 Subject: [PATCH] add custom asr script --- demos/custom_streaming_asr/README.md | 4 +- demos/custom_streaming_asr/README_cn.md | 2 + speechx/CMakeLists.txt | 12 +- speechx/examples/custom_asr/README.md | 32 ++ .../local/compile_lexicon_token_fst.sh | 89 +++++ .../custom_asr/local/mk_slot_graph.sh | 74 ++++ .../custom_asr/local/mk_tlg_with_slot.sh | 61 +++ .../custom_asr/local/train_lm_with_slot.sh | 55 +++ speechx/examples/custom_asr/path.sh | 17 + speechx/examples/custom_asr/run.sh | 88 ++++ speechx/examples/custom_asr/utils | 1 + speechx/speechx/kaldi/CMakeLists.txt | 4 + speechx/speechx/kaldi/fstbin/CMakeLists.txt | 15 + .../kaldi}/fstbin/fstaddselfloops.cc | 0 .../kaldi}/fstbin/fstdeterminizestar.cc | 0 .../kaldi}/fstbin/fstisstochastic.cc | 0 .../kaldi}/fstbin/fstminimizeencoded.cc | 0 .../kaldi}/fstbin/fsttablecompose.cc | 0 speechx/speechx/kaldi/fstext/CMakeLists.txt | 2 +- speechx/speechx/kaldi/lm/CMakeLists.txt | 6 + speechx/speechx/kaldi/lm/arpa-file-parser.cc | 281 +++++++++++++ speechx/speechx/kaldi/lm/arpa-file-parser.h | 146 +++++++ speechx/speechx/kaldi/lm/arpa-lm-compiler.cc | 377 ++++++++++++++++++ speechx/speechx/kaldi/lm/arpa-lm-compiler.h | 65 +++ .../kaldi}/lmbin/CMakeLists.txt | 3 +- .../kaldi}/lmbin/arpa2fst.cc | 0 26 files changed, 1328 insertions(+), 6 deletions(-) create mode 100644 speechx/examples/custom_asr/README.md create mode 100755 speechx/examples/custom_asr/local/compile_lexicon_token_fst.sh create mode 100755 speechx/examples/custom_asr/local/mk_slot_graph.sh create mode 100755 speechx/examples/custom_asr/local/mk_tlg_with_slot.sh create mode 100755 speechx/examples/custom_asr/local/train_lm_with_slot.sh create mode 100644 speechx/examples/custom_asr/path.sh create mode 100644 speechx/examples/custom_asr/run.sh create mode 120000 speechx/examples/custom_asr/utils create mode 100644 speechx/speechx/kaldi/fstbin/CMakeLists.txt rename speechx/{tools => speechx/kaldi}/fstbin/fstaddselfloops.cc (100%) rename speechx/{tools => speechx/kaldi}/fstbin/fstdeterminizestar.cc (100%) rename speechx/{tools => speechx/kaldi}/fstbin/fstisstochastic.cc (100%) rename speechx/{tools => speechx/kaldi}/fstbin/fstminimizeencoded.cc (100%) rename speechx/{tools => speechx/kaldi}/fstbin/fsttablecompose.cc (100%) create mode 100644 speechx/speechx/kaldi/lm/CMakeLists.txt create mode 100644 speechx/speechx/kaldi/lm/arpa-file-parser.cc create mode 100644 speechx/speechx/kaldi/lm/arpa-file-parser.h create mode 100644 speechx/speechx/kaldi/lm/arpa-lm-compiler.cc create mode 100644 speechx/speechx/kaldi/lm/arpa-lm-compiler.h rename speechx/{tools => speechx/kaldi}/lmbin/CMakeLists.txt (64%) rename speechx/{tools => speechx/kaldi}/lmbin/arpa2fst.cc (100%) diff --git a/demos/custom_streaming_asr/README.md b/demos/custom_streaming_asr/README.md index aa28d502..74af59a7 100644 --- a/demos/custom_streaming_asr/README.md +++ b/demos/custom_streaming_asr/README.md @@ -7,6 +7,8 @@ In some cases, we need to recognize the specific rare words with high accuracy. this demo is customized for expense account, which need to recognize rare address. +the scripts are in PaddleSpeech/speechx/examples/custom_asr. + * G with slot: 打车到 "address_slot"。 ![](https://ai-studio-static-online.cdn.bcebos.com/28d9ef132a7f47a895a65ae9e5c4f55b8f472c9f3dd24be8a2e66e0b88b173a4) @@ -62,4 +64,4 @@ I0513 10:58:13.884493 41768 feature_cache.h:52] set finished I0513 10:58:24.247171 41768 paddle_nnet.h:76] Tensor neml: 10240 I0513 10:58:24.247249 41768 paddle_nnet.h:76] Tensor neml: 10240 LOG ([5.5.544~2-f21d7]:main():decoder/recognizer_test_main.cc:90) the result of case_10 is 五月十二日二十二点三十六分加班打车回家四十一元 -``` \ No newline at end of file +``` diff --git a/demos/custom_streaming_asr/README_cn.md b/demos/custom_streaming_asr/README_cn.md index ffbf682f..5c0f7e89 100644 --- a/demos/custom_streaming_asr/README_cn.md +++ b/demos/custom_streaming_asr/README_cn.md @@ -6,6 +6,8 @@ 这个 demo 是打车报销单的场景识别,需要识别一些稀有的地名,可以通过如下操作实现。 +相关脚本:PaddleSpeech/speechx/examples/custom_asr + * G with slot: 打车到 "address_slot"。 ![](https://ai-studio-static-online.cdn.bcebos.com/28d9ef132a7f47a895a65ae9e5c4f55b8f472c9f3dd24be8a2e66e0b88b173a4) diff --git a/speechx/CMakeLists.txt b/speechx/CMakeLists.txt index 98d9e637..db5c3cc6 100644 --- a/speechx/CMakeLists.txt +++ b/speechx/CMakeLists.txt @@ -57,7 +57,7 @@ include(gtest) include(absl) # libsndfile -include(libsndfile) +#include(libsndfile) # boost # include(boost) # not work @@ -73,9 +73,17 @@ find_package(Eigen3 REQUIRED) # Kenlm include(kenlm) add_dependencies(kenlm eigen boost) +#set(kenlm_install_dir $(fc_patch)/kenlm-build) +#link_directories(${Kenlm_install_dir}/lib) +#include_directories(${fc_patch}/kenlm-src) #openblas -include(openblas) +#include(openblas) +set(OpenBLAS_INSTALL_PREFIX ${fc_patch}/openblas-install) +link_directories(${OpenBLAS_INSTALL_PREFIX}/lib) +include_directories(${OpenBLAS_INSTALL_PREFIX}/include) + + # openfst include(openfst) diff --git a/speechx/examples/custom_asr/README.md b/speechx/examples/custom_asr/README.md new file mode 100644 index 00000000..bfc071cb --- /dev/null +++ b/speechx/examples/custom_asr/README.md @@ -0,0 +1,32 @@ +# customized Auto Speech Recognition + +## introduction +those scripts are tutorials to show you how make your own decoding graph. + +eg: +* G with slot: 打车到 "address_slot"。 +![](https://ai-studio-static-online.cdn.bcebos.com/28d9ef132a7f47a895a65ae9e5c4f55b8f472c9f3dd24be8a2e66e0b88b173a4) + +* this is address slot wfst, you can add the address which want to recognize. +![](https://ai-studio-static-online.cdn.bcebos.com/47c89100ef8c465bac733605ffc53d76abefba33d62f4d818d351f8cea3c8fe2) + +* after replace operation, G = fstreplace(G_with_slot, address_slot), we will get the customized graph. +![](https://ai-studio-static-online.cdn.bcebos.com/60a3095293044f10b73039ab10c7950d139a6717580a44a3ba878c6e74de402b) + +those operations are in the scripts, please check out. we will lanuch more detail scripts. + +## How to run + +``` +bash run.sh +``` + +## Results + +### CTC WFST + +``` +Overall -> 1.23 % N=1134 C=1126 S=6 D=2 I=6 +Mandarin -> 1.24 % N=1132 C=1124 S=6 D=2 I=6 +English -> 0.00 % N=2 C=2 S=0 D=0 I=0 +``` \ No newline at end of file diff --git a/speechx/examples/custom_asr/local/compile_lexicon_token_fst.sh b/speechx/examples/custom_asr/local/compile_lexicon_token_fst.sh new file mode 100755 index 00000000..8411f7ed --- /dev/null +++ b/speechx/examples/custom_asr/local/compile_lexicon_token_fst.sh @@ -0,0 +1,89 @@ +#!/bin/bash +# Copyright 2015 Yajie Miao (Carnegie Mellon University) + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + +# This script compiles the lexicon and CTC tokens into FSTs. FST compiling slightly differs between the +# phoneme and character-based lexicons. +set -eo pipefail +. utils/parse_options.sh + +if [ $# -ne 3 ]; then + echo "usage: utils/fst/compile_lexicon_token_fst.sh " + echo "e.g.: utils/fst/compile_lexicon_token_fst.sh data/local/dict data/local/lang_tmp data/lang" + echo " should contain the following files:" + echo "lexicon.txt lexicon_numbers.txt units.txt" + echo "options: " + exit 1; +fi + +srcdir=$1 +tmpdir=$2 +dir=$3 +mkdir -p $dir $tmpdir + +[ -f path.sh ] && . ./path.sh + +cp $srcdir/units.txt $dir + +# Add probabilities to lexicon entries. There is in fact no point of doing this here since all the entries have 1.0. +# But utils/make_lexicon_fst.pl requires a probabilistic version, so we just leave it as it is. +perl -ape 's/(\S+\s+)(.+)/${1}1.0\t$2/;' < $srcdir/lexicon.txt > $tmpdir/lexiconp.txt || exit 1; + +# Add disambiguation symbols to the lexicon. This is necessary for determinizing the composition of L.fst and G.fst. +# Without these symbols, determinization will fail. +# default first disambiguation is #1 +ndisambig=`utils/fst/add_lex_disambig.pl $tmpdir/lexiconp.txt $tmpdir/lexiconp_disambig.txt` +# add #0 (#0 reserved for symbol in grammar). +ndisambig=$[$ndisambig+1]; + +( for n in `seq 0 $ndisambig`; do echo '#'$n; done ) > $tmpdir/disambig.list + +# Get the full list of CTC tokens used in FST. These tokens include , the blank , +# the actual model unit, and the disambiguation symbols. +cat $srcdir/units.txt | awk '{print $1}' > $tmpdir/units.list +(echo '';) | cat - $tmpdir/units.list $tmpdir/disambig.list | awk '{print $1 " " (NR-1)}' > $dir/tokens.txt + +# ctc_token_fst_corrected is too big and too slow for character based chinese modeling, +# so here just use simple ctc_token_fst +utils/fst/ctc_token_fst.py --token_file $dir/tokens.txt | \ + fstcompile --isymbols=$dir/tokens.txt --osymbols=$dir/tokens.txt --keep_isymbols=false --keep_osymbols=false | \ + fstarcsort --sort_type=olabel > $dir/T.fst || exit 1; + +# Encode the words with indices. Will be used in lexicon and language model FST compiling. +cat $tmpdir/lexiconp.txt | awk '{print $1}' | sort | awk ' + BEGIN { + print " 0"; + } + { + printf("%s %d\n", $1, NR); + } + END { + printf("#0 %d\n", NR+1); + printf(" %d\n", NR+2); + printf(" %d\n", NR+3); + printf("ROOT %d\n", NR+4); + }' > $dir/words.txt || exit 1; + +# Now compile the lexicon FST. Depending on the size of your lexicon, it may take some time. +token_disambig_symbol=`grep \#0 $dir/tokens.txt | awk '{print $2}'` +word_disambig_symbol=`grep \#0 $dir/words.txt | awk '{print $2}'` + +utils/fst/make_lexicon_fst.pl --pron-probs $tmpdir/lexiconp_disambig.txt 0 "sil" '#'$ndisambig | \ + fstcompile --isymbols=$dir/tokens.txt --osymbols=$dir/words.txt \ + --keep_isymbols=false --keep_osymbols=false | \ + fstaddselfloops "echo $token_disambig_symbol |" "echo $word_disambig_symbol |" | \ + fstarcsort --sort_type=olabel > $dir/L.fst || exit 1; + +echo "Lexicon and Token FSTs compiling succeeded" diff --git a/speechx/examples/custom_asr/local/mk_slot_graph.sh b/speechx/examples/custom_asr/local/mk_slot_graph.sh new file mode 100755 index 00000000..8298a5d0 --- /dev/null +++ b/speechx/examples/custom_asr/local/mk_slot_graph.sh @@ -0,0 +1,74 @@ +#!/bin/bash + +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +graph_slot=$1 +dir=$2 + +[ -f path.sh ] && . ./path.sh + +sym=$dir/../lang/words.txt +cat > $dir/address_slot.txt < +0 5 上海 上海 +0 5 北京 北京 +0 5 合肥 合肥 +5 1 南站 南站 +0 6 立水 立水 +6 1 桥 桥 +0 7 青岛 青岛 +7 1 站 站 +1 +EOF + +fstcompile --isymbols=$sym --osymbols=$sym $dir/address_slot.txt $dir/address_slot.fst +fstcompile --isymbols=$sym --osymbols=$sym $graph_slot/time_slot.txt $dir/time_slot.fst +fstcompile --isymbols=$sym --osymbols=$sym $graph_slot/date_slot.txt $dir/date_slot.fst +fstcompile --isymbols=$sym --osymbols=$sym $graph_slot/money_slot.txt $dir/money_slot.fst +fstcompile --isymbols=$sym --osymbols=$sym $graph_slot/year_slot.txt $dir/year_slot.fst diff --git a/speechx/examples/custom_asr/local/mk_tlg_with_slot.sh b/speechx/examples/custom_asr/local/mk_tlg_with_slot.sh new file mode 100755 index 00000000..a5569f40 --- /dev/null +++ b/speechx/examples/custom_asr/local/mk_tlg_with_slot.sh @@ -0,0 +1,61 @@ +#!/bin/bash + +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +lm=$1 +lang=$2 +tgt_lang=$3 + +unset GREP_OPTIONS + +sym=$lang/words.txt +arpa_lm=$lm/lm.arpa +# Compose the language model to FST +cat $arpa_lm | \ + grep -v ' ' | \ + grep -v ' ' | \ + grep -v ' ' | \ + grep -v -i '' | \ + grep -v -i '' | \ + arpa2fst --read-symbol-table=$sym --keep-symbols=true - | fstprint | \ + utils/fst/eps2disambig.pl | utils/fst/s2eps.pl | fstcompile --isymbols=$sym \ + --osymbols=$sym --keep_isymbols=false --keep_osymbols=false | \ + fstrmepsilon | fstarcsort --sort_type=ilabel > $tgt_lang/G_with_slot.fst + +root_label=`grep ROOT $sym | awk '{print $2}'` +address_slot_label=`grep \ $sym | awk '{print $2}'` +time_slot_label=`grep \ $sym | awk '{print $2}'` +date_slot_label=`grep \ $sym | awk '{print $2}'` +money_slot_label=`grep \ $sym | awk '{print $2}'` +year_slot_label=`grep \ $sym | awk '{print $2}'` + +fstisstochastic $tgt_lang/G_with_slot.fst + +fstreplace --epsilon_on_replace $tgt_lang/G_with_slot.fst \ + $root_label $tgt_lang/address_slot.fst $address_slot_label \ + $tgt_lang/date_slot.fst $date_slot_label \ + $tgt_lang/money_slot.fst $money_slot_label \ + $tgt_lang/time_slot.fst $time_slot_label \ + $tgt_lang/year_slot.fst $year_slot_label $tgt_lang/G.fst + +fstisstochastic $tgt_lang/G.fst + +# Compose the token, lexicon and language-model FST into the final decoding graph +fsttablecompose $lang/L.fst $tgt_lang/G.fst | fstdeterminizestar --use-log=true | \ + fstminimizeencoded | fstarcsort --sort_type=ilabel > $tgt_lang/LG.fst || exit 1; +fsttablecompose $lang/T.fst $tgt_lang/LG.fst > $tgt_lang/TLG.fst || exit 1; +rm $tgt_lang/LG.fst + +echo "Composing decoding graph TLG.fst succeeded" \ No newline at end of file diff --git a/speechx/examples/custom_asr/local/train_lm_with_slot.sh b/speechx/examples/custom_asr/local/train_lm_with_slot.sh new file mode 100755 index 00000000..3f557ec3 --- /dev/null +++ b/speechx/examples/custom_asr/local/train_lm_with_slot.sh @@ -0,0 +1,55 @@ +#!/bin/bash + +# To be run from one directory above this script. +. ./path.sh +src=ds2_graph_with_slot +text=$src/train_text +lexicon=$src/local/dict/lexicon.txt + +dir=$src/local/lm +mkdir -p $dir + +for f in "$text" "$lexicon"; do + [ ! -f $x ] && echo "$0: No such file $f" && exit 1; +done + +# Check SRILM tools +if ! which ngram-count > /dev/null; then + pushd $MAIN_ROOT/tools + make srilm.done + popd +fi + +# This script takes no arguments. It assumes you have already run +# It takes as input the files +# data/local/lm/text +# data/local/dict/lexicon.txt + + +cleantext=$dir/text.no_oov + +cat $text | awk -v lex=$lexicon 'BEGIN{while((getline0){ seen[$1]=1; } } + {for(n=1; n<=NF;n++) { if (seen[$n]) { printf("%s ", $n); } else {printf(" ");} } printf("\n");}' \ + > $cleantext || exit 1; + +cat $cleantext | awk '{for(n=2;n<=NF;n++) print $n; }' | sort | uniq -c | \ + sort -nr > $dir/word.counts || exit 1; +# Get counts from acoustic training transcripts, and add one-count +# for each word in the lexicon (but not silence, we don't want it +# in the LM-- we'll add it optionally later). +cat $cleantext | awk '{for(n=2;n<=NF;n++) print $n; }' | \ + cat - <(grep -w -v '!SIL' $lexicon | awk '{print $1}') | \ + sort | uniq -c | sort -nr > $dir/unigram.counts || exit 1; + +# filter the words which are not in the text +cat $dir/unigram.counts | awk '$1>1{print $0}' | awk '{print $2}' | cat - <(echo ""; echo "" ) > $dir/wordlist + +# kaldi_lm results +mkdir -p $dir +cat $cleantext | awk '{for(n=2;n<=NF;n++){ printf $n; if(n $dir/train + +ngram-count -text $dir/train -order 3 -limit-vocab -vocab $dir/wordlist -unk \ + -map-unk "" -gt3max 0 -gt2max 0 -gt1max 0 -lm $dir/lm.arpa + +#ngram-count -text $dir/train -order 3 -limit-vocab -vocab $dir/wordlist -unk \ +# -map-unk "" -lm $dir/lm2.arpa \ No newline at end of file diff --git a/speechx/examples/custom_asr/path.sh b/speechx/examples/custom_asr/path.sh new file mode 100644 index 00000000..1907c79f --- /dev/null +++ b/speechx/examples/custom_asr/path.sh @@ -0,0 +1,17 @@ +# This contains the locations of binarys build required for running the examples. + +MAIN_ROOT=`realpath $PWD/../../../` +SPEECHX_ROOT=`realpath $MAIN_ROOT/speechx` +SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples + +export LC_AL=C + +# srilm +export LIBLBFGS=${MAIN_ROOT}/tools/liblbfgs-1.10 +export LD_LIBRARY_PATH=${LD_LIBRARY_PATH:-}:${LIBLBFGS}/lib/.libs +export SRILM=${MAIN_ROOT}/tools/srilm + +# kaldi lm +KALDI_DIR=$SPEECHX_ROOT/build/speechx/kaldi/ +OPENFST_DIR=$SPEECHX_ROOT/fc_patch/openfst-build/src +export PATH=${PATH}:${SRILM}/bin:${SRILM}/bin/i686-m64:$KALDI_DIR/lmbin:$KALDI_DIR/fstbin:$OPENFST_DIR/bin:$SPEECHX_EXAMPLES/ds2_ol/decoder diff --git a/speechx/examples/custom_asr/run.sh b/speechx/examples/custom_asr/run.sh new file mode 100644 index 00000000..8d88000d --- /dev/null +++ b/speechx/examples/custom_asr/run.sh @@ -0,0 +1,88 @@ +#!/bin/bash +set +x +set -e + +export GLOG_logtostderr=1 + +. ./path.sh || exit 1; + +# ds2 means deepspeech2 (acoutic model type) +dir=$PWD/ds2_graph_with_slot +data=$PWD/data +stage=0 +stop_stage=10 + +mkdir -p $dir + +model_dir=$PWD/resource/model +vocab=$model_dir/vocab.txt +cmvn=$data/cmvn.ark +text_with_slot=$data/text_with_slot +resource=$PWD/resource +# download resource +if [ ! -f $cmvn ]; then + wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/resource.tar.gz + tar xzfv resource.tar.gz + ln -s ./resource/data . +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # make dict + unit_file=$vocab + mkdir -p $dir/local/dict + cp $unit_file $dir/local/dict/units.txt + cp $text_with_slot $dir/train_text + utils/fst/prepare_dict.py --unit_file $unit_file --in_lexicon $data/lexicon.txt \ + --out_lexicon $dir/local/dict/lexicon.txt + # add slot to lexicon, just in case the lm training script filter the slot. + echo " 一" >> $dir/local/dict/lexicon.txt + echo " 一" >> $dir/local/dict/lexicon.txt + echo " 一" >> $dir/local/dict/lexicon.txt + echo " 一" >> $dir/local/dict/lexicon.txt + echo " 一" >> $dir/local/dict/lexicon.txt +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # train lm + lm=$dir/local/lm + mkdir -p $lm + # this script is different with the common lm training script + local/train_lm_with_slot.sh +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # make T & L + local/compile_lexicon_token_fst.sh $dir/local/dict $dir/local/tmp $dir/local/lang + mkdir -p $dir/local/lang_test + # make slot graph + local/mk_slot_graph.sh $resource/graph $dir/local/lang_test + # make TLG + local/mk_tlg_with_slot.sh $dir/local/lm $dir/local/lang $dir/local/lang_test || exit 1; + mv $dir/local/lang_test/TLG.fst $dir/local/lang/ +fi + +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + # test TLG + model_dir=$PWD/resource/model + cmvn=$data/cmvn.ark + wav_scp=$data/wav.scp + graph=$dir/local/lang + + recognizer_test_main \ + --wav_rspecifier=scp:$wav_scp \ + --cmvn_file=$cmvn \ + --streaming_chunk=30 \ + --use_fbank=true \ + --model_path=$model_dir/avg_10.jit.pdmodel \ + --param_path=$model_dir/avg_10.jit.pdiparams \ + --model_cache_shapes="5-1-2048,5-1-2048" \ + --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ + --word_symbol_table=$graph/words.txt \ + --graph_path=$graph/TLG.fst --max_active=7500 \ + --acoustic_scale=12 \ + --result_wspecifier=ark,t:./result_run.txt + + # the data/wav.trans is the label. + utils/compute-wer.py --char=1 --v=1 data/wav.trans result_run.txt > wer_run + tail -n 7 wer_run +fi diff --git a/speechx/examples/custom_asr/utils b/speechx/examples/custom_asr/utils new file mode 120000 index 00000000..973afe67 --- /dev/null +++ b/speechx/examples/custom_asr/utils @@ -0,0 +1 @@ +../../../utils \ No newline at end of file diff --git a/speechx/speechx/kaldi/CMakeLists.txt b/speechx/speechx/kaldi/CMakeLists.txt index 6f7398cd..ce6b43f6 100644 --- a/speechx/speechx/kaldi/CMakeLists.txt +++ b/speechx/speechx/kaldi/CMakeLists.txt @@ -7,3 +7,7 @@ add_subdirectory(matrix) add_subdirectory(lat) add_subdirectory(fstext) add_subdirectory(decoder) +add_subdirectory(lm) + +add_subdirectory(fstbin) +add_subdirectory(lmbin) \ No newline at end of file diff --git a/speechx/speechx/kaldi/fstbin/CMakeLists.txt b/speechx/speechx/kaldi/fstbin/CMakeLists.txt new file mode 100644 index 00000000..05d0501f --- /dev/null +++ b/speechx/speechx/kaldi/fstbin/CMakeLists.txt @@ -0,0 +1,15 @@ +cmake_minimum_required(VERSION 3.14 FATAL_ERROR) + +set(BINS +fstaddselfloops +fstisstochastic +fstminimizeencoded +fstdeterminizestar +fsttablecompose +) + +foreach(binary IN LISTS BINS) + add_executable(${binary} ${CMAKE_CURRENT_SOURCE_DIR}/${binary}.cc) + target_include_directories(${binary} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) + target_link_libraries(${binary} PUBLIC kaldi-fstext glog gflags fst dl) +endforeach() diff --git a/speechx/tools/fstbin/fstaddselfloops.cc b/speechx/speechx/kaldi/fstbin/fstaddselfloops.cc similarity index 100% rename from speechx/tools/fstbin/fstaddselfloops.cc rename to speechx/speechx/kaldi/fstbin/fstaddselfloops.cc diff --git a/speechx/tools/fstbin/fstdeterminizestar.cc b/speechx/speechx/kaldi/fstbin/fstdeterminizestar.cc similarity index 100% rename from speechx/tools/fstbin/fstdeterminizestar.cc rename to speechx/speechx/kaldi/fstbin/fstdeterminizestar.cc diff --git a/speechx/tools/fstbin/fstisstochastic.cc b/speechx/speechx/kaldi/fstbin/fstisstochastic.cc similarity index 100% rename from speechx/tools/fstbin/fstisstochastic.cc rename to speechx/speechx/kaldi/fstbin/fstisstochastic.cc diff --git a/speechx/tools/fstbin/fstminimizeencoded.cc b/speechx/speechx/kaldi/fstbin/fstminimizeencoded.cc similarity index 100% rename from speechx/tools/fstbin/fstminimizeencoded.cc rename to speechx/speechx/kaldi/fstbin/fstminimizeencoded.cc diff --git a/speechx/tools/fstbin/fsttablecompose.cc b/speechx/speechx/kaldi/fstbin/fsttablecompose.cc similarity index 100% rename from speechx/tools/fstbin/fsttablecompose.cc rename to speechx/speechx/kaldi/fstbin/fsttablecompose.cc diff --git a/speechx/speechx/kaldi/fstext/CMakeLists.txt b/speechx/speechx/kaldi/fstext/CMakeLists.txt index af91fd98..465d9dba 100644 --- a/speechx/speechx/kaldi/fstext/CMakeLists.txt +++ b/speechx/speechx/kaldi/fstext/CMakeLists.txt @@ -1,5 +1,5 @@ add_library(kaldi-fstext -kaldi-fst-io.cc + kaldi-fst-io.cc ) target_link_libraries(kaldi-fstext PUBLIC kaldi-util) diff --git a/speechx/speechx/kaldi/lm/CMakeLists.txt b/speechx/speechx/kaldi/lm/CMakeLists.txt new file mode 100644 index 00000000..75c1567e --- /dev/null +++ b/speechx/speechx/kaldi/lm/CMakeLists.txt @@ -0,0 +1,6 @@ + +add_library(kaldi-lm + arpa-file-parser.cc + arpa-lm-compiler.cc +) +target_link_libraries(kaldi-lm PUBLIC kaldi-util) \ No newline at end of file diff --git a/speechx/speechx/kaldi/lm/arpa-file-parser.cc b/speechx/speechx/kaldi/lm/arpa-file-parser.cc new file mode 100644 index 00000000..81b63ed1 --- /dev/null +++ b/speechx/speechx/kaldi/lm/arpa-file-parser.cc @@ -0,0 +1,281 @@ +// lm/arpa-file-parser.cc + +// Copyright 2014 Guoguo Chen +// Copyright 2016 Smart Action Company LLC (kkm) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include + +#include + +#include "base/kaldi-error.h" +#include "base/kaldi-math.h" +#include "lm/arpa-file-parser.h" +#include "util/text-utils.h" + +namespace kaldi { + +ArpaFileParser::ArpaFileParser(const ArpaParseOptions& options, + fst::SymbolTable* symbols) + : options_(options), symbols_(symbols), + line_number_(0), warning_count_(0) { +} + +ArpaFileParser::~ArpaFileParser() { +} + +void TrimTrailingWhitespace(std::string *str) { + str->erase(str->find_last_not_of(" \n\r\t") + 1); +} + +void ArpaFileParser::Read(std::istream &is) { + // Argument sanity checks. + if (options_.bos_symbol <= 0 || options_.eos_symbol <= 0 || + options_.bos_symbol == options_.eos_symbol) + KALDI_ERR << "BOS and EOS symbols are required, must not be epsilons, and " + << "differ from each other. Given:" + << " BOS=" << options_.bos_symbol + << " EOS=" << options_.eos_symbol; + if (symbols_ != NULL && + options_.oov_handling == ArpaParseOptions::kReplaceWithUnk && + (options_.unk_symbol <= 0 || + options_.unk_symbol == options_.bos_symbol || + options_.unk_symbol == options_.eos_symbol)) + KALDI_ERR << "When symbol table is given and OOV mode is kReplaceWithUnk, " + << "UNK symbol is required, must not be epsilon, and " + << "differ from both BOS and EOS symbols. Given:" + << " UNK=" << options_.unk_symbol + << " BOS=" << options_.bos_symbol + << " EOS=" << options_.eos_symbol; + if (symbols_ != NULL && symbols_->Find(options_.bos_symbol).empty()) + KALDI_ERR << "BOS symbol must exist in symbol table"; + if (symbols_ != NULL && symbols_->Find(options_.eos_symbol).empty()) + KALDI_ERR << "EOS symbol must exist in symbol table"; + if (symbols_ != NULL && options_.unk_symbol > 0 && + symbols_->Find(options_.unk_symbol).empty()) + KALDI_ERR << "UNK symbol must exist in symbol table"; + + ngram_counts_.clear(); + line_number_ = 0; + warning_count_ = 0; + current_line_.clear(); + +#define PARSE_ERR KALDI_ERR << LineReference() << ": " + + // Give derived class an opportunity to prepare its state. + ReadStarted(); + + // Processes "\data\" section. + bool keyword_found = false; + while (++line_number_, getline(is, current_line_) && !is.eof()) { + if (current_line_.find_first_not_of(" \t\n\r") == std::string::npos) { + continue; + } + + TrimTrailingWhitespace(¤t_line_); + + // Continue skipping lines until the \data\ marker alone on a line is found. + if (!keyword_found) { + if (current_line_ == "\\data\\") { + KALDI_LOG << "Reading \\data\\ section."; + keyword_found = true; + } + continue; + } + + if (current_line_[0] == '\\') break; + + // Enters "\data\" section, and looks for patterns like "ngram 1=1000", + // which means there are 1000 unigrams. + std::size_t equal_symbol_pos = current_line_.find("="); + if (equal_symbol_pos != std::string::npos) + // Guaranteed spaces around the "=". + current_line_.replace(equal_symbol_pos, 1, " = "); + std::vector col; + SplitStringToVector(current_line_, " \t", true, &col); + if (col.size() == 4 && col[0] == "ngram" && col[2] == "=") { + int32 order, ngram_count = 0; + if (!ConvertStringToInteger(col[1], &order) || + !ConvertStringToInteger(col[3], &ngram_count)) { + PARSE_ERR << "cannot parse ngram count"; + } + if (ngram_counts_.size() <= order) { + ngram_counts_.resize(order); + } + ngram_counts_[order - 1] = ngram_count; + } else { + KALDI_WARN << LineReference() + << ": uninterpretable line in \\data\\ section"; + } + } + + if (ngram_counts_.size() == 0) + PARSE_ERR << "\\data\\ section missing or empty."; + + // Signal that grammar order and n-gram counts are known. + HeaderAvailable(); + + NGram ngram; + ngram.words.reserve(ngram_counts_.size()); + + // Processes "\N-grams:" section. + for (int32 cur_order = 1; cur_order <= ngram_counts_.size(); ++cur_order) { + // Skips n-grams with zero count. + if (ngram_counts_[cur_order - 1] == 0) + KALDI_WARN << "Zero ngram count in ngram order " << cur_order + << "(look for 'ngram " << cur_order << "=0' in the \\data\\ " + << " section). There is possibly a problem with the file."; + + // Must be looking at a \k-grams: directive at this point. + std::ostringstream keyword; + keyword << "\\" << cur_order << "-grams:"; + if (current_line_ != keyword.str()) { + PARSE_ERR << "invalid directive, expecting '" << keyword.str() << "'"; + } + KALDI_LOG << "Reading " << current_line_ << " section."; + + int32 ngram_count = 0; + while (++line_number_, getline(is, current_line_) && !is.eof()) { + if (current_line_.find_first_not_of(" \n\t\r") == std::string::npos) { + continue; + } + if (current_line_[0] == '\\') { + TrimTrailingWhitespace(¤t_line_); + std::ostringstream next_keyword; + next_keyword << "\\" << cur_order + 1 << "-grams:"; + if ((current_line_ != next_keyword.str()) && + (current_line_ != "\\end\\")) { + if (ShouldWarn()) { + KALDI_WARN << "ignoring possible directive '" << current_line_ + << "' expecting '" << next_keyword.str() << "'"; + + if (warning_count_ > 0 && + warning_count_ > static_cast(options_.max_warnings)) { + KALDI_WARN << "Of " << warning_count_ << " parse warnings, " + << options_.max_warnings << " were reported. " + << "Run program with --max-arpa-warnings=-1 " + << "to see all warnings"; + } + } + } else { + break; + } + } + + std::vector col; + SplitStringToVector(current_line_, " \t", true, &col); + + if (col.size() < 1 + cur_order || + col.size() > 2 + cur_order || + (cur_order == ngram_counts_.size() && col.size() != 1 + cur_order)) { + PARSE_ERR << "Invalid n-gram data line"; + } + ++ngram_count; + + // Parse out n-gram logprob and, if present, backoff weight. + if (!ConvertStringToReal(col[0], &ngram.logprob)) { + PARSE_ERR << "invalid n-gram logprob '" << col[0] << "'"; + } + ngram.backoff = 0.0; + if (col.size() > cur_order + 1) { + if (!ConvertStringToReal(col[cur_order + 1], &ngram.backoff)) + PARSE_ERR << "invalid backoff weight '" << col[cur_order + 1] << "'"; + } + // Convert to natural log. + ngram.logprob *= M_LN10; + ngram.backoff *= M_LN10; + + ngram.words.resize(cur_order); + bool skip_ngram = false; + for (int32 index = 0; !skip_ngram && index < cur_order; ++index) { + int32 word; + if (symbols_) { + // Symbol table provided, so symbol labels are expected. + if (options_.oov_handling == ArpaParseOptions::kAddToSymbols) { + word = symbols_->AddSymbol(col[1 + index]); + } else { + word = symbols_->Find(col[1 + index]); + if (word == -1) { // fst::kNoSymbol + switch (options_.oov_handling) { + case ArpaParseOptions::kReplaceWithUnk: + word = options_.unk_symbol; + break; + case ArpaParseOptions::kSkipNGram: + if (ShouldWarn()) + KALDI_WARN << LineReference() << " skipped: word '" + << col[1 + index] << "' not in symbol table"; + skip_ngram = true; + break; + default: + PARSE_ERR << "word '" << col[1 + index] + << "' not in symbol table"; + } + } + } + } else { + // Symbols not provided, LM file should contain integers. + if (!ConvertStringToInteger(col[1 + index], &word) || word < 0) { + PARSE_ERR << "invalid symbol '" << col[1 + index] << "'"; + } + } + // Whichever way we got it, an epsilon is invalid. + if (word == 0) { + PARSE_ERR << "epsilon symbol '" << col[1 + index] + << "' is illegal in ARPA LM"; + } + ngram.words[index] = word; + } + if (!skip_ngram) { + ConsumeNGram(ngram); + } + } + if (ngram_count > ngram_counts_[cur_order - 1]) { + PARSE_ERR << "header said there would be " << ngram_counts_[cur_order - 1] + << " n-grams of order " << cur_order + << ", but we saw more already."; + } + } + + if (current_line_ != "\\end\\") { + PARSE_ERR << "invalid or unexpected directive line, expecting \\end\\"; + } + + if (warning_count_ > 0 && + warning_count_ > static_cast(options_.max_warnings)) { + KALDI_WARN << "Of " << warning_count_ << " parse warnings, " + << options_.max_warnings << " were reported. Run program with " + << "--max-arpa-warnings=-1 to see all warnings"; + } + + current_line_.clear(); + ReadComplete(); + +#undef PARSE_ERR +} + +std::string ArpaFileParser::LineReference() const { + std::ostringstream ss; + ss << "line " << line_number_ << " [" << current_line_ << "]"; + return ss.str(); +} + +bool ArpaFileParser::ShouldWarn() { + return (warning_count_ != -1) && + (++warning_count_ <= static_cast(options_.max_warnings)); +} + +} // namespace kaldi diff --git a/speechx/speechx/kaldi/lm/arpa-file-parser.h b/speechx/speechx/kaldi/lm/arpa-file-parser.h new file mode 100644 index 00000000..99ffba02 --- /dev/null +++ b/speechx/speechx/kaldi/lm/arpa-file-parser.h @@ -0,0 +1,146 @@ +// lm/arpa-file-parser.h + +// Copyright 2014 Guoguo Chen +// Copyright 2016 Smart Action Company LLC (kkm) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_LM_ARPA_FILE_PARSER_H_ +#define KALDI_LM_ARPA_FILE_PARSER_H_ + +#include + +#include +#include + +#include "base/kaldi-types.h" +#include "util/options-itf.h" + +namespace kaldi { + +/** + Options that control ArpaFileParser +*/ +struct ArpaParseOptions { + enum OovHandling { + kRaiseError, ///< Abort on OOV words + kAddToSymbols, ///< Add novel words to the symbol table. + kReplaceWithUnk, ///< Replace OOV words with . + kSkipNGram ///< Skip n-gram with OOV word and continue. + }; + + ArpaParseOptions(): + bos_symbol(-1), eos_symbol(-1), unk_symbol(-1), + oov_handling(kRaiseError), max_warnings(30) { } + + void Register(OptionsItf *opts) { + // Registering only the max_warnings count, since other options are + // treated differently by client programs: some want integer symbols, + // while other are passed words in their command line. + opts->Register("max-arpa-warnings", &max_warnings, + "Maximum warnings to report on ARPA parsing, " + "0 to disable, -1 to show all"); + } + + int32 bos_symbol; ///< Symbol for , Required non-epsilon. + int32 eos_symbol; ///< Symbol for , Required non-epsilon. + int32 unk_symbol; ///< Symbol for , Required for kReplaceWithUnk. + OovHandling oov_handling; ///< How to handle OOV words in the file. + int32 max_warnings; ///< Maximum warnings to report, <0 unlimited. +}; + +/** + A parsed n-gram from ARPA LM file. +*/ +struct NGram { + NGram() : logprob(0.0), backoff(0.0) { } + std::vector words; ///< Symbols in left to right order. + float logprob; ///< Log-prob of the n-gram. + float backoff; ///< log-backoff weight of the n-gram. + ///< Defaults to zero if not specified. +}; + +/** + ArpaFileParser is an abstract base class for ARPA LM file conversion. + + See ConstArpaLmBuilder and ArpaLmCompiler for usage examples. +*/ +class ArpaFileParser { + public: + /// Constructs the parser with the given options and optional symbol table. + /// If symbol table is provided, then the file should contain text n-grams, + /// and the words are mapped to symbols through it. bos_symbol and + /// eos_symbol in the options structure must be valid symbols in the table, + /// and so must be unk_symbol if provided. The table is not owned by the + /// parser, but may be augmented, if oov_handling is set to kAddToSymbols. + /// If symbol table is a null pointer, the file should contain integer + /// symbol values, and oov_handling has no effect. bos_symbol and eos_symbol + /// must be valid symbols still. + ArpaFileParser(const ArpaParseOptions& options, fst::SymbolTable* symbols); + virtual ~ArpaFileParser(); + + /// Read ARPA LM file from a stream. + void Read(std::istream &is); + + /// Parser options. + const ArpaParseOptions& Options() const { return options_; } + + protected: + /// Override called before reading starts. This is the point to prepare + /// any state in the derived class. + virtual void ReadStarted() { } + + /// Override function called to signal that ARPA header with the expected + /// number of n-grams has been read, and ngram_counts() is now valid. + virtual void HeaderAvailable() { } + + /// Pure override that must be implemented to process current n-gram. The + /// n-grams are sent in the file order, which guarantees that all + /// (k-1)-grams are processed before the first k-gram is. + virtual void ConsumeNGram(const NGram&) = 0; + + /// Override function called after the last n-gram has been consumed. + virtual void ReadComplete() { } + + /// Read-only access to symbol table. Not owned, do not make public. + const fst::SymbolTable* Symbols() const { return symbols_; } + + /// Inside ConsumeNGram(), provides the current line number. + int32 LineNumber() const { return line_number_; } + + /// Inside ConsumeNGram(), returns a formatted reference to the line being + /// compiled, to print out as part of diagnostics. + std::string LineReference() const; + + /// Increments warning count, and returns true if a warning should be + /// printed or false if the count has exceeded the set maximum. + bool ShouldWarn(); + + /// N-gram counts. Valid from the point when HeaderAvailable() is called. + const std::vector& NgramCounts() const { return ngram_counts_; } + + private: + ArpaParseOptions options_; + fst::SymbolTable* symbols_; // the pointer is not owned here. + int32 line_number_; + uint32 warning_count_; + std::string current_line_; + std::vector ngram_counts_; +}; + +} // namespace kaldi + +#endif // KALDI_LM_ARPA_FILE_PARSER_H_ diff --git a/speechx/speechx/kaldi/lm/arpa-lm-compiler.cc b/speechx/speechx/kaldi/lm/arpa-lm-compiler.cc new file mode 100644 index 00000000..47bd20d4 --- /dev/null +++ b/speechx/speechx/kaldi/lm/arpa-lm-compiler.cc @@ -0,0 +1,377 @@ +// lm/arpa-lm-compiler.cc + +// Copyright 2009-2011 Gilles Boulianne +// Copyright 2016 Smart Action LLC (kkm) +// Copyright 2017 Xiaohui Zhang + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "base/kaldi-math.h" +#include "lm/arpa-lm-compiler.h" +#include "util/stl-utils.h" +#include "util/text-utils.h" +#include "fstext/remove-eps-local.h" + +namespace kaldi { + +class ArpaLmCompilerImplInterface { + public: + virtual ~ArpaLmCompilerImplInterface() { } + virtual void ConsumeNGram(const NGram& ngram, bool is_highest) = 0; +}; + +namespace { + +typedef int32 StateId; +typedef int32 Symbol; + +// GeneralHistKey can represent state history in an arbitrarily large n +// n-gram model with symbol ids fitting int32. +class GeneralHistKey { + public: + // Construct key from being and end iterators. + template + GeneralHistKey(InputIt begin, InputIt end) : vector_(begin, end) { } + // Construct empty history key. + GeneralHistKey() : vector_() { } + // Return tails of the key as a GeneralHistKey. The tails of an n-gram + // w[1..n] is the sequence w[2..n] (and the heads is w[1..n-1], but the + // key class does not need this operartion). + GeneralHistKey Tails() const { + return GeneralHistKey(vector_.begin() + 1, vector_.end()); + } + // Keys are equal if represent same state. + friend bool operator==(const GeneralHistKey& a, const GeneralHistKey& b) { + return a.vector_ == b.vector_; + } + // Public typename HashType for hashing. + struct HashType : public std::unary_function { + size_t operator()(const GeneralHistKey& key) const { + return VectorHasher().operator()(key.vector_); + } + }; + + private: + std::vector vector_; +}; + +// OptimizedHistKey combines 3 21-bit symbol ID values into one 64-bit +// machine word. allowing significant memory reduction and some runtime +// benefit over GeneralHistKey. Since 3 symbols are enough to track history +// in a 4-gram model, this optimized key is used for smaller models with up +// to 4-gram and symbol values up to 2^21-1. +// +// See GeneralHistKey for interface requirements of a key class. +class OptimizedHistKey { + public: + enum { + kShift = 21, // 21 * 3 = 63 bits for data. + kMaxData = (1 << kShift) - 1 + }; + template + OptimizedHistKey(InputIt begin, InputIt end) : data_(0) { + for (uint32 shift = 0; begin != end; ++begin, shift += kShift) { + data_ |= static_cast(*begin) << shift; + } + } + OptimizedHistKey() : data_(0) { } + OptimizedHistKey Tails() const { + return OptimizedHistKey(data_ >> kShift); + } + friend bool operator==(const OptimizedHistKey& a, const OptimizedHistKey& b) { + return a.data_ == b.data_; + } + struct HashType : public std::unary_function { + size_t operator()(const OptimizedHistKey& key) const { return key.data_; } + }; + + private: + explicit OptimizedHistKey(uint64 data) : data_(data) { } + uint64 data_; +}; + +} // namespace + +template +class ArpaLmCompilerImpl : public ArpaLmCompilerImplInterface { + public: + ArpaLmCompilerImpl(ArpaLmCompiler* parent, fst::StdVectorFst* fst, + Symbol sub_eps); + + virtual void ConsumeNGram(const NGram &ngram, bool is_highest); + + private: + StateId AddStateWithBackoff(HistKey key, float backoff); + void CreateBackoff(HistKey key, StateId state, float weight); + + ArpaLmCompiler *parent_; // Not owned. + fst::StdVectorFst* fst_; // Not owned. + Symbol bos_symbol_; + Symbol eos_symbol_; + Symbol sub_eps_; + + StateId eos_state_; + typedef unordered_map HistoryMap; + HistoryMap history_; +}; + +template +ArpaLmCompilerImpl::ArpaLmCompilerImpl( + ArpaLmCompiler* parent, fst::StdVectorFst* fst, Symbol sub_eps) + : parent_(parent), fst_(fst), bos_symbol_(parent->Options().bos_symbol), + eos_symbol_(parent->Options().eos_symbol), sub_eps_(sub_eps) { + // The algorithm maintains state per history. The 0-gram is a special state + // for empty history. All unigrams (including BOS) backoff into this state. + StateId zerogram = fst_->AddState(); + history_[HistKey()] = zerogram; + + // Also, if is not treated as epsilon, create a common end state for + // all transitions accepting the , since they do not back off. This small + // optimization saves about 2% states in an average grammar. + if (sub_eps_ == 0) { + eos_state_ = fst_->AddState(); + fst_->SetFinal(eos_state_, 0); + } +} + +template +void ArpaLmCompilerImpl::ConsumeNGram(const NGram &ngram, + bool is_highest) { + // Generally, we do the following. Suppose we are adding an n-gram "A B + // C". Then find the node for "A B", add a new node for "A B C", and connect + // them with the arc accepting "C" with the specified weight. Also, add a + // backoff arc from the new "A B C" node to its backoff state "B C". + // + // Two notable exceptions are the highest order n-grams, and final n-grams. + // + // When adding a highest order n-gram (e. g., our "A B C" is in a 3-gram LM), + // the following optimization is performed. There is no point adding a node + // for "A B C" with a "C" arc from "A B", since there will be no other + // arcs ingoing to this node, and an epsilon backoff arc into the backoff + // model "B C", with the weight of \bar{1}. To save a node, create an arc + // accepting "C" directly from "A B" to "B C". This saves as many nodes + // as there are the highest order n-grams, which is typically about half + // the size of a large 3-gram model. + // + // Indeed, this does not apply to n-grams ending in EOS, since they do not + // back off. These are special, as they do not have a back-off state, and + // the node for "(..anything..) " is always final. These are handled + // in one of the two possible ways, If symbols and are being + // replaced by epsilons, neither node nor arc is created, and the logprob + // of the n-gram is applied to its source node as final weight. If and + // are preserved, then a special final node for is allocated and + // used as the destination of the "" acceptor arc. + HistKey heads(ngram.words.begin(), ngram.words.end() - 1); + typename HistoryMap::iterator source_it = history_.find(heads); + if (source_it == history_.end()) { + // There was no "A B", therefore the probability of "A B C" is zero. + // Print a warning and discard current n-gram. + if (parent_->ShouldWarn()) + KALDI_WARN << parent_->LineReference() + << " skipped: no parent (n-1)-gram exists"; + return; + } + + StateId source = source_it->second; + StateId dest; + Symbol sym = ngram.words.back(); + float weight = -ngram.logprob; + if (sym == sub_eps_ || sym == 0) { + KALDI_ERR << " or disambiguation symbol " << sym << "found in the ARPA file. "; + } + if (sym == eos_symbol_) { + if (sub_eps_ == 0) { + // Keep as a real symbol when not substituting. + dest = eos_state_; + } else { + // Treat as if it was epsilon: mark source final, with the weight + // of the n-gram. + fst_->SetFinal(source, weight); + return; + } + } else { + // For the highest order n-gram, this may find an existing state, for + // non-highest, will create one (unless there are duplicate n-grams + // in the grammar, which cannot be reliably detected if highest order, + // so we better do not do that at all). + dest = AddStateWithBackoff( + HistKey(ngram.words.begin() + (is_highest ? 1 : 0), + ngram.words.end()), + -ngram.backoff); + } + + if (sym == bos_symbol_) { + weight = 0; // Accepting is always free. + if (sub_eps_ == 0) { + // is as a real symbol, only accepted in the start state. + source = fst_->AddState(); + fst_->SetStart(source); + } else { + // The new state for unigram history *is* the start state. + fst_->SetStart(dest); + return; + } + } + + // Add arc from source to dest, whichever way it was found. + fst_->AddArc(source, fst::StdArc(sym, sym, weight, dest)); + return; +} + +// Find or create a new state for n-gram defined by key, and ensure it has a +// backoff transition. The key is either the current n-gram for all but +// highest orders, or the tails of the n-gram for the highest order. The +// latter arises from the chain-collapsing optimization described above. +template +StateId ArpaLmCompilerImpl::AddStateWithBackoff(HistKey key, + float backoff) { + typename HistoryMap::iterator dest_it = history_.find(key); + if (dest_it != history_.end()) { + // Found an existing state in the history map. Invariant: if the state in + // the map, then its backoff arc is in the FST. We are done. + return dest_it->second; + } + // Otherwise create a new state and its backoff arc, and register in the map. + StateId dest = fst_->AddState(); + history_[key] = dest; + CreateBackoff(key.Tails(), dest, backoff); + return dest; +} + +// Create a backoff arc for a state. Key is a backoff destination that may or +// may not exist. When the destination is not found, naturally fall back to +// the lower order model, and all the way down until one is found (since the +// 0-gram model is always present, the search is guaranteed to terminate). +template +inline void ArpaLmCompilerImpl::CreateBackoff( + HistKey key, StateId state, float weight) { + typename HistoryMap::iterator dest_it = history_.find(key); + while (dest_it == history_.end()) { + key = key.Tails(); + dest_it = history_.find(key); + } + + // The arc should transduce either or #0 to , depending on the + // epsilon substitution mode. This is the only case when input and output + // label may differ. + fst_->AddArc(state, fst::StdArc(sub_eps_, 0, weight, dest_it->second)); +} + +ArpaLmCompiler::~ArpaLmCompiler() { + if (impl_ != NULL) + delete impl_; +} + +void ArpaLmCompiler::HeaderAvailable() { + KALDI_ASSERT(impl_ == NULL); + // Use optimized implementation if the grammar is 4-gram or less, and the + // maximum attained symbol id will fit into the optimized range. + int64 max_symbol = 0; + if (Symbols() != NULL) + max_symbol = Symbols()->AvailableKey() - 1; + // If augmenting the symbol table, assume the worst case when all words in + // the model being read are novel. + if (Options().oov_handling == ArpaParseOptions::kAddToSymbols) + max_symbol += NgramCounts()[0]; + + if (NgramCounts().size() <= 4 && max_symbol < OptimizedHistKey::kMaxData) { + impl_ = new ArpaLmCompilerImpl(this, &fst_, sub_eps_); + } else { + impl_ = new ArpaLmCompilerImpl(this, &fst_, sub_eps_); + KALDI_LOG << "Reverting to slower state tracking because model is large: " + << NgramCounts().size() << "-gram with symbols up to " + << max_symbol; + } +} + +void ArpaLmCompiler::ConsumeNGram(const NGram &ngram) { + // is invalid in tails, in heads of an n-gram. + for (int i = 0; i < ngram.words.size(); ++i) { + if ((i > 0 && ngram.words[i] == Options().bos_symbol) || + (i + 1 < ngram.words.size() + && ngram.words[i] == Options().eos_symbol)) { + if (ShouldWarn()) + KALDI_WARN << LineReference() + << " skipped: n-gram has invalid BOS/EOS placement"; + return; + } + } + + bool is_highest = ngram.words.size() == NgramCounts().size(); + impl_->ConsumeNGram(ngram, is_highest); +} + +void ArpaLmCompiler::RemoveRedundantStates() { + fst::StdArc::Label backoff_symbol = sub_eps_; + if (backoff_symbol == 0) { + // The method of removing redundant states implemented in this function + // leads to slow determinization of L o G when people use the older style of + // usage of arpa2fst where the --disambig-symbol option was not specified. + // The issue seems to be that it creates a non-deterministic FST, while G is + // supposed to be deterministic. By 'return'ing below, we just disable this + // method if people were using an older script. This method isn't really + // that consequential anyway, and people will move to the newer-style + // scripts (see current utils/format_lm.sh), so this isn't much of a + // problem. + return; + } + + fst::StdArc::StateId num_states = fst_.NumStates(); + + + // replace the #0 symbols on the input of arcs out of redundant states (states + // that are not final and have only a backoff arc leaving them), with . + for (fst::StdArc::StateId state = 0; state < num_states; state++) { + if (fst_.NumArcs(state) == 1 && fst_.Final(state) == fst::TropicalWeight::Zero()) { + fst::MutableArcIterator iter(&fst_, state); + fst::StdArc arc = iter.Value(); + if (arc.ilabel == backoff_symbol) { + arc.ilabel = 0; + iter.SetValue(arc); + } + } + } + + // we could call fst::RemoveEps, and it would have the same effect in normal + // cases, where backoff_symbol != 0 and there are no epsilons in unexpected + // places, but RemoveEpsLocal is a bit safer in case something weird is going + // on; it guarantees not to blow up the FST. + fst::RemoveEpsLocal(&fst_); + KALDI_LOG << "Reduced num-states from " << num_states << " to " + << fst_.NumStates(); +} + +void ArpaLmCompiler::Check() const { + if (fst_.Start() == fst::kNoStateId) { + KALDI_ERR << "Arpa file did not contain the beginning-of-sentence symbol " + << Symbols()->Find(Options().bos_symbol) << "."; + } +} + +void ArpaLmCompiler::ReadComplete() { + fst_.SetInputSymbols(Symbols()); + fst_.SetOutputSymbols(Symbols()); + RemoveRedundantStates(); + Check(); +} + +} // namespace kaldi diff --git a/speechx/speechx/kaldi/lm/arpa-lm-compiler.h b/speechx/speechx/kaldi/lm/arpa-lm-compiler.h new file mode 100644 index 00000000..67a18273 --- /dev/null +++ b/speechx/speechx/kaldi/lm/arpa-lm-compiler.h @@ -0,0 +1,65 @@ +// lm/arpa-lm-compiler.h + +// Copyright 2009-2011 Gilles Boulianne +// Copyright 2016 Smart Action LLC (kkm) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_LM_ARPA_LM_COMPILER_H_ +#define KALDI_LM_ARPA_LM_COMPILER_H_ + +#include + +#include "lm/arpa-file-parser.h" + +namespace kaldi { + +class ArpaLmCompilerImplInterface; + +class ArpaLmCompiler : public ArpaFileParser { + public: + ArpaLmCompiler(const ArpaParseOptions& options, int sub_eps, + fst::SymbolTable* symbols) + : ArpaFileParser(options, symbols), + sub_eps_(sub_eps), impl_(NULL) { + } + ~ArpaLmCompiler(); + + const fst::StdVectorFst& Fst() const { return fst_; } + fst::StdVectorFst* MutableFst() { return &fst_; } + + protected: + // ArpaFileParser overrides. + virtual void HeaderAvailable(); + virtual void ConsumeNGram(const NGram& ngram); + virtual void ReadComplete(); + + + private: + // this function removes states that only have a backoff arc coming + // out of them. + void RemoveRedundantStates(); + void Check() const; + + int sub_eps_; + ArpaLmCompilerImplInterface* impl_; // Owned. + fst::StdVectorFst fst_; + template friend class ArpaLmCompilerImpl; +}; + +} // namespace kaldi + +#endif // KALDI_LM_ARPA_LM_COMPILER_H_ diff --git a/speechx/tools/lmbin/CMakeLists.txt b/speechx/speechx/kaldi/lmbin/CMakeLists.txt similarity index 64% rename from speechx/tools/lmbin/CMakeLists.txt rename to speechx/speechx/kaldi/lmbin/CMakeLists.txt index 277e2077..2b0932f7 100644 --- a/speechx/tools/lmbin/CMakeLists.txt +++ b/speechx/speechx/kaldi/lmbin/CMakeLists.txt @@ -1,5 +1,4 @@ -cmake_minimum_required(VERSION 3.14 FATAL_ERROR) add_executable(arpa2fst ${CMAKE_CURRENT_SOURCE_DIR}/arpa2fst.cc) target_include_directories(arpa2fst PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) -target_link_libraries(arpa2fst ) +target_link_libraries(arpa2fst PUBLIC kaldi-lm glog gflags fst) diff --git a/speechx/tools/lmbin/arpa2fst.cc b/speechx/speechx/kaldi/lmbin/arpa2fst.cc similarity index 100% rename from speechx/tools/lmbin/arpa2fst.cc rename to speechx/speechx/kaldi/lmbin/arpa2fst.cc