diff --git a/speechx/examples/aishell/run.sh b/speechx/examples/aishell/run.sh index a21ba086..8a16a865 100755 --- a/speechx/examples/aishell/run.sh +++ b/speechx/examples/aishell/run.sh @@ -48,7 +48,7 @@ wer=./aishell_wer nj=40 export GLOG_logtostderr=1 -./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj +#./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj data=$PWD/data # 3. gen linear feat @@ -72,10 +72,42 @@ utils/run.pl JOB=1:$nj $data/split${nj}/JOB/log \ --param_path=$aishell_online_model/avg_1.jit.pdiparams \ --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ --dict_file=$lm_model_dir/vocab.txt \ - --lm_path=$lm_model_dir/avg_1.jit.klm \ --result_wspecifier=ark,t:$data/split${nj}/JOB/result -cat $data/split${nj}/*/result > $label_file +cat $data/split${nj}/*/result > ${label_file} +local/compute-wer.py --char=1 --v=1 ${label_file} $text > ${wer} + +# 4. decode with lm +utils/run.pl JOB=1:$nj $data/split${nj}/JOB/log_lm \ + offline_decoder_sliding_chunk_main \ + --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \ + --model_path=$aishell_online_model/avg_1.jit.pdmodel \ + --param_path=$aishell_online_model/avg_1.jit.pdiparams \ + --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ + --dict_file=$lm_model_dir/vocab.txt \ + --lm_path=$lm_model_dir/avg_1.jit.klm \ + --result_wspecifier=ark,t:$data/split${nj}/JOB/result_lm + +cat $data/split${nj}/*/result_lm > ${label_file}_lm +local/compute-wer.py --char=1 --v=1 ${label_file}_lm $text > ${wer}_lm + +graph_dir=./aishell_graph +if [ ! -d $ ]; then + wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_graph.zip + unzip -d aishell_graph.zip +fi + +# 5. test TLG decoder +utils/run.pl JOB=1:$nj $data/split${nj}/JOB/log_tlg \ + offline_wfst_decoder_main \ + --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \ + --model_path=$aishell_online_model/avg_1.jit.pdmodel \ + --param_path=$aishell_online_model/avg_1.jit.pdiparams \ + --word_symbol_table=$graph_dir/words.txt \ + --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ + --graph_path=$graph_dir/TLG.fst --max_active=7500 \ + --acoustic_scale=1.2 \ + --result_wspecifier=ark,t:$data/split${nj}/JOB/result_tlg -local/compute-wer.py --char=1 --v=1 $label_file $text > $wer -tail $wer +cat $data/split${nj}/*/result_tlg > ${label_file}_tlg +local/compute-wer.py --char=1 --v=1 ${label_file}_tlg $text > ${wer}_tlg \ No newline at end of file diff --git a/speechx/examples/decoder/CMakeLists.txt b/speechx/examples/decoder/CMakeLists.txt index ded423e9..d446a671 100644 --- a/speechx/examples/decoder/CMakeLists.txt +++ b/speechx/examples/decoder/CMakeLists.txt @@ -8,6 +8,10 @@ add_executable(offline_decoder_main ${CMAKE_CURRENT_SOURCE_DIR}/offline_decoder_ target_include_directories(offline_decoder_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) target_link_libraries(offline_decoder_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS}) +add_executable(offline_wfst_decoder_main ${CMAKE_CURRENT_SOURCE_DIR}/offline_wfst_decoder_main.cc) +target_include_directories(offline_wfst_decoder_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) +target_link_libraries(offline_wfst_decoder_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder ${DEPS}) + add_executable(decoder_test_main ${CMAKE_CURRENT_SOURCE_DIR}/decoder_test_main.cc) target_include_directories(decoder_test_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) target_link_libraries(decoder_test_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS}) diff --git a/speechx/examples/decoder/offline_decoder_sliding_chunk_main.cc b/speechx/examples/decoder/offline_decoder_sliding_chunk_main.cc index be56342f..40092de3 100644 --- a/speechx/examples/decoder/offline_decoder_sliding_chunk_main.cc +++ b/speechx/examples/decoder/offline_decoder_sliding_chunk_main.cc @@ -27,7 +27,7 @@ DEFINE_string(result_wspecifier, "", "test result wspecifier"); DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model"); DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param"); DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm"); -DEFINE_string(lm_path, "lm.klm", "language model"); +DEFINE_string(lm_path, "", "language model"); DEFINE_int32(receptive_field_length, 7, "receptive field of two CNN(kernel=5) downsampling module."); @@ -45,7 +45,6 @@ using kaldi::BaseFloat; using kaldi::Matrix; using std::vector; - // test ds2 online decoder by feeding speech feature int main(int argc, char* argv[]) { gflags::ParseCommandLineFlags(&argc, &argv, false); @@ -63,7 +62,6 @@ int main(int argc, char* argv[]) { LOG(INFO) << "dict path: " << dict_file; LOG(INFO) << "lm path: " << lm_path; - int32 num_done = 0, num_err = 0; ppspeech::CTCBeamSearchOptions opts; @@ -138,10 +136,16 @@ int main(int argc, char* argv[]) { } std::string result; result = decoder.GetFinalBestPath(); - KALDI_LOG << " the result of " << utt << " is " << result; - result_writer.Write(utt, result); decodable->Reset(); decoder.Reset(); + if (result.empty()) { + // the TokenWriter can not write empty string. + ++num_err; + KALDI_LOG << " the result of " << utt << " is empty"; + continue; + } + KALDI_LOG << " the result of " << utt << " is " << result; + result_writer.Write(utt, result); ++num_done; } diff --git a/speechx/examples/decoder/offline_wfst_decoder_main.cc b/speechx/examples/decoder/offline_wfst_decoder_main.cc new file mode 100644 index 00000000..06460a45 --- /dev/null +++ b/speechx/examples/decoder/offline_wfst_decoder_main.cc @@ -0,0 +1,158 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// todo refactor, repalce with gtest + +#include "base/flags.h" +#include "base/log.h" +#include "decoder/ctc_tlg_decoder.h" +#include "frontend/audio/data_cache.h" +#include "kaldi/util/table-types.h" +#include "nnet/decodable.h" +#include "nnet/paddle_nnet.h" + +DEFINE_string(feature_rspecifier, "", "test feature rspecifier"); +DEFINE_string(result_wspecifier, "", "test result wspecifier"); +DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model"); +DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param"); +DEFINE_string(word_symbol_table, "words.txt", "word symbol table"); +DEFINE_string(graph_path, "TLG", "decoder graph"); +DEFINE_double(acoustic_scale, 1.0, "acoustic scale"); +DEFINE_int32(max_active, 7500, "decoder graph"); +DEFINE_int32(receptive_field_length, + 7, + "receptive field of two CNN(kernel=5) downsampling module."); +DEFINE_int32(downsampling_rate, + 4, + "two CNN(kernel=5) module downsampling rate."); +DEFINE_string(model_output_names, + "save_infer_model/scale_0.tmp_1,save_infer_model/" + "scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/" + "scale_3.tmp_1", + "model output names"); +DEFINE_string(model_cache_names, "5-1-1024,5-1-1024", "model cache names"); + +using kaldi::BaseFloat; +using kaldi::Matrix; +using std::vector; + +// test TLG decoder by feeding speech feature. +int main(int argc, char* argv[]) { + gflags::ParseCommandLineFlags(&argc, &argv, false); + google::InitGoogleLogging(argv[0]); + + kaldi::SequentialBaseFloatMatrixReader feature_reader( + FLAGS_feature_rspecifier); + kaldi::TokenWriter result_writer(FLAGS_result_wspecifier); + std::string model_graph = FLAGS_model_path; + std::string model_params = FLAGS_param_path; + std::string word_symbol_table = FLAGS_word_symbol_table; + std::string graph_path = FLAGS_graph_path; + LOG(INFO) << "model path: " << model_graph; + LOG(INFO) << "model param: " << model_params; + LOG(INFO) << "word symbol path: " << word_symbol_table; + LOG(INFO) << "graph path: " << graph_path; + + int32 num_done = 0, num_err = 0; + + ppspeech::TLGDecoderOptions opts; + opts.word_symbol_table = word_symbol_table; + opts.fst_path = graph_path; + opts.opts.max_active = FLAGS_max_active; + opts.opts.beam = 15.0; + opts.opts.lattice_beam = 7.5; + ppspeech::TLGDecoder decoder(opts); + + ppspeech::ModelOptions model_opts; + model_opts.model_path = model_graph; + model_opts.params_path = model_params; + model_opts.cache_shape = FLAGS_model_cache_names; + model_opts.output_names = FLAGS_model_output_names; + std::shared_ptr nnet( + new ppspeech::PaddleNnet(model_opts)); + std::shared_ptr raw_data(new ppspeech::DataCache()); + std::shared_ptr decodable( + new ppspeech::Decodable(nnet, raw_data, FLAGS_acoustic_scale)); + + int32 chunk_size = FLAGS_receptive_field_length; + int32 chunk_stride = FLAGS_downsampling_rate; + int32 receptive_field_length = FLAGS_receptive_field_length; + LOG(INFO) << "chunk size (frame): " << chunk_size; + LOG(INFO) << "chunk stride (frame): " << chunk_stride; + LOG(INFO) << "receptive field (frame): " << receptive_field_length; + decoder.InitDecoder(); + + for (; !feature_reader.Done(); feature_reader.Next()) { + string utt = feature_reader.Key(); + kaldi::Matrix feature = feature_reader.Value(); + raw_data->SetDim(feature.NumCols()); + LOG(INFO) << "process utt: " << utt; + LOG(INFO) << "rows: " << feature.NumRows(); + LOG(INFO) << "cols: " << feature.NumCols(); + + int32 row_idx = 0; + int32 padding_len = 0; + int32 ori_feature_len = feature.NumRows(); + if ((feature.NumRows() - chunk_size) % chunk_stride != 0) { + padding_len = + chunk_stride - (feature.NumRows() - chunk_size) % chunk_stride; + feature.Resize(feature.NumRows() + padding_len, + feature.NumCols(), + kaldi::kCopyData); + } + int32 num_chunks = (feature.NumRows() - chunk_size) / chunk_stride + 1; + for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) { + kaldi::Vector feature_chunk(chunk_size * + feature.NumCols()); + int32 feature_chunk_size = 0; + if (ori_feature_len > chunk_idx * chunk_stride) { + feature_chunk_size = std::min( + ori_feature_len - chunk_idx * chunk_stride, chunk_size); + } + if (feature_chunk_size < receptive_field_length) break; + + int32 start = chunk_idx * chunk_stride; + for (int row_id = 0; row_id < chunk_size; ++row_id) { + kaldi::SubVector tmp(feature, start); + kaldi::SubVector f_chunk_tmp( + feature_chunk.Data() + row_id * feature.NumCols(), + feature.NumCols()); + f_chunk_tmp.CopyFromVec(tmp); + ++start; + } + raw_data->Accept(feature_chunk); + if (chunk_idx == num_chunks - 1) { + raw_data->SetFinished(); + } + decoder.AdvanceDecode(decodable); + } + std::string result; + result = decoder.GetFinalBestPath(); + decodable->Reset(); + decoder.Reset(); + if (result.empty()) { + // the TokenWriter can not write empty string. + ++num_err; + KALDI_LOG << " the result of " << utt << " is empty"; + continue; + } + KALDI_LOG << " the result of " << utt << " is " << result; + result_writer.Write(utt, result); + ++num_done; + } + + KALDI_LOG << "Done " << num_done << " utterances, " << num_err + << " with errors."; + return (num_done != 0 ? 0 : 1); +} diff --git a/speechx/speechx/decoder/CMakeLists.txt b/speechx/speechx/decoder/CMakeLists.txt index 7cd281b6..ee0863fd 100644 --- a/speechx/speechx/decoder/CMakeLists.txt +++ b/speechx/speechx/decoder/CMakeLists.txt @@ -6,5 +6,6 @@ add_library(decoder STATIC ctc_decoders/decoder_utils.cpp ctc_decoders/path_trie.cpp ctc_decoders/scorer.cpp + ctc_tlg_decoder.cc ) -target_link_libraries(decoder PUBLIC kenlm utils fst) \ No newline at end of file +target_link_libraries(decoder PUBLIC kenlm utils fst) diff --git a/speechx/speechx/decoder/ctc_beam_search_decoder.cc b/speechx/speechx/decoder/ctc_beam_search_decoder.cc index 5d7a4f77..b4caa8e7 100644 --- a/speechx/speechx/decoder/ctc_beam_search_decoder.cc +++ b/speechx/speechx/decoder/ctc_beam_search_decoder.cc @@ -93,7 +93,7 @@ void CTCBeamSearch::AdvanceDecode( vector> likelihood; vector frame_prob; bool flag = - decodable->FrameLogLikelihood(num_frame_decoded_, &frame_prob); + decodable->FrameLikelihood(num_frame_decoded_, &frame_prob); if (flag == false) break; likelihood.push_back(frame_prob); AdvanceDecoding(likelihood); diff --git a/speechx/speechx/decoder/ctc_beam_search_decoder.h b/speechx/speechx/decoder/ctc_beam_search_decoder.h index 1387eee7..9d0a5d14 100644 --- a/speechx/speechx/decoder/ctc_beam_search_decoder.h +++ b/speechx/speechx/decoder/ctc_beam_search_decoder.h @@ -15,7 +15,7 @@ #include "base/common.h" #include "decoder/ctc_decoders/path_trie.h" #include "decoder/ctc_decoders/scorer.h" -#include "nnet/decodable-itf.h" +#include "kaldi/decoder/decodable-itf.h" #include "util/parse-options.h" #pragma once diff --git a/speechx/speechx/decoder/ctc_tlg_decoder.cc b/speechx/speechx/decoder/ctc_tlg_decoder.cc new file mode 100644 index 00000000..5365e709 --- /dev/null +++ b/speechx/speechx/decoder/ctc_tlg_decoder.cc @@ -0,0 +1,66 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "decoder/ctc_tlg_decoder.h" +namespace ppspeech { + +TLGDecoder::TLGDecoder(TLGDecoderOptions opts) { + fst_.reset(fst::Fst::Read(opts.fst_path)); + CHECK(fst_ != nullptr); + word_symbol_table_.reset( + fst::SymbolTable::ReadText(opts.word_symbol_table)); + decoder_.reset(new kaldi::LatticeFasterOnlineDecoder(*fst_, opts.opts)); + decoder_->InitDecoding(); + frame_decoded_size_ = 0; +} + +void TLGDecoder::InitDecoder() { + decoder_->InitDecoding(); + frame_decoded_size_ = 0; +} + +void TLGDecoder::AdvanceDecode( + const std::shared_ptr& decodable) { + while (!decodable->IsLastFrame(frame_decoded_size_)) { + LOG(INFO) << "num frame decode: " << frame_decoded_size_; + AdvanceDecoding(decodable.get()); + } +} + +void TLGDecoder::AdvanceDecoding(kaldi::DecodableInterface* decodable) { + decoder_->AdvanceDecoding(decodable, 1); + frame_decoded_size_++; +} + +void TLGDecoder::Reset() { + InitDecoder(); + return; +} + +std::string TLGDecoder::GetFinalBestPath() { + decoder_->FinalizeDecoding(); + kaldi::Lattice lat; + kaldi::LatticeWeight weight; + std::vector alignment; + std::vector words_id; + decoder_->GetBestPath(&lat, true); + fst::GetLinearSymbolSequence(lat, &alignment, &words_id, &weight); + std::string words; + for (int32 idx = 0; idx < words_id.size(); ++idx) { + std::string word = word_symbol_table_->Find(words_id[idx]); + words += word; + } + return words; +} +} \ No newline at end of file diff --git a/speechx/speechx/decoder/ctc_tlg_decoder.h b/speechx/speechx/decoder/ctc_tlg_decoder.h new file mode 100644 index 00000000..361c44af --- /dev/null +++ b/speechx/speechx/decoder/ctc_tlg_decoder.h @@ -0,0 +1,59 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "base/basic_types.h" +#include "kaldi/decoder/decodable-itf.h" +#include "kaldi/decoder/lattice-faster-online-decoder.h" +#include "util/parse-options.h" + +namespace ppspeech { + +struct TLGDecoderOptions { + kaldi::LatticeFasterDecoderConfig opts; + // todo remove later, add into decode resource + std::string word_symbol_table; + std::string fst_path; + + TLGDecoderOptions() : word_symbol_table(""), fst_path("") {} +}; + +class TLGDecoder { + public: + explicit TLGDecoder(TLGDecoderOptions opts); + void InitDecoder(); + void Decode(); + std::string GetBestPath(); + std::vector> GetNBestPath(); + std::string GetFinalBestPath(); + int NumFrameDecoded(); + int DecodeLikelihoods(const std::vector>& probs, + std::vector& nbest_words); + void AdvanceDecode( + const std::shared_ptr& decodable); + void Reset(); + + private: + void AdvanceDecoding(kaldi::DecodableInterface* decodable); + + std::shared_ptr decoder_; + std::shared_ptr> fst_; + std::shared_ptr word_symbol_table_; + // the frame size which have decoded starts from 0. + int32 frame_decoded_size_; +}; + + +} // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/frontend/audio/cmvn.cc b/speechx/speechx/frontend/audio/cmvn.cc index 4c1ffd6a..c7e446c9 100644 --- a/speechx/speechx/frontend/audio/cmvn.cc +++ b/speechx/speechx/frontend/audio/cmvn.cc @@ -120,4 +120,4 @@ void CMVN::ApplyCMVN(kaldi::MatrixBase* feats) { ApplyCmvn(stats_, var_norm_, feats); } -} // namespace ppspeech \ No newline at end of file +} // namespace ppspeech diff --git a/speechx/speechx/kaldi/CMakeLists.txt b/speechx/speechx/kaldi/CMakeLists.txt index 414a6fa0..6f7398cd 100644 --- a/speechx/speechx/kaldi/CMakeLists.txt +++ b/speechx/speechx/kaldi/CMakeLists.txt @@ -4,3 +4,6 @@ add_subdirectory(base) add_subdirectory(util) add_subdirectory(feat) add_subdirectory(matrix) +add_subdirectory(lat) +add_subdirectory(fstext) +add_subdirectory(decoder) diff --git a/speechx/speechx/kaldi/decoder/CMakeLists.txt b/speechx/speechx/kaldi/decoder/CMakeLists.txt new file mode 100644 index 00000000..f1ee6eab --- /dev/null +++ b/speechx/speechx/kaldi/decoder/CMakeLists.txt @@ -0,0 +1,6 @@ + +add_library(kaldi-decoder +lattice-faster-decoder.cc +lattice-faster-online-decoder.cc +) +target_link_libraries(kaldi-decoder PUBLIC kaldi-lat) diff --git a/speechx/speechx/nnet/decodable-itf.h b/speechx/speechx/kaldi/decoder/decodable-itf.h similarity index 98% rename from speechx/speechx/nnet/decodable-itf.h rename to speechx/speechx/kaldi/decoder/decodable-itf.h index 8e9a5a72..b8ce9143 100644 --- a/speechx/speechx/nnet/decodable-itf.h +++ b/speechx/speechx/kaldi/decoder/decodable-itf.h @@ -121,7 +121,7 @@ class DecodableInterface { /// decoding-from-matrix setting where we want to allow the last delta or /// LDA /// features to be flushed out for compatibility with the baseline setup. - virtual bool IsLastFrame(int32 frame) const = 0; + virtual bool IsLastFrame(int32 frame) = 0; /// The call NumFramesReady() will return the number of frames currently /// available @@ -143,7 +143,7 @@ class DecodableInterface { /// this is for compatibility with OpenFst). virtual int32 NumIndices() const = 0; - virtual bool FrameLogLikelihood( + virtual bool FrameLikelihood( int32 frame, std::vector* likelihood) = 0; diff --git a/speechx/speechx/kaldi/decoder/lattice-faster-decoder.cc b/speechx/speechx/kaldi/decoder/lattice-faster-decoder.cc index 42d1d2af..ae6b7160 100644 --- a/speechx/speechx/kaldi/decoder/lattice-faster-decoder.cc +++ b/speechx/speechx/kaldi/decoder/lattice-faster-decoder.cc @@ -1007,14 +1007,10 @@ template class LatticeFasterDecoderTpl, decoder::StdToken> template class LatticeFasterDecoderTpl, decoder::StdToken >; template class LatticeFasterDecoderTpl, decoder::StdToken >; -template class LatticeFasterDecoderTpl; -template class LatticeFasterDecoderTpl; template class LatticeFasterDecoderTpl , decoder::BackpointerToken>; template class LatticeFasterDecoderTpl, decoder::BackpointerToken >; template class LatticeFasterDecoderTpl, decoder::BackpointerToken >; -template class LatticeFasterDecoderTpl; -template class LatticeFasterDecoderTpl; } // end namespace kaldi. diff --git a/speechx/speechx/kaldi/decoder/lattice-faster-decoder.h b/speechx/speechx/kaldi/decoder/lattice-faster-decoder.h index 2016ad57..d142a8c7 100644 --- a/speechx/speechx/kaldi/decoder/lattice-faster-decoder.h +++ b/speechx/speechx/kaldi/decoder/lattice-faster-decoder.h @@ -23,11 +23,10 @@ #ifndef KALDI_DECODER_LATTICE_FASTER_DECODER_H_ #define KALDI_DECODER_LATTICE_FASTER_DECODER_H_ -#include "decoder/grammar-fst.h" #include "fst/fstlib.h" #include "fst/memory.h" #include "fstext/fstext-lib.h" -#include "itf/decodable-itf.h" +#include "decoder/decodable-itf.h" #include "lat/determinize-lattice-pruned.h" #include "lat/kaldi-lattice.h" #include "util/hash-list.h" diff --git a/speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.cc b/speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.cc index ebdace7e..b5261503 100644 --- a/speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.cc +++ b/speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.cc @@ -278,8 +278,8 @@ bool LatticeFasterOnlineDecoderTpl::GetRawLatticePruned( template class LatticeFasterOnlineDecoderTpl >; template class LatticeFasterOnlineDecoderTpl >; template class LatticeFasterOnlineDecoderTpl >; -template class LatticeFasterOnlineDecoderTpl; -template class LatticeFasterOnlineDecoderTpl; +//template class LatticeFasterOnlineDecoderTpl; +//template class LatticeFasterOnlineDecoderTpl; } // end namespace kaldi. diff --git a/speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.h b/speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.h index 8b10996f..f57368a4 100644 --- a/speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.h +++ b/speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.h @@ -30,7 +30,7 @@ #include "util/stl-utils.h" #include "util/hash-list.h" #include "fst/fstlib.h" -#include "itf/decodable-itf.h" +#include "decoder/decodable-itf.h" #include "fstext/fstext-lib.h" #include "lat/determinize-lattice-pruned.h" #include "lat/kaldi-lattice.h" diff --git a/speechx/speechx/kaldi/fstext/CMakeLists.txt b/speechx/speechx/kaldi/fstext/CMakeLists.txt new file mode 100644 index 00000000..af91fd98 --- /dev/null +++ b/speechx/speechx/kaldi/fstext/CMakeLists.txt @@ -0,0 +1,5 @@ + +add_library(kaldi-fstext +kaldi-fst-io.cc +) +target_link_libraries(kaldi-fstext PUBLIC kaldi-util) diff --git a/speechx/speechx/kaldi/fstext/determinize-lattice-inl.h b/speechx/speechx/kaldi/fstext/determinize-lattice-inl.h new file mode 100644 index 00000000..0bfbc8f4 --- /dev/null +++ b/speechx/speechx/kaldi/fstext/determinize-lattice-inl.h @@ -0,0 +1,1357 @@ +// fstext/determinize-lattice-inl.h + +// Copyright 2009-2012 Microsoft Corporation +// 2012-2013 Johns Hopkins University (Author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_FSTEXT_DETERMINIZE_LATTICE_INL_H_ +#define KALDI_FSTEXT_DETERMINIZE_LATTICE_INL_H_ +// Do not include this file directly. It is included by determinize-lattice.h + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace fst { + +// This class maps back and forth from/to integer id's to sequences of strings. +// used in determinization algorithm. It is constructed in such a way that +// finding the string-id of the successor of (string, next-label) has constant +// time. + +// Note: class IntType, typically int32, is the type of the element in the +// string (typically a template argument of the CompactLatticeWeightTpl). + +template +class LatticeStringRepository { + public: + struct Entry { + const Entry *parent; // NULL for empty string. + IntType i; + inline bool operator==(const Entry &other) const { + return (parent == other.parent && i == other.i); + } + Entry() {} + Entry(const Entry &e) : parent(e.parent), i(e.i) {} + }; + // Note: all Entry* pointers returned in function calls are + // owned by the repository itself, not by the caller! + + // Interface guarantees empty string is NULL. + inline const Entry *EmptyString() { return NULL; } + + // Returns string of "parent" with i appended. Pointer + // owned by repository + const Entry *Successor(const Entry *parent, IntType i) { + new_entry_->parent = parent; + new_entry_->i = i; + + std::pair pr = set_.insert(new_entry_); + if (pr.second) { // Was successfully inserted (was not there). We need to + // replace the element we inserted, which resides on the + // stack, with one from the heap. + const Entry *ans = new_entry_; + new_entry_ = new Entry(); + return ans; + } else { // Was not inserted because an equivalent Entry already + // existed. + return *pr.first; + } + } + + const Entry *Concatenate(const Entry *a, const Entry *b) { + if (a == NULL) + return b; + else if (b == NULL) + return a; + std::vector v; + ConvertToVector(b, &v); + const Entry *ans = a; + for (size_t i = 0; i < v.size(); i++) ans = Successor(ans, v[i]); + return ans; + } + const Entry *CommonPrefix(const Entry *a, const Entry *b) { + std::vector a_vec, b_vec; + ConvertToVector(a, &a_vec); + ConvertToVector(b, &b_vec); + const Entry *ans = NULL; + for (size_t i = 0; + i < a_vec.size() && i < b_vec.size() && a_vec[i] == b_vec[i]; i++) + ans = Successor(ans, a_vec[i]); + return ans; + } + + // removes any elements from b that are not part of + // a common prefix with a. + void ReduceToCommonPrefix(const Entry *a, std::vector *b) { + size_t a_size = Size(a), b_size = b->size(); + while (a_size > b_size) { + a = a->parent; + a_size--; + } + if (b_size > a_size) b_size = a_size; + typename std::vector::iterator b_begin = b->begin(); + while (a_size != 0) { + if (a->i != *(b_begin + a_size - 1)) b_size = a_size - 1; + a = a->parent; + a_size--; + } + if (b_size != b->size()) b->resize(b_size); + } + + // removes the first n elements of a. + const Entry *RemovePrefix(const Entry *a, size_t n) { + if (n == 0) return a; + std::vector a_vec; + ConvertToVector(a, &a_vec); + assert(a_vec.size() >= n); + const Entry *ans = NULL; + for (size_t i = n; i < a_vec.size(); i++) ans = Successor(ans, a_vec[i]); + return ans; + } + + // Returns true if a is a prefix of b. If a is prefix of b, + // time taken is |b| - |a|. Else, time taken is |b|. + bool IsPrefixOf(const Entry *a, const Entry *b) const { + if (a == NULL) return true; // empty string prefix of all. + if (a == b) return true; + if (b == NULL) return false; + return IsPrefixOf(a, b->parent); + } + + inline size_t Size(const Entry *entry) const { + size_t ans = 0; + while (entry != NULL) { + ans++; + entry = entry->parent; + } + return ans; + } + + void ConvertToVector(const Entry *entry, std::vector *out) const { + size_t length = Size(entry); + out->resize(length); + if (entry != NULL) { + typename std::vector::reverse_iterator iter = out->rbegin(); + while (entry != NULL) { + *iter = entry->i; + entry = entry->parent; + ++iter; + } + } + } + + const Entry *ConvertFromVector(const std::vector &vec) { + const Entry *e = NULL; + for (size_t i = 0; i < vec.size(); i++) e = Successor(e, vec[i]); + return e; + } + + LatticeStringRepository() { new_entry_ = new Entry; } + + void Destroy() { + for (typename SetType::iterator iter = set_.begin(); iter != set_.end(); + ++iter) + delete *iter; + SetType tmp; + tmp.swap(set_); + if (new_entry_) { + delete new_entry_; + new_entry_ = NULL; + } + } + + // Rebuild will rebuild this object, guaranteeing only + // to preserve the Entry values that are in the vector pointed + // to (this list does not have to be unique). The point of + // this is to save memory. + void Rebuild(const std::vector &to_keep) { + SetType tmp_set; + for (typename std::vector::const_iterator iter = + to_keep.begin(); + iter != to_keep.end(); ++iter) + RebuildHelper(*iter, &tmp_set); + // Now delete all elems not in tmp_set. + for (typename SetType::iterator iter = set_.begin(); iter != set_.end(); + ++iter) { + if (tmp_set.count(*iter) == 0) + delete (*iter); // delete the Entry; not needed. + } + set_.swap(tmp_set); + } + + ~LatticeStringRepository() { Destroy(); } + int32 MemSize() const { + return set_.size() * sizeof(Entry) * 2; // this is a lower bound + // on the size this structure might take. + } + + private: + class EntryKey { // Hash function object. + public: + inline size_t operator()(const Entry *entry) const { + size_t prime = 49109; + return static_cast(entry->i) + + prime * reinterpret_cast(entry->parent); + } + }; + class EntryEqual { + public: + inline bool operator()(const Entry *e1, const Entry *e2) const { + return (*e1 == *e2); + } + }; + typedef std::unordered_set SetType; + + void RebuildHelper(const Entry *to_add, SetType *tmp_set) { + while (true) { + if (to_add == NULL) return; + typename SetType::iterator iter = tmp_set->find(to_add); + if (iter == tmp_set->end()) { // not in tmp_set. + tmp_set->insert(to_add); + to_add = to_add->parent; // and loop. + } else { + return; + } + } + } + + KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeStringRepository); + Entry *new_entry_; // We always have a pre-allocated Entry ready to use, + // to avoid unnecessary news and deletes. + SetType set_; +}; + +// class LatticeDeterminizer is templated on the same types that +// CompactLatticeWeight is templated on: the base weight (Weight), typically +// LatticeWeightTpl etc. but could also be e.g. TropicalWeight, and the +// IntType, typically int32, used for the output symbols in the compact +// representation of strings [note: the output symbols would usually be +// p.d.f. id's in the anticipated use of this code] It has a special requirement +// on the Weight type: that there should be a Compare function on the weights +// such that Compare(w1, w2) returns -1 if w1 < w2, 0 if w1 == w2, and +1 if w1 +// > w2. This requires that there be a total order on the weights. + +template +class LatticeDeterminizer { + public: + // Output to Gallic acceptor (so the strings go on weights, and there is a 1-1 + // correspondence between our states and the states in ofst. If destroy == + // true, release memory as we go (but we cannot output again). + + typedef CompactLatticeWeightTpl CompactWeight; + typedef ArcTpl + CompactArc; // arc in compact, acceptor form of lattice + typedef ArcTpl Arc; // arc in non-compact version of lattice + + // Output to standard FST with CompactWeightTpl as its weight type + // (the weight stores the original output-symbol strings). If destroy == + // true, release memory as we go (but we cannot output again). + void Output(MutableFst *ofst, bool destroy = true) { + assert(determinized_); + typedef typename Arc::StateId StateId; + StateId nStates = static_cast(output_arcs_.size()); + if (destroy) FreeMostMemory(); + ofst->DeleteStates(); + ofst->SetStart(kNoStateId); + if (nStates == 0) { + return; + } + for (StateId s = 0; s < nStates; s++) { + OutputStateId news = ofst->AddState(); + assert(news == s); + } + ofst->SetStart(0); + // now process transitions. + for (StateId this_state = 0; this_state < nStates; this_state++) { + std::vector &this_vec(output_arcs_[this_state]); + typename std::vector::const_iterator iter = this_vec.begin(), + end = this_vec.end(); + + for (; iter != end; ++iter) { + const TempArc &temp_arc(*iter); + CompactArc new_arc; + std::vector