Merge pull request #1400 from SmileGoat/feature_dev
[speechx]add linear spectrogram feature extractorpull/1495/head
commit
b584b9690f
@ -0,0 +1,36 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <deque>
|
||||
#include <iostream>
|
||||
#include <istream>
|
||||
#include <fstream>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <ostream>
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
#include <stack>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <mutex>
|
||||
|
||||
#include "base/log.h"
|
||||
#include "base/flags.h"
|
||||
#include "base/basic_types.h"
|
||||
#include "base/macros.h"
|
@ -0,0 +1,17 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "gflags/gflags.h"
|
@ -0,0 +1,17 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "glog/logging.h"
|
@ -0,0 +1,120 @@
|
||||
// Copyright (c) 2012 Jakob Progsch, Václav Zeman
|
||||
|
||||
// This software is provided 'as-is', without any express or implied
|
||||
// warranty. In no event will the authors be held liable for any damages
|
||||
// arising from the use of this software.
|
||||
|
||||
// Permission is granted to anyone to use this software for any purpose,
|
||||
// including commercial applications, and to alter it and redistribute it
|
||||
// freely, subject to the following restrictions:
|
||||
|
||||
// 1. The origin of this software must not be misrepresented; you must not
|
||||
// claim that you wrote the original software. If you use this software
|
||||
// in a product, an acknowledgment in the product documentation would be
|
||||
// appreciated but is not required.
|
||||
|
||||
// 2. Altered source versions must be plainly marked as such, and must not be
|
||||
// misrepresented as being the original software.
|
||||
|
||||
// 3. This notice may not be removed or altered from any source
|
||||
// distribution.
|
||||
// this code is from https://github.com/progschj/ThreadPool
|
||||
|
||||
#ifndef BASE_THREAD_POOL_H
|
||||
#define BASE_THREAD_POOL_H
|
||||
|
||||
#include <vector>
|
||||
#include <queue>
|
||||
#include <memory>
|
||||
#include <thread>
|
||||
#include <mutex>
|
||||
#include <condition_variable>
|
||||
#include <future>
|
||||
#include <functional>
|
||||
#include <stdexcept>
|
||||
|
||||
class ThreadPool {
|
||||
public:
|
||||
ThreadPool(size_t);
|
||||
template<class F, class... Args>
|
||||
auto enqueue(F&& f, Args&&... args)
|
||||
-> std::future<typename std::result_of<F(Args...)>::type>;
|
||||
~ThreadPool();
|
||||
private:
|
||||
// need to keep track of threads so we can join them
|
||||
std::vector< std::thread > workers;
|
||||
// the task queue
|
||||
std::queue< std::function<void()> > tasks;
|
||||
|
||||
// synchronization
|
||||
std::mutex queue_mutex;
|
||||
std::condition_variable condition;
|
||||
bool stop;
|
||||
};
|
||||
|
||||
// the constructor just launches some amount of workers
|
||||
inline ThreadPool::ThreadPool(size_t threads)
|
||||
: stop(false)
|
||||
{
|
||||
for(size_t i = 0;i<threads;++i)
|
||||
workers.emplace_back(
|
||||
[this]
|
||||
{
|
||||
for(;;)
|
||||
{
|
||||
std::function<void()> task;
|
||||
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(this->queue_mutex);
|
||||
this->condition.wait(lock,
|
||||
[this]{ return this->stop || !this->tasks.empty(); });
|
||||
if(this->stop && this->tasks.empty())
|
||||
return;
|
||||
task = std::move(this->tasks.front());
|
||||
this->tasks.pop();
|
||||
}
|
||||
|
||||
task();
|
||||
}
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
// add new work item to the pool
|
||||
template<class F, class... Args>
|
||||
auto ThreadPool::enqueue(F&& f, Args&&... args)
|
||||
-> std::future<typename std::result_of<F(Args...)>::type>
|
||||
{
|
||||
using return_type = typename std::result_of<F(Args...)>::type;
|
||||
|
||||
auto task = std::make_shared< std::packaged_task<return_type()> >(
|
||||
std::bind(std::forward<F>(f), std::forward<Args>(args)...)
|
||||
);
|
||||
|
||||
std::future<return_type> res = task->get_future();
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(queue_mutex);
|
||||
|
||||
// don't allow enqueueing after stopping the pool
|
||||
if(stop)
|
||||
throw std::runtime_error("enqueue on stopped ThreadPool");
|
||||
|
||||
tasks.emplace([task](){ (*task)(); });
|
||||
}
|
||||
condition.notify_one();
|
||||
return res;
|
||||
}
|
||||
|
||||
// the destructor joins all threads
|
||||
inline ThreadPool::~ThreadPool()
|
||||
{
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(queue_mutex);
|
||||
stop = true;
|
||||
}
|
||||
condition.notify_all();
|
||||
for(std::thread &worker: workers)
|
||||
worker.join();
|
||||
}
|
||||
|
||||
#endif
|
@ -0,0 +1,58 @@
|
||||
// todo refactor, repalce with gtest
|
||||
|
||||
#include "decoder/ctc_beam_search_decoder.h"
|
||||
#include "kaldi/util/table-types.h"
|
||||
#include "base/log.h"
|
||||
#include "base/flags.h"
|
||||
|
||||
DEFINE_string(feature_respecifier, "", "test nnet prob");
|
||||
|
||||
using kaldi::BaseFloat;
|
||||
|
||||
void SplitFeature(kaldi::Matrix<BaseFloat> feature,
|
||||
int32 chunk_size,
|
||||
std::vector<kaldi::Matrix<BaseFloat>> feature_chunks) {
|
||||
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
|
||||
kaldi::SequentialBaseFloatMatrixReader feature_reader(FLAGS_feature_respecifier);
|
||||
|
||||
// test nnet_output --> decoder result
|
||||
int32 num_done = 0, num_err = 0;
|
||||
|
||||
CTCBeamSearchOptions opts;
|
||||
CTCBeamSearch decoder(opts);
|
||||
|
||||
ModelOptions model_opts;
|
||||
std::shared_ptr<PaddleNnet> nnet(new PaddleNnet(model_opts));
|
||||
|
||||
Decodable decodable();
|
||||
decodable.SetNnet(nnet);
|
||||
|
||||
int32 chunk_size = 0;
|
||||
for (; !feature_reader.Done(); feature_reader.Next()) {
|
||||
string utt = feature_reader.Key();
|
||||
const kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
|
||||
vector<Matrix<BaseFloat>> feature_chunks;
|
||||
SplitFeature(feature, chunk_size, &feature_chunks);
|
||||
for (auto feature_chunk : feature_chunks) {
|
||||
decodable.FeedFeatures(feature_chunk);
|
||||
decoder.InitDecoder();
|
||||
decoder.AdvanceDecode(decodable, chunk_size);
|
||||
}
|
||||
decodable.InputFinished();
|
||||
std::string result;
|
||||
result = decoder.GetFinalBestPath();
|
||||
KALDI_LOG << " the result of " << utt << " is " << result;
|
||||
decodable.Reset();
|
||||
++num_done;
|
||||
}
|
||||
|
||||
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
|
||||
<< " with errors.";
|
||||
return (num_done != 0 ? 0 : 1);
|
||||
}
|
@ -0,0 +1,47 @@
|
||||
// todo refactor, repalce with gtest
|
||||
|
||||
#include "frontend/linear_spectrogram.h"
|
||||
#include "frontend/normalizer.h"
|
||||
#include "frontend/feature_extractor_interface.h"
|
||||
#include "kaldi/util/table-types.h"
|
||||
#include "base/log.h"
|
||||
#include "base/flags.h"
|
||||
#include "kaldi/feat/wave-reader.h"
|
||||
|
||||
DEFINE_string(wav_rspecifier, "", "test wav path");
|
||||
DEFINE_string(feature_wspecifier, "", "test wav ark");
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
|
||||
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(FLAGS_wav_rspecifier);
|
||||
kaldi::BaseFloatMatrixWriter feat_writer(FLAGS_feature_wspecifier);
|
||||
|
||||
// test feature linear_spectorgram: wave --> decibel_normalizer --> hanning window -->linear_spectrogram --> cmvn
|
||||
int32 num_done = 0, num_err = 0;
|
||||
ppspeech::LinearSpectrogramOptions opt;
|
||||
ppspeech::DecibelNormalizerOptions db_norm_opt;
|
||||
std::unique_ptr<ppspeech::FeatureExtractorInterface> base_feature_extractor(
|
||||
new ppspeech::DecibelNormalizer(db_norm_opt));
|
||||
ppspeech::LinearSpectrogram linear_spectrogram(opt, std::move(base_feature_extractor));
|
||||
|
||||
for (; !wav_reader.Done(); wav_reader.Next()) {
|
||||
std::string utt = wav_reader.Key();
|
||||
const kaldi::WaveData &wave_data = wav_reader.Value();
|
||||
|
||||
int32 this_channel = 0;
|
||||
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(), this_channel);
|
||||
kaldi::Matrix<BaseFloat> features;
|
||||
linear_spectrogram.AcceptWaveform(waveform);
|
||||
linear_spectrogram.ReadFeats(&features);
|
||||
|
||||
feat_writer.Write(utt, features);
|
||||
if (num_done % 50 == 0 && num_done != 0)
|
||||
KALDI_VLOG(2) << "Processed " << num_done << " utterances";
|
||||
num_done++;
|
||||
}
|
||||
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
|
||||
<< " with errors.";
|
||||
return (num_done != 0 ? 0 : 1);
|
||||
}
|
@ -1,2 +1,10 @@
|
||||
aux_source_directory(. DIR_LIB_SRCS)
|
||||
add_library(decoder STATIC ${DIR_LIB_SRCS})
|
||||
project(decoder)
|
||||
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR/ctc_decoders})
|
||||
add_library(decoder
|
||||
ctc_beam_search_decoder.cc
|
||||
ctc_decoders/decoder_utils.cpp
|
||||
ctc_decoders/path_trie.cpp
|
||||
ctc_decoders/scorer.cpp
|
||||
)
|
||||
target_link_libraries(decoder kenlm)
|
@ -0,0 +1,7 @@
|
||||
#include "base/basic_types.h"
|
||||
|
||||
struct DecoderResult {
|
||||
BaseFloat acoustic_score;
|
||||
std::vector<int32> words_idx;
|
||||
std::vector<pair<int32, int32>> time_stamp;
|
||||
};
|
@ -0,0 +1,300 @@
|
||||
#include "decoder/ctc_beam_search_decoder.h"
|
||||
|
||||
#include "base/basic_types.h"
|
||||
#include "decoder/ctc_decoders/decoder_utils.h"
|
||||
#include "utils/file_utils.h"
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
using std::vector;
|
||||
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
|
||||
|
||||
CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts) :
|
||||
opts_(opts),
|
||||
init_ext_scorer_(nullptr),
|
||||
blank_id(-1),
|
||||
space_id(-1),
|
||||
num_frame_decoded_(0),
|
||||
root(nullptr) {
|
||||
|
||||
LOG(INFO) << "dict path: " << opts_.dict_file;
|
||||
if (!ReadFileToVector(opts_.dict_file, &vocabulary_)) {
|
||||
LOG(INFO) << "load the dict failed";
|
||||
}
|
||||
LOG(INFO) << "read the vocabulary success, dict size: " << vocabulary_.size();
|
||||
|
||||
LOG(INFO) << "language model path: " << opts_.lm_path;
|
||||
init_ext_scorer_ = std::make_shared<Scorer>(opts_.alpha,
|
||||
opts_.beta,
|
||||
opts_.lm_path,
|
||||
vocabulary_);
|
||||
}
|
||||
|
||||
void CTCBeamSearch::Reset() {
|
||||
num_frame_decoded_ = 0;
|
||||
ResetPrefixes();
|
||||
}
|
||||
|
||||
void CTCBeamSearch::InitDecoder() {
|
||||
|
||||
blank_id = 0;
|
||||
auto it = std::find(vocabulary_.begin(), vocabulary_.end(), " ");
|
||||
|
||||
space_id = it - vocabulary_.begin();
|
||||
// if no space in vocabulary
|
||||
if ((size_t)space_id >= vocabulary_.size()) {
|
||||
space_id = -2;
|
||||
}
|
||||
|
||||
ResetPrefixes();
|
||||
|
||||
root = std::make_shared<PathTrie>();
|
||||
root->score = root->log_prob_b_prev = 0.0;
|
||||
prefixes.push_back(root.get());
|
||||
if (init_ext_scorer_ != nullptr && !init_ext_scorer_->is_character_based()) {
|
||||
auto fst_dict =
|
||||
static_cast<fst::StdVectorFst *>(init_ext_scorer_->dictionary);
|
||||
fst::StdVectorFst *dict_ptr = fst_dict->Copy(true);
|
||||
root->set_dictionary(dict_ptr);
|
||||
|
||||
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
|
||||
root->set_matcher(matcher);
|
||||
}
|
||||
}
|
||||
|
||||
void CTCBeamSearch::Decode(std::shared_ptr<kaldi::DecodableInterface> decodable) {
|
||||
return;
|
||||
}
|
||||
|
||||
int32 CTCBeamSearch::NumFrameDecoded() {
|
||||
return num_frame_decoded_;
|
||||
}
|
||||
|
||||
// todo rename, refactor
|
||||
void CTCBeamSearch::AdvanceDecode(const std::shared_ptr<kaldi::DecodableInterface>& decodable,
|
||||
int max_frames) {
|
||||
while (max_frames > 0) {
|
||||
vector<vector<BaseFloat>> likelihood;
|
||||
if (decodable->IsLastFrame(NumFrameDecoded() + 1)) {
|
||||
break;
|
||||
}
|
||||
likelihood.push_back(decodable->FrameLogLikelihood(NumFrameDecoded() + 1));
|
||||
AdvanceDecoding(likelihood);
|
||||
max_frames--;
|
||||
}
|
||||
}
|
||||
|
||||
void CTCBeamSearch::ResetPrefixes() {
|
||||
for (size_t i = 0; i < prefixes.size(); i++) {
|
||||
if (prefixes[i] != nullptr) {
|
||||
delete prefixes[i];
|
||||
prefixes[i] = nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int CTCBeamSearch::DecodeLikelihoods(const vector<vector<float>>&probs,
|
||||
vector<string>& nbest_words) {
|
||||
kaldi::Timer timer;
|
||||
timer.Reset();
|
||||
AdvanceDecoding(probs);
|
||||
LOG(INFO) <<"ctc decoding elapsed time(s) " << static_cast<float>(timer.Elapsed()) / 1000.0f;
|
||||
return 0;
|
||||
}
|
||||
|
||||
vector<std::pair<double, string>> CTCBeamSearch::GetNBestPath() {
|
||||
return get_beam_search_result(prefixes, vocabulary_, opts_.beam_size);
|
||||
}
|
||||
|
||||
string CTCBeamSearch::GetBestPath() {
|
||||
std::vector<std::pair<double, std::string>> result;
|
||||
result = get_beam_search_result(prefixes, vocabulary_, opts_.beam_size);
|
||||
return result[0].second;
|
||||
}
|
||||
|
||||
string CTCBeamSearch::GetFinalBestPath() {
|
||||
CalculateApproxScore();
|
||||
LMRescore();
|
||||
return GetBestPath();
|
||||
}
|
||||
|
||||
void CTCBeamSearch::AdvanceDecoding(const vector<vector<BaseFloat>>& probs) {
|
||||
size_t num_time_steps = probs.size();
|
||||
size_t beam_size = opts_.beam_size;
|
||||
double cutoff_prob = opts_.cutoff_prob;
|
||||
size_t cutoff_top_n = opts_.cutoff_top_n;
|
||||
|
||||
vector<vector<double>> probs_seq(probs.size(), vector<double>(probs[0].size(), 0));
|
||||
|
||||
int row = probs.size();
|
||||
int col = probs[0].size();
|
||||
for(int i = 0; i < row; i++) {
|
||||
for (int j = 0; j < col; j++){
|
||||
probs_seq[i][j] = static_cast<double>(probs[i][j]);
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t time_step = 0; time_step < num_time_steps; time_step++) {
|
||||
const auto& prob = probs_seq[time_step];
|
||||
|
||||
float min_cutoff = -NUM_FLT_INF;
|
||||
bool full_beam = false;
|
||||
if (init_ext_scorer_ != nullptr) {
|
||||
size_t num_prefixes = std::min(prefixes.size(), beam_size);
|
||||
std::sort(prefixes.begin(), prefixes.begin() + num_prefixes,
|
||||
prefix_compare);
|
||||
|
||||
if (num_prefixes == 0) {
|
||||
continue;
|
||||
}
|
||||
min_cutoff = prefixes[num_prefixes - 1]->score +
|
||||
std::log(prob[blank_id]) -
|
||||
std::max(0.0, init_ext_scorer_->beta);
|
||||
|
||||
full_beam = (num_prefixes == beam_size);
|
||||
}
|
||||
|
||||
vector<std::pair<size_t, float>> log_prob_idx =
|
||||
get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n);
|
||||
|
||||
// loop over chars
|
||||
size_t log_prob_idx_len = log_prob_idx.size();
|
||||
for (size_t index = 0; index < log_prob_idx_len; index++) {
|
||||
SearchOneChar(full_beam, log_prob_idx[index], min_cutoff);
|
||||
}
|
||||
|
||||
prefixes.clear();
|
||||
|
||||
// update log probs
|
||||
root->iterate_to_vec(prefixes);
|
||||
// only preserve top beam_size prefixes
|
||||
if (prefixes.size() >= beam_size) {
|
||||
std::nth_element(prefixes.begin(),
|
||||
prefixes.begin() + beam_size,
|
||||
prefixes.end(),
|
||||
prefix_compare);
|
||||
for (size_t i = beam_size; i < prefixes.size(); ++i) {
|
||||
prefixes[i]->remove();
|
||||
}
|
||||
} // if
|
||||
num_frame_decoded_++;
|
||||
} // for probs_seq
|
||||
}
|
||||
|
||||
int32 CTCBeamSearch::SearchOneChar(const bool& full_beam,
|
||||
const std::pair<size_t, BaseFloat>& log_prob_idx,
|
||||
const BaseFloat& min_cutoff) {
|
||||
size_t beam_size = opts_.beam_size;
|
||||
const auto& c = log_prob_idx.first;
|
||||
const auto& log_prob_c = log_prob_idx.second;
|
||||
size_t prefixes_len = std::min(prefixes.size(), beam_size);
|
||||
|
||||
for (size_t i = 0; i < prefixes_len; ++i) {
|
||||
auto prefix = prefixes[i];
|
||||
if (full_beam && log_prob_c + prefix->score < min_cutoff) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (c == blank_id) {
|
||||
prefix->log_prob_b_cur = log_sum_exp(
|
||||
prefix->log_prob_b_cur,
|
||||
log_prob_c +
|
||||
prefix->score);
|
||||
continue;
|
||||
}
|
||||
|
||||
// repeated character
|
||||
if (c == prefix->character) {
|
||||
// p_{nb}(l;x_{1:t}) = p(c;x_{t})p(l;x_{1:t-1})
|
||||
prefix->log_prob_nb_cur = log_sum_exp(
|
||||
prefix->log_prob_nb_cur,
|
||||
log_prob_c +
|
||||
prefix->log_prob_nb_prev);
|
||||
}
|
||||
|
||||
// get new prefix
|
||||
auto prefix_new = prefix->get_path_trie(c);
|
||||
if (prefix_new != nullptr) {
|
||||
float log_p = -NUM_FLT_INF;
|
||||
if (c == prefix->character &&
|
||||
prefix->log_prob_b_prev > -NUM_FLT_INF) {
|
||||
// p_{nb}(l^{+};x_{1:t}) = p(c;x_{t})p_{b}(l;x_{1:t-1})
|
||||
log_p = log_prob_c + prefix->log_prob_b_prev;
|
||||
} else if (c != prefix->character) {
|
||||
// p_{nb}(l^{+};x_{1:t}) = p(c;x_{t}) p(l;x_{1:t-1})
|
||||
log_p = log_prob_c + prefix->score;
|
||||
}
|
||||
|
||||
// language model scoring
|
||||
if (init_ext_scorer_ != nullptr &&
|
||||
(c == space_id || init_ext_scorer_->is_character_based())) {
|
||||
PathTrie *prefix_to_score = nullptr;
|
||||
// skip scoring the space
|
||||
if (init_ext_scorer_->is_character_based()) {
|
||||
prefix_to_score = prefix_new;
|
||||
} else {
|
||||
prefix_to_score = prefix;
|
||||
}
|
||||
|
||||
float score = 0.0;
|
||||
vector<string> ngram;
|
||||
ngram = init_ext_scorer_->make_ngram(prefix_to_score);
|
||||
// lm score: p_{lm}(W)^{\alpha} + \beta
|
||||
score = init_ext_scorer_->get_log_cond_prob(ngram) *
|
||||
init_ext_scorer_->alpha;
|
||||
log_p += score;
|
||||
log_p += init_ext_scorer_->beta;
|
||||
}
|
||||
// p_{nb}(l;x_{1:t})
|
||||
prefix_new->log_prob_nb_cur =
|
||||
log_sum_exp(prefix_new->log_prob_nb_cur,
|
||||
log_p);
|
||||
}
|
||||
} // end of loop over prefix
|
||||
return 0;
|
||||
}
|
||||
|
||||
void CTCBeamSearch::CalculateApproxScore() {
|
||||
size_t beam_size = opts_.beam_size;
|
||||
size_t num_prefixes = std::min(prefixes.size(), beam_size);
|
||||
std::sort(
|
||||
prefixes.begin(),
|
||||
prefixes.begin() + num_prefixes,
|
||||
prefix_compare);
|
||||
|
||||
// compute aproximate ctc score as the return score, without affecting the
|
||||
// return order of decoding result. To delete when decoder gets stable.
|
||||
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
|
||||
double approx_ctc = prefixes[i]->score;
|
||||
if (init_ext_scorer_ != nullptr) {
|
||||
vector<int> output;
|
||||
prefixes[i]->get_path_vec(output);
|
||||
auto prefix_length = output.size();
|
||||
auto words = init_ext_scorer_->split_labels(output);
|
||||
// remove word insert
|
||||
approx_ctc = approx_ctc - prefix_length * init_ext_scorer_->beta;
|
||||
// remove language model weight:
|
||||
approx_ctc -=
|
||||
(init_ext_scorer_->get_sent_log_prob(words)) * init_ext_scorer_->alpha;
|
||||
}
|
||||
prefixes[i]->approx_ctc = approx_ctc;
|
||||
}
|
||||
}
|
||||
|
||||
void CTCBeamSearch::LMRescore() {
|
||||
size_t beam_size = opts_.beam_size;
|
||||
if (init_ext_scorer_ != nullptr && !init_ext_scorer_->is_character_based()) {
|
||||
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
|
||||
auto prefix = prefixes[i];
|
||||
if (!prefix->is_empty() && prefix->character != space_id) {
|
||||
float score = 0.0;
|
||||
vector<string> ngram = init_ext_scorer_->make_ngram(prefix);
|
||||
score = init_ext_scorer_->get_log_cond_prob(ngram) * init_ext_scorer_->alpha;
|
||||
score += init_ext_scorer_->beta;
|
||||
prefix->score += score;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ppspeech
|
@ -0,0 +1,79 @@
|
||||
#include "base/common.h"
|
||||
#include "nnet/decodable-itf.h"
|
||||
#include "util/parse-options.h"
|
||||
#include "decoder/ctc_decoders/scorer.h"
|
||||
#include "decoder/ctc_decoders/path_trie.h"
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
struct CTCBeamSearchOptions {
|
||||
std::string dict_file;
|
||||
std::string lm_path;
|
||||
BaseFloat alpha;
|
||||
BaseFloat beta;
|
||||
BaseFloat cutoff_prob;
|
||||
int beam_size;
|
||||
int cutoff_top_n;
|
||||
int num_proc_bsearch;
|
||||
CTCBeamSearchOptions() :
|
||||
dict_file("./model/words.txt"),
|
||||
lm_path("./model/lm.arpa"),
|
||||
alpha(1.9f),
|
||||
beta(5.0),
|
||||
beam_size(300),
|
||||
cutoff_prob(0.99f),
|
||||
cutoff_top_n(40),
|
||||
num_proc_bsearch(0) {
|
||||
}
|
||||
|
||||
void Register(kaldi::OptionsItf* opts) {
|
||||
opts->Register("dict", &dict_file, "dict file ");
|
||||
opts->Register("lm-path", &lm_path, "language model file");
|
||||
opts->Register("alpha", &alpha, "alpha");
|
||||
opts->Register("beta", &beta, "beta");
|
||||
opts->Register("beam-size", &beam_size, "beam size for beam search method");
|
||||
opts->Register("cutoff-prob", &cutoff_prob, "cutoff probs");
|
||||
opts->Register("cutoff-top-n", &cutoff_top_n, "cutoff top n");
|
||||
opts->Register("num-proc-bsearch", &num_proc_bsearch, "num proc bsearch");
|
||||
}
|
||||
};
|
||||
|
||||
class CTCBeamSearch {
|
||||
public:
|
||||
explicit CTCBeamSearch(const CTCBeamSearchOptions& opts);
|
||||
~CTCBeamSearch() {}
|
||||
void InitDecoder();
|
||||
void Decode(std::shared_ptr<kaldi::DecodableInterface> decodable);
|
||||
std::string GetBestPath();
|
||||
std::vector<std::pair<double, std::string>> GetNBestPath();
|
||||
std::string GetFinalBestPath();
|
||||
int NumFrameDecoded();
|
||||
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>&probs,
|
||||
std::vector<std::string>& nbest_words);
|
||||
void AdvanceDecode(const std::shared_ptr<kaldi::DecodableInterface>& decodable,
|
||||
int max_frames);
|
||||
void Reset();
|
||||
private:
|
||||
void ResetPrefixes();
|
||||
int32 SearchOneChar(const bool& full_beam,
|
||||
const std::pair<size_t, BaseFloat>& log_prob_idx,
|
||||
const BaseFloat& min_cutoff);
|
||||
void CalculateApproxScore();
|
||||
void LMRescore();
|
||||
void AdvanceDecoding(const std::vector<std::vector<BaseFloat>>& probs);
|
||||
|
||||
CTCBeamSearchOptions opts_;
|
||||
std::shared_ptr<Scorer> init_ext_scorer_; // todo separate later
|
||||
//std::vector<DecodeResult> decoder_results_;
|
||||
std::vector<std::string> vocabulary_; // todo remove later
|
||||
size_t blank_id;
|
||||
int space_id;
|
||||
std::shared_ptr<PathTrie> root;
|
||||
std::vector<PathTrie*> prefixes;
|
||||
int num_frame_decoded_;
|
||||
DISALLOW_COPY_AND_ASSIGN(CTCBeamSearch);
|
||||
};
|
||||
|
||||
} // namespace basr
|
@ -0,0 +1 @@
|
||||
../../../third_party/ctc_decoders
|
@ -0,0 +1,8 @@
|
||||
project(frontend)
|
||||
|
||||
add_library(frontend
|
||||
normalizer.cc
|
||||
linear_spectrogram.cc
|
||||
)
|
||||
|
||||
target_link_libraries(frontend kaldi-matrix)
|
@ -0,0 +1,36 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// wrap the fbank feat of kaldi, todo (SmileGoat)
|
||||
|
||||
#include "kaldi/feat/feature-mfcc.h"
|
||||
|
||||
#incldue "kaldi/matrix/kaldi-vector.h"
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
class FbankExtractor : FeatureExtractorInterface {
|
||||
public:
|
||||
explicit FbankExtractor(const FbankOptions& opts,
|
||||
share_ptr<FeatureExtractorInterface> pre_extractor);
|
||||
virtual void AcceptWaveform(const kaldi::Vector<kaldi::BaseFloat>& input) = 0;
|
||||
virtual void Read(kaldi::Vector<kaldi::BaseFloat>* feat) = 0;
|
||||
virtual size_t Dim() const = 0;
|
||||
|
||||
private:
|
||||
bool Compute(const kaldi::Vector<kaldi::BaseFloat>& wave,
|
||||
kaldi::Vector<kaldi::BaseFloat>* feat) const;
|
||||
};
|
||||
|
||||
} // namespace ppspeech
|
@ -0,0 +1,29 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "base/basic_types.h"
|
||||
#include "kaldi/matrix/kaldi-vector.h"
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
class FeatureExtractorInterface {
|
||||
public:
|
||||
virtual void AcceptWaveform(const kaldi::VectorBase<kaldi::BaseFloat>& input) = 0;
|
||||
virtual void Read(kaldi::VectorBase<kaldi::BaseFloat>* feat) = 0;
|
||||
virtual size_t Dim() const = 0;
|
||||
};
|
||||
|
||||
} // namespace ppspeech
|
@ -0,0 +1,179 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "frontend/linear_spectrogram.h"
|
||||
#include "kaldi/base/kaldi-math.h"
|
||||
#include "kaldi/matrix/matrix-functions.h"
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
using kaldi::int32;
|
||||
using kaldi::BaseFloat;
|
||||
using kaldi::Vector;
|
||||
using kaldi::VectorBase;
|
||||
using kaldi::Matrix;
|
||||
using std::vector;
|
||||
|
||||
//todo remove later
|
||||
void CopyVector2StdVector_(const VectorBase<BaseFloat>& input,
|
||||
vector<BaseFloat>* output) {
|
||||
if (input.Dim() == 0) return;
|
||||
output->resize(input.Dim());
|
||||
for (size_t idx = 0; idx < input.Dim(); ++idx) {
|
||||
(*output)[idx] = input(idx);
|
||||
}
|
||||
}
|
||||
|
||||
void CopyStdVector2Vector_(const vector<BaseFloat>& input,
|
||||
Vector<BaseFloat>* output) {
|
||||
if (input.empty()) return;
|
||||
output->Resize(input.size());
|
||||
for (size_t idx = 0; idx < input.size(); ++idx) {
|
||||
(*output)(idx) = input[idx];
|
||||
}
|
||||
}
|
||||
|
||||
LinearSpectrogram::LinearSpectrogram(
|
||||
const LinearSpectrogramOptions& opts,
|
||||
std::unique_ptr<FeatureExtractorInterface> base_extractor) {
|
||||
base_extractor_ = std::move(base_extractor);
|
||||
int32 window_size = opts.frame_opts.WindowSize();
|
||||
int32 window_shift = opts.frame_opts.WindowShift();
|
||||
fft_points_ = window_size;
|
||||
hanning_window_.resize(window_size);
|
||||
|
||||
double a = M_2PI / (window_size - 1);
|
||||
hanning_window_energy_ = 0;
|
||||
for (int i = 0; i < window_size; ++i) {
|
||||
hanning_window_[i] = 0.5 - 0.5 * cos(a * i);
|
||||
hanning_window_energy_ += hanning_window_[i] * hanning_window_[i];
|
||||
}
|
||||
|
||||
dim_ = fft_points_ / 2 + 1; // the dimension is Fs/2 Hz
|
||||
}
|
||||
|
||||
void LinearSpectrogram::AcceptWaveform(const VectorBase<BaseFloat>& input) {
|
||||
base_extractor_->AcceptWaveform(input);
|
||||
}
|
||||
|
||||
void LinearSpectrogram::Hanning(vector<float>* data) const {
|
||||
CHECK_GE(data->size(), hanning_window_.size());
|
||||
|
||||
for (size_t i = 0; i < hanning_window_.size(); ++i) {
|
||||
data->at(i) *= hanning_window_[i];
|
||||
}
|
||||
}
|
||||
|
||||
bool LinearSpectrogram::NumpyFft(vector<BaseFloat>* v,
|
||||
vector<BaseFloat>* real,
|
||||
vector<BaseFloat>* img) const {
|
||||
Vector<BaseFloat> v_tmp;
|
||||
CopyStdVector2Vector_(*v, &v_tmp);
|
||||
RealFft(&v_tmp, true);
|
||||
CopyVector2StdVector_(v_tmp, v);
|
||||
real->push_back(v->at(0));
|
||||
img->push_back(0);
|
||||
for (int i = 1; i < v->size() / 2; i++) {
|
||||
real->push_back(v->at(2 * i));
|
||||
img->push_back(v->at(2 * i + 1));
|
||||
}
|
||||
real->push_back(v->at(1));
|
||||
img->push_back(0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// todo remove later
|
||||
void LinearSpectrogram::ReadFeats(Matrix<BaseFloat>* feats) {
|
||||
Vector<BaseFloat> tmp;
|
||||
waveform_.Resize(base_extractor_->Dim());
|
||||
Compute(tmp, &waveform_);
|
||||
vector<vector<BaseFloat>> result;
|
||||
vector<BaseFloat> feats_vec;
|
||||
CopyVector2StdVector_(waveform_, &feats_vec);
|
||||
Compute(feats_vec, result);
|
||||
feats->Resize(result.size(), result[0].size());
|
||||
for (int row_idx = 0; row_idx < result.size(); ++row_idx) {
|
||||
for (int col_idx = 0; col_idx < result.size(); ++col_idx) {
|
||||
(*feats)(row_idx, col_idx) = result[row_idx][col_idx];
|
||||
}
|
||||
}
|
||||
waveform_.Resize(0);
|
||||
}
|
||||
|
||||
void LinearSpectrogram::Read(VectorBase<BaseFloat>* feat) {
|
||||
// todo
|
||||
return;
|
||||
}
|
||||
|
||||
// only for test, remove later
|
||||
// todo: compute the feature frame by frame.
|
||||
void LinearSpectrogram::Compute(const VectorBase<kaldi::BaseFloat>& input,
|
||||
VectorBase<kaldi::BaseFloat>* feature) {
|
||||
base_extractor_->Read(feature);
|
||||
}
|
||||
|
||||
// Compute spectrogram feat, only for test, remove later
|
||||
// todo: refactor later (SmileGoat)
|
||||
bool LinearSpectrogram::Compute(const vector<float>& wave,
|
||||
vector<vector<float>>& feat) {
|
||||
int num_samples = wave.size();
|
||||
const int& frame_length = opts_.frame_opts.WindowSize();
|
||||
const int& sample_rate = opts_.frame_opts.samp_freq;
|
||||
const int& frame_shift = opts_.frame_opts.WindowShift();
|
||||
const int& fft_points = fft_points_;
|
||||
const float scale = hanning_window_energy_ * frame_shift;
|
||||
|
||||
if (num_samples < frame_length) {
|
||||
return true;
|
||||
}
|
||||
|
||||
int num_frames = 1 + ((num_samples - frame_length) / frame_shift);
|
||||
feat.resize(num_frames);
|
||||
vector<float> fft_real((fft_points_ / 2 + 1), 0);
|
||||
vector<float> fft_img((fft_points_ / 2 + 1), 0);
|
||||
vector<float> v(frame_length, 0);
|
||||
vector<float> power((fft_points / 2 + 1));
|
||||
|
||||
for (int i = 0; i < num_frames; ++i) {
|
||||
vector<float> data(wave.data() + i * frame_shift,
|
||||
wave.data() + i * frame_shift + frame_length);
|
||||
Hanning(&data);
|
||||
fft_img.clear();
|
||||
fft_real.clear();
|
||||
v.assign(data.begin(), data.end());
|
||||
if (NumpyFft(&v, &fft_real, &fft_img)) {
|
||||
LOG(ERROR)<< i << " fft compute occurs error, please checkout the input data";
|
||||
return false;
|
||||
}
|
||||
|
||||
feat[i].resize(fft_points / 2 + 1); // the last dimension is Fs/2 Hz
|
||||
for (int j = 0; j < (fft_points / 2 + 1); ++j) {
|
||||
power[j] = fft_real[j] * fft_real[j] + fft_img[j] * fft_img[j];
|
||||
feat[i][j] = power[j];
|
||||
|
||||
if (j == 0 || j == feat[0].size() - 1) {
|
||||
feat[i][j] /= scale;
|
||||
} else {
|
||||
feat[i][j] *= (2.0 / scale);
|
||||
}
|
||||
|
||||
// log added eps=1e-14
|
||||
feat[i][j] = std::log(feat[i][j] + 1e-14);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace ppspeech
|
@ -0,0 +1,50 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "frontend/feature_extractor_interface.h"
|
||||
#include "kaldi/feat/feature-window.h"
|
||||
#include "base/common.h"
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
struct LinearSpectrogramOptions {
|
||||
kaldi::FrameExtractionOptions frame_opts;
|
||||
LinearSpectrogramOptions():
|
||||
frame_opts() {}
|
||||
|
||||
void Register(kaldi::OptionsItf* opts) {
|
||||
frame_opts.Register(opts);
|
||||
}
|
||||
};
|
||||
|
||||
class LinearSpectrogram : public FeatureExtractorInterface {
|
||||
public:
|
||||
explicit LinearSpectrogram(const LinearSpectrogramOptions& opts,
|
||||
std::unique_ptr<FeatureExtractorInterface> base_extractor);
|
||||
virtual void AcceptWaveform(const kaldi::VectorBase<kaldi::BaseFloat>& input);
|
||||
virtual void Read(kaldi::VectorBase<kaldi::BaseFloat>* feat);
|
||||
virtual size_t Dim() const { return dim_; }
|
||||
void ReadFeats(kaldi::Matrix<kaldi::BaseFloat>* feats);
|
||||
|
||||
private:
|
||||
void Hanning(std::vector<kaldi::BaseFloat>* data) const;
|
||||
bool Compute(const std::vector<kaldi::BaseFloat>& wave,
|
||||
std::vector<std::vector<kaldi::BaseFloat>>& feat);
|
||||
void Compute(const kaldi::VectorBase<kaldi::BaseFloat>& input,
|
||||
kaldi::VectorBase<kaldi::BaseFloat>* feature);
|
||||
bool NumpyFft(std::vector<kaldi::BaseFloat>* v,
|
||||
std::vector<kaldi::BaseFloat>* real,
|
||||
std::vector<kaldi::BaseFloat>* img) const;
|
||||
|
||||
kaldi::int32 fft_points_;
|
||||
size_t dim_;
|
||||
std::vector<kaldi::BaseFloat> hanning_window_;
|
||||
kaldi::BaseFloat hanning_window_energy_;
|
||||
LinearSpectrogramOptions opts_;
|
||||
kaldi::Vector<kaldi::BaseFloat> waveform_; // remove later, todo(SmileGoat)
|
||||
std::unique_ptr<FeatureExtractorInterface> base_extractor_;
|
||||
DISALLOW_COPY_AND_ASSIGN(LinearSpectrogram);
|
||||
};
|
||||
|
||||
|
||||
} // namespace ppspeech
|
@ -0,0 +1,16 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// wrap the mfcc feat of kaldi, todo (SmileGoat)
|
||||
#include "kaldi/feat/feature-mfcc.h"
|
@ -0,0 +1,130 @@
|
||||
|
||||
#include "frontend/normalizer.h"
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
using kaldi::Vector;
|
||||
using kaldi::VectorBase;
|
||||
using kaldi::BaseFloat;
|
||||
using std::vector;
|
||||
|
||||
DecibelNormalizer::DecibelNormalizer(const DecibelNormalizerOptions& opts) {
|
||||
opts_ = opts;
|
||||
dim_ = 0;
|
||||
}
|
||||
|
||||
void DecibelNormalizer::AcceptWaveform(const kaldi::VectorBase<BaseFloat>& input) {
|
||||
dim_ = input.Dim();
|
||||
waveform_.Resize(input.Dim());
|
||||
waveform_.CopyFromVec(input);
|
||||
}
|
||||
|
||||
void DecibelNormalizer::Read(kaldi::VectorBase<BaseFloat>* feat) {
|
||||
if (waveform_.Dim() == 0) return;
|
||||
Compute(waveform_, feat);
|
||||
}
|
||||
|
||||
//todo remove later
|
||||
void CopyVector2StdVector(const kaldi::VectorBase<BaseFloat>& input,
|
||||
vector<BaseFloat>* output) {
|
||||
if (input.Dim() == 0) return;
|
||||
output->resize(input.Dim());
|
||||
for (size_t idx = 0; idx < input.Dim(); ++idx) {
|
||||
(*output)[idx] = input(idx);
|
||||
}
|
||||
}
|
||||
|
||||
void CopyStdVector2Vector(const vector<BaseFloat>& input,
|
||||
VectorBase<BaseFloat>* output) {
|
||||
if (input.empty()) return;
|
||||
assert(input.size() == output->Dim());
|
||||
for (size_t idx = 0; idx < input.size(); ++idx) {
|
||||
(*output)(idx) = input[idx];
|
||||
}
|
||||
}
|
||||
|
||||
bool DecibelNormalizer::Compute(const VectorBase<BaseFloat>& input,
|
||||
VectorBase<BaseFloat>* feat) const {
|
||||
// calculate db rms
|
||||
BaseFloat rms_db = 0.0;
|
||||
BaseFloat mean_square = 0.0;
|
||||
BaseFloat gain = 0.0;
|
||||
BaseFloat wave_float_normlization = 1.0f / (std::pow(2, 16 - 1));
|
||||
|
||||
vector<BaseFloat> samples;
|
||||
samples.resize(input.Dim());
|
||||
for (int32 i = 0; i < samples.size(); ++i) {
|
||||
samples[i] = input(i);
|
||||
}
|
||||
|
||||
// square
|
||||
for (auto &d : samples) {
|
||||
if (opts_.convert_int_float) {
|
||||
d = d * wave_float_normlization;
|
||||
}
|
||||
mean_square += d * d;
|
||||
}
|
||||
|
||||
// mean
|
||||
mean_square /= samples.size();
|
||||
rms_db = 10 * std::log10(mean_square);
|
||||
gain = opts_.target_db - rms_db;
|
||||
|
||||
if (gain > opts_.max_gain_db) {
|
||||
LOG(ERROR) << "Unable to normalize segment to " << opts_.target_db << "dB,"
|
||||
<< "because the the probable gain have exceeds opts_.max_gain_db"
|
||||
<< opts_.max_gain_db << "dB.";
|
||||
return false;
|
||||
}
|
||||
|
||||
// Note that this is an in-place transformation.
|
||||
for (auto &item : samples) {
|
||||
// python item *= 10.0 ** (gain / 20.0)
|
||||
item *= std::pow(10.0, gain / 20.0);
|
||||
}
|
||||
|
||||
CopyStdVector2Vector(samples, feat);
|
||||
return true;
|
||||
}
|
||||
|
||||
/*
|
||||
PPNormalizer::PPNormalizer(
|
||||
const PPNormalizerOptions& opts,
|
||||
const std::unique_ptr<FeatureExtractorInterface>& pre_extractor) {
|
||||
|
||||
}
|
||||
|
||||
void PPNormalizer::AcceptWavefrom(const Vector<BaseFloat>& input) {
|
||||
|
||||
}
|
||||
|
||||
void PPNormalizer::Read(Vector<BaseFloat>* feat) {
|
||||
|
||||
}
|
||||
|
||||
bool PPNormalizer::Compute(const Vector<BaseFloat>& input,
|
||||
Vector<BaseFloat>>* feat) {
|
||||
if ((input.Dim() % mean_.Dim()) == 0) {
|
||||
LOG(ERROR) << "CMVN dimension is wrong!";
|
||||
return false;
|
||||
}
|
||||
|
||||
try {
|
||||
int32 size = mean_.Dim();
|
||||
feat->Resize(input.Dim());
|
||||
for (int32 row_idx = 0; row_idx < j; ++row_idx) {
|
||||
int32 base_idx = row_idx * size;
|
||||
for (int32 idx = 0; idx < mean_.Dim(); ++idx) {
|
||||
(*feat)(base_idx + idx) = (input(base_dix + idx) - mean_(idx))* variance_(idx);
|
||||
}
|
||||
}
|
||||
|
||||
} catch(const std::exception& e) {
|
||||
std::cerr << e.what() << '\n';
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}*/
|
||||
|
||||
} // namespace ppspeech
|
@ -0,0 +1,72 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "base/common.h"
|
||||
#include "frontend/feature_extractor_interface.h"
|
||||
#include "kaldi/util/options-itf.h"
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
|
||||
struct DecibelNormalizerOptions {
|
||||
float target_db;
|
||||
float max_gain_db;
|
||||
bool convert_int_float;
|
||||
DecibelNormalizerOptions() :
|
||||
target_db(-20),
|
||||
max_gain_db(300.0),
|
||||
convert_int_float(false) {}
|
||||
|
||||
void Register(kaldi::OptionsItf* opts) {
|
||||
opts->Register("target-db", &target_db, "target db for db normalization");
|
||||
opts->Register("max-gain-db", &max_gain_db, "max gain db for db normalization");
|
||||
opts->Register("convert-int-float", &convert_int_float, "if convert int samples to float");
|
||||
}
|
||||
};
|
||||
|
||||
class DecibelNormalizer : public FeatureExtractorInterface {
|
||||
public:
|
||||
explicit DecibelNormalizer(const DecibelNormalizerOptions& opts);
|
||||
virtual void AcceptWaveform(const kaldi::VectorBase<kaldi::BaseFloat>& input);
|
||||
virtual void Read(kaldi::VectorBase<kaldi::BaseFloat>* feat);
|
||||
virtual size_t Dim() const { return 0; }
|
||||
bool Compute(const kaldi::VectorBase<kaldi::BaseFloat>& input,
|
||||
kaldi::VectorBase<kaldi::BaseFloat>* feat) const;
|
||||
private:
|
||||
DecibelNormalizerOptions opts_;
|
||||
size_t dim_;
|
||||
std::unique_ptr<FeatureExtractorInterface> base_extractor_;
|
||||
kaldi::Vector<kaldi::BaseFloat> waveform_;
|
||||
};
|
||||
|
||||
/*
|
||||
struct NormalizerOptions {
|
||||
std::string mean_std_path;
|
||||
NormalizerOptions() :
|
||||
mean_std_path("") {}
|
||||
|
||||
void Register(kaldi::OptionsItf* opts) {
|
||||
opts->Register("mean-std", &mean_std_path, "mean std file");
|
||||
}
|
||||
};
|
||||
|
||||
// todo refactor later (SmileGoat)
|
||||
class PPNormalizer : public FeatureExtractorInterface {
|
||||
public:
|
||||
explicit PPNormalizer(const NormalizerOptions& opts,
|
||||
const std::unique_ptr<FeatureExtractorInterface>& pre_extractor);
|
||||
~PPNormalizer() {}
|
||||
virtual void AcceptWavefrom(const kaldi::Vector<kaldi::BaseFloat>& input);
|
||||
virtual void Read(kaldi::Vector<kaldi::BaseFloat>* feat);
|
||||
virtual size_t Dim() const;
|
||||
bool Compute(const kaldi::Vector<kaldi::BaseFloat>& input,
|
||||
kaldi::Vector<kaldi::BaseFloat>>& feat);
|
||||
|
||||
private:
|
||||
bool _initialized;
|
||||
kaldi::Vector<float> mean_;
|
||||
kaldi::Vector<float> variance_;
|
||||
NormalizerOptions _opts;
|
||||
};
|
||||
*/
|
||||
} // namespace ppspeech
|
@ -0,0 +1,16 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// extract the window of kaldi feat.
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,549 @@
|
||||
// decoder/lattice-faster-decoder.h
|
||||
|
||||
// Copyright 2009-2013 Microsoft Corporation; Mirko Hannemann;
|
||||
// 2013-2014 Johns Hopkins University (Author: Daniel Povey)
|
||||
// 2014 Guoguo Chen
|
||||
// 2018 Zhehuai Chen
|
||||
|
||||
// See ../../COPYING for clarification regarding multiple authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef KALDI_DECODER_LATTICE_FASTER_DECODER_H_
|
||||
#define KALDI_DECODER_LATTICE_FASTER_DECODER_H_
|
||||
|
||||
#include "decoder/grammar-fst.h"
|
||||
#include "fst/fstlib.h"
|
||||
#include "fst/memory.h"
|
||||
#include "fstext/fstext-lib.h"
|
||||
#include "itf/decodable-itf.h"
|
||||
#include "lat/determinize-lattice-pruned.h"
|
||||
#include "lat/kaldi-lattice.h"
|
||||
#include "util/hash-list.h"
|
||||
#include "util/stl-utils.h"
|
||||
|
||||
namespace kaldi {
|
||||
|
||||
struct LatticeFasterDecoderConfig {
|
||||
BaseFloat beam;
|
||||
int32 max_active;
|
||||
int32 min_active;
|
||||
BaseFloat lattice_beam;
|
||||
int32 prune_interval;
|
||||
bool determinize_lattice; // not inspected by this class... used in
|
||||
// command-line program.
|
||||
BaseFloat beam_delta;
|
||||
BaseFloat hash_ratio;
|
||||
// Note: we don't make prune_scale configurable on the command line, it's not
|
||||
// a very important parameter. It affects the algorithm that prunes the
|
||||
// tokens as we go.
|
||||
BaseFloat prune_scale;
|
||||
|
||||
// Number of elements in the block for Token and ForwardLink memory
|
||||
// pool allocation.
|
||||
int32 memory_pool_tokens_block_size;
|
||||
int32 memory_pool_links_block_size;
|
||||
|
||||
// Most of the options inside det_opts are not actually queried by the
|
||||
// LatticeFasterDecoder class itself, but by the code that calls it, for
|
||||
// example in the function DecodeUtteranceLatticeFaster.
|
||||
fst::DeterminizeLatticePhonePrunedOptions det_opts;
|
||||
|
||||
LatticeFasterDecoderConfig()
|
||||
: beam(16.0),
|
||||
max_active(std::numeric_limits<int32>::max()),
|
||||
min_active(200),
|
||||
lattice_beam(10.0),
|
||||
prune_interval(25),
|
||||
determinize_lattice(true),
|
||||
beam_delta(0.5),
|
||||
hash_ratio(2.0),
|
||||
prune_scale(0.1),
|
||||
memory_pool_tokens_block_size(1 << 8),
|
||||
memory_pool_links_block_size(1 << 8) {}
|
||||
void Register(OptionsItf *opts) {
|
||||
det_opts.Register(opts);
|
||||
opts->Register("beam", &beam, "Decoding beam. Larger->slower, more accurate.");
|
||||
opts->Register("max-active", &max_active, "Decoder max active states. Larger->slower; "
|
||||
"more accurate");
|
||||
opts->Register("min-active", &min_active, "Decoder minimum #active states.");
|
||||
opts->Register("lattice-beam", &lattice_beam, "Lattice generation beam. Larger->slower, "
|
||||
"and deeper lattices");
|
||||
opts->Register("prune-interval", &prune_interval, "Interval (in frames) at "
|
||||
"which to prune tokens");
|
||||
opts->Register("determinize-lattice", &determinize_lattice, "If true, "
|
||||
"determinize the lattice (lattice-determinization, keeping only "
|
||||
"best pdf-sequence for each word-sequence).");
|
||||
opts->Register("beam-delta", &beam_delta, "Increment used in decoding-- this "
|
||||
"parameter is obscure and relates to a speedup in the way the "
|
||||
"max-active constraint is applied. Larger is more accurate.");
|
||||
opts->Register("hash-ratio", &hash_ratio, "Setting used in decoder to "
|
||||
"control hash behavior");
|
||||
opts->Register("memory-pool-tokens-block-size", &memory_pool_tokens_block_size,
|
||||
"Memory pool block size suggestion for storing tokens (in elements). "
|
||||
"Smaller uses less memory but increases cache misses.");
|
||||
opts->Register("memory-pool-links-block-size", &memory_pool_links_block_size,
|
||||
"Memory pool block size suggestion for storing links (in elements). "
|
||||
"Smaller uses less memory but increases cache misses.");
|
||||
}
|
||||
void Check() const {
|
||||
KALDI_ASSERT(beam > 0.0 && max_active > 1 && lattice_beam > 0.0
|
||||
&& min_active <= max_active
|
||||
&& prune_interval > 0 && beam_delta > 0.0 && hash_ratio >= 1.0
|
||||
&& prune_scale > 0.0 && prune_scale < 1.0);
|
||||
}
|
||||
};
|
||||
|
||||
namespace decoder {
|
||||
// We will template the decoder on the token type as well as the FST type; this
|
||||
// is a mechanism so that we can use the same underlying decoder code for
|
||||
// versions of the decoder that support quickly getting the best path
|
||||
// (LatticeFasterOnlineDecoder, see lattice-faster-online-decoder.h) and also
|
||||
// those that do not (LatticeFasterDecoder).
|
||||
|
||||
|
||||
// ForwardLinks are the links from a token to a token on the next frame.
|
||||
// or sometimes on the current frame (for input-epsilon links).
|
||||
template <typename Token>
|
||||
struct ForwardLink {
|
||||
using Label = fst::StdArc::Label;
|
||||
|
||||
Token *next_tok; // the next token [or NULL if represents final-state]
|
||||
Label ilabel; // ilabel on arc
|
||||
Label olabel; // olabel on arc
|
||||
BaseFloat graph_cost; // graph cost of traversing arc (contains LM, etc.)
|
||||
BaseFloat acoustic_cost; // acoustic cost (pre-scaled) of traversing arc
|
||||
ForwardLink *next; // next in singly-linked list of forward arcs (arcs
|
||||
// in the state-level lattice) from a token.
|
||||
inline ForwardLink(Token *next_tok, Label ilabel, Label olabel,
|
||||
BaseFloat graph_cost, BaseFloat acoustic_cost,
|
||||
ForwardLink *next):
|
||||
next_tok(next_tok), ilabel(ilabel), olabel(olabel),
|
||||
graph_cost(graph_cost), acoustic_cost(acoustic_cost),
|
||||
next(next) { }
|
||||
};
|
||||
|
||||
|
||||
struct StdToken {
|
||||
using ForwardLinkT = ForwardLink<StdToken>;
|
||||
using Token = StdToken;
|
||||
|
||||
// Standard token type for LatticeFasterDecoder. Each active HCLG
|
||||
// (decoding-graph) state on each frame has one token.
|
||||
|
||||
// tot_cost is the total (LM + acoustic) cost from the beginning of the
|
||||
// utterance up to this point. (but see cost_offset_, which is subtracted
|
||||
// to keep it in a good numerical range).
|
||||
BaseFloat tot_cost;
|
||||
|
||||
// exta_cost is >= 0. After calling PruneForwardLinks, this equals the
|
||||
// minimum difference between the cost of the best path that this link is a
|
||||
// part of, and the cost of the absolute best path, under the assumption that
|
||||
// any of the currently active states at the decoding front may eventually
|
||||
// succeed (e.g. if you were to take the currently active states one by one
|
||||
// and compute this difference, and then take the minimum).
|
||||
BaseFloat extra_cost;
|
||||
|
||||
// 'links' is the head of singly-linked list of ForwardLinks, which is what we
|
||||
// use for lattice generation.
|
||||
ForwardLinkT *links;
|
||||
|
||||
//'next' is the next in the singly-linked list of tokens for this frame.
|
||||
Token *next;
|
||||
|
||||
// This function does nothing and should be optimized out; it's needed
|
||||
// so we can share the regular LatticeFasterDecoderTpl code and the code
|
||||
// for LatticeFasterOnlineDecoder that supports fast traceback.
|
||||
inline void SetBackpointer (Token *backpointer) { }
|
||||
|
||||
// This constructor just ignores the 'backpointer' argument. That argument is
|
||||
// needed so that we can use the same decoder code for LatticeFasterDecoderTpl
|
||||
// and LatticeFasterOnlineDecoderTpl (which needs backpointers to support a
|
||||
// fast way to obtain the best path).
|
||||
inline StdToken(BaseFloat tot_cost, BaseFloat extra_cost, ForwardLinkT *links,
|
||||
Token *next, Token *backpointer):
|
||||
tot_cost(tot_cost), extra_cost(extra_cost), links(links), next(next) { }
|
||||
};
|
||||
|
||||
struct BackpointerToken {
|
||||
using ForwardLinkT = ForwardLink<BackpointerToken>;
|
||||
using Token = BackpointerToken;
|
||||
|
||||
// BackpointerToken is like Token but also
|
||||
// Standard token type for LatticeFasterDecoder. Each active HCLG
|
||||
// (decoding-graph) state on each frame has one token.
|
||||
|
||||
// tot_cost is the total (LM + acoustic) cost from the beginning of the
|
||||
// utterance up to this point. (but see cost_offset_, which is subtracted
|
||||
// to keep it in a good numerical range).
|
||||
BaseFloat tot_cost;
|
||||
|
||||
// exta_cost is >= 0. After calling PruneForwardLinks, this equals
|
||||
// the minimum difference between the cost of the best path, and the cost of
|
||||
// this is on, and the cost of the absolute best path, under the assumption
|
||||
// that any of the currently active states at the decoding front may
|
||||
// eventually succeed (e.g. if you were to take the currently active states
|
||||
// one by one and compute this difference, and then take the minimum).
|
||||
BaseFloat extra_cost;
|
||||
|
||||
// 'links' is the head of singly-linked list of ForwardLinks, which is what we
|
||||
// use for lattice generation.
|
||||
ForwardLinkT *links;
|
||||
|
||||
//'next' is the next in the singly-linked list of tokens for this frame.
|
||||
BackpointerToken *next;
|
||||
|
||||
// Best preceding BackpointerToken (could be a on this frame, connected to
|
||||
// this via an epsilon transition, or on a previous frame). This is only
|
||||
// required for an efficient GetBestPath function in
|
||||
// LatticeFasterOnlineDecoderTpl; it plays no part in the lattice generation
|
||||
// (the "links" list is what stores the forward links, for that).
|
||||
Token *backpointer;
|
||||
|
||||
inline void SetBackpointer (Token *backpointer) {
|
||||
this->backpointer = backpointer;
|
||||
}
|
||||
|
||||
inline BackpointerToken(BaseFloat tot_cost, BaseFloat extra_cost, ForwardLinkT *links,
|
||||
Token *next, Token *backpointer):
|
||||
tot_cost(tot_cost), extra_cost(extra_cost), links(links), next(next),
|
||||
backpointer(backpointer) { }
|
||||
};
|
||||
|
||||
} // namespace decoder
|
||||
|
||||
|
||||
/** This is the "normal" lattice-generating decoder.
|
||||
See \ref lattices_generation \ref decoders_faster and \ref decoders_simple
|
||||
for more information.
|
||||
|
||||
The decoder is templated on the FST type and the token type. The token type
|
||||
will normally be StdToken, but also may be BackpointerToken which is to support
|
||||
quick lookup of the current best path (see lattice-faster-online-decoder.h)
|
||||
|
||||
The FST you invoke this decoder which is expected to equal
|
||||
Fst::Fst<fst::StdArc>, a.k.a. StdFst, or GrammarFst. If you invoke it with
|
||||
FST == StdFst and it notices that the actual FST type is
|
||||
fst::VectorFst<fst::StdArc> or fst::ConstFst<fst::StdArc>, the decoder object
|
||||
will internally cast itself to one that is templated on those more specific
|
||||
types; this is an optimization for speed.
|
||||
*/
|
||||
template <typename FST, typename Token = decoder::StdToken>
|
||||
class LatticeFasterDecoderTpl {
|
||||
public:
|
||||
using Arc = typename FST::Arc;
|
||||
using Label = typename Arc::Label;
|
||||
using StateId = typename Arc::StateId;
|
||||
using Weight = typename Arc::Weight;
|
||||
using ForwardLinkT = decoder::ForwardLink<Token>;
|
||||
|
||||
// Instantiate this class once for each thing you have to decode.
|
||||
// This version of the constructor does not take ownership of
|
||||
// 'fst'.
|
||||
LatticeFasterDecoderTpl(const FST &fst,
|
||||
const LatticeFasterDecoderConfig &config);
|
||||
|
||||
// This version of the constructor takes ownership of the fst, and will delete
|
||||
// it when this object is destroyed.
|
||||
LatticeFasterDecoderTpl(const LatticeFasterDecoderConfig &config,
|
||||
FST *fst);
|
||||
|
||||
void SetOptions(const LatticeFasterDecoderConfig &config) {
|
||||
config_ = config;
|
||||
}
|
||||
|
||||
const LatticeFasterDecoderConfig &GetOptions() const {
|
||||
return config_;
|
||||
}
|
||||
|
||||
~LatticeFasterDecoderTpl();
|
||||
|
||||
/// Decodes until there are no more frames left in the "decodable" object..
|
||||
/// note, this may block waiting for input if the "decodable" object blocks.
|
||||
/// Returns true if any kind of traceback is available (not necessarily from a
|
||||
/// final state).
|
||||
bool Decode(DecodableInterface *decodable);
|
||||
|
||||
|
||||
/// says whether a final-state was active on the last frame. If it was not, the
|
||||
/// lattice (or traceback) will end with states that are not final-states.
|
||||
bool ReachedFinal() const {
|
||||
return FinalRelativeCost() != std::numeric_limits<BaseFloat>::infinity();
|
||||
}
|
||||
|
||||
/// Outputs an FST corresponding to the single best path through the lattice.
|
||||
/// Returns true if result is nonempty (using the return status is deprecated,
|
||||
/// it will become void). If "use_final_probs" is true AND we reached the
|
||||
/// final-state of the graph then it will include those as final-probs, else
|
||||
/// it will treat all final-probs as one. Note: this just calls GetRawLattice()
|
||||
/// and figures out the shortest path.
|
||||
bool GetBestPath(Lattice *ofst,
|
||||
bool use_final_probs = true) const;
|
||||
|
||||
/// Outputs an FST corresponding to the raw, state-level
|
||||
/// tracebacks. Returns true if result is nonempty.
|
||||
/// If "use_final_probs" is true AND we reached the final-state
|
||||
/// of the graph then it will include those as final-probs, else
|
||||
/// it will treat all final-probs as one.
|
||||
/// The raw lattice will be topologically sorted.
|
||||
///
|
||||
/// See also GetRawLatticePruned in lattice-faster-online-decoder.h,
|
||||
/// which also supports a pruning beam, in case for some reason
|
||||
/// you want it pruned tighter than the regular lattice beam.
|
||||
/// We could put that here in future needed.
|
||||
bool GetRawLattice(Lattice *ofst, bool use_final_probs = true) const;
|
||||
|
||||
|
||||
|
||||
/// [Deprecated, users should now use GetRawLattice and determinize it
|
||||
/// themselves, e.g. using DeterminizeLatticePhonePrunedWrapper].
|
||||
/// Outputs an FST corresponding to the lattice-determinized
|
||||
/// lattice (one path per word sequence). Returns true if result is nonempty.
|
||||
/// If "use_final_probs" is true AND we reached the final-state of the graph
|
||||
/// then it will include those as final-probs, else it will treat all
|
||||
/// final-probs as one.
|
||||
bool GetLattice(CompactLattice *ofst,
|
||||
bool use_final_probs = true) const;
|
||||
|
||||
/// InitDecoding initializes the decoding, and should only be used if you
|
||||
/// intend to call AdvanceDecoding(). If you call Decode(), you don't need to
|
||||
/// call this. You can also call InitDecoding if you have already decoded an
|
||||
/// utterance and want to start with a new utterance.
|
||||
void InitDecoding();
|
||||
|
||||
/// This will decode until there are no more frames ready in the decodable
|
||||
/// object. You can keep calling it each time more frames become available.
|
||||
/// If max_num_frames is specified, it specifies the maximum number of frames
|
||||
/// the function will decode before returning.
|
||||
void AdvanceDecoding(DecodableInterface *decodable,
|
||||
int32 max_num_frames = -1);
|
||||
|
||||
/// This function may be optionally called after AdvanceDecoding(), when you
|
||||
/// do not plan to decode any further. It does an extra pruning step that
|
||||
/// will help to prune the lattices output by GetLattice and (particularly)
|
||||
/// GetRawLattice more completely, particularly toward the end of the
|
||||
/// utterance. If you call this, you cannot call AdvanceDecoding again (it
|
||||
/// will fail), and you cannot call GetLattice() and related functions with
|
||||
/// use_final_probs = false. Used to be called PruneActiveTokensFinal().
|
||||
void FinalizeDecoding();
|
||||
|
||||
/// FinalRelativeCost() serves the same purpose as ReachedFinal(), but gives
|
||||
/// more information. It returns the difference between the best (final-cost
|
||||
/// plus cost) of any token on the final frame, and the best cost of any token
|
||||
/// on the final frame. If it is infinity it means no final-states were
|
||||
/// present on the final frame. It will usually be nonnegative. If it not
|
||||
/// too positive (e.g. < 5 is my first guess, but this is not tested) you can
|
||||
/// take it as a good indication that we reached the final-state with
|
||||
/// reasonable likelihood.
|
||||
BaseFloat FinalRelativeCost() const;
|
||||
|
||||
|
||||
// Returns the number of frames decoded so far. The value returned changes
|
||||
// whenever we call ProcessEmitting().
|
||||
inline int32 NumFramesDecoded() const { return active_toks_.size() - 1; }
|
||||
|
||||
protected:
|
||||
// we make things protected instead of private, as code in
|
||||
// LatticeFasterOnlineDecoderTpl, which inherits from this, also uses the
|
||||
// internals.
|
||||
|
||||
// Deletes the elements of the singly linked list tok->links.
|
||||
void DeleteForwardLinks(Token *tok);
|
||||
|
||||
// head of per-frame list of Tokens (list is in topological order),
|
||||
// and something saying whether we ever pruned it using PruneForwardLinks.
|
||||
struct TokenList {
|
||||
Token *toks;
|
||||
bool must_prune_forward_links;
|
||||
bool must_prune_tokens;
|
||||
TokenList(): toks(NULL), must_prune_forward_links(true),
|
||||
must_prune_tokens(true) { }
|
||||
};
|
||||
|
||||
using Elem = typename HashList<StateId, Token*>::Elem;
|
||||
// Equivalent to:
|
||||
// struct Elem {
|
||||
// StateId key;
|
||||
// Token *val;
|
||||
// Elem *tail;
|
||||
// };
|
||||
|
||||
void PossiblyResizeHash(size_t num_toks);
|
||||
|
||||
// FindOrAddToken either locates a token in hash of toks_, or if necessary
|
||||
// inserts a new, empty token (i.e. with no forward links) for the current
|
||||
// frame. [note: it's inserted if necessary into hash toks_ and also into the
|
||||
// singly linked list of tokens active on this frame (whose head is at
|
||||
// active_toks_[frame]). The frame_plus_one argument is the acoustic frame
|
||||
// index plus one, which is used to index into the active_toks_ array.
|
||||
// Returns the Token pointer. Sets "changed" (if non-NULL) to true if the
|
||||
// token was newly created or the cost changed.
|
||||
// If Token == StdToken, the 'backpointer' argument has no purpose (and will
|
||||
// hopefully be optimized out).
|
||||
inline Elem *FindOrAddToken(StateId state, int32 frame_plus_one,
|
||||
BaseFloat tot_cost, Token *backpointer,
|
||||
bool *changed);
|
||||
|
||||
// prunes outgoing links for all tokens in active_toks_[frame]
|
||||
// it's called by PruneActiveTokens
|
||||
// all links, that have link_extra_cost > lattice_beam are pruned
|
||||
// delta is the amount by which the extra_costs must change
|
||||
// before we set *extra_costs_changed = true.
|
||||
// If delta is larger, we'll tend to go back less far
|
||||
// toward the beginning of the file.
|
||||
// extra_costs_changed is set to true if extra_cost was changed for any token
|
||||
// links_pruned is set to true if any link in any token was pruned
|
||||
void PruneForwardLinks(int32 frame_plus_one, bool *extra_costs_changed,
|
||||
bool *links_pruned,
|
||||
BaseFloat delta);
|
||||
|
||||
// This function computes the final-costs for tokens active on the final
|
||||
// frame. It outputs to final-costs, if non-NULL, a map from the Token*
|
||||
// pointer to the final-prob of the corresponding state, for all Tokens
|
||||
// that correspond to states that have final-probs. This map will be
|
||||
// empty if there were no final-probs. It outputs to
|
||||
// final_relative_cost, if non-NULL, the difference between the best
|
||||
// forward-cost including the final-prob cost, and the best forward-cost
|
||||
// without including the final-prob cost (this will usually be positive), or
|
||||
// infinity if there were no final-probs. [c.f. FinalRelativeCost(), which
|
||||
// outputs this quanitity]. It outputs to final_best_cost, if
|
||||
// non-NULL, the lowest for any token t active on the final frame, of
|
||||
// forward-cost[t] + final-cost[t], where final-cost[t] is the final-cost in
|
||||
// the graph of the state corresponding to token t, or the best of
|
||||
// forward-cost[t] if there were no final-probs active on the final frame.
|
||||
// You cannot call this after FinalizeDecoding() has been called; in that
|
||||
// case you should get the answer from class-member variables.
|
||||
void ComputeFinalCosts(unordered_map<Token*, BaseFloat> *final_costs,
|
||||
BaseFloat *final_relative_cost,
|
||||
BaseFloat *final_best_cost) const;
|
||||
|
||||
// PruneForwardLinksFinal is a version of PruneForwardLinks that we call
|
||||
// on the final frame. If there are final tokens active, it uses
|
||||
// the final-probs for pruning, otherwise it treats all tokens as final.
|
||||
void PruneForwardLinksFinal();
|
||||
|
||||
// Prune away any tokens on this frame that have no forward links.
|
||||
// [we don't do this in PruneForwardLinks because it would give us
|
||||
// a problem with dangling pointers].
|
||||
// It's called by PruneActiveTokens if any forward links have been pruned
|
||||
void PruneTokensForFrame(int32 frame_plus_one);
|
||||
|
||||
|
||||
// Go backwards through still-alive tokens, pruning them if the
|
||||
// forward+backward cost is more than lat_beam away from the best path. It's
|
||||
// possible to prove that this is "correct" in the sense that we won't lose
|
||||
// anything outside of lat_beam, regardless of what happens in the future.
|
||||
// delta controls when it considers a cost to have changed enough to continue
|
||||
// going backward and propagating the change. larger delta -> will recurse
|
||||
// less far.
|
||||
void PruneActiveTokens(BaseFloat delta);
|
||||
|
||||
/// Gets the weight cutoff. Also counts the active tokens.
|
||||
BaseFloat GetCutoff(Elem *list_head, size_t *tok_count,
|
||||
BaseFloat *adaptive_beam, Elem **best_elem);
|
||||
|
||||
/// Processes emitting arcs for one frame. Propagates from prev_toks_ to
|
||||
/// cur_toks_. Returns the cost cutoff for subsequent ProcessNonemitting() to
|
||||
/// use.
|
||||
BaseFloat ProcessEmitting(DecodableInterface *decodable);
|
||||
|
||||
/// Processes nonemitting (epsilon) arcs for one frame. Called after
|
||||
/// ProcessEmitting() on each frame. The cost cutoff is computed by the
|
||||
/// preceding ProcessEmitting().
|
||||
void ProcessNonemitting(BaseFloat cost_cutoff);
|
||||
|
||||
// HashList defined in ../util/hash-list.h. It actually allows us to maintain
|
||||
// more than one list (e.g. for current and previous frames), but only one of
|
||||
// them at a time can be indexed by StateId. It is indexed by frame-index
|
||||
// plus one, where the frame-index is zero-based, as used in decodable object.
|
||||
// That is, the emitting probs of frame t are accounted for in tokens at
|
||||
// toks_[t+1]. The zeroth frame is for nonemitting transition at the start of
|
||||
// the graph.
|
||||
HashList<StateId, Token*> toks_;
|
||||
|
||||
std::vector<TokenList> active_toks_; // Lists of tokens, indexed by
|
||||
// frame (members of TokenList are toks, must_prune_forward_links,
|
||||
// must_prune_tokens).
|
||||
std::vector<const Elem* > queue_; // temp variable used in ProcessNonemitting,
|
||||
std::vector<BaseFloat> tmp_array_; // used in GetCutoff.
|
||||
|
||||
// fst_ is a pointer to the FST we are decoding from.
|
||||
const FST *fst_;
|
||||
// delete_fst_ is true if the pointer fst_ needs to be deleted when this
|
||||
// object is destroyed.
|
||||
bool delete_fst_;
|
||||
|
||||
std::vector<BaseFloat> cost_offsets_; // This contains, for each
|
||||
// frame, an offset that was added to the acoustic log-likelihoods on that
|
||||
// frame in order to keep everything in a nice dynamic range i.e. close to
|
||||
// zero, to reduce roundoff errors.
|
||||
LatticeFasterDecoderConfig config_;
|
||||
int32 num_toks_; // current total #toks allocated...
|
||||
bool warned_;
|
||||
|
||||
/// decoding_finalized_ is true if someone called FinalizeDecoding(). [note,
|
||||
/// calling this is optional]. If true, it's forbidden to decode more. Also,
|
||||
/// if this is set, then the output of ComputeFinalCosts() is in the next
|
||||
/// three variables. The reason we need to do this is that after
|
||||
/// FinalizeDecoding() calls PruneTokensForFrame() for the final frame, some
|
||||
/// of the tokens on the last frame are freed, so we free the list from toks_
|
||||
/// to avoid having dangling pointers hanging around.
|
||||
bool decoding_finalized_;
|
||||
/// For the meaning of the next 3 variables, see the comment for
|
||||
/// decoding_finalized_ above., and ComputeFinalCosts().
|
||||
unordered_map<Token*, BaseFloat> final_costs_;
|
||||
BaseFloat final_relative_cost_;
|
||||
BaseFloat final_best_cost_;
|
||||
|
||||
// Memory pools for storing tokens and forward links.
|
||||
// We use it to decrease the work put on allocator and to move some of data
|
||||
// together. Too small block sizes will result in more work to allocator but
|
||||
// bigger ones increase the memory usage.
|
||||
fst::MemoryPool<Token> token_pool_;
|
||||
fst::MemoryPool<ForwardLinkT> forward_link_pool_;
|
||||
|
||||
// There are various cleanup tasks... the toks_ structure contains
|
||||
// singly linked lists of Token pointers, where Elem is the list type.
|
||||
// It also indexes them in a hash, indexed by state (this hash is only
|
||||
// maintained for the most recent frame). toks_.Clear()
|
||||
// deletes them from the hash and returns the list of Elems. The
|
||||
// function DeleteElems calls toks_.Delete(elem) for each elem in
|
||||
// the list, which returns ownership of the Elem to the toks_ structure
|
||||
// for reuse, but does not delete the Token pointer. The Token pointers
|
||||
// are reference-counted and are ultimately deleted in PruneTokensForFrame,
|
||||
// but are also linked together on each frame by their own linked-list,
|
||||
// using the "next" pointer. We delete them manually.
|
||||
void DeleteElems(Elem *list);
|
||||
|
||||
// This function takes a singly linked list of tokens for a single frame, and
|
||||
// outputs a list of them in topological order (it will crash if no such order
|
||||
// can be found, which will typically be due to decoding graphs with epsilon
|
||||
// cycles, which are not allowed). Note: the output list may contain NULLs,
|
||||
// which the caller should pass over; it just happens to be more efficient for
|
||||
// the algorithm to output a list that contains NULLs.
|
||||
static void TopSortTokens(Token *tok_list,
|
||||
std::vector<Token*> *topsorted_list);
|
||||
|
||||
void ClearActiveTokens();
|
||||
|
||||
KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeFasterDecoderTpl);
|
||||
};
|
||||
|
||||
typedef LatticeFasterDecoderTpl<fst::StdFst, decoder::StdToken> LatticeFasterDecoder;
|
||||
|
||||
|
||||
|
||||
} // end namespace kaldi.
|
||||
|
||||
#endif
|
@ -0,0 +1,285 @@
|
||||
// decoder/lattice-faster-online-decoder.cc
|
||||
|
||||
// Copyright 2009-2012 Microsoft Corporation Mirko Hannemann
|
||||
// 2013-2014 Johns Hopkins University (Author: Daniel Povey)
|
||||
// 2014 Guoguo Chen
|
||||
// 2014 IMSL, PKU-HKUST (author: Wei Shi)
|
||||
// 2018 Zhehuai Chen
|
||||
|
||||
// See ../../COPYING for clarification regarding multiple authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// see note at the top of lattice-faster-decoder.cc, about how to maintain this
|
||||
// file in sync with lattice-faster-decoder.cc
|
||||
|
||||
#include "decoder/lattice-faster-online-decoder.h"
|
||||
#include "lat/lattice-functions.h"
|
||||
|
||||
namespace kaldi {
|
||||
|
||||
template <typename FST>
|
||||
bool LatticeFasterOnlineDecoderTpl<FST>::TestGetBestPath(
|
||||
bool use_final_probs) const {
|
||||
Lattice lat1;
|
||||
{
|
||||
Lattice raw_lat;
|
||||
this->GetRawLattice(&raw_lat, use_final_probs);
|
||||
ShortestPath(raw_lat, &lat1);
|
||||
}
|
||||
Lattice lat2;
|
||||
GetBestPath(&lat2, use_final_probs);
|
||||
BaseFloat delta = 0.1;
|
||||
int32 num_paths = 1;
|
||||
if (!fst::RandEquivalent(lat1, lat2, num_paths, delta, rand())) {
|
||||
KALDI_WARN << "Best-path test failed";
|
||||
return false;
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Outputs an FST corresponding to the single best path through the lattice.
|
||||
template <typename FST>
|
||||
bool LatticeFasterOnlineDecoderTpl<FST>::GetBestPath(Lattice *olat,
|
||||
bool use_final_probs) const {
|
||||
olat->DeleteStates();
|
||||
BaseFloat final_graph_cost;
|
||||
BestPathIterator iter = BestPathEnd(use_final_probs, &final_graph_cost);
|
||||
if (iter.Done())
|
||||
return false; // would have printed warning.
|
||||
StateId state = olat->AddState();
|
||||
olat->SetFinal(state, LatticeWeight(final_graph_cost, 0.0));
|
||||
while (!iter.Done()) {
|
||||
LatticeArc arc;
|
||||
iter = TraceBackBestPath(iter, &arc);
|
||||
arc.nextstate = state;
|
||||
StateId new_state = olat->AddState();
|
||||
olat->AddArc(new_state, arc);
|
||||
state = new_state;
|
||||
}
|
||||
olat->SetStart(state);
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename FST>
|
||||
typename LatticeFasterOnlineDecoderTpl<FST>::BestPathIterator LatticeFasterOnlineDecoderTpl<FST>::BestPathEnd(
|
||||
bool use_final_probs,
|
||||
BaseFloat *final_cost_out) const {
|
||||
if (this->decoding_finalized_ && !use_final_probs)
|
||||
KALDI_ERR << "You cannot call FinalizeDecoding() and then call "
|
||||
<< "BestPathEnd() with use_final_probs == false";
|
||||
KALDI_ASSERT(this->NumFramesDecoded() > 0 &&
|
||||
"You cannot call BestPathEnd if no frames were decoded.");
|
||||
|
||||
unordered_map<Token*, BaseFloat> final_costs_local;
|
||||
|
||||
const unordered_map<Token*, BaseFloat> &final_costs =
|
||||
(this->decoding_finalized_ ? this->final_costs_ :final_costs_local);
|
||||
if (!this->decoding_finalized_ && use_final_probs)
|
||||
this->ComputeFinalCosts(&final_costs_local, NULL, NULL);
|
||||
|
||||
// Singly linked list of tokens on last frame (access list through "next"
|
||||
// pointer).
|
||||
BaseFloat best_cost = std::numeric_limits<BaseFloat>::infinity();
|
||||
BaseFloat best_final_cost = 0;
|
||||
Token *best_tok = NULL;
|
||||
for (Token *tok = this->active_toks_.back().toks;
|
||||
tok != NULL; tok = tok->next) {
|
||||
BaseFloat cost = tok->tot_cost, final_cost = 0.0;
|
||||
if (use_final_probs && !final_costs.empty()) {
|
||||
// if we are instructed to use final-probs, and any final tokens were
|
||||
// active on final frame, include the final-prob in the cost of the token.
|
||||
typename unordered_map<Token*, BaseFloat>::const_iterator
|
||||
iter = final_costs.find(tok);
|
||||
if (iter != final_costs.end()) {
|
||||
final_cost = iter->second;
|
||||
cost += final_cost;
|
||||
} else {
|
||||
cost = std::numeric_limits<BaseFloat>::infinity();
|
||||
}
|
||||
}
|
||||
if (cost < best_cost) {
|
||||
best_cost = cost;
|
||||
best_tok = tok;
|
||||
best_final_cost = final_cost;
|
||||
}
|
||||
}
|
||||
if (best_tok == NULL) { // this should not happen, and is likely a code error or
|
||||
// caused by infinities in likelihoods, but I'm not making
|
||||
// it a fatal error for now.
|
||||
KALDI_WARN << "No final token found.";
|
||||
}
|
||||
if (final_cost_out)
|
||||
*final_cost_out = best_final_cost;
|
||||
return BestPathIterator(best_tok, this->NumFramesDecoded() - 1);
|
||||
}
|
||||
|
||||
|
||||
template <typename FST>
|
||||
typename LatticeFasterOnlineDecoderTpl<FST>::BestPathIterator LatticeFasterOnlineDecoderTpl<FST>::TraceBackBestPath(
|
||||
BestPathIterator iter, LatticeArc *oarc) const {
|
||||
KALDI_ASSERT(!iter.Done() && oarc != NULL);
|
||||
Token *tok = static_cast<Token*>(iter.tok);
|
||||
int32 cur_t = iter.frame, step_t = 0;
|
||||
if (tok->backpointer != NULL) {
|
||||
// retrieve the correct forward link(with the best link cost)
|
||||
BaseFloat best_cost = std::numeric_limits<BaseFloat>::infinity();
|
||||
ForwardLinkT *link;
|
||||
for (link = tok->backpointer->links;
|
||||
link != NULL; link = link->next) {
|
||||
if (link->next_tok == tok) { // this is a link to "tok"
|
||||
BaseFloat graph_cost = link->graph_cost,
|
||||
acoustic_cost = link->acoustic_cost;
|
||||
BaseFloat cost = graph_cost + acoustic_cost;
|
||||
if (cost < best_cost) {
|
||||
oarc->ilabel = link->ilabel;
|
||||
oarc->olabel = link->olabel;
|
||||
if (link->ilabel != 0) {
|
||||
KALDI_ASSERT(static_cast<size_t>(cur_t) < this->cost_offsets_.size());
|
||||
acoustic_cost -= this->cost_offsets_[cur_t];
|
||||
step_t = -1;
|
||||
} else {
|
||||
step_t = 0;
|
||||
}
|
||||
oarc->weight = LatticeWeight(graph_cost, acoustic_cost);
|
||||
best_cost = cost;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (link == NULL &&
|
||||
best_cost == std::numeric_limits<BaseFloat>::infinity()) { // Did not find correct link.
|
||||
KALDI_ERR << "Error tracing best-path back (likely "
|
||||
<< "bug in token-pruning algorithm)";
|
||||
}
|
||||
} else {
|
||||
oarc->ilabel = 0;
|
||||
oarc->olabel = 0;
|
||||
oarc->weight = LatticeWeight::One(); // zero costs.
|
||||
}
|
||||
return BestPathIterator(tok->backpointer, cur_t + step_t);
|
||||
}
|
||||
|
||||
template <typename FST>
|
||||
bool LatticeFasterOnlineDecoderTpl<FST>::GetRawLatticePruned(
|
||||
Lattice *ofst,
|
||||
bool use_final_probs,
|
||||
BaseFloat beam) const {
|
||||
typedef LatticeArc Arc;
|
||||
typedef Arc::StateId StateId;
|
||||
typedef Arc::Weight Weight;
|
||||
typedef Arc::Label Label;
|
||||
|
||||
// Note: you can't use the old interface (Decode()) if you want to
|
||||
// get the lattice with use_final_probs = false. You'd have to do
|
||||
// InitDecoding() and then AdvanceDecoding().
|
||||
if (this->decoding_finalized_ && !use_final_probs)
|
||||
KALDI_ERR << "You cannot call FinalizeDecoding() and then call "
|
||||
<< "GetRawLattice() with use_final_probs == false";
|
||||
|
||||
unordered_map<Token*, BaseFloat> final_costs_local;
|
||||
|
||||
const unordered_map<Token*, BaseFloat> &final_costs =
|
||||
(this->decoding_finalized_ ? this->final_costs_ : final_costs_local);
|
||||
if (!this->decoding_finalized_ && use_final_probs)
|
||||
this->ComputeFinalCosts(&final_costs_local, NULL, NULL);
|
||||
|
||||
ofst->DeleteStates();
|
||||
// num-frames plus one (since frames are one-based, and we have
|
||||
// an extra frame for the start-state).
|
||||
int32 num_frames = this->active_toks_.size() - 1;
|
||||
KALDI_ASSERT(num_frames > 0);
|
||||
for (int32 f = 0; f <= num_frames; f++) {
|
||||
if (this->active_toks_[f].toks == NULL) {
|
||||
KALDI_WARN << "No tokens active on frame " << f
|
||||
<< ": not producing lattice.\n";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
unordered_map<Token*, StateId> tok_map;
|
||||
std::queue<std::pair<Token*, int32> > tok_queue;
|
||||
// First initialize the queue and states. Put the initial state on the queue;
|
||||
// this is the last token in the list active_toks_[0].toks.
|
||||
for (Token *tok = this->active_toks_[0].toks;
|
||||
tok != NULL; tok = tok->next) {
|
||||
if (tok->next == NULL) {
|
||||
tok_map[tok] = ofst->AddState();
|
||||
ofst->SetStart(tok_map[tok]);
|
||||
std::pair<Token*, int32> tok_pair(tok, 0); // #frame = 0
|
||||
tok_queue.push(tok_pair);
|
||||
}
|
||||
}
|
||||
|
||||
// Next create states for "good" tokens
|
||||
while (!tok_queue.empty()) {
|
||||
std::pair<Token*, int32> cur_tok_pair = tok_queue.front();
|
||||
tok_queue.pop();
|
||||
Token *cur_tok = cur_tok_pair.first;
|
||||
int32 cur_frame = cur_tok_pair.second;
|
||||
KALDI_ASSERT(cur_frame >= 0 &&
|
||||
cur_frame <= this->cost_offsets_.size());
|
||||
|
||||
typename unordered_map<Token*, StateId>::const_iterator iter =
|
||||
tok_map.find(cur_tok);
|
||||
KALDI_ASSERT(iter != tok_map.end());
|
||||
StateId cur_state = iter->second;
|
||||
|
||||
for (ForwardLinkT *l = cur_tok->links;
|
||||
l != NULL;
|
||||
l = l->next) {
|
||||
Token *next_tok = l->next_tok;
|
||||
if (next_tok->extra_cost < beam) {
|
||||
// so both the current and the next token are good; create the arc
|
||||
int32 next_frame = l->ilabel == 0 ? cur_frame : cur_frame + 1;
|
||||
StateId nextstate;
|
||||
if (tok_map.find(next_tok) == tok_map.end()) {
|
||||
nextstate = tok_map[next_tok] = ofst->AddState();
|
||||
tok_queue.push(std::pair<Token*, int32>(next_tok, next_frame));
|
||||
} else {
|
||||
nextstate = tok_map[next_tok];
|
||||
}
|
||||
BaseFloat cost_offset = (l->ilabel != 0 ?
|
||||
this->cost_offsets_[cur_frame] : 0);
|
||||
Arc arc(l->ilabel, l->olabel,
|
||||
Weight(l->graph_cost, l->acoustic_cost - cost_offset),
|
||||
nextstate);
|
||||
ofst->AddArc(cur_state, arc);
|
||||
}
|
||||
}
|
||||
if (cur_frame == num_frames) {
|
||||
if (use_final_probs && !final_costs.empty()) {
|
||||
typename unordered_map<Token*, BaseFloat>::const_iterator iter =
|
||||
final_costs.find(cur_tok);
|
||||
if (iter != final_costs.end())
|
||||
ofst->SetFinal(cur_state, LatticeWeight(iter->second, 0));
|
||||
} else {
|
||||
ofst->SetFinal(cur_state, LatticeWeight::One());
|
||||
}
|
||||
}
|
||||
}
|
||||
return (ofst->NumStates() != 0);
|
||||
}
|
||||
|
||||
|
||||
|
||||
// Instantiate the template for the FST types that we'll need.
|
||||
template class LatticeFasterOnlineDecoderTpl<fst::Fst<fst::StdArc> >;
|
||||
template class LatticeFasterOnlineDecoderTpl<fst::VectorFst<fst::StdArc> >;
|
||||
template class LatticeFasterOnlineDecoderTpl<fst::ConstFst<fst::StdArc> >;
|
||||
template class LatticeFasterOnlineDecoderTpl<fst::ConstGrammarFst >;
|
||||
template class LatticeFasterOnlineDecoderTpl<fst::VectorGrammarFst >;
|
||||
|
||||
|
||||
} // end namespace kaldi.
|
@ -0,0 +1,147 @@
|
||||
// decoder/lattice-faster-online-decoder.h
|
||||
|
||||
// Copyright 2009-2013 Microsoft Corporation; Mirko Hannemann;
|
||||
// 2013-2014 Johns Hopkins University (Author: Daniel Povey)
|
||||
// 2014 Guoguo Chen
|
||||
// 2018 Zhehuai Chen
|
||||
|
||||
// See ../../COPYING for clarification regarding multiple authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// see note at the top of lattice-faster-decoder.h, about how to maintain this
|
||||
// file in sync with lattice-faster-decoder.h
|
||||
|
||||
|
||||
#ifndef KALDI_DECODER_LATTICE_FASTER_ONLINE_DECODER_H_
|
||||
#define KALDI_DECODER_LATTICE_FASTER_ONLINE_DECODER_H_
|
||||
|
||||
#include "util/stl-utils.h"
|
||||
#include "util/hash-list.h"
|
||||
#include "fst/fstlib.h"
|
||||
#include "itf/decodable-itf.h"
|
||||
#include "fstext/fstext-lib.h"
|
||||
#include "lat/determinize-lattice-pruned.h"
|
||||
#include "lat/kaldi-lattice.h"
|
||||
#include "decoder/lattice-faster-decoder.h"
|
||||
|
||||
namespace kaldi {
|
||||
|
||||
|
||||
|
||||
/** LatticeFasterOnlineDecoderTpl is as LatticeFasterDecoderTpl but also
|
||||
supports an efficient way to get the best path (see the function
|
||||
BestPathEnd()), which is useful in endpointing and in situations where you
|
||||
might want to frequently access the best path.
|
||||
|
||||
This is only templated on the FST type, since the Token type is required to
|
||||
be BackpointerToken. Actually it only makes sense to instantiate
|
||||
LatticeFasterDecoderTpl with Token == BackpointerToken if you do so indirectly via
|
||||
this child class.
|
||||
*/
|
||||
template <typename FST>
|
||||
class LatticeFasterOnlineDecoderTpl:
|
||||
public LatticeFasterDecoderTpl<FST, decoder::BackpointerToken> {
|
||||
public:
|
||||
using Arc = typename FST::Arc;
|
||||
using Label = typename Arc::Label;
|
||||
using StateId = typename Arc::StateId;
|
||||
using Weight = typename Arc::Weight;
|
||||
using Token = decoder::BackpointerToken;
|
||||
using ForwardLinkT = decoder::ForwardLink<Token>;
|
||||
|
||||
// Instantiate this class once for each thing you have to decode.
|
||||
// This version of the constructor does not take ownership of
|
||||
// 'fst'.
|
||||
LatticeFasterOnlineDecoderTpl(const FST &fst,
|
||||
const LatticeFasterDecoderConfig &config):
|
||||
LatticeFasterDecoderTpl<FST, Token>(fst, config) { }
|
||||
|
||||
// This version of the initializer takes ownership of 'fst', and will delete
|
||||
// it when this object is destroyed.
|
||||
LatticeFasterOnlineDecoderTpl(const LatticeFasterDecoderConfig &config,
|
||||
FST *fst):
|
||||
LatticeFasterDecoderTpl<FST, Token>(config, fst) { }
|
||||
|
||||
|
||||
struct BestPathIterator {
|
||||
void *tok;
|
||||
int32 frame;
|
||||
// note, "frame" is the frame-index of the frame you'll get the
|
||||
// transition-id for next time, if you call TraceBackBestPath on this
|
||||
// iterator (assuming it's not an epsilon transition). Note that this
|
||||
// is one less than you might reasonably expect, e.g. it's -1 for
|
||||
// the nonemitting transitions before the first frame.
|
||||
BestPathIterator(void *t, int32 f): tok(t), frame(f) { }
|
||||
bool Done() const { return tok == NULL; }
|
||||
};
|
||||
|
||||
|
||||
/// Outputs an FST corresponding to the single best path through the lattice.
|
||||
/// This is quite efficient because it doesn't get the entire raw lattice and find
|
||||
/// the best path through it; instead, it uses the BestPathEnd and BestPathIterator
|
||||
/// so it basically traces it back through the lattice.
|
||||
/// Returns true if result is nonempty (using the return status is deprecated,
|
||||
/// it will become void). If "use_final_probs" is true AND we reached the
|
||||
/// final-state of the graph then it will include those as final-probs, else
|
||||
/// it will treat all final-probs as one.
|
||||
bool GetBestPath(Lattice *ofst,
|
||||
bool use_final_probs = true) const;
|
||||
|
||||
|
||||
/// This function does a self-test of GetBestPath(). Returns true on
|
||||
/// success; returns false and prints a warning on failure.
|
||||
bool TestGetBestPath(bool use_final_probs = true) const;
|
||||
|
||||
|
||||
/// This function returns an iterator that can be used to trace back
|
||||
/// the best path. If use_final_probs == true and at least one final state
|
||||
/// survived till the end, it will use the final-probs in working out the best
|
||||
/// final Token, and will output the final cost to *final_cost (if non-NULL),
|
||||
/// else it will use only the forward likelihood, and will put zero in
|
||||
/// *final_cost (if non-NULL).
|
||||
/// Requires that NumFramesDecoded() > 0.
|
||||
BestPathIterator BestPathEnd(bool use_final_probs,
|
||||
BaseFloat *final_cost = NULL) const;
|
||||
|
||||
|
||||
/// This function can be used in conjunction with BestPathEnd() to trace back
|
||||
/// the best path one link at a time (e.g. this can be useful in endpoint
|
||||
/// detection). By "link" we mean a link in the graph; not all links cross
|
||||
/// frame boundaries, but each time you see a nonzero ilabel you can interpret
|
||||
/// that as a frame. The return value is the updated iterator. It outputs
|
||||
/// the ilabel and olabel, and the (graph and acoustic) weight to the "arc" pointer,
|
||||
/// while leaving its "nextstate" variable unchanged.
|
||||
BestPathIterator TraceBackBestPath(
|
||||
BestPathIterator iter, LatticeArc *arc) const;
|
||||
|
||||
|
||||
/// Behaves the same as GetRawLattice but only processes tokens whose
|
||||
/// extra_cost is smaller than the best-cost plus the specified beam.
|
||||
/// It is only worthwhile to call this function if beam is less than
|
||||
/// the lattice_beam specified in the config; otherwise, it would
|
||||
/// return essentially the same thing as GetRawLattice, but more slowly.
|
||||
bool GetRawLatticePruned(Lattice *ofst,
|
||||
bool use_final_probs,
|
||||
BaseFloat beam) const;
|
||||
|
||||
KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeFasterOnlineDecoderTpl);
|
||||
};
|
||||
|
||||
typedef LatticeFasterOnlineDecoderTpl<fst::StdFst> LatticeFasterOnlineDecoder;
|
||||
|
||||
|
||||
} // end namespace kaldi.
|
||||
|
||||
#endif
|
@ -0,0 +1,147 @@
|
||||
// lat/determinize-lattice-pruned-test.cc
|
||||
|
||||
// Copyright 2009-2012 Microsoft Corporation
|
||||
// 2012-2013 Johns Hopkins University (Author: Daniel Povey)
|
||||
|
||||
// See ../../COPYING for clarification regarding multiple authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "lat/determinize-lattice-pruned.h"
|
||||
#include "fstext/lattice-utils.h"
|
||||
#include "fstext/fst-test-utils.h"
|
||||
#include "lat/kaldi-lattice.h"
|
||||
#include "lat/lattice-functions.h"
|
||||
|
||||
namespace fst {
|
||||
// Caution: these tests are not as generic as you might think from all the
|
||||
// templates in the code. They are basically only valid for LatticeArc.
|
||||
// This is partly due to the fact that certain templates need to be instantiated
|
||||
// in other .cc files in this directory.
|
||||
|
||||
// test that determinization proceeds correctly on general
|
||||
// FSTs (not guaranteed determinzable, but we use the
|
||||
// max-states option to stop it getting out of control).
|
||||
template<class Arc> void TestDeterminizeLatticePruned() {
|
||||
typedef kaldi::int32 Int;
|
||||
typedef typename Arc::Weight Weight;
|
||||
typedef ArcTpl<CompactLatticeWeightTpl<Weight, Int> > CompactArc;
|
||||
|
||||
for(int i = 0; i < 100; i++) {
|
||||
RandFstOptions opts;
|
||||
opts.n_states = 4;
|
||||
opts.n_arcs = 10;
|
||||
opts.n_final = 2;
|
||||
opts.allow_empty = false;
|
||||
opts.weight_multiplier = 0.5; // impt for the randomly generated weights
|
||||
opts.acyclic = true;
|
||||
// to be exactly representable in float,
|
||||
// or this test fails because numerical differences can cause symmetry in
|
||||
// weights to be broken, which causes the wrong path to be chosen as far
|
||||
// as the string part is concerned.
|
||||
|
||||
VectorFst<Arc> *fst = RandPairFst<Arc>(opts);
|
||||
|
||||
bool sorted = TopSort(fst);
|
||||
KALDI_ASSERT(sorted);
|
||||
|
||||
ILabelCompare<Arc> ilabel_comp;
|
||||
if (kaldi::Rand() % 2 == 0)
|
||||
ArcSort(fst, ilabel_comp);
|
||||
|
||||
std::cout << "FST before lattice-determinizing is:\n";
|
||||
{
|
||||
FstPrinter<Arc> fstprinter(*fst, NULL, NULL, NULL, false, true, "\t");
|
||||
fstprinter.Print(&std::cout, "standard output");
|
||||
}
|
||||
VectorFst<Arc> det_fst;
|
||||
try {
|
||||
DeterminizeLatticePrunedOptions lat_opts;
|
||||
lat_opts.max_mem = ((kaldi::Rand() % 2 == 0) ? 100 : 1000);
|
||||
lat_opts.max_states = ((kaldi::Rand() % 2 == 0) ? -1 : 20);
|
||||
lat_opts.max_arcs = ((kaldi::Rand() % 2 == 0) ? -1 : 30);
|
||||
bool ans = DeterminizeLatticePruned<Weight>(*fst, 10.0, &det_fst, lat_opts);
|
||||
|
||||
std::cout << "FST after lattice-determinizing is:\n";
|
||||
{
|
||||
FstPrinter<Arc> fstprinter(det_fst, NULL, NULL, NULL, false, true, "\t");
|
||||
fstprinter.Print(&std::cout, "standard output");
|
||||
}
|
||||
KALDI_ASSERT(det_fst.Properties(kIDeterministic, true) & kIDeterministic);
|
||||
// OK, now determinize it a different way and check equivalence.
|
||||
// [note: it's not normal determinization, it's taking the best path
|
||||
// for any input-symbol sequence....
|
||||
|
||||
|
||||
VectorFst<Arc> pruned_fst(*fst);
|
||||
if (pruned_fst.NumStates() != 0)
|
||||
kaldi::PruneLattice(10.0, &pruned_fst);
|
||||
|
||||
VectorFst<CompactArc> compact_pruned_fst, compact_pruned_det_fst;
|
||||
ConvertLattice<Weight, Int>(pruned_fst, &compact_pruned_fst, false);
|
||||
std::cout << "Compact pruned FST is:\n";
|
||||
{
|
||||
FstPrinter<CompactArc> fstprinter(compact_pruned_fst, NULL, NULL, NULL, false, true, "\t");
|
||||
fstprinter.Print(&std::cout, "standard output");
|
||||
}
|
||||
ConvertLattice<Weight, Int>(det_fst, &compact_pruned_det_fst, false);
|
||||
|
||||
std::cout << "Compact version of determinized FST is:\n";
|
||||
{
|
||||
FstPrinter<CompactArc> fstprinter(compact_pruned_det_fst, NULL, NULL, NULL, false, true, "\t");
|
||||
fstprinter.Print(&std::cout, "standard output");
|
||||
}
|
||||
|
||||
if (ans)
|
||||
KALDI_ASSERT(RandEquivalent(compact_pruned_det_fst, compact_pruned_fst, 5/*paths*/, 0.01/*delta*/, kaldi::Rand()/*seed*/, 100/*path length, max*/));
|
||||
} catch (...) {
|
||||
std::cout << "Failed to lattice-determinize this FST (probably not determinizable)\n";
|
||||
}
|
||||
delete fst;
|
||||
}
|
||||
}
|
||||
|
||||
// test that determinization proceeds without crash on acyclic FSTs
|
||||
// (guaranteed determinizable in this sense).
|
||||
template<class Arc> void TestDeterminizeLatticePruned2() {
|
||||
typedef typename Arc::Weight Weight;
|
||||
RandFstOptions opts;
|
||||
opts.acyclic = true;
|
||||
for(int i = 0; i < 100; i++) {
|
||||
VectorFst<Arc> *fst = RandPairFst<Arc>(opts);
|
||||
std::cout << "FST before lattice-determinizing is:\n";
|
||||
{
|
||||
FstPrinter<Arc> fstprinter(*fst, NULL, NULL, NULL, false, true, "\t");
|
||||
fstprinter.Print(&std::cout, "standard output");
|
||||
}
|
||||
VectorFst<Arc> ofst;
|
||||
DeterminizeLatticePruned<Weight>(*fst, 10.0, &ofst);
|
||||
std::cout << "FST after lattice-determinizing is:\n";
|
||||
{
|
||||
FstPrinter<Arc> fstprinter(ofst, NULL, NULL, NULL, false, true, "\t");
|
||||
fstprinter.Print(&std::cout, "standard output");
|
||||
}
|
||||
delete fst;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
} // end namespace fst
|
||||
|
||||
int main() {
|
||||
using namespace fst;
|
||||
TestDeterminizeLatticePruned<kaldi::LatticeArc>();
|
||||
TestDeterminizeLatticePruned2<kaldi::LatticeArc>();
|
||||
std::cout << "Tests succeeded\n";
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,296 @@
|
||||
// lat/determinize-lattice-pruned.h
|
||||
|
||||
// Copyright 2009-2012 Microsoft Corporation
|
||||
// 2012-2013 Johns Hopkins University (Author: Daniel Povey)
|
||||
// 2014 Guoguo Chen
|
||||
|
||||
// See ../../COPYING for clarification regarding multiple authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef KALDI_LAT_DETERMINIZE_LATTICE_PRUNED_H_
|
||||
#define KALDI_LAT_DETERMINIZE_LATTICE_PRUNED_H_
|
||||
#include <fst/fstlib.h>
|
||||
#include <fst/fst-decl.h>
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include "fstext/lattice-weight.h"
|
||||
#include "itf/transition-information.h"
|
||||
#include "itf/options-itf.h"
|
||||
#include "lat/kaldi-lattice.h"
|
||||
|
||||
namespace fst {
|
||||
|
||||
/// \addtogroup fst_extensions
|
||||
/// @{
|
||||
|
||||
|
||||
// For example of usage, see test-determinize-lattice-pruned.cc
|
||||
|
||||
/*
|
||||
DeterminizeLatticePruned implements a special form of determinization with
|
||||
epsilon removal, optimized for a phase of lattice generation. This algorithm
|
||||
also does pruning at the same time-- the combination is more efficient as it
|
||||
somtimes prevents us from creating a lot of states that would later be pruned
|
||||
away. This allows us to increase the lattice-beam and not have the algorithm
|
||||
blow up. Also, because our algorithm processes states in order from those
|
||||
that appear on high-scoring paths down to those that appear on low-scoring
|
||||
paths, we can easily terminate the algorithm after a certain specified number
|
||||
of states or arcs.
|
||||
|
||||
The input is an FST with weight-type BaseWeightType (usually a pair of floats,
|
||||
with a lexicographical type of order, such as LatticeWeightTpl<float>).
|
||||
Typically this would be a state-level lattice, with input symbols equal to
|
||||
words, and output-symbols equal to p.d.f's (so like the inverse of HCLG). Imagine representing this as an
|
||||
acceptor of type CompactLatticeWeightTpl<float>, in which the input/output
|
||||
symbols are words, and the weights contain the original weights together with
|
||||
strings (with zero or one symbol in them) containing the original output labels
|
||||
(the p.d.f.'s). We determinize this using acceptor determinization with
|
||||
epsilon removal. Remember (from lattice-weight.h) that
|
||||
CompactLatticeWeightTpl has a special kind of semiring where we always take
|
||||
the string corresponding to the best cost (of type BaseWeightType), and
|
||||
discard the other. This corresponds to taking the best output-label sequence
|
||||
(of p.d.f.'s) for each input-label sequence (of words). We couldn't use the
|
||||
Gallic weight for this, or it would die as soon as it detected that the input
|
||||
FST was non-functional. In our case, any acyclic FST (and many cyclic ones)
|
||||
can be determinized.
|
||||
We assume that there is a function
|
||||
Compare(const BaseWeightType &a, const BaseWeightType &b)
|
||||
that returns (-1, 0, 1) according to whether (a < b, a == b, a > b) in the
|
||||
total order on the BaseWeightType... this information should be the
|
||||
same as NaturalLess would give, but it's more efficient to do it this way.
|
||||
You can define this for things like TropicalWeight if you need to instantiate
|
||||
this class for that weight type.
|
||||
|
||||
We implement this determinization in a special way to make it efficient for
|
||||
the types of FSTs that we will apply it to. One issue is that if we
|
||||
explicitly represent the strings (in CompactLatticeWeightTpl) as vectors of
|
||||
type vector<IntType>, the algorithm takes time quadratic in the length of
|
||||
words (in states), because propagating each arc involves copying a whole
|
||||
vector (of integers representing p.d.f.'s). Instead we use a hash structure
|
||||
where each string is a pointer (Entry*), and uses a hash from (Entry*,
|
||||
IntType), to the successor string (and a way to get the latest IntType and the
|
||||
ancestor Entry*). [this is the class LatticeStringRepository].
|
||||
|
||||
Another issue is that rather than representing a determinized-state as a
|
||||
collection of (state, weight), we represent it in a couple of reduced forms.
|
||||
Suppose a determinized-state is a collection of (state, weight) pairs; call
|
||||
this the "canonical representation". Note: these collections are always
|
||||
normalized to remove any common weight and string part. Define end-states as
|
||||
the subset of states that have an arc out of them with a label on, or are
|
||||
final. If we represent a determinized-state a the set of just its (end-state,
|
||||
weight) pairs, this will be a valid and more compact representation, and will
|
||||
lead to a smaller set of determinized states (like early minimization). Call
|
||||
this collection of (end-state, weight) pairs the "minimal representation". As
|
||||
a mechanism to reduce compute, we can also consider another representation.
|
||||
In the determinization algorithm, we start off with a set of (begin-state,
|
||||
weight) pairs (where the "begin-states" are initial or have a label on the
|
||||
transition into them), and the "canonical representation" consists of the
|
||||
epsilon-closure of this set (i.e. follow epsilons). Call this set of
|
||||
(begin-state, weight) pairs, appropriately normalized, the "initial
|
||||
representation". If two initial representations are the same, the "canonical
|
||||
representation" and hence the "minimal representation" will be the same. We
|
||||
can use this to reduce compute. Note that if two initial representations are
|
||||
different, this does not preclude the other representations from being the same.
|
||||
|
||||
*/
|
||||
|
||||
|
||||
struct DeterminizeLatticePrunedOptions {
|
||||
float delta; // A small offset used to measure equality of weights.
|
||||
int max_mem; // If >0, determinization will fail and return false
|
||||
// when the algorithm's (approximate) memory consumption crosses this threshold.
|
||||
int max_loop; // If >0, can be used to detect non-determinizable input
|
||||
// (a case that wouldn't be caught by max_mem).
|
||||
int max_states;
|
||||
int max_arcs;
|
||||
float retry_cutoff;
|
||||
DeterminizeLatticePrunedOptions(): delta(kDelta),
|
||||
max_mem(-1),
|
||||
max_loop(-1),
|
||||
max_states(-1),
|
||||
max_arcs(-1),
|
||||
retry_cutoff(0.5) { }
|
||||
void Register (kaldi::OptionsItf *opts) {
|
||||
opts->Register("delta", &delta, "Tolerance used in determinization");
|
||||
opts->Register("max-mem", &max_mem, "Maximum approximate memory usage in "
|
||||
"determinization (real usage might be many times this)");
|
||||
opts->Register("max-arcs", &max_arcs, "Maximum number of arcs in "
|
||||
"output FST (total, not per state");
|
||||
opts->Register("max-states", &max_states, "Maximum number of arcs in output "
|
||||
"FST (total, not per state");
|
||||
opts->Register("max-loop", &max_loop, "Option used to detect a particular "
|
||||
"type of determinization failure, typically due to invalid input "
|
||||
"(e.g., negative-cost loops)");
|
||||
opts->Register("retry-cutoff", &retry_cutoff, "Controls pruning un-determinized "
|
||||
"lattice and retrying determinization: if effective-beam < "
|
||||
"retry-cutoff * beam, we prune the raw lattice and retry. Avoids "
|
||||
"ever getting empty output for long segments.");
|
||||
}
|
||||
};
|
||||
|
||||
struct DeterminizeLatticePhonePrunedOptions {
|
||||
// delta: a small offset used to measure equality of weights.
|
||||
float delta;
|
||||
// max_mem: if > 0, determinization will fail and return false when the
|
||||
// algorithm's (approximate) memory consumption crosses this threshold.
|
||||
int max_mem;
|
||||
// phone_determinize: if true, do a first pass determinization on both phones
|
||||
// and words.
|
||||
bool phone_determinize;
|
||||
// word_determinize: if true, do a second pass determinization on words only.
|
||||
bool word_determinize;
|
||||
// minimize: if true, push and minimize after determinization.
|
||||
bool minimize;
|
||||
DeterminizeLatticePhonePrunedOptions(): delta(kDelta),
|
||||
max_mem(50000000),
|
||||
phone_determinize(true),
|
||||
word_determinize(true),
|
||||
minimize(false) {}
|
||||
void Register (kaldi::OptionsItf *opts) {
|
||||
opts->Register("delta", &delta, "Tolerance used in determinization");
|
||||
opts->Register("max-mem", &max_mem, "Maximum approximate memory usage in "
|
||||
"determinization (real usage might be many times this).");
|
||||
opts->Register("phone-determinize", &phone_determinize, "If true, do an "
|
||||
"initial pass of determinization on both phones and words (see"
|
||||
" also --word-determinize)");
|
||||
opts->Register("word-determinize", &word_determinize, "If true, do a second "
|
||||
"pass of determinization on words only (see also "
|
||||
"--phone-determinize)");
|
||||
opts->Register("minimize", &minimize, "If true, push and minimize after "
|
||||
"determinization.");
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
This function implements the normal version of DeterminizeLattice, in which the
|
||||
output strings are represented using sequences of arcs, where all but the
|
||||
first one has an epsilon on the input side. It also prunes using the beam
|
||||
in the "prune" parameter. The input FST must be topologically sorted in order
|
||||
for the algorithm to work. For efficiency it is recommended to sort ilabel as well.
|
||||
Returns true on success, and false if it had to terminate the determinization
|
||||
earlier than specified by the "prune" beam-- that is, if it terminated because
|
||||
of the max_mem, max_loop or max_arcs constraints in the options.
|
||||
CAUTION: you may want to use the version below which outputs to CompactLattice.
|
||||
*/
|
||||
template<class Weight>
|
||||
bool DeterminizeLatticePruned(
|
||||
const ExpandedFst<ArcTpl<Weight> > &ifst,
|
||||
double prune,
|
||||
MutableFst<ArcTpl<Weight> > *ofst,
|
||||
DeterminizeLatticePrunedOptions opts = DeterminizeLatticePrunedOptions());
|
||||
|
||||
|
||||
/* This is a version of DeterminizeLattice with a slightly more "natural" output format,
|
||||
where the output sequences are encoded using the CompactLatticeArcTpl template
|
||||
(i.e. the sequences of output symbols are represented directly as strings The input
|
||||
FST must be topologically sorted in order for the algorithm to work. For efficiency
|
||||
it is recommended to sort the ilabel for the input FST as well.
|
||||
Returns true on normal success, and false if it had to terminate the determinization
|
||||
earlier than specified by the "prune" beam-- that is, if it terminated because
|
||||
of the max_mem, max_loop or max_arcs constraints in the options.
|
||||
CAUTION: if Lattice is the input, you need to Invert() before calling this,
|
||||
so words are on the input side.
|
||||
*/
|
||||
template<class Weight, class IntType>
|
||||
bool DeterminizeLatticePruned(
|
||||
const ExpandedFst<ArcTpl<Weight> >&ifst,
|
||||
double prune,
|
||||
MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > > *ofst,
|
||||
DeterminizeLatticePrunedOptions opts = DeterminizeLatticePrunedOptions());
|
||||
|
||||
/** This function takes in lattices and inserts phones at phone boundaries. It
|
||||
uses the transition model to work out the transition_id to phone map. The
|
||||
returning value is the starting index of the phone label. Typically we pick
|
||||
(maximum_output_label_index + 1) as this value. The inserted phones are then
|
||||
mapped to (returning_value + original_phone_label) in the new lattice. The
|
||||
returning value will be used by DeterminizeLatticeDeletePhones() where it
|
||||
works out the phones according to this value.
|
||||
*/
|
||||
template<class Weight>
|
||||
typename ArcTpl<Weight>::Label DeterminizeLatticeInsertPhones(
|
||||
const kaldi::TransitionInformation &trans_model,
|
||||
MutableFst<ArcTpl<Weight> > *fst);
|
||||
|
||||
/** This function takes in lattices and deletes "phones" from them. The "phones"
|
||||
here are actually any label that is larger than first_phone_label because
|
||||
when we insert phones into the lattice, we map the original phone label to
|
||||
(first_phone_label + original_phone_label). It is supposed to be used
|
||||
together with DeterminizeLatticeInsertPhones()
|
||||
*/
|
||||
template<class Weight>
|
||||
void DeterminizeLatticeDeletePhones(
|
||||
typename ArcTpl<Weight>::Label first_phone_label,
|
||||
MutableFst<ArcTpl<Weight> > *fst);
|
||||
|
||||
/** This function is a wrapper of DeterminizeLatticePhonePrunedFirstPass() and
|
||||
DeterminizeLatticePruned(). If --phone-determinize is set to true, it first
|
||||
calls DeterminizeLatticePhonePrunedFirstPass() to do the initial pass of
|
||||
determinization on the phone + word lattices. If --word-determinize is set
|
||||
true, it then does a second pass of determinization on the word lattices by
|
||||
calling DeterminizeLatticePruned(). If both are set to false, then it gives
|
||||
a warning and copying the lattices without determinization.
|
||||
|
||||
Note: the point of doing first a phone-level determinization pass and then
|
||||
a word-level determinization pass is that it allows us to determinize
|
||||
deeper lattices without "failing early" and returning a too-small lattice
|
||||
due to the max-mem constraint. The result should be the same as word-level
|
||||
determinization in general, but for deeper lattices it is a bit faster,
|
||||
despite the fact that we now have two passes of determinization by default.
|
||||
*/
|
||||
template<class Weight, class IntType>
|
||||
bool DeterminizeLatticePhonePruned(
|
||||
const kaldi::TransitionInformation &trans_model,
|
||||
const ExpandedFst<ArcTpl<Weight> > &ifst,
|
||||
double prune,
|
||||
MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > > *ofst,
|
||||
DeterminizeLatticePhonePrunedOptions opts
|
||||
= DeterminizeLatticePhonePrunedOptions());
|
||||
|
||||
/** "Destructive" version of DeterminizeLatticePhonePruned() where the input
|
||||
lattice might be changed.
|
||||
*/
|
||||
template<class Weight, class IntType>
|
||||
bool DeterminizeLatticePhonePruned(
|
||||
const kaldi::TransitionInformation &trans_model,
|
||||
MutableFst<ArcTpl<Weight> > *ifst,
|
||||
double prune,
|
||||
MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > > *ofst,
|
||||
DeterminizeLatticePhonePrunedOptions opts
|
||||
= DeterminizeLatticePhonePrunedOptions());
|
||||
|
||||
/** This function is a wrapper of DeterminizeLatticePhonePruned() that works for
|
||||
Lattice type FSTs. It simplifies the calling process by calling
|
||||
TopSort() Invert() and ArcSort() for you.
|
||||
Unlike other determinization routines, the function
|
||||
requires "ifst" to have transition-id's on the input side and words on the
|
||||
output side.
|
||||
This function can be used as the top-level interface to all the determinization
|
||||
code.
|
||||
*/
|
||||
bool DeterminizeLatticePhonePrunedWrapper(
|
||||
const kaldi::TransitionInformation &trans_model,
|
||||
MutableFst<kaldi::LatticeArc> *ifst,
|
||||
double prune,
|
||||
MutableFst<kaldi::CompactLatticeArc> *ofst,
|
||||
DeterminizeLatticePhonePrunedOptions opts
|
||||
= DeterminizeLatticePhonePrunedOptions());
|
||||
|
||||
/// @} end "addtogroup fst_extensions"
|
||||
|
||||
} // end namespace fst
|
||||
|
||||
#endif
|
@ -0,0 +1,506 @@
|
||||
// lat/kaldi-lattice.cc
|
||||
|
||||
// Copyright 2009-2011 Microsoft Corporation
|
||||
// 2013 Johns Hopkins University (author: Daniel Povey)
|
||||
|
||||
// See ../../COPYING for clarification regarding multiple authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
#include "lat/kaldi-lattice.h"
|
||||
#include "fst/script/print-impl.h"
|
||||
|
||||
namespace kaldi {
|
||||
|
||||
/// Converts lattice types if necessary, deleting its input.
|
||||
template<class OrigWeightType>
|
||||
CompactLattice* ConvertToCompactLattice(fst::VectorFst<OrigWeightType> *ifst) {
|
||||
if (!ifst) return NULL;
|
||||
CompactLattice *ofst = new CompactLattice();
|
||||
ConvertLattice(*ifst, ofst);
|
||||
delete ifst;
|
||||
return ofst;
|
||||
}
|
||||
|
||||
// This overrides the template if there is no type conversion going on
|
||||
// (for efficiency).
|
||||
template<>
|
||||
CompactLattice* ConvertToCompactLattice(CompactLattice *ifst) {
|
||||
return ifst;
|
||||
}
|
||||
|
||||
/// Converts lattice types if necessary, deleting its input.
|
||||
template<class OrigWeightType>
|
||||
Lattice* ConvertToLattice(fst::VectorFst<OrigWeightType> *ifst) {
|
||||
if (!ifst) return NULL;
|
||||
Lattice *ofst = new Lattice();
|
||||
ConvertLattice(*ifst, ofst);
|
||||
delete ifst;
|
||||
return ofst;
|
||||
}
|
||||
|
||||
// This overrides the template if there is no type conversion going on
|
||||
// (for efficiency).
|
||||
template<>
|
||||
Lattice* ConvertToLattice(Lattice *ifst) {
|
||||
return ifst;
|
||||
}
|
||||
|
||||
|
||||
bool WriteCompactLattice(std::ostream &os, bool binary,
|
||||
const CompactLattice &t) {
|
||||
if (binary) {
|
||||
fst::FstWriteOptions opts;
|
||||
// Leave all the options default. Normally these lattices wouldn't have any
|
||||
// osymbols/isymbols so no point directing it not to write them (who knows what
|
||||
// we'd want to if we had them).
|
||||
return t.Write(os, opts);
|
||||
} else {
|
||||
// Text-mode output. Note: we expect that t.InputSymbols() and
|
||||
// t.OutputSymbols() would always return NULL. The corresponding input
|
||||
// routine would not work if the FST actually had symbols attached.
|
||||
// Write a newline after the key, so the first line of the FST appears
|
||||
// on its own line.
|
||||
os << '\n';
|
||||
bool acceptor = true, write_one = false;
|
||||
fst::FstPrinter<CompactLatticeArc> printer(t, t.InputSymbols(),
|
||||
t.OutputSymbols(),
|
||||
NULL, acceptor, write_one, "\t");
|
||||
printer.Print(&os, "<unknown>");
|
||||
if (os.fail())
|
||||
KALDI_WARN << "Stream failure detected.";
|
||||
// Write another newline as a terminating character. The read routine will
|
||||
// detect this [this is a Kaldi mechanism, not somethig in the original
|
||||
// OpenFst code].
|
||||
os << '\n';
|
||||
return os.good();
|
||||
}
|
||||
}
|
||||
|
||||
/// LatticeReader provides (static) functions for reading both Lattice
|
||||
/// and CompactLattice, in text form.
|
||||
class LatticeReader {
|
||||
typedef LatticeArc Arc;
|
||||
typedef LatticeWeight Weight;
|
||||
typedef CompactLatticeArc CArc;
|
||||
typedef CompactLatticeWeight CWeight;
|
||||
typedef Arc::Label Label;
|
||||
typedef Arc::StateId StateId;
|
||||
public:
|
||||
// everything is static in this class.
|
||||
|
||||
/** This function reads from the FST text format; it does not know in advance
|
||||
whether it's a Lattice or CompactLattice in the stream so it tries to
|
||||
read both formats until it becomes clear which is the correct one.
|
||||
*/
|
||||
static std::pair<Lattice*, CompactLattice*> ReadText(
|
||||
std::istream &is) {
|
||||
typedef std::pair<Lattice*, CompactLattice*> PairT;
|
||||
using std::string;
|
||||
using std::vector;
|
||||
Lattice *fst = new Lattice();
|
||||
CompactLattice *cfst = new CompactLattice();
|
||||
string line;
|
||||
size_t nline = 0;
|
||||
string separator = FLAGS_fst_field_separator + "\r\n";
|
||||
while (std::getline(is, line)) {
|
||||
nline++;
|
||||
vector<string> col;
|
||||
// on Windows we'll write in text and read in binary mode.
|
||||
SplitStringToVector(line, separator.c_str(), true, &col);
|
||||
if (col.size() == 0) break; // Empty line is a signal to stop, in our
|
||||
// archive format.
|
||||
if (col.size() > 5) {
|
||||
KALDI_WARN << "Reading lattice: bad line in FST: " << line;
|
||||
delete fst;
|
||||
delete cfst;
|
||||
return PairT(static_cast<Lattice*>(NULL),
|
||||
static_cast<CompactLattice*>(NULL));
|
||||
}
|
||||
StateId s;
|
||||
if (!ConvertStringToInteger(col[0], &s)) {
|
||||
KALDI_WARN << "FstCompiler: bad line in FST: " << line;
|
||||
delete fst;
|
||||
delete cfst;
|
||||
return PairT(static_cast<Lattice*>(NULL),
|
||||
static_cast<CompactLattice*>(NULL));
|
||||
}
|
||||
if (fst)
|
||||
while (s >= fst->NumStates())
|
||||
fst->AddState();
|
||||
if (cfst)
|
||||
while (s >= cfst->NumStates())
|
||||
cfst->AddState();
|
||||
if (nline == 1) {
|
||||
if (fst) fst->SetStart(s);
|
||||
if (cfst) cfst->SetStart(s);
|
||||
}
|
||||
|
||||
if (fst) { // we still have fst; try to read that arc.
|
||||
bool ok = true;
|
||||
Arc arc;
|
||||
Weight w;
|
||||
StateId d = s;
|
||||
switch (col.size()) {
|
||||
case 1 :
|
||||
fst->SetFinal(s, Weight::One());
|
||||
break;
|
||||
case 2:
|
||||
if (!StrToWeight(col[1], true, &w)) ok = false;
|
||||
else fst->SetFinal(s, w);
|
||||
break;
|
||||
case 3: // 3 columns not ok for Lattice format; it's not an acceptor.
|
||||
ok = false;
|
||||
break;
|
||||
case 4:
|
||||
ok = ConvertStringToInteger(col[1], &arc.nextstate) &&
|
||||
ConvertStringToInteger(col[2], &arc.ilabel) &&
|
||||
ConvertStringToInteger(col[3], &arc.olabel);
|
||||
if (ok) {
|
||||
d = arc.nextstate;
|
||||
arc.weight = Weight::One();
|
||||
fst->AddArc(s, arc);
|
||||
}
|
||||
break;
|
||||
case 5:
|
||||
ok = ConvertStringToInteger(col[1], &arc.nextstate) &&
|
||||
ConvertStringToInteger(col[2], &arc.ilabel) &&
|
||||
ConvertStringToInteger(col[3], &arc.olabel) &&
|
||||
StrToWeight(col[4], false, &arc.weight);
|
||||
if (ok) {
|
||||
d = arc.nextstate;
|
||||
fst->AddArc(s, arc);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
ok = false;
|
||||
}
|
||||
while (d >= fst->NumStates())
|
||||
fst->AddState();
|
||||
if (!ok) {
|
||||
delete fst;
|
||||
fst = NULL;
|
||||
}
|
||||
}
|
||||
if (cfst) {
|
||||
bool ok = true;
|
||||
CArc arc;
|
||||
CWeight w;
|
||||
StateId d = s;
|
||||
switch (col.size()) {
|
||||
case 1 :
|
||||
cfst->SetFinal(s, CWeight::One());
|
||||
break;
|
||||
case 2:
|
||||
if (!StrToCWeight(col[1], true, &w)) ok = false;
|
||||
else cfst->SetFinal(s, w);
|
||||
break;
|
||||
case 3: // compact-lattice is acceptor format: state, next-state, label.
|
||||
ok = ConvertStringToInteger(col[1], &arc.nextstate) &&
|
||||
ConvertStringToInteger(col[2], &arc.ilabel);
|
||||
if (ok) {
|
||||
d = arc.nextstate;
|
||||
arc.olabel = arc.ilabel;
|
||||
arc.weight = CWeight::One();
|
||||
cfst->AddArc(s, arc);
|
||||
}
|
||||
break;
|
||||
case 4:
|
||||
ok = ConvertStringToInteger(col[1], &arc.nextstate) &&
|
||||
ConvertStringToInteger(col[2], &arc.ilabel) &&
|
||||
StrToCWeight(col[3], false, &arc.weight);
|
||||
if (ok) {
|
||||
d = arc.nextstate;
|
||||
arc.olabel = arc.ilabel;
|
||||
cfst->AddArc(s, arc);
|
||||
}
|
||||
break;
|
||||
case 5: default:
|
||||
ok = false;
|
||||
}
|
||||
while (d >= cfst->NumStates())
|
||||
cfst->AddState();
|
||||
if (!ok) {
|
||||
delete cfst;
|
||||
cfst = NULL;
|
||||
}
|
||||
}
|
||||
if (!fst && !cfst) {
|
||||
KALDI_WARN << "Bad line in lattice text format: " << line;
|
||||
// read until we get an empty line, so at least we
|
||||
// have a chance to read the next one (although this might
|
||||
// be a bit futile since the calling code will get unhappy
|
||||
// about failing to read this one.
|
||||
while (std::getline(is, line)) {
|
||||
SplitStringToVector(line, separator.c_str(), true, &col);
|
||||
if (col.empty()) break;
|
||||
}
|
||||
return PairT(static_cast<Lattice*>(NULL),
|
||||
static_cast<CompactLattice*>(NULL));
|
||||
}
|
||||
}
|
||||
return PairT(fst, cfst);
|
||||
}
|
||||
|
||||
static bool StrToWeight(const std::string &s, bool allow_zero, Weight *w) {
|
||||
std::istringstream strm(s);
|
||||
strm >> *w;
|
||||
if (!strm || (!allow_zero && *w == Weight::Zero())) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool StrToCWeight(const std::string &s, bool allow_zero, CWeight *w) {
|
||||
std::istringstream strm(s);
|
||||
strm >> *w;
|
||||
if (!strm || (!allow_zero && *w == CWeight::Zero())) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
CompactLattice *ReadCompactLatticeText(std::istream &is) {
|
||||
std::pair<Lattice*, CompactLattice*> lat_pair = LatticeReader::ReadText(is);
|
||||
if (lat_pair.second != NULL) {
|
||||
delete lat_pair.first;
|
||||
return lat_pair.second;
|
||||
} else if (lat_pair.first != NULL) {
|
||||
// note: ConvertToCompactLattice frees its input.
|
||||
return ConvertToCompactLattice(lat_pair.first);
|
||||
} else {
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Lattice *ReadLatticeText(std::istream &is) {
|
||||
std::pair<Lattice*, CompactLattice*> lat_pair = LatticeReader::ReadText(is);
|
||||
if (lat_pair.first != NULL) {
|
||||
delete lat_pair.second;
|
||||
return lat_pair.first;
|
||||
} else if (lat_pair.second != NULL) {
|
||||
// note: ConvertToLattice frees its input.
|
||||
return ConvertToLattice(lat_pair.second);
|
||||
} else {
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
bool ReadCompactLattice(std::istream &is, bool binary,
|
||||
CompactLattice **clat) {
|
||||
KALDI_ASSERT(*clat == NULL);
|
||||
if (binary) {
|
||||
fst::FstHeader hdr;
|
||||
if (!hdr.Read(is, "<unknown>")) {
|
||||
KALDI_WARN << "Reading compact lattice: error reading FST header.";
|
||||
return false;
|
||||
}
|
||||
if (hdr.FstType() != "vector") {
|
||||
KALDI_WARN << "Reading compact lattice: unsupported FST type: "
|
||||
<< hdr.FstType();
|
||||
return false;
|
||||
}
|
||||
fst::FstReadOptions ropts("<unspecified>",
|
||||
&hdr);
|
||||
|
||||
typedef fst::CompactLatticeWeightTpl<fst::LatticeWeightTpl<float>, int32> T1;
|
||||
typedef fst::CompactLatticeWeightTpl<fst::LatticeWeightTpl<double>, int32> T2;
|
||||
typedef fst::LatticeWeightTpl<float> T3;
|
||||
typedef fst::LatticeWeightTpl<double> T4;
|
||||
typedef fst::VectorFst<fst::ArcTpl<T1> > F1;
|
||||
typedef fst::VectorFst<fst::ArcTpl<T2> > F2;
|
||||
typedef fst::VectorFst<fst::ArcTpl<T3> > F3;
|
||||
typedef fst::VectorFst<fst::ArcTpl<T4> > F4;
|
||||
|
||||
CompactLattice *ans = NULL;
|
||||
if (hdr.ArcType() == T1::Type()) {
|
||||
ans = ConvertToCompactLattice(F1::Read(is, ropts));
|
||||
} else if (hdr.ArcType() == T2::Type()) {
|
||||
ans = ConvertToCompactLattice(F2::Read(is, ropts));
|
||||
} else if (hdr.ArcType() == T3::Type()) {
|
||||
ans = ConvertToCompactLattice(F3::Read(is, ropts));
|
||||
} else if (hdr.ArcType() == T4::Type()) {
|
||||
ans = ConvertToCompactLattice(F4::Read(is, ropts));
|
||||
} else {
|
||||
KALDI_WARN << "FST with arc type " << hdr.ArcType()
|
||||
<< " cannot be converted to CompactLattice.\n";
|
||||
return false;
|
||||
}
|
||||
if (ans == NULL) {
|
||||
KALDI_WARN << "Error reading compact lattice (after reading header).";
|
||||
return false;
|
||||
}
|
||||
*clat = ans;
|
||||
return true;
|
||||
} else {
|
||||
// The next line would normally consume the \r on Windows, plus any
|
||||
// extra spaces that might have got in there somehow.
|
||||
while (std::isspace(is.peek()) && is.peek() != '\n') is.get();
|
||||
if (is.peek() == '\n') is.get(); // consume the newline.
|
||||
else { // saw spaces but no newline.. this is not expected.
|
||||
KALDI_WARN << "Reading compact lattice: unexpected sequence of spaces "
|
||||
<< " at file position " << is.tellg();
|
||||
return false;
|
||||
}
|
||||
*clat = ReadCompactLatticeText(is); // that routine will warn on error.
|
||||
return (*clat != NULL);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
bool CompactLatticeHolder::Read(std::istream &is) {
|
||||
Clear(); // in case anything currently stored.
|
||||
int c = is.peek();
|
||||
if (c == -1) {
|
||||
KALDI_WARN << "End of stream detected reading CompactLattice.";
|
||||
return false;
|
||||
} else if (isspace(c)) { // The text form of the lattice begins
|
||||
// with space (normally, '\n'), so this means it's text (the binary form
|
||||
// cannot begin with space because it starts with the FST Type() which is not
|
||||
// space).
|
||||
return ReadCompactLattice(is, false, &t_);
|
||||
} else if (c != 214) { // 214 is first char of FST magic number,
|
||||
// on little-endian machines which is all we support (\326 octal)
|
||||
KALDI_WARN << "Reading compact lattice: does not appear to be an FST "
|
||||
<< " [non-space but no magic number detected], file pos is "
|
||||
<< is.tellg();
|
||||
return false;
|
||||
} else {
|
||||
return ReadCompactLattice(is, true, &t_);
|
||||
}
|
||||
}
|
||||
|
||||
bool WriteLattice(std::ostream &os, bool binary, const Lattice &t) {
|
||||
if (binary) {
|
||||
fst::FstWriteOptions opts;
|
||||
// Leave all the options default. Normally these lattices wouldn't have any
|
||||
// osymbols/isymbols so no point directing it not to write them (who knows what
|
||||
// we'd want to do if we had them).
|
||||
return t.Write(os, opts);
|
||||
} else {
|
||||
// Text-mode output. Note: we expect that t.InputSymbols() and
|
||||
// t.OutputSymbols() would always return NULL. The corresponding input
|
||||
// routine would not work if the FST actually had symbols attached.
|
||||
// Write a newline after the key, so the first line of the FST appears
|
||||
// on its own line.
|
||||
os << '\n';
|
||||
bool acceptor = false, write_one = false;
|
||||
fst::FstPrinter<LatticeArc> printer(t, t.InputSymbols(),
|
||||
t.OutputSymbols(),
|
||||
NULL, acceptor, write_one, "\t");
|
||||
printer.Print(&os, "<unknown>");
|
||||
if (os.fail())
|
||||
KALDI_WARN << "Stream failure detected.";
|
||||
// Write another newline as a terminating character. The read routine will
|
||||
// detect this [this is a Kaldi mechanism, not somethig in the original
|
||||
// OpenFst code].
|
||||
os << '\n';
|
||||
return os.good();
|
||||
}
|
||||
}
|
||||
|
||||
bool ReadLattice(std::istream &is, bool binary,
|
||||
Lattice **lat) {
|
||||
KALDI_ASSERT(*lat == NULL);
|
||||
if (binary) {
|
||||
fst::FstHeader hdr;
|
||||
if (!hdr.Read(is, "<unknown>")) {
|
||||
KALDI_WARN << "Reading lattice: error reading FST header.";
|
||||
return false;
|
||||
}
|
||||
if (hdr.FstType() != "vector") {
|
||||
KALDI_WARN << "Reading lattice: unsupported FST type: "
|
||||
<< hdr.FstType();
|
||||
return false;
|
||||
}
|
||||
fst::FstReadOptions ropts("<unspecified>",
|
||||
&hdr);
|
||||
|
||||
typedef fst::CompactLatticeWeightTpl<fst::LatticeWeightTpl<float>, int32> T1;
|
||||
typedef fst::CompactLatticeWeightTpl<fst::LatticeWeightTpl<double>, int32> T2;
|
||||
typedef fst::LatticeWeightTpl<float> T3;
|
||||
typedef fst::LatticeWeightTpl<double> T4;
|
||||
typedef fst::VectorFst<fst::ArcTpl<T1> > F1;
|
||||
typedef fst::VectorFst<fst::ArcTpl<T2> > F2;
|
||||
typedef fst::VectorFst<fst::ArcTpl<T3> > F3;
|
||||
typedef fst::VectorFst<fst::ArcTpl<T4> > F4;
|
||||
|
||||
Lattice *ans = NULL;
|
||||
if (hdr.ArcType() == T1::Type()) {
|
||||
ans = ConvertToLattice(F1::Read(is, ropts));
|
||||
} else if (hdr.ArcType() == T2::Type()) {
|
||||
ans = ConvertToLattice(F2::Read(is, ropts));
|
||||
} else if (hdr.ArcType() == T3::Type()) {
|
||||
ans = ConvertToLattice(F3::Read(is, ropts));
|
||||
} else if (hdr.ArcType() == T4::Type()) {
|
||||
ans = ConvertToLattice(F4::Read(is, ropts));
|
||||
} else {
|
||||
KALDI_WARN << "FST with arc type " << hdr.ArcType()
|
||||
<< " cannot be converted to Lattice.\n";
|
||||
return false;
|
||||
}
|
||||
if (ans == NULL) {
|
||||
KALDI_WARN << "Error reading lattice (after reading header).";
|
||||
return false;
|
||||
}
|
||||
*lat = ans;
|
||||
return true;
|
||||
} else {
|
||||
// The next line would normally consume the \r on Windows, plus any
|
||||
// extra spaces that might have got in there somehow.
|
||||
while (std::isspace(is.peek()) && is.peek() != '\n') is.get();
|
||||
if (is.peek() == '\n') is.get(); // consume the newline.
|
||||
else { // saw spaces but no newline.. this is not expected.
|
||||
KALDI_WARN << "Reading compact lattice: unexpected sequence of spaces "
|
||||
<< " at file position " << is.tellg();
|
||||
return false;
|
||||
}
|
||||
*lat = ReadLatticeText(is); // that routine will warn on error.
|
||||
return (*lat != NULL);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/* Since we don't write the binary headers for this type of holder,
|
||||
we use a different method to work out whether we're in binary mode.
|
||||
*/
|
||||
bool LatticeHolder::Read(std::istream &is) {
|
||||
Clear(); // in case anything currently stored.
|
||||
int c = is.peek();
|
||||
if (c == -1) {
|
||||
KALDI_WARN << "End of stream detected reading Lattice.";
|
||||
return false;
|
||||
} else if (isspace(c)) { // The text form of the lattice begins
|
||||
// with space (normally, '\n'), so this means it's text (the binary form
|
||||
// cannot begin with space because it starts with the FST Type() which is not
|
||||
// space).
|
||||
return ReadLattice(is, false, &t_);
|
||||
} else if (c != 214) { // 214 is first char of FST magic number,
|
||||
// on little-endian machines which is all we support (\326 octal)
|
||||
KALDI_WARN << "Reading compact lattice: does not appear to be an FST "
|
||||
<< " [non-space but no magic number detected], file pos is "
|
||||
<< is.tellg();
|
||||
return false;
|
||||
} else {
|
||||
return ReadLattice(is, true, &t_);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
} // end namespace kaldi
|
@ -0,0 +1,156 @@
|
||||
// lat/kaldi-lattice.h
|
||||
|
||||
// Copyright 2009-2011 Microsoft Corporation
|
||||
|
||||
// See ../../COPYING for clarification regarding multiple authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
#ifndef KALDI_LAT_KALDI_LATTICE_H_
|
||||
#define KALDI_LAT_KALDI_LATTICE_H_
|
||||
|
||||
#include "fstext/fstext-lib.h"
|
||||
#include "base/kaldi-common.h"
|
||||
#include "util/common-utils.h"
|
||||
|
||||
|
||||
namespace kaldi {
|
||||
// will import some things above...
|
||||
|
||||
typedef fst::LatticeWeightTpl<BaseFloat> LatticeWeight;
|
||||
|
||||
// careful: kaldi::int32 is not always the same C type as fst::int32
|
||||
typedef fst::CompactLatticeWeightTpl<LatticeWeight, int32> CompactLatticeWeight;
|
||||
|
||||
typedef fst::CompactLatticeWeightCommonDivisorTpl<LatticeWeight, int32>
|
||||
CompactLatticeWeightCommonDivisor;
|
||||
|
||||
typedef fst::ArcTpl<LatticeWeight> LatticeArc;
|
||||
|
||||
typedef fst::ArcTpl<CompactLatticeWeight> CompactLatticeArc;
|
||||
|
||||
typedef fst::VectorFst<LatticeArc> Lattice;
|
||||
|
||||
typedef fst::VectorFst<CompactLatticeArc> CompactLattice;
|
||||
|
||||
// The following functions for writing and reading lattices in binary or text
|
||||
// form are provided here in case you need to include lattices in larger,
|
||||
// Kaldi-type objects with their own Read and Write functions. Caution: these
|
||||
// functions return false on stream failure rather than throwing an exception as
|
||||
// most similar Kaldi functions would do.
|
||||
|
||||
bool WriteCompactLattice(std::ostream &os, bool binary,
|
||||
const CompactLattice &clat);
|
||||
bool WriteLattice(std::ostream &os, bool binary,
|
||||
const Lattice &lat);
|
||||
|
||||
// the following function requires that *clat be
|
||||
// NULL when called.
|
||||
bool ReadCompactLattice(std::istream &is, bool binary,
|
||||
CompactLattice **clat);
|
||||
// the following function requires that *lat be
|
||||
// NULL when called.
|
||||
bool ReadLattice(std::istream &is, bool binary,
|
||||
Lattice **lat);
|
||||
|
||||
|
||||
class CompactLatticeHolder {
|
||||
public:
|
||||
typedef CompactLattice T;
|
||||
|
||||
CompactLatticeHolder() { t_ = NULL; }
|
||||
|
||||
static bool Write(std::ostream &os, bool binary, const T &t) {
|
||||
// Note: we don't include the binary-mode header when writing
|
||||
// this object to disk; this ensures that if we write to single
|
||||
// files, the result can be read by OpenFst.
|
||||
return WriteCompactLattice(os, binary, t);
|
||||
}
|
||||
|
||||
bool Read(std::istream &is);
|
||||
|
||||
static bool IsReadInBinary() { return true; }
|
||||
|
||||
T &Value() {
|
||||
KALDI_ASSERT(t_ != NULL && "Called Value() on empty CompactLatticeHolder");
|
||||
return *t_;
|
||||
}
|
||||
|
||||
void Clear() { delete t_; t_ = NULL; }
|
||||
|
||||
void Swap(CompactLatticeHolder *other) {
|
||||
std::swap(t_, other->t_);
|
||||
}
|
||||
|
||||
bool ExtractRange(const CompactLatticeHolder &other, const std::string &range) {
|
||||
KALDI_ERR << "ExtractRange is not defined for this type of holder.";
|
||||
return false;
|
||||
}
|
||||
|
||||
~CompactLatticeHolder() { Clear(); }
|
||||
private:
|
||||
T *t_;
|
||||
};
|
||||
|
||||
class LatticeHolder {
|
||||
public:
|
||||
typedef Lattice T;
|
||||
|
||||
LatticeHolder() { t_ = NULL; }
|
||||
|
||||
static bool Write(std::ostream &os, bool binary, const T &t) {
|
||||
// Note: we don't include the binary-mode header when writing
|
||||
// this object to disk; this ensures that if we write to single
|
||||
// files, the result can be read by OpenFst.
|
||||
return WriteLattice(os, binary, t);
|
||||
}
|
||||
|
||||
bool Read(std::istream &is);
|
||||
|
||||
static bool IsReadInBinary() { return true; }
|
||||
|
||||
T &Value() {
|
||||
KALDI_ASSERT(t_ != NULL && "Called Value() on empty LatticeHolder");
|
||||
return *t_;
|
||||
}
|
||||
|
||||
void Clear() { delete t_; t_ = NULL; }
|
||||
|
||||
void Swap(LatticeHolder *other) {
|
||||
std::swap(t_, other->t_);
|
||||
}
|
||||
|
||||
bool ExtractRange(const LatticeHolder &other, const std::string &range) {
|
||||
KALDI_ERR << "ExtractRange is not defined for this type of holder.";
|
||||
return false;
|
||||
}
|
||||
|
||||
~LatticeHolder() { Clear(); }
|
||||
private:
|
||||
T *t_;
|
||||
};
|
||||
|
||||
typedef TableWriter<LatticeHolder> LatticeWriter;
|
||||
typedef SequentialTableReader<LatticeHolder> SequentialLatticeReader;
|
||||
typedef RandomAccessTableReader<LatticeHolder> RandomAccessLatticeReader;
|
||||
|
||||
typedef TableWriter<CompactLatticeHolder> CompactLatticeWriter;
|
||||
typedef SequentialTableReader<CompactLatticeHolder> SequentialCompactLatticeReader;
|
||||
typedef RandomAccessTableReader<CompactLatticeHolder> RandomAccessCompactLatticeReader;
|
||||
|
||||
|
||||
} // namespace kaldi
|
||||
|
||||
#endif // KALDI_LAT_KALDI_LATTICE_H_
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,402 @@
|
||||
// lat/lattice-functions.h
|
||||
|
||||
// Copyright 2009-2012 Saarland University (author: Arnab Ghoshal)
|
||||
// 2012-2013 Johns Hopkins University (Author: Daniel Povey);
|
||||
// Bagher BabaAli
|
||||
// 2014 Guoguo Chen
|
||||
|
||||
// See ../../COPYING for clarification regarding multiple authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
#ifndef KALDI_LAT_LATTICE_FUNCTIONS_H_
|
||||
#define KALDI_LAT_LATTICE_FUNCTIONS_H_
|
||||
|
||||
#include <vector>
|
||||
#include <map>
|
||||
|
||||
#include "base/kaldi-common.h"
|
||||
#include "fstext/fstext-lib.h"
|
||||
#include "itf/decodable-itf.h"
|
||||
#include "itf/transition-information.h"
|
||||
#include "lat/kaldi-lattice.h"
|
||||
|
||||
namespace kaldi {
|
||||
|
||||
// Redundant with the typedef in hmm/posterior.h. We want functions
|
||||
// using the Posterior type to be usable without a dependency on the
|
||||
// hmm library.
|
||||
typedef std::vector<std::vector<std::pair<int32, BaseFloat> > > Posterior;
|
||||
|
||||
/**
|
||||
This function extracts the per-frame log likelihoods from a linear
|
||||
lattice (which we refer to as an 'nbest' lattice elsewhere in Kaldi code).
|
||||
The dimension of *per_frame_loglikes will be set to the
|
||||
number of input symbols in 'nbest'. The elements of
|
||||
'*per_frame_loglikes' will be set to the .Value2() elements of the lattice
|
||||
weights, which represent the acoustic costs; you may want to scale this
|
||||
vector afterward by -1/acoustic_scale to get the original loglikes.
|
||||
If there are acoustic costs on input-epsilon arcs or the final-prob in 'nbest'
|
||||
(and this should not normally be the case in situations where it makes
|
||||
sense to call this function), they will be included to the cost of the
|
||||
preceding input symbol, or the following input symbol for input-epsilons
|
||||
encountered prior to any input symbol. If 'nbest' has no input symbols,
|
||||
'per_frame_loglikes' will be set to the empty vector.
|
||||
**/
|
||||
void GetPerFrameAcousticCosts(const Lattice &nbest,
|
||||
Vector<BaseFloat> *per_frame_loglikes);
|
||||
|
||||
/// This function iterates over the states of a topologically sorted lattice and
|
||||
/// counts the time instance corresponding to each state. The times are returned
|
||||
/// in a vector of integers 'times' which is resized to have a size equal to the
|
||||
/// number of states in the lattice. The function also returns the maximum time
|
||||
/// in the lattice (this will equal the number of frames in the file).
|
||||
int32 LatticeStateTimes(const Lattice &lat, std::vector<int32> *times);
|
||||
|
||||
/// As LatticeStateTimes, but in the CompactLattice format. Note: must
|
||||
/// be topologically sorted. Returns length of the utterance in frames, which
|
||||
/// might not be the same as the maximum time in the lattice, due to frames
|
||||
/// in the final-prob.
|
||||
int32 CompactLatticeStateTimes(const CompactLattice &clat,
|
||||
std::vector<int32> *times);
|
||||
|
||||
/// This function does the forward-backward over lattices and computes the
|
||||
/// posterior probabilities of the arcs. It returns the total log-probability
|
||||
/// of the lattice. The Posterior quantities contain pairs of (transition-id, weight)
|
||||
/// on each frame.
|
||||
/// If the pointer "acoustic_like_sum" is provided, this value is set to
|
||||
/// the sum over the arcs, of the posterior of the arc times the
|
||||
/// acoustic likelihood [i.e. negated acoustic score] on that link.
|
||||
/// This is used in combination with other quantities to work out
|
||||
/// the objective function in MMI discriminative training.
|
||||
BaseFloat LatticeForwardBackward(const Lattice &lat,
|
||||
Posterior *arc_post,
|
||||
double *acoustic_like_sum = NULL);
|
||||
|
||||
// This function is something similar to LatticeForwardBackward(), but it is on
|
||||
// the CompactLattice lattice format. Also we only need the alpha in the forward
|
||||
// path, not the posteriors.
|
||||
bool ComputeCompactLatticeAlphas(const CompactLattice &lat,
|
||||
std::vector<double> *alpha);
|
||||
|
||||
// A sibling of the function CompactLatticeAlphas()... We compute the beta from
|
||||
// the backward path here.
|
||||
bool ComputeCompactLatticeBetas(const CompactLattice &lat,
|
||||
std::vector<double> *beta);
|
||||
|
||||
|
||||
// Computes (normal or Viterbi) alphas and betas; returns (total-prob, or
|
||||
// best-path negated cost) Note: in either case, the alphas and betas are
|
||||
// negated costs. Requires that lat be topologically sorted. This code
|
||||
// will work for either CompactLattice or Lattice.
|
||||
template<typename LatticeType>
|
||||
double ComputeLatticeAlphasAndBetas(const LatticeType &lat,
|
||||
bool viterbi,
|
||||
std::vector<double> *alpha,
|
||||
std::vector<double> *beta);
|
||||
|
||||
|
||||
/// Topologically sort the compact lattice if not already topologically sorted.
|
||||
/// Will crash if the lattice cannot be topologically sorted.
|
||||
void TopSortCompactLatticeIfNeeded(CompactLattice *clat);
|
||||
|
||||
|
||||
/// Topologically sort the lattice if not already topologically sorted.
|
||||
/// Will crash if lattice cannot be topologically sorted.
|
||||
void TopSortLatticeIfNeeded(Lattice *clat);
|
||||
|
||||
/// Returns the depth of the lattice, defined as the average number of arcs (or
|
||||
/// final-prob strings) crossing any given frame. Returns 1 for empty lattices.
|
||||
/// Requires that clat is topologically sorted!
|
||||
BaseFloat CompactLatticeDepth(const CompactLattice &clat,
|
||||
int32 *num_frames = NULL);
|
||||
|
||||
/// This function returns, for each frame, the number of arcs crossing that
|
||||
/// frame.
|
||||
void CompactLatticeDepthPerFrame(const CompactLattice &clat,
|
||||
std::vector<int32> *depth_per_frame);
|
||||
|
||||
|
||||
/// This function limits the depth of the lattice, per frame: that means, it
|
||||
/// does not allow more than a specified number of arcs active on any given
|
||||
/// frame. This can be used to reduce the size of the "very deep" portions of
|
||||
/// the lattice.
|
||||
void CompactLatticeLimitDepth(int32 max_arcs_per_frame,
|
||||
CompactLattice *clat);
|
||||
|
||||
|
||||
/// Given a lattice, and a transition model to map pdf-ids to phones,
|
||||
/// outputs for each frame the set of phones active on that frame. If
|
||||
/// sil_phones (which must be sorted and uniq) is nonempty, it excludes
|
||||
/// phones in this list.
|
||||
void LatticeActivePhones(const Lattice &lat, const TransitionInformation &trans,
|
||||
const std::vector<int32> &sil_phones,
|
||||
std::vector<std::set<int32> > *active_phones);
|
||||
|
||||
/// Given a lattice, and a transition model to map pdf-ids to phones,
|
||||
/// replace the output symbols (presumably words), with phones; we
|
||||
/// use the TransitionModel to work out the phone sequence. Note
|
||||
/// that the phone labels are not exactly aligned with the phone
|
||||
/// boundaries. We put a phone label to coincide with any transition
|
||||
/// to the final, nonemitting state of a phone (this state always exists,
|
||||
/// we ensure this in HmmTopology::Check()). This would be the last
|
||||
/// transition-id in the phone if reordering is not done (but typically
|
||||
/// we do reorder).
|
||||
/// Also see PhoneAlignLattice, in phone-align-lattice.h.
|
||||
void ConvertLatticeToPhones(const TransitionInformation &trans_model,
|
||||
Lattice *lat);
|
||||
|
||||
/// Prunes a lattice or compact lattice. Returns true on success, false if
|
||||
/// there was some kind of failure.
|
||||
template<class LatticeType>
|
||||
bool PruneLattice(BaseFloat beam, LatticeType *lat);
|
||||
|
||||
|
||||
/// Given a lattice, and a transition model to map pdf-ids to phones,
|
||||
/// replace the sequences of transition-ids with sequences of phones.
|
||||
/// Note that this is different from ConvertLatticeToPhones, in that
|
||||
/// we replace the transition-ids not the words.
|
||||
void ConvertCompactLatticeToPhones(const TransitionInformation &trans_model,
|
||||
CompactLattice *clat);
|
||||
|
||||
/// Boosts LM probabilities by b * [number of frame errors]; equivalently, adds
|
||||
/// -b*[number of frame errors] to the graph-component of the cost of each arc/path.
|
||||
/// There is a frame error if a particular transition-id on a particular frame
|
||||
/// corresponds to a phone not matching transcription's alignment for that frame.
|
||||
/// This is used in "margin-inspired" discriminative training, esp. Boosted MMI.
|
||||
/// The TransitionInformation is used to map transition-ids in the lattice
|
||||
/// input-side to phones; the phones appearing in
|
||||
/// "silence_phones" are treated specially in that we replace the frame error f
|
||||
/// (either zero or 1) for a frame, with the minimum of f or max_silence_error.
|
||||
/// For the normal recipe, max_silence_error would be zero.
|
||||
/// Returns true on success, false if there was some kind of mismatch.
|
||||
/// At input, silence_phones must be sorted and unique.
|
||||
bool LatticeBoost(const TransitionInformation &trans,
|
||||
const std::vector<int32> &alignment,
|
||||
const std::vector<int32> &silence_phones,
|
||||
BaseFloat b,
|
||||
BaseFloat max_silence_error,
|
||||
Lattice *lat);
|
||||
|
||||
|
||||
/**
|
||||
This function implements either the MPFE (minimum phone frame error) or SMBR
|
||||
(state-level minimum bayes risk) forward-backward, depending on whether
|
||||
"criterion" is "mpfe" or "smbr". It returns the MPFE
|
||||
criterion of SMBR criterion for this utterance, and outputs the posteriors (which
|
||||
may be positive or negative) into "post".
|
||||
|
||||
@param [in] trans The transition model. Used to map the
|
||||
transition-ids to phones or pdfs.
|
||||
@param [in] silence_phones A list of integer ids of silence phones. The
|
||||
silence frames i.e. the frames where num_ali
|
||||
corresponds to a silence phones are treated specially.
|
||||
The behavior is determined by 'one_silence_class'
|
||||
being false (traditional behavior) or true.
|
||||
Usually in our setup, several phones including
|
||||
the silence, vocalized noise, non-spoken noise
|
||||
and unk are treated as "silence phones"
|
||||
@param [in] lat The denominator lattice
|
||||
@param [in] num_ali The numerator alignment
|
||||
@param [in] criterion The objective function. Must be "mpfe" or "smbr"
|
||||
for MPFE (minimum phone frame error) or sMBR
|
||||
(state minimum bayes risk) training.
|
||||
@param [in] one_silence_class Determines how the silence frames are treated.
|
||||
Setting this to false gives the old traditional behavior,
|
||||
where the silence frames (according to num_ali) are
|
||||
treated as incorrect. However, this means that the
|
||||
insertions are not penalized by the objective.
|
||||
Setting this to true gives the new behaviour, where we
|
||||
treat silence as any other phone, except that all pdfs
|
||||
of silence phones are collapsed into a single class for
|
||||
the frame-error computation. This can possible reduce
|
||||
the insertions in the trained model. This is closer to
|
||||
the WER metric that we actually care about, since WER is
|
||||
generally computed after filtering out noises, but
|
||||
does penalize insertions.
|
||||
@param [out] post The "MBR posteriors" i.e. derivatives w.r.t to the
|
||||
pseudo log-likelihoods of states at each frame.
|
||||
*/
|
||||
BaseFloat LatticeForwardBackwardMpeVariants(
|
||||
const TransitionInformation &trans,
|
||||
const std::vector<int32> &silence_phones,
|
||||
const Lattice &lat,
|
||||
const std::vector<int32> &num_ali,
|
||||
std::string criterion,
|
||||
bool one_silence_class,
|
||||
Posterior *post);
|
||||
|
||||
/// This function takes a CompactLattice that should only contain a single
|
||||
/// linear sequence (e.g. derived from lattice-1best), and that should have been
|
||||
/// processed so that the arcs in the CompactLattice align correctly with the
|
||||
/// word boundaries (e.g. by lattice-align-words). It outputs 3 vectors of the
|
||||
/// same size, which give, for each word in the lattice (in sequence), the word
|
||||
/// label and the begin time and length in frames. This is done even for zero
|
||||
/// (epsilon) words, generally corresponding to optional silence-- if you don't
|
||||
/// want them, just ignore them in the output.
|
||||
/// This function will print a warning and return false, if the lattice
|
||||
/// did not have the correct format (e.g. if it is empty or it is not
|
||||
/// linear).
|
||||
bool CompactLatticeToWordAlignment(const CompactLattice &clat,
|
||||
std::vector<int32> *words,
|
||||
std::vector<int32> *begin_times,
|
||||
std::vector<int32> *lengths);
|
||||
|
||||
/// A form of the shortest-path/best-path algorithm that's specially coded for
|
||||
/// CompactLattice. Requires that clat be acyclic.
|
||||
void CompactLatticeShortestPath(const CompactLattice &clat,
|
||||
CompactLattice *shortest_path);
|
||||
|
||||
/// This function expands a CompactLattice to ensure high-probability paths
|
||||
/// have unique histories. Arcs with posteriors larger than epsilon get splitted.
|
||||
void ExpandCompactLattice(const CompactLattice &clat,
|
||||
double epsilon,
|
||||
CompactLattice *expand_clat);
|
||||
|
||||
/// For each state, compute forward and backward best (viterbi) costs and its
|
||||
/// traceback states (for generating best paths later). The forward best cost
|
||||
/// for a state is the cost of the best path from the start state to the state.
|
||||
/// The traceback state of this state is its predecessor state in the best path.
|
||||
/// The backward best cost for a state is the cost of the best path from the
|
||||
/// state to a final one. Its traceback state is the successor state in the best
|
||||
/// path in the forward direction.
|
||||
/// Note: final weights of states are in backward_best_cost_and_pred.
|
||||
/// Requires the input CompactLattice clat be acyclic.
|
||||
typedef std::vector<std::pair<double,
|
||||
CompactLatticeArc::StateId> > CostTraceType;
|
||||
void CompactLatticeBestCostsAndTracebacks(
|
||||
const CompactLattice &clat,
|
||||
CostTraceType *forward_best_cost_and_pred,
|
||||
CostTraceType *backward_best_cost_and_pred);
|
||||
|
||||
/// This function adds estimated neural language model scores of words in a
|
||||
/// minimal list of hypotheses that covers a lattice, to the graph scores on the
|
||||
/// arcs. The list of hypotheses are generated by latbin/lattice-path-cover.
|
||||
typedef unordered_map<std::pair<int32, int32>, double, PairHasher<int32> > MapT;
|
||||
void AddNnlmScoreToCompactLattice(const MapT &nnlm_scores,
|
||||
CompactLattice *clat);
|
||||
|
||||
/// This function add the word insertion penalty to graph score of each word
|
||||
/// in the compact lattice
|
||||
void AddWordInsPenToCompactLattice(BaseFloat word_ins_penalty,
|
||||
CompactLattice *clat);
|
||||
|
||||
/// This function *adds* the negated scores obtained from the Decodable object,
|
||||
/// to the acoustic scores on the arcs. If you want to replace them, you should
|
||||
/// use ScaleCompactLattice to first set the acoustic scores to zero. Returns
|
||||
/// true on success, false on error (typically some kind of mismatched inputs).
|
||||
bool RescoreCompactLattice(DecodableInterface *decodable,
|
||||
CompactLattice *clat);
|
||||
|
||||
|
||||
/// This function returns the number of words in the longest sentence in a
|
||||
/// CompactLattice (i.e. the the maximum of any path, of the count of
|
||||
/// olabels on that path).
|
||||
int32 LongestSentenceLength(const Lattice &lat);
|
||||
|
||||
/// This function returns the number of words in the longest sentence in a
|
||||
/// CompactLattice, i.e. the the maximum of any path, of the count of
|
||||
/// labels on that path... note, in CompactLattice, the ilabels and olabels
|
||||
/// are identical because it is an acceptor.
|
||||
int32 LongestSentenceLength(const CompactLattice &lat);
|
||||
|
||||
|
||||
/// This function is like RescoreCompactLattice, but it is modified to avoid
|
||||
/// computing probabilities on most frames where all the pdf-ids are the same.
|
||||
/// (it needs the transition-model to work out whether two transition-ids map to
|
||||
/// the same pdf-id, and it assumes that the lattice has transition-ids on it).
|
||||
/// The naive thing would be to just set all probabilities to zero on frames
|
||||
/// where all the pdf-ids are the same (because this value won't affect the
|
||||
/// lattice posterior). But this would become confusing when we compute
|
||||
/// corpus-level diagnostics such as the MMI objective function. Instead,
|
||||
/// imagine speedup_factor = 100 (it must be >= 1.0)... with probability (1.0 /
|
||||
/// speedup_factor) we compute those likelihoods and multiply them by
|
||||
/// speedup_factor; otherwise we set them to zero. This gives the right
|
||||
/// expected probability so our corpus-level diagnostics will be about right.
|
||||
bool RescoreCompactLatticeSpeedup(
|
||||
const TransitionInformation &tmodel,
|
||||
BaseFloat speedup_factor,
|
||||
DecodableInterface *decodable,
|
||||
CompactLattice *clat);
|
||||
|
||||
|
||||
/// This function *adds* the negated scores obtained from the Decodable object,
|
||||
/// to the acoustic scores on the arcs. If you want to replace them, you should
|
||||
/// use ScaleCompactLattice to first set the acoustic scores to zero. Returns
|
||||
/// true on success, false on error (e.g. some kind of mismatched inputs).
|
||||
/// The input labels, if nonzero, are interpreted as transition-ids or whatever
|
||||
/// other index the Decodable object expects.
|
||||
bool RescoreLattice(DecodableInterface *decodable,
|
||||
Lattice *lat);
|
||||
|
||||
/// This function Composes a CompactLattice format lattice with a
|
||||
/// DeterministicOnDemandFst<fst::StdFst> format fst, and outputs another
|
||||
/// CompactLattice format lattice. The first element (the one that corresponds
|
||||
/// to LM weight) in CompactLatticeWeight is used for composition.
|
||||
///
|
||||
/// Note that the DeterministicOnDemandFst interface is not "const", therefore
|
||||
/// we cannot use "const" for <det_fst>.
|
||||
void ComposeCompactLatticeDeterministic(
|
||||
const CompactLattice& clat,
|
||||
fst::DeterministicOnDemandFst<fst::StdArc>* det_fst,
|
||||
CompactLattice* composed_clat);
|
||||
|
||||
/// This function computes the mapping from the pair
|
||||
/// (frame-index, transition-id) to the pair
|
||||
/// (sum-of-acoustic-scores, num-of-occurences) over all occurences of the
|
||||
/// transition-id in that frame.
|
||||
/// frame-index in the lattice.
|
||||
/// This function is useful for retaining the acoustic scores in a
|
||||
/// non-compact lattice after a process like determinization where the
|
||||
/// frame-level acoustic scores are typically lost.
|
||||
/// The function ReplaceAcousticScoresFromMap is used to restore the
|
||||
/// acoustic scores computed by this function.
|
||||
///
|
||||
/// @param [in] lat Input lattice. Expected to be top-sorted. Otherwise the
|
||||
/// function will crash.
|
||||
/// @param [out] acoustic_scores
|
||||
/// Pointer to a map from the pair (frame-index,
|
||||
/// transition-id) to a pair (sum-of-acoustic-scores,
|
||||
/// num-of-occurences).
|
||||
/// Usually the acoustic scores for a pdf-id (and hence
|
||||
/// transition-id) on a frame will be the same for all the
|
||||
/// occurences of the pdf-id in that frame.
|
||||
/// But if not, we will take the average of the acoustic
|
||||
/// scores. Hence, we store both the sum-of-acoustic-scores
|
||||
/// and the num-of-occurences of the transition-id in that
|
||||
/// frame.
|
||||
void ComputeAcousticScoresMap(
|
||||
const Lattice &lat,
|
||||
unordered_map<std::pair<int32, int32>, std::pair<BaseFloat, int32>,
|
||||
PairHasher<int32> > *acoustic_scores);
|
||||
|
||||
/// This function restores acoustic scores computed using the function
|
||||
/// ComputeAcousticScoresMap into the lattice.
|
||||
///
|
||||
/// @param [in] acoustic_scores
|
||||
/// A map from the pair (frame-index, transition-id) to a
|
||||
/// pair (sum-of-acoustic-scores, num-of-occurences) of
|
||||
/// the occurences of the transition-id in that frame.
|
||||
/// See the comments for ComputeAcousticScoresMap for
|
||||
/// details.
|
||||
/// @param [out] lat Pointer to the output lattice.
|
||||
void ReplaceAcousticScoresFromMap(
|
||||
const unordered_map<std::pair<int32, int32>, std::pair<BaseFloat, int32>,
|
||||
PairHasher<int32> > &acoustic_scores,
|
||||
Lattice *lat);
|
||||
|
||||
} // namespace kaldi
|
||||
|
||||
#endif // KALDI_LAT_LATTICE_FUNCTIONS_H_
|
@ -0,0 +1,2 @@
|
||||
aux_source_directory(. DIR_LIB_SRCS)
|
||||
add_library(nnet STATIC ${DIR_LIB_SRCS})
|
@ -0,0 +1,124 @@
|
||||
// itf/decodable-itf.h
|
||||
|
||||
// Copyright 2009-2011 Microsoft Corporation; Saarland University;
|
||||
// Mirko Hannemann; Go Vivace Inc.;
|
||||
// 2013 Johns Hopkins University (author: Daniel Povey)
|
||||
|
||||
// See ../../COPYING for clarification regarding multiple authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef KALDI_ITF_DECODABLE_ITF_H_
|
||||
#define KALDI_ITF_DECODABLE_ITF_H_ 1
|
||||
#include "base/kaldi-common.h"
|
||||
|
||||
namespace kaldi {
|
||||
/// @ingroup Interfaces
|
||||
/// @{
|
||||
|
||||
|
||||
/**
|
||||
DecodableInterface provides a link between the (acoustic-modeling and
|
||||
feature-processing) code and the decoder. The idea is to make this
|
||||
interface as small as possible, and to make it as agnostic as possible about
|
||||
the form of the acoustic model (e.g. don't assume the probabilities are a
|
||||
function of just a vector of floats), and about the decoder (e.g. don't
|
||||
assume it accesses frames in strict left-to-right order). For normal
|
||||
models, without on-line operation, the "decodable" sub-class will just be a
|
||||
wrapper around a matrix of features and an acoustic model, and it will
|
||||
answer the question 'what is the acoustic likelihood for this index and this
|
||||
frame?'.
|
||||
|
||||
For online decoding, where the features are coming in in real time, it is
|
||||
important to understand the IsLastFrame() and NumFramesReady() functions.
|
||||
There are two ways these are used: the old online-decoding code, in ../online/,
|
||||
and the new online-decoding code, in ../online2/. In the old online-decoding
|
||||
code, the decoder would do:
|
||||
\code{.cc}
|
||||
for (int frame = 0; !decodable.IsLastFrame(frame); frame++) {
|
||||
// Process this frame
|
||||
}
|
||||
\endcode
|
||||
and the call to IsLastFrame would block if the features had not arrived yet.
|
||||
The decodable object would have to know when to terminate the decoding. This
|
||||
online-decoding mode is still supported, it is what happens when you call, for
|
||||
example, LatticeFasterDecoder::Decode().
|
||||
|
||||
We realized that this "blocking" mode of decoding is not very convenient
|
||||
because it forces the program to be multi-threaded and makes it complex to
|
||||
control endpointing. In the "new" decoding code, you don't call (for example)
|
||||
LatticeFasterDecoder::Decode(), you call LatticeFasterDecoder::InitDecoding(),
|
||||
and then each time you get more features, you provide them to the decodable
|
||||
object, and you call LatticeFasterDecoder::AdvanceDecoding(), which does
|
||||
something like this:
|
||||
\code{.cc}
|
||||
while (num_frames_decoded_ < decodable.NumFramesReady()) {
|
||||
// Decode one more frame [increments num_frames_decoded_]
|
||||
}
|
||||
\endcode
|
||||
So the decodable object never has IsLastFrame() called. For decoding where
|
||||
you are starting with a matrix of features, the NumFramesReady() function will
|
||||
always just return the number of frames in the file, and IsLastFrame() will
|
||||
return true for the last frame.
|
||||
|
||||
For truly online decoding, the "old" online decodable objects in ../online/
|
||||
have a "blocking" IsLastFrame() and will crash if you call NumFramesReady().
|
||||
The "new" online decodable objects in ../online2/ return the number of frames
|
||||
currently accessible if you call NumFramesReady(). You will likely not need
|
||||
to call IsLastFrame(), but we implement it to only return true for the last
|
||||
frame of the file once we've decided to terminate decoding.
|
||||
*/
|
||||
class DecodableInterface {
|
||||
public:
|
||||
/// Returns the log likelihood, which will be negated in the decoder.
|
||||
/// The "frame" starts from zero. You should verify that NumFramesReady() > frame
|
||||
/// before calling this.
|
||||
virtual BaseFloat LogLikelihood(int32 frame, int32 index) = 0;
|
||||
|
||||
/// Returns true if this is the last frame. Frames are zero-based, so the
|
||||
/// first frame is zero. IsLastFrame(-1) will return false, unless the file
|
||||
/// is empty (which is a case that I'm not sure all the code will handle, so
|
||||
/// be careful). Caution: the behavior of this function in an online setting
|
||||
/// is being changed somewhat. In future it may return false in cases where
|
||||
/// we haven't yet decided to terminate decoding, but later true if we decide
|
||||
/// to terminate decoding. The plan in future is to rely more on
|
||||
/// NumFramesReady(), and in future, IsLastFrame() would always return false
|
||||
/// in an online-decoding setting, and would only return true in a
|
||||
/// decoding-from-matrix setting where we want to allow the last delta or LDA
|
||||
/// features to be flushed out for compatibility with the baseline setup.
|
||||
virtual bool IsLastFrame(int32 frame) const = 0;
|
||||
|
||||
/// The call NumFramesReady() will return the number of frames currently available
|
||||
/// for this decodable object. This is for use in setups where you don't want the
|
||||
/// decoder to block while waiting for input. This is newly added as of Jan 2014,
|
||||
/// and I hope, going forward, to rely on this mechanism more than IsLastFrame to
|
||||
/// know when to stop decoding.
|
||||
virtual int32 NumFramesReady() const {
|
||||
KALDI_ERR << "NumFramesReady() not implemented for this decodable type.";
|
||||
return -1;
|
||||
}
|
||||
|
||||
/// Returns the number of states in the acoustic model
|
||||
/// (they will be indexed one-based, i.e. from 1 to NumIndices();
|
||||
/// this is for compatibility with OpenFst).
|
||||
virtual int32 NumIndices() const = 0;
|
||||
|
||||
virtual std::vector<BaseFloat> FrameLogLikelihood(int32 frame);
|
||||
|
||||
virtual ~DecodableInterface() {}
|
||||
};
|
||||
/// @}
|
||||
} // namespace Kaldi
|
||||
|
||||
#endif // KALDI_ITF_DECODABLE_ITF_H_
|
@ -0,0 +1,46 @@
|
||||
#include "nnet/decodable.h"
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
using kaldi::BaseFloat;
|
||||
using kaldi::Matrix;
|
||||
|
||||
Decodable::Decodable(const std::shared_ptr<NnetInterface>& nnet):
|
||||
frontend_(NULL),
|
||||
nnet_(nnet),
|
||||
finished_(false),
|
||||
frames_ready_(0) {
|
||||
}
|
||||
|
||||
void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) {
|
||||
frames_ready_ += likelihood.NumRows();
|
||||
}
|
||||
|
||||
//Decodable::Init(DecodableConfig config) {
|
||||
//}
|
||||
|
||||
bool Decodable::IsLastFrame(int32 frame) const {
|
||||
CHECK_LE(frame, frames_ready_);
|
||||
return finished_ && (frame == frames_ready_ - 1);
|
||||
}
|
||||
|
||||
int32 Decodable::NumIndices() const {
|
||||
return 0;
|
||||
}
|
||||
|
||||
BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
void Decodable::FeedFeatures(const Matrix<kaldi::BaseFloat>& features) {
|
||||
nnet_->FeedForward(features, &nnet_cache_);
|
||||
frames_ready_ += nnet_cache_.NumRows();
|
||||
return ;
|
||||
}
|
||||
|
||||
void Decodable::Reset() {
|
||||
// frontend_.Reset();
|
||||
nnet_->Reset();
|
||||
}
|
||||
|
||||
} // namespace ppspeech
|
@ -0,0 +1,31 @@
|
||||
#include "nnet/decodable-itf.h"
|
||||
#include "base/common.h"
|
||||
#include "kaldi/matrix/kaldi-matrix.h"
|
||||
#include "frontend/feature_extractor_interface.h"
|
||||
#include "nnet/nnet_interface.h"
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
struct DecodableOpts;
|
||||
|
||||
class Decodable : public kaldi::DecodableInterface {
|
||||
public:
|
||||
explicit Decodable(const std::shared_ptr<NnetInterface>& nnet);
|
||||
//void Init(DecodableOpts config);
|
||||
virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index);
|
||||
virtual bool IsLastFrame(int32 frame) const;
|
||||
virtual int32 NumIndices() const;
|
||||
void Acceptlikelihood(const kaldi::Matrix<kaldi::BaseFloat>& likelihood); // remove later
|
||||
void FeedFeatures(const kaldi::Matrix<kaldi::BaseFloat>& feature); // only for test, todo remove later
|
||||
std::vector<BaseFloat> FrameLogLikelihood(int32 frame);
|
||||
void Reset();
|
||||
void InputFinished() { finished_ = true; }
|
||||
private:
|
||||
std::shared_ptr<FeatureExtractorInterface> frontend_;
|
||||
std::shared_ptr<NnetInterface> nnet_;
|
||||
kaldi::Matrix<kaldi::BaseFloat> nnet_cache_;
|
||||
bool finished_;
|
||||
int32 frames_ready_;
|
||||
};
|
||||
|
||||
} // namespace ppspeech
|
@ -0,0 +1,19 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "base/basic_types.h"
|
||||
#include "kaldi/base/kaldi-types.h"
|
||||
#include "kaldi/matrix/kaldi-matrix.h"
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
class NnetInterface {
|
||||
public:
|
||||
virtual ~NnetInterface() {}
|
||||
virtual void FeedForward(const kaldi::Matrix<kaldi::BaseFloat>& features,
|
||||
kaldi::Matrix<kaldi::BaseFloat>* inferences);
|
||||
virtual void Reset();
|
||||
|
||||
};
|
||||
|
||||
} // namespace ppspeech
|
@ -0,0 +1,182 @@
|
||||
#include "nnet/paddle_nnet.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
using std::vector;
|
||||
using std::string;
|
||||
using std::shared_ptr;
|
||||
using kaldi::Matrix;
|
||||
|
||||
void PaddleNnet::InitCacheEncouts(const ModelOptions& opts) {
|
||||
std::vector<std::string> cache_names;
|
||||
cache_names = absl::StrSplit(opts.cache_names, ", ");
|
||||
std::vector<std::string> cache_shapes;
|
||||
cache_shapes = absl::StrSplit(opts.cache_shape, ", ");
|
||||
assert(cache_shapes.size() == cache_names.size());
|
||||
|
||||
for (size_t i = 0; i < cache_shapes.size(); i++) {
|
||||
std::vector<std::string> tmp_shape;
|
||||
tmp_shape = absl::StrSplit(cache_shapes[i], "- ");
|
||||
std::vector<int> cur_shape;
|
||||
std::transform(tmp_shape.begin(), tmp_shape.end(),
|
||||
std::back_inserter(cur_shape),
|
||||
[](const std::string& s) {
|
||||
return atoi(s.c_str());
|
||||
});
|
||||
cache_names_idx_[cache_names[i]] = i;
|
||||
std::shared_ptr<Tensor<BaseFloat>> cache_eout = std::make_shared<Tensor<BaseFloat>>(cur_shape);
|
||||
cache_encouts_.push_back(cache_eout);
|
||||
}
|
||||
}
|
||||
|
||||
PaddleNnet::PaddleNnet(const ModelOptions& opts) {
|
||||
paddle_infer::Config config;
|
||||
config.SetModel(opts.model_path, opts.params_path);
|
||||
if (opts.use_gpu) {
|
||||
config.EnableUseGpu(500, 0);
|
||||
}
|
||||
config.SwitchIrOptim(opts.switch_ir_optim);
|
||||
if (opts.enable_fc_padding) {
|
||||
config.DisableFCPadding();
|
||||
}
|
||||
if (opts.enable_profile) {
|
||||
config.EnableProfile();
|
||||
}
|
||||
pool.reset(new paddle_infer::services::PredictorPool(config, opts.thread_num));
|
||||
if (pool == nullptr) {
|
||||
LOG(ERROR) << "create the predictor pool failed";
|
||||
}
|
||||
pool_usages.resize(opts.thread_num);
|
||||
std::fill(pool_usages.begin(), pool_usages.end(), false);
|
||||
LOG(INFO) << "load paddle model success";
|
||||
|
||||
LOG(INFO) << "start to check the predictor input and output names";
|
||||
LOG(INFO) << "input names: " << opts.input_names;
|
||||
LOG(INFO) << "output names: " << opts.output_names;
|
||||
vector<string> input_names_vec = absl::StrSplit(opts.input_names, ", ");
|
||||
vector<string> output_names_vec = absl::StrSplit(opts.output_names, ", ");
|
||||
paddle_infer::Predictor* predictor = GetPredictor();
|
||||
|
||||
std::vector<std::string> model_input_names = predictor->GetInputNames();
|
||||
assert(input_names_vec.size() == model_input_names.size());
|
||||
for (size_t i = 0; i < model_input_names.size(); i++) {
|
||||
assert(input_names_vec[i] == model_input_names[i]);
|
||||
}
|
||||
|
||||
std::vector<std::string> model_output_names = predictor->GetOutputNames();
|
||||
assert(output_names_vec.size() == model_output_names.size());
|
||||
for (size_t i = 0;i < output_names_vec.size(); i++) {
|
||||
assert(output_names_vec[i] == model_output_names[i]);
|
||||
}
|
||||
ReleasePredictor(predictor);
|
||||
|
||||
InitCacheEncouts(opts);
|
||||
}
|
||||
|
||||
paddle_infer::Predictor* PaddleNnet::GetPredictor() {
|
||||
LOG(INFO) << "attempt to get a new predictor instance " << std::endl;
|
||||
paddle_infer::Predictor* predictor = nullptr;
|
||||
std::lock_guard<std::mutex> guard(pool_mutex);
|
||||
int pred_id = 0;
|
||||
|
||||
while (pred_id < pool_usages.size()) {
|
||||
if (pool_usages[pred_id] == false) {
|
||||
predictor = pool->Retrive(pred_id);
|
||||
break;
|
||||
}
|
||||
++pred_id;
|
||||
}
|
||||
|
||||
if (predictor) {
|
||||
pool_usages[pred_id] = true;
|
||||
predictor_to_thread_id[predictor] = pred_id;
|
||||
LOG(INFO) << pred_id << " predictor create success";
|
||||
} else {
|
||||
LOG(INFO) << "Failed to get predictor from pool !!!";
|
||||
}
|
||||
|
||||
return predictor;
|
||||
}
|
||||
|
||||
int PaddleNnet::ReleasePredictor(paddle_infer::Predictor* predictor) {
|
||||
LOG(INFO) << "attempt to releae a predictor";
|
||||
std::lock_guard<std::mutex> guard(pool_mutex);
|
||||
auto iter = predictor_to_thread_id.find(predictor);
|
||||
|
||||
if (iter == predictor_to_thread_id.end()) {
|
||||
LOG(INFO) << "there is no such predictor";
|
||||
return 0;
|
||||
}
|
||||
|
||||
LOG(INFO) << iter->second << " predictor will be release";
|
||||
pool_usages[iter->second] = false;
|
||||
predictor_to_thread_id.erase(predictor);
|
||||
LOG(INFO) << "release success";
|
||||
return 0;
|
||||
}
|
||||
|
||||
shared_ptr<Tensor<BaseFloat>> PaddleNnet::GetCacheEncoder(const string& name) {
|
||||
auto iter = cache_names_idx_.find(name);
|
||||
if (iter == cache_names_idx_.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
assert(iter->second < cache_encouts_.size());
|
||||
return cache_encouts_[iter->second];
|
||||
}
|
||||
|
||||
void PaddleNnet::FeedForward(const Matrix<BaseFloat>& features, Matrix<BaseFloat>* inferences) {
|
||||
|
||||
paddle_infer::Predictor* predictor = GetPredictor();
|
||||
// 1. 得到所有的 input tensor 的名称
|
||||
int row = features.NumRows();
|
||||
int col = features.NumCols();
|
||||
std::vector<std::string> input_names = predictor->GetInputNames();
|
||||
std::vector<std::string> output_names = predictor->GetOutputNames();
|
||||
LOG(INFO) << "feat info: row=" << row << ", col=" << col;
|
||||
|
||||
std::unique_ptr<paddle_infer::Tensor> input_tensor = predictor->GetInputHandle(input_names[0]);
|
||||
std::vector<int> INPUT_SHAPE = {1, row, col};
|
||||
input_tensor->Reshape(INPUT_SHAPE);
|
||||
input_tensor->CopyFromCpu(features.Data());
|
||||
// 3. 输入每个音频帧数
|
||||
std::unique_ptr<paddle_infer::Tensor> input_len = predictor->GetInputHandle(input_names[1]);
|
||||
std::vector<int> input_len_size = {1};
|
||||
input_len->Reshape(input_len_size);
|
||||
std::vector<int64_t> audio_len;
|
||||
audio_len.push_back(row);
|
||||
input_len->CopyFromCpu(audio_len.data());
|
||||
// 输入流式的缓存数据
|
||||
std::unique_ptr<paddle_infer::Tensor> h_box = predictor->GetInputHandle(input_names[2]);
|
||||
shared_ptr<Tensor<BaseFloat>> h_cache = GetCacheEncoder(input_names[2]);
|
||||
h_box->Reshape(h_cache->get_shape());
|
||||
h_box->CopyFromCpu(h_cache->get_data().data());
|
||||
std::unique_ptr<paddle_infer::Tensor> c_box = predictor->GetInputHandle(input_names[3]);
|
||||
shared_ptr<Tensor<float>> c_cache = GetCacheEncoder(input_names[3]);
|
||||
c_box->Reshape(c_cache->get_shape());
|
||||
c_box->CopyFromCpu(c_cache->get_data().data());
|
||||
bool success = predictor->Run();
|
||||
|
||||
if (success == false) {
|
||||
LOG(INFO) << "predictor run occurs error";
|
||||
}
|
||||
|
||||
LOG(INFO) << "get the model success";
|
||||
std::unique_ptr<paddle_infer::Tensor> h_out = predictor->GetOutputHandle(output_names[2]);
|
||||
assert(h_cache->get_shape() == h_out->shape());
|
||||
h_out->CopyToCpu(h_cache->get_data().data());
|
||||
std::unique_ptr<paddle_infer::Tensor> c_out = predictor->GetOutputHandle(output_names[3]);
|
||||
assert(c_cache->get_shape() == c_out->shape());
|
||||
c_out->CopyToCpu(c_cache->get_data().data());
|
||||
// 5. 得到最后的输出结果
|
||||
std::unique_ptr<paddle_infer::Tensor> output_tensor =
|
||||
predictor->GetOutputHandle(output_names[0]);
|
||||
std::vector<int> output_shape = output_tensor->shape();
|
||||
row = output_shape[1];
|
||||
col = output_shape[2];
|
||||
inferences->Resize(row, col);
|
||||
output_tensor->CopyToCpu(inferences->Data());
|
||||
ReleasePredictor(predictor);
|
||||
}
|
||||
|
||||
} // namespace ppspeech
|
@ -0,0 +1,108 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "nnet/nnet_interface.h"
|
||||
#include "base/common.h"
|
||||
#include "paddle_inference_api.h"
|
||||
|
||||
#include "kaldi/matrix/kaldi-matrix.h"
|
||||
#include "kaldi/util/options-itf.h"
|
||||
|
||||
#include <numeric>
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
struct ModelOptions {
|
||||
std::string model_path;
|
||||
std::string params_path;
|
||||
int thread_num;
|
||||
bool use_gpu;
|
||||
bool switch_ir_optim;
|
||||
std::string input_names;
|
||||
std::string output_names;
|
||||
std::string cache_names;
|
||||
std::string cache_shape;
|
||||
bool enable_fc_padding;
|
||||
bool enable_profile;
|
||||
ModelOptions() :
|
||||
model_path("model/final.zip"),
|
||||
params_path("model/avg_1.jit.pdmodel"),
|
||||
thread_num(2),
|
||||
use_gpu(false),
|
||||
input_names("audio"),
|
||||
output_names("probs"),
|
||||
cache_names("enouts"),
|
||||
cache_shape("1-1-1"),
|
||||
switch_ir_optim(false),
|
||||
enable_fc_padding(false),
|
||||
enable_profile(false) {
|
||||
}
|
||||
|
||||
void Register(kaldi::OptionsItf* opts) {
|
||||
opts->Register("model-path", &model_path, "model file path");
|
||||
opts->Register("model-params", ¶ms_path, "params model file path");
|
||||
opts->Register("thread-num", &thread_num, "thread num");
|
||||
opts->Register("use-gpu", &use_gpu, "if use gpu");
|
||||
opts->Register("input-names", &input_names, "paddle input names");
|
||||
opts->Register("output-names", &output_names, "paddle output names");
|
||||
opts->Register("cache-names", &cache_names, "cache names");
|
||||
opts->Register("cache-shape", &cache_shape, "cache shape");
|
||||
opts->Register("switch-ir-optiom", &switch_ir_optim, "paddle SwitchIrOptim option");
|
||||
opts->Register("enable-fc-padding", &enable_fc_padding, "paddle EnableFCPadding option");
|
||||
opts->Register("enable-profile", &enable_profile, "paddle EnableProfile option");
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
class Tensor {
|
||||
public:
|
||||
Tensor() {
|
||||
}
|
||||
Tensor(const std::vector<int>& shape) :
|
||||
_shape(shape) {
|
||||
int data_size = std::accumulate(_shape.begin(), _shape.end(),
|
||||
1, std::multiplies<int>());
|
||||
LOG(INFO) << "data size: " << data_size;
|
||||
_data.resize(data_size, 0);
|
||||
}
|
||||
void reshape(const std::vector<int>& shape) {
|
||||
_shape = shape;
|
||||
int data_size = std::accumulate(_shape.begin(), _shape.end(),
|
||||
1, std::multiplies<int>());
|
||||
_data.resize(data_size, 0);
|
||||
}
|
||||
const std::vector<int>& get_shape() const {
|
||||
return _shape;
|
||||
}
|
||||
std::vector<T>& get_data() {
|
||||
return _data;
|
||||
}
|
||||
private:
|
||||
std::vector<int> _shape;
|
||||
std::vector<T> _data;
|
||||
};
|
||||
|
||||
class PaddleNnet : public NnetInterface {
|
||||
public:
|
||||
PaddleNnet(const ModelOptions& opts);
|
||||
virtual void FeedForward(const kaldi::Matrix<kaldi::BaseFloat>& features,
|
||||
kaldi::Matrix<kaldi::BaseFloat>* inferences);
|
||||
std::shared_ptr<Tensor<kaldi::BaseFloat>> GetCacheEncoder(const std::string& name);
|
||||
void InitCacheEncouts(const ModelOptions& opts);
|
||||
|
||||
private:
|
||||
paddle_infer::Predictor* GetPredictor();
|
||||
int ReleasePredictor(paddle_infer::Predictor* predictor);
|
||||
|
||||
std::unique_ptr<paddle_infer::services::PredictorPool> pool;
|
||||
std::vector<bool> pool_usages;
|
||||
std::mutex pool_mutex;
|
||||
std::map<paddle_infer::Predictor*, int> predictor_to_thread_id;
|
||||
std::map<std::string, int> cache_names_idx_;
|
||||
std::vector<std::shared_ptr<Tensor<kaldi::BaseFloat>>> cache_encouts_;
|
||||
|
||||
public:
|
||||
DISALLOW_COPY_AND_ASSIGN(PaddleNnet);
|
||||
};
|
||||
|
||||
} // namespace ppspeech
|
@ -0,0 +1,4 @@
|
||||
|
||||
add_library(utils
|
||||
file_utils.cc
|
||||
)
|
@ -0,0 +1,17 @@
|
||||
#include "utils/file_utils.h"
|
||||
|
||||
bool ReadFileToVector(const std::string& filename,
|
||||
std::vector<std::string>* vocabulary) {
|
||||
std::ifstream file_in(filename);
|
||||
if (!file_in) {
|
||||
std::cerr << "please input a valid file" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string line;
|
||||
while (std::getline(file_in, line)) {
|
||||
vocabulary->emplace_back(line);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
@ -0,0 +1,8 @@
|
||||
#include "base/common.h"
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
bool ReadFileToVector(const std::string& filename,
|
||||
std::vector<std::string>* data);
|
||||
|
||||
}
|
Loading…
Reference in new issue