add custom asr script

pull/1946/head
Yang Zhou 3 years ago
parent 2fbb7ec569
commit daadec0c63

@ -7,6 +7,8 @@ In some cases, we need to recognize the specific rare words with high accuracy.
this demo is customized for expense account, which need to recognize rare address. this demo is customized for expense account, which need to recognize rare address.
the scripts are in PaddleSpeech/speechx/examples/custom_asr.
* G with slot: 打车到 "address_slot"。 * G with slot: 打车到 "address_slot"。
![](https://ai-studio-static-online.cdn.bcebos.com/28d9ef132a7f47a895a65ae9e5c4f55b8f472c9f3dd24be8a2e66e0b88b173a4) ![](https://ai-studio-static-online.cdn.bcebos.com/28d9ef132a7f47a895a65ae9e5c4f55b8f472c9f3dd24be8a2e66e0b88b173a4)

@ -6,6 +6,8 @@
这个 demo 是打车报销单的场景识别,需要识别一些稀有的地名,可以通过如下操作实现。 这个 demo 是打车报销单的场景识别,需要识别一些稀有的地名,可以通过如下操作实现。
相关脚本:PaddleSpeech/speechx/examples/custom_asr
* G with slot: 打车到 "address_slot"。 * G with slot: 打车到 "address_slot"。
![](https://ai-studio-static-online.cdn.bcebos.com/28d9ef132a7f47a895a65ae9e5c4f55b8f472c9f3dd24be8a2e66e0b88b173a4) ![](https://ai-studio-static-online.cdn.bcebos.com/28d9ef132a7f47a895a65ae9e5c4f55b8f472c9f3dd24be8a2e66e0b88b173a4)

@ -57,7 +57,7 @@ include(gtest)
include(absl) include(absl)
# libsndfile # libsndfile
include(libsndfile) #include(libsndfile)
# boost # boost
# include(boost) # not work # include(boost) # not work
@ -73,9 +73,17 @@ find_package(Eigen3 REQUIRED)
# Kenlm # Kenlm
include(kenlm) include(kenlm)
add_dependencies(kenlm eigen boost) add_dependencies(kenlm eigen boost)
#set(kenlm_install_dir $(fc_patch)/kenlm-build)
#link_directories(${Kenlm_install_dir}/lib)
#include_directories(${fc_patch}/kenlm-src)
#openblas #openblas
include(openblas) #include(openblas)
set(OpenBLAS_INSTALL_PREFIX ${fc_patch}/openblas-install)
link_directories(${OpenBLAS_INSTALL_PREFIX}/lib)
include_directories(${OpenBLAS_INSTALL_PREFIX}/include)
# openfst # openfst
include(openfst) include(openfst)

@ -0,0 +1,32 @@
# customized Auto Speech Recognition
## introduction
those scripts are tutorials to show you how make your own decoding graph.
eg:
* G with slot: 打车到 "address_slot"。
![](https://ai-studio-static-online.cdn.bcebos.com/28d9ef132a7f47a895a65ae9e5c4f55b8f472c9f3dd24be8a2e66e0b88b173a4)
* this is address slot wfst, you can add the address which want to recognize.
![](https://ai-studio-static-online.cdn.bcebos.com/47c89100ef8c465bac733605ffc53d76abefba33d62f4d818d351f8cea3c8fe2)
* after replace operation, G = fstreplace(G_with_slot, address_slot), we will get the customized graph.
![](https://ai-studio-static-online.cdn.bcebos.com/60a3095293044f10b73039ab10c7950d139a6717580a44a3ba878c6e74de402b)
those operations are in the scripts, please check out. we will lanuch more detail scripts.
## How to run
```
bash run.sh
```
## Results
### CTC WFST
```
Overall -> 1.23 % N=1134 C=1126 S=6 D=2 I=6
Mandarin -> 1.24 % N=1132 C=1124 S=6 D=2 I=6
English -> 0.00 % N=2 C=2 S=0 D=0 I=0
```

@ -0,0 +1,89 @@
#!/bin/bash
# Copyright 2015 Yajie Miao (Carnegie Mellon University)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# This script compiles the lexicon and CTC tokens into FSTs. FST compiling slightly differs between the
# phoneme and character-based lexicons.
set -eo pipefail
. utils/parse_options.sh
if [ $# -ne 3 ]; then
echo "usage: utils/fst/compile_lexicon_token_fst.sh <dict-src-dir> <tmp-dir> <lang-dir>"
echo "e.g.: utils/fst/compile_lexicon_token_fst.sh data/local/dict data/local/lang_tmp data/lang"
echo "<dict-src-dir> should contain the following files:"
echo "lexicon.txt lexicon_numbers.txt units.txt"
echo "options: "
exit 1;
fi
srcdir=$1
tmpdir=$2
dir=$3
mkdir -p $dir $tmpdir
[ -f path.sh ] && . ./path.sh
cp $srcdir/units.txt $dir
# Add probabilities to lexicon entries. There is in fact no point of doing this here since all the entries have 1.0.
# But utils/make_lexicon_fst.pl requires a probabilistic version, so we just leave it as it is.
perl -ape 's/(\S+\s+)(.+)/${1}1.0\t$2/;' < $srcdir/lexicon.txt > $tmpdir/lexiconp.txt || exit 1;
# Add disambiguation symbols to the lexicon. This is necessary for determinizing the composition of L.fst and G.fst.
# Without these symbols, determinization will fail.
# default first disambiguation is #1
ndisambig=`utils/fst/add_lex_disambig.pl $tmpdir/lexiconp.txt $tmpdir/lexiconp_disambig.txt`
# add #0 (#0 reserved for symbol in grammar).
ndisambig=$[$ndisambig+1];
( for n in `seq 0 $ndisambig`; do echo '#'$n; done ) > $tmpdir/disambig.list
# Get the full list of CTC tokens used in FST. These tokens include <eps>, the blank <blk>,
# the actual model unit, and the disambiguation symbols.
cat $srcdir/units.txt | awk '{print $1}' > $tmpdir/units.list
(echo '<eps>';) | cat - $tmpdir/units.list $tmpdir/disambig.list | awk '{print $1 " " (NR-1)}' > $dir/tokens.txt
# ctc_token_fst_corrected is too big and too slow for character based chinese modeling,
# so here just use simple ctc_token_fst
utils/fst/ctc_token_fst.py --token_file $dir/tokens.txt | \
fstcompile --isymbols=$dir/tokens.txt --osymbols=$dir/tokens.txt --keep_isymbols=false --keep_osymbols=false | \
fstarcsort --sort_type=olabel > $dir/T.fst || exit 1;
# Encode the words with indices. Will be used in lexicon and language model FST compiling.
cat $tmpdir/lexiconp.txt | awk '{print $1}' | sort | awk '
BEGIN {
print "<eps> 0";
}
{
printf("%s %d\n", $1, NR);
}
END {
printf("#0 %d\n", NR+1);
printf("<s> %d\n", NR+2);
printf("</s> %d\n", NR+3);
printf("ROOT %d\n", NR+4);
}' > $dir/words.txt || exit 1;
# Now compile the lexicon FST. Depending on the size of your lexicon, it may take some time.
token_disambig_symbol=`grep \#0 $dir/tokens.txt | awk '{print $2}'`
word_disambig_symbol=`grep \#0 $dir/words.txt | awk '{print $2}'`
utils/fst/make_lexicon_fst.pl --pron-probs $tmpdir/lexiconp_disambig.txt 0 "sil" '#'$ndisambig | \
fstcompile --isymbols=$dir/tokens.txt --osymbols=$dir/words.txt \
--keep_isymbols=false --keep_osymbols=false | \
fstaddselfloops "echo $token_disambig_symbol |" "echo $word_disambig_symbol |" | \
fstarcsort --sort_type=olabel > $dir/L.fst || exit 1;
echo "Lexicon and Token FSTs compiling succeeded"

@ -0,0 +1,74 @@
#!/bin/bash
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
graph_slot=$1
dir=$2
[ -f path.sh ] && . ./path.sh
sym=$dir/../lang/words.txt
cat > $dir/address_slot.txt <<EOF
0 1 南山 南山
0 1 南京 南京
0 1 光明 光明
0 1 龙岗 龙岗
0 1 北苑 北苑
0 1 北京 北京
0 1 酒店 酒店
0 1 合肥 合肥
0 1 望京搜后 望京搜后
0 1 地铁站 地铁站
0 1 海淀黄庄 海淀黄庄
0 1 佛山 佛山
0 1 广州 广州
0 1 苏州 苏州
0 1 百度大厦 百度大厦
0 1 龙泽苑东区 龙泽苑东区
0 1 首都机场 首都机场
0 1 朝来家园 朝来家园
0 1 深大 深大
0 1 双龙 双龙
0 1 公司 公司
0 1 上海 上海
0 1 家 家
0 1 机场 机场
0 1 华祝 华祝
0 1 上海虹桥 上海虹桥
0 2 检验 检验
2 1 中心 中心
0 3 苏州 苏州
3 1 街 街
3 8 高铁 高铁
8 1 站 站
0 4 杭州 杭州
4 1 东站 东站
4 1 <eps> <eps>
0 5 上海 上海
0 5 北京 北京
0 5 合肥 合肥
5 1 南站 南站
0 6 立水 立水
6 1 桥 桥
0 7 青岛 青岛
7 1 站 站
1
EOF
fstcompile --isymbols=$sym --osymbols=$sym $dir/address_slot.txt $dir/address_slot.fst
fstcompile --isymbols=$sym --osymbols=$sym $graph_slot/time_slot.txt $dir/time_slot.fst
fstcompile --isymbols=$sym --osymbols=$sym $graph_slot/date_slot.txt $dir/date_slot.fst
fstcompile --isymbols=$sym --osymbols=$sym $graph_slot/money_slot.txt $dir/money_slot.fst
fstcompile --isymbols=$sym --osymbols=$sym $graph_slot/year_slot.txt $dir/year_slot.fst

@ -0,0 +1,61 @@
#!/bin/bash
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
lm=$1
lang=$2
tgt_lang=$3
unset GREP_OPTIONS
sym=$lang/words.txt
arpa_lm=$lm/lm.arpa
# Compose the language model to FST
cat $arpa_lm | \
grep -v '<s> <s>' | \
grep -v '</s> <s>' | \
grep -v '</s> </s>' | \
grep -v -i '<unk>' | \
grep -v -i '<spoken_noise>' | \
arpa2fst --read-symbol-table=$sym --keep-symbols=true - | fstprint | \
utils/fst/eps2disambig.pl | utils/fst/s2eps.pl | fstcompile --isymbols=$sym \
--osymbols=$sym --keep_isymbols=false --keep_osymbols=false | \
fstrmepsilon | fstarcsort --sort_type=ilabel > $tgt_lang/G_with_slot.fst
root_label=`grep ROOT $sym | awk '{print $2}'`
address_slot_label=`grep \<ADDRESS_SLOT\> $sym | awk '{print $2}'`
time_slot_label=`grep \<TIME_SLOT\> $sym | awk '{print $2}'`
date_slot_label=`grep \<DATE_SLOT\> $sym | awk '{print $2}'`
money_slot_label=`grep \<MONEY_SLOT\> $sym | awk '{print $2}'`
year_slot_label=`grep \<YEAR_SLOT\> $sym | awk '{print $2}'`
fstisstochastic $tgt_lang/G_with_slot.fst
fstreplace --epsilon_on_replace $tgt_lang/G_with_slot.fst \
$root_label $tgt_lang/address_slot.fst $address_slot_label \
$tgt_lang/date_slot.fst $date_slot_label \
$tgt_lang/money_slot.fst $money_slot_label \
$tgt_lang/time_slot.fst $time_slot_label \
$tgt_lang/year_slot.fst $year_slot_label $tgt_lang/G.fst
fstisstochastic $tgt_lang/G.fst
# Compose the token, lexicon and language-model FST into the final decoding graph
fsttablecompose $lang/L.fst $tgt_lang/G.fst | fstdeterminizestar --use-log=true | \
fstminimizeencoded | fstarcsort --sort_type=ilabel > $tgt_lang/LG.fst || exit 1;
fsttablecompose $lang/T.fst $tgt_lang/LG.fst > $tgt_lang/TLG.fst || exit 1;
rm $tgt_lang/LG.fst
echo "Composing decoding graph TLG.fst succeeded"

@ -0,0 +1,55 @@
#!/bin/bash
# To be run from one directory above this script.
. ./path.sh
src=ds2_graph_with_slot
text=$src/train_text
lexicon=$src/local/dict/lexicon.txt
dir=$src/local/lm
mkdir -p $dir
for f in "$text" "$lexicon"; do
[ ! -f $x ] && echo "$0: No such file $f" && exit 1;
done
# Check SRILM tools
if ! which ngram-count > /dev/null; then
pushd $MAIN_ROOT/tools
make srilm.done
popd
fi
# This script takes no arguments. It assumes you have already run
# It takes as input the files
# data/local/lm/text
# data/local/dict/lexicon.txt
cleantext=$dir/text.no_oov
cat $text | awk -v lex=$lexicon 'BEGIN{while((getline<lex) >0){ seen[$1]=1; } }
{for(n=1; n<=NF;n++) { if (seen[$n]) { printf("%s ", $n); } else {printf("<SPOKEN_NOISE> ");} } printf("\n");}' \
> $cleantext || exit 1;
cat $cleantext | awk '{for(n=2;n<=NF;n++) print $n; }' | sort | uniq -c | \
sort -nr > $dir/word.counts || exit 1;
# Get counts from acoustic training transcripts, and add one-count
# for each word in the lexicon (but not silence, we don't want it
# in the LM-- we'll add it optionally later).
cat $cleantext | awk '{for(n=2;n<=NF;n++) print $n; }' | \
cat - <(grep -w -v '!SIL' $lexicon | awk '{print $1}') | \
sort | uniq -c | sort -nr > $dir/unigram.counts || exit 1;
# filter the words which are not in the text
cat $dir/unigram.counts | awk '$1>1{print $0}' | awk '{print $2}' | cat - <(echo "<s>"; echo "</s>" ) > $dir/wordlist
# kaldi_lm results
mkdir -p $dir
cat $cleantext | awk '{for(n=2;n<=NF;n++){ printf $n; if(n<NF) printf " "; else print ""; }}' > $dir/train
ngram-count -text $dir/train -order 3 -limit-vocab -vocab $dir/wordlist -unk \
-map-unk "<UNK>" -gt3max 0 -gt2max 0 -gt1max 0 -lm $dir/lm.arpa
#ngram-count -text $dir/train -order 3 -limit-vocab -vocab $dir/wordlist -unk \
# -map-unk "<UNK>" -lm $dir/lm2.arpa

@ -0,0 +1,17 @@
# This contains the locations of binarys build required for running the examples.
MAIN_ROOT=`realpath $PWD/../../../`
SPEECHX_ROOT=`realpath $MAIN_ROOT/speechx`
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
export LC_AL=C
# srilm
export LIBLBFGS=${MAIN_ROOT}/tools/liblbfgs-1.10
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH:-}:${LIBLBFGS}/lib/.libs
export SRILM=${MAIN_ROOT}/tools/srilm
# kaldi lm
KALDI_DIR=$SPEECHX_ROOT/build/speechx/kaldi/
OPENFST_DIR=$SPEECHX_ROOT/fc_patch/openfst-build/src
export PATH=${PATH}:${SRILM}/bin:${SRILM}/bin/i686-m64:$KALDI_DIR/lmbin:$KALDI_DIR/fstbin:$OPENFST_DIR/bin:$SPEECHX_EXAMPLES/ds2_ol/decoder

@ -0,0 +1,88 @@
#!/bin/bash
set +x
set -e
export GLOG_logtostderr=1
. ./path.sh || exit 1;
# ds2 means deepspeech2 (acoutic model type)
dir=$PWD/ds2_graph_with_slot
data=$PWD/data
stage=0
stop_stage=10
mkdir -p $dir
model_dir=$PWD/resource/model
vocab=$model_dir/vocab.txt
cmvn=$data/cmvn.ark
text_with_slot=$data/text_with_slot
resource=$PWD/resource
# download resource
if [ ! -f $cmvn ]; then
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/resource.tar.gz
tar xzfv resource.tar.gz
ln -s ./resource/data .
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# make dict
unit_file=$vocab
mkdir -p $dir/local/dict
cp $unit_file $dir/local/dict/units.txt
cp $text_with_slot $dir/train_text
utils/fst/prepare_dict.py --unit_file $unit_file --in_lexicon $data/lexicon.txt \
--out_lexicon $dir/local/dict/lexicon.txt
# add slot to lexicon, just in case the lm training script filter the slot.
echo "<MONEY_SLOT> 一" >> $dir/local/dict/lexicon.txt
echo "<DATE_SLOT> 一" >> $dir/local/dict/lexicon.txt
echo "<ADDRESS_SLOT> 一" >> $dir/local/dict/lexicon.txt
echo "<YEAR_SLOT> 一" >> $dir/local/dict/lexicon.txt
echo "<TIME_SLOT> 一" >> $dir/local/dict/lexicon.txt
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# train lm
lm=$dir/local/lm
mkdir -p $lm
# this script is different with the common lm training script
local/train_lm_with_slot.sh
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# make T & L
local/compile_lexicon_token_fst.sh $dir/local/dict $dir/local/tmp $dir/local/lang
mkdir -p $dir/local/lang_test
# make slot graph
local/mk_slot_graph.sh $resource/graph $dir/local/lang_test
# make TLG
local/mk_tlg_with_slot.sh $dir/local/lm $dir/local/lang $dir/local/lang_test || exit 1;
mv $dir/local/lang_test/TLG.fst $dir/local/lang/
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# test TLG
model_dir=$PWD/resource/model
cmvn=$data/cmvn.ark
wav_scp=$data/wav.scp
graph=$dir/local/lang
recognizer_test_main \
--wav_rspecifier=scp:$wav_scp \
--cmvn_file=$cmvn \
--streaming_chunk=30 \
--use_fbank=true \
--model_path=$model_dir/avg_10.jit.pdmodel \
--param_path=$model_dir/avg_10.jit.pdiparams \
--model_cache_shapes="5-1-2048,5-1-2048" \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--word_symbol_table=$graph/words.txt \
--graph_path=$graph/TLG.fst --max_active=7500 \
--acoustic_scale=12 \
--result_wspecifier=ark,t:./result_run.txt
# the data/wav.trans is the label.
utils/compute-wer.py --char=1 --v=1 data/wav.trans result_run.txt > wer_run
tail -n 7 wer_run
fi

@ -7,3 +7,7 @@ add_subdirectory(matrix)
add_subdirectory(lat) add_subdirectory(lat)
add_subdirectory(fstext) add_subdirectory(fstext)
add_subdirectory(decoder) add_subdirectory(decoder)
add_subdirectory(lm)
add_subdirectory(fstbin)
add_subdirectory(lmbin)

@ -0,0 +1,15 @@
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
set(BINS
fstaddselfloops
fstisstochastic
fstminimizeencoded
fstdeterminizestar
fsttablecompose
)
foreach(binary IN LISTS BINS)
add_executable(${binary} ${CMAKE_CURRENT_SOURCE_DIR}/${binary}.cc)
target_include_directories(${binary} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${binary} PUBLIC kaldi-fstext glog gflags fst dl)
endforeach()

@ -1,5 +1,5 @@
add_library(kaldi-fstext add_library(kaldi-fstext
kaldi-fst-io.cc kaldi-fst-io.cc
) )
target_link_libraries(kaldi-fstext PUBLIC kaldi-util) target_link_libraries(kaldi-fstext PUBLIC kaldi-util)

@ -0,0 +1,6 @@
add_library(kaldi-lm
arpa-file-parser.cc
arpa-lm-compiler.cc
)
target_link_libraries(kaldi-lm PUBLIC kaldi-util)

@ -0,0 +1,281 @@
// lm/arpa-file-parser.cc
// Copyright 2014 Guoguo Chen
// Copyright 2016 Smart Action Company LLC (kkm)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include <fst/fstlib.h>
#include <sstream>
#include "base/kaldi-error.h"
#include "base/kaldi-math.h"
#include "lm/arpa-file-parser.h"
#include "util/text-utils.h"
namespace kaldi {
ArpaFileParser::ArpaFileParser(const ArpaParseOptions& options,
fst::SymbolTable* symbols)
: options_(options), symbols_(symbols),
line_number_(0), warning_count_(0) {
}
ArpaFileParser::~ArpaFileParser() {
}
void TrimTrailingWhitespace(std::string *str) {
str->erase(str->find_last_not_of(" \n\r\t") + 1);
}
void ArpaFileParser::Read(std::istream &is) {
// Argument sanity checks.
if (options_.bos_symbol <= 0 || options_.eos_symbol <= 0 ||
options_.bos_symbol == options_.eos_symbol)
KALDI_ERR << "BOS and EOS symbols are required, must not be epsilons, and "
<< "differ from each other. Given:"
<< " BOS=" << options_.bos_symbol
<< " EOS=" << options_.eos_symbol;
if (symbols_ != NULL &&
options_.oov_handling == ArpaParseOptions::kReplaceWithUnk &&
(options_.unk_symbol <= 0 ||
options_.unk_symbol == options_.bos_symbol ||
options_.unk_symbol == options_.eos_symbol))
KALDI_ERR << "When symbol table is given and OOV mode is kReplaceWithUnk, "
<< "UNK symbol is required, must not be epsilon, and "
<< "differ from both BOS and EOS symbols. Given:"
<< " UNK=" << options_.unk_symbol
<< " BOS=" << options_.bos_symbol
<< " EOS=" << options_.eos_symbol;
if (symbols_ != NULL && symbols_->Find(options_.bos_symbol).empty())
KALDI_ERR << "BOS symbol must exist in symbol table";
if (symbols_ != NULL && symbols_->Find(options_.eos_symbol).empty())
KALDI_ERR << "EOS symbol must exist in symbol table";
if (symbols_ != NULL && options_.unk_symbol > 0 &&
symbols_->Find(options_.unk_symbol).empty())
KALDI_ERR << "UNK symbol must exist in symbol table";
ngram_counts_.clear();
line_number_ = 0;
warning_count_ = 0;
current_line_.clear();
#define PARSE_ERR KALDI_ERR << LineReference() << ": "
// Give derived class an opportunity to prepare its state.
ReadStarted();
// Processes "\data\" section.
bool keyword_found = false;
while (++line_number_, getline(is, current_line_) && !is.eof()) {
if (current_line_.find_first_not_of(" \t\n\r") == std::string::npos) {
continue;
}
TrimTrailingWhitespace(&current_line_);
// Continue skipping lines until the \data\ marker alone on a line is found.
if (!keyword_found) {
if (current_line_ == "\\data\\") {
KALDI_LOG << "Reading \\data\\ section.";
keyword_found = true;
}
continue;
}
if (current_line_[0] == '\\') break;
// Enters "\data\" section, and looks for patterns like "ngram 1=1000",
// which means there are 1000 unigrams.
std::size_t equal_symbol_pos = current_line_.find("=");
if (equal_symbol_pos != std::string::npos)
// Guaranteed spaces around the "=".
current_line_.replace(equal_symbol_pos, 1, " = ");
std::vector<std::string> col;
SplitStringToVector(current_line_, " \t", true, &col);
if (col.size() == 4 && col[0] == "ngram" && col[2] == "=") {
int32 order, ngram_count = 0;
if (!ConvertStringToInteger(col[1], &order) ||
!ConvertStringToInteger(col[3], &ngram_count)) {
PARSE_ERR << "cannot parse ngram count";
}
if (ngram_counts_.size() <= order) {
ngram_counts_.resize(order);
}
ngram_counts_[order - 1] = ngram_count;
} else {
KALDI_WARN << LineReference()
<< ": uninterpretable line in \\data\\ section";
}
}
if (ngram_counts_.size() == 0)
PARSE_ERR << "\\data\\ section missing or empty.";
// Signal that grammar order and n-gram counts are known.
HeaderAvailable();
NGram ngram;
ngram.words.reserve(ngram_counts_.size());
// Processes "\N-grams:" section.
for (int32 cur_order = 1; cur_order <= ngram_counts_.size(); ++cur_order) {
// Skips n-grams with zero count.
if (ngram_counts_[cur_order - 1] == 0)
KALDI_WARN << "Zero ngram count in ngram order " << cur_order
<< "(look for 'ngram " << cur_order << "=0' in the \\data\\ "
<< " section). There is possibly a problem with the file.";
// Must be looking at a \k-grams: directive at this point.
std::ostringstream keyword;
keyword << "\\" << cur_order << "-grams:";
if (current_line_ != keyword.str()) {
PARSE_ERR << "invalid directive, expecting '" << keyword.str() << "'";
}
KALDI_LOG << "Reading " << current_line_ << " section.";
int32 ngram_count = 0;
while (++line_number_, getline(is, current_line_) && !is.eof()) {
if (current_line_.find_first_not_of(" \n\t\r") == std::string::npos) {
continue;
}
if (current_line_[0] == '\\') {
TrimTrailingWhitespace(&current_line_);
std::ostringstream next_keyword;
next_keyword << "\\" << cur_order + 1 << "-grams:";
if ((current_line_ != next_keyword.str()) &&
(current_line_ != "\\end\\")) {
if (ShouldWarn()) {
KALDI_WARN << "ignoring possible directive '" << current_line_
<< "' expecting '" << next_keyword.str() << "'";
if (warning_count_ > 0 &&
warning_count_ > static_cast<uint32>(options_.max_warnings)) {
KALDI_WARN << "Of " << warning_count_ << " parse warnings, "
<< options_.max_warnings << " were reported. "
<< "Run program with --max-arpa-warnings=-1 "
<< "to see all warnings";
}
}
} else {
break;
}
}
std::vector<std::string> col;
SplitStringToVector(current_line_, " \t", true, &col);
if (col.size() < 1 + cur_order ||
col.size() > 2 + cur_order ||
(cur_order == ngram_counts_.size() && col.size() != 1 + cur_order)) {
PARSE_ERR << "Invalid n-gram data line";
}
++ngram_count;
// Parse out n-gram logprob and, if present, backoff weight.
if (!ConvertStringToReal(col[0], &ngram.logprob)) {
PARSE_ERR << "invalid n-gram logprob '" << col[0] << "'";
}
ngram.backoff = 0.0;
if (col.size() > cur_order + 1) {
if (!ConvertStringToReal(col[cur_order + 1], &ngram.backoff))
PARSE_ERR << "invalid backoff weight '" << col[cur_order + 1] << "'";
}
// Convert to natural log.
ngram.logprob *= M_LN10;
ngram.backoff *= M_LN10;
ngram.words.resize(cur_order);
bool skip_ngram = false;
for (int32 index = 0; !skip_ngram && index < cur_order; ++index) {
int32 word;
if (symbols_) {
// Symbol table provided, so symbol labels are expected.
if (options_.oov_handling == ArpaParseOptions::kAddToSymbols) {
word = symbols_->AddSymbol(col[1 + index]);
} else {
word = symbols_->Find(col[1 + index]);
if (word == -1) { // fst::kNoSymbol
switch (options_.oov_handling) {
case ArpaParseOptions::kReplaceWithUnk:
word = options_.unk_symbol;
break;
case ArpaParseOptions::kSkipNGram:
if (ShouldWarn())
KALDI_WARN << LineReference() << " skipped: word '"
<< col[1 + index] << "' not in symbol table";
skip_ngram = true;
break;
default:
PARSE_ERR << "word '" << col[1 + index]
<< "' not in symbol table";
}
}
}
} else {
// Symbols not provided, LM file should contain integers.
if (!ConvertStringToInteger(col[1 + index], &word) || word < 0) {
PARSE_ERR << "invalid symbol '" << col[1 + index] << "'";
}
}
// Whichever way we got it, an epsilon is invalid.
if (word == 0) {
PARSE_ERR << "epsilon symbol '" << col[1 + index]
<< "' is illegal in ARPA LM";
}
ngram.words[index] = word;
}
if (!skip_ngram) {
ConsumeNGram(ngram);
}
}
if (ngram_count > ngram_counts_[cur_order - 1]) {
PARSE_ERR << "header said there would be " << ngram_counts_[cur_order - 1]
<< " n-grams of order " << cur_order
<< ", but we saw more already.";
}
}
if (current_line_ != "\\end\\") {
PARSE_ERR << "invalid or unexpected directive line, expecting \\end\\";
}
if (warning_count_ > 0 &&
warning_count_ > static_cast<uint32>(options_.max_warnings)) {
KALDI_WARN << "Of " << warning_count_ << " parse warnings, "
<< options_.max_warnings << " were reported. Run program with "
<< "--max-arpa-warnings=-1 to see all warnings";
}
current_line_.clear();
ReadComplete();
#undef PARSE_ERR
}
std::string ArpaFileParser::LineReference() const {
std::ostringstream ss;
ss << "line " << line_number_ << " [" << current_line_ << "]";
return ss.str();
}
bool ArpaFileParser::ShouldWarn() {
return (warning_count_ != -1) &&
(++warning_count_ <= static_cast<uint32>(options_.max_warnings));
}
} // namespace kaldi

@ -0,0 +1,146 @@
// lm/arpa-file-parser.h
// Copyright 2014 Guoguo Chen
// Copyright 2016 Smart Action Company LLC (kkm)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_LM_ARPA_FILE_PARSER_H_
#define KALDI_LM_ARPA_FILE_PARSER_H_
#include <fst/fst-decl.h>
#include <string>
#include <vector>
#include "base/kaldi-types.h"
#include "util/options-itf.h"
namespace kaldi {
/**
Options that control ArpaFileParser
*/
struct ArpaParseOptions {
enum OovHandling {
kRaiseError, ///< Abort on OOV words
kAddToSymbols, ///< Add novel words to the symbol table.
kReplaceWithUnk, ///< Replace OOV words with <unk>.
kSkipNGram ///< Skip n-gram with OOV word and continue.
};
ArpaParseOptions():
bos_symbol(-1), eos_symbol(-1), unk_symbol(-1),
oov_handling(kRaiseError), max_warnings(30) { }
void Register(OptionsItf *opts) {
// Registering only the max_warnings count, since other options are
// treated differently by client programs: some want integer symbols,
// while other are passed words in their command line.
opts->Register("max-arpa-warnings", &max_warnings,
"Maximum warnings to report on ARPA parsing, "
"0 to disable, -1 to show all");
}
int32 bos_symbol; ///< Symbol for <s>, Required non-epsilon.
int32 eos_symbol; ///< Symbol for </s>, Required non-epsilon.
int32 unk_symbol; ///< Symbol for <unk>, Required for kReplaceWithUnk.
OovHandling oov_handling; ///< How to handle OOV words in the file.
int32 max_warnings; ///< Maximum warnings to report, <0 unlimited.
};
/**
A parsed n-gram from ARPA LM file.
*/
struct NGram {
NGram() : logprob(0.0), backoff(0.0) { }
std::vector<int32> words; ///< Symbols in left to right order.
float logprob; ///< Log-prob of the n-gram.
float backoff; ///< log-backoff weight of the n-gram.
///< Defaults to zero if not specified.
};
/**
ArpaFileParser is an abstract base class for ARPA LM file conversion.
See ConstArpaLmBuilder and ArpaLmCompiler for usage examples.
*/
class ArpaFileParser {
public:
/// Constructs the parser with the given options and optional symbol table.
/// If symbol table is provided, then the file should contain text n-grams,
/// and the words are mapped to symbols through it. bos_symbol and
/// eos_symbol in the options structure must be valid symbols in the table,
/// and so must be unk_symbol if provided. The table is not owned by the
/// parser, but may be augmented, if oov_handling is set to kAddToSymbols.
/// If symbol table is a null pointer, the file should contain integer
/// symbol values, and oov_handling has no effect. bos_symbol and eos_symbol
/// must be valid symbols still.
ArpaFileParser(const ArpaParseOptions& options, fst::SymbolTable* symbols);
virtual ~ArpaFileParser();
/// Read ARPA LM file from a stream.
void Read(std::istream &is);
/// Parser options.
const ArpaParseOptions& Options() const { return options_; }
protected:
/// Override called before reading starts. This is the point to prepare
/// any state in the derived class.
virtual void ReadStarted() { }
/// Override function called to signal that ARPA header with the expected
/// number of n-grams has been read, and ngram_counts() is now valid.
virtual void HeaderAvailable() { }
/// Pure override that must be implemented to process current n-gram. The
/// n-grams are sent in the file order, which guarantees that all
/// (k-1)-grams are processed before the first k-gram is.
virtual void ConsumeNGram(const NGram&) = 0;
/// Override function called after the last n-gram has been consumed.
virtual void ReadComplete() { }
/// Read-only access to symbol table. Not owned, do not make public.
const fst::SymbolTable* Symbols() const { return symbols_; }
/// Inside ConsumeNGram(), provides the current line number.
int32 LineNumber() const { return line_number_; }
/// Inside ConsumeNGram(), returns a formatted reference to the line being
/// compiled, to print out as part of diagnostics.
std::string LineReference() const;
/// Increments warning count, and returns true if a warning should be
/// printed or false if the count has exceeded the set maximum.
bool ShouldWarn();
/// N-gram counts. Valid from the point when HeaderAvailable() is called.
const std::vector<int32>& NgramCounts() const { return ngram_counts_; }
private:
ArpaParseOptions options_;
fst::SymbolTable* symbols_; // the pointer is not owned here.
int32 line_number_;
uint32 warning_count_;
std::string current_line_;
std::vector<int32> ngram_counts_;
};
} // namespace kaldi
#endif // KALDI_LM_ARPA_FILE_PARSER_H_

@ -0,0 +1,377 @@
// lm/arpa-lm-compiler.cc
// Copyright 2009-2011 Gilles Boulianne
// Copyright 2016 Smart Action LLC (kkm)
// Copyright 2017 Xiaohui Zhang
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <limits>
#include <sstream>
#include <utility>
#include "base/kaldi-math.h"
#include "lm/arpa-lm-compiler.h"
#include "util/stl-utils.h"
#include "util/text-utils.h"
#include "fstext/remove-eps-local.h"
namespace kaldi {
class ArpaLmCompilerImplInterface {
public:
virtual ~ArpaLmCompilerImplInterface() { }
virtual void ConsumeNGram(const NGram& ngram, bool is_highest) = 0;
};
namespace {
typedef int32 StateId;
typedef int32 Symbol;
// GeneralHistKey can represent state history in an arbitrarily large n
// n-gram model with symbol ids fitting int32.
class GeneralHistKey {
public:
// Construct key from being and end iterators.
template<class InputIt>
GeneralHistKey(InputIt begin, InputIt end) : vector_(begin, end) { }
// Construct empty history key.
GeneralHistKey() : vector_() { }
// Return tails of the key as a GeneralHistKey. The tails of an n-gram
// w[1..n] is the sequence w[2..n] (and the heads is w[1..n-1], but the
// key class does not need this operartion).
GeneralHistKey Tails() const {
return GeneralHistKey(vector_.begin() + 1, vector_.end());
}
// Keys are equal if represent same state.
friend bool operator==(const GeneralHistKey& a, const GeneralHistKey& b) {
return a.vector_ == b.vector_;
}
// Public typename HashType for hashing.
struct HashType : public std::unary_function<GeneralHistKey, size_t> {
size_t operator()(const GeneralHistKey& key) const {
return VectorHasher<Symbol>().operator()(key.vector_);
}
};
private:
std::vector<Symbol> vector_;
};
// OptimizedHistKey combines 3 21-bit symbol ID values into one 64-bit
// machine word. allowing significant memory reduction and some runtime
// benefit over GeneralHistKey. Since 3 symbols are enough to track history
// in a 4-gram model, this optimized key is used for smaller models with up
// to 4-gram and symbol values up to 2^21-1.
//
// See GeneralHistKey for interface requirements of a key class.
class OptimizedHistKey {
public:
enum {
kShift = 21, // 21 * 3 = 63 bits for data.
kMaxData = (1 << kShift) - 1
};
template<class InputIt>
OptimizedHistKey(InputIt begin, InputIt end) : data_(0) {
for (uint32 shift = 0; begin != end; ++begin, shift += kShift) {
data_ |= static_cast<uint64>(*begin) << shift;
}
}
OptimizedHistKey() : data_(0) { }
OptimizedHistKey Tails() const {
return OptimizedHistKey(data_ >> kShift);
}
friend bool operator==(const OptimizedHistKey& a, const OptimizedHistKey& b) {
return a.data_ == b.data_;
}
struct HashType : public std::unary_function<OptimizedHistKey, size_t> {
size_t operator()(const OptimizedHistKey& key) const { return key.data_; }
};
private:
explicit OptimizedHistKey(uint64 data) : data_(data) { }
uint64 data_;
};
} // namespace
template <class HistKey>
class ArpaLmCompilerImpl : public ArpaLmCompilerImplInterface {
public:
ArpaLmCompilerImpl(ArpaLmCompiler* parent, fst::StdVectorFst* fst,
Symbol sub_eps);
virtual void ConsumeNGram(const NGram &ngram, bool is_highest);
private:
StateId AddStateWithBackoff(HistKey key, float backoff);
void CreateBackoff(HistKey key, StateId state, float weight);
ArpaLmCompiler *parent_; // Not owned.
fst::StdVectorFst* fst_; // Not owned.
Symbol bos_symbol_;
Symbol eos_symbol_;
Symbol sub_eps_;
StateId eos_state_;
typedef unordered_map<HistKey, StateId,
typename HistKey::HashType> HistoryMap;
HistoryMap history_;
};
template <class HistKey>
ArpaLmCompilerImpl<HistKey>::ArpaLmCompilerImpl(
ArpaLmCompiler* parent, fst::StdVectorFst* fst, Symbol sub_eps)
: parent_(parent), fst_(fst), bos_symbol_(parent->Options().bos_symbol),
eos_symbol_(parent->Options().eos_symbol), sub_eps_(sub_eps) {
// The algorithm maintains state per history. The 0-gram is a special state
// for empty history. All unigrams (including BOS) backoff into this state.
StateId zerogram = fst_->AddState();
history_[HistKey()] = zerogram;
// Also, if </s> is not treated as epsilon, create a common end state for
// all transitions accepting the </s>, since they do not back off. This small
// optimization saves about 2% states in an average grammar.
if (sub_eps_ == 0) {
eos_state_ = fst_->AddState();
fst_->SetFinal(eos_state_, 0);
}
}
template <class HistKey>
void ArpaLmCompilerImpl<HistKey>::ConsumeNGram(const NGram &ngram,
bool is_highest) {
// Generally, we do the following. Suppose we are adding an n-gram "A B
// C". Then find the node for "A B", add a new node for "A B C", and connect
// them with the arc accepting "C" with the specified weight. Also, add a
// backoff arc from the new "A B C" node to its backoff state "B C".
//
// Two notable exceptions are the highest order n-grams, and final n-grams.
//
// When adding a highest order n-gram (e. g., our "A B C" is in a 3-gram LM),
// the following optimization is performed. There is no point adding a node
// for "A B C" with a "C" arc from "A B", since there will be no other
// arcs ingoing to this node, and an epsilon backoff arc into the backoff
// model "B C", with the weight of \bar{1}. To save a node, create an arc
// accepting "C" directly from "A B" to "B C". This saves as many nodes
// as there are the highest order n-grams, which is typically about half
// the size of a large 3-gram model.
//
// Indeed, this does not apply to n-grams ending in EOS, since they do not
// back off. These are special, as they do not have a back-off state, and
// the node for "(..anything..) </s>" is always final. These are handled
// in one of the two possible ways, If symbols <s> and </s> are being
// replaced by epsilons, neither node nor arc is created, and the logprob
// of the n-gram is applied to its source node as final weight. If <s> and
// </s> are preserved, then a special final node for </s> is allocated and
// used as the destination of the "</s>" acceptor arc.
HistKey heads(ngram.words.begin(), ngram.words.end() - 1);
typename HistoryMap::iterator source_it = history_.find(heads);
if (source_it == history_.end()) {
// There was no "A B", therefore the probability of "A B C" is zero.
// Print a warning and discard current n-gram.
if (parent_->ShouldWarn())
KALDI_WARN << parent_->LineReference()
<< " skipped: no parent (n-1)-gram exists";
return;
}
StateId source = source_it->second;
StateId dest;
Symbol sym = ngram.words.back();
float weight = -ngram.logprob;
if (sym == sub_eps_ || sym == 0) {
KALDI_ERR << " <eps> or disambiguation symbol " << sym << "found in the ARPA file. ";
}
if (sym == eos_symbol_) {
if (sub_eps_ == 0) {
// Keep </s> as a real symbol when not substituting.
dest = eos_state_;
} else {
// Treat </s> as if it was epsilon: mark source final, with the weight
// of the n-gram.
fst_->SetFinal(source, weight);
return;
}
} else {
// For the highest order n-gram, this may find an existing state, for
// non-highest, will create one (unless there are duplicate n-grams
// in the grammar, which cannot be reliably detected if highest order,
// so we better do not do that at all).
dest = AddStateWithBackoff(
HistKey(ngram.words.begin() + (is_highest ? 1 : 0),
ngram.words.end()),
-ngram.backoff);
}
if (sym == bos_symbol_) {
weight = 0; // Accepting <s> is always free.
if (sub_eps_ == 0) {
// <s> is as a real symbol, only accepted in the start state.
source = fst_->AddState();
fst_->SetStart(source);
} else {
// The new state for <s> unigram history *is* the start state.
fst_->SetStart(dest);
return;
}
}
// Add arc from source to dest, whichever way it was found.
fst_->AddArc(source, fst::StdArc(sym, sym, weight, dest));
return;
}
// Find or create a new state for n-gram defined by key, and ensure it has a
// backoff transition. The key is either the current n-gram for all but
// highest orders, or the tails of the n-gram for the highest order. The
// latter arises from the chain-collapsing optimization described above.
template <class HistKey>
StateId ArpaLmCompilerImpl<HistKey>::AddStateWithBackoff(HistKey key,
float backoff) {
typename HistoryMap::iterator dest_it = history_.find(key);
if (dest_it != history_.end()) {
// Found an existing state in the history map. Invariant: if the state in
// the map, then its backoff arc is in the FST. We are done.
return dest_it->second;
}
// Otherwise create a new state and its backoff arc, and register in the map.
StateId dest = fst_->AddState();
history_[key] = dest;
CreateBackoff(key.Tails(), dest, backoff);
return dest;
}
// Create a backoff arc for a state. Key is a backoff destination that may or
// may not exist. When the destination is not found, naturally fall back to
// the lower order model, and all the way down until one is found (since the
// 0-gram model is always present, the search is guaranteed to terminate).
template <class HistKey>
inline void ArpaLmCompilerImpl<HistKey>::CreateBackoff(
HistKey key, StateId state, float weight) {
typename HistoryMap::iterator dest_it = history_.find(key);
while (dest_it == history_.end()) {
key = key.Tails();
dest_it = history_.find(key);
}
// The arc should transduce either <eos> or #0 to <eps>, depending on the
// epsilon substitution mode. This is the only case when input and output
// label may differ.
fst_->AddArc(state, fst::StdArc(sub_eps_, 0, weight, dest_it->second));
}
ArpaLmCompiler::~ArpaLmCompiler() {
if (impl_ != NULL)
delete impl_;
}
void ArpaLmCompiler::HeaderAvailable() {
KALDI_ASSERT(impl_ == NULL);
// Use optimized implementation if the grammar is 4-gram or less, and the
// maximum attained symbol id will fit into the optimized range.
int64 max_symbol = 0;
if (Symbols() != NULL)
max_symbol = Symbols()->AvailableKey() - 1;
// If augmenting the symbol table, assume the worst case when all words in
// the model being read are novel.
if (Options().oov_handling == ArpaParseOptions::kAddToSymbols)
max_symbol += NgramCounts()[0];
if (NgramCounts().size() <= 4 && max_symbol < OptimizedHistKey::kMaxData) {
impl_ = new ArpaLmCompilerImpl<OptimizedHistKey>(this, &fst_, sub_eps_);
} else {
impl_ = new ArpaLmCompilerImpl<GeneralHistKey>(this, &fst_, sub_eps_);
KALDI_LOG << "Reverting to slower state tracking because model is large: "
<< NgramCounts().size() << "-gram with symbols up to "
<< max_symbol;
}
}
void ArpaLmCompiler::ConsumeNGram(const NGram &ngram) {
// <s> is invalid in tails, </s> in heads of an n-gram.
for (int i = 0; i < ngram.words.size(); ++i) {
if ((i > 0 && ngram.words[i] == Options().bos_symbol) ||
(i + 1 < ngram.words.size()
&& ngram.words[i] == Options().eos_symbol)) {
if (ShouldWarn())
KALDI_WARN << LineReference()
<< " skipped: n-gram has invalid BOS/EOS placement";
return;
}
}
bool is_highest = ngram.words.size() == NgramCounts().size();
impl_->ConsumeNGram(ngram, is_highest);
}
void ArpaLmCompiler::RemoveRedundantStates() {
fst::StdArc::Label backoff_symbol = sub_eps_;
if (backoff_symbol == 0) {
// The method of removing redundant states implemented in this function
// leads to slow determinization of L o G when people use the older style of
// usage of arpa2fst where the --disambig-symbol option was not specified.
// The issue seems to be that it creates a non-deterministic FST, while G is
// supposed to be deterministic. By 'return'ing below, we just disable this
// method if people were using an older script. This method isn't really
// that consequential anyway, and people will move to the newer-style
// scripts (see current utils/format_lm.sh), so this isn't much of a
// problem.
return;
}
fst::StdArc::StateId num_states = fst_.NumStates();
// replace the #0 symbols on the input of arcs out of redundant states (states
// that are not final and have only a backoff arc leaving them), with <eps>.
for (fst::StdArc::StateId state = 0; state < num_states; state++) {
if (fst_.NumArcs(state) == 1 && fst_.Final(state) == fst::TropicalWeight::Zero()) {
fst::MutableArcIterator<fst::StdVectorFst> iter(&fst_, state);
fst::StdArc arc = iter.Value();
if (arc.ilabel == backoff_symbol) {
arc.ilabel = 0;
iter.SetValue(arc);
}
}
}
// we could call fst::RemoveEps, and it would have the same effect in normal
// cases, where backoff_symbol != 0 and there are no epsilons in unexpected
// places, but RemoveEpsLocal is a bit safer in case something weird is going
// on; it guarantees not to blow up the FST.
fst::RemoveEpsLocal(&fst_);
KALDI_LOG << "Reduced num-states from " << num_states << " to "
<< fst_.NumStates();
}
void ArpaLmCompiler::Check() const {
if (fst_.Start() == fst::kNoStateId) {
KALDI_ERR << "Arpa file did not contain the beginning-of-sentence symbol "
<< Symbols()->Find(Options().bos_symbol) << ".";
}
}
void ArpaLmCompiler::ReadComplete() {
fst_.SetInputSymbols(Symbols());
fst_.SetOutputSymbols(Symbols());
RemoveRedundantStates();
Check();
}
} // namespace kaldi

@ -0,0 +1,65 @@
// lm/arpa-lm-compiler.h
// Copyright 2009-2011 Gilles Boulianne
// Copyright 2016 Smart Action LLC (kkm)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_LM_ARPA_LM_COMPILER_H_
#define KALDI_LM_ARPA_LM_COMPILER_H_
#include <fst/fstlib.h>
#include "lm/arpa-file-parser.h"
namespace kaldi {
class ArpaLmCompilerImplInterface;
class ArpaLmCompiler : public ArpaFileParser {
public:
ArpaLmCompiler(const ArpaParseOptions& options, int sub_eps,
fst::SymbolTable* symbols)
: ArpaFileParser(options, symbols),
sub_eps_(sub_eps), impl_(NULL) {
}
~ArpaLmCompiler();
const fst::StdVectorFst& Fst() const { return fst_; }
fst::StdVectorFst* MutableFst() { return &fst_; }
protected:
// ArpaFileParser overrides.
virtual void HeaderAvailable();
virtual void ConsumeNGram(const NGram& ngram);
virtual void ReadComplete();
private:
// this function removes states that only have a backoff arc coming
// out of them.
void RemoveRedundantStates();
void Check() const;
int sub_eps_;
ArpaLmCompilerImplInterface* impl_; // Owned.
fst::StdVectorFst fst_;
template <class HistKey> friend class ArpaLmCompilerImpl;
};
} // namespace kaldi
#endif // KALDI_LM_ARPA_LM_COMPILER_H_

@ -1,5 +1,4 @@
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_executable(arpa2fst ${CMAKE_CURRENT_SOURCE_DIR}/arpa2fst.cc) add_executable(arpa2fst ${CMAKE_CURRENT_SOURCE_DIR}/arpa2fst.cc)
target_include_directories(arpa2fst PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) target_include_directories(arpa2fst PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(arpa2fst ) target_link_libraries(arpa2fst PUBLIC kaldi-lm glog gflags fst)
Loading…
Cancel
Save