Merge pull request #1400 from SmileGoat/feature_dev
[speechx]add linear spectrogram feature extractorpull/1495/head
commit
b584b9690f
@ -0,0 +1,36 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <deque>
|
||||||
|
#include <iostream>
|
||||||
|
#include <istream>
|
||||||
|
#include <fstream>
|
||||||
|
#include <map>
|
||||||
|
#include <memory>
|
||||||
|
#include <ostream>
|
||||||
|
#include <set>
|
||||||
|
#include <sstream>
|
||||||
|
#include <stack>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <unordered_set>
|
||||||
|
#include <mutex>
|
||||||
|
|
||||||
|
#include "base/log.h"
|
||||||
|
#include "base/flags.h"
|
||||||
|
#include "base/basic_types.h"
|
||||||
|
#include "base/macros.h"
|
@ -0,0 +1,17 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "gflags/gflags.h"
|
@ -0,0 +1,17 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "glog/logging.h"
|
@ -0,0 +1,120 @@
|
|||||||
|
// Copyright (c) 2012 Jakob Progsch, Václav Zeman
|
||||||
|
|
||||||
|
// This software is provided 'as-is', without any express or implied
|
||||||
|
// warranty. In no event will the authors be held liable for any damages
|
||||||
|
// arising from the use of this software.
|
||||||
|
|
||||||
|
// Permission is granted to anyone to use this software for any purpose,
|
||||||
|
// including commercial applications, and to alter it and redistribute it
|
||||||
|
// freely, subject to the following restrictions:
|
||||||
|
|
||||||
|
// 1. The origin of this software must not be misrepresented; you must not
|
||||||
|
// claim that you wrote the original software. If you use this software
|
||||||
|
// in a product, an acknowledgment in the product documentation would be
|
||||||
|
// appreciated but is not required.
|
||||||
|
|
||||||
|
// 2. Altered source versions must be plainly marked as such, and must not be
|
||||||
|
// misrepresented as being the original software.
|
||||||
|
|
||||||
|
// 3. This notice may not be removed or altered from any source
|
||||||
|
// distribution.
|
||||||
|
// this code is from https://github.com/progschj/ThreadPool
|
||||||
|
|
||||||
|
#ifndef BASE_THREAD_POOL_H
|
||||||
|
#define BASE_THREAD_POOL_H
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <queue>
|
||||||
|
#include <memory>
|
||||||
|
#include <thread>
|
||||||
|
#include <mutex>
|
||||||
|
#include <condition_variable>
|
||||||
|
#include <future>
|
||||||
|
#include <functional>
|
||||||
|
#include <stdexcept>
|
||||||
|
|
||||||
|
class ThreadPool {
|
||||||
|
public:
|
||||||
|
ThreadPool(size_t);
|
||||||
|
template<class F, class... Args>
|
||||||
|
auto enqueue(F&& f, Args&&... args)
|
||||||
|
-> std::future<typename std::result_of<F(Args...)>::type>;
|
||||||
|
~ThreadPool();
|
||||||
|
private:
|
||||||
|
// need to keep track of threads so we can join them
|
||||||
|
std::vector< std::thread > workers;
|
||||||
|
// the task queue
|
||||||
|
std::queue< std::function<void()> > tasks;
|
||||||
|
|
||||||
|
// synchronization
|
||||||
|
std::mutex queue_mutex;
|
||||||
|
std::condition_variable condition;
|
||||||
|
bool stop;
|
||||||
|
};
|
||||||
|
|
||||||
|
// the constructor just launches some amount of workers
|
||||||
|
inline ThreadPool::ThreadPool(size_t threads)
|
||||||
|
: stop(false)
|
||||||
|
{
|
||||||
|
for(size_t i = 0;i<threads;++i)
|
||||||
|
workers.emplace_back(
|
||||||
|
[this]
|
||||||
|
{
|
||||||
|
for(;;)
|
||||||
|
{
|
||||||
|
std::function<void()> task;
|
||||||
|
|
||||||
|
{
|
||||||
|
std::unique_lock<std::mutex> lock(this->queue_mutex);
|
||||||
|
this->condition.wait(lock,
|
||||||
|
[this]{ return this->stop || !this->tasks.empty(); });
|
||||||
|
if(this->stop && this->tasks.empty())
|
||||||
|
return;
|
||||||
|
task = std::move(this->tasks.front());
|
||||||
|
this->tasks.pop();
|
||||||
|
}
|
||||||
|
|
||||||
|
task();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// add new work item to the pool
|
||||||
|
template<class F, class... Args>
|
||||||
|
auto ThreadPool::enqueue(F&& f, Args&&... args)
|
||||||
|
-> std::future<typename std::result_of<F(Args...)>::type>
|
||||||
|
{
|
||||||
|
using return_type = typename std::result_of<F(Args...)>::type;
|
||||||
|
|
||||||
|
auto task = std::make_shared< std::packaged_task<return_type()> >(
|
||||||
|
std::bind(std::forward<F>(f), std::forward<Args>(args)...)
|
||||||
|
);
|
||||||
|
|
||||||
|
std::future<return_type> res = task->get_future();
|
||||||
|
{
|
||||||
|
std::unique_lock<std::mutex> lock(queue_mutex);
|
||||||
|
|
||||||
|
// don't allow enqueueing after stopping the pool
|
||||||
|
if(stop)
|
||||||
|
throw std::runtime_error("enqueue on stopped ThreadPool");
|
||||||
|
|
||||||
|
tasks.emplace([task](){ (*task)(); });
|
||||||
|
}
|
||||||
|
condition.notify_one();
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// the destructor joins all threads
|
||||||
|
inline ThreadPool::~ThreadPool()
|
||||||
|
{
|
||||||
|
{
|
||||||
|
std::unique_lock<std::mutex> lock(queue_mutex);
|
||||||
|
stop = true;
|
||||||
|
}
|
||||||
|
condition.notify_all();
|
||||||
|
for(std::thread &worker: workers)
|
||||||
|
worker.join();
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
@ -0,0 +1,58 @@
|
|||||||
|
// todo refactor, repalce with gtest
|
||||||
|
|
||||||
|
#include "decoder/ctc_beam_search_decoder.h"
|
||||||
|
#include "kaldi/util/table-types.h"
|
||||||
|
#include "base/log.h"
|
||||||
|
#include "base/flags.h"
|
||||||
|
|
||||||
|
DEFINE_string(feature_respecifier, "", "test nnet prob");
|
||||||
|
|
||||||
|
using kaldi::BaseFloat;
|
||||||
|
|
||||||
|
void SplitFeature(kaldi::Matrix<BaseFloat> feature,
|
||||||
|
int32 chunk_size,
|
||||||
|
std::vector<kaldi::Matrix<BaseFloat>> feature_chunks) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
int main(int argc, char* argv[]) {
|
||||||
|
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
||||||
|
google::InitGoogleLogging(argv[0]);
|
||||||
|
|
||||||
|
kaldi::SequentialBaseFloatMatrixReader feature_reader(FLAGS_feature_respecifier);
|
||||||
|
|
||||||
|
// test nnet_output --> decoder result
|
||||||
|
int32 num_done = 0, num_err = 0;
|
||||||
|
|
||||||
|
CTCBeamSearchOptions opts;
|
||||||
|
CTCBeamSearch decoder(opts);
|
||||||
|
|
||||||
|
ModelOptions model_opts;
|
||||||
|
std::shared_ptr<PaddleNnet> nnet(new PaddleNnet(model_opts));
|
||||||
|
|
||||||
|
Decodable decodable();
|
||||||
|
decodable.SetNnet(nnet);
|
||||||
|
|
||||||
|
int32 chunk_size = 0;
|
||||||
|
for (; !feature_reader.Done(); feature_reader.Next()) {
|
||||||
|
string utt = feature_reader.Key();
|
||||||
|
const kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
|
||||||
|
vector<Matrix<BaseFloat>> feature_chunks;
|
||||||
|
SplitFeature(feature, chunk_size, &feature_chunks);
|
||||||
|
for (auto feature_chunk : feature_chunks) {
|
||||||
|
decodable.FeedFeatures(feature_chunk);
|
||||||
|
decoder.InitDecoder();
|
||||||
|
decoder.AdvanceDecode(decodable, chunk_size);
|
||||||
|
}
|
||||||
|
decodable.InputFinished();
|
||||||
|
std::string result;
|
||||||
|
result = decoder.GetFinalBestPath();
|
||||||
|
KALDI_LOG << " the result of " << utt << " is " << result;
|
||||||
|
decodable.Reset();
|
||||||
|
++num_done;
|
||||||
|
}
|
||||||
|
|
||||||
|
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
|
||||||
|
<< " with errors.";
|
||||||
|
return (num_done != 0 ? 0 : 1);
|
||||||
|
}
|
@ -0,0 +1,47 @@
|
|||||||
|
// todo refactor, repalce with gtest
|
||||||
|
|
||||||
|
#include "frontend/linear_spectrogram.h"
|
||||||
|
#include "frontend/normalizer.h"
|
||||||
|
#include "frontend/feature_extractor_interface.h"
|
||||||
|
#include "kaldi/util/table-types.h"
|
||||||
|
#include "base/log.h"
|
||||||
|
#include "base/flags.h"
|
||||||
|
#include "kaldi/feat/wave-reader.h"
|
||||||
|
|
||||||
|
DEFINE_string(wav_rspecifier, "", "test wav path");
|
||||||
|
DEFINE_string(feature_wspecifier, "", "test wav ark");
|
||||||
|
|
||||||
|
int main(int argc, char* argv[]) {
|
||||||
|
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
||||||
|
google::InitGoogleLogging(argv[0]);
|
||||||
|
|
||||||
|
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(FLAGS_wav_rspecifier);
|
||||||
|
kaldi::BaseFloatMatrixWriter feat_writer(FLAGS_feature_wspecifier);
|
||||||
|
|
||||||
|
// test feature linear_spectorgram: wave --> decibel_normalizer --> hanning window -->linear_spectrogram --> cmvn
|
||||||
|
int32 num_done = 0, num_err = 0;
|
||||||
|
ppspeech::LinearSpectrogramOptions opt;
|
||||||
|
ppspeech::DecibelNormalizerOptions db_norm_opt;
|
||||||
|
std::unique_ptr<ppspeech::FeatureExtractorInterface> base_feature_extractor(
|
||||||
|
new ppspeech::DecibelNormalizer(db_norm_opt));
|
||||||
|
ppspeech::LinearSpectrogram linear_spectrogram(opt, std::move(base_feature_extractor));
|
||||||
|
|
||||||
|
for (; !wav_reader.Done(); wav_reader.Next()) {
|
||||||
|
std::string utt = wav_reader.Key();
|
||||||
|
const kaldi::WaveData &wave_data = wav_reader.Value();
|
||||||
|
|
||||||
|
int32 this_channel = 0;
|
||||||
|
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(), this_channel);
|
||||||
|
kaldi::Matrix<BaseFloat> features;
|
||||||
|
linear_spectrogram.AcceptWaveform(waveform);
|
||||||
|
linear_spectrogram.ReadFeats(&features);
|
||||||
|
|
||||||
|
feat_writer.Write(utt, features);
|
||||||
|
if (num_done % 50 == 0 && num_done != 0)
|
||||||
|
KALDI_VLOG(2) << "Processed " << num_done << " utterances";
|
||||||
|
num_done++;
|
||||||
|
}
|
||||||
|
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
|
||||||
|
<< " with errors.";
|
||||||
|
return (num_done != 0 ? 0 : 1);
|
||||||
|
}
|
@ -1,2 +1,10 @@
|
|||||||
aux_source_directory(. DIR_LIB_SRCS)
|
project(decoder)
|
||||||
add_library(decoder STATIC ${DIR_LIB_SRCS})
|
|
||||||
|
include_directories(${CMAKE_CURRENT_SOURCE_DIR/ctc_decoders})
|
||||||
|
add_library(decoder
|
||||||
|
ctc_beam_search_decoder.cc
|
||||||
|
ctc_decoders/decoder_utils.cpp
|
||||||
|
ctc_decoders/path_trie.cpp
|
||||||
|
ctc_decoders/scorer.cpp
|
||||||
|
)
|
||||||
|
target_link_libraries(decoder kenlm)
|
@ -0,0 +1,7 @@
|
|||||||
|
#include "base/basic_types.h"
|
||||||
|
|
||||||
|
struct DecoderResult {
|
||||||
|
BaseFloat acoustic_score;
|
||||||
|
std::vector<int32> words_idx;
|
||||||
|
std::vector<pair<int32, int32>> time_stamp;
|
||||||
|
};
|
@ -0,0 +1,300 @@
|
|||||||
|
#include "decoder/ctc_beam_search_decoder.h"
|
||||||
|
|
||||||
|
#include "base/basic_types.h"
|
||||||
|
#include "decoder/ctc_decoders/decoder_utils.h"
|
||||||
|
#include "utils/file_utils.h"
|
||||||
|
|
||||||
|
namespace ppspeech {
|
||||||
|
|
||||||
|
using std::vector;
|
||||||
|
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
|
||||||
|
|
||||||
|
CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts) :
|
||||||
|
opts_(opts),
|
||||||
|
init_ext_scorer_(nullptr),
|
||||||
|
blank_id(-1),
|
||||||
|
space_id(-1),
|
||||||
|
num_frame_decoded_(0),
|
||||||
|
root(nullptr) {
|
||||||
|
|
||||||
|
LOG(INFO) << "dict path: " << opts_.dict_file;
|
||||||
|
if (!ReadFileToVector(opts_.dict_file, &vocabulary_)) {
|
||||||
|
LOG(INFO) << "load the dict failed";
|
||||||
|
}
|
||||||
|
LOG(INFO) << "read the vocabulary success, dict size: " << vocabulary_.size();
|
||||||
|
|
||||||
|
LOG(INFO) << "language model path: " << opts_.lm_path;
|
||||||
|
init_ext_scorer_ = std::make_shared<Scorer>(opts_.alpha,
|
||||||
|
opts_.beta,
|
||||||
|
opts_.lm_path,
|
||||||
|
vocabulary_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void CTCBeamSearch::Reset() {
|
||||||
|
num_frame_decoded_ = 0;
|
||||||
|
ResetPrefixes();
|
||||||
|
}
|
||||||
|
|
||||||
|
void CTCBeamSearch::InitDecoder() {
|
||||||
|
|
||||||
|
blank_id = 0;
|
||||||
|
auto it = std::find(vocabulary_.begin(), vocabulary_.end(), " ");
|
||||||
|
|
||||||
|
space_id = it - vocabulary_.begin();
|
||||||
|
// if no space in vocabulary
|
||||||
|
if ((size_t)space_id >= vocabulary_.size()) {
|
||||||
|
space_id = -2;
|
||||||
|
}
|
||||||
|
|
||||||
|
ResetPrefixes();
|
||||||
|
|
||||||
|
root = std::make_shared<PathTrie>();
|
||||||
|
root->score = root->log_prob_b_prev = 0.0;
|
||||||
|
prefixes.push_back(root.get());
|
||||||
|
if (init_ext_scorer_ != nullptr && !init_ext_scorer_->is_character_based()) {
|
||||||
|
auto fst_dict =
|
||||||
|
static_cast<fst::StdVectorFst *>(init_ext_scorer_->dictionary);
|
||||||
|
fst::StdVectorFst *dict_ptr = fst_dict->Copy(true);
|
||||||
|
root->set_dictionary(dict_ptr);
|
||||||
|
|
||||||
|
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
|
||||||
|
root->set_matcher(matcher);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void CTCBeamSearch::Decode(std::shared_ptr<kaldi::DecodableInterface> decodable) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int32 CTCBeamSearch::NumFrameDecoded() {
|
||||||
|
return num_frame_decoded_;
|
||||||
|
}
|
||||||
|
|
||||||
|
// todo rename, refactor
|
||||||
|
void CTCBeamSearch::AdvanceDecode(const std::shared_ptr<kaldi::DecodableInterface>& decodable,
|
||||||
|
int max_frames) {
|
||||||
|
while (max_frames > 0) {
|
||||||
|
vector<vector<BaseFloat>> likelihood;
|
||||||
|
if (decodable->IsLastFrame(NumFrameDecoded() + 1)) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
likelihood.push_back(decodable->FrameLogLikelihood(NumFrameDecoded() + 1));
|
||||||
|
AdvanceDecoding(likelihood);
|
||||||
|
max_frames--;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void CTCBeamSearch::ResetPrefixes() {
|
||||||
|
for (size_t i = 0; i < prefixes.size(); i++) {
|
||||||
|
if (prefixes[i] != nullptr) {
|
||||||
|
delete prefixes[i];
|
||||||
|
prefixes[i] = nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int CTCBeamSearch::DecodeLikelihoods(const vector<vector<float>>&probs,
|
||||||
|
vector<string>& nbest_words) {
|
||||||
|
kaldi::Timer timer;
|
||||||
|
timer.Reset();
|
||||||
|
AdvanceDecoding(probs);
|
||||||
|
LOG(INFO) <<"ctc decoding elapsed time(s) " << static_cast<float>(timer.Elapsed()) / 1000.0f;
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<std::pair<double, string>> CTCBeamSearch::GetNBestPath() {
|
||||||
|
return get_beam_search_result(prefixes, vocabulary_, opts_.beam_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
string CTCBeamSearch::GetBestPath() {
|
||||||
|
std::vector<std::pair<double, std::string>> result;
|
||||||
|
result = get_beam_search_result(prefixes, vocabulary_, opts_.beam_size);
|
||||||
|
return result[0].second;
|
||||||
|
}
|
||||||
|
|
||||||
|
string CTCBeamSearch::GetFinalBestPath() {
|
||||||
|
CalculateApproxScore();
|
||||||
|
LMRescore();
|
||||||
|
return GetBestPath();
|
||||||
|
}
|
||||||
|
|
||||||
|
void CTCBeamSearch::AdvanceDecoding(const vector<vector<BaseFloat>>& probs) {
|
||||||
|
size_t num_time_steps = probs.size();
|
||||||
|
size_t beam_size = opts_.beam_size;
|
||||||
|
double cutoff_prob = opts_.cutoff_prob;
|
||||||
|
size_t cutoff_top_n = opts_.cutoff_top_n;
|
||||||
|
|
||||||
|
vector<vector<double>> probs_seq(probs.size(), vector<double>(probs[0].size(), 0));
|
||||||
|
|
||||||
|
int row = probs.size();
|
||||||
|
int col = probs[0].size();
|
||||||
|
for(int i = 0; i < row; i++) {
|
||||||
|
for (int j = 0; j < col; j++){
|
||||||
|
probs_seq[i][j] = static_cast<double>(probs[i][j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t time_step = 0; time_step < num_time_steps; time_step++) {
|
||||||
|
const auto& prob = probs_seq[time_step];
|
||||||
|
|
||||||
|
float min_cutoff = -NUM_FLT_INF;
|
||||||
|
bool full_beam = false;
|
||||||
|
if (init_ext_scorer_ != nullptr) {
|
||||||
|
size_t num_prefixes = std::min(prefixes.size(), beam_size);
|
||||||
|
std::sort(prefixes.begin(), prefixes.begin() + num_prefixes,
|
||||||
|
prefix_compare);
|
||||||
|
|
||||||
|
if (num_prefixes == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
min_cutoff = prefixes[num_prefixes - 1]->score +
|
||||||
|
std::log(prob[blank_id]) -
|
||||||
|
std::max(0.0, init_ext_scorer_->beta);
|
||||||
|
|
||||||
|
full_beam = (num_prefixes == beam_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<std::pair<size_t, float>> log_prob_idx =
|
||||||
|
get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n);
|
||||||
|
|
||||||
|
// loop over chars
|
||||||
|
size_t log_prob_idx_len = log_prob_idx.size();
|
||||||
|
for (size_t index = 0; index < log_prob_idx_len; index++) {
|
||||||
|
SearchOneChar(full_beam, log_prob_idx[index], min_cutoff);
|
||||||
|
}
|
||||||
|
|
||||||
|
prefixes.clear();
|
||||||
|
|
||||||
|
// update log probs
|
||||||
|
root->iterate_to_vec(prefixes);
|
||||||
|
// only preserve top beam_size prefixes
|
||||||
|
if (prefixes.size() >= beam_size) {
|
||||||
|
std::nth_element(prefixes.begin(),
|
||||||
|
prefixes.begin() + beam_size,
|
||||||
|
prefixes.end(),
|
||||||
|
prefix_compare);
|
||||||
|
for (size_t i = beam_size; i < prefixes.size(); ++i) {
|
||||||
|
prefixes[i]->remove();
|
||||||
|
}
|
||||||
|
} // if
|
||||||
|
num_frame_decoded_++;
|
||||||
|
} // for probs_seq
|
||||||
|
}
|
||||||
|
|
||||||
|
int32 CTCBeamSearch::SearchOneChar(const bool& full_beam,
|
||||||
|
const std::pair<size_t, BaseFloat>& log_prob_idx,
|
||||||
|
const BaseFloat& min_cutoff) {
|
||||||
|
size_t beam_size = opts_.beam_size;
|
||||||
|
const auto& c = log_prob_idx.first;
|
||||||
|
const auto& log_prob_c = log_prob_idx.second;
|
||||||
|
size_t prefixes_len = std::min(prefixes.size(), beam_size);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < prefixes_len; ++i) {
|
||||||
|
auto prefix = prefixes[i];
|
||||||
|
if (full_beam && log_prob_c + prefix->score < min_cutoff) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (c == blank_id) {
|
||||||
|
prefix->log_prob_b_cur = log_sum_exp(
|
||||||
|
prefix->log_prob_b_cur,
|
||||||
|
log_prob_c +
|
||||||
|
prefix->score);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// repeated character
|
||||||
|
if (c == prefix->character) {
|
||||||
|
// p_{nb}(l;x_{1:t}) = p(c;x_{t})p(l;x_{1:t-1})
|
||||||
|
prefix->log_prob_nb_cur = log_sum_exp(
|
||||||
|
prefix->log_prob_nb_cur,
|
||||||
|
log_prob_c +
|
||||||
|
prefix->log_prob_nb_prev);
|
||||||
|
}
|
||||||
|
|
||||||
|
// get new prefix
|
||||||
|
auto prefix_new = prefix->get_path_trie(c);
|
||||||
|
if (prefix_new != nullptr) {
|
||||||
|
float log_p = -NUM_FLT_INF;
|
||||||
|
if (c == prefix->character &&
|
||||||
|
prefix->log_prob_b_prev > -NUM_FLT_INF) {
|
||||||
|
// p_{nb}(l^{+};x_{1:t}) = p(c;x_{t})p_{b}(l;x_{1:t-1})
|
||||||
|
log_p = log_prob_c + prefix->log_prob_b_prev;
|
||||||
|
} else if (c != prefix->character) {
|
||||||
|
// p_{nb}(l^{+};x_{1:t}) = p(c;x_{t}) p(l;x_{1:t-1})
|
||||||
|
log_p = log_prob_c + prefix->score;
|
||||||
|
}
|
||||||
|
|
||||||
|
// language model scoring
|
||||||
|
if (init_ext_scorer_ != nullptr &&
|
||||||
|
(c == space_id || init_ext_scorer_->is_character_based())) {
|
||||||
|
PathTrie *prefix_to_score = nullptr;
|
||||||
|
// skip scoring the space
|
||||||
|
if (init_ext_scorer_->is_character_based()) {
|
||||||
|
prefix_to_score = prefix_new;
|
||||||
|
} else {
|
||||||
|
prefix_to_score = prefix;
|
||||||
|
}
|
||||||
|
|
||||||
|
float score = 0.0;
|
||||||
|
vector<string> ngram;
|
||||||
|
ngram = init_ext_scorer_->make_ngram(prefix_to_score);
|
||||||
|
// lm score: p_{lm}(W)^{\alpha} + \beta
|
||||||
|
score = init_ext_scorer_->get_log_cond_prob(ngram) *
|
||||||
|
init_ext_scorer_->alpha;
|
||||||
|
log_p += score;
|
||||||
|
log_p += init_ext_scorer_->beta;
|
||||||
|
}
|
||||||
|
// p_{nb}(l;x_{1:t})
|
||||||
|
prefix_new->log_prob_nb_cur =
|
||||||
|
log_sum_exp(prefix_new->log_prob_nb_cur,
|
||||||
|
log_p);
|
||||||
|
}
|
||||||
|
} // end of loop over prefix
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void CTCBeamSearch::CalculateApproxScore() {
|
||||||
|
size_t beam_size = opts_.beam_size;
|
||||||
|
size_t num_prefixes = std::min(prefixes.size(), beam_size);
|
||||||
|
std::sort(
|
||||||
|
prefixes.begin(),
|
||||||
|
prefixes.begin() + num_prefixes,
|
||||||
|
prefix_compare);
|
||||||
|
|
||||||
|
// compute aproximate ctc score as the return score, without affecting the
|
||||||
|
// return order of decoding result. To delete when decoder gets stable.
|
||||||
|
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
|
||||||
|
double approx_ctc = prefixes[i]->score;
|
||||||
|
if (init_ext_scorer_ != nullptr) {
|
||||||
|
vector<int> output;
|
||||||
|
prefixes[i]->get_path_vec(output);
|
||||||
|
auto prefix_length = output.size();
|
||||||
|
auto words = init_ext_scorer_->split_labels(output);
|
||||||
|
// remove word insert
|
||||||
|
approx_ctc = approx_ctc - prefix_length * init_ext_scorer_->beta;
|
||||||
|
// remove language model weight:
|
||||||
|
approx_ctc -=
|
||||||
|
(init_ext_scorer_->get_sent_log_prob(words)) * init_ext_scorer_->alpha;
|
||||||
|
}
|
||||||
|
prefixes[i]->approx_ctc = approx_ctc;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void CTCBeamSearch::LMRescore() {
|
||||||
|
size_t beam_size = opts_.beam_size;
|
||||||
|
if (init_ext_scorer_ != nullptr && !init_ext_scorer_->is_character_based()) {
|
||||||
|
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
|
||||||
|
auto prefix = prefixes[i];
|
||||||
|
if (!prefix->is_empty() && prefix->character != space_id) {
|
||||||
|
float score = 0.0;
|
||||||
|
vector<string> ngram = init_ext_scorer_->make_ngram(prefix);
|
||||||
|
score = init_ext_scorer_->get_log_cond_prob(ngram) * init_ext_scorer_->alpha;
|
||||||
|
score += init_ext_scorer_->beta;
|
||||||
|
prefix->score += score;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace ppspeech
|
@ -0,0 +1,79 @@
|
|||||||
|
#include "base/common.h"
|
||||||
|
#include "nnet/decodable-itf.h"
|
||||||
|
#include "util/parse-options.h"
|
||||||
|
#include "decoder/ctc_decoders/scorer.h"
|
||||||
|
#include "decoder/ctc_decoders/path_trie.h"
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
namespace ppspeech {
|
||||||
|
|
||||||
|
struct CTCBeamSearchOptions {
|
||||||
|
std::string dict_file;
|
||||||
|
std::string lm_path;
|
||||||
|
BaseFloat alpha;
|
||||||
|
BaseFloat beta;
|
||||||
|
BaseFloat cutoff_prob;
|
||||||
|
int beam_size;
|
||||||
|
int cutoff_top_n;
|
||||||
|
int num_proc_bsearch;
|
||||||
|
CTCBeamSearchOptions() :
|
||||||
|
dict_file("./model/words.txt"),
|
||||||
|
lm_path("./model/lm.arpa"),
|
||||||
|
alpha(1.9f),
|
||||||
|
beta(5.0),
|
||||||
|
beam_size(300),
|
||||||
|
cutoff_prob(0.99f),
|
||||||
|
cutoff_top_n(40),
|
||||||
|
num_proc_bsearch(0) {
|
||||||
|
}
|
||||||
|
|
||||||
|
void Register(kaldi::OptionsItf* opts) {
|
||||||
|
opts->Register("dict", &dict_file, "dict file ");
|
||||||
|
opts->Register("lm-path", &lm_path, "language model file");
|
||||||
|
opts->Register("alpha", &alpha, "alpha");
|
||||||
|
opts->Register("beta", &beta, "beta");
|
||||||
|
opts->Register("beam-size", &beam_size, "beam size for beam search method");
|
||||||
|
opts->Register("cutoff-prob", &cutoff_prob, "cutoff probs");
|
||||||
|
opts->Register("cutoff-top-n", &cutoff_top_n, "cutoff top n");
|
||||||
|
opts->Register("num-proc-bsearch", &num_proc_bsearch, "num proc bsearch");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class CTCBeamSearch {
|
||||||
|
public:
|
||||||
|
explicit CTCBeamSearch(const CTCBeamSearchOptions& opts);
|
||||||
|
~CTCBeamSearch() {}
|
||||||
|
void InitDecoder();
|
||||||
|
void Decode(std::shared_ptr<kaldi::DecodableInterface> decodable);
|
||||||
|
std::string GetBestPath();
|
||||||
|
std::vector<std::pair<double, std::string>> GetNBestPath();
|
||||||
|
std::string GetFinalBestPath();
|
||||||
|
int NumFrameDecoded();
|
||||||
|
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>&probs,
|
||||||
|
std::vector<std::string>& nbest_words);
|
||||||
|
void AdvanceDecode(const std::shared_ptr<kaldi::DecodableInterface>& decodable,
|
||||||
|
int max_frames);
|
||||||
|
void Reset();
|
||||||
|
private:
|
||||||
|
void ResetPrefixes();
|
||||||
|
int32 SearchOneChar(const bool& full_beam,
|
||||||
|
const std::pair<size_t, BaseFloat>& log_prob_idx,
|
||||||
|
const BaseFloat& min_cutoff);
|
||||||
|
void CalculateApproxScore();
|
||||||
|
void LMRescore();
|
||||||
|
void AdvanceDecoding(const std::vector<std::vector<BaseFloat>>& probs);
|
||||||
|
|
||||||
|
CTCBeamSearchOptions opts_;
|
||||||
|
std::shared_ptr<Scorer> init_ext_scorer_; // todo separate later
|
||||||
|
//std::vector<DecodeResult> decoder_results_;
|
||||||
|
std::vector<std::string> vocabulary_; // todo remove later
|
||||||
|
size_t blank_id;
|
||||||
|
int space_id;
|
||||||
|
std::shared_ptr<PathTrie> root;
|
||||||
|
std::vector<PathTrie*> prefixes;
|
||||||
|
int num_frame_decoded_;
|
||||||
|
DISALLOW_COPY_AND_ASSIGN(CTCBeamSearch);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace basr
|
@ -0,0 +1 @@
|
|||||||
|
../../../third_party/ctc_decoders
|
@ -0,0 +1,8 @@
|
|||||||
|
project(frontend)
|
||||||
|
|
||||||
|
add_library(frontend
|
||||||
|
normalizer.cc
|
||||||
|
linear_spectrogram.cc
|
||||||
|
)
|
||||||
|
|
||||||
|
target_link_libraries(frontend kaldi-matrix)
|
@ -0,0 +1,36 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
// wrap the fbank feat of kaldi, todo (SmileGoat)
|
||||||
|
|
||||||
|
#include "kaldi/feat/feature-mfcc.h"
|
||||||
|
|
||||||
|
#incldue "kaldi/matrix/kaldi-vector.h"
|
||||||
|
|
||||||
|
namespace ppspeech {
|
||||||
|
|
||||||
|
class FbankExtractor : FeatureExtractorInterface {
|
||||||
|
public:
|
||||||
|
explicit FbankExtractor(const FbankOptions& opts,
|
||||||
|
share_ptr<FeatureExtractorInterface> pre_extractor);
|
||||||
|
virtual void AcceptWaveform(const kaldi::Vector<kaldi::BaseFloat>& input) = 0;
|
||||||
|
virtual void Read(kaldi::Vector<kaldi::BaseFloat>* feat) = 0;
|
||||||
|
virtual size_t Dim() const = 0;
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool Compute(const kaldi::Vector<kaldi::BaseFloat>& wave,
|
||||||
|
kaldi::Vector<kaldi::BaseFloat>* feat) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace ppspeech
|
@ -0,0 +1,29 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "base/basic_types.h"
|
||||||
|
#include "kaldi/matrix/kaldi-vector.h"
|
||||||
|
|
||||||
|
namespace ppspeech {
|
||||||
|
|
||||||
|
class FeatureExtractorInterface {
|
||||||
|
public:
|
||||||
|
virtual void AcceptWaveform(const kaldi::VectorBase<kaldi::BaseFloat>& input) = 0;
|
||||||
|
virtual void Read(kaldi::VectorBase<kaldi::BaseFloat>* feat) = 0;
|
||||||
|
virtual size_t Dim() const = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace ppspeech
|
@ -0,0 +1,179 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "frontend/linear_spectrogram.h"
|
||||||
|
#include "kaldi/base/kaldi-math.h"
|
||||||
|
#include "kaldi/matrix/matrix-functions.h"
|
||||||
|
|
||||||
|
namespace ppspeech {
|
||||||
|
|
||||||
|
using kaldi::int32;
|
||||||
|
using kaldi::BaseFloat;
|
||||||
|
using kaldi::Vector;
|
||||||
|
using kaldi::VectorBase;
|
||||||
|
using kaldi::Matrix;
|
||||||
|
using std::vector;
|
||||||
|
|
||||||
|
//todo remove later
|
||||||
|
void CopyVector2StdVector_(const VectorBase<BaseFloat>& input,
|
||||||
|
vector<BaseFloat>* output) {
|
||||||
|
if (input.Dim() == 0) return;
|
||||||
|
output->resize(input.Dim());
|
||||||
|
for (size_t idx = 0; idx < input.Dim(); ++idx) {
|
||||||
|
(*output)[idx] = input(idx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void CopyStdVector2Vector_(const vector<BaseFloat>& input,
|
||||||
|
Vector<BaseFloat>* output) {
|
||||||
|
if (input.empty()) return;
|
||||||
|
output->Resize(input.size());
|
||||||
|
for (size_t idx = 0; idx < input.size(); ++idx) {
|
||||||
|
(*output)(idx) = input[idx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
LinearSpectrogram::LinearSpectrogram(
|
||||||
|
const LinearSpectrogramOptions& opts,
|
||||||
|
std::unique_ptr<FeatureExtractorInterface> base_extractor) {
|
||||||
|
base_extractor_ = std::move(base_extractor);
|
||||||
|
int32 window_size = opts.frame_opts.WindowSize();
|
||||||
|
int32 window_shift = opts.frame_opts.WindowShift();
|
||||||
|
fft_points_ = window_size;
|
||||||
|
hanning_window_.resize(window_size);
|
||||||
|
|
||||||
|
double a = M_2PI / (window_size - 1);
|
||||||
|
hanning_window_energy_ = 0;
|
||||||
|
for (int i = 0; i < window_size; ++i) {
|
||||||
|
hanning_window_[i] = 0.5 - 0.5 * cos(a * i);
|
||||||
|
hanning_window_energy_ += hanning_window_[i] * hanning_window_[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
dim_ = fft_points_ / 2 + 1; // the dimension is Fs/2 Hz
|
||||||
|
}
|
||||||
|
|
||||||
|
void LinearSpectrogram::AcceptWaveform(const VectorBase<BaseFloat>& input) {
|
||||||
|
base_extractor_->AcceptWaveform(input);
|
||||||
|
}
|
||||||
|
|
||||||
|
void LinearSpectrogram::Hanning(vector<float>* data) const {
|
||||||
|
CHECK_GE(data->size(), hanning_window_.size());
|
||||||
|
|
||||||
|
for (size_t i = 0; i < hanning_window_.size(); ++i) {
|
||||||
|
data->at(i) *= hanning_window_[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool LinearSpectrogram::NumpyFft(vector<BaseFloat>* v,
|
||||||
|
vector<BaseFloat>* real,
|
||||||
|
vector<BaseFloat>* img) const {
|
||||||
|
Vector<BaseFloat> v_tmp;
|
||||||
|
CopyStdVector2Vector_(*v, &v_tmp);
|
||||||
|
RealFft(&v_tmp, true);
|
||||||
|
CopyVector2StdVector_(v_tmp, v);
|
||||||
|
real->push_back(v->at(0));
|
||||||
|
img->push_back(0);
|
||||||
|
for (int i = 1; i < v->size() / 2; i++) {
|
||||||
|
real->push_back(v->at(2 * i));
|
||||||
|
img->push_back(v->at(2 * i + 1));
|
||||||
|
}
|
||||||
|
real->push_back(v->at(1));
|
||||||
|
img->push_back(0);
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// todo remove later
|
||||||
|
void LinearSpectrogram::ReadFeats(Matrix<BaseFloat>* feats) {
|
||||||
|
Vector<BaseFloat> tmp;
|
||||||
|
waveform_.Resize(base_extractor_->Dim());
|
||||||
|
Compute(tmp, &waveform_);
|
||||||
|
vector<vector<BaseFloat>> result;
|
||||||
|
vector<BaseFloat> feats_vec;
|
||||||
|
CopyVector2StdVector_(waveform_, &feats_vec);
|
||||||
|
Compute(feats_vec, result);
|
||||||
|
feats->Resize(result.size(), result[0].size());
|
||||||
|
for (int row_idx = 0; row_idx < result.size(); ++row_idx) {
|
||||||
|
for (int col_idx = 0; col_idx < result.size(); ++col_idx) {
|
||||||
|
(*feats)(row_idx, col_idx) = result[row_idx][col_idx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
waveform_.Resize(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
void LinearSpectrogram::Read(VectorBase<BaseFloat>* feat) {
|
||||||
|
// todo
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// only for test, remove later
|
||||||
|
// todo: compute the feature frame by frame.
|
||||||
|
void LinearSpectrogram::Compute(const VectorBase<kaldi::BaseFloat>& input,
|
||||||
|
VectorBase<kaldi::BaseFloat>* feature) {
|
||||||
|
base_extractor_->Read(feature);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute spectrogram feat, only for test, remove later
|
||||||
|
// todo: refactor later (SmileGoat)
|
||||||
|
bool LinearSpectrogram::Compute(const vector<float>& wave,
|
||||||
|
vector<vector<float>>& feat) {
|
||||||
|
int num_samples = wave.size();
|
||||||
|
const int& frame_length = opts_.frame_opts.WindowSize();
|
||||||
|
const int& sample_rate = opts_.frame_opts.samp_freq;
|
||||||
|
const int& frame_shift = opts_.frame_opts.WindowShift();
|
||||||
|
const int& fft_points = fft_points_;
|
||||||
|
const float scale = hanning_window_energy_ * frame_shift;
|
||||||
|
|
||||||
|
if (num_samples < frame_length) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
int num_frames = 1 + ((num_samples - frame_length) / frame_shift);
|
||||||
|
feat.resize(num_frames);
|
||||||
|
vector<float> fft_real((fft_points_ / 2 + 1), 0);
|
||||||
|
vector<float> fft_img((fft_points_ / 2 + 1), 0);
|
||||||
|
vector<float> v(frame_length, 0);
|
||||||
|
vector<float> power((fft_points / 2 + 1));
|
||||||
|
|
||||||
|
for (int i = 0; i < num_frames; ++i) {
|
||||||
|
vector<float> data(wave.data() + i * frame_shift,
|
||||||
|
wave.data() + i * frame_shift + frame_length);
|
||||||
|
Hanning(&data);
|
||||||
|
fft_img.clear();
|
||||||
|
fft_real.clear();
|
||||||
|
v.assign(data.begin(), data.end());
|
||||||
|
if (NumpyFft(&v, &fft_real, &fft_img)) {
|
||||||
|
LOG(ERROR)<< i << " fft compute occurs error, please checkout the input data";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
feat[i].resize(fft_points / 2 + 1); // the last dimension is Fs/2 Hz
|
||||||
|
for (int j = 0; j < (fft_points / 2 + 1); ++j) {
|
||||||
|
power[j] = fft_real[j] * fft_real[j] + fft_img[j] * fft_img[j];
|
||||||
|
feat[i][j] = power[j];
|
||||||
|
|
||||||
|
if (j == 0 || j == feat[0].size() - 1) {
|
||||||
|
feat[i][j] /= scale;
|
||||||
|
} else {
|
||||||
|
feat[i][j] *= (2.0 / scale);
|
||||||
|
}
|
||||||
|
|
||||||
|
// log added eps=1e-14
|
||||||
|
feat[i][j] = std::log(feat[i][j] + 1e-14);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace ppspeech
|
@ -0,0 +1,50 @@
|
|||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "frontend/feature_extractor_interface.h"
|
||||||
|
#include "kaldi/feat/feature-window.h"
|
||||||
|
#include "base/common.h"
|
||||||
|
|
||||||
|
namespace ppspeech {
|
||||||
|
|
||||||
|
struct LinearSpectrogramOptions {
|
||||||
|
kaldi::FrameExtractionOptions frame_opts;
|
||||||
|
LinearSpectrogramOptions():
|
||||||
|
frame_opts() {}
|
||||||
|
|
||||||
|
void Register(kaldi::OptionsItf* opts) {
|
||||||
|
frame_opts.Register(opts);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class LinearSpectrogram : public FeatureExtractorInterface {
|
||||||
|
public:
|
||||||
|
explicit LinearSpectrogram(const LinearSpectrogramOptions& opts,
|
||||||
|
std::unique_ptr<FeatureExtractorInterface> base_extractor);
|
||||||
|
virtual void AcceptWaveform(const kaldi::VectorBase<kaldi::BaseFloat>& input);
|
||||||
|
virtual void Read(kaldi::VectorBase<kaldi::BaseFloat>* feat);
|
||||||
|
virtual size_t Dim() const { return dim_; }
|
||||||
|
void ReadFeats(kaldi::Matrix<kaldi::BaseFloat>* feats);
|
||||||
|
|
||||||
|
private:
|
||||||
|
void Hanning(std::vector<kaldi::BaseFloat>* data) const;
|
||||||
|
bool Compute(const std::vector<kaldi::BaseFloat>& wave,
|
||||||
|
std::vector<std::vector<kaldi::BaseFloat>>& feat);
|
||||||
|
void Compute(const kaldi::VectorBase<kaldi::BaseFloat>& input,
|
||||||
|
kaldi::VectorBase<kaldi::BaseFloat>* feature);
|
||||||
|
bool NumpyFft(std::vector<kaldi::BaseFloat>* v,
|
||||||
|
std::vector<kaldi::BaseFloat>* real,
|
||||||
|
std::vector<kaldi::BaseFloat>* img) const;
|
||||||
|
|
||||||
|
kaldi::int32 fft_points_;
|
||||||
|
size_t dim_;
|
||||||
|
std::vector<kaldi::BaseFloat> hanning_window_;
|
||||||
|
kaldi::BaseFloat hanning_window_energy_;
|
||||||
|
LinearSpectrogramOptions opts_;
|
||||||
|
kaldi::Vector<kaldi::BaseFloat> waveform_; // remove later, todo(SmileGoat)
|
||||||
|
std::unique_ptr<FeatureExtractorInterface> base_extractor_;
|
||||||
|
DISALLOW_COPY_AND_ASSIGN(LinearSpectrogram);
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
} // namespace ppspeech
|
@ -0,0 +1,16 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
// wrap the mfcc feat of kaldi, todo (SmileGoat)
|
||||||
|
#include "kaldi/feat/feature-mfcc.h"
|
@ -0,0 +1,130 @@
|
|||||||
|
|
||||||
|
#include "frontend/normalizer.h"
|
||||||
|
|
||||||
|
namespace ppspeech {
|
||||||
|
|
||||||
|
using kaldi::Vector;
|
||||||
|
using kaldi::VectorBase;
|
||||||
|
using kaldi::BaseFloat;
|
||||||
|
using std::vector;
|
||||||
|
|
||||||
|
DecibelNormalizer::DecibelNormalizer(const DecibelNormalizerOptions& opts) {
|
||||||
|
opts_ = opts;
|
||||||
|
dim_ = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void DecibelNormalizer::AcceptWaveform(const kaldi::VectorBase<BaseFloat>& input) {
|
||||||
|
dim_ = input.Dim();
|
||||||
|
waveform_.Resize(input.Dim());
|
||||||
|
waveform_.CopyFromVec(input);
|
||||||
|
}
|
||||||
|
|
||||||
|
void DecibelNormalizer::Read(kaldi::VectorBase<BaseFloat>* feat) {
|
||||||
|
if (waveform_.Dim() == 0) return;
|
||||||
|
Compute(waveform_, feat);
|
||||||
|
}
|
||||||
|
|
||||||
|
//todo remove later
|
||||||
|
void CopyVector2StdVector(const kaldi::VectorBase<BaseFloat>& input,
|
||||||
|
vector<BaseFloat>* output) {
|
||||||
|
if (input.Dim() == 0) return;
|
||||||
|
output->resize(input.Dim());
|
||||||
|
for (size_t idx = 0; idx < input.Dim(); ++idx) {
|
||||||
|
(*output)[idx] = input(idx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void CopyStdVector2Vector(const vector<BaseFloat>& input,
|
||||||
|
VectorBase<BaseFloat>* output) {
|
||||||
|
if (input.empty()) return;
|
||||||
|
assert(input.size() == output->Dim());
|
||||||
|
for (size_t idx = 0; idx < input.size(); ++idx) {
|
||||||
|
(*output)(idx) = input[idx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool DecibelNormalizer::Compute(const VectorBase<BaseFloat>& input,
|
||||||
|
VectorBase<BaseFloat>* feat) const {
|
||||||
|
// calculate db rms
|
||||||
|
BaseFloat rms_db = 0.0;
|
||||||
|
BaseFloat mean_square = 0.0;
|
||||||
|
BaseFloat gain = 0.0;
|
||||||
|
BaseFloat wave_float_normlization = 1.0f / (std::pow(2, 16 - 1));
|
||||||
|
|
||||||
|
vector<BaseFloat> samples;
|
||||||
|
samples.resize(input.Dim());
|
||||||
|
for (int32 i = 0; i < samples.size(); ++i) {
|
||||||
|
samples[i] = input(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
// square
|
||||||
|
for (auto &d : samples) {
|
||||||
|
if (opts_.convert_int_float) {
|
||||||
|
d = d * wave_float_normlization;
|
||||||
|
}
|
||||||
|
mean_square += d * d;
|
||||||
|
}
|
||||||
|
|
||||||
|
// mean
|
||||||
|
mean_square /= samples.size();
|
||||||
|
rms_db = 10 * std::log10(mean_square);
|
||||||
|
gain = opts_.target_db - rms_db;
|
||||||
|
|
||||||
|
if (gain > opts_.max_gain_db) {
|
||||||
|
LOG(ERROR) << "Unable to normalize segment to " << opts_.target_db << "dB,"
|
||||||
|
<< "because the the probable gain have exceeds opts_.max_gain_db"
|
||||||
|
<< opts_.max_gain_db << "dB.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Note that this is an in-place transformation.
|
||||||
|
for (auto &item : samples) {
|
||||||
|
// python item *= 10.0 ** (gain / 20.0)
|
||||||
|
item *= std::pow(10.0, gain / 20.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
CopyStdVector2Vector(samples, feat);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
PPNormalizer::PPNormalizer(
|
||||||
|
const PPNormalizerOptions& opts,
|
||||||
|
const std::unique_ptr<FeatureExtractorInterface>& pre_extractor) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
void PPNormalizer::AcceptWavefrom(const Vector<BaseFloat>& input) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
void PPNormalizer::Read(Vector<BaseFloat>* feat) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
bool PPNormalizer::Compute(const Vector<BaseFloat>& input,
|
||||||
|
Vector<BaseFloat>>* feat) {
|
||||||
|
if ((input.Dim() % mean_.Dim()) == 0) {
|
||||||
|
LOG(ERROR) << "CMVN dimension is wrong!";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
int32 size = mean_.Dim();
|
||||||
|
feat->Resize(input.Dim());
|
||||||
|
for (int32 row_idx = 0; row_idx < j; ++row_idx) {
|
||||||
|
int32 base_idx = row_idx * size;
|
||||||
|
for (int32 idx = 0; idx < mean_.Dim(); ++idx) {
|
||||||
|
(*feat)(base_idx + idx) = (input(base_dix + idx) - mean_(idx))* variance_(idx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} catch(const std::exception& e) {
|
||||||
|
std::cerr << e.what() << '\n';
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}*/
|
||||||
|
|
||||||
|
} // namespace ppspeech
|
@ -0,0 +1,72 @@
|
|||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "base/common.h"
|
||||||
|
#include "frontend/feature_extractor_interface.h"
|
||||||
|
#include "kaldi/util/options-itf.h"
|
||||||
|
|
||||||
|
namespace ppspeech {
|
||||||
|
|
||||||
|
|
||||||
|
struct DecibelNormalizerOptions {
|
||||||
|
float target_db;
|
||||||
|
float max_gain_db;
|
||||||
|
bool convert_int_float;
|
||||||
|
DecibelNormalizerOptions() :
|
||||||
|
target_db(-20),
|
||||||
|
max_gain_db(300.0),
|
||||||
|
convert_int_float(false) {}
|
||||||
|
|
||||||
|
void Register(kaldi::OptionsItf* opts) {
|
||||||
|
opts->Register("target-db", &target_db, "target db for db normalization");
|
||||||
|
opts->Register("max-gain-db", &max_gain_db, "max gain db for db normalization");
|
||||||
|
opts->Register("convert-int-float", &convert_int_float, "if convert int samples to float");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class DecibelNormalizer : public FeatureExtractorInterface {
|
||||||
|
public:
|
||||||
|
explicit DecibelNormalizer(const DecibelNormalizerOptions& opts);
|
||||||
|
virtual void AcceptWaveform(const kaldi::VectorBase<kaldi::BaseFloat>& input);
|
||||||
|
virtual void Read(kaldi::VectorBase<kaldi::BaseFloat>* feat);
|
||||||
|
virtual size_t Dim() const { return 0; }
|
||||||
|
bool Compute(const kaldi::VectorBase<kaldi::BaseFloat>& input,
|
||||||
|
kaldi::VectorBase<kaldi::BaseFloat>* feat) const;
|
||||||
|
private:
|
||||||
|
DecibelNormalizerOptions opts_;
|
||||||
|
size_t dim_;
|
||||||
|
std::unique_ptr<FeatureExtractorInterface> base_extractor_;
|
||||||
|
kaldi::Vector<kaldi::BaseFloat> waveform_;
|
||||||
|
};
|
||||||
|
|
||||||
|
/*
|
||||||
|
struct NormalizerOptions {
|
||||||
|
std::string mean_std_path;
|
||||||
|
NormalizerOptions() :
|
||||||
|
mean_std_path("") {}
|
||||||
|
|
||||||
|
void Register(kaldi::OptionsItf* opts) {
|
||||||
|
opts->Register("mean-std", &mean_std_path, "mean std file");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// todo refactor later (SmileGoat)
|
||||||
|
class PPNormalizer : public FeatureExtractorInterface {
|
||||||
|
public:
|
||||||
|
explicit PPNormalizer(const NormalizerOptions& opts,
|
||||||
|
const std::unique_ptr<FeatureExtractorInterface>& pre_extractor);
|
||||||
|
~PPNormalizer() {}
|
||||||
|
virtual void AcceptWavefrom(const kaldi::Vector<kaldi::BaseFloat>& input);
|
||||||
|
virtual void Read(kaldi::Vector<kaldi::BaseFloat>* feat);
|
||||||
|
virtual size_t Dim() const;
|
||||||
|
bool Compute(const kaldi::Vector<kaldi::BaseFloat>& input,
|
||||||
|
kaldi::Vector<kaldi::BaseFloat>>& feat);
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool _initialized;
|
||||||
|
kaldi::Vector<float> mean_;
|
||||||
|
kaldi::Vector<float> variance_;
|
||||||
|
NormalizerOptions _opts;
|
||||||
|
};
|
||||||
|
*/
|
||||||
|
} // namespace ppspeech
|
@ -0,0 +1,16 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
// extract the window of kaldi feat.
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,549 @@
|
|||||||
|
// decoder/lattice-faster-decoder.h
|
||||||
|
|
||||||
|
// Copyright 2009-2013 Microsoft Corporation; Mirko Hannemann;
|
||||||
|
// 2013-2014 Johns Hopkins University (Author: Daniel Povey)
|
||||||
|
// 2014 Guoguo Chen
|
||||||
|
// 2018 Zhehuai Chen
|
||||||
|
|
||||||
|
// See ../../COPYING for clarification regarding multiple authors
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||||
|
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||||
|
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||||
|
// See the Apache 2 License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#ifndef KALDI_DECODER_LATTICE_FASTER_DECODER_H_
|
||||||
|
#define KALDI_DECODER_LATTICE_FASTER_DECODER_H_
|
||||||
|
|
||||||
|
#include "decoder/grammar-fst.h"
|
||||||
|
#include "fst/fstlib.h"
|
||||||
|
#include "fst/memory.h"
|
||||||
|
#include "fstext/fstext-lib.h"
|
||||||
|
#include "itf/decodable-itf.h"
|
||||||
|
#include "lat/determinize-lattice-pruned.h"
|
||||||
|
#include "lat/kaldi-lattice.h"
|
||||||
|
#include "util/hash-list.h"
|
||||||
|
#include "util/stl-utils.h"
|
||||||
|
|
||||||
|
namespace kaldi {
|
||||||
|
|
||||||
|
struct LatticeFasterDecoderConfig {
|
||||||
|
BaseFloat beam;
|
||||||
|
int32 max_active;
|
||||||
|
int32 min_active;
|
||||||
|
BaseFloat lattice_beam;
|
||||||
|
int32 prune_interval;
|
||||||
|
bool determinize_lattice; // not inspected by this class... used in
|
||||||
|
// command-line program.
|
||||||
|
BaseFloat beam_delta;
|
||||||
|
BaseFloat hash_ratio;
|
||||||
|
// Note: we don't make prune_scale configurable on the command line, it's not
|
||||||
|
// a very important parameter. It affects the algorithm that prunes the
|
||||||
|
// tokens as we go.
|
||||||
|
BaseFloat prune_scale;
|
||||||
|
|
||||||
|
// Number of elements in the block for Token and ForwardLink memory
|
||||||
|
// pool allocation.
|
||||||
|
int32 memory_pool_tokens_block_size;
|
||||||
|
int32 memory_pool_links_block_size;
|
||||||
|
|
||||||
|
// Most of the options inside det_opts are not actually queried by the
|
||||||
|
// LatticeFasterDecoder class itself, but by the code that calls it, for
|
||||||
|
// example in the function DecodeUtteranceLatticeFaster.
|
||||||
|
fst::DeterminizeLatticePhonePrunedOptions det_opts;
|
||||||
|
|
||||||
|
LatticeFasterDecoderConfig()
|
||||||
|
: beam(16.0),
|
||||||
|
max_active(std::numeric_limits<int32>::max()),
|
||||||
|
min_active(200),
|
||||||
|
lattice_beam(10.0),
|
||||||
|
prune_interval(25),
|
||||||
|
determinize_lattice(true),
|
||||||
|
beam_delta(0.5),
|
||||||
|
hash_ratio(2.0),
|
||||||
|
prune_scale(0.1),
|
||||||
|
memory_pool_tokens_block_size(1 << 8),
|
||||||
|
memory_pool_links_block_size(1 << 8) {}
|
||||||
|
void Register(OptionsItf *opts) {
|
||||||
|
det_opts.Register(opts);
|
||||||
|
opts->Register("beam", &beam, "Decoding beam. Larger->slower, more accurate.");
|
||||||
|
opts->Register("max-active", &max_active, "Decoder max active states. Larger->slower; "
|
||||||
|
"more accurate");
|
||||||
|
opts->Register("min-active", &min_active, "Decoder minimum #active states.");
|
||||||
|
opts->Register("lattice-beam", &lattice_beam, "Lattice generation beam. Larger->slower, "
|
||||||
|
"and deeper lattices");
|
||||||
|
opts->Register("prune-interval", &prune_interval, "Interval (in frames) at "
|
||||||
|
"which to prune tokens");
|
||||||
|
opts->Register("determinize-lattice", &determinize_lattice, "If true, "
|
||||||
|
"determinize the lattice (lattice-determinization, keeping only "
|
||||||
|
"best pdf-sequence for each word-sequence).");
|
||||||
|
opts->Register("beam-delta", &beam_delta, "Increment used in decoding-- this "
|
||||||
|
"parameter is obscure and relates to a speedup in the way the "
|
||||||
|
"max-active constraint is applied. Larger is more accurate.");
|
||||||
|
opts->Register("hash-ratio", &hash_ratio, "Setting used in decoder to "
|
||||||
|
"control hash behavior");
|
||||||
|
opts->Register("memory-pool-tokens-block-size", &memory_pool_tokens_block_size,
|
||||||
|
"Memory pool block size suggestion for storing tokens (in elements). "
|
||||||
|
"Smaller uses less memory but increases cache misses.");
|
||||||
|
opts->Register("memory-pool-links-block-size", &memory_pool_links_block_size,
|
||||||
|
"Memory pool block size suggestion for storing links (in elements). "
|
||||||
|
"Smaller uses less memory but increases cache misses.");
|
||||||
|
}
|
||||||
|
void Check() const {
|
||||||
|
KALDI_ASSERT(beam > 0.0 && max_active > 1 && lattice_beam > 0.0
|
||||||
|
&& min_active <= max_active
|
||||||
|
&& prune_interval > 0 && beam_delta > 0.0 && hash_ratio >= 1.0
|
||||||
|
&& prune_scale > 0.0 && prune_scale < 1.0);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
namespace decoder {
|
||||||
|
// We will template the decoder on the token type as well as the FST type; this
|
||||||
|
// is a mechanism so that we can use the same underlying decoder code for
|
||||||
|
// versions of the decoder that support quickly getting the best path
|
||||||
|
// (LatticeFasterOnlineDecoder, see lattice-faster-online-decoder.h) and also
|
||||||
|
// those that do not (LatticeFasterDecoder).
|
||||||
|
|
||||||
|
|
||||||
|
// ForwardLinks are the links from a token to a token on the next frame.
|
||||||
|
// or sometimes on the current frame (for input-epsilon links).
|
||||||
|
template <typename Token>
|
||||||
|
struct ForwardLink {
|
||||||
|
using Label = fst::StdArc::Label;
|
||||||
|
|
||||||
|
Token *next_tok; // the next token [or NULL if represents final-state]
|
||||||
|
Label ilabel; // ilabel on arc
|
||||||
|
Label olabel; // olabel on arc
|
||||||
|
BaseFloat graph_cost; // graph cost of traversing arc (contains LM, etc.)
|
||||||
|
BaseFloat acoustic_cost; // acoustic cost (pre-scaled) of traversing arc
|
||||||
|
ForwardLink *next; // next in singly-linked list of forward arcs (arcs
|
||||||
|
// in the state-level lattice) from a token.
|
||||||
|
inline ForwardLink(Token *next_tok, Label ilabel, Label olabel,
|
||||||
|
BaseFloat graph_cost, BaseFloat acoustic_cost,
|
||||||
|
ForwardLink *next):
|
||||||
|
next_tok(next_tok), ilabel(ilabel), olabel(olabel),
|
||||||
|
graph_cost(graph_cost), acoustic_cost(acoustic_cost),
|
||||||
|
next(next) { }
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
struct StdToken {
|
||||||
|
using ForwardLinkT = ForwardLink<StdToken>;
|
||||||
|
using Token = StdToken;
|
||||||
|
|
||||||
|
// Standard token type for LatticeFasterDecoder. Each active HCLG
|
||||||
|
// (decoding-graph) state on each frame has one token.
|
||||||
|
|
||||||
|
// tot_cost is the total (LM + acoustic) cost from the beginning of the
|
||||||
|
// utterance up to this point. (but see cost_offset_, which is subtracted
|
||||||
|
// to keep it in a good numerical range).
|
||||||
|
BaseFloat tot_cost;
|
||||||
|
|
||||||
|
// exta_cost is >= 0. After calling PruneForwardLinks, this equals the
|
||||||
|
// minimum difference between the cost of the best path that this link is a
|
||||||
|
// part of, and the cost of the absolute best path, under the assumption that
|
||||||
|
// any of the currently active states at the decoding front may eventually
|
||||||
|
// succeed (e.g. if you were to take the currently active states one by one
|
||||||
|
// and compute this difference, and then take the minimum).
|
||||||
|
BaseFloat extra_cost;
|
||||||
|
|
||||||
|
// 'links' is the head of singly-linked list of ForwardLinks, which is what we
|
||||||
|
// use for lattice generation.
|
||||||
|
ForwardLinkT *links;
|
||||||
|
|
||||||
|
//'next' is the next in the singly-linked list of tokens for this frame.
|
||||||
|
Token *next;
|
||||||
|
|
||||||
|
// This function does nothing and should be optimized out; it's needed
|
||||||
|
// so we can share the regular LatticeFasterDecoderTpl code and the code
|
||||||
|
// for LatticeFasterOnlineDecoder that supports fast traceback.
|
||||||
|
inline void SetBackpointer (Token *backpointer) { }
|
||||||
|
|
||||||
|
// This constructor just ignores the 'backpointer' argument. That argument is
|
||||||
|
// needed so that we can use the same decoder code for LatticeFasterDecoderTpl
|
||||||
|
// and LatticeFasterOnlineDecoderTpl (which needs backpointers to support a
|
||||||
|
// fast way to obtain the best path).
|
||||||
|
inline StdToken(BaseFloat tot_cost, BaseFloat extra_cost, ForwardLinkT *links,
|
||||||
|
Token *next, Token *backpointer):
|
||||||
|
tot_cost(tot_cost), extra_cost(extra_cost), links(links), next(next) { }
|
||||||
|
};
|
||||||
|
|
||||||
|
struct BackpointerToken {
|
||||||
|
using ForwardLinkT = ForwardLink<BackpointerToken>;
|
||||||
|
using Token = BackpointerToken;
|
||||||
|
|
||||||
|
// BackpointerToken is like Token but also
|
||||||
|
// Standard token type for LatticeFasterDecoder. Each active HCLG
|
||||||
|
// (decoding-graph) state on each frame has one token.
|
||||||
|
|
||||||
|
// tot_cost is the total (LM + acoustic) cost from the beginning of the
|
||||||
|
// utterance up to this point. (but see cost_offset_, which is subtracted
|
||||||
|
// to keep it in a good numerical range).
|
||||||
|
BaseFloat tot_cost;
|
||||||
|
|
||||||
|
// exta_cost is >= 0. After calling PruneForwardLinks, this equals
|
||||||
|
// the minimum difference between the cost of the best path, and the cost of
|
||||||
|
// this is on, and the cost of the absolute best path, under the assumption
|
||||||
|
// that any of the currently active states at the decoding front may
|
||||||
|
// eventually succeed (e.g. if you were to take the currently active states
|
||||||
|
// one by one and compute this difference, and then take the minimum).
|
||||||
|
BaseFloat extra_cost;
|
||||||
|
|
||||||
|
// 'links' is the head of singly-linked list of ForwardLinks, which is what we
|
||||||
|
// use for lattice generation.
|
||||||
|
ForwardLinkT *links;
|
||||||
|
|
||||||
|
//'next' is the next in the singly-linked list of tokens for this frame.
|
||||||
|
BackpointerToken *next;
|
||||||
|
|
||||||
|
// Best preceding BackpointerToken (could be a on this frame, connected to
|
||||||
|
// this via an epsilon transition, or on a previous frame). This is only
|
||||||
|
// required for an efficient GetBestPath function in
|
||||||
|
// LatticeFasterOnlineDecoderTpl; it plays no part in the lattice generation
|
||||||
|
// (the "links" list is what stores the forward links, for that).
|
||||||
|
Token *backpointer;
|
||||||
|
|
||||||
|
inline void SetBackpointer (Token *backpointer) {
|
||||||
|
this->backpointer = backpointer;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline BackpointerToken(BaseFloat tot_cost, BaseFloat extra_cost, ForwardLinkT *links,
|
||||||
|
Token *next, Token *backpointer):
|
||||||
|
tot_cost(tot_cost), extra_cost(extra_cost), links(links), next(next),
|
||||||
|
backpointer(backpointer) { }
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace decoder
|
||||||
|
|
||||||
|
|
||||||
|
/** This is the "normal" lattice-generating decoder.
|
||||||
|
See \ref lattices_generation \ref decoders_faster and \ref decoders_simple
|
||||||
|
for more information.
|
||||||
|
|
||||||
|
The decoder is templated on the FST type and the token type. The token type
|
||||||
|
will normally be StdToken, but also may be BackpointerToken which is to support
|
||||||
|
quick lookup of the current best path (see lattice-faster-online-decoder.h)
|
||||||
|
|
||||||
|
The FST you invoke this decoder which is expected to equal
|
||||||
|
Fst::Fst<fst::StdArc>, a.k.a. StdFst, or GrammarFst. If you invoke it with
|
||||||
|
FST == StdFst and it notices that the actual FST type is
|
||||||
|
fst::VectorFst<fst::StdArc> or fst::ConstFst<fst::StdArc>, the decoder object
|
||||||
|
will internally cast itself to one that is templated on those more specific
|
||||||
|
types; this is an optimization for speed.
|
||||||
|
*/
|
||||||
|
template <typename FST, typename Token = decoder::StdToken>
|
||||||
|
class LatticeFasterDecoderTpl {
|
||||||
|
public:
|
||||||
|
using Arc = typename FST::Arc;
|
||||||
|
using Label = typename Arc::Label;
|
||||||
|
using StateId = typename Arc::StateId;
|
||||||
|
using Weight = typename Arc::Weight;
|
||||||
|
using ForwardLinkT = decoder::ForwardLink<Token>;
|
||||||
|
|
||||||
|
// Instantiate this class once for each thing you have to decode.
|
||||||
|
// This version of the constructor does not take ownership of
|
||||||
|
// 'fst'.
|
||||||
|
LatticeFasterDecoderTpl(const FST &fst,
|
||||||
|
const LatticeFasterDecoderConfig &config);
|
||||||
|
|
||||||
|
// This version of the constructor takes ownership of the fst, and will delete
|
||||||
|
// it when this object is destroyed.
|
||||||
|
LatticeFasterDecoderTpl(const LatticeFasterDecoderConfig &config,
|
||||||
|
FST *fst);
|
||||||
|
|
||||||
|
void SetOptions(const LatticeFasterDecoderConfig &config) {
|
||||||
|
config_ = config;
|
||||||
|
}
|
||||||
|
|
||||||
|
const LatticeFasterDecoderConfig &GetOptions() const {
|
||||||
|
return config_;
|
||||||
|
}
|
||||||
|
|
||||||
|
~LatticeFasterDecoderTpl();
|
||||||
|
|
||||||
|
/// Decodes until there are no more frames left in the "decodable" object..
|
||||||
|
/// note, this may block waiting for input if the "decodable" object blocks.
|
||||||
|
/// Returns true if any kind of traceback is available (not necessarily from a
|
||||||
|
/// final state).
|
||||||
|
bool Decode(DecodableInterface *decodable);
|
||||||
|
|
||||||
|
|
||||||
|
/// says whether a final-state was active on the last frame. If it was not, the
|
||||||
|
/// lattice (or traceback) will end with states that are not final-states.
|
||||||
|
bool ReachedFinal() const {
|
||||||
|
return FinalRelativeCost() != std::numeric_limits<BaseFloat>::infinity();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Outputs an FST corresponding to the single best path through the lattice.
|
||||||
|
/// Returns true if result is nonempty (using the return status is deprecated,
|
||||||
|
/// it will become void). If "use_final_probs" is true AND we reached the
|
||||||
|
/// final-state of the graph then it will include those as final-probs, else
|
||||||
|
/// it will treat all final-probs as one. Note: this just calls GetRawLattice()
|
||||||
|
/// and figures out the shortest path.
|
||||||
|
bool GetBestPath(Lattice *ofst,
|
||||||
|
bool use_final_probs = true) const;
|
||||||
|
|
||||||
|
/// Outputs an FST corresponding to the raw, state-level
|
||||||
|
/// tracebacks. Returns true if result is nonempty.
|
||||||
|
/// If "use_final_probs" is true AND we reached the final-state
|
||||||
|
/// of the graph then it will include those as final-probs, else
|
||||||
|
/// it will treat all final-probs as one.
|
||||||
|
/// The raw lattice will be topologically sorted.
|
||||||
|
///
|
||||||
|
/// See also GetRawLatticePruned in lattice-faster-online-decoder.h,
|
||||||
|
/// which also supports a pruning beam, in case for some reason
|
||||||
|
/// you want it pruned tighter than the regular lattice beam.
|
||||||
|
/// We could put that here in future needed.
|
||||||
|
bool GetRawLattice(Lattice *ofst, bool use_final_probs = true) const;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
/// [Deprecated, users should now use GetRawLattice and determinize it
|
||||||
|
/// themselves, e.g. using DeterminizeLatticePhonePrunedWrapper].
|
||||||
|
/// Outputs an FST corresponding to the lattice-determinized
|
||||||
|
/// lattice (one path per word sequence). Returns true if result is nonempty.
|
||||||
|
/// If "use_final_probs" is true AND we reached the final-state of the graph
|
||||||
|
/// then it will include those as final-probs, else it will treat all
|
||||||
|
/// final-probs as one.
|
||||||
|
bool GetLattice(CompactLattice *ofst,
|
||||||
|
bool use_final_probs = true) const;
|
||||||
|
|
||||||
|
/// InitDecoding initializes the decoding, and should only be used if you
|
||||||
|
/// intend to call AdvanceDecoding(). If you call Decode(), you don't need to
|
||||||
|
/// call this. You can also call InitDecoding if you have already decoded an
|
||||||
|
/// utterance and want to start with a new utterance.
|
||||||
|
void InitDecoding();
|
||||||
|
|
||||||
|
/// This will decode until there are no more frames ready in the decodable
|
||||||
|
/// object. You can keep calling it each time more frames become available.
|
||||||
|
/// If max_num_frames is specified, it specifies the maximum number of frames
|
||||||
|
/// the function will decode before returning.
|
||||||
|
void AdvanceDecoding(DecodableInterface *decodable,
|
||||||
|
int32 max_num_frames = -1);
|
||||||
|
|
||||||
|
/// This function may be optionally called after AdvanceDecoding(), when you
|
||||||
|
/// do not plan to decode any further. It does an extra pruning step that
|
||||||
|
/// will help to prune the lattices output by GetLattice and (particularly)
|
||||||
|
/// GetRawLattice more completely, particularly toward the end of the
|
||||||
|
/// utterance. If you call this, you cannot call AdvanceDecoding again (it
|
||||||
|
/// will fail), and you cannot call GetLattice() and related functions with
|
||||||
|
/// use_final_probs = false. Used to be called PruneActiveTokensFinal().
|
||||||
|
void FinalizeDecoding();
|
||||||
|
|
||||||
|
/// FinalRelativeCost() serves the same purpose as ReachedFinal(), but gives
|
||||||
|
/// more information. It returns the difference between the best (final-cost
|
||||||
|
/// plus cost) of any token on the final frame, and the best cost of any token
|
||||||
|
/// on the final frame. If it is infinity it means no final-states were
|
||||||
|
/// present on the final frame. It will usually be nonnegative. If it not
|
||||||
|
/// too positive (e.g. < 5 is my first guess, but this is not tested) you can
|
||||||
|
/// take it as a good indication that we reached the final-state with
|
||||||
|
/// reasonable likelihood.
|
||||||
|
BaseFloat FinalRelativeCost() const;
|
||||||
|
|
||||||
|
|
||||||
|
// Returns the number of frames decoded so far. The value returned changes
|
||||||
|
// whenever we call ProcessEmitting().
|
||||||
|
inline int32 NumFramesDecoded() const { return active_toks_.size() - 1; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
// we make things protected instead of private, as code in
|
||||||
|
// LatticeFasterOnlineDecoderTpl, which inherits from this, also uses the
|
||||||
|
// internals.
|
||||||
|
|
||||||
|
// Deletes the elements of the singly linked list tok->links.
|
||||||
|
void DeleteForwardLinks(Token *tok);
|
||||||
|
|
||||||
|
// head of per-frame list of Tokens (list is in topological order),
|
||||||
|
// and something saying whether we ever pruned it using PruneForwardLinks.
|
||||||
|
struct TokenList {
|
||||||
|
Token *toks;
|
||||||
|
bool must_prune_forward_links;
|
||||||
|
bool must_prune_tokens;
|
||||||
|
TokenList(): toks(NULL), must_prune_forward_links(true),
|
||||||
|
must_prune_tokens(true) { }
|
||||||
|
};
|
||||||
|
|
||||||
|
using Elem = typename HashList<StateId, Token*>::Elem;
|
||||||
|
// Equivalent to:
|
||||||
|
// struct Elem {
|
||||||
|
// StateId key;
|
||||||
|
// Token *val;
|
||||||
|
// Elem *tail;
|
||||||
|
// };
|
||||||
|
|
||||||
|
void PossiblyResizeHash(size_t num_toks);
|
||||||
|
|
||||||
|
// FindOrAddToken either locates a token in hash of toks_, or if necessary
|
||||||
|
// inserts a new, empty token (i.e. with no forward links) for the current
|
||||||
|
// frame. [note: it's inserted if necessary into hash toks_ and also into the
|
||||||
|
// singly linked list of tokens active on this frame (whose head is at
|
||||||
|
// active_toks_[frame]). The frame_plus_one argument is the acoustic frame
|
||||||
|
// index plus one, which is used to index into the active_toks_ array.
|
||||||
|
// Returns the Token pointer. Sets "changed" (if non-NULL) to true if the
|
||||||
|
// token was newly created or the cost changed.
|
||||||
|
// If Token == StdToken, the 'backpointer' argument has no purpose (and will
|
||||||
|
// hopefully be optimized out).
|
||||||
|
inline Elem *FindOrAddToken(StateId state, int32 frame_plus_one,
|
||||||
|
BaseFloat tot_cost, Token *backpointer,
|
||||||
|
bool *changed);
|
||||||
|
|
||||||
|
// prunes outgoing links for all tokens in active_toks_[frame]
|
||||||
|
// it's called by PruneActiveTokens
|
||||||
|
// all links, that have link_extra_cost > lattice_beam are pruned
|
||||||
|
// delta is the amount by which the extra_costs must change
|
||||||
|
// before we set *extra_costs_changed = true.
|
||||||
|
// If delta is larger, we'll tend to go back less far
|
||||||
|
// toward the beginning of the file.
|
||||||
|
// extra_costs_changed is set to true if extra_cost was changed for any token
|
||||||
|
// links_pruned is set to true if any link in any token was pruned
|
||||||
|
void PruneForwardLinks(int32 frame_plus_one, bool *extra_costs_changed,
|
||||||
|
bool *links_pruned,
|
||||||
|
BaseFloat delta);
|
||||||
|
|
||||||
|
// This function computes the final-costs for tokens active on the final
|
||||||
|
// frame. It outputs to final-costs, if non-NULL, a map from the Token*
|
||||||
|
// pointer to the final-prob of the corresponding state, for all Tokens
|
||||||
|
// that correspond to states that have final-probs. This map will be
|
||||||
|
// empty if there were no final-probs. It outputs to
|
||||||
|
// final_relative_cost, if non-NULL, the difference between the best
|
||||||
|
// forward-cost including the final-prob cost, and the best forward-cost
|
||||||
|
// without including the final-prob cost (this will usually be positive), or
|
||||||
|
// infinity if there were no final-probs. [c.f. FinalRelativeCost(), which
|
||||||
|
// outputs this quanitity]. It outputs to final_best_cost, if
|
||||||
|
// non-NULL, the lowest for any token t active on the final frame, of
|
||||||
|
// forward-cost[t] + final-cost[t], where final-cost[t] is the final-cost in
|
||||||
|
// the graph of the state corresponding to token t, or the best of
|
||||||
|
// forward-cost[t] if there were no final-probs active on the final frame.
|
||||||
|
// You cannot call this after FinalizeDecoding() has been called; in that
|
||||||
|
// case you should get the answer from class-member variables.
|
||||||
|
void ComputeFinalCosts(unordered_map<Token*, BaseFloat> *final_costs,
|
||||||
|
BaseFloat *final_relative_cost,
|
||||||
|
BaseFloat *final_best_cost) const;
|
||||||
|
|
||||||
|
// PruneForwardLinksFinal is a version of PruneForwardLinks that we call
|
||||||
|
// on the final frame. If there are final tokens active, it uses
|
||||||
|
// the final-probs for pruning, otherwise it treats all tokens as final.
|
||||||
|
void PruneForwardLinksFinal();
|
||||||
|
|
||||||
|
// Prune away any tokens on this frame that have no forward links.
|
||||||
|
// [we don't do this in PruneForwardLinks because it would give us
|
||||||
|
// a problem with dangling pointers].
|
||||||
|
// It's called by PruneActiveTokens if any forward links have been pruned
|
||||||
|
void PruneTokensForFrame(int32 frame_plus_one);
|
||||||
|
|
||||||
|
|
||||||
|
// Go backwards through still-alive tokens, pruning them if the
|
||||||
|
// forward+backward cost is more than lat_beam away from the best path. It's
|
||||||
|
// possible to prove that this is "correct" in the sense that we won't lose
|
||||||
|
// anything outside of lat_beam, regardless of what happens in the future.
|
||||||
|
// delta controls when it considers a cost to have changed enough to continue
|
||||||
|
// going backward and propagating the change. larger delta -> will recurse
|
||||||
|
// less far.
|
||||||
|
void PruneActiveTokens(BaseFloat delta);
|
||||||
|
|
||||||
|
/// Gets the weight cutoff. Also counts the active tokens.
|
||||||
|
BaseFloat GetCutoff(Elem *list_head, size_t *tok_count,
|
||||||
|
BaseFloat *adaptive_beam, Elem **best_elem);
|
||||||
|
|
||||||
|
/// Processes emitting arcs for one frame. Propagates from prev_toks_ to
|
||||||
|
/// cur_toks_. Returns the cost cutoff for subsequent ProcessNonemitting() to
|
||||||
|
/// use.
|
||||||
|
BaseFloat ProcessEmitting(DecodableInterface *decodable);
|
||||||
|
|
||||||
|
/// Processes nonemitting (epsilon) arcs for one frame. Called after
|
||||||
|
/// ProcessEmitting() on each frame. The cost cutoff is computed by the
|
||||||
|
/// preceding ProcessEmitting().
|
||||||
|
void ProcessNonemitting(BaseFloat cost_cutoff);
|
||||||
|
|
||||||
|
// HashList defined in ../util/hash-list.h. It actually allows us to maintain
|
||||||
|
// more than one list (e.g. for current and previous frames), but only one of
|
||||||
|
// them at a time can be indexed by StateId. It is indexed by frame-index
|
||||||
|
// plus one, where the frame-index is zero-based, as used in decodable object.
|
||||||
|
// That is, the emitting probs of frame t are accounted for in tokens at
|
||||||
|
// toks_[t+1]. The zeroth frame is for nonemitting transition at the start of
|
||||||
|
// the graph.
|
||||||
|
HashList<StateId, Token*> toks_;
|
||||||
|
|
||||||
|
std::vector<TokenList> active_toks_; // Lists of tokens, indexed by
|
||||||
|
// frame (members of TokenList are toks, must_prune_forward_links,
|
||||||
|
// must_prune_tokens).
|
||||||
|
std::vector<const Elem* > queue_; // temp variable used in ProcessNonemitting,
|
||||||
|
std::vector<BaseFloat> tmp_array_; // used in GetCutoff.
|
||||||
|
|
||||||
|
// fst_ is a pointer to the FST we are decoding from.
|
||||||
|
const FST *fst_;
|
||||||
|
// delete_fst_ is true if the pointer fst_ needs to be deleted when this
|
||||||
|
// object is destroyed.
|
||||||
|
bool delete_fst_;
|
||||||
|
|
||||||
|
std::vector<BaseFloat> cost_offsets_; // This contains, for each
|
||||||
|
// frame, an offset that was added to the acoustic log-likelihoods on that
|
||||||
|
// frame in order to keep everything in a nice dynamic range i.e. close to
|
||||||
|
// zero, to reduce roundoff errors.
|
||||||
|
LatticeFasterDecoderConfig config_;
|
||||||
|
int32 num_toks_; // current total #toks allocated...
|
||||||
|
bool warned_;
|
||||||
|
|
||||||
|
/// decoding_finalized_ is true if someone called FinalizeDecoding(). [note,
|
||||||
|
/// calling this is optional]. If true, it's forbidden to decode more. Also,
|
||||||
|
/// if this is set, then the output of ComputeFinalCosts() is in the next
|
||||||
|
/// three variables. The reason we need to do this is that after
|
||||||
|
/// FinalizeDecoding() calls PruneTokensForFrame() for the final frame, some
|
||||||
|
/// of the tokens on the last frame are freed, so we free the list from toks_
|
||||||
|
/// to avoid having dangling pointers hanging around.
|
||||||
|
bool decoding_finalized_;
|
||||||
|
/// For the meaning of the next 3 variables, see the comment for
|
||||||
|
/// decoding_finalized_ above., and ComputeFinalCosts().
|
||||||
|
unordered_map<Token*, BaseFloat> final_costs_;
|
||||||
|
BaseFloat final_relative_cost_;
|
||||||
|
BaseFloat final_best_cost_;
|
||||||
|
|
||||||
|
// Memory pools for storing tokens and forward links.
|
||||||
|
// We use it to decrease the work put on allocator and to move some of data
|
||||||
|
// together. Too small block sizes will result in more work to allocator but
|
||||||
|
// bigger ones increase the memory usage.
|
||||||
|
fst::MemoryPool<Token> token_pool_;
|
||||||
|
fst::MemoryPool<ForwardLinkT> forward_link_pool_;
|
||||||
|
|
||||||
|
// There are various cleanup tasks... the toks_ structure contains
|
||||||
|
// singly linked lists of Token pointers, where Elem is the list type.
|
||||||
|
// It also indexes them in a hash, indexed by state (this hash is only
|
||||||
|
// maintained for the most recent frame). toks_.Clear()
|
||||||
|
// deletes them from the hash and returns the list of Elems. The
|
||||||
|
// function DeleteElems calls toks_.Delete(elem) for each elem in
|
||||||
|
// the list, which returns ownership of the Elem to the toks_ structure
|
||||||
|
// for reuse, but does not delete the Token pointer. The Token pointers
|
||||||
|
// are reference-counted and are ultimately deleted in PruneTokensForFrame,
|
||||||
|
// but are also linked together on each frame by their own linked-list,
|
||||||
|
// using the "next" pointer. We delete them manually.
|
||||||
|
void DeleteElems(Elem *list);
|
||||||
|
|
||||||
|
// This function takes a singly linked list of tokens for a single frame, and
|
||||||
|
// outputs a list of them in topological order (it will crash if no such order
|
||||||
|
// can be found, which will typically be due to decoding graphs with epsilon
|
||||||
|
// cycles, which are not allowed). Note: the output list may contain NULLs,
|
||||||
|
// which the caller should pass over; it just happens to be more efficient for
|
||||||
|
// the algorithm to output a list that contains NULLs.
|
||||||
|
static void TopSortTokens(Token *tok_list,
|
||||||
|
std::vector<Token*> *topsorted_list);
|
||||||
|
|
||||||
|
void ClearActiveTokens();
|
||||||
|
|
||||||
|
KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeFasterDecoderTpl);
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef LatticeFasterDecoderTpl<fst::StdFst, decoder::StdToken> LatticeFasterDecoder;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
} // end namespace kaldi.
|
||||||
|
|
||||||
|
#endif
|
@ -0,0 +1,285 @@
|
|||||||
|
// decoder/lattice-faster-online-decoder.cc
|
||||||
|
|
||||||
|
// Copyright 2009-2012 Microsoft Corporation Mirko Hannemann
|
||||||
|
// 2013-2014 Johns Hopkins University (Author: Daniel Povey)
|
||||||
|
// 2014 Guoguo Chen
|
||||||
|
// 2014 IMSL, PKU-HKUST (author: Wei Shi)
|
||||||
|
// 2018 Zhehuai Chen
|
||||||
|
|
||||||
|
// See ../../COPYING for clarification regarding multiple authors
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||||
|
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||||
|
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||||
|
// See the Apache 2 License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
// see note at the top of lattice-faster-decoder.cc, about how to maintain this
|
||||||
|
// file in sync with lattice-faster-decoder.cc
|
||||||
|
|
||||||
|
#include "decoder/lattice-faster-online-decoder.h"
|
||||||
|
#include "lat/lattice-functions.h"
|
||||||
|
|
||||||
|
namespace kaldi {
|
||||||
|
|
||||||
|
template <typename FST>
|
||||||
|
bool LatticeFasterOnlineDecoderTpl<FST>::TestGetBestPath(
|
||||||
|
bool use_final_probs) const {
|
||||||
|
Lattice lat1;
|
||||||
|
{
|
||||||
|
Lattice raw_lat;
|
||||||
|
this->GetRawLattice(&raw_lat, use_final_probs);
|
||||||
|
ShortestPath(raw_lat, &lat1);
|
||||||
|
}
|
||||||
|
Lattice lat2;
|
||||||
|
GetBestPath(&lat2, use_final_probs);
|
||||||
|
BaseFloat delta = 0.1;
|
||||||
|
int32 num_paths = 1;
|
||||||
|
if (!fst::RandEquivalent(lat1, lat2, num_paths, delta, rand())) {
|
||||||
|
KALDI_WARN << "Best-path test failed";
|
||||||
|
return false;
|
||||||
|
} else {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Outputs an FST corresponding to the single best path through the lattice.
|
||||||
|
template <typename FST>
|
||||||
|
bool LatticeFasterOnlineDecoderTpl<FST>::GetBestPath(Lattice *olat,
|
||||||
|
bool use_final_probs) const {
|
||||||
|
olat->DeleteStates();
|
||||||
|
BaseFloat final_graph_cost;
|
||||||
|
BestPathIterator iter = BestPathEnd(use_final_probs, &final_graph_cost);
|
||||||
|
if (iter.Done())
|
||||||
|
return false; // would have printed warning.
|
||||||
|
StateId state = olat->AddState();
|
||||||
|
olat->SetFinal(state, LatticeWeight(final_graph_cost, 0.0));
|
||||||
|
while (!iter.Done()) {
|
||||||
|
LatticeArc arc;
|
||||||
|
iter = TraceBackBestPath(iter, &arc);
|
||||||
|
arc.nextstate = state;
|
||||||
|
StateId new_state = olat->AddState();
|
||||||
|
olat->AddArc(new_state, arc);
|
||||||
|
state = new_state;
|
||||||
|
}
|
||||||
|
olat->SetStart(state);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename FST>
|
||||||
|
typename LatticeFasterOnlineDecoderTpl<FST>::BestPathIterator LatticeFasterOnlineDecoderTpl<FST>::BestPathEnd(
|
||||||
|
bool use_final_probs,
|
||||||
|
BaseFloat *final_cost_out) const {
|
||||||
|
if (this->decoding_finalized_ && !use_final_probs)
|
||||||
|
KALDI_ERR << "You cannot call FinalizeDecoding() and then call "
|
||||||
|
<< "BestPathEnd() with use_final_probs == false";
|
||||||
|
KALDI_ASSERT(this->NumFramesDecoded() > 0 &&
|
||||||
|
"You cannot call BestPathEnd if no frames were decoded.");
|
||||||
|
|
||||||
|
unordered_map<Token*, BaseFloat> final_costs_local;
|
||||||
|
|
||||||
|
const unordered_map<Token*, BaseFloat> &final_costs =
|
||||||
|
(this->decoding_finalized_ ? this->final_costs_ :final_costs_local);
|
||||||
|
if (!this->decoding_finalized_ && use_final_probs)
|
||||||
|
this->ComputeFinalCosts(&final_costs_local, NULL, NULL);
|
||||||
|
|
||||||
|
// Singly linked list of tokens on last frame (access list through "next"
|
||||||
|
// pointer).
|
||||||
|
BaseFloat best_cost = std::numeric_limits<BaseFloat>::infinity();
|
||||||
|
BaseFloat best_final_cost = 0;
|
||||||
|
Token *best_tok = NULL;
|
||||||
|
for (Token *tok = this->active_toks_.back().toks;
|
||||||
|
tok != NULL; tok = tok->next) {
|
||||||
|
BaseFloat cost = tok->tot_cost, final_cost = 0.0;
|
||||||
|
if (use_final_probs && !final_costs.empty()) {
|
||||||
|
// if we are instructed to use final-probs, and any final tokens were
|
||||||
|
// active on final frame, include the final-prob in the cost of the token.
|
||||||
|
typename unordered_map<Token*, BaseFloat>::const_iterator
|
||||||
|
iter = final_costs.find(tok);
|
||||||
|
if (iter != final_costs.end()) {
|
||||||
|
final_cost = iter->second;
|
||||||
|
cost += final_cost;
|
||||||
|
} else {
|
||||||
|
cost = std::numeric_limits<BaseFloat>::infinity();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (cost < best_cost) {
|
||||||
|
best_cost = cost;
|
||||||
|
best_tok = tok;
|
||||||
|
best_final_cost = final_cost;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (best_tok == NULL) { // this should not happen, and is likely a code error or
|
||||||
|
// caused by infinities in likelihoods, but I'm not making
|
||||||
|
// it a fatal error for now.
|
||||||
|
KALDI_WARN << "No final token found.";
|
||||||
|
}
|
||||||
|
if (final_cost_out)
|
||||||
|
*final_cost_out = best_final_cost;
|
||||||
|
return BestPathIterator(best_tok, this->NumFramesDecoded() - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template <typename FST>
|
||||||
|
typename LatticeFasterOnlineDecoderTpl<FST>::BestPathIterator LatticeFasterOnlineDecoderTpl<FST>::TraceBackBestPath(
|
||||||
|
BestPathIterator iter, LatticeArc *oarc) const {
|
||||||
|
KALDI_ASSERT(!iter.Done() && oarc != NULL);
|
||||||
|
Token *tok = static_cast<Token*>(iter.tok);
|
||||||
|
int32 cur_t = iter.frame, step_t = 0;
|
||||||
|
if (tok->backpointer != NULL) {
|
||||||
|
// retrieve the correct forward link(with the best link cost)
|
||||||
|
BaseFloat best_cost = std::numeric_limits<BaseFloat>::infinity();
|
||||||
|
ForwardLinkT *link;
|
||||||
|
for (link = tok->backpointer->links;
|
||||||
|
link != NULL; link = link->next) {
|
||||||
|
if (link->next_tok == tok) { // this is a link to "tok"
|
||||||
|
BaseFloat graph_cost = link->graph_cost,
|
||||||
|
acoustic_cost = link->acoustic_cost;
|
||||||
|
BaseFloat cost = graph_cost + acoustic_cost;
|
||||||
|
if (cost < best_cost) {
|
||||||
|
oarc->ilabel = link->ilabel;
|
||||||
|
oarc->olabel = link->olabel;
|
||||||
|
if (link->ilabel != 0) {
|
||||||
|
KALDI_ASSERT(static_cast<size_t>(cur_t) < this->cost_offsets_.size());
|
||||||
|
acoustic_cost -= this->cost_offsets_[cur_t];
|
||||||
|
step_t = -1;
|
||||||
|
} else {
|
||||||
|
step_t = 0;
|
||||||
|
}
|
||||||
|
oarc->weight = LatticeWeight(graph_cost, acoustic_cost);
|
||||||
|
best_cost = cost;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (link == NULL &&
|
||||||
|
best_cost == std::numeric_limits<BaseFloat>::infinity()) { // Did not find correct link.
|
||||||
|
KALDI_ERR << "Error tracing best-path back (likely "
|
||||||
|
<< "bug in token-pruning algorithm)";
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
oarc->ilabel = 0;
|
||||||
|
oarc->olabel = 0;
|
||||||
|
oarc->weight = LatticeWeight::One(); // zero costs.
|
||||||
|
}
|
||||||
|
return BestPathIterator(tok->backpointer, cur_t + step_t);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename FST>
|
||||||
|
bool LatticeFasterOnlineDecoderTpl<FST>::GetRawLatticePruned(
|
||||||
|
Lattice *ofst,
|
||||||
|
bool use_final_probs,
|
||||||
|
BaseFloat beam) const {
|
||||||
|
typedef LatticeArc Arc;
|
||||||
|
typedef Arc::StateId StateId;
|
||||||
|
typedef Arc::Weight Weight;
|
||||||
|
typedef Arc::Label Label;
|
||||||
|
|
||||||
|
// Note: you can't use the old interface (Decode()) if you want to
|
||||||
|
// get the lattice with use_final_probs = false. You'd have to do
|
||||||
|
// InitDecoding() and then AdvanceDecoding().
|
||||||
|
if (this->decoding_finalized_ && !use_final_probs)
|
||||||
|
KALDI_ERR << "You cannot call FinalizeDecoding() and then call "
|
||||||
|
<< "GetRawLattice() with use_final_probs == false";
|
||||||
|
|
||||||
|
unordered_map<Token*, BaseFloat> final_costs_local;
|
||||||
|
|
||||||
|
const unordered_map<Token*, BaseFloat> &final_costs =
|
||||||
|
(this->decoding_finalized_ ? this->final_costs_ : final_costs_local);
|
||||||
|
if (!this->decoding_finalized_ && use_final_probs)
|
||||||
|
this->ComputeFinalCosts(&final_costs_local, NULL, NULL);
|
||||||
|
|
||||||
|
ofst->DeleteStates();
|
||||||
|
// num-frames plus one (since frames are one-based, and we have
|
||||||
|
// an extra frame for the start-state).
|
||||||
|
int32 num_frames = this->active_toks_.size() - 1;
|
||||||
|
KALDI_ASSERT(num_frames > 0);
|
||||||
|
for (int32 f = 0; f <= num_frames; f++) {
|
||||||
|
if (this->active_toks_[f].toks == NULL) {
|
||||||
|
KALDI_WARN << "No tokens active on frame " << f
|
||||||
|
<< ": not producing lattice.\n";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
unordered_map<Token*, StateId> tok_map;
|
||||||
|
std::queue<std::pair<Token*, int32> > tok_queue;
|
||||||
|
// First initialize the queue and states. Put the initial state on the queue;
|
||||||
|
// this is the last token in the list active_toks_[0].toks.
|
||||||
|
for (Token *tok = this->active_toks_[0].toks;
|
||||||
|
tok != NULL; tok = tok->next) {
|
||||||
|
if (tok->next == NULL) {
|
||||||
|
tok_map[tok] = ofst->AddState();
|
||||||
|
ofst->SetStart(tok_map[tok]);
|
||||||
|
std::pair<Token*, int32> tok_pair(tok, 0); // #frame = 0
|
||||||
|
tok_queue.push(tok_pair);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next create states for "good" tokens
|
||||||
|
while (!tok_queue.empty()) {
|
||||||
|
std::pair<Token*, int32> cur_tok_pair = tok_queue.front();
|
||||||
|
tok_queue.pop();
|
||||||
|
Token *cur_tok = cur_tok_pair.first;
|
||||||
|
int32 cur_frame = cur_tok_pair.second;
|
||||||
|
KALDI_ASSERT(cur_frame >= 0 &&
|
||||||
|
cur_frame <= this->cost_offsets_.size());
|
||||||
|
|
||||||
|
typename unordered_map<Token*, StateId>::const_iterator iter =
|
||||||
|
tok_map.find(cur_tok);
|
||||||
|
KALDI_ASSERT(iter != tok_map.end());
|
||||||
|
StateId cur_state = iter->second;
|
||||||
|
|
||||||
|
for (ForwardLinkT *l = cur_tok->links;
|
||||||
|
l != NULL;
|
||||||
|
l = l->next) {
|
||||||
|
Token *next_tok = l->next_tok;
|
||||||
|
if (next_tok->extra_cost < beam) {
|
||||||
|
// so both the current and the next token are good; create the arc
|
||||||
|
int32 next_frame = l->ilabel == 0 ? cur_frame : cur_frame + 1;
|
||||||
|
StateId nextstate;
|
||||||
|
if (tok_map.find(next_tok) == tok_map.end()) {
|
||||||
|
nextstate = tok_map[next_tok] = ofst->AddState();
|
||||||
|
tok_queue.push(std::pair<Token*, int32>(next_tok, next_frame));
|
||||||
|
} else {
|
||||||
|
nextstate = tok_map[next_tok];
|
||||||
|
}
|
||||||
|
BaseFloat cost_offset = (l->ilabel != 0 ?
|
||||||
|
this->cost_offsets_[cur_frame] : 0);
|
||||||
|
Arc arc(l->ilabel, l->olabel,
|
||||||
|
Weight(l->graph_cost, l->acoustic_cost - cost_offset),
|
||||||
|
nextstate);
|
||||||
|
ofst->AddArc(cur_state, arc);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (cur_frame == num_frames) {
|
||||||
|
if (use_final_probs && !final_costs.empty()) {
|
||||||
|
typename unordered_map<Token*, BaseFloat>::const_iterator iter =
|
||||||
|
final_costs.find(cur_tok);
|
||||||
|
if (iter != final_costs.end())
|
||||||
|
ofst->SetFinal(cur_state, LatticeWeight(iter->second, 0));
|
||||||
|
} else {
|
||||||
|
ofst->SetFinal(cur_state, LatticeWeight::One());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return (ofst->NumStates() != 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
// Instantiate the template for the FST types that we'll need.
|
||||||
|
template class LatticeFasterOnlineDecoderTpl<fst::Fst<fst::StdArc> >;
|
||||||
|
template class LatticeFasterOnlineDecoderTpl<fst::VectorFst<fst::StdArc> >;
|
||||||
|
template class LatticeFasterOnlineDecoderTpl<fst::ConstFst<fst::StdArc> >;
|
||||||
|
template class LatticeFasterOnlineDecoderTpl<fst::ConstGrammarFst >;
|
||||||
|
template class LatticeFasterOnlineDecoderTpl<fst::VectorGrammarFst >;
|
||||||
|
|
||||||
|
|
||||||
|
} // end namespace kaldi.
|
@ -0,0 +1,147 @@
|
|||||||
|
// decoder/lattice-faster-online-decoder.h
|
||||||
|
|
||||||
|
// Copyright 2009-2013 Microsoft Corporation; Mirko Hannemann;
|
||||||
|
// 2013-2014 Johns Hopkins University (Author: Daniel Povey)
|
||||||
|
// 2014 Guoguo Chen
|
||||||
|
// 2018 Zhehuai Chen
|
||||||
|
|
||||||
|
// See ../../COPYING for clarification regarding multiple authors
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||||
|
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||||
|
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||||
|
// See the Apache 2 License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
// see note at the top of lattice-faster-decoder.h, about how to maintain this
|
||||||
|
// file in sync with lattice-faster-decoder.h
|
||||||
|
|
||||||
|
|
||||||
|
#ifndef KALDI_DECODER_LATTICE_FASTER_ONLINE_DECODER_H_
|
||||||
|
#define KALDI_DECODER_LATTICE_FASTER_ONLINE_DECODER_H_
|
||||||
|
|
||||||
|
#include "util/stl-utils.h"
|
||||||
|
#include "util/hash-list.h"
|
||||||
|
#include "fst/fstlib.h"
|
||||||
|
#include "itf/decodable-itf.h"
|
||||||
|
#include "fstext/fstext-lib.h"
|
||||||
|
#include "lat/determinize-lattice-pruned.h"
|
||||||
|
#include "lat/kaldi-lattice.h"
|
||||||
|
#include "decoder/lattice-faster-decoder.h"
|
||||||
|
|
||||||
|
namespace kaldi {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
/** LatticeFasterOnlineDecoderTpl is as LatticeFasterDecoderTpl but also
|
||||||
|
supports an efficient way to get the best path (see the function
|
||||||
|
BestPathEnd()), which is useful in endpointing and in situations where you
|
||||||
|
might want to frequently access the best path.
|
||||||
|
|
||||||
|
This is only templated on the FST type, since the Token type is required to
|
||||||
|
be BackpointerToken. Actually it only makes sense to instantiate
|
||||||
|
LatticeFasterDecoderTpl with Token == BackpointerToken if you do so indirectly via
|
||||||
|
this child class.
|
||||||
|
*/
|
||||||
|
template <typename FST>
|
||||||
|
class LatticeFasterOnlineDecoderTpl:
|
||||||
|
public LatticeFasterDecoderTpl<FST, decoder::BackpointerToken> {
|
||||||
|
public:
|
||||||
|
using Arc = typename FST::Arc;
|
||||||
|
using Label = typename Arc::Label;
|
||||||
|
using StateId = typename Arc::StateId;
|
||||||
|
using Weight = typename Arc::Weight;
|
||||||
|
using Token = decoder::BackpointerToken;
|
||||||
|
using ForwardLinkT = decoder::ForwardLink<Token>;
|
||||||
|
|
||||||
|
// Instantiate this class once for each thing you have to decode.
|
||||||
|
// This version of the constructor does not take ownership of
|
||||||
|
// 'fst'.
|
||||||
|
LatticeFasterOnlineDecoderTpl(const FST &fst,
|
||||||
|
const LatticeFasterDecoderConfig &config):
|
||||||
|
LatticeFasterDecoderTpl<FST, Token>(fst, config) { }
|
||||||
|
|
||||||
|
// This version of the initializer takes ownership of 'fst', and will delete
|
||||||
|
// it when this object is destroyed.
|
||||||
|
LatticeFasterOnlineDecoderTpl(const LatticeFasterDecoderConfig &config,
|
||||||
|
FST *fst):
|
||||||
|
LatticeFasterDecoderTpl<FST, Token>(config, fst) { }
|
||||||
|
|
||||||
|
|
||||||
|
struct BestPathIterator {
|
||||||
|
void *tok;
|
||||||
|
int32 frame;
|
||||||
|
// note, "frame" is the frame-index of the frame you'll get the
|
||||||
|
// transition-id for next time, if you call TraceBackBestPath on this
|
||||||
|
// iterator (assuming it's not an epsilon transition). Note that this
|
||||||
|
// is one less than you might reasonably expect, e.g. it's -1 for
|
||||||
|
// the nonemitting transitions before the first frame.
|
||||||
|
BestPathIterator(void *t, int32 f): tok(t), frame(f) { }
|
||||||
|
bool Done() const { return tok == NULL; }
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
/// Outputs an FST corresponding to the single best path through the lattice.
|
||||||
|
/// This is quite efficient because it doesn't get the entire raw lattice and find
|
||||||
|
/// the best path through it; instead, it uses the BestPathEnd and BestPathIterator
|
||||||
|
/// so it basically traces it back through the lattice.
|
||||||
|
/// Returns true if result is nonempty (using the return status is deprecated,
|
||||||
|
/// it will become void). If "use_final_probs" is true AND we reached the
|
||||||
|
/// final-state of the graph then it will include those as final-probs, else
|
||||||
|
/// it will treat all final-probs as one.
|
||||||
|
bool GetBestPath(Lattice *ofst,
|
||||||
|
bool use_final_probs = true) const;
|
||||||
|
|
||||||
|
|
||||||
|
/// This function does a self-test of GetBestPath(). Returns true on
|
||||||
|
/// success; returns false and prints a warning on failure.
|
||||||
|
bool TestGetBestPath(bool use_final_probs = true) const;
|
||||||
|
|
||||||
|
|
||||||
|
/// This function returns an iterator that can be used to trace back
|
||||||
|
/// the best path. If use_final_probs == true and at least one final state
|
||||||
|
/// survived till the end, it will use the final-probs in working out the best
|
||||||
|
/// final Token, and will output the final cost to *final_cost (if non-NULL),
|
||||||
|
/// else it will use only the forward likelihood, and will put zero in
|
||||||
|
/// *final_cost (if non-NULL).
|
||||||
|
/// Requires that NumFramesDecoded() > 0.
|
||||||
|
BestPathIterator BestPathEnd(bool use_final_probs,
|
||||||
|
BaseFloat *final_cost = NULL) const;
|
||||||
|
|
||||||
|
|
||||||
|
/// This function can be used in conjunction with BestPathEnd() to trace back
|
||||||
|
/// the best path one link at a time (e.g. this can be useful in endpoint
|
||||||
|
/// detection). By "link" we mean a link in the graph; not all links cross
|
||||||
|
/// frame boundaries, but each time you see a nonzero ilabel you can interpret
|
||||||
|
/// that as a frame. The return value is the updated iterator. It outputs
|
||||||
|
/// the ilabel and olabel, and the (graph and acoustic) weight to the "arc" pointer,
|
||||||
|
/// while leaving its "nextstate" variable unchanged.
|
||||||
|
BestPathIterator TraceBackBestPath(
|
||||||
|
BestPathIterator iter, LatticeArc *arc) const;
|
||||||
|
|
||||||
|
|
||||||
|
/// Behaves the same as GetRawLattice but only processes tokens whose
|
||||||
|
/// extra_cost is smaller than the best-cost plus the specified beam.
|
||||||
|
/// It is only worthwhile to call this function if beam is less than
|
||||||
|
/// the lattice_beam specified in the config; otherwise, it would
|
||||||
|
/// return essentially the same thing as GetRawLattice, but more slowly.
|
||||||
|
bool GetRawLatticePruned(Lattice *ofst,
|
||||||
|
bool use_final_probs,
|
||||||
|
BaseFloat beam) const;
|
||||||
|
|
||||||
|
KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeFasterOnlineDecoderTpl);
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef LatticeFasterOnlineDecoderTpl<fst::StdFst> LatticeFasterOnlineDecoder;
|
||||||
|
|
||||||
|
|
||||||
|
} // end namespace kaldi.
|
||||||
|
|
||||||
|
#endif
|
@ -0,0 +1,147 @@
|
|||||||
|
// lat/determinize-lattice-pruned-test.cc
|
||||||
|
|
||||||
|
// Copyright 2009-2012 Microsoft Corporation
|
||||||
|
// 2012-2013 Johns Hopkins University (Author: Daniel Povey)
|
||||||
|
|
||||||
|
// See ../../COPYING for clarification regarding multiple authors
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||||
|
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||||
|
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||||
|
// See the Apache 2 License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "lat/determinize-lattice-pruned.h"
|
||||||
|
#include "fstext/lattice-utils.h"
|
||||||
|
#include "fstext/fst-test-utils.h"
|
||||||
|
#include "lat/kaldi-lattice.h"
|
||||||
|
#include "lat/lattice-functions.h"
|
||||||
|
|
||||||
|
namespace fst {
|
||||||
|
// Caution: these tests are not as generic as you might think from all the
|
||||||
|
// templates in the code. They are basically only valid for LatticeArc.
|
||||||
|
// This is partly due to the fact that certain templates need to be instantiated
|
||||||
|
// in other .cc files in this directory.
|
||||||
|
|
||||||
|
// test that determinization proceeds correctly on general
|
||||||
|
// FSTs (not guaranteed determinzable, but we use the
|
||||||
|
// max-states option to stop it getting out of control).
|
||||||
|
template<class Arc> void TestDeterminizeLatticePruned() {
|
||||||
|
typedef kaldi::int32 Int;
|
||||||
|
typedef typename Arc::Weight Weight;
|
||||||
|
typedef ArcTpl<CompactLatticeWeightTpl<Weight, Int> > CompactArc;
|
||||||
|
|
||||||
|
for(int i = 0; i < 100; i++) {
|
||||||
|
RandFstOptions opts;
|
||||||
|
opts.n_states = 4;
|
||||||
|
opts.n_arcs = 10;
|
||||||
|
opts.n_final = 2;
|
||||||
|
opts.allow_empty = false;
|
||||||
|
opts.weight_multiplier = 0.5; // impt for the randomly generated weights
|
||||||
|
opts.acyclic = true;
|
||||||
|
// to be exactly representable in float,
|
||||||
|
// or this test fails because numerical differences can cause symmetry in
|
||||||
|
// weights to be broken, which causes the wrong path to be chosen as far
|
||||||
|
// as the string part is concerned.
|
||||||
|
|
||||||
|
VectorFst<Arc> *fst = RandPairFst<Arc>(opts);
|
||||||
|
|
||||||
|
bool sorted = TopSort(fst);
|
||||||
|
KALDI_ASSERT(sorted);
|
||||||
|
|
||||||
|
ILabelCompare<Arc> ilabel_comp;
|
||||||
|
if (kaldi::Rand() % 2 == 0)
|
||||||
|
ArcSort(fst, ilabel_comp);
|
||||||
|
|
||||||
|
std::cout << "FST before lattice-determinizing is:\n";
|
||||||
|
{
|
||||||
|
FstPrinter<Arc> fstprinter(*fst, NULL, NULL, NULL, false, true, "\t");
|
||||||
|
fstprinter.Print(&std::cout, "standard output");
|
||||||
|
}
|
||||||
|
VectorFst<Arc> det_fst;
|
||||||
|
try {
|
||||||
|
DeterminizeLatticePrunedOptions lat_opts;
|
||||||
|
lat_opts.max_mem = ((kaldi::Rand() % 2 == 0) ? 100 : 1000);
|
||||||
|
lat_opts.max_states = ((kaldi::Rand() % 2 == 0) ? -1 : 20);
|
||||||
|
lat_opts.max_arcs = ((kaldi::Rand() % 2 == 0) ? -1 : 30);
|
||||||
|
bool ans = DeterminizeLatticePruned<Weight>(*fst, 10.0, &det_fst, lat_opts);
|
||||||
|
|
||||||
|
std::cout << "FST after lattice-determinizing is:\n";
|
||||||
|
{
|
||||||
|
FstPrinter<Arc> fstprinter(det_fst, NULL, NULL, NULL, false, true, "\t");
|
||||||
|
fstprinter.Print(&std::cout, "standard output");
|
||||||
|
}
|
||||||
|
KALDI_ASSERT(det_fst.Properties(kIDeterministic, true) & kIDeterministic);
|
||||||
|
// OK, now determinize it a different way and check equivalence.
|
||||||
|
// [note: it's not normal determinization, it's taking the best path
|
||||||
|
// for any input-symbol sequence....
|
||||||
|
|
||||||
|
|
||||||
|
VectorFst<Arc> pruned_fst(*fst);
|
||||||
|
if (pruned_fst.NumStates() != 0)
|
||||||
|
kaldi::PruneLattice(10.0, &pruned_fst);
|
||||||
|
|
||||||
|
VectorFst<CompactArc> compact_pruned_fst, compact_pruned_det_fst;
|
||||||
|
ConvertLattice<Weight, Int>(pruned_fst, &compact_pruned_fst, false);
|
||||||
|
std::cout << "Compact pruned FST is:\n";
|
||||||
|
{
|
||||||
|
FstPrinter<CompactArc> fstprinter(compact_pruned_fst, NULL, NULL, NULL, false, true, "\t");
|
||||||
|
fstprinter.Print(&std::cout, "standard output");
|
||||||
|
}
|
||||||
|
ConvertLattice<Weight, Int>(det_fst, &compact_pruned_det_fst, false);
|
||||||
|
|
||||||
|
std::cout << "Compact version of determinized FST is:\n";
|
||||||
|
{
|
||||||
|
FstPrinter<CompactArc> fstprinter(compact_pruned_det_fst, NULL, NULL, NULL, false, true, "\t");
|
||||||
|
fstprinter.Print(&std::cout, "standard output");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ans)
|
||||||
|
KALDI_ASSERT(RandEquivalent(compact_pruned_det_fst, compact_pruned_fst, 5/*paths*/, 0.01/*delta*/, kaldi::Rand()/*seed*/, 100/*path length, max*/));
|
||||||
|
} catch (...) {
|
||||||
|
std::cout << "Failed to lattice-determinize this FST (probably not determinizable)\n";
|
||||||
|
}
|
||||||
|
delete fst;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// test that determinization proceeds without crash on acyclic FSTs
|
||||||
|
// (guaranteed determinizable in this sense).
|
||||||
|
template<class Arc> void TestDeterminizeLatticePruned2() {
|
||||||
|
typedef typename Arc::Weight Weight;
|
||||||
|
RandFstOptions opts;
|
||||||
|
opts.acyclic = true;
|
||||||
|
for(int i = 0; i < 100; i++) {
|
||||||
|
VectorFst<Arc> *fst = RandPairFst<Arc>(opts);
|
||||||
|
std::cout << "FST before lattice-determinizing is:\n";
|
||||||
|
{
|
||||||
|
FstPrinter<Arc> fstprinter(*fst, NULL, NULL, NULL, false, true, "\t");
|
||||||
|
fstprinter.Print(&std::cout, "standard output");
|
||||||
|
}
|
||||||
|
VectorFst<Arc> ofst;
|
||||||
|
DeterminizeLatticePruned<Weight>(*fst, 10.0, &ofst);
|
||||||
|
std::cout << "FST after lattice-determinizing is:\n";
|
||||||
|
{
|
||||||
|
FstPrinter<Arc> fstprinter(ofst, NULL, NULL, NULL, false, true, "\t");
|
||||||
|
fstprinter.Print(&std::cout, "standard output");
|
||||||
|
}
|
||||||
|
delete fst;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
} // end namespace fst
|
||||||
|
|
||||||
|
int main() {
|
||||||
|
using namespace fst;
|
||||||
|
TestDeterminizeLatticePruned<kaldi::LatticeArc>();
|
||||||
|
TestDeterminizeLatticePruned2<kaldi::LatticeArc>();
|
||||||
|
std::cout << "Tests succeeded\n";
|
||||||
|
}
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,296 @@
|
|||||||
|
// lat/determinize-lattice-pruned.h
|
||||||
|
|
||||||
|
// Copyright 2009-2012 Microsoft Corporation
|
||||||
|
// 2012-2013 Johns Hopkins University (Author: Daniel Povey)
|
||||||
|
// 2014 Guoguo Chen
|
||||||
|
|
||||||
|
// See ../../COPYING for clarification regarding multiple authors
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||||
|
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||||
|
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||||
|
// See the Apache 2 License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#ifndef KALDI_LAT_DETERMINIZE_LATTICE_PRUNED_H_
|
||||||
|
#define KALDI_LAT_DETERMINIZE_LATTICE_PRUNED_H_
|
||||||
|
#include <fst/fstlib.h>
|
||||||
|
#include <fst/fst-decl.h>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <map>
|
||||||
|
#include <set>
|
||||||
|
#include <vector>
|
||||||
|
#include "fstext/lattice-weight.h"
|
||||||
|
#include "itf/transition-information.h"
|
||||||
|
#include "itf/options-itf.h"
|
||||||
|
#include "lat/kaldi-lattice.h"
|
||||||
|
|
||||||
|
namespace fst {
|
||||||
|
|
||||||
|
/// \addtogroup fst_extensions
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
|
||||||
|
// For example of usage, see test-determinize-lattice-pruned.cc
|
||||||
|
|
||||||
|
/*
|
||||||
|
DeterminizeLatticePruned implements a special form of determinization with
|
||||||
|
epsilon removal, optimized for a phase of lattice generation. This algorithm
|
||||||
|
also does pruning at the same time-- the combination is more efficient as it
|
||||||
|
somtimes prevents us from creating a lot of states that would later be pruned
|
||||||
|
away. This allows us to increase the lattice-beam and not have the algorithm
|
||||||
|
blow up. Also, because our algorithm processes states in order from those
|
||||||
|
that appear on high-scoring paths down to those that appear on low-scoring
|
||||||
|
paths, we can easily terminate the algorithm after a certain specified number
|
||||||
|
of states or arcs.
|
||||||
|
|
||||||
|
The input is an FST with weight-type BaseWeightType (usually a pair of floats,
|
||||||
|
with a lexicographical type of order, such as LatticeWeightTpl<float>).
|
||||||
|
Typically this would be a state-level lattice, with input symbols equal to
|
||||||
|
words, and output-symbols equal to p.d.f's (so like the inverse of HCLG). Imagine representing this as an
|
||||||
|
acceptor of type CompactLatticeWeightTpl<float>, in which the input/output
|
||||||
|
symbols are words, and the weights contain the original weights together with
|
||||||
|
strings (with zero or one symbol in them) containing the original output labels
|
||||||
|
(the p.d.f.'s). We determinize this using acceptor determinization with
|
||||||
|
epsilon removal. Remember (from lattice-weight.h) that
|
||||||
|
CompactLatticeWeightTpl has a special kind of semiring where we always take
|
||||||
|
the string corresponding to the best cost (of type BaseWeightType), and
|
||||||
|
discard the other. This corresponds to taking the best output-label sequence
|
||||||
|
(of p.d.f.'s) for each input-label sequence (of words). We couldn't use the
|
||||||
|
Gallic weight for this, or it would die as soon as it detected that the input
|
||||||
|
FST was non-functional. In our case, any acyclic FST (and many cyclic ones)
|
||||||
|
can be determinized.
|
||||||
|
We assume that there is a function
|
||||||
|
Compare(const BaseWeightType &a, const BaseWeightType &b)
|
||||||
|
that returns (-1, 0, 1) according to whether (a < b, a == b, a > b) in the
|
||||||
|
total order on the BaseWeightType... this information should be the
|
||||||
|
same as NaturalLess would give, but it's more efficient to do it this way.
|
||||||
|
You can define this for things like TropicalWeight if you need to instantiate
|
||||||
|
this class for that weight type.
|
||||||
|
|
||||||
|
We implement this determinization in a special way to make it efficient for
|
||||||
|
the types of FSTs that we will apply it to. One issue is that if we
|
||||||
|
explicitly represent the strings (in CompactLatticeWeightTpl) as vectors of
|
||||||
|
type vector<IntType>, the algorithm takes time quadratic in the length of
|
||||||
|
words (in states), because propagating each arc involves copying a whole
|
||||||
|
vector (of integers representing p.d.f.'s). Instead we use a hash structure
|
||||||
|
where each string is a pointer (Entry*), and uses a hash from (Entry*,
|
||||||
|
IntType), to the successor string (and a way to get the latest IntType and the
|
||||||
|
ancestor Entry*). [this is the class LatticeStringRepository].
|
||||||
|
|
||||||
|
Another issue is that rather than representing a determinized-state as a
|
||||||
|
collection of (state, weight), we represent it in a couple of reduced forms.
|
||||||
|
Suppose a determinized-state is a collection of (state, weight) pairs; call
|
||||||
|
this the "canonical representation". Note: these collections are always
|
||||||
|
normalized to remove any common weight and string part. Define end-states as
|
||||||
|
the subset of states that have an arc out of them with a label on, or are
|
||||||
|
final. If we represent a determinized-state a the set of just its (end-state,
|
||||||
|
weight) pairs, this will be a valid and more compact representation, and will
|
||||||
|
lead to a smaller set of determinized states (like early minimization). Call
|
||||||
|
this collection of (end-state, weight) pairs the "minimal representation". As
|
||||||
|
a mechanism to reduce compute, we can also consider another representation.
|
||||||
|
In the determinization algorithm, we start off with a set of (begin-state,
|
||||||
|
weight) pairs (where the "begin-states" are initial or have a label on the
|
||||||
|
transition into them), and the "canonical representation" consists of the
|
||||||
|
epsilon-closure of this set (i.e. follow epsilons). Call this set of
|
||||||
|
(begin-state, weight) pairs, appropriately normalized, the "initial
|
||||||
|
representation". If two initial representations are the same, the "canonical
|
||||||
|
representation" and hence the "minimal representation" will be the same. We
|
||||||
|
can use this to reduce compute. Note that if two initial representations are
|
||||||
|
different, this does not preclude the other representations from being the same.
|
||||||
|
|
||||||
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
struct DeterminizeLatticePrunedOptions {
|
||||||
|
float delta; // A small offset used to measure equality of weights.
|
||||||
|
int max_mem; // If >0, determinization will fail and return false
|
||||||
|
// when the algorithm's (approximate) memory consumption crosses this threshold.
|
||||||
|
int max_loop; // If >0, can be used to detect non-determinizable input
|
||||||
|
// (a case that wouldn't be caught by max_mem).
|
||||||
|
int max_states;
|
||||||
|
int max_arcs;
|
||||||
|
float retry_cutoff;
|
||||||
|
DeterminizeLatticePrunedOptions(): delta(kDelta),
|
||||||
|
max_mem(-1),
|
||||||
|
max_loop(-1),
|
||||||
|
max_states(-1),
|
||||||
|
max_arcs(-1),
|
||||||
|
retry_cutoff(0.5) { }
|
||||||
|
void Register (kaldi::OptionsItf *opts) {
|
||||||
|
opts->Register("delta", &delta, "Tolerance used in determinization");
|
||||||
|
opts->Register("max-mem", &max_mem, "Maximum approximate memory usage in "
|
||||||
|
"determinization (real usage might be many times this)");
|
||||||
|
opts->Register("max-arcs", &max_arcs, "Maximum number of arcs in "
|
||||||
|
"output FST (total, not per state");
|
||||||
|
opts->Register("max-states", &max_states, "Maximum number of arcs in output "
|
||||||
|
"FST (total, not per state");
|
||||||
|
opts->Register("max-loop", &max_loop, "Option used to detect a particular "
|
||||||
|
"type of determinization failure, typically due to invalid input "
|
||||||
|
"(e.g., negative-cost loops)");
|
||||||
|
opts->Register("retry-cutoff", &retry_cutoff, "Controls pruning un-determinized "
|
||||||
|
"lattice and retrying determinization: if effective-beam < "
|
||||||
|
"retry-cutoff * beam, we prune the raw lattice and retry. Avoids "
|
||||||
|
"ever getting empty output for long segments.");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct DeterminizeLatticePhonePrunedOptions {
|
||||||
|
// delta: a small offset used to measure equality of weights.
|
||||||
|
float delta;
|
||||||
|
// max_mem: if > 0, determinization will fail and return false when the
|
||||||
|
// algorithm's (approximate) memory consumption crosses this threshold.
|
||||||
|
int max_mem;
|
||||||
|
// phone_determinize: if true, do a first pass determinization on both phones
|
||||||
|
// and words.
|
||||||
|
bool phone_determinize;
|
||||||
|
// word_determinize: if true, do a second pass determinization on words only.
|
||||||
|
bool word_determinize;
|
||||||
|
// minimize: if true, push and minimize after determinization.
|
||||||
|
bool minimize;
|
||||||
|
DeterminizeLatticePhonePrunedOptions(): delta(kDelta),
|
||||||
|
max_mem(50000000),
|
||||||
|
phone_determinize(true),
|
||||||
|
word_determinize(true),
|
||||||
|
minimize(false) {}
|
||||||
|
void Register (kaldi::OptionsItf *opts) {
|
||||||
|
opts->Register("delta", &delta, "Tolerance used in determinization");
|
||||||
|
opts->Register("max-mem", &max_mem, "Maximum approximate memory usage in "
|
||||||
|
"determinization (real usage might be many times this).");
|
||||||
|
opts->Register("phone-determinize", &phone_determinize, "If true, do an "
|
||||||
|
"initial pass of determinization on both phones and words (see"
|
||||||
|
" also --word-determinize)");
|
||||||
|
opts->Register("word-determinize", &word_determinize, "If true, do a second "
|
||||||
|
"pass of determinization on words only (see also "
|
||||||
|
"--phone-determinize)");
|
||||||
|
opts->Register("minimize", &minimize, "If true, push and minimize after "
|
||||||
|
"determinization.");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
This function implements the normal version of DeterminizeLattice, in which the
|
||||||
|
output strings are represented using sequences of arcs, where all but the
|
||||||
|
first one has an epsilon on the input side. It also prunes using the beam
|
||||||
|
in the "prune" parameter. The input FST must be topologically sorted in order
|
||||||
|
for the algorithm to work. For efficiency it is recommended to sort ilabel as well.
|
||||||
|
Returns true on success, and false if it had to terminate the determinization
|
||||||
|
earlier than specified by the "prune" beam-- that is, if it terminated because
|
||||||
|
of the max_mem, max_loop or max_arcs constraints in the options.
|
||||||
|
CAUTION: you may want to use the version below which outputs to CompactLattice.
|
||||||
|
*/
|
||||||
|
template<class Weight>
|
||||||
|
bool DeterminizeLatticePruned(
|
||||||
|
const ExpandedFst<ArcTpl<Weight> > &ifst,
|
||||||
|
double prune,
|
||||||
|
MutableFst<ArcTpl<Weight> > *ofst,
|
||||||
|
DeterminizeLatticePrunedOptions opts = DeterminizeLatticePrunedOptions());
|
||||||
|
|
||||||
|
|
||||||
|
/* This is a version of DeterminizeLattice with a slightly more "natural" output format,
|
||||||
|
where the output sequences are encoded using the CompactLatticeArcTpl template
|
||||||
|
(i.e. the sequences of output symbols are represented directly as strings The input
|
||||||
|
FST must be topologically sorted in order for the algorithm to work. For efficiency
|
||||||
|
it is recommended to sort the ilabel for the input FST as well.
|
||||||
|
Returns true on normal success, and false if it had to terminate the determinization
|
||||||
|
earlier than specified by the "prune" beam-- that is, if it terminated because
|
||||||
|
of the max_mem, max_loop or max_arcs constraints in the options.
|
||||||
|
CAUTION: if Lattice is the input, you need to Invert() before calling this,
|
||||||
|
so words are on the input side.
|
||||||
|
*/
|
||||||
|
template<class Weight, class IntType>
|
||||||
|
bool DeterminizeLatticePruned(
|
||||||
|
const ExpandedFst<ArcTpl<Weight> >&ifst,
|
||||||
|
double prune,
|
||||||
|
MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > > *ofst,
|
||||||
|
DeterminizeLatticePrunedOptions opts = DeterminizeLatticePrunedOptions());
|
||||||
|
|
||||||
|
/** This function takes in lattices and inserts phones at phone boundaries. It
|
||||||
|
uses the transition model to work out the transition_id to phone map. The
|
||||||
|
returning value is the starting index of the phone label. Typically we pick
|
||||||
|
(maximum_output_label_index + 1) as this value. The inserted phones are then
|
||||||
|
mapped to (returning_value + original_phone_label) in the new lattice. The
|
||||||
|
returning value will be used by DeterminizeLatticeDeletePhones() where it
|
||||||
|
works out the phones according to this value.
|
||||||
|
*/
|
||||||
|
template<class Weight>
|
||||||
|
typename ArcTpl<Weight>::Label DeterminizeLatticeInsertPhones(
|
||||||
|
const kaldi::TransitionInformation &trans_model,
|
||||||
|
MutableFst<ArcTpl<Weight> > *fst);
|
||||||
|
|
||||||
|
/** This function takes in lattices and deletes "phones" from them. The "phones"
|
||||||
|
here are actually any label that is larger than first_phone_label because
|
||||||
|
when we insert phones into the lattice, we map the original phone label to
|
||||||
|
(first_phone_label + original_phone_label). It is supposed to be used
|
||||||
|
together with DeterminizeLatticeInsertPhones()
|
||||||
|
*/
|
||||||
|
template<class Weight>
|
||||||
|
void DeterminizeLatticeDeletePhones(
|
||||||
|
typename ArcTpl<Weight>::Label first_phone_label,
|
||||||
|
MutableFst<ArcTpl<Weight> > *fst);
|
||||||
|
|
||||||
|
/** This function is a wrapper of DeterminizeLatticePhonePrunedFirstPass() and
|
||||||
|
DeterminizeLatticePruned(). If --phone-determinize is set to true, it first
|
||||||
|
calls DeterminizeLatticePhonePrunedFirstPass() to do the initial pass of
|
||||||
|
determinization on the phone + word lattices. If --word-determinize is set
|
||||||
|
true, it then does a second pass of determinization on the word lattices by
|
||||||
|
calling DeterminizeLatticePruned(). If both are set to false, then it gives
|
||||||
|
a warning and copying the lattices without determinization.
|
||||||
|
|
||||||
|
Note: the point of doing first a phone-level determinization pass and then
|
||||||
|
a word-level determinization pass is that it allows us to determinize
|
||||||
|
deeper lattices without "failing early" and returning a too-small lattice
|
||||||
|
due to the max-mem constraint. The result should be the same as word-level
|
||||||
|
determinization in general, but for deeper lattices it is a bit faster,
|
||||||
|
despite the fact that we now have two passes of determinization by default.
|
||||||
|
*/
|
||||||
|
template<class Weight, class IntType>
|
||||||
|
bool DeterminizeLatticePhonePruned(
|
||||||
|
const kaldi::TransitionInformation &trans_model,
|
||||||
|
const ExpandedFst<ArcTpl<Weight> > &ifst,
|
||||||
|
double prune,
|
||||||
|
MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > > *ofst,
|
||||||
|
DeterminizeLatticePhonePrunedOptions opts
|
||||||
|
= DeterminizeLatticePhonePrunedOptions());
|
||||||
|
|
||||||
|
/** "Destructive" version of DeterminizeLatticePhonePruned() where the input
|
||||||
|
lattice might be changed.
|
||||||
|
*/
|
||||||
|
template<class Weight, class IntType>
|
||||||
|
bool DeterminizeLatticePhonePruned(
|
||||||
|
const kaldi::TransitionInformation &trans_model,
|
||||||
|
MutableFst<ArcTpl<Weight> > *ifst,
|
||||||
|
double prune,
|
||||||
|
MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > > *ofst,
|
||||||
|
DeterminizeLatticePhonePrunedOptions opts
|
||||||
|
= DeterminizeLatticePhonePrunedOptions());
|
||||||
|
|
||||||
|
/** This function is a wrapper of DeterminizeLatticePhonePruned() that works for
|
||||||
|
Lattice type FSTs. It simplifies the calling process by calling
|
||||||
|
TopSort() Invert() and ArcSort() for you.
|
||||||
|
Unlike other determinization routines, the function
|
||||||
|
requires "ifst" to have transition-id's on the input side and words on the
|
||||||
|
output side.
|
||||||
|
This function can be used as the top-level interface to all the determinization
|
||||||
|
code.
|
||||||
|
*/
|
||||||
|
bool DeterminizeLatticePhonePrunedWrapper(
|
||||||
|
const kaldi::TransitionInformation &trans_model,
|
||||||
|
MutableFst<kaldi::LatticeArc> *ifst,
|
||||||
|
double prune,
|
||||||
|
MutableFst<kaldi::CompactLatticeArc> *ofst,
|
||||||
|
DeterminizeLatticePhonePrunedOptions opts
|
||||||
|
= DeterminizeLatticePhonePrunedOptions());
|
||||||
|
|
||||||
|
/// @} end "addtogroup fst_extensions"
|
||||||
|
|
||||||
|
} // end namespace fst
|
||||||
|
|
||||||
|
#endif
|
@ -0,0 +1,506 @@
|
|||||||
|
// lat/kaldi-lattice.cc
|
||||||
|
|
||||||
|
// Copyright 2009-2011 Microsoft Corporation
|
||||||
|
// 2013 Johns Hopkins University (author: Daniel Povey)
|
||||||
|
|
||||||
|
// See ../../COPYING for clarification regarding multiple authors
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||||
|
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||||
|
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||||
|
// See the Apache 2 License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
#include "lat/kaldi-lattice.h"
|
||||||
|
#include "fst/script/print-impl.h"
|
||||||
|
|
||||||
|
namespace kaldi {
|
||||||
|
|
||||||
|
/// Converts lattice types if necessary, deleting its input.
|
||||||
|
template<class OrigWeightType>
|
||||||
|
CompactLattice* ConvertToCompactLattice(fst::VectorFst<OrigWeightType> *ifst) {
|
||||||
|
if (!ifst) return NULL;
|
||||||
|
CompactLattice *ofst = new CompactLattice();
|
||||||
|
ConvertLattice(*ifst, ofst);
|
||||||
|
delete ifst;
|
||||||
|
return ofst;
|
||||||
|
}
|
||||||
|
|
||||||
|
// This overrides the template if there is no type conversion going on
|
||||||
|
// (for efficiency).
|
||||||
|
template<>
|
||||||
|
CompactLattice* ConvertToCompactLattice(CompactLattice *ifst) {
|
||||||
|
return ifst;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Converts lattice types if necessary, deleting its input.
|
||||||
|
template<class OrigWeightType>
|
||||||
|
Lattice* ConvertToLattice(fst::VectorFst<OrigWeightType> *ifst) {
|
||||||
|
if (!ifst) return NULL;
|
||||||
|
Lattice *ofst = new Lattice();
|
||||||
|
ConvertLattice(*ifst, ofst);
|
||||||
|
delete ifst;
|
||||||
|
return ofst;
|
||||||
|
}
|
||||||
|
|
||||||
|
// This overrides the template if there is no type conversion going on
|
||||||
|
// (for efficiency).
|
||||||
|
template<>
|
||||||
|
Lattice* ConvertToLattice(Lattice *ifst) {
|
||||||
|
return ifst;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
bool WriteCompactLattice(std::ostream &os, bool binary,
|
||||||
|
const CompactLattice &t) {
|
||||||
|
if (binary) {
|
||||||
|
fst::FstWriteOptions opts;
|
||||||
|
// Leave all the options default. Normally these lattices wouldn't have any
|
||||||
|
// osymbols/isymbols so no point directing it not to write them (who knows what
|
||||||
|
// we'd want to if we had them).
|
||||||
|
return t.Write(os, opts);
|
||||||
|
} else {
|
||||||
|
// Text-mode output. Note: we expect that t.InputSymbols() and
|
||||||
|
// t.OutputSymbols() would always return NULL. The corresponding input
|
||||||
|
// routine would not work if the FST actually had symbols attached.
|
||||||
|
// Write a newline after the key, so the first line of the FST appears
|
||||||
|
// on its own line.
|
||||||
|
os << '\n';
|
||||||
|
bool acceptor = true, write_one = false;
|
||||||
|
fst::FstPrinter<CompactLatticeArc> printer(t, t.InputSymbols(),
|
||||||
|
t.OutputSymbols(),
|
||||||
|
NULL, acceptor, write_one, "\t");
|
||||||
|
printer.Print(&os, "<unknown>");
|
||||||
|
if (os.fail())
|
||||||
|
KALDI_WARN << "Stream failure detected.";
|
||||||
|
// Write another newline as a terminating character. The read routine will
|
||||||
|
// detect this [this is a Kaldi mechanism, not somethig in the original
|
||||||
|
// OpenFst code].
|
||||||
|
os << '\n';
|
||||||
|
return os.good();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// LatticeReader provides (static) functions for reading both Lattice
|
||||||
|
/// and CompactLattice, in text form.
|
||||||
|
class LatticeReader {
|
||||||
|
typedef LatticeArc Arc;
|
||||||
|
typedef LatticeWeight Weight;
|
||||||
|
typedef CompactLatticeArc CArc;
|
||||||
|
typedef CompactLatticeWeight CWeight;
|
||||||
|
typedef Arc::Label Label;
|
||||||
|
typedef Arc::StateId StateId;
|
||||||
|
public:
|
||||||
|
// everything is static in this class.
|
||||||
|
|
||||||
|
/** This function reads from the FST text format; it does not know in advance
|
||||||
|
whether it's a Lattice or CompactLattice in the stream so it tries to
|
||||||
|
read both formats until it becomes clear which is the correct one.
|
||||||
|
*/
|
||||||
|
static std::pair<Lattice*, CompactLattice*> ReadText(
|
||||||
|
std::istream &is) {
|
||||||
|
typedef std::pair<Lattice*, CompactLattice*> PairT;
|
||||||
|
using std::string;
|
||||||
|
using std::vector;
|
||||||
|
Lattice *fst = new Lattice();
|
||||||
|
CompactLattice *cfst = new CompactLattice();
|
||||||
|
string line;
|
||||||
|
size_t nline = 0;
|
||||||
|
string separator = FLAGS_fst_field_separator + "\r\n";
|
||||||
|
while (std::getline(is, line)) {
|
||||||
|
nline++;
|
||||||
|
vector<string> col;
|
||||||
|
// on Windows we'll write in text and read in binary mode.
|
||||||
|
SplitStringToVector(line, separator.c_str(), true, &col);
|
||||||
|
if (col.size() == 0) break; // Empty line is a signal to stop, in our
|
||||||
|
// archive format.
|
||||||
|
if (col.size() > 5) {
|
||||||
|
KALDI_WARN << "Reading lattice: bad line in FST: " << line;
|
||||||
|
delete fst;
|
||||||
|
delete cfst;
|
||||||
|
return PairT(static_cast<Lattice*>(NULL),
|
||||||
|
static_cast<CompactLattice*>(NULL));
|
||||||
|
}
|
||||||
|
StateId s;
|
||||||
|
if (!ConvertStringToInteger(col[0], &s)) {
|
||||||
|
KALDI_WARN << "FstCompiler: bad line in FST: " << line;
|
||||||
|
delete fst;
|
||||||
|
delete cfst;
|
||||||
|
return PairT(static_cast<Lattice*>(NULL),
|
||||||
|
static_cast<CompactLattice*>(NULL));
|
||||||
|
}
|
||||||
|
if (fst)
|
||||||
|
while (s >= fst->NumStates())
|
||||||
|
fst->AddState();
|
||||||
|
if (cfst)
|
||||||
|
while (s >= cfst->NumStates())
|
||||||
|
cfst->AddState();
|
||||||
|
if (nline == 1) {
|
||||||
|
if (fst) fst->SetStart(s);
|
||||||
|
if (cfst) cfst->SetStart(s);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (fst) { // we still have fst; try to read that arc.
|
||||||
|
bool ok = true;
|
||||||
|
Arc arc;
|
||||||
|
Weight w;
|
||||||
|
StateId d = s;
|
||||||
|
switch (col.size()) {
|
||||||
|
case 1 :
|
||||||
|
fst->SetFinal(s, Weight::One());
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
if (!StrToWeight(col[1], true, &w)) ok = false;
|
||||||
|
else fst->SetFinal(s, w);
|
||||||
|
break;
|
||||||
|
case 3: // 3 columns not ok for Lattice format; it's not an acceptor.
|
||||||
|
ok = false;
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
ok = ConvertStringToInteger(col[1], &arc.nextstate) &&
|
||||||
|
ConvertStringToInteger(col[2], &arc.ilabel) &&
|
||||||
|
ConvertStringToInteger(col[3], &arc.olabel);
|
||||||
|
if (ok) {
|
||||||
|
d = arc.nextstate;
|
||||||
|
arc.weight = Weight::One();
|
||||||
|
fst->AddArc(s, arc);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case 5:
|
||||||
|
ok = ConvertStringToInteger(col[1], &arc.nextstate) &&
|
||||||
|
ConvertStringToInteger(col[2], &arc.ilabel) &&
|
||||||
|
ConvertStringToInteger(col[3], &arc.olabel) &&
|
||||||
|
StrToWeight(col[4], false, &arc.weight);
|
||||||
|
if (ok) {
|
||||||
|
d = arc.nextstate;
|
||||||
|
fst->AddArc(s, arc);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
ok = false;
|
||||||
|
}
|
||||||
|
while (d >= fst->NumStates())
|
||||||
|
fst->AddState();
|
||||||
|
if (!ok) {
|
||||||
|
delete fst;
|
||||||
|
fst = NULL;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (cfst) {
|
||||||
|
bool ok = true;
|
||||||
|
CArc arc;
|
||||||
|
CWeight w;
|
||||||
|
StateId d = s;
|
||||||
|
switch (col.size()) {
|
||||||
|
case 1 :
|
||||||
|
cfst->SetFinal(s, CWeight::One());
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
if (!StrToCWeight(col[1], true, &w)) ok = false;
|
||||||
|
else cfst->SetFinal(s, w);
|
||||||
|
break;
|
||||||
|
case 3: // compact-lattice is acceptor format: state, next-state, label.
|
||||||
|
ok = ConvertStringToInteger(col[1], &arc.nextstate) &&
|
||||||
|
ConvertStringToInteger(col[2], &arc.ilabel);
|
||||||
|
if (ok) {
|
||||||
|
d = arc.nextstate;
|
||||||
|
arc.olabel = arc.ilabel;
|
||||||
|
arc.weight = CWeight::One();
|
||||||
|
cfst->AddArc(s, arc);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
ok = ConvertStringToInteger(col[1], &arc.nextstate) &&
|
||||||
|
ConvertStringToInteger(col[2], &arc.ilabel) &&
|
||||||
|
StrToCWeight(col[3], false, &arc.weight);
|
||||||
|
if (ok) {
|
||||||
|
d = arc.nextstate;
|
||||||
|
arc.olabel = arc.ilabel;
|
||||||
|
cfst->AddArc(s, arc);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case 5: default:
|
||||||
|
ok = false;
|
||||||
|
}
|
||||||
|
while (d >= cfst->NumStates())
|
||||||
|
cfst->AddState();
|
||||||
|
if (!ok) {
|
||||||
|
delete cfst;
|
||||||
|
cfst = NULL;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!fst && !cfst) {
|
||||||
|
KALDI_WARN << "Bad line in lattice text format: " << line;
|
||||||
|
// read until we get an empty line, so at least we
|
||||||
|
// have a chance to read the next one (although this might
|
||||||
|
// be a bit futile since the calling code will get unhappy
|
||||||
|
// about failing to read this one.
|
||||||
|
while (std::getline(is, line)) {
|
||||||
|
SplitStringToVector(line, separator.c_str(), true, &col);
|
||||||
|
if (col.empty()) break;
|
||||||
|
}
|
||||||
|
return PairT(static_cast<Lattice*>(NULL),
|
||||||
|
static_cast<CompactLattice*>(NULL));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return PairT(fst, cfst);
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool StrToWeight(const std::string &s, bool allow_zero, Weight *w) {
|
||||||
|
std::istringstream strm(s);
|
||||||
|
strm >> *w;
|
||||||
|
if (!strm || (!allow_zero && *w == Weight::Zero())) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool StrToCWeight(const std::string &s, bool allow_zero, CWeight *w) {
|
||||||
|
std::istringstream strm(s);
|
||||||
|
strm >> *w;
|
||||||
|
if (!strm || (!allow_zero && *w == CWeight::Zero())) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
CompactLattice *ReadCompactLatticeText(std::istream &is) {
|
||||||
|
std::pair<Lattice*, CompactLattice*> lat_pair = LatticeReader::ReadText(is);
|
||||||
|
if (lat_pair.second != NULL) {
|
||||||
|
delete lat_pair.first;
|
||||||
|
return lat_pair.second;
|
||||||
|
} else if (lat_pair.first != NULL) {
|
||||||
|
// note: ConvertToCompactLattice frees its input.
|
||||||
|
return ConvertToCompactLattice(lat_pair.first);
|
||||||
|
} else {
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Lattice *ReadLatticeText(std::istream &is) {
|
||||||
|
std::pair<Lattice*, CompactLattice*> lat_pair = LatticeReader::ReadText(is);
|
||||||
|
if (lat_pair.first != NULL) {
|
||||||
|
delete lat_pair.second;
|
||||||
|
return lat_pair.first;
|
||||||
|
} else if (lat_pair.second != NULL) {
|
||||||
|
// note: ConvertToLattice frees its input.
|
||||||
|
return ConvertToLattice(lat_pair.second);
|
||||||
|
} else {
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ReadCompactLattice(std::istream &is, bool binary,
|
||||||
|
CompactLattice **clat) {
|
||||||
|
KALDI_ASSERT(*clat == NULL);
|
||||||
|
if (binary) {
|
||||||
|
fst::FstHeader hdr;
|
||||||
|
if (!hdr.Read(is, "<unknown>")) {
|
||||||
|
KALDI_WARN << "Reading compact lattice: error reading FST header.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (hdr.FstType() != "vector") {
|
||||||
|
KALDI_WARN << "Reading compact lattice: unsupported FST type: "
|
||||||
|
<< hdr.FstType();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
fst::FstReadOptions ropts("<unspecified>",
|
||||||
|
&hdr);
|
||||||
|
|
||||||
|
typedef fst::CompactLatticeWeightTpl<fst::LatticeWeightTpl<float>, int32> T1;
|
||||||
|
typedef fst::CompactLatticeWeightTpl<fst::LatticeWeightTpl<double>, int32> T2;
|
||||||
|
typedef fst::LatticeWeightTpl<float> T3;
|
||||||
|
typedef fst::LatticeWeightTpl<double> T4;
|
||||||
|
typedef fst::VectorFst<fst::ArcTpl<T1> > F1;
|
||||||
|
typedef fst::VectorFst<fst::ArcTpl<T2> > F2;
|
||||||
|
typedef fst::VectorFst<fst::ArcTpl<T3> > F3;
|
||||||
|
typedef fst::VectorFst<fst::ArcTpl<T4> > F4;
|
||||||
|
|
||||||
|
CompactLattice *ans = NULL;
|
||||||
|
if (hdr.ArcType() == T1::Type()) {
|
||||||
|
ans = ConvertToCompactLattice(F1::Read(is, ropts));
|
||||||
|
} else if (hdr.ArcType() == T2::Type()) {
|
||||||
|
ans = ConvertToCompactLattice(F2::Read(is, ropts));
|
||||||
|
} else if (hdr.ArcType() == T3::Type()) {
|
||||||
|
ans = ConvertToCompactLattice(F3::Read(is, ropts));
|
||||||
|
} else if (hdr.ArcType() == T4::Type()) {
|
||||||
|
ans = ConvertToCompactLattice(F4::Read(is, ropts));
|
||||||
|
} else {
|
||||||
|
KALDI_WARN << "FST with arc type " << hdr.ArcType()
|
||||||
|
<< " cannot be converted to CompactLattice.\n";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (ans == NULL) {
|
||||||
|
KALDI_WARN << "Error reading compact lattice (after reading header).";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
*clat = ans;
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
// The next line would normally consume the \r on Windows, plus any
|
||||||
|
// extra spaces that might have got in there somehow.
|
||||||
|
while (std::isspace(is.peek()) && is.peek() != '\n') is.get();
|
||||||
|
if (is.peek() == '\n') is.get(); // consume the newline.
|
||||||
|
else { // saw spaces but no newline.. this is not expected.
|
||||||
|
KALDI_WARN << "Reading compact lattice: unexpected sequence of spaces "
|
||||||
|
<< " at file position " << is.tellg();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
*clat = ReadCompactLatticeText(is); // that routine will warn on error.
|
||||||
|
return (*clat != NULL);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
bool CompactLatticeHolder::Read(std::istream &is) {
|
||||||
|
Clear(); // in case anything currently stored.
|
||||||
|
int c = is.peek();
|
||||||
|
if (c == -1) {
|
||||||
|
KALDI_WARN << "End of stream detected reading CompactLattice.";
|
||||||
|
return false;
|
||||||
|
} else if (isspace(c)) { // The text form of the lattice begins
|
||||||
|
// with space (normally, '\n'), so this means it's text (the binary form
|
||||||
|
// cannot begin with space because it starts with the FST Type() which is not
|
||||||
|
// space).
|
||||||
|
return ReadCompactLattice(is, false, &t_);
|
||||||
|
} else if (c != 214) { // 214 is first char of FST magic number,
|
||||||
|
// on little-endian machines which is all we support (\326 octal)
|
||||||
|
KALDI_WARN << "Reading compact lattice: does not appear to be an FST "
|
||||||
|
<< " [non-space but no magic number detected], file pos is "
|
||||||
|
<< is.tellg();
|
||||||
|
return false;
|
||||||
|
} else {
|
||||||
|
return ReadCompactLattice(is, true, &t_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool WriteLattice(std::ostream &os, bool binary, const Lattice &t) {
|
||||||
|
if (binary) {
|
||||||
|
fst::FstWriteOptions opts;
|
||||||
|
// Leave all the options default. Normally these lattices wouldn't have any
|
||||||
|
// osymbols/isymbols so no point directing it not to write them (who knows what
|
||||||
|
// we'd want to do if we had them).
|
||||||
|
return t.Write(os, opts);
|
||||||
|
} else {
|
||||||
|
// Text-mode output. Note: we expect that t.InputSymbols() and
|
||||||
|
// t.OutputSymbols() would always return NULL. The corresponding input
|
||||||
|
// routine would not work if the FST actually had symbols attached.
|
||||||
|
// Write a newline after the key, so the first line of the FST appears
|
||||||
|
// on its own line.
|
||||||
|
os << '\n';
|
||||||
|
bool acceptor = false, write_one = false;
|
||||||
|
fst::FstPrinter<LatticeArc> printer(t, t.InputSymbols(),
|
||||||
|
t.OutputSymbols(),
|
||||||
|
NULL, acceptor, write_one, "\t");
|
||||||
|
printer.Print(&os, "<unknown>");
|
||||||
|
if (os.fail())
|
||||||
|
KALDI_WARN << "Stream failure detected.";
|
||||||
|
// Write another newline as a terminating character. The read routine will
|
||||||
|
// detect this [this is a Kaldi mechanism, not somethig in the original
|
||||||
|
// OpenFst code].
|
||||||
|
os << '\n';
|
||||||
|
return os.good();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ReadLattice(std::istream &is, bool binary,
|
||||||
|
Lattice **lat) {
|
||||||
|
KALDI_ASSERT(*lat == NULL);
|
||||||
|
if (binary) {
|
||||||
|
fst::FstHeader hdr;
|
||||||
|
if (!hdr.Read(is, "<unknown>")) {
|
||||||
|
KALDI_WARN << "Reading lattice: error reading FST header.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (hdr.FstType() != "vector") {
|
||||||
|
KALDI_WARN << "Reading lattice: unsupported FST type: "
|
||||||
|
<< hdr.FstType();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
fst::FstReadOptions ropts("<unspecified>",
|
||||||
|
&hdr);
|
||||||
|
|
||||||
|
typedef fst::CompactLatticeWeightTpl<fst::LatticeWeightTpl<float>, int32> T1;
|
||||||
|
typedef fst::CompactLatticeWeightTpl<fst::LatticeWeightTpl<double>, int32> T2;
|
||||||
|
typedef fst::LatticeWeightTpl<float> T3;
|
||||||
|
typedef fst::LatticeWeightTpl<double> T4;
|
||||||
|
typedef fst::VectorFst<fst::ArcTpl<T1> > F1;
|
||||||
|
typedef fst::VectorFst<fst::ArcTpl<T2> > F2;
|
||||||
|
typedef fst::VectorFst<fst::ArcTpl<T3> > F3;
|
||||||
|
typedef fst::VectorFst<fst::ArcTpl<T4> > F4;
|
||||||
|
|
||||||
|
Lattice *ans = NULL;
|
||||||
|
if (hdr.ArcType() == T1::Type()) {
|
||||||
|
ans = ConvertToLattice(F1::Read(is, ropts));
|
||||||
|
} else if (hdr.ArcType() == T2::Type()) {
|
||||||
|
ans = ConvertToLattice(F2::Read(is, ropts));
|
||||||
|
} else if (hdr.ArcType() == T3::Type()) {
|
||||||
|
ans = ConvertToLattice(F3::Read(is, ropts));
|
||||||
|
} else if (hdr.ArcType() == T4::Type()) {
|
||||||
|
ans = ConvertToLattice(F4::Read(is, ropts));
|
||||||
|
} else {
|
||||||
|
KALDI_WARN << "FST with arc type " << hdr.ArcType()
|
||||||
|
<< " cannot be converted to Lattice.\n";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (ans == NULL) {
|
||||||
|
KALDI_WARN << "Error reading lattice (after reading header).";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
*lat = ans;
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
// The next line would normally consume the \r on Windows, plus any
|
||||||
|
// extra spaces that might have got in there somehow.
|
||||||
|
while (std::isspace(is.peek()) && is.peek() != '\n') is.get();
|
||||||
|
if (is.peek() == '\n') is.get(); // consume the newline.
|
||||||
|
else { // saw spaces but no newline.. this is not expected.
|
||||||
|
KALDI_WARN << "Reading compact lattice: unexpected sequence of spaces "
|
||||||
|
<< " at file position " << is.tellg();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
*lat = ReadLatticeText(is); // that routine will warn on error.
|
||||||
|
return (*lat != NULL);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/* Since we don't write the binary headers for this type of holder,
|
||||||
|
we use a different method to work out whether we're in binary mode.
|
||||||
|
*/
|
||||||
|
bool LatticeHolder::Read(std::istream &is) {
|
||||||
|
Clear(); // in case anything currently stored.
|
||||||
|
int c = is.peek();
|
||||||
|
if (c == -1) {
|
||||||
|
KALDI_WARN << "End of stream detected reading Lattice.";
|
||||||
|
return false;
|
||||||
|
} else if (isspace(c)) { // The text form of the lattice begins
|
||||||
|
// with space (normally, '\n'), so this means it's text (the binary form
|
||||||
|
// cannot begin with space because it starts with the FST Type() which is not
|
||||||
|
// space).
|
||||||
|
return ReadLattice(is, false, &t_);
|
||||||
|
} else if (c != 214) { // 214 is first char of FST magic number,
|
||||||
|
// on little-endian machines which is all we support (\326 octal)
|
||||||
|
KALDI_WARN << "Reading compact lattice: does not appear to be an FST "
|
||||||
|
<< " [non-space but no magic number detected], file pos is "
|
||||||
|
<< is.tellg();
|
||||||
|
return false;
|
||||||
|
} else {
|
||||||
|
return ReadLattice(is, true, &t_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
} // end namespace kaldi
|
@ -0,0 +1,156 @@
|
|||||||
|
// lat/kaldi-lattice.h
|
||||||
|
|
||||||
|
// Copyright 2009-2011 Microsoft Corporation
|
||||||
|
|
||||||
|
// See ../../COPYING for clarification regarding multiple authors
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||||
|
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||||
|
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||||
|
// See the Apache 2 License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
#ifndef KALDI_LAT_KALDI_LATTICE_H_
|
||||||
|
#define KALDI_LAT_KALDI_LATTICE_H_
|
||||||
|
|
||||||
|
#include "fstext/fstext-lib.h"
|
||||||
|
#include "base/kaldi-common.h"
|
||||||
|
#include "util/common-utils.h"
|
||||||
|
|
||||||
|
|
||||||
|
namespace kaldi {
|
||||||
|
// will import some things above...
|
||||||
|
|
||||||
|
typedef fst::LatticeWeightTpl<BaseFloat> LatticeWeight;
|
||||||
|
|
||||||
|
// careful: kaldi::int32 is not always the same C type as fst::int32
|
||||||
|
typedef fst::CompactLatticeWeightTpl<LatticeWeight, int32> CompactLatticeWeight;
|
||||||
|
|
||||||
|
typedef fst::CompactLatticeWeightCommonDivisorTpl<LatticeWeight, int32>
|
||||||
|
CompactLatticeWeightCommonDivisor;
|
||||||
|
|
||||||
|
typedef fst::ArcTpl<LatticeWeight> LatticeArc;
|
||||||
|
|
||||||
|
typedef fst::ArcTpl<CompactLatticeWeight> CompactLatticeArc;
|
||||||
|
|
||||||
|
typedef fst::VectorFst<LatticeArc> Lattice;
|
||||||
|
|
||||||
|
typedef fst::VectorFst<CompactLatticeArc> CompactLattice;
|
||||||
|
|
||||||
|
// The following functions for writing and reading lattices in binary or text
|
||||||
|
// form are provided here in case you need to include lattices in larger,
|
||||||
|
// Kaldi-type objects with their own Read and Write functions. Caution: these
|
||||||
|
// functions return false on stream failure rather than throwing an exception as
|
||||||
|
// most similar Kaldi functions would do.
|
||||||
|
|
||||||
|
bool WriteCompactLattice(std::ostream &os, bool binary,
|
||||||
|
const CompactLattice &clat);
|
||||||
|
bool WriteLattice(std::ostream &os, bool binary,
|
||||||
|
const Lattice &lat);
|
||||||
|
|
||||||
|
// the following function requires that *clat be
|
||||||
|
// NULL when called.
|
||||||
|
bool ReadCompactLattice(std::istream &is, bool binary,
|
||||||
|
CompactLattice **clat);
|
||||||
|
// the following function requires that *lat be
|
||||||
|
// NULL when called.
|
||||||
|
bool ReadLattice(std::istream &is, bool binary,
|
||||||
|
Lattice **lat);
|
||||||
|
|
||||||
|
|
||||||
|
class CompactLatticeHolder {
|
||||||
|
public:
|
||||||
|
typedef CompactLattice T;
|
||||||
|
|
||||||
|
CompactLatticeHolder() { t_ = NULL; }
|
||||||
|
|
||||||
|
static bool Write(std::ostream &os, bool binary, const T &t) {
|
||||||
|
// Note: we don't include the binary-mode header when writing
|
||||||
|
// this object to disk; this ensures that if we write to single
|
||||||
|
// files, the result can be read by OpenFst.
|
||||||
|
return WriteCompactLattice(os, binary, t);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Read(std::istream &is);
|
||||||
|
|
||||||
|
static bool IsReadInBinary() { return true; }
|
||||||
|
|
||||||
|
T &Value() {
|
||||||
|
KALDI_ASSERT(t_ != NULL && "Called Value() on empty CompactLatticeHolder");
|
||||||
|
return *t_;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Clear() { delete t_; t_ = NULL; }
|
||||||
|
|
||||||
|
void Swap(CompactLatticeHolder *other) {
|
||||||
|
std::swap(t_, other->t_);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ExtractRange(const CompactLatticeHolder &other, const std::string &range) {
|
||||||
|
KALDI_ERR << "ExtractRange is not defined for this type of holder.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
~CompactLatticeHolder() { Clear(); }
|
||||||
|
private:
|
||||||
|
T *t_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class LatticeHolder {
|
||||||
|
public:
|
||||||
|
typedef Lattice T;
|
||||||
|
|
||||||
|
LatticeHolder() { t_ = NULL; }
|
||||||
|
|
||||||
|
static bool Write(std::ostream &os, bool binary, const T &t) {
|
||||||
|
// Note: we don't include the binary-mode header when writing
|
||||||
|
// this object to disk; this ensures that if we write to single
|
||||||
|
// files, the result can be read by OpenFst.
|
||||||
|
return WriteLattice(os, binary, t);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Read(std::istream &is);
|
||||||
|
|
||||||
|
static bool IsReadInBinary() { return true; }
|
||||||
|
|
||||||
|
T &Value() {
|
||||||
|
KALDI_ASSERT(t_ != NULL && "Called Value() on empty LatticeHolder");
|
||||||
|
return *t_;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Clear() { delete t_; t_ = NULL; }
|
||||||
|
|
||||||
|
void Swap(LatticeHolder *other) {
|
||||||
|
std::swap(t_, other->t_);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ExtractRange(const LatticeHolder &other, const std::string &range) {
|
||||||
|
KALDI_ERR << "ExtractRange is not defined for this type of holder.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
~LatticeHolder() { Clear(); }
|
||||||
|
private:
|
||||||
|
T *t_;
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef TableWriter<LatticeHolder> LatticeWriter;
|
||||||
|
typedef SequentialTableReader<LatticeHolder> SequentialLatticeReader;
|
||||||
|
typedef RandomAccessTableReader<LatticeHolder> RandomAccessLatticeReader;
|
||||||
|
|
||||||
|
typedef TableWriter<CompactLatticeHolder> CompactLatticeWriter;
|
||||||
|
typedef SequentialTableReader<CompactLatticeHolder> SequentialCompactLatticeReader;
|
||||||
|
typedef RandomAccessTableReader<CompactLatticeHolder> RandomAccessCompactLatticeReader;
|
||||||
|
|
||||||
|
|
||||||
|
} // namespace kaldi
|
||||||
|
|
||||||
|
#endif // KALDI_LAT_KALDI_LATTICE_H_
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,402 @@
|
|||||||
|
// lat/lattice-functions.h
|
||||||
|
|
||||||
|
// Copyright 2009-2012 Saarland University (author: Arnab Ghoshal)
|
||||||
|
// 2012-2013 Johns Hopkins University (Author: Daniel Povey);
|
||||||
|
// Bagher BabaAli
|
||||||
|
// 2014 Guoguo Chen
|
||||||
|
|
||||||
|
// See ../../COPYING for clarification regarding multiple authors
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||||
|
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||||
|
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||||
|
// See the Apache 2 License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
#ifndef KALDI_LAT_LATTICE_FUNCTIONS_H_
|
||||||
|
#define KALDI_LAT_LATTICE_FUNCTIONS_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <map>
|
||||||
|
|
||||||
|
#include "base/kaldi-common.h"
|
||||||
|
#include "fstext/fstext-lib.h"
|
||||||
|
#include "itf/decodable-itf.h"
|
||||||
|
#include "itf/transition-information.h"
|
||||||
|
#include "lat/kaldi-lattice.h"
|
||||||
|
|
||||||
|
namespace kaldi {
|
||||||
|
|
||||||
|
// Redundant with the typedef in hmm/posterior.h. We want functions
|
||||||
|
// using the Posterior type to be usable without a dependency on the
|
||||||
|
// hmm library.
|
||||||
|
typedef std::vector<std::vector<std::pair<int32, BaseFloat> > > Posterior;
|
||||||
|
|
||||||
|
/**
|
||||||
|
This function extracts the per-frame log likelihoods from a linear
|
||||||
|
lattice (which we refer to as an 'nbest' lattice elsewhere in Kaldi code).
|
||||||
|
The dimension of *per_frame_loglikes will be set to the
|
||||||
|
number of input symbols in 'nbest'. The elements of
|
||||||
|
'*per_frame_loglikes' will be set to the .Value2() elements of the lattice
|
||||||
|
weights, which represent the acoustic costs; you may want to scale this
|
||||||
|
vector afterward by -1/acoustic_scale to get the original loglikes.
|
||||||
|
If there are acoustic costs on input-epsilon arcs or the final-prob in 'nbest'
|
||||||
|
(and this should not normally be the case in situations where it makes
|
||||||
|
sense to call this function), they will be included to the cost of the
|
||||||
|
preceding input symbol, or the following input symbol for input-epsilons
|
||||||
|
encountered prior to any input symbol. If 'nbest' has no input symbols,
|
||||||
|
'per_frame_loglikes' will be set to the empty vector.
|
||||||
|
**/
|
||||||
|
void GetPerFrameAcousticCosts(const Lattice &nbest,
|
||||||
|
Vector<BaseFloat> *per_frame_loglikes);
|
||||||
|
|
||||||
|
/// This function iterates over the states of a topologically sorted lattice and
|
||||||
|
/// counts the time instance corresponding to each state. The times are returned
|
||||||
|
/// in a vector of integers 'times' which is resized to have a size equal to the
|
||||||
|
/// number of states in the lattice. The function also returns the maximum time
|
||||||
|
/// in the lattice (this will equal the number of frames in the file).
|
||||||
|
int32 LatticeStateTimes(const Lattice &lat, std::vector<int32> *times);
|
||||||
|
|
||||||
|
/// As LatticeStateTimes, but in the CompactLattice format. Note: must
|
||||||
|
/// be topologically sorted. Returns length of the utterance in frames, which
|
||||||
|
/// might not be the same as the maximum time in the lattice, due to frames
|
||||||
|
/// in the final-prob.
|
||||||
|
int32 CompactLatticeStateTimes(const CompactLattice &clat,
|
||||||
|
std::vector<int32> *times);
|
||||||
|
|
||||||
|
/// This function does the forward-backward over lattices and computes the
|
||||||
|
/// posterior probabilities of the arcs. It returns the total log-probability
|
||||||
|
/// of the lattice. The Posterior quantities contain pairs of (transition-id, weight)
|
||||||
|
/// on each frame.
|
||||||
|
/// If the pointer "acoustic_like_sum" is provided, this value is set to
|
||||||
|
/// the sum over the arcs, of the posterior of the arc times the
|
||||||
|
/// acoustic likelihood [i.e. negated acoustic score] on that link.
|
||||||
|
/// This is used in combination with other quantities to work out
|
||||||
|
/// the objective function in MMI discriminative training.
|
||||||
|
BaseFloat LatticeForwardBackward(const Lattice &lat,
|
||||||
|
Posterior *arc_post,
|
||||||
|
double *acoustic_like_sum = NULL);
|
||||||
|
|
||||||
|
// This function is something similar to LatticeForwardBackward(), but it is on
|
||||||
|
// the CompactLattice lattice format. Also we only need the alpha in the forward
|
||||||
|
// path, not the posteriors.
|
||||||
|
bool ComputeCompactLatticeAlphas(const CompactLattice &lat,
|
||||||
|
std::vector<double> *alpha);
|
||||||
|
|
||||||
|
// A sibling of the function CompactLatticeAlphas()... We compute the beta from
|
||||||
|
// the backward path here.
|
||||||
|
bool ComputeCompactLatticeBetas(const CompactLattice &lat,
|
||||||
|
std::vector<double> *beta);
|
||||||
|
|
||||||
|
|
||||||
|
// Computes (normal or Viterbi) alphas and betas; returns (total-prob, or
|
||||||
|
// best-path negated cost) Note: in either case, the alphas and betas are
|
||||||
|
// negated costs. Requires that lat be topologically sorted. This code
|
||||||
|
// will work for either CompactLattice or Lattice.
|
||||||
|
template<typename LatticeType>
|
||||||
|
double ComputeLatticeAlphasAndBetas(const LatticeType &lat,
|
||||||
|
bool viterbi,
|
||||||
|
std::vector<double> *alpha,
|
||||||
|
std::vector<double> *beta);
|
||||||
|
|
||||||
|
|
||||||
|
/// Topologically sort the compact lattice if not already topologically sorted.
|
||||||
|
/// Will crash if the lattice cannot be topologically sorted.
|
||||||
|
void TopSortCompactLatticeIfNeeded(CompactLattice *clat);
|
||||||
|
|
||||||
|
|
||||||
|
/// Topologically sort the lattice if not already topologically sorted.
|
||||||
|
/// Will crash if lattice cannot be topologically sorted.
|
||||||
|
void TopSortLatticeIfNeeded(Lattice *clat);
|
||||||
|
|
||||||
|
/// Returns the depth of the lattice, defined as the average number of arcs (or
|
||||||
|
/// final-prob strings) crossing any given frame. Returns 1 for empty lattices.
|
||||||
|
/// Requires that clat is topologically sorted!
|
||||||
|
BaseFloat CompactLatticeDepth(const CompactLattice &clat,
|
||||||
|
int32 *num_frames = NULL);
|
||||||
|
|
||||||
|
/// This function returns, for each frame, the number of arcs crossing that
|
||||||
|
/// frame.
|
||||||
|
void CompactLatticeDepthPerFrame(const CompactLattice &clat,
|
||||||
|
std::vector<int32> *depth_per_frame);
|
||||||
|
|
||||||
|
|
||||||
|
/// This function limits the depth of the lattice, per frame: that means, it
|
||||||
|
/// does not allow more than a specified number of arcs active on any given
|
||||||
|
/// frame. This can be used to reduce the size of the "very deep" portions of
|
||||||
|
/// the lattice.
|
||||||
|
void CompactLatticeLimitDepth(int32 max_arcs_per_frame,
|
||||||
|
CompactLattice *clat);
|
||||||
|
|
||||||
|
|
||||||
|
/// Given a lattice, and a transition model to map pdf-ids to phones,
|
||||||
|
/// outputs for each frame the set of phones active on that frame. If
|
||||||
|
/// sil_phones (which must be sorted and uniq) is nonempty, it excludes
|
||||||
|
/// phones in this list.
|
||||||
|
void LatticeActivePhones(const Lattice &lat, const TransitionInformation &trans,
|
||||||
|
const std::vector<int32> &sil_phones,
|
||||||
|
std::vector<std::set<int32> > *active_phones);
|
||||||
|
|
||||||
|
/// Given a lattice, and a transition model to map pdf-ids to phones,
|
||||||
|
/// replace the output symbols (presumably words), with phones; we
|
||||||
|
/// use the TransitionModel to work out the phone sequence. Note
|
||||||
|
/// that the phone labels are not exactly aligned with the phone
|
||||||
|
/// boundaries. We put a phone label to coincide with any transition
|
||||||
|
/// to the final, nonemitting state of a phone (this state always exists,
|
||||||
|
/// we ensure this in HmmTopology::Check()). This would be the last
|
||||||
|
/// transition-id in the phone if reordering is not done (but typically
|
||||||
|
/// we do reorder).
|
||||||
|
/// Also see PhoneAlignLattice, in phone-align-lattice.h.
|
||||||
|
void ConvertLatticeToPhones(const TransitionInformation &trans_model,
|
||||||
|
Lattice *lat);
|
||||||
|
|
||||||
|
/// Prunes a lattice or compact lattice. Returns true on success, false if
|
||||||
|
/// there was some kind of failure.
|
||||||
|
template<class LatticeType>
|
||||||
|
bool PruneLattice(BaseFloat beam, LatticeType *lat);
|
||||||
|
|
||||||
|
|
||||||
|
/// Given a lattice, and a transition model to map pdf-ids to phones,
|
||||||
|
/// replace the sequences of transition-ids with sequences of phones.
|
||||||
|
/// Note that this is different from ConvertLatticeToPhones, in that
|
||||||
|
/// we replace the transition-ids not the words.
|
||||||
|
void ConvertCompactLatticeToPhones(const TransitionInformation &trans_model,
|
||||||
|
CompactLattice *clat);
|
||||||
|
|
||||||
|
/// Boosts LM probabilities by b * [number of frame errors]; equivalently, adds
|
||||||
|
/// -b*[number of frame errors] to the graph-component of the cost of each arc/path.
|
||||||
|
/// There is a frame error if a particular transition-id on a particular frame
|
||||||
|
/// corresponds to a phone not matching transcription's alignment for that frame.
|
||||||
|
/// This is used in "margin-inspired" discriminative training, esp. Boosted MMI.
|
||||||
|
/// The TransitionInformation is used to map transition-ids in the lattice
|
||||||
|
/// input-side to phones; the phones appearing in
|
||||||
|
/// "silence_phones" are treated specially in that we replace the frame error f
|
||||||
|
/// (either zero or 1) for a frame, with the minimum of f or max_silence_error.
|
||||||
|
/// For the normal recipe, max_silence_error would be zero.
|
||||||
|
/// Returns true on success, false if there was some kind of mismatch.
|
||||||
|
/// At input, silence_phones must be sorted and unique.
|
||||||
|
bool LatticeBoost(const TransitionInformation &trans,
|
||||||
|
const std::vector<int32> &alignment,
|
||||||
|
const std::vector<int32> &silence_phones,
|
||||||
|
BaseFloat b,
|
||||||
|
BaseFloat max_silence_error,
|
||||||
|
Lattice *lat);
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
This function implements either the MPFE (minimum phone frame error) or SMBR
|
||||||
|
(state-level minimum bayes risk) forward-backward, depending on whether
|
||||||
|
"criterion" is "mpfe" or "smbr". It returns the MPFE
|
||||||
|
criterion of SMBR criterion for this utterance, and outputs the posteriors (which
|
||||||
|
may be positive or negative) into "post".
|
||||||
|
|
||||||
|
@param [in] trans The transition model. Used to map the
|
||||||
|
transition-ids to phones or pdfs.
|
||||||
|
@param [in] silence_phones A list of integer ids of silence phones. The
|
||||||
|
silence frames i.e. the frames where num_ali
|
||||||
|
corresponds to a silence phones are treated specially.
|
||||||
|
The behavior is determined by 'one_silence_class'
|
||||||
|
being false (traditional behavior) or true.
|
||||||
|
Usually in our setup, several phones including
|
||||||
|
the silence, vocalized noise, non-spoken noise
|
||||||
|
and unk are treated as "silence phones"
|
||||||
|
@param [in] lat The denominator lattice
|
||||||
|
@param [in] num_ali The numerator alignment
|
||||||
|
@param [in] criterion The objective function. Must be "mpfe" or "smbr"
|
||||||
|
for MPFE (minimum phone frame error) or sMBR
|
||||||
|
(state minimum bayes risk) training.
|
||||||
|
@param [in] one_silence_class Determines how the silence frames are treated.
|
||||||
|
Setting this to false gives the old traditional behavior,
|
||||||
|
where the silence frames (according to num_ali) are
|
||||||
|
treated as incorrect. However, this means that the
|
||||||
|
insertions are not penalized by the objective.
|
||||||
|
Setting this to true gives the new behaviour, where we
|
||||||
|
treat silence as any other phone, except that all pdfs
|
||||||
|
of silence phones are collapsed into a single class for
|
||||||
|
the frame-error computation. This can possible reduce
|
||||||
|
the insertions in the trained model. This is closer to
|
||||||
|
the WER metric that we actually care about, since WER is
|
||||||
|
generally computed after filtering out noises, but
|
||||||
|
does penalize insertions.
|
||||||
|
@param [out] post The "MBR posteriors" i.e. derivatives w.r.t to the
|
||||||
|
pseudo log-likelihoods of states at each frame.
|
||||||
|
*/
|
||||||
|
BaseFloat LatticeForwardBackwardMpeVariants(
|
||||||
|
const TransitionInformation &trans,
|
||||||
|
const std::vector<int32> &silence_phones,
|
||||||
|
const Lattice &lat,
|
||||||
|
const std::vector<int32> &num_ali,
|
||||||
|
std::string criterion,
|
||||||
|
bool one_silence_class,
|
||||||
|
Posterior *post);
|
||||||
|
|
||||||
|
/// This function takes a CompactLattice that should only contain a single
|
||||||
|
/// linear sequence (e.g. derived from lattice-1best), and that should have been
|
||||||
|
/// processed so that the arcs in the CompactLattice align correctly with the
|
||||||
|
/// word boundaries (e.g. by lattice-align-words). It outputs 3 vectors of the
|
||||||
|
/// same size, which give, for each word in the lattice (in sequence), the word
|
||||||
|
/// label and the begin time and length in frames. This is done even for zero
|
||||||
|
/// (epsilon) words, generally corresponding to optional silence-- if you don't
|
||||||
|
/// want them, just ignore them in the output.
|
||||||
|
/// This function will print a warning and return false, if the lattice
|
||||||
|
/// did not have the correct format (e.g. if it is empty or it is not
|
||||||
|
/// linear).
|
||||||
|
bool CompactLatticeToWordAlignment(const CompactLattice &clat,
|
||||||
|
std::vector<int32> *words,
|
||||||
|
std::vector<int32> *begin_times,
|
||||||
|
std::vector<int32> *lengths);
|
||||||
|
|
||||||
|
/// A form of the shortest-path/best-path algorithm that's specially coded for
|
||||||
|
/// CompactLattice. Requires that clat be acyclic.
|
||||||
|
void CompactLatticeShortestPath(const CompactLattice &clat,
|
||||||
|
CompactLattice *shortest_path);
|
||||||
|
|
||||||
|
/// This function expands a CompactLattice to ensure high-probability paths
|
||||||
|
/// have unique histories. Arcs with posteriors larger than epsilon get splitted.
|
||||||
|
void ExpandCompactLattice(const CompactLattice &clat,
|
||||||
|
double epsilon,
|
||||||
|
CompactLattice *expand_clat);
|
||||||
|
|
||||||
|
/// For each state, compute forward and backward best (viterbi) costs and its
|
||||||
|
/// traceback states (for generating best paths later). The forward best cost
|
||||||
|
/// for a state is the cost of the best path from the start state to the state.
|
||||||
|
/// The traceback state of this state is its predecessor state in the best path.
|
||||||
|
/// The backward best cost for a state is the cost of the best path from the
|
||||||
|
/// state to a final one. Its traceback state is the successor state in the best
|
||||||
|
/// path in the forward direction.
|
||||||
|
/// Note: final weights of states are in backward_best_cost_and_pred.
|
||||||
|
/// Requires the input CompactLattice clat be acyclic.
|
||||||
|
typedef std::vector<std::pair<double,
|
||||||
|
CompactLatticeArc::StateId> > CostTraceType;
|
||||||
|
void CompactLatticeBestCostsAndTracebacks(
|
||||||
|
const CompactLattice &clat,
|
||||||
|
CostTraceType *forward_best_cost_and_pred,
|
||||||
|
CostTraceType *backward_best_cost_and_pred);
|
||||||
|
|
||||||
|
/// This function adds estimated neural language model scores of words in a
|
||||||
|
/// minimal list of hypotheses that covers a lattice, to the graph scores on the
|
||||||
|
/// arcs. The list of hypotheses are generated by latbin/lattice-path-cover.
|
||||||
|
typedef unordered_map<std::pair<int32, int32>, double, PairHasher<int32> > MapT;
|
||||||
|
void AddNnlmScoreToCompactLattice(const MapT &nnlm_scores,
|
||||||
|
CompactLattice *clat);
|
||||||
|
|
||||||
|
/// This function add the word insertion penalty to graph score of each word
|
||||||
|
/// in the compact lattice
|
||||||
|
void AddWordInsPenToCompactLattice(BaseFloat word_ins_penalty,
|
||||||
|
CompactLattice *clat);
|
||||||
|
|
||||||
|
/// This function *adds* the negated scores obtained from the Decodable object,
|
||||||
|
/// to the acoustic scores on the arcs. If you want to replace them, you should
|
||||||
|
/// use ScaleCompactLattice to first set the acoustic scores to zero. Returns
|
||||||
|
/// true on success, false on error (typically some kind of mismatched inputs).
|
||||||
|
bool RescoreCompactLattice(DecodableInterface *decodable,
|
||||||
|
CompactLattice *clat);
|
||||||
|
|
||||||
|
|
||||||
|
/// This function returns the number of words in the longest sentence in a
|
||||||
|
/// CompactLattice (i.e. the the maximum of any path, of the count of
|
||||||
|
/// olabels on that path).
|
||||||
|
int32 LongestSentenceLength(const Lattice &lat);
|
||||||
|
|
||||||
|
/// This function returns the number of words in the longest sentence in a
|
||||||
|
/// CompactLattice, i.e. the the maximum of any path, of the count of
|
||||||
|
/// labels on that path... note, in CompactLattice, the ilabels and olabels
|
||||||
|
/// are identical because it is an acceptor.
|
||||||
|
int32 LongestSentenceLength(const CompactLattice &lat);
|
||||||
|
|
||||||
|
|
||||||
|
/// This function is like RescoreCompactLattice, but it is modified to avoid
|
||||||
|
/// computing probabilities on most frames where all the pdf-ids are the same.
|
||||||
|
/// (it needs the transition-model to work out whether two transition-ids map to
|
||||||
|
/// the same pdf-id, and it assumes that the lattice has transition-ids on it).
|
||||||
|
/// The naive thing would be to just set all probabilities to zero on frames
|
||||||
|
/// where all the pdf-ids are the same (because this value won't affect the
|
||||||
|
/// lattice posterior). But this would become confusing when we compute
|
||||||
|
/// corpus-level diagnostics such as the MMI objective function. Instead,
|
||||||
|
/// imagine speedup_factor = 100 (it must be >= 1.0)... with probability (1.0 /
|
||||||
|
/// speedup_factor) we compute those likelihoods and multiply them by
|
||||||
|
/// speedup_factor; otherwise we set them to zero. This gives the right
|
||||||
|
/// expected probability so our corpus-level diagnostics will be about right.
|
||||||
|
bool RescoreCompactLatticeSpeedup(
|
||||||
|
const TransitionInformation &tmodel,
|
||||||
|
BaseFloat speedup_factor,
|
||||||
|
DecodableInterface *decodable,
|
||||||
|
CompactLattice *clat);
|
||||||
|
|
||||||
|
|
||||||
|
/// This function *adds* the negated scores obtained from the Decodable object,
|
||||||
|
/// to the acoustic scores on the arcs. If you want to replace them, you should
|
||||||
|
/// use ScaleCompactLattice to first set the acoustic scores to zero. Returns
|
||||||
|
/// true on success, false on error (e.g. some kind of mismatched inputs).
|
||||||
|
/// The input labels, if nonzero, are interpreted as transition-ids or whatever
|
||||||
|
/// other index the Decodable object expects.
|
||||||
|
bool RescoreLattice(DecodableInterface *decodable,
|
||||||
|
Lattice *lat);
|
||||||
|
|
||||||
|
/// This function Composes a CompactLattice format lattice with a
|
||||||
|
/// DeterministicOnDemandFst<fst::StdFst> format fst, and outputs another
|
||||||
|
/// CompactLattice format lattice. The first element (the one that corresponds
|
||||||
|
/// to LM weight) in CompactLatticeWeight is used for composition.
|
||||||
|
///
|
||||||
|
/// Note that the DeterministicOnDemandFst interface is not "const", therefore
|
||||||
|
/// we cannot use "const" for <det_fst>.
|
||||||
|
void ComposeCompactLatticeDeterministic(
|
||||||
|
const CompactLattice& clat,
|
||||||
|
fst::DeterministicOnDemandFst<fst::StdArc>* det_fst,
|
||||||
|
CompactLattice* composed_clat);
|
||||||
|
|
||||||
|
/// This function computes the mapping from the pair
|
||||||
|
/// (frame-index, transition-id) to the pair
|
||||||
|
/// (sum-of-acoustic-scores, num-of-occurences) over all occurences of the
|
||||||
|
/// transition-id in that frame.
|
||||||
|
/// frame-index in the lattice.
|
||||||
|
/// This function is useful for retaining the acoustic scores in a
|
||||||
|
/// non-compact lattice after a process like determinization where the
|
||||||
|
/// frame-level acoustic scores are typically lost.
|
||||||
|
/// The function ReplaceAcousticScoresFromMap is used to restore the
|
||||||
|
/// acoustic scores computed by this function.
|
||||||
|
///
|
||||||
|
/// @param [in] lat Input lattice. Expected to be top-sorted. Otherwise the
|
||||||
|
/// function will crash.
|
||||||
|
/// @param [out] acoustic_scores
|
||||||
|
/// Pointer to a map from the pair (frame-index,
|
||||||
|
/// transition-id) to a pair (sum-of-acoustic-scores,
|
||||||
|
/// num-of-occurences).
|
||||||
|
/// Usually the acoustic scores for a pdf-id (and hence
|
||||||
|
/// transition-id) on a frame will be the same for all the
|
||||||
|
/// occurences of the pdf-id in that frame.
|
||||||
|
/// But if not, we will take the average of the acoustic
|
||||||
|
/// scores. Hence, we store both the sum-of-acoustic-scores
|
||||||
|
/// and the num-of-occurences of the transition-id in that
|
||||||
|
/// frame.
|
||||||
|
void ComputeAcousticScoresMap(
|
||||||
|
const Lattice &lat,
|
||||||
|
unordered_map<std::pair<int32, int32>, std::pair<BaseFloat, int32>,
|
||||||
|
PairHasher<int32> > *acoustic_scores);
|
||||||
|
|
||||||
|
/// This function restores acoustic scores computed using the function
|
||||||
|
/// ComputeAcousticScoresMap into the lattice.
|
||||||
|
///
|
||||||
|
/// @param [in] acoustic_scores
|
||||||
|
/// A map from the pair (frame-index, transition-id) to a
|
||||||
|
/// pair (sum-of-acoustic-scores, num-of-occurences) of
|
||||||
|
/// the occurences of the transition-id in that frame.
|
||||||
|
/// See the comments for ComputeAcousticScoresMap for
|
||||||
|
/// details.
|
||||||
|
/// @param [out] lat Pointer to the output lattice.
|
||||||
|
void ReplaceAcousticScoresFromMap(
|
||||||
|
const unordered_map<std::pair<int32, int32>, std::pair<BaseFloat, int32>,
|
||||||
|
PairHasher<int32> > &acoustic_scores,
|
||||||
|
Lattice *lat);
|
||||||
|
|
||||||
|
} // namespace kaldi
|
||||||
|
|
||||||
|
#endif // KALDI_LAT_LATTICE_FUNCTIONS_H_
|
@ -0,0 +1,2 @@
|
|||||||
|
aux_source_directory(. DIR_LIB_SRCS)
|
||||||
|
add_library(nnet STATIC ${DIR_LIB_SRCS})
|
@ -0,0 +1,124 @@
|
|||||||
|
// itf/decodable-itf.h
|
||||||
|
|
||||||
|
// Copyright 2009-2011 Microsoft Corporation; Saarland University;
|
||||||
|
// Mirko Hannemann; Go Vivace Inc.;
|
||||||
|
// 2013 Johns Hopkins University (author: Daniel Povey)
|
||||||
|
|
||||||
|
// See ../../COPYING for clarification regarding multiple authors
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||||
|
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||||
|
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||||
|
// See the Apache 2 License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#ifndef KALDI_ITF_DECODABLE_ITF_H_
|
||||||
|
#define KALDI_ITF_DECODABLE_ITF_H_ 1
|
||||||
|
#include "base/kaldi-common.h"
|
||||||
|
|
||||||
|
namespace kaldi {
|
||||||
|
/// @ingroup Interfaces
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
DecodableInterface provides a link between the (acoustic-modeling and
|
||||||
|
feature-processing) code and the decoder. The idea is to make this
|
||||||
|
interface as small as possible, and to make it as agnostic as possible about
|
||||||
|
the form of the acoustic model (e.g. don't assume the probabilities are a
|
||||||
|
function of just a vector of floats), and about the decoder (e.g. don't
|
||||||
|
assume it accesses frames in strict left-to-right order). For normal
|
||||||
|
models, without on-line operation, the "decodable" sub-class will just be a
|
||||||
|
wrapper around a matrix of features and an acoustic model, and it will
|
||||||
|
answer the question 'what is the acoustic likelihood for this index and this
|
||||||
|
frame?'.
|
||||||
|
|
||||||
|
For online decoding, where the features are coming in in real time, it is
|
||||||
|
important to understand the IsLastFrame() and NumFramesReady() functions.
|
||||||
|
There are two ways these are used: the old online-decoding code, in ../online/,
|
||||||
|
and the new online-decoding code, in ../online2/. In the old online-decoding
|
||||||
|
code, the decoder would do:
|
||||||
|
\code{.cc}
|
||||||
|
for (int frame = 0; !decodable.IsLastFrame(frame); frame++) {
|
||||||
|
// Process this frame
|
||||||
|
}
|
||||||
|
\endcode
|
||||||
|
and the call to IsLastFrame would block if the features had not arrived yet.
|
||||||
|
The decodable object would have to know when to terminate the decoding. This
|
||||||
|
online-decoding mode is still supported, it is what happens when you call, for
|
||||||
|
example, LatticeFasterDecoder::Decode().
|
||||||
|
|
||||||
|
We realized that this "blocking" mode of decoding is not very convenient
|
||||||
|
because it forces the program to be multi-threaded and makes it complex to
|
||||||
|
control endpointing. In the "new" decoding code, you don't call (for example)
|
||||||
|
LatticeFasterDecoder::Decode(), you call LatticeFasterDecoder::InitDecoding(),
|
||||||
|
and then each time you get more features, you provide them to the decodable
|
||||||
|
object, and you call LatticeFasterDecoder::AdvanceDecoding(), which does
|
||||||
|
something like this:
|
||||||
|
\code{.cc}
|
||||||
|
while (num_frames_decoded_ < decodable.NumFramesReady()) {
|
||||||
|
// Decode one more frame [increments num_frames_decoded_]
|
||||||
|
}
|
||||||
|
\endcode
|
||||||
|
So the decodable object never has IsLastFrame() called. For decoding where
|
||||||
|
you are starting with a matrix of features, the NumFramesReady() function will
|
||||||
|
always just return the number of frames in the file, and IsLastFrame() will
|
||||||
|
return true for the last frame.
|
||||||
|
|
||||||
|
For truly online decoding, the "old" online decodable objects in ../online/
|
||||||
|
have a "blocking" IsLastFrame() and will crash if you call NumFramesReady().
|
||||||
|
The "new" online decodable objects in ../online2/ return the number of frames
|
||||||
|
currently accessible if you call NumFramesReady(). You will likely not need
|
||||||
|
to call IsLastFrame(), but we implement it to only return true for the last
|
||||||
|
frame of the file once we've decided to terminate decoding.
|
||||||
|
*/
|
||||||
|
class DecodableInterface {
|
||||||
|
public:
|
||||||
|
/// Returns the log likelihood, which will be negated in the decoder.
|
||||||
|
/// The "frame" starts from zero. You should verify that NumFramesReady() > frame
|
||||||
|
/// before calling this.
|
||||||
|
virtual BaseFloat LogLikelihood(int32 frame, int32 index) = 0;
|
||||||
|
|
||||||
|
/// Returns true if this is the last frame. Frames are zero-based, so the
|
||||||
|
/// first frame is zero. IsLastFrame(-1) will return false, unless the file
|
||||||
|
/// is empty (which is a case that I'm not sure all the code will handle, so
|
||||||
|
/// be careful). Caution: the behavior of this function in an online setting
|
||||||
|
/// is being changed somewhat. In future it may return false in cases where
|
||||||
|
/// we haven't yet decided to terminate decoding, but later true if we decide
|
||||||
|
/// to terminate decoding. The plan in future is to rely more on
|
||||||
|
/// NumFramesReady(), and in future, IsLastFrame() would always return false
|
||||||
|
/// in an online-decoding setting, and would only return true in a
|
||||||
|
/// decoding-from-matrix setting where we want to allow the last delta or LDA
|
||||||
|
/// features to be flushed out for compatibility with the baseline setup.
|
||||||
|
virtual bool IsLastFrame(int32 frame) const = 0;
|
||||||
|
|
||||||
|
/// The call NumFramesReady() will return the number of frames currently available
|
||||||
|
/// for this decodable object. This is for use in setups where you don't want the
|
||||||
|
/// decoder to block while waiting for input. This is newly added as of Jan 2014,
|
||||||
|
/// and I hope, going forward, to rely on this mechanism more than IsLastFrame to
|
||||||
|
/// know when to stop decoding.
|
||||||
|
virtual int32 NumFramesReady() const {
|
||||||
|
KALDI_ERR << "NumFramesReady() not implemented for this decodable type.";
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the number of states in the acoustic model
|
||||||
|
/// (they will be indexed one-based, i.e. from 1 to NumIndices();
|
||||||
|
/// this is for compatibility with OpenFst).
|
||||||
|
virtual int32 NumIndices() const = 0;
|
||||||
|
|
||||||
|
virtual std::vector<BaseFloat> FrameLogLikelihood(int32 frame);
|
||||||
|
|
||||||
|
virtual ~DecodableInterface() {}
|
||||||
|
};
|
||||||
|
/// @}
|
||||||
|
} // namespace Kaldi
|
||||||
|
|
||||||
|
#endif // KALDI_ITF_DECODABLE_ITF_H_
|
@ -0,0 +1,46 @@
|
|||||||
|
#include "nnet/decodable.h"
|
||||||
|
|
||||||
|
namespace ppspeech {
|
||||||
|
|
||||||
|
using kaldi::BaseFloat;
|
||||||
|
using kaldi::Matrix;
|
||||||
|
|
||||||
|
Decodable::Decodable(const std::shared_ptr<NnetInterface>& nnet):
|
||||||
|
frontend_(NULL),
|
||||||
|
nnet_(nnet),
|
||||||
|
finished_(false),
|
||||||
|
frames_ready_(0) {
|
||||||
|
}
|
||||||
|
|
||||||
|
void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) {
|
||||||
|
frames_ready_ += likelihood.NumRows();
|
||||||
|
}
|
||||||
|
|
||||||
|
//Decodable::Init(DecodableConfig config) {
|
||||||
|
//}
|
||||||
|
|
||||||
|
bool Decodable::IsLastFrame(int32 frame) const {
|
||||||
|
CHECK_LE(frame, frames_ready_);
|
||||||
|
return finished_ && (frame == frames_ready_ - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
int32 Decodable::NumIndices() const {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Decodable::FeedFeatures(const Matrix<kaldi::BaseFloat>& features) {
|
||||||
|
nnet_->FeedForward(features, &nnet_cache_);
|
||||||
|
frames_ready_ += nnet_cache_.NumRows();
|
||||||
|
return ;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Decodable::Reset() {
|
||||||
|
// frontend_.Reset();
|
||||||
|
nnet_->Reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace ppspeech
|
@ -0,0 +1,31 @@
|
|||||||
|
#include "nnet/decodable-itf.h"
|
||||||
|
#include "base/common.h"
|
||||||
|
#include "kaldi/matrix/kaldi-matrix.h"
|
||||||
|
#include "frontend/feature_extractor_interface.h"
|
||||||
|
#include "nnet/nnet_interface.h"
|
||||||
|
|
||||||
|
namespace ppspeech {
|
||||||
|
|
||||||
|
struct DecodableOpts;
|
||||||
|
|
||||||
|
class Decodable : public kaldi::DecodableInterface {
|
||||||
|
public:
|
||||||
|
explicit Decodable(const std::shared_ptr<NnetInterface>& nnet);
|
||||||
|
//void Init(DecodableOpts config);
|
||||||
|
virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index);
|
||||||
|
virtual bool IsLastFrame(int32 frame) const;
|
||||||
|
virtual int32 NumIndices() const;
|
||||||
|
void Acceptlikelihood(const kaldi::Matrix<kaldi::BaseFloat>& likelihood); // remove later
|
||||||
|
void FeedFeatures(const kaldi::Matrix<kaldi::BaseFloat>& feature); // only for test, todo remove later
|
||||||
|
std::vector<BaseFloat> FrameLogLikelihood(int32 frame);
|
||||||
|
void Reset();
|
||||||
|
void InputFinished() { finished_ = true; }
|
||||||
|
private:
|
||||||
|
std::shared_ptr<FeatureExtractorInterface> frontend_;
|
||||||
|
std::shared_ptr<NnetInterface> nnet_;
|
||||||
|
kaldi::Matrix<kaldi::BaseFloat> nnet_cache_;
|
||||||
|
bool finished_;
|
||||||
|
int32 frames_ready_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace ppspeech
|
@ -0,0 +1,19 @@
|
|||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "base/basic_types.h"
|
||||||
|
#include "kaldi/base/kaldi-types.h"
|
||||||
|
#include "kaldi/matrix/kaldi-matrix.h"
|
||||||
|
|
||||||
|
namespace ppspeech {
|
||||||
|
|
||||||
|
class NnetInterface {
|
||||||
|
public:
|
||||||
|
virtual ~NnetInterface() {}
|
||||||
|
virtual void FeedForward(const kaldi::Matrix<kaldi::BaseFloat>& features,
|
||||||
|
kaldi::Matrix<kaldi::BaseFloat>* inferences);
|
||||||
|
virtual void Reset();
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace ppspeech
|
@ -0,0 +1,182 @@
|
|||||||
|
#include "nnet/paddle_nnet.h"
|
||||||
|
#include "absl/strings/str_split.h"
|
||||||
|
|
||||||
|
namespace ppspeech {
|
||||||
|
|
||||||
|
using std::vector;
|
||||||
|
using std::string;
|
||||||
|
using std::shared_ptr;
|
||||||
|
using kaldi::Matrix;
|
||||||
|
|
||||||
|
void PaddleNnet::InitCacheEncouts(const ModelOptions& opts) {
|
||||||
|
std::vector<std::string> cache_names;
|
||||||
|
cache_names = absl::StrSplit(opts.cache_names, ", ");
|
||||||
|
std::vector<std::string> cache_shapes;
|
||||||
|
cache_shapes = absl::StrSplit(opts.cache_shape, ", ");
|
||||||
|
assert(cache_shapes.size() == cache_names.size());
|
||||||
|
|
||||||
|
for (size_t i = 0; i < cache_shapes.size(); i++) {
|
||||||
|
std::vector<std::string> tmp_shape;
|
||||||
|
tmp_shape = absl::StrSplit(cache_shapes[i], "- ");
|
||||||
|
std::vector<int> cur_shape;
|
||||||
|
std::transform(tmp_shape.begin(), tmp_shape.end(),
|
||||||
|
std::back_inserter(cur_shape),
|
||||||
|
[](const std::string& s) {
|
||||||
|
return atoi(s.c_str());
|
||||||
|
});
|
||||||
|
cache_names_idx_[cache_names[i]] = i;
|
||||||
|
std::shared_ptr<Tensor<BaseFloat>> cache_eout = std::make_shared<Tensor<BaseFloat>>(cur_shape);
|
||||||
|
cache_encouts_.push_back(cache_eout);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
PaddleNnet::PaddleNnet(const ModelOptions& opts) {
|
||||||
|
paddle_infer::Config config;
|
||||||
|
config.SetModel(opts.model_path, opts.params_path);
|
||||||
|
if (opts.use_gpu) {
|
||||||
|
config.EnableUseGpu(500, 0);
|
||||||
|
}
|
||||||
|
config.SwitchIrOptim(opts.switch_ir_optim);
|
||||||
|
if (opts.enable_fc_padding) {
|
||||||
|
config.DisableFCPadding();
|
||||||
|
}
|
||||||
|
if (opts.enable_profile) {
|
||||||
|
config.EnableProfile();
|
||||||
|
}
|
||||||
|
pool.reset(new paddle_infer::services::PredictorPool(config, opts.thread_num));
|
||||||
|
if (pool == nullptr) {
|
||||||
|
LOG(ERROR) << "create the predictor pool failed";
|
||||||
|
}
|
||||||
|
pool_usages.resize(opts.thread_num);
|
||||||
|
std::fill(pool_usages.begin(), pool_usages.end(), false);
|
||||||
|
LOG(INFO) << "load paddle model success";
|
||||||
|
|
||||||
|
LOG(INFO) << "start to check the predictor input and output names";
|
||||||
|
LOG(INFO) << "input names: " << opts.input_names;
|
||||||
|
LOG(INFO) << "output names: " << opts.output_names;
|
||||||
|
vector<string> input_names_vec = absl::StrSplit(opts.input_names, ", ");
|
||||||
|
vector<string> output_names_vec = absl::StrSplit(opts.output_names, ", ");
|
||||||
|
paddle_infer::Predictor* predictor = GetPredictor();
|
||||||
|
|
||||||
|
std::vector<std::string> model_input_names = predictor->GetInputNames();
|
||||||
|
assert(input_names_vec.size() == model_input_names.size());
|
||||||
|
for (size_t i = 0; i < model_input_names.size(); i++) {
|
||||||
|
assert(input_names_vec[i] == model_input_names[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> model_output_names = predictor->GetOutputNames();
|
||||||
|
assert(output_names_vec.size() == model_output_names.size());
|
||||||
|
for (size_t i = 0;i < output_names_vec.size(); i++) {
|
||||||
|
assert(output_names_vec[i] == model_output_names[i]);
|
||||||
|
}
|
||||||
|
ReleasePredictor(predictor);
|
||||||
|
|
||||||
|
InitCacheEncouts(opts);
|
||||||
|
}
|
||||||
|
|
||||||
|
paddle_infer::Predictor* PaddleNnet::GetPredictor() {
|
||||||
|
LOG(INFO) << "attempt to get a new predictor instance " << std::endl;
|
||||||
|
paddle_infer::Predictor* predictor = nullptr;
|
||||||
|
std::lock_guard<std::mutex> guard(pool_mutex);
|
||||||
|
int pred_id = 0;
|
||||||
|
|
||||||
|
while (pred_id < pool_usages.size()) {
|
||||||
|
if (pool_usages[pred_id] == false) {
|
||||||
|
predictor = pool->Retrive(pred_id);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
++pred_id;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (predictor) {
|
||||||
|
pool_usages[pred_id] = true;
|
||||||
|
predictor_to_thread_id[predictor] = pred_id;
|
||||||
|
LOG(INFO) << pred_id << " predictor create success";
|
||||||
|
} else {
|
||||||
|
LOG(INFO) << "Failed to get predictor from pool !!!";
|
||||||
|
}
|
||||||
|
|
||||||
|
return predictor;
|
||||||
|
}
|
||||||
|
|
||||||
|
int PaddleNnet::ReleasePredictor(paddle_infer::Predictor* predictor) {
|
||||||
|
LOG(INFO) << "attempt to releae a predictor";
|
||||||
|
std::lock_guard<std::mutex> guard(pool_mutex);
|
||||||
|
auto iter = predictor_to_thread_id.find(predictor);
|
||||||
|
|
||||||
|
if (iter == predictor_to_thread_id.end()) {
|
||||||
|
LOG(INFO) << "there is no such predictor";
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG(INFO) << iter->second << " predictor will be release";
|
||||||
|
pool_usages[iter->second] = false;
|
||||||
|
predictor_to_thread_id.erase(predictor);
|
||||||
|
LOG(INFO) << "release success";
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
shared_ptr<Tensor<BaseFloat>> PaddleNnet::GetCacheEncoder(const string& name) {
|
||||||
|
auto iter = cache_names_idx_.find(name);
|
||||||
|
if (iter == cache_names_idx_.end()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
assert(iter->second < cache_encouts_.size());
|
||||||
|
return cache_encouts_[iter->second];
|
||||||
|
}
|
||||||
|
|
||||||
|
void PaddleNnet::FeedForward(const Matrix<BaseFloat>& features, Matrix<BaseFloat>* inferences) {
|
||||||
|
|
||||||
|
paddle_infer::Predictor* predictor = GetPredictor();
|
||||||
|
// 1. 得到所有的 input tensor 的名称
|
||||||
|
int row = features.NumRows();
|
||||||
|
int col = features.NumCols();
|
||||||
|
std::vector<std::string> input_names = predictor->GetInputNames();
|
||||||
|
std::vector<std::string> output_names = predictor->GetOutputNames();
|
||||||
|
LOG(INFO) << "feat info: row=" << row << ", col=" << col;
|
||||||
|
|
||||||
|
std::unique_ptr<paddle_infer::Tensor> input_tensor = predictor->GetInputHandle(input_names[0]);
|
||||||
|
std::vector<int> INPUT_SHAPE = {1, row, col};
|
||||||
|
input_tensor->Reshape(INPUT_SHAPE);
|
||||||
|
input_tensor->CopyFromCpu(features.Data());
|
||||||
|
// 3. 输入每个音频帧数
|
||||||
|
std::unique_ptr<paddle_infer::Tensor> input_len = predictor->GetInputHandle(input_names[1]);
|
||||||
|
std::vector<int> input_len_size = {1};
|
||||||
|
input_len->Reshape(input_len_size);
|
||||||
|
std::vector<int64_t> audio_len;
|
||||||
|
audio_len.push_back(row);
|
||||||
|
input_len->CopyFromCpu(audio_len.data());
|
||||||
|
// 输入流式的缓存数据
|
||||||
|
std::unique_ptr<paddle_infer::Tensor> h_box = predictor->GetInputHandle(input_names[2]);
|
||||||
|
shared_ptr<Tensor<BaseFloat>> h_cache = GetCacheEncoder(input_names[2]);
|
||||||
|
h_box->Reshape(h_cache->get_shape());
|
||||||
|
h_box->CopyFromCpu(h_cache->get_data().data());
|
||||||
|
std::unique_ptr<paddle_infer::Tensor> c_box = predictor->GetInputHandle(input_names[3]);
|
||||||
|
shared_ptr<Tensor<float>> c_cache = GetCacheEncoder(input_names[3]);
|
||||||
|
c_box->Reshape(c_cache->get_shape());
|
||||||
|
c_box->CopyFromCpu(c_cache->get_data().data());
|
||||||
|
bool success = predictor->Run();
|
||||||
|
|
||||||
|
if (success == false) {
|
||||||
|
LOG(INFO) << "predictor run occurs error";
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG(INFO) << "get the model success";
|
||||||
|
std::unique_ptr<paddle_infer::Tensor> h_out = predictor->GetOutputHandle(output_names[2]);
|
||||||
|
assert(h_cache->get_shape() == h_out->shape());
|
||||||
|
h_out->CopyToCpu(h_cache->get_data().data());
|
||||||
|
std::unique_ptr<paddle_infer::Tensor> c_out = predictor->GetOutputHandle(output_names[3]);
|
||||||
|
assert(c_cache->get_shape() == c_out->shape());
|
||||||
|
c_out->CopyToCpu(c_cache->get_data().data());
|
||||||
|
// 5. 得到最后的输出结果
|
||||||
|
std::unique_ptr<paddle_infer::Tensor> output_tensor =
|
||||||
|
predictor->GetOutputHandle(output_names[0]);
|
||||||
|
std::vector<int> output_shape = output_tensor->shape();
|
||||||
|
row = output_shape[1];
|
||||||
|
col = output_shape[2];
|
||||||
|
inferences->Resize(row, col);
|
||||||
|
output_tensor->CopyToCpu(inferences->Data());
|
||||||
|
ReleasePredictor(predictor);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace ppspeech
|
@ -0,0 +1,108 @@
|
|||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "nnet/nnet_interface.h"
|
||||||
|
#include "base/common.h"
|
||||||
|
#include "paddle_inference_api.h"
|
||||||
|
|
||||||
|
#include "kaldi/matrix/kaldi-matrix.h"
|
||||||
|
#include "kaldi/util/options-itf.h"
|
||||||
|
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
|
namespace ppspeech {
|
||||||
|
|
||||||
|
struct ModelOptions {
|
||||||
|
std::string model_path;
|
||||||
|
std::string params_path;
|
||||||
|
int thread_num;
|
||||||
|
bool use_gpu;
|
||||||
|
bool switch_ir_optim;
|
||||||
|
std::string input_names;
|
||||||
|
std::string output_names;
|
||||||
|
std::string cache_names;
|
||||||
|
std::string cache_shape;
|
||||||
|
bool enable_fc_padding;
|
||||||
|
bool enable_profile;
|
||||||
|
ModelOptions() :
|
||||||
|
model_path("model/final.zip"),
|
||||||
|
params_path("model/avg_1.jit.pdmodel"),
|
||||||
|
thread_num(2),
|
||||||
|
use_gpu(false),
|
||||||
|
input_names("audio"),
|
||||||
|
output_names("probs"),
|
||||||
|
cache_names("enouts"),
|
||||||
|
cache_shape("1-1-1"),
|
||||||
|
switch_ir_optim(false),
|
||||||
|
enable_fc_padding(false),
|
||||||
|
enable_profile(false) {
|
||||||
|
}
|
||||||
|
|
||||||
|
void Register(kaldi::OptionsItf* opts) {
|
||||||
|
opts->Register("model-path", &model_path, "model file path");
|
||||||
|
opts->Register("model-params", ¶ms_path, "params model file path");
|
||||||
|
opts->Register("thread-num", &thread_num, "thread num");
|
||||||
|
opts->Register("use-gpu", &use_gpu, "if use gpu");
|
||||||
|
opts->Register("input-names", &input_names, "paddle input names");
|
||||||
|
opts->Register("output-names", &output_names, "paddle output names");
|
||||||
|
opts->Register("cache-names", &cache_names, "cache names");
|
||||||
|
opts->Register("cache-shape", &cache_shape, "cache shape");
|
||||||
|
opts->Register("switch-ir-optiom", &switch_ir_optim, "paddle SwitchIrOptim option");
|
||||||
|
opts->Register("enable-fc-padding", &enable_fc_padding, "paddle EnableFCPadding option");
|
||||||
|
opts->Register("enable-profile", &enable_profile, "paddle EnableProfile option");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
class Tensor {
|
||||||
|
public:
|
||||||
|
Tensor() {
|
||||||
|
}
|
||||||
|
Tensor(const std::vector<int>& shape) :
|
||||||
|
_shape(shape) {
|
||||||
|
int data_size = std::accumulate(_shape.begin(), _shape.end(),
|
||||||
|
1, std::multiplies<int>());
|
||||||
|
LOG(INFO) << "data size: " << data_size;
|
||||||
|
_data.resize(data_size, 0);
|
||||||
|
}
|
||||||
|
void reshape(const std::vector<int>& shape) {
|
||||||
|
_shape = shape;
|
||||||
|
int data_size = std::accumulate(_shape.begin(), _shape.end(),
|
||||||
|
1, std::multiplies<int>());
|
||||||
|
_data.resize(data_size, 0);
|
||||||
|
}
|
||||||
|
const std::vector<int>& get_shape() const {
|
||||||
|
return _shape;
|
||||||
|
}
|
||||||
|
std::vector<T>& get_data() {
|
||||||
|
return _data;
|
||||||
|
}
|
||||||
|
private:
|
||||||
|
std::vector<int> _shape;
|
||||||
|
std::vector<T> _data;
|
||||||
|
};
|
||||||
|
|
||||||
|
class PaddleNnet : public NnetInterface {
|
||||||
|
public:
|
||||||
|
PaddleNnet(const ModelOptions& opts);
|
||||||
|
virtual void FeedForward(const kaldi::Matrix<kaldi::BaseFloat>& features,
|
||||||
|
kaldi::Matrix<kaldi::BaseFloat>* inferences);
|
||||||
|
std::shared_ptr<Tensor<kaldi::BaseFloat>> GetCacheEncoder(const std::string& name);
|
||||||
|
void InitCacheEncouts(const ModelOptions& opts);
|
||||||
|
|
||||||
|
private:
|
||||||
|
paddle_infer::Predictor* GetPredictor();
|
||||||
|
int ReleasePredictor(paddle_infer::Predictor* predictor);
|
||||||
|
|
||||||
|
std::unique_ptr<paddle_infer::services::PredictorPool> pool;
|
||||||
|
std::vector<bool> pool_usages;
|
||||||
|
std::mutex pool_mutex;
|
||||||
|
std::map<paddle_infer::Predictor*, int> predictor_to_thread_id;
|
||||||
|
std::map<std::string, int> cache_names_idx_;
|
||||||
|
std::vector<std::shared_ptr<Tensor<kaldi::BaseFloat>>> cache_encouts_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
DISALLOW_COPY_AND_ASSIGN(PaddleNnet);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace ppspeech
|
@ -0,0 +1,4 @@
|
|||||||
|
|
||||||
|
add_library(utils
|
||||||
|
file_utils.cc
|
||||||
|
)
|
@ -0,0 +1,17 @@
|
|||||||
|
#include "utils/file_utils.h"
|
||||||
|
|
||||||
|
bool ReadFileToVector(const std::string& filename,
|
||||||
|
std::vector<std::string>* vocabulary) {
|
||||||
|
std::ifstream file_in(filename);
|
||||||
|
if (!file_in) {
|
||||||
|
std::cerr << "please input a valid file" << std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string line;
|
||||||
|
while (std::getline(file_in, line)) {
|
||||||
|
vocabulary->emplace_back(line);
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
@ -0,0 +1,8 @@
|
|||||||
|
#include "base/common.h"
|
||||||
|
|
||||||
|
namespace ppspeech {
|
||||||
|
|
||||||
|
bool ReadFileToVector(const std::string& filename,
|
||||||
|
std::vector<std::string>* data);
|
||||||
|
|
||||||
|
}
|
Loading…
Reference in new issue