parent
2fbb7ec569
commit
daadec0c63
@ -0,0 +1,32 @@
|
||||
# customized Auto Speech Recognition
|
||||
|
||||
## introduction
|
||||
those scripts are tutorials to show you how make your own decoding graph.
|
||||
|
||||
eg:
|
||||
* G with slot: 打车到 "address_slot"。
|
||||
![](https://ai-studio-static-online.cdn.bcebos.com/28d9ef132a7f47a895a65ae9e5c4f55b8f472c9f3dd24be8a2e66e0b88b173a4)
|
||||
|
||||
* this is address slot wfst, you can add the address which want to recognize.
|
||||
![](https://ai-studio-static-online.cdn.bcebos.com/47c89100ef8c465bac733605ffc53d76abefba33d62f4d818d351f8cea3c8fe2)
|
||||
|
||||
* after replace operation, G = fstreplace(G_with_slot, address_slot), we will get the customized graph.
|
||||
![](https://ai-studio-static-online.cdn.bcebos.com/60a3095293044f10b73039ab10c7950d139a6717580a44a3ba878c6e74de402b)
|
||||
|
||||
those operations are in the scripts, please check out. we will lanuch more detail scripts.
|
||||
|
||||
## How to run
|
||||
|
||||
```
|
||||
bash run.sh
|
||||
```
|
||||
|
||||
## Results
|
||||
|
||||
### CTC WFST
|
||||
|
||||
```
|
||||
Overall -> 1.23 % N=1134 C=1126 S=6 D=2 I=6
|
||||
Mandarin -> 1.24 % N=1132 C=1124 S=6 D=2 I=6
|
||||
English -> 0.00 % N=2 C=2 S=0 D=0 I=0
|
||||
```
|
@ -0,0 +1,89 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2015 Yajie Miao (Carnegie Mellon University)
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
# See the Apache 2 License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This script compiles the lexicon and CTC tokens into FSTs. FST compiling slightly differs between the
|
||||
# phoneme and character-based lexicons.
|
||||
set -eo pipefail
|
||||
. utils/parse_options.sh
|
||||
|
||||
if [ $# -ne 3 ]; then
|
||||
echo "usage: utils/fst/compile_lexicon_token_fst.sh <dict-src-dir> <tmp-dir> <lang-dir>"
|
||||
echo "e.g.: utils/fst/compile_lexicon_token_fst.sh data/local/dict data/local/lang_tmp data/lang"
|
||||
echo "<dict-src-dir> should contain the following files:"
|
||||
echo "lexicon.txt lexicon_numbers.txt units.txt"
|
||||
echo "options: "
|
||||
exit 1;
|
||||
fi
|
||||
|
||||
srcdir=$1
|
||||
tmpdir=$2
|
||||
dir=$3
|
||||
mkdir -p $dir $tmpdir
|
||||
|
||||
[ -f path.sh ] && . ./path.sh
|
||||
|
||||
cp $srcdir/units.txt $dir
|
||||
|
||||
# Add probabilities to lexicon entries. There is in fact no point of doing this here since all the entries have 1.0.
|
||||
# But utils/make_lexicon_fst.pl requires a probabilistic version, so we just leave it as it is.
|
||||
perl -ape 's/(\S+\s+)(.+)/${1}1.0\t$2/;' < $srcdir/lexicon.txt > $tmpdir/lexiconp.txt || exit 1;
|
||||
|
||||
# Add disambiguation symbols to the lexicon. This is necessary for determinizing the composition of L.fst and G.fst.
|
||||
# Without these symbols, determinization will fail.
|
||||
# default first disambiguation is #1
|
||||
ndisambig=`utils/fst/add_lex_disambig.pl $tmpdir/lexiconp.txt $tmpdir/lexiconp_disambig.txt`
|
||||
# add #0 (#0 reserved for symbol in grammar).
|
||||
ndisambig=$[$ndisambig+1];
|
||||
|
||||
( for n in `seq 0 $ndisambig`; do echo '#'$n; done ) > $tmpdir/disambig.list
|
||||
|
||||
# Get the full list of CTC tokens used in FST. These tokens include <eps>, the blank <blk>,
|
||||
# the actual model unit, and the disambiguation symbols.
|
||||
cat $srcdir/units.txt | awk '{print $1}' > $tmpdir/units.list
|
||||
(echo '<eps>';) | cat - $tmpdir/units.list $tmpdir/disambig.list | awk '{print $1 " " (NR-1)}' > $dir/tokens.txt
|
||||
|
||||
# ctc_token_fst_corrected is too big and too slow for character based chinese modeling,
|
||||
# so here just use simple ctc_token_fst
|
||||
utils/fst/ctc_token_fst.py --token_file $dir/tokens.txt | \
|
||||
fstcompile --isymbols=$dir/tokens.txt --osymbols=$dir/tokens.txt --keep_isymbols=false --keep_osymbols=false | \
|
||||
fstarcsort --sort_type=olabel > $dir/T.fst || exit 1;
|
||||
|
||||
# Encode the words with indices. Will be used in lexicon and language model FST compiling.
|
||||
cat $tmpdir/lexiconp.txt | awk '{print $1}' | sort | awk '
|
||||
BEGIN {
|
||||
print "<eps> 0";
|
||||
}
|
||||
{
|
||||
printf("%s %d\n", $1, NR);
|
||||
}
|
||||
END {
|
||||
printf("#0 %d\n", NR+1);
|
||||
printf("<s> %d\n", NR+2);
|
||||
printf("</s> %d\n", NR+3);
|
||||
printf("ROOT %d\n", NR+4);
|
||||
}' > $dir/words.txt || exit 1;
|
||||
|
||||
# Now compile the lexicon FST. Depending on the size of your lexicon, it may take some time.
|
||||
token_disambig_symbol=`grep \#0 $dir/tokens.txt | awk '{print $2}'`
|
||||
word_disambig_symbol=`grep \#0 $dir/words.txt | awk '{print $2}'`
|
||||
|
||||
utils/fst/make_lexicon_fst.pl --pron-probs $tmpdir/lexiconp_disambig.txt 0 "sil" '#'$ndisambig | \
|
||||
fstcompile --isymbols=$dir/tokens.txt --osymbols=$dir/words.txt \
|
||||
--keep_isymbols=false --keep_osymbols=false | \
|
||||
fstaddselfloops "echo $token_disambig_symbol |" "echo $word_disambig_symbol |" | \
|
||||
fstarcsort --sort_type=olabel > $dir/L.fst || exit 1;
|
||||
|
||||
echo "Lexicon and Token FSTs compiling succeeded"
|
@ -0,0 +1,74 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License
|
||||
|
||||
graph_slot=$1
|
||||
dir=$2
|
||||
|
||||
[ -f path.sh ] && . ./path.sh
|
||||
|
||||
sym=$dir/../lang/words.txt
|
||||
cat > $dir/address_slot.txt <<EOF
|
||||
0 1 南山 南山
|
||||
0 1 南京 南京
|
||||
0 1 光明 光明
|
||||
0 1 龙岗 龙岗
|
||||
0 1 北苑 北苑
|
||||
0 1 北京 北京
|
||||
0 1 酒店 酒店
|
||||
0 1 合肥 合肥
|
||||
0 1 望京搜后 望京搜后
|
||||
0 1 地铁站 地铁站
|
||||
0 1 海淀黄庄 海淀黄庄
|
||||
0 1 佛山 佛山
|
||||
0 1 广州 广州
|
||||
0 1 苏州 苏州
|
||||
0 1 百度大厦 百度大厦
|
||||
0 1 龙泽苑东区 龙泽苑东区
|
||||
0 1 首都机场 首都机场
|
||||
0 1 朝来家园 朝来家园
|
||||
0 1 深大 深大
|
||||
0 1 双龙 双龙
|
||||
0 1 公司 公司
|
||||
0 1 上海 上海
|
||||
0 1 家 家
|
||||
0 1 机场 机场
|
||||
0 1 华祝 华祝
|
||||
0 1 上海虹桥 上海虹桥
|
||||
0 2 检验 检验
|
||||
2 1 中心 中心
|
||||
0 3 苏州 苏州
|
||||
3 1 街 街
|
||||
3 8 高铁 高铁
|
||||
8 1 站 站
|
||||
0 4 杭州 杭州
|
||||
4 1 东站 东站
|
||||
4 1 <eps> <eps>
|
||||
0 5 上海 上海
|
||||
0 5 北京 北京
|
||||
0 5 合肥 合肥
|
||||
5 1 南站 南站
|
||||
0 6 立水 立水
|
||||
6 1 桥 桥
|
||||
0 7 青岛 青岛
|
||||
7 1 站 站
|
||||
1
|
||||
EOF
|
||||
|
||||
fstcompile --isymbols=$sym --osymbols=$sym $dir/address_slot.txt $dir/address_slot.fst
|
||||
fstcompile --isymbols=$sym --osymbols=$sym $graph_slot/time_slot.txt $dir/time_slot.fst
|
||||
fstcompile --isymbols=$sym --osymbols=$sym $graph_slot/date_slot.txt $dir/date_slot.fst
|
||||
fstcompile --isymbols=$sym --osymbols=$sym $graph_slot/money_slot.txt $dir/money_slot.fst
|
||||
fstcompile --isymbols=$sym --osymbols=$sym $graph_slot/year_slot.txt $dir/year_slot.fst
|
@ -0,0 +1,61 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License
|
||||
|
||||
lm=$1
|
||||
lang=$2
|
||||
tgt_lang=$3
|
||||
|
||||
unset GREP_OPTIONS
|
||||
|
||||
sym=$lang/words.txt
|
||||
arpa_lm=$lm/lm.arpa
|
||||
# Compose the language model to FST
|
||||
cat $arpa_lm | \
|
||||
grep -v '<s> <s>' | \
|
||||
grep -v '</s> <s>' | \
|
||||
grep -v '</s> </s>' | \
|
||||
grep -v -i '<unk>' | \
|
||||
grep -v -i '<spoken_noise>' | \
|
||||
arpa2fst --read-symbol-table=$sym --keep-symbols=true - | fstprint | \
|
||||
utils/fst/eps2disambig.pl | utils/fst/s2eps.pl | fstcompile --isymbols=$sym \
|
||||
--osymbols=$sym --keep_isymbols=false --keep_osymbols=false | \
|
||||
fstrmepsilon | fstarcsort --sort_type=ilabel > $tgt_lang/G_with_slot.fst
|
||||
|
||||
root_label=`grep ROOT $sym | awk '{print $2}'`
|
||||
address_slot_label=`grep \<ADDRESS_SLOT\> $sym | awk '{print $2}'`
|
||||
time_slot_label=`grep \<TIME_SLOT\> $sym | awk '{print $2}'`
|
||||
date_slot_label=`grep \<DATE_SLOT\> $sym | awk '{print $2}'`
|
||||
money_slot_label=`grep \<MONEY_SLOT\> $sym | awk '{print $2}'`
|
||||
year_slot_label=`grep \<YEAR_SLOT\> $sym | awk '{print $2}'`
|
||||
|
||||
fstisstochastic $tgt_lang/G_with_slot.fst
|
||||
|
||||
fstreplace --epsilon_on_replace $tgt_lang/G_with_slot.fst \
|
||||
$root_label $tgt_lang/address_slot.fst $address_slot_label \
|
||||
$tgt_lang/date_slot.fst $date_slot_label \
|
||||
$tgt_lang/money_slot.fst $money_slot_label \
|
||||
$tgt_lang/time_slot.fst $time_slot_label \
|
||||
$tgt_lang/year_slot.fst $year_slot_label $tgt_lang/G.fst
|
||||
|
||||
fstisstochastic $tgt_lang/G.fst
|
||||
|
||||
# Compose the token, lexicon and language-model FST into the final decoding graph
|
||||
fsttablecompose $lang/L.fst $tgt_lang/G.fst | fstdeterminizestar --use-log=true | \
|
||||
fstminimizeencoded | fstarcsort --sort_type=ilabel > $tgt_lang/LG.fst || exit 1;
|
||||
fsttablecompose $lang/T.fst $tgt_lang/LG.fst > $tgt_lang/TLG.fst || exit 1;
|
||||
rm $tgt_lang/LG.fst
|
||||
|
||||
echo "Composing decoding graph TLG.fst succeeded"
|
@ -0,0 +1,55 @@
|
||||
#!/bin/bash
|
||||
|
||||
# To be run from one directory above this script.
|
||||
. ./path.sh
|
||||
src=ds2_graph_with_slot
|
||||
text=$src/train_text
|
||||
lexicon=$src/local/dict/lexicon.txt
|
||||
|
||||
dir=$src/local/lm
|
||||
mkdir -p $dir
|
||||
|
||||
for f in "$text" "$lexicon"; do
|
||||
[ ! -f $x ] && echo "$0: No such file $f" && exit 1;
|
||||
done
|
||||
|
||||
# Check SRILM tools
|
||||
if ! which ngram-count > /dev/null; then
|
||||
pushd $MAIN_ROOT/tools
|
||||
make srilm.done
|
||||
popd
|
||||
fi
|
||||
|
||||
# This script takes no arguments. It assumes you have already run
|
||||
# It takes as input the files
|
||||
# data/local/lm/text
|
||||
# data/local/dict/lexicon.txt
|
||||
|
||||
|
||||
cleantext=$dir/text.no_oov
|
||||
|
||||
cat $text | awk -v lex=$lexicon 'BEGIN{while((getline<lex) >0){ seen[$1]=1; } }
|
||||
{for(n=1; n<=NF;n++) { if (seen[$n]) { printf("%s ", $n); } else {printf("<SPOKEN_NOISE> ");} } printf("\n");}' \
|
||||
> $cleantext || exit 1;
|
||||
|
||||
cat $cleantext | awk '{for(n=2;n<=NF;n++) print $n; }' | sort | uniq -c | \
|
||||
sort -nr > $dir/word.counts || exit 1;
|
||||
# Get counts from acoustic training transcripts, and add one-count
|
||||
# for each word in the lexicon (but not silence, we don't want it
|
||||
# in the LM-- we'll add it optionally later).
|
||||
cat $cleantext | awk '{for(n=2;n<=NF;n++) print $n; }' | \
|
||||
cat - <(grep -w -v '!SIL' $lexicon | awk '{print $1}') | \
|
||||
sort | uniq -c | sort -nr > $dir/unigram.counts || exit 1;
|
||||
|
||||
# filter the words which are not in the text
|
||||
cat $dir/unigram.counts | awk '$1>1{print $0}' | awk '{print $2}' | cat - <(echo "<s>"; echo "</s>" ) > $dir/wordlist
|
||||
|
||||
# kaldi_lm results
|
||||
mkdir -p $dir
|
||||
cat $cleantext | awk '{for(n=2;n<=NF;n++){ printf $n; if(n<NF) printf " "; else print ""; }}' > $dir/train
|
||||
|
||||
ngram-count -text $dir/train -order 3 -limit-vocab -vocab $dir/wordlist -unk \
|
||||
-map-unk "<UNK>" -gt3max 0 -gt2max 0 -gt1max 0 -lm $dir/lm.arpa
|
||||
|
||||
#ngram-count -text $dir/train -order 3 -limit-vocab -vocab $dir/wordlist -unk \
|
||||
# -map-unk "<UNK>" -lm $dir/lm2.arpa
|
@ -0,0 +1,17 @@
|
||||
# This contains the locations of binarys build required for running the examples.
|
||||
|
||||
MAIN_ROOT=`realpath $PWD/../../../`
|
||||
SPEECHX_ROOT=`realpath $MAIN_ROOT/speechx`
|
||||
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
|
||||
|
||||
export LC_AL=C
|
||||
|
||||
# srilm
|
||||
export LIBLBFGS=${MAIN_ROOT}/tools/liblbfgs-1.10
|
||||
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH:-}:${LIBLBFGS}/lib/.libs
|
||||
export SRILM=${MAIN_ROOT}/tools/srilm
|
||||
|
||||
# kaldi lm
|
||||
KALDI_DIR=$SPEECHX_ROOT/build/speechx/kaldi/
|
||||
OPENFST_DIR=$SPEECHX_ROOT/fc_patch/openfst-build/src
|
||||
export PATH=${PATH}:${SRILM}/bin:${SRILM}/bin/i686-m64:$KALDI_DIR/lmbin:$KALDI_DIR/fstbin:$OPENFST_DIR/bin:$SPEECHX_EXAMPLES/ds2_ol/decoder
|
@ -0,0 +1,88 @@
|
||||
#!/bin/bash
|
||||
set +x
|
||||
set -e
|
||||
|
||||
export GLOG_logtostderr=1
|
||||
|
||||
. ./path.sh || exit 1;
|
||||
|
||||
# ds2 means deepspeech2 (acoutic model type)
|
||||
dir=$PWD/ds2_graph_with_slot
|
||||
data=$PWD/data
|
||||
stage=0
|
||||
stop_stage=10
|
||||
|
||||
mkdir -p $dir
|
||||
|
||||
model_dir=$PWD/resource/model
|
||||
vocab=$model_dir/vocab.txt
|
||||
cmvn=$data/cmvn.ark
|
||||
text_with_slot=$data/text_with_slot
|
||||
resource=$PWD/resource
|
||||
# download resource
|
||||
if [ ! -f $cmvn ]; then
|
||||
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/resource.tar.gz
|
||||
tar xzfv resource.tar.gz
|
||||
ln -s ./resource/data .
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
# make dict
|
||||
unit_file=$vocab
|
||||
mkdir -p $dir/local/dict
|
||||
cp $unit_file $dir/local/dict/units.txt
|
||||
cp $text_with_slot $dir/train_text
|
||||
utils/fst/prepare_dict.py --unit_file $unit_file --in_lexicon $data/lexicon.txt \
|
||||
--out_lexicon $dir/local/dict/lexicon.txt
|
||||
# add slot to lexicon, just in case the lm training script filter the slot.
|
||||
echo "<MONEY_SLOT> 一" >> $dir/local/dict/lexicon.txt
|
||||
echo "<DATE_SLOT> 一" >> $dir/local/dict/lexicon.txt
|
||||
echo "<ADDRESS_SLOT> 一" >> $dir/local/dict/lexicon.txt
|
||||
echo "<YEAR_SLOT> 一" >> $dir/local/dict/lexicon.txt
|
||||
echo "<TIME_SLOT> 一" >> $dir/local/dict/lexicon.txt
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
# train lm
|
||||
lm=$dir/local/lm
|
||||
mkdir -p $lm
|
||||
# this script is different with the common lm training script
|
||||
local/train_lm_with_slot.sh
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
# make T & L
|
||||
local/compile_lexicon_token_fst.sh $dir/local/dict $dir/local/tmp $dir/local/lang
|
||||
mkdir -p $dir/local/lang_test
|
||||
# make slot graph
|
||||
local/mk_slot_graph.sh $resource/graph $dir/local/lang_test
|
||||
# make TLG
|
||||
local/mk_tlg_with_slot.sh $dir/local/lm $dir/local/lang $dir/local/lang_test || exit 1;
|
||||
mv $dir/local/lang_test/TLG.fst $dir/local/lang/
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
||||
# test TLG
|
||||
model_dir=$PWD/resource/model
|
||||
cmvn=$data/cmvn.ark
|
||||
wav_scp=$data/wav.scp
|
||||
graph=$dir/local/lang
|
||||
|
||||
recognizer_test_main \
|
||||
--wav_rspecifier=scp:$wav_scp \
|
||||
--cmvn_file=$cmvn \
|
||||
--streaming_chunk=30 \
|
||||
--use_fbank=true \
|
||||
--model_path=$model_dir/avg_10.jit.pdmodel \
|
||||
--param_path=$model_dir/avg_10.jit.pdiparams \
|
||||
--model_cache_shapes="5-1-2048,5-1-2048" \
|
||||
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
|
||||
--word_symbol_table=$graph/words.txt \
|
||||
--graph_path=$graph/TLG.fst --max_active=7500 \
|
||||
--acoustic_scale=12 \
|
||||
--result_wspecifier=ark,t:./result_run.txt
|
||||
|
||||
# the data/wav.trans is the label.
|
||||
utils/compute-wer.py --char=1 --v=1 data/wav.trans result_run.txt > wer_run
|
||||
tail -n 7 wer_run
|
||||
fi
|
@ -0,0 +1 @@
|
||||
../../../utils
|
@ -0,0 +1,15 @@
|
||||
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
|
||||
|
||||
set(BINS
|
||||
fstaddselfloops
|
||||
fstisstochastic
|
||||
fstminimizeencoded
|
||||
fstdeterminizestar
|
||||
fsttablecompose
|
||||
)
|
||||
|
||||
foreach(binary IN LISTS BINS)
|
||||
add_executable(${binary} ${CMAKE_CURRENT_SOURCE_DIR}/${binary}.cc)
|
||||
target_include_directories(${binary} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
||||
target_link_libraries(${binary} PUBLIC kaldi-fstext glog gflags fst dl)
|
||||
endforeach()
|
@ -1,5 +1,5 @@
|
||||
|
||||
add_library(kaldi-fstext
|
||||
kaldi-fst-io.cc
|
||||
kaldi-fst-io.cc
|
||||
)
|
||||
target_link_libraries(kaldi-fstext PUBLIC kaldi-util)
|
||||
|
@ -0,0 +1,6 @@
|
||||
|
||||
add_library(kaldi-lm
|
||||
arpa-file-parser.cc
|
||||
arpa-lm-compiler.cc
|
||||
)
|
||||
target_link_libraries(kaldi-lm PUBLIC kaldi-util)
|
@ -0,0 +1,281 @@
|
||||
// lm/arpa-file-parser.cc
|
||||
|
||||
// Copyright 2014 Guoguo Chen
|
||||
// Copyright 2016 Smart Action Company LLC (kkm)
|
||||
|
||||
// See ../../COPYING for clarification regarding multiple authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <fst/fstlib.h>
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include "base/kaldi-error.h"
|
||||
#include "base/kaldi-math.h"
|
||||
#include "lm/arpa-file-parser.h"
|
||||
#include "util/text-utils.h"
|
||||
|
||||
namespace kaldi {
|
||||
|
||||
ArpaFileParser::ArpaFileParser(const ArpaParseOptions& options,
|
||||
fst::SymbolTable* symbols)
|
||||
: options_(options), symbols_(symbols),
|
||||
line_number_(0), warning_count_(0) {
|
||||
}
|
||||
|
||||
ArpaFileParser::~ArpaFileParser() {
|
||||
}
|
||||
|
||||
void TrimTrailingWhitespace(std::string *str) {
|
||||
str->erase(str->find_last_not_of(" \n\r\t") + 1);
|
||||
}
|
||||
|
||||
void ArpaFileParser::Read(std::istream &is) {
|
||||
// Argument sanity checks.
|
||||
if (options_.bos_symbol <= 0 || options_.eos_symbol <= 0 ||
|
||||
options_.bos_symbol == options_.eos_symbol)
|
||||
KALDI_ERR << "BOS and EOS symbols are required, must not be epsilons, and "
|
||||
<< "differ from each other. Given:"
|
||||
<< " BOS=" << options_.bos_symbol
|
||||
<< " EOS=" << options_.eos_symbol;
|
||||
if (symbols_ != NULL &&
|
||||
options_.oov_handling == ArpaParseOptions::kReplaceWithUnk &&
|
||||
(options_.unk_symbol <= 0 ||
|
||||
options_.unk_symbol == options_.bos_symbol ||
|
||||
options_.unk_symbol == options_.eos_symbol))
|
||||
KALDI_ERR << "When symbol table is given and OOV mode is kReplaceWithUnk, "
|
||||
<< "UNK symbol is required, must not be epsilon, and "
|
||||
<< "differ from both BOS and EOS symbols. Given:"
|
||||
<< " UNK=" << options_.unk_symbol
|
||||
<< " BOS=" << options_.bos_symbol
|
||||
<< " EOS=" << options_.eos_symbol;
|
||||
if (symbols_ != NULL && symbols_->Find(options_.bos_symbol).empty())
|
||||
KALDI_ERR << "BOS symbol must exist in symbol table";
|
||||
if (symbols_ != NULL && symbols_->Find(options_.eos_symbol).empty())
|
||||
KALDI_ERR << "EOS symbol must exist in symbol table";
|
||||
if (symbols_ != NULL && options_.unk_symbol > 0 &&
|
||||
symbols_->Find(options_.unk_symbol).empty())
|
||||
KALDI_ERR << "UNK symbol must exist in symbol table";
|
||||
|
||||
ngram_counts_.clear();
|
||||
line_number_ = 0;
|
||||
warning_count_ = 0;
|
||||
current_line_.clear();
|
||||
|
||||
#define PARSE_ERR KALDI_ERR << LineReference() << ": "
|
||||
|
||||
// Give derived class an opportunity to prepare its state.
|
||||
ReadStarted();
|
||||
|
||||
// Processes "\data\" section.
|
||||
bool keyword_found = false;
|
||||
while (++line_number_, getline(is, current_line_) && !is.eof()) {
|
||||
if (current_line_.find_first_not_of(" \t\n\r") == std::string::npos) {
|
||||
continue;
|
||||
}
|
||||
|
||||
TrimTrailingWhitespace(¤t_line_);
|
||||
|
||||
// Continue skipping lines until the \data\ marker alone on a line is found.
|
||||
if (!keyword_found) {
|
||||
if (current_line_ == "\\data\\") {
|
||||
KALDI_LOG << "Reading \\data\\ section.";
|
||||
keyword_found = true;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (current_line_[0] == '\\') break;
|
||||
|
||||
// Enters "\data\" section, and looks for patterns like "ngram 1=1000",
|
||||
// which means there are 1000 unigrams.
|
||||
std::size_t equal_symbol_pos = current_line_.find("=");
|
||||
if (equal_symbol_pos != std::string::npos)
|
||||
// Guaranteed spaces around the "=".
|
||||
current_line_.replace(equal_symbol_pos, 1, " = ");
|
||||
std::vector<std::string> col;
|
||||
SplitStringToVector(current_line_, " \t", true, &col);
|
||||
if (col.size() == 4 && col[0] == "ngram" && col[2] == "=") {
|
||||
int32 order, ngram_count = 0;
|
||||
if (!ConvertStringToInteger(col[1], &order) ||
|
||||
!ConvertStringToInteger(col[3], &ngram_count)) {
|
||||
PARSE_ERR << "cannot parse ngram count";
|
||||
}
|
||||
if (ngram_counts_.size() <= order) {
|
||||
ngram_counts_.resize(order);
|
||||
}
|
||||
ngram_counts_[order - 1] = ngram_count;
|
||||
} else {
|
||||
KALDI_WARN << LineReference()
|
||||
<< ": uninterpretable line in \\data\\ section";
|
||||
}
|
||||
}
|
||||
|
||||
if (ngram_counts_.size() == 0)
|
||||
PARSE_ERR << "\\data\\ section missing or empty.";
|
||||
|
||||
// Signal that grammar order and n-gram counts are known.
|
||||
HeaderAvailable();
|
||||
|
||||
NGram ngram;
|
||||
ngram.words.reserve(ngram_counts_.size());
|
||||
|
||||
// Processes "\N-grams:" section.
|
||||
for (int32 cur_order = 1; cur_order <= ngram_counts_.size(); ++cur_order) {
|
||||
// Skips n-grams with zero count.
|
||||
if (ngram_counts_[cur_order - 1] == 0)
|
||||
KALDI_WARN << "Zero ngram count in ngram order " << cur_order
|
||||
<< "(look for 'ngram " << cur_order << "=0' in the \\data\\ "
|
||||
<< " section). There is possibly a problem with the file.";
|
||||
|
||||
// Must be looking at a \k-grams: directive at this point.
|
||||
std::ostringstream keyword;
|
||||
keyword << "\\" << cur_order << "-grams:";
|
||||
if (current_line_ != keyword.str()) {
|
||||
PARSE_ERR << "invalid directive, expecting '" << keyword.str() << "'";
|
||||
}
|
||||
KALDI_LOG << "Reading " << current_line_ << " section.";
|
||||
|
||||
int32 ngram_count = 0;
|
||||
while (++line_number_, getline(is, current_line_) && !is.eof()) {
|
||||
if (current_line_.find_first_not_of(" \n\t\r") == std::string::npos) {
|
||||
continue;
|
||||
}
|
||||
if (current_line_[0] == '\\') {
|
||||
TrimTrailingWhitespace(¤t_line_);
|
||||
std::ostringstream next_keyword;
|
||||
next_keyword << "\\" << cur_order + 1 << "-grams:";
|
||||
if ((current_line_ != next_keyword.str()) &&
|
||||
(current_line_ != "\\end\\")) {
|
||||
if (ShouldWarn()) {
|
||||
KALDI_WARN << "ignoring possible directive '" << current_line_
|
||||
<< "' expecting '" << next_keyword.str() << "'";
|
||||
|
||||
if (warning_count_ > 0 &&
|
||||
warning_count_ > static_cast<uint32>(options_.max_warnings)) {
|
||||
KALDI_WARN << "Of " << warning_count_ << " parse warnings, "
|
||||
<< options_.max_warnings << " were reported. "
|
||||
<< "Run program with --max-arpa-warnings=-1 "
|
||||
<< "to see all warnings";
|
||||
}
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::string> col;
|
||||
SplitStringToVector(current_line_, " \t", true, &col);
|
||||
|
||||
if (col.size() < 1 + cur_order ||
|
||||
col.size() > 2 + cur_order ||
|
||||
(cur_order == ngram_counts_.size() && col.size() != 1 + cur_order)) {
|
||||
PARSE_ERR << "Invalid n-gram data line";
|
||||
}
|
||||
++ngram_count;
|
||||
|
||||
// Parse out n-gram logprob and, if present, backoff weight.
|
||||
if (!ConvertStringToReal(col[0], &ngram.logprob)) {
|
||||
PARSE_ERR << "invalid n-gram logprob '" << col[0] << "'";
|
||||
}
|
||||
ngram.backoff = 0.0;
|
||||
if (col.size() > cur_order + 1) {
|
||||
if (!ConvertStringToReal(col[cur_order + 1], &ngram.backoff))
|
||||
PARSE_ERR << "invalid backoff weight '" << col[cur_order + 1] << "'";
|
||||
}
|
||||
// Convert to natural log.
|
||||
ngram.logprob *= M_LN10;
|
||||
ngram.backoff *= M_LN10;
|
||||
|
||||
ngram.words.resize(cur_order);
|
||||
bool skip_ngram = false;
|
||||
for (int32 index = 0; !skip_ngram && index < cur_order; ++index) {
|
||||
int32 word;
|
||||
if (symbols_) {
|
||||
// Symbol table provided, so symbol labels are expected.
|
||||
if (options_.oov_handling == ArpaParseOptions::kAddToSymbols) {
|
||||
word = symbols_->AddSymbol(col[1 + index]);
|
||||
} else {
|
||||
word = symbols_->Find(col[1 + index]);
|
||||
if (word == -1) { // fst::kNoSymbol
|
||||
switch (options_.oov_handling) {
|
||||
case ArpaParseOptions::kReplaceWithUnk:
|
||||
word = options_.unk_symbol;
|
||||
break;
|
||||
case ArpaParseOptions::kSkipNGram:
|
||||
if (ShouldWarn())
|
||||
KALDI_WARN << LineReference() << " skipped: word '"
|
||||
<< col[1 + index] << "' not in symbol table";
|
||||
skip_ngram = true;
|
||||
break;
|
||||
default:
|
||||
PARSE_ERR << "word '" << col[1 + index]
|
||||
<< "' not in symbol table";
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Symbols not provided, LM file should contain integers.
|
||||
if (!ConvertStringToInteger(col[1 + index], &word) || word < 0) {
|
||||
PARSE_ERR << "invalid symbol '" << col[1 + index] << "'";
|
||||
}
|
||||
}
|
||||
// Whichever way we got it, an epsilon is invalid.
|
||||
if (word == 0) {
|
||||
PARSE_ERR << "epsilon symbol '" << col[1 + index]
|
||||
<< "' is illegal in ARPA LM";
|
||||
}
|
||||
ngram.words[index] = word;
|
||||
}
|
||||
if (!skip_ngram) {
|
||||
ConsumeNGram(ngram);
|
||||
}
|
||||
}
|
||||
if (ngram_count > ngram_counts_[cur_order - 1]) {
|
||||
PARSE_ERR << "header said there would be " << ngram_counts_[cur_order - 1]
|
||||
<< " n-grams of order " << cur_order
|
||||
<< ", but we saw more already.";
|
||||
}
|
||||
}
|
||||
|
||||
if (current_line_ != "\\end\\") {
|
||||
PARSE_ERR << "invalid or unexpected directive line, expecting \\end\\";
|
||||
}
|
||||
|
||||
if (warning_count_ > 0 &&
|
||||
warning_count_ > static_cast<uint32>(options_.max_warnings)) {
|
||||
KALDI_WARN << "Of " << warning_count_ << " parse warnings, "
|
||||
<< options_.max_warnings << " were reported. Run program with "
|
||||
<< "--max-arpa-warnings=-1 to see all warnings";
|
||||
}
|
||||
|
||||
current_line_.clear();
|
||||
ReadComplete();
|
||||
|
||||
#undef PARSE_ERR
|
||||
}
|
||||
|
||||
std::string ArpaFileParser::LineReference() const {
|
||||
std::ostringstream ss;
|
||||
ss << "line " << line_number_ << " [" << current_line_ << "]";
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
bool ArpaFileParser::ShouldWarn() {
|
||||
return (warning_count_ != -1) &&
|
||||
(++warning_count_ <= static_cast<uint32>(options_.max_warnings));
|
||||
}
|
||||
|
||||
} // namespace kaldi
|
@ -0,0 +1,146 @@
|
||||
// lm/arpa-file-parser.h
|
||||
|
||||
// Copyright 2014 Guoguo Chen
|
||||
// Copyright 2016 Smart Action Company LLC (kkm)
|
||||
|
||||
// See ../../COPYING for clarification regarding multiple authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef KALDI_LM_ARPA_FILE_PARSER_H_
|
||||
#define KALDI_LM_ARPA_FILE_PARSER_H_
|
||||
|
||||
#include <fst/fst-decl.h>
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "base/kaldi-types.h"
|
||||
#include "util/options-itf.h"
|
||||
|
||||
namespace kaldi {
|
||||
|
||||
/**
|
||||
Options that control ArpaFileParser
|
||||
*/
|
||||
struct ArpaParseOptions {
|
||||
enum OovHandling {
|
||||
kRaiseError, ///< Abort on OOV words
|
||||
kAddToSymbols, ///< Add novel words to the symbol table.
|
||||
kReplaceWithUnk, ///< Replace OOV words with <unk>.
|
||||
kSkipNGram ///< Skip n-gram with OOV word and continue.
|
||||
};
|
||||
|
||||
ArpaParseOptions():
|
||||
bos_symbol(-1), eos_symbol(-1), unk_symbol(-1),
|
||||
oov_handling(kRaiseError), max_warnings(30) { }
|
||||
|
||||
void Register(OptionsItf *opts) {
|
||||
// Registering only the max_warnings count, since other options are
|
||||
// treated differently by client programs: some want integer symbols,
|
||||
// while other are passed words in their command line.
|
||||
opts->Register("max-arpa-warnings", &max_warnings,
|
||||
"Maximum warnings to report on ARPA parsing, "
|
||||
"0 to disable, -1 to show all");
|
||||
}
|
||||
|
||||
int32 bos_symbol; ///< Symbol for <s>, Required non-epsilon.
|
||||
int32 eos_symbol; ///< Symbol for </s>, Required non-epsilon.
|
||||
int32 unk_symbol; ///< Symbol for <unk>, Required for kReplaceWithUnk.
|
||||
OovHandling oov_handling; ///< How to handle OOV words in the file.
|
||||
int32 max_warnings; ///< Maximum warnings to report, <0 unlimited.
|
||||
};
|
||||
|
||||
/**
|
||||
A parsed n-gram from ARPA LM file.
|
||||
*/
|
||||
struct NGram {
|
||||
NGram() : logprob(0.0), backoff(0.0) { }
|
||||
std::vector<int32> words; ///< Symbols in left to right order.
|
||||
float logprob; ///< Log-prob of the n-gram.
|
||||
float backoff; ///< log-backoff weight of the n-gram.
|
||||
///< Defaults to zero if not specified.
|
||||
};
|
||||
|
||||
/**
|
||||
ArpaFileParser is an abstract base class for ARPA LM file conversion.
|
||||
|
||||
See ConstArpaLmBuilder and ArpaLmCompiler for usage examples.
|
||||
*/
|
||||
class ArpaFileParser {
|
||||
public:
|
||||
/// Constructs the parser with the given options and optional symbol table.
|
||||
/// If symbol table is provided, then the file should contain text n-grams,
|
||||
/// and the words are mapped to symbols through it. bos_symbol and
|
||||
/// eos_symbol in the options structure must be valid symbols in the table,
|
||||
/// and so must be unk_symbol if provided. The table is not owned by the
|
||||
/// parser, but may be augmented, if oov_handling is set to kAddToSymbols.
|
||||
/// If symbol table is a null pointer, the file should contain integer
|
||||
/// symbol values, and oov_handling has no effect. bos_symbol and eos_symbol
|
||||
/// must be valid symbols still.
|
||||
ArpaFileParser(const ArpaParseOptions& options, fst::SymbolTable* symbols);
|
||||
virtual ~ArpaFileParser();
|
||||
|
||||
/// Read ARPA LM file from a stream.
|
||||
void Read(std::istream &is);
|
||||
|
||||
/// Parser options.
|
||||
const ArpaParseOptions& Options() const { return options_; }
|
||||
|
||||
protected:
|
||||
/// Override called before reading starts. This is the point to prepare
|
||||
/// any state in the derived class.
|
||||
virtual void ReadStarted() { }
|
||||
|
||||
/// Override function called to signal that ARPA header with the expected
|
||||
/// number of n-grams has been read, and ngram_counts() is now valid.
|
||||
virtual void HeaderAvailable() { }
|
||||
|
||||
/// Pure override that must be implemented to process current n-gram. The
|
||||
/// n-grams are sent in the file order, which guarantees that all
|
||||
/// (k-1)-grams are processed before the first k-gram is.
|
||||
virtual void ConsumeNGram(const NGram&) = 0;
|
||||
|
||||
/// Override function called after the last n-gram has been consumed.
|
||||
virtual void ReadComplete() { }
|
||||
|
||||
/// Read-only access to symbol table. Not owned, do not make public.
|
||||
const fst::SymbolTable* Symbols() const { return symbols_; }
|
||||
|
||||
/// Inside ConsumeNGram(), provides the current line number.
|
||||
int32 LineNumber() const { return line_number_; }
|
||||
|
||||
/// Inside ConsumeNGram(), returns a formatted reference to the line being
|
||||
/// compiled, to print out as part of diagnostics.
|
||||
std::string LineReference() const;
|
||||
|
||||
/// Increments warning count, and returns true if a warning should be
|
||||
/// printed or false if the count has exceeded the set maximum.
|
||||
bool ShouldWarn();
|
||||
|
||||
/// N-gram counts. Valid from the point when HeaderAvailable() is called.
|
||||
const std::vector<int32>& NgramCounts() const { return ngram_counts_; }
|
||||
|
||||
private:
|
||||
ArpaParseOptions options_;
|
||||
fst::SymbolTable* symbols_; // the pointer is not owned here.
|
||||
int32 line_number_;
|
||||
uint32 warning_count_;
|
||||
std::string current_line_;
|
||||
std::vector<int32> ngram_counts_;
|
||||
};
|
||||
|
||||
} // namespace kaldi
|
||||
|
||||
#endif // KALDI_LM_ARPA_FILE_PARSER_H_
|
@ -0,0 +1,377 @@
|
||||
// lm/arpa-lm-compiler.cc
|
||||
|
||||
// Copyright 2009-2011 Gilles Boulianne
|
||||
// Copyright 2016 Smart Action LLC (kkm)
|
||||
// Copyright 2017 Xiaohui Zhang
|
||||
|
||||
// See ../../COPYING for clarification regarding multiple authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <sstream>
|
||||
#include <utility>
|
||||
|
||||
#include "base/kaldi-math.h"
|
||||
#include "lm/arpa-lm-compiler.h"
|
||||
#include "util/stl-utils.h"
|
||||
#include "util/text-utils.h"
|
||||
#include "fstext/remove-eps-local.h"
|
||||
|
||||
namespace kaldi {
|
||||
|
||||
class ArpaLmCompilerImplInterface {
|
||||
public:
|
||||
virtual ~ArpaLmCompilerImplInterface() { }
|
||||
virtual void ConsumeNGram(const NGram& ngram, bool is_highest) = 0;
|
||||
};
|
||||
|
||||
namespace {
|
||||
|
||||
typedef int32 StateId;
|
||||
typedef int32 Symbol;
|
||||
|
||||
// GeneralHistKey can represent state history in an arbitrarily large n
|
||||
// n-gram model with symbol ids fitting int32.
|
||||
class GeneralHistKey {
|
||||
public:
|
||||
// Construct key from being and end iterators.
|
||||
template<class InputIt>
|
||||
GeneralHistKey(InputIt begin, InputIt end) : vector_(begin, end) { }
|
||||
// Construct empty history key.
|
||||
GeneralHistKey() : vector_() { }
|
||||
// Return tails of the key as a GeneralHistKey. The tails of an n-gram
|
||||
// w[1..n] is the sequence w[2..n] (and the heads is w[1..n-1], but the
|
||||
// key class does not need this operartion).
|
||||
GeneralHistKey Tails() const {
|
||||
return GeneralHistKey(vector_.begin() + 1, vector_.end());
|
||||
}
|
||||
// Keys are equal if represent same state.
|
||||
friend bool operator==(const GeneralHistKey& a, const GeneralHistKey& b) {
|
||||
return a.vector_ == b.vector_;
|
||||
}
|
||||
// Public typename HashType for hashing.
|
||||
struct HashType : public std::unary_function<GeneralHistKey, size_t> {
|
||||
size_t operator()(const GeneralHistKey& key) const {
|
||||
return VectorHasher<Symbol>().operator()(key.vector_);
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
std::vector<Symbol> vector_;
|
||||
};
|
||||
|
||||
// OptimizedHistKey combines 3 21-bit symbol ID values into one 64-bit
|
||||
// machine word. allowing significant memory reduction and some runtime
|
||||
// benefit over GeneralHistKey. Since 3 symbols are enough to track history
|
||||
// in a 4-gram model, this optimized key is used for smaller models with up
|
||||
// to 4-gram and symbol values up to 2^21-1.
|
||||
//
|
||||
// See GeneralHistKey for interface requirements of a key class.
|
||||
class OptimizedHistKey {
|
||||
public:
|
||||
enum {
|
||||
kShift = 21, // 21 * 3 = 63 bits for data.
|
||||
kMaxData = (1 << kShift) - 1
|
||||
};
|
||||
template<class InputIt>
|
||||
OptimizedHistKey(InputIt begin, InputIt end) : data_(0) {
|
||||
for (uint32 shift = 0; begin != end; ++begin, shift += kShift) {
|
||||
data_ |= static_cast<uint64>(*begin) << shift;
|
||||
}
|
||||
}
|
||||
OptimizedHistKey() : data_(0) { }
|
||||
OptimizedHistKey Tails() const {
|
||||
return OptimizedHistKey(data_ >> kShift);
|
||||
}
|
||||
friend bool operator==(const OptimizedHistKey& a, const OptimizedHistKey& b) {
|
||||
return a.data_ == b.data_;
|
||||
}
|
||||
struct HashType : public std::unary_function<OptimizedHistKey, size_t> {
|
||||
size_t operator()(const OptimizedHistKey& key) const { return key.data_; }
|
||||
};
|
||||
|
||||
private:
|
||||
explicit OptimizedHistKey(uint64 data) : data_(data) { }
|
||||
uint64 data_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <class HistKey>
|
||||
class ArpaLmCompilerImpl : public ArpaLmCompilerImplInterface {
|
||||
public:
|
||||
ArpaLmCompilerImpl(ArpaLmCompiler* parent, fst::StdVectorFst* fst,
|
||||
Symbol sub_eps);
|
||||
|
||||
virtual void ConsumeNGram(const NGram &ngram, bool is_highest);
|
||||
|
||||
private:
|
||||
StateId AddStateWithBackoff(HistKey key, float backoff);
|
||||
void CreateBackoff(HistKey key, StateId state, float weight);
|
||||
|
||||
ArpaLmCompiler *parent_; // Not owned.
|
||||
fst::StdVectorFst* fst_; // Not owned.
|
||||
Symbol bos_symbol_;
|
||||
Symbol eos_symbol_;
|
||||
Symbol sub_eps_;
|
||||
|
||||
StateId eos_state_;
|
||||
typedef unordered_map<HistKey, StateId,
|
||||
typename HistKey::HashType> HistoryMap;
|
||||
HistoryMap history_;
|
||||
};
|
||||
|
||||
template <class HistKey>
|
||||
ArpaLmCompilerImpl<HistKey>::ArpaLmCompilerImpl(
|
||||
ArpaLmCompiler* parent, fst::StdVectorFst* fst, Symbol sub_eps)
|
||||
: parent_(parent), fst_(fst), bos_symbol_(parent->Options().bos_symbol),
|
||||
eos_symbol_(parent->Options().eos_symbol), sub_eps_(sub_eps) {
|
||||
// The algorithm maintains state per history. The 0-gram is a special state
|
||||
// for empty history. All unigrams (including BOS) backoff into this state.
|
||||
StateId zerogram = fst_->AddState();
|
||||
history_[HistKey()] = zerogram;
|
||||
|
||||
// Also, if </s> is not treated as epsilon, create a common end state for
|
||||
// all transitions accepting the </s>, since they do not back off. This small
|
||||
// optimization saves about 2% states in an average grammar.
|
||||
if (sub_eps_ == 0) {
|
||||
eos_state_ = fst_->AddState();
|
||||
fst_->SetFinal(eos_state_, 0);
|
||||
}
|
||||
}
|
||||
|
||||
template <class HistKey>
|
||||
void ArpaLmCompilerImpl<HistKey>::ConsumeNGram(const NGram &ngram,
|
||||
bool is_highest) {
|
||||
// Generally, we do the following. Suppose we are adding an n-gram "A B
|
||||
// C". Then find the node for "A B", add a new node for "A B C", and connect
|
||||
// them with the arc accepting "C" with the specified weight. Also, add a
|
||||
// backoff arc from the new "A B C" node to its backoff state "B C".
|
||||
//
|
||||
// Two notable exceptions are the highest order n-grams, and final n-grams.
|
||||
//
|
||||
// When adding a highest order n-gram (e. g., our "A B C" is in a 3-gram LM),
|
||||
// the following optimization is performed. There is no point adding a node
|
||||
// for "A B C" with a "C" arc from "A B", since there will be no other
|
||||
// arcs ingoing to this node, and an epsilon backoff arc into the backoff
|
||||
// model "B C", with the weight of \bar{1}. To save a node, create an arc
|
||||
// accepting "C" directly from "A B" to "B C". This saves as many nodes
|
||||
// as there are the highest order n-grams, which is typically about half
|
||||
// the size of a large 3-gram model.
|
||||
//
|
||||
// Indeed, this does not apply to n-grams ending in EOS, since they do not
|
||||
// back off. These are special, as they do not have a back-off state, and
|
||||
// the node for "(..anything..) </s>" is always final. These are handled
|
||||
// in one of the two possible ways, If symbols <s> and </s> are being
|
||||
// replaced by epsilons, neither node nor arc is created, and the logprob
|
||||
// of the n-gram is applied to its source node as final weight. If <s> and
|
||||
// </s> are preserved, then a special final node for </s> is allocated and
|
||||
// used as the destination of the "</s>" acceptor arc.
|
||||
HistKey heads(ngram.words.begin(), ngram.words.end() - 1);
|
||||
typename HistoryMap::iterator source_it = history_.find(heads);
|
||||
if (source_it == history_.end()) {
|
||||
// There was no "A B", therefore the probability of "A B C" is zero.
|
||||
// Print a warning and discard current n-gram.
|
||||
if (parent_->ShouldWarn())
|
||||
KALDI_WARN << parent_->LineReference()
|
||||
<< " skipped: no parent (n-1)-gram exists";
|
||||
return;
|
||||
}
|
||||
|
||||
StateId source = source_it->second;
|
||||
StateId dest;
|
||||
Symbol sym = ngram.words.back();
|
||||
float weight = -ngram.logprob;
|
||||
if (sym == sub_eps_ || sym == 0) {
|
||||
KALDI_ERR << " <eps> or disambiguation symbol " << sym << "found in the ARPA file. ";
|
||||
}
|
||||
if (sym == eos_symbol_) {
|
||||
if (sub_eps_ == 0) {
|
||||
// Keep </s> as a real symbol when not substituting.
|
||||
dest = eos_state_;
|
||||
} else {
|
||||
// Treat </s> as if it was epsilon: mark source final, with the weight
|
||||
// of the n-gram.
|
||||
fst_->SetFinal(source, weight);
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
// For the highest order n-gram, this may find an existing state, for
|
||||
// non-highest, will create one (unless there are duplicate n-grams
|
||||
// in the grammar, which cannot be reliably detected if highest order,
|
||||
// so we better do not do that at all).
|
||||
dest = AddStateWithBackoff(
|
||||
HistKey(ngram.words.begin() + (is_highest ? 1 : 0),
|
||||
ngram.words.end()),
|
||||
-ngram.backoff);
|
||||
}
|
||||
|
||||
if (sym == bos_symbol_) {
|
||||
weight = 0; // Accepting <s> is always free.
|
||||
if (sub_eps_ == 0) {
|
||||
// <s> is as a real symbol, only accepted in the start state.
|
||||
source = fst_->AddState();
|
||||
fst_->SetStart(source);
|
||||
} else {
|
||||
// The new state for <s> unigram history *is* the start state.
|
||||
fst_->SetStart(dest);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Add arc from source to dest, whichever way it was found.
|
||||
fst_->AddArc(source, fst::StdArc(sym, sym, weight, dest));
|
||||
return;
|
||||
}
|
||||
|
||||
// Find or create a new state for n-gram defined by key, and ensure it has a
|
||||
// backoff transition. The key is either the current n-gram for all but
|
||||
// highest orders, or the tails of the n-gram for the highest order. The
|
||||
// latter arises from the chain-collapsing optimization described above.
|
||||
template <class HistKey>
|
||||
StateId ArpaLmCompilerImpl<HistKey>::AddStateWithBackoff(HistKey key,
|
||||
float backoff) {
|
||||
typename HistoryMap::iterator dest_it = history_.find(key);
|
||||
if (dest_it != history_.end()) {
|
||||
// Found an existing state in the history map. Invariant: if the state in
|
||||
// the map, then its backoff arc is in the FST. We are done.
|
||||
return dest_it->second;
|
||||
}
|
||||
// Otherwise create a new state and its backoff arc, and register in the map.
|
||||
StateId dest = fst_->AddState();
|
||||
history_[key] = dest;
|
||||
CreateBackoff(key.Tails(), dest, backoff);
|
||||
return dest;
|
||||
}
|
||||
|
||||
// Create a backoff arc for a state. Key is a backoff destination that may or
|
||||
// may not exist. When the destination is not found, naturally fall back to
|
||||
// the lower order model, and all the way down until one is found (since the
|
||||
// 0-gram model is always present, the search is guaranteed to terminate).
|
||||
template <class HistKey>
|
||||
inline void ArpaLmCompilerImpl<HistKey>::CreateBackoff(
|
||||
HistKey key, StateId state, float weight) {
|
||||
typename HistoryMap::iterator dest_it = history_.find(key);
|
||||
while (dest_it == history_.end()) {
|
||||
key = key.Tails();
|
||||
dest_it = history_.find(key);
|
||||
}
|
||||
|
||||
// The arc should transduce either <eos> or #0 to <eps>, depending on the
|
||||
// epsilon substitution mode. This is the only case when input and output
|
||||
// label may differ.
|
||||
fst_->AddArc(state, fst::StdArc(sub_eps_, 0, weight, dest_it->second));
|
||||
}
|
||||
|
||||
ArpaLmCompiler::~ArpaLmCompiler() {
|
||||
if (impl_ != NULL)
|
||||
delete impl_;
|
||||
}
|
||||
|
||||
void ArpaLmCompiler::HeaderAvailable() {
|
||||
KALDI_ASSERT(impl_ == NULL);
|
||||
// Use optimized implementation if the grammar is 4-gram or less, and the
|
||||
// maximum attained symbol id will fit into the optimized range.
|
||||
int64 max_symbol = 0;
|
||||
if (Symbols() != NULL)
|
||||
max_symbol = Symbols()->AvailableKey() - 1;
|
||||
// If augmenting the symbol table, assume the worst case when all words in
|
||||
// the model being read are novel.
|
||||
if (Options().oov_handling == ArpaParseOptions::kAddToSymbols)
|
||||
max_symbol += NgramCounts()[0];
|
||||
|
||||
if (NgramCounts().size() <= 4 && max_symbol < OptimizedHistKey::kMaxData) {
|
||||
impl_ = new ArpaLmCompilerImpl<OptimizedHistKey>(this, &fst_, sub_eps_);
|
||||
} else {
|
||||
impl_ = new ArpaLmCompilerImpl<GeneralHistKey>(this, &fst_, sub_eps_);
|
||||
KALDI_LOG << "Reverting to slower state tracking because model is large: "
|
||||
<< NgramCounts().size() << "-gram with symbols up to "
|
||||
<< max_symbol;
|
||||
}
|
||||
}
|
||||
|
||||
void ArpaLmCompiler::ConsumeNGram(const NGram &ngram) {
|
||||
// <s> is invalid in tails, </s> in heads of an n-gram.
|
||||
for (int i = 0; i < ngram.words.size(); ++i) {
|
||||
if ((i > 0 && ngram.words[i] == Options().bos_symbol) ||
|
||||
(i + 1 < ngram.words.size()
|
||||
&& ngram.words[i] == Options().eos_symbol)) {
|
||||
if (ShouldWarn())
|
||||
KALDI_WARN << LineReference()
|
||||
<< " skipped: n-gram has invalid BOS/EOS placement";
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
bool is_highest = ngram.words.size() == NgramCounts().size();
|
||||
impl_->ConsumeNGram(ngram, is_highest);
|
||||
}
|
||||
|
||||
void ArpaLmCompiler::RemoveRedundantStates() {
|
||||
fst::StdArc::Label backoff_symbol = sub_eps_;
|
||||
if (backoff_symbol == 0) {
|
||||
// The method of removing redundant states implemented in this function
|
||||
// leads to slow determinization of L o G when people use the older style of
|
||||
// usage of arpa2fst where the --disambig-symbol option was not specified.
|
||||
// The issue seems to be that it creates a non-deterministic FST, while G is
|
||||
// supposed to be deterministic. By 'return'ing below, we just disable this
|
||||
// method if people were using an older script. This method isn't really
|
||||
// that consequential anyway, and people will move to the newer-style
|
||||
// scripts (see current utils/format_lm.sh), so this isn't much of a
|
||||
// problem.
|
||||
return;
|
||||
}
|
||||
|
||||
fst::StdArc::StateId num_states = fst_.NumStates();
|
||||
|
||||
|
||||
// replace the #0 symbols on the input of arcs out of redundant states (states
|
||||
// that are not final and have only a backoff arc leaving them), with <eps>.
|
||||
for (fst::StdArc::StateId state = 0; state < num_states; state++) {
|
||||
if (fst_.NumArcs(state) == 1 && fst_.Final(state) == fst::TropicalWeight::Zero()) {
|
||||
fst::MutableArcIterator<fst::StdVectorFst> iter(&fst_, state);
|
||||
fst::StdArc arc = iter.Value();
|
||||
if (arc.ilabel == backoff_symbol) {
|
||||
arc.ilabel = 0;
|
||||
iter.SetValue(arc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// we could call fst::RemoveEps, and it would have the same effect in normal
|
||||
// cases, where backoff_symbol != 0 and there are no epsilons in unexpected
|
||||
// places, but RemoveEpsLocal is a bit safer in case something weird is going
|
||||
// on; it guarantees not to blow up the FST.
|
||||
fst::RemoveEpsLocal(&fst_);
|
||||
KALDI_LOG << "Reduced num-states from " << num_states << " to "
|
||||
<< fst_.NumStates();
|
||||
}
|
||||
|
||||
void ArpaLmCompiler::Check() const {
|
||||
if (fst_.Start() == fst::kNoStateId) {
|
||||
KALDI_ERR << "Arpa file did not contain the beginning-of-sentence symbol "
|
||||
<< Symbols()->Find(Options().bos_symbol) << ".";
|
||||
}
|
||||
}
|
||||
|
||||
void ArpaLmCompiler::ReadComplete() {
|
||||
fst_.SetInputSymbols(Symbols());
|
||||
fst_.SetOutputSymbols(Symbols());
|
||||
RemoveRedundantStates();
|
||||
Check();
|
||||
}
|
||||
|
||||
} // namespace kaldi
|
@ -0,0 +1,65 @@
|
||||
// lm/arpa-lm-compiler.h
|
||||
|
||||
// Copyright 2009-2011 Gilles Boulianne
|
||||
// Copyright 2016 Smart Action LLC (kkm)
|
||||
|
||||
// See ../../COPYING for clarification regarding multiple authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef KALDI_LM_ARPA_LM_COMPILER_H_
|
||||
#define KALDI_LM_ARPA_LM_COMPILER_H_
|
||||
|
||||
#include <fst/fstlib.h>
|
||||
|
||||
#include "lm/arpa-file-parser.h"
|
||||
|
||||
namespace kaldi {
|
||||
|
||||
class ArpaLmCompilerImplInterface;
|
||||
|
||||
class ArpaLmCompiler : public ArpaFileParser {
|
||||
public:
|
||||
ArpaLmCompiler(const ArpaParseOptions& options, int sub_eps,
|
||||
fst::SymbolTable* symbols)
|
||||
: ArpaFileParser(options, symbols),
|
||||
sub_eps_(sub_eps), impl_(NULL) {
|
||||
}
|
||||
~ArpaLmCompiler();
|
||||
|
||||
const fst::StdVectorFst& Fst() const { return fst_; }
|
||||
fst::StdVectorFst* MutableFst() { return &fst_; }
|
||||
|
||||
protected:
|
||||
// ArpaFileParser overrides.
|
||||
virtual void HeaderAvailable();
|
||||
virtual void ConsumeNGram(const NGram& ngram);
|
||||
virtual void ReadComplete();
|
||||
|
||||
|
||||
private:
|
||||
// this function removes states that only have a backoff arc coming
|
||||
// out of them.
|
||||
void RemoveRedundantStates();
|
||||
void Check() const;
|
||||
|
||||
int sub_eps_;
|
||||
ArpaLmCompilerImplInterface* impl_; // Owned.
|
||||
fst::StdVectorFst fst_;
|
||||
template <class HistKey> friend class ArpaLmCompilerImpl;
|
||||
};
|
||||
|
||||
} // namespace kaldi
|
||||
|
||||
#endif // KALDI_LM_ARPA_LM_COMPILER_H_
|
@ -1,5 +1,4 @@
|
||||
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
|
||||
|
||||
add_executable(arpa2fst ${CMAKE_CURRENT_SOURCE_DIR}/arpa2fst.cc)
|
||||
target_include_directories(arpa2fst PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
||||
target_link_libraries(arpa2fst )
|
||||
target_link_libraries(arpa2fst PUBLIC kaldi-lm glog gflags fst)
|
Loading…
Reference in new issue