From 88275aff05e918727b562312484d1cad10767194 Mon Sep 17 00:00:00 2001 From: SmileGoat Date: Thu, 27 Jan 2022 13:28:58 +0800 Subject: [PATCH 001/124] add linear spectrogram feature extractor, test=doc --- speechx/speechx/base/common.h | 33 +++++ speechx/speechx/base/flags.h | 17 +++ speechx/speechx/base/log.h | 17 +++ speechx/speechx/base/thread_pool.h | 100 +++++++++++++ speechx/speechx/frontend/fbank.h | 36 +++++ .../frontend/feature_extractor_controller.h | 0 .../feature_extractor_controller_impl.h | 0 .../frontend/feature_extractor_interface.h | 29 ++++ .../speechx/frontend/linear_spectrogram.cc | 139 ++++++++++++++++++ speechx/speechx/frontend/linear_spectrogram.h | 46 ++++++ .../frontend/linear_spectrogram_main.cc | 39 +++++ speechx/speechx/frontend/mfcc.h | 16 ++ speechx/speechx/frontend/window.h | 16 ++ 13 files changed, 488 insertions(+) create mode 100644 speechx/speechx/base/common.h create mode 100644 speechx/speechx/base/flags.h create mode 100644 speechx/speechx/base/log.h create mode 100644 speechx/speechx/base/thread_pool.h create mode 100644 speechx/speechx/frontend/fbank.h create mode 100644 speechx/speechx/frontend/feature_extractor_controller.h create mode 100644 speechx/speechx/frontend/feature_extractor_controller_impl.h create mode 100644 speechx/speechx/frontend/feature_extractor_interface.h create mode 100644 speechx/speechx/frontend/linear_spectrogram.cc create mode 100644 speechx/speechx/frontend/linear_spectrogram.h create mode 100644 speechx/speechx/frontend/linear_spectrogram_main.cc create mode 100644 speechx/speechx/frontend/mfcc.h create mode 100644 speechx/speechx/frontend/window.h diff --git a/speechx/speechx/base/common.h b/speechx/speechx/base/common.h new file mode 100644 index 00000000..a16fc55b --- /dev/null +++ b/speechx/speechx/base/common.h @@ -0,0 +1,33 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "base/log.h" +#include "base/basic_types.h" +#include "base/macros.h" diff --git a/speechx/speechx/base/flags.h b/speechx/speechx/base/flags.h new file mode 100644 index 00000000..41df0d45 --- /dev/null +++ b/speechx/speechx/base/flags.h @@ -0,0 +1,17 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "fst/flags.h" diff --git a/speechx/speechx/base/log.h b/speechx/speechx/base/log.h new file mode 100644 index 00000000..d1b7b169 --- /dev/null +++ b/speechx/speechx/base/log.h @@ -0,0 +1,17 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "glog/logging.h" diff --git a/speechx/speechx/base/thread_pool.h b/speechx/speechx/base/thread_pool.h new file mode 100644 index 00000000..f6dada90 --- /dev/null +++ b/speechx/speechx/base/thread_pool.h @@ -0,0 +1,100 @@ +// this code is from https://github.com/progschj/ThreadPool + +#ifndef BASE_THREAD_POOL_H +#define BASE_THREAD_POOL_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +class ThreadPool { +public: + ThreadPool(size_t); + template + auto enqueue(F&& f, Args&&... args) + -> std::future::type>; + ~ThreadPool(); +private: + // need to keep track of threads so we can join them + std::vector< std::thread > workers; + // the task queue + std::queue< std::function > tasks; + + // synchronization + std::mutex queue_mutex; + std::condition_variable condition; + bool stop; +}; + +// the constructor just launches some amount of workers +inline ThreadPool::ThreadPool(size_t threads) + : stop(false) +{ + for(size_t i = 0;i task; + + { + std::unique_lock lock(this->queue_mutex); + this->condition.wait(lock, + [this]{ return this->stop || !this->tasks.empty(); }); + if(this->stop && this->tasks.empty()) + return; + task = std::move(this->tasks.front()); + this->tasks.pop(); + } + + task(); + } + } + ); +} + +// add new work item to the pool +template +auto ThreadPool::enqueue(F&& f, Args&&... args) + -> std::future::type> +{ + using return_type = typename std::result_of::type; + + auto task = std::make_shared< std::packaged_task >( + std::bind(std::forward(f), std::forward(args)...) + ); + + std::future res = task->get_future(); + { + std::unique_lock lock(queue_mutex); + + // don't allow enqueueing after stopping the pool + if(stop) + throw std::runtime_error("enqueue on stopped ThreadPool"); + + tasks.emplace([task](){ (*task)(); }); + } + condition.notify_one(); + return res; +} + +// the destructor joins all threads +inline ThreadPool::~ThreadPool() +{ + { + std::unique_lock lock(queue_mutex); + stop = true; + } + condition.notify_all(); + for(std::thread &worker: workers) + worker.join(); +} + +#endif \ No newline at end of file diff --git a/speechx/speechx/frontend/fbank.h b/speechx/speechx/frontend/fbank.h new file mode 100644 index 00000000..6956690d --- /dev/null +++ b/speechx/speechx/frontend/fbank.h @@ -0,0 +1,36 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// wrap the fbank feat of kaldi, todo (SmileGoat) + +#include "kaldi/feat/feature-mfcc.h" + +#incldue "kaldi/matrix/kaldi-vector.h" + +namespace ppspeech { + +class FbankExtractor : FeatureExtractorInterface { + public: + explicit FbankExtractor(const FbankOptions& opts, + share_ptr pre_extractor); + virtual void AcceptWaveform(const kaldi::Vector& input) = 0; + virtual void Read(kaldi::Vector* feat) = 0; + virtual size_t Dim() const = 0; + + private: + bool Compute(const kaldi::Vector& wave, + kaldi::Vector* feat) const; +}; + +} // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/frontend/feature_extractor_controller.h b/speechx/speechx/frontend/feature_extractor_controller.h new file mode 100644 index 00000000..e69de29b diff --git a/speechx/speechx/frontend/feature_extractor_controller_impl.h b/speechx/speechx/frontend/feature_extractor_controller_impl.h new file mode 100644 index 00000000..e69de29b diff --git a/speechx/speechx/frontend/feature_extractor_interface.h b/speechx/speechx/frontend/feature_extractor_interface.h new file mode 100644 index 00000000..3f3f0ff9 --- /dev/null +++ b/speechx/speechx/frontend/feature_extractor_interface.h @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "base/basic_types.h" +#incldue "kaldi/matrix/kaldi-vector.h" + +namespace ppspeech { + +class FeatureExtractorInterface { + public: + virtual void AcceptWaveform(const kaldi::Vector& input) = 0; + virtual void Read(kaldi::Vector* feat) = 0; + virtual size_t Dim() const = 0; +}; + +} // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/frontend/linear_spectrogram.cc b/speechx/speechx/frontend/linear_spectrogram.cc new file mode 100644 index 00000000..327c3f57 --- /dev/null +++ b/speechx/speechx/frontend/linear_spectrogram.cc @@ -0,0 +1,139 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "frontend/linear_spectrogram.h" +#include "kaldi/base/kaldi-math.h" +#include "kaldi/matrix/matrix-functions.h" + +using kaldi::int32; +using kaldi::BaseFloat; +using kaldi::Vector; +using kaldi::Matrix; +using std::vector; + +LinearSpectrogram::LinearSpectrogram(const LinearSpectrogramOptions& opts) { + int32 window_size = opts.frame_opts.WindowSize(); + int32 window_shift = opts.frame_opts.WindowShift(); + fft_points_ = window_size; + hanning_window_.resize(window_size); + + double a = M_2PI / (window_size - 1); + hanning_window_energy_ = 0; + for (int i = 0; i < window_size; ++i) { + hanning_window_[i] = 0.5 - 0.5 * cos(a * i); + hanning_window_energy_ += hanning_window_[i] * hanning_window_[i]; + } +} + +void LinearSpectrogram::AcceptWavefrom(const Vector& input) { + wavefrom_.resize(input.Dim()); + for (size_t idx = 0; idx < input.Dim(); ++idx) { + waveform_[idx] = input(idx); + } +} + +void LinearSpectrogram::Hanning(vector* data) const { + CHECK_GE(data->size(), hanning_window_.size()); + + for (size_t i = 0; i < hanning_window_.size(); ++i) { + data->at(i) *= hanning_window_[i]; + } +} + +bool LinearSpectrogram::NumpyFft(vector* v, + vector* real, + vector* img) { + if (RealFft(v, true)) { + LOG(ERROR) << "compute the fft occurs error"; + return false; + } + real->push_back(v->at(0)); + img->push_back(0); + for (int i = 1; i < v->size() / 2; i++) { + real->push_back(v->at(2 * i)); + img->push_back(v->at(2 * i + 1)); + } + real->push_back(v->at(1)); + img->push_back(0); + + return true; +} + +// todo refactor later +bool LinearSpectrogram::ReadFeats(Matrix* feats) const { + vector> feat; + if (wavefrom_.empty()) { + return false; + } + vector> result; + Compute(wavefrom_, result); + feats->Resize(result.size(), result[0].size()); + for (int row_idx = 0; row_idx < result.size(); ++row_idx) { + for (int col_idx = 0; col_idx < result.size(); ++col_idx) { + feats(row_idx, col_idx) = result[row_idx][col_idx]; + } + wavefrom_.clear(); + return true; +} + +// Compute spectrogram feat, return num frames +// todo: refactor later (SmileGoat) +int32 LinearSpectrogram::Compute(const vector& wave, + vector>& feat) { + int num_samples = wave.size(); + const int& frame_length = opts.frame_opts.WindowSize(); + const int& sample_rate = opts.frame_opts.samp_freq; + const int& frame_shift = opts.frame_opts.WindowShift(); + const int& fft_points = fft_points_; + const float scale = hanning_window_energy_ * frame_shift; + + if (num_samples < frame_length) { + return 0; + } + + int num_frames = 1 + ((num_samples - frame_length) / frame_shift); + feat.resize(num_frames); + vector fft_real((fft_points_ / 2 + 1), 0); + vector fft_img((fft_points_ / 2 + 1), 0); + vector v(frame_length, 0); + vector power((fft_points / 2 + 1)); + + for (int i = 0; i < num_frames; ++i) { + vector data(wave.data() + i * frame_shift, + wave.data() + i * frame_shift + frame_length); + Hanning(data); + fft_img.clear(); + fft_real.clear(); + v.assign(data.begin(), data.end()); + if (NumpyFft(&v, fft_real, fft_img)) { + LOG(ERROR)<< i << " fft compute occurs error, please checkout the input data"; + return -1; + } + + feat[i].resize(fft_points / 2 + 1); // the last dimension is Fs/2 Hz + for (int j = 0; j < (fft_points / 2 + 1); ++j) { + power[j] = fft_real[j] * fft_real[j] + fft_img[j] * fft_img[j]; + feat[i][j] = power[j]; + + if (j == 0 || j == feat[0].size() - 1) { + feat[i][j] /= scale; + } else { + feat[i][j] *= (2.0 / scale); + } + + // log added eps=1e-14 + feat[i][j] = std::log(feat[i][j] + 1e-14); + } + return 0; +} diff --git a/speechx/speechx/frontend/linear_spectrogram.h b/speechx/speechx/frontend/linear_spectrogram.h new file mode 100644 index 00000000..b69050d1 --- /dev/null +++ b/speechx/speechx/frontend/linear_spectrogram.h @@ -0,0 +1,46 @@ + +#pragma once + +#include "frontend/feature_extractor_interface.h" +#include "kaldi/feat/feature-window.h" +#include "base/common.h" + +namespace ppspeech { + +struct LinearSpectrogramOptions { + kaldi::FrameExtrationOptions frame_opts; + LinearSpectrogramOptions(): + frame_opts() {} + + void Register(kaldi::OptionsItf* opts) { + frame_opts.Register(opts); + } +}; + +class LinearSpectrogram : public FeatureExtractorInterface { + public: + explict LinearSpectrogram(const LinearSpectrogramOptions& opts); + virtual void AcceptWavefrom(const kaldi::Vector& input); + virtual void Read(kaldi::Vector* feat); + virtual size_t Dim() const; + void ReadFeats(kaldi::Matrix* feats) const; + + private: + void Hanning(std::vector& data) const; + kaldi::int32 Compute(const std::vector& wave, + std::vector>& feat) const; + bool NumpyFft(std::vector* v, + std::vector* real, + std::vector* img) const; + + kaldi::int32 fft_points_; + size_t dim_; + std::vector hanning_window_; + kaldi::BaseFloat hanning_window_energy_; + LinearSpectrogramOptions opts_; + std::vector wavefrom_; // remove later, todo(SmileGoat) + DISALLOW_COPY_AND_ASSIGN(LinearSpectrogram); +}; + + +} // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/frontend/linear_spectrogram_main.cc b/speechx/speechx/frontend/linear_spectrogram_main.cc new file mode 100644 index 00000000..455f4f91 --- /dev/null +++ b/speechx/speechx/frontend/linear_spectrogram_main.cc @@ -0,0 +1,39 @@ + +#include "frontend/linear_spectrogram.h" +#include "kaldi/util/table-types.h" +#include "base/log.h" +#include "base/flags.h" +#include "kaldi/feat/wave-reader.h" + +DEFINE_string(wav_rspecifier, "", "test wav path"); +DEFINE_string(feature_wspecifier, "", "test wav ark"); + +int main(int argc, char* argv[]) { + google::ParseCommandLineFlags(&argc, &argv, false); + google::InitGoogleLogging(argv[0]); + + kaldi::SequentialTableReader wav_reader(FLAGS_wav_rspecifier); + kaldi::BaseFloatMatrixWriter feat_writer(FLAGS_feature_wspecifier); + + int32 num_done = 0, num_err = 0; + ppspeech::LinearSpectrogramOptions opt; + ppspeech::LinearSpectrogram linear_spectrogram(opt); + for (; !wav_reader.Done(); wav_reader.Next()) { + std::string utt = wav_reader.Key(); + const kaldi::WaveData &wave_data = wav_reader.Value(); + + int32 this_channel = 0; + kaldi::SubVector waveform(wave_data.Data(), this_channel); + kaldi::Matrix features; + linear_spectrogram.AcceptWaveform(waveform); + linear_spectrogram.ReadFeats(&features); + + feat_writer.Write(utt, features); + if (num_done % 50 == 0 && num_done != 0) + KALDI_VLOG(2) << "Processed " << num_done << " utterances"; + num_done++; + } + KALDI_LOG << "Done " << num_done << " utterances, " << num_err + << " with errors."; + return (num_done != 0 ? 0 : 1); +} \ No newline at end of file diff --git a/speechx/speechx/frontend/mfcc.h b/speechx/speechx/frontend/mfcc.h new file mode 100644 index 00000000..aa369655 --- /dev/null +++ b/speechx/speechx/frontend/mfcc.h @@ -0,0 +1,16 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// wrap the mfcc feat of kaldi, todo (SmileGoat) +#include "kaldi/feat/feature-mfcc.h" \ No newline at end of file diff --git a/speechx/speechx/frontend/window.h b/speechx/speechx/frontend/window.h new file mode 100644 index 00000000..5303cad8 --- /dev/null +++ b/speechx/speechx/frontend/window.h @@ -0,0 +1,16 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// extract the window of kaldi feat. + From a01fa866a4b20d9a8dffd672431c0fdd7fd2ceb1 Mon Sep 17 00:00:00 2001 From: SmileGoat Date: Fri, 28 Jan 2022 14:21:26 +0800 Subject: [PATCH 002/124] add normalizer --- .../speechx/frontend/linear_spectrogram.cc | 8 +- speechx/speechx/frontend/linear_spectrogram.h | 2 +- speechx/speechx/frontend/normalizer.cc | 97 +++++++++++++++++++ speechx/speechx/frontend/normalizer.h | 65 +++++++++++++ 4 files changed, 167 insertions(+), 5 deletions(-) create mode 100644 speechx/speechx/frontend/normalizer.cc create mode 100644 speechx/speechx/frontend/normalizer.h diff --git a/speechx/speechx/frontend/linear_spectrogram.cc b/speechx/speechx/frontend/linear_spectrogram.cc index 327c3f57..df8a29f9 100644 --- a/speechx/speechx/frontend/linear_spectrogram.cc +++ b/speechx/speechx/frontend/linear_spectrogram.cc @@ -89,7 +89,7 @@ bool LinearSpectrogram::ReadFeats(Matrix* feats) const { // Compute spectrogram feat, return num frames // todo: refactor later (SmileGoat) -int32 LinearSpectrogram::Compute(const vector& wave, +bool LinearSpectrogram::Compute(const vector& wave, vector>& feat) { int num_samples = wave.size(); const int& frame_length = opts.frame_opts.WindowSize(); @@ -99,7 +99,7 @@ int32 LinearSpectrogram::Compute(const vector& wave, const float scale = hanning_window_energy_ * frame_shift; if (num_samples < frame_length) { - return 0; + return true; } int num_frames = 1 + ((num_samples - frame_length) / frame_shift); @@ -118,7 +118,7 @@ int32 LinearSpectrogram::Compute(const vector& wave, v.assign(data.begin(), data.end()); if (NumpyFft(&v, fft_real, fft_img)) { LOG(ERROR)<< i << " fft compute occurs error, please checkout the input data"; - return -1; + return false; } feat[i].resize(fft_points / 2 + 1); // the last dimension is Fs/2 Hz @@ -135,5 +135,5 @@ int32 LinearSpectrogram::Compute(const vector& wave, // log added eps=1e-14 feat[i][j] = std::log(feat[i][j] + 1e-14); } - return 0; + return true; } diff --git a/speechx/speechx/frontend/linear_spectrogram.h b/speechx/speechx/frontend/linear_spectrogram.h index b69050d1..981f92ea 100644 --- a/speechx/speechx/frontend/linear_spectrogram.h +++ b/speechx/speechx/frontend/linear_spectrogram.h @@ -28,7 +28,7 @@ class LinearSpectrogram : public FeatureExtractorInterface { private: void Hanning(std::vector& data) const; kaldi::int32 Compute(const std::vector& wave, - std::vector>& feat) const; + std::vector>& feat); bool NumpyFft(std::vector* v, std::vector* real, std::vector* img) const; diff --git a/speechx/speechx/frontend/normalizer.cc b/speechx/speechx/frontend/normalizer.cc new file mode 100644 index 00000000..9a384484 --- /dev/null +++ b/speechx/speechx/frontend/normalizer.cc @@ -0,0 +1,97 @@ + +#include "frontend/normalizer.h" + +DecibelNormalizer::DecibelNormalizer( + const DecibelNormalizerOptions& opts, + const std::unique_ptr& pre_extractor) { + +} + +void DecibelNormalizer::AcceptWavefrom(const kaldi::Vector& input) { + +} + +void DecibelNormalizer::Read(kaldi::Vector* feat) { + +} + +bool DecibelNormalizer::Compute(const Vector& input, + kaldi::Vector* feat) { + // calculate db rms + float rms_db = 0.0; + float mean_square = 0.0; + float gain = 0.0; + vector smaples; + samples.resize(input.Size()); + for (int32 i = 0; i < samples.size(); ++i) { + samples[i] = input(i); + } + + // square + for (auto &d : samples) { + if (_opts.convert_int_float) { + d = d * WAVE_FLOAT_NORMALIZATION; + } + mean_square += d * d; + } + + // mean + mean_square /= samples.size(); + rms_db = 10 * std::log10(mean_square); + gain = opts.target_db - rms_db; + + if (gain > opts.max_gain_db) { + LOG(ERROR) << "Unable to normalize segment to " << opts.target_db << "dB," + << "because the the probable gain have exceeds opts.max_gain_db" + << opts.max_gain_db << "dB."; + return false; + } + + // Note that this is an in-place transformation. + for (auto &item : samples) { + // python item *= 10.0 ** (gain / 20.0) + item *= std::pow(10.0, gain / 20.0); + } + + return true; +} + + +PPNormalizer::PPNormalizer( + const PPNormalizerOptions& opts, + const std::unique_ptr& pre_extractor) { + +} + +void PPNormalizer::AcceptWavefrom(const kaldi::Vector& input) { + +} + +void PPNormalizer::Read(kaldi::Vector* feat) { + +} + +bool PPNormalizer::Compute(const Vector& input, + kaldi::Vector>* feat) { + if ((input.Dim() % mean_.Dim()) == 0) { + LOG(ERROR) << "CMVN dimension is wrong!"; + return false; + } + + try { + int32 size = mean_.Dim(); + feat->Resize(input.Dim()); + for (int32 row_idx = 0; row_idx < j; ++row_idx) { + int32 base_idx = row_idx * size; + for (int32 idx = 0; idx < mean_.Dim(); ++idx) { + (*feat)(base_idx + idx) = (input(base_dix + idx) - mean_(idx))* variance_(idx); + } + } + + } catch(const std::exception& e) { + std::cerr << e.what() << '\n'; + return false; + } + + return true; +} diff --git a/speechx/speechx/frontend/normalizer.h b/speechx/speechx/frontend/normalizer.h new file mode 100644 index 00000000..f297403b --- /dev/null +++ b/speechx/speechx/frontend/normalizer.h @@ -0,0 +1,65 @@ + +#pragma once + +#include "frontend/feature_extractor_interface.h" + +namespace ppspeech { + + +struct DecibelNormalizerOptions { + float target_db; + float max_gain_db; + DecibelNormalizerOptions() : + target_db(-20), + max_gain_db(300.0), + convert_int_float(false) {} + + void Register(kaldi::OptionsItf* opts) { + opts->Register("target-db", &target_db, "target db for db normalization"); + opts->Register("max-gain-db", &max_gain_db, "max gain db for db normalization"); + opts->Register("convert-int-float", &convert_int_float, "if convert int samples to float"); + } +}; + +class DecibelNormalizer : public FeatureExtractorInterface { + public: + explict DecibelNormalizer(const DecibelNormalizerOptions& opts, + const std::unique_ptr& pre_extractor); + virtual void AcceptWavefrom(const kaldi::Vector& input); + virtual void Read(kaldi::Vector* feat); + virtual size_t Dim() const; + bool Compute(const kaldi::Vector& input, + kaldi::Vector>* feat); + private: +}; + +struct NormalizerOptions { + std::string mean_std_path; + NormalizerOptions() : + mean_std_path("") {} + + void Register(kaldi::OptionsItf* opts) { + opts->Register("mean-std", &mean_std_path, "mean std file"); + } +}; + +// todo refactor later (SmileGoat) +class PPNormalizer : public FeatureExtractorInterface { + public: + explicit PPNormalizer(const NormalizerOptions& opts, + const std::unique_ptr& pre_extractor); + ~PPNormalizer() {} + virtual void AcceptWavefrom(const kaldi::Vector& input); + virtual void Read(kaldi::Vector* feat); + virtual size_t Dim() const; + bool Compute(const kaldi::Vector& input, + kaldi::Vector>& feat); + + private: + bool _initialized; + kaldi::Vector mean_; + kaldi::Vector variance_; + NormalizerOptions _opts; +}; + +} // namespace ppspeech \ No newline at end of file From f03d48f79bf7ae3a2cb8ffb1a2e641c0523ce6e9 Mon Sep 17 00:00:00 2001 From: SmileGoat Date: Thu, 10 Feb 2022 20:19:42 +0800 Subject: [PATCH 003/124] test linear spectrogram feature --- .../frontend/feature_extractor_interface.h | 2 ++ .../speechx/frontend/linear_spectrogram.cc | 36 ++++++++++++++----- speechx/speechx/frontend/linear_spectrogram.h | 10 ++++-- .../frontend/linear_spectrogram_main.cc | 9 ++++- speechx/speechx/frontend/normalizer.cc | 3 +- speechx/speechx/nnet/nnet_interface.h | 16 +++++++++ 6 files changed, 62 insertions(+), 14 deletions(-) create mode 100644 speechx/speechx/nnet/nnet_interface.h diff --git a/speechx/speechx/frontend/feature_extractor_interface.h b/speechx/speechx/frontend/feature_extractor_interface.h index 3f3f0ff9..863c4281 100644 --- a/speechx/speechx/frontend/feature_extractor_interface.h +++ b/speechx/speechx/frontend/feature_extractor_interface.h @@ -23,6 +23,8 @@ class FeatureExtractorInterface { public: virtual void AcceptWaveform(const kaldi::Vector& input) = 0; virtual void Read(kaldi::Vector* feat) = 0; + virtual void Compute(const kaldi::VectorBase& input, + kaldi::VectorBae* feature) = 0; virtual size_t Dim() const = 0; }; diff --git a/speechx/speechx/frontend/linear_spectrogram.cc b/speechx/speechx/frontend/linear_spectrogram.cc index df8a29f9..3545cb53 100644 --- a/speechx/speechx/frontend/linear_spectrogram.cc +++ b/speechx/speechx/frontend/linear_spectrogram.cc @@ -22,7 +22,10 @@ using kaldi::Vector; using kaldi::Matrix; using std::vector; -LinearSpectrogram::LinearSpectrogram(const LinearSpectrogramOptions& opts) { +LinearSpectrogram::LinearSpectrogram( + const LinearSpectrogramOptions& opts, + const std::unique_ptr base_extractor) { + base_extractor_ = std::move(base_extractor); int32 window_size = opts.frame_opts.WindowSize(); int32 window_shift = opts.frame_opts.WindowShift(); fft_points_ = window_size; @@ -34,6 +37,8 @@ LinearSpectrogram::LinearSpectrogram(const LinearSpectrogramOptions& opts) { hanning_window_[i] = 0.5 - 0.5 * cos(a * i); hanning_window_energy_ += hanning_window_[i] * hanning_window_[i]; } + + dim_ = fft_points_ / 2 + 1; // the dimension is Fs/2 Hz } void LinearSpectrogram::AcceptWavefrom(const Vector& input) { @@ -70,27 +75,42 @@ bool LinearSpectrogram::NumpyFft(vector* v, return true; } -// todo refactor later +//todo remove later +void CopyVector2StdVector(const kaldi::Vector& input, + vector* output) { +} + +// todo remove later bool LinearSpectrogram::ReadFeats(Matrix* feats) const { - vector> feat; - if (wavefrom_.empty()) { + if (wavefrom_.Dim() == 0) { return false; } + kaldi::Vector feats; + Compute(wavefrom_, &feats); vector> result; - Compute(wavefrom_, result); + vector feats_vec; + CopyVector2StdVector(feats, &feats_vec); + Compute(feats_vec, result); feats->Resize(result.size(), result[0].size()); for (int row_idx = 0; row_idx < result.size(); ++row_idx) { for (int col_idx = 0; col_idx < result.size(); ++col_idx) { feats(row_idx, col_idx) = result[row_idx][col_idx]; } - wavefrom_.clear(); + wavefrom_.Resize(0); return true; } -// Compute spectrogram feat, return num frames +// only for test, remove later +// todo: compute the feature frame by frame. +void LinearSpectrogram::Compute(const kaldi::VectorBase& input, + kaldi::VectorBae* feature) { + base_extractor_->Compute(input, feature); +} + +// Compute spectrogram feat, only for test, remove later // todo: refactor later (SmileGoat) bool LinearSpectrogram::Compute(const vector& wave, - vector>& feat) { + vector>& feat) { int num_samples = wave.size(); const int& frame_length = opts.frame_opts.WindowSize(); const int& sample_rate = opts.frame_opts.samp_freq; diff --git a/speechx/speechx/frontend/linear_spectrogram.h b/speechx/speechx/frontend/linear_spectrogram.h index 981f92ea..16683890 100644 --- a/speechx/speechx/frontend/linear_spectrogram.h +++ b/speechx/speechx/frontend/linear_spectrogram.h @@ -19,16 +19,19 @@ struct LinearSpectrogramOptions { class LinearSpectrogram : public FeatureExtractorInterface { public: - explict LinearSpectrogram(const LinearSpectrogramOptions& opts); + explict LinearSpectrogram(const LinearSpectrogramOptions& opts, + const std::unique_ptr base_extractor); virtual void AcceptWavefrom(const kaldi::Vector& input); virtual void Read(kaldi::Vector* feat); - virtual size_t Dim() const; + virtual size_t Dim() const { return dim_; } void ReadFeats(kaldi::Matrix* feats) const; private: void Hanning(std::vector& data) const; kaldi::int32 Compute(const std::vector& wave, std::vector>& feat); + void Compute(const kaldi::VectorBase& input, + kaldi::VectorBae* feature); bool NumpyFft(std::vector* v, std::vector* real, std::vector* img) const; @@ -38,7 +41,8 @@ class LinearSpectrogram : public FeatureExtractorInterface { std::vector hanning_window_; kaldi::BaseFloat hanning_window_energy_; LinearSpectrogramOptions opts_; - std::vector wavefrom_; // remove later, todo(SmileGoat) + kaldi::Vector wavefrom_; // remove later, todo(SmileGoat) + std::unique_ptr base_extractor_; DISALLOW_COPY_AND_ASSIGN(LinearSpectrogram); }; diff --git a/speechx/speechx/frontend/linear_spectrogram_main.cc b/speechx/speechx/frontend/linear_spectrogram_main.cc index 455f4f91..352e7225 100644 --- a/speechx/speechx/frontend/linear_spectrogram_main.cc +++ b/speechx/speechx/frontend/linear_spectrogram_main.cc @@ -1,5 +1,7 @@ +// todo refactor, repalce with gtest #include "frontend/linear_spectrogram.h" +#include "frontend/normalizer.h" #include "kaldi/util/table-types.h" #include "base/log.h" #include "base/flags.h" @@ -15,9 +17,14 @@ int main(int argc, char* argv[]) { kaldi::SequentialTableReader wav_reader(FLAGS_wav_rspecifier); kaldi::BaseFloatMatrixWriter feat_writer(FLAGS_feature_wspecifier); + // test feature linear_spectorgram: wave --> decibel_normalizer --> hanning window -->linear_spectrogram --> cmvn int32 num_done = 0, num_err = 0; ppspeech::LinearSpectrogramOptions opt; - ppspeech::LinearSpectrogram linear_spectrogram(opt); + ppspeech::DecibelNormalizerOptions db_norm_opt; + std::unique_ptr base_feature_extractor = + new DecibelNormalizer(db_norm_opt); + ppspeech::LinearSpectrogram linear_spectrogram(opt, base_featrue_extractor); + for (; !wav_reader.Done(); wav_reader.Next()) { std::string utt = wav_reader.Key(); const kaldi::WaveData &wave_data = wav_reader.Value(); diff --git a/speechx/speechx/frontend/normalizer.cc b/speechx/speechx/frontend/normalizer.cc index 9a384484..ca27d6ac 100644 --- a/speechx/speechx/frontend/normalizer.cc +++ b/speechx/speechx/frontend/normalizer.cc @@ -2,8 +2,7 @@ #include "frontend/normalizer.h" DecibelNormalizer::DecibelNormalizer( - const DecibelNormalizerOptions& opts, - const std::unique_ptr& pre_extractor) { + const DecibelNormalizerOptions& opts) { } diff --git a/speechx/speechx/nnet/nnet_interface.h b/speechx/speechx/nnet/nnet_interface.h new file mode 100644 index 00000000..e999b8f0 --- /dev/null +++ b/speechx/speechx/nnet/nnet_interface.h @@ -0,0 +1,16 @@ + +#pragma once + +#include "" + +namespace ppspeech { + +class NnetForwardInterface { + public: + virtual ~NnetForwardInterface() {} + virtual void FeedForward(const kaldi::Matrix& features, + kaldi::Vector* inference) const = 0; + +}; + +} // namespace ppspeech \ No newline at end of file From c60277515bb05aa79343c0fecd4ba5737cd5067d Mon Sep 17 00:00:00 2001 From: SmileGoat Date: Sun, 13 Feb 2022 15:17:44 +0800 Subject: [PATCH 004/124] add frontend cmakelist --- docs/source/reference.md | 4 + speechx/CMakeLists.txt | 2 +- speechx/speechx/CMakeLists.txt | 11 +++ speechx/speechx/base/basic_types.h | 4 +- speechx/speechx/base/thread_pool.h | 22 ++++- .../feat_test}/linear_spectrogram_main.cc | 3 +- speechx/speechx/frontend/CMakeLists.txt | 8 ++ .../frontend/feature_extractor_interface.h | 8 +- .../speechx/frontend/linear_spectrogram.cc | 81 +++++++++++-------- speechx/speechx/frontend/linear_spectrogram.h | 24 +++--- speechx/speechx/frontend/normalizer.cc | 80 ++++++++++++------ speechx/speechx/frontend/normalizer.h | 20 +++-- 12 files changed, 179 insertions(+), 88 deletions(-) rename speechx/speechx/{frontend => codelab/feat_test}/linear_spectrogram_main.cc (94%) diff --git a/docs/source/reference.md b/docs/source/reference.md index a8327e92..f1a02d20 100644 --- a/docs/source/reference.md +++ b/docs/source/reference.md @@ -35,3 +35,7 @@ We borrowed a lot of code from these repos to build `model` and `engine`, thanks * [librosa](https://github.com/librosa/librosa/blob/main/LICENSE.md) - ISC License - Audio feature + +* [ThreadPool](https://github.com/progschj/ThreadPool/blob/master/COPYING) +- zlib License +- ThreadPool diff --git a/speechx/CMakeLists.txt b/speechx/CMakeLists.txt index e003136a..1876a4fa 100644 --- a/speechx/CMakeLists.txt +++ b/speechx/CMakeLists.txt @@ -65,7 +65,7 @@ FetchContent_Declare( URL_HASH SHA256=9e1b54eb2782f53cd8af107ecf08d2ab64b8d0dc2b7f5594472f3bd63ca85cdc ) FetchContent_MakeAvailable(glog) -include_directories(${glog_BINARY_DIR}) +include_directories(${glog_BINARY_DIR} ${glog_SOURCE_DIR}/src) # gtest FetchContent_Declare(googletest diff --git a/speechx/speechx/CMakeLists.txt b/speechx/speechx/CMakeLists.txt index 71c7eb7c..d05c7034 100644 --- a/speechx/speechx/CMakeLists.txt +++ b/speechx/speechx/CMakeLists.txt @@ -4,11 +4,22 @@ project(speechx LANGUAGES CXX) link_directories(${CMAKE_CURRENT_SOURCE_DIR}/third_party/openblas) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14") + include_directories( ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/kaldi ) add_subdirectory(kaldi) +include_directories( +${CMAKE_CURRENT_SOURCE_DIR} +${CMAKE_CURRENT_SOURCE_DIR}/frontend +) +add_subdirectory(frontend) + add_executable(mfcc-test codelab/feat_test/feature-mfcc-test.cc) target_link_libraries(mfcc-test kaldi-mfcc) + +add_executable(linear_spectrogram_main codelab/feat_test/linear_spectrogram_main.cc) +target_link_libraries(linear_spectrogram_main frontend kaildi-util kaldi-feat) diff --git a/speechx/speechx/base/basic_types.h b/speechx/speechx/base/basic_types.h index 1966c021..1186efd5 100644 --- a/speechx/speechx/base/basic_types.h +++ b/speechx/speechx/base/basic_types.h @@ -16,7 +16,7 @@ #include "kaldi/base/kaldi-types.h" -#include +#include typedef float BaseFloat; typedef double double64; @@ -35,7 +35,7 @@ typedef unsigned char uint8; typedef unsigned short uint16; typedef unsigned int uint32; -if defined(__LP64__) && !defined(OS_MACOSX) && !defined(OS_OPENBSD) +#if defined(__LP64__) && !defined(OS_MACOSX) && !defined(OS_OPENBSD) typedef unsigned long uint64; #else typedef unsigned long long uint64; diff --git a/speechx/speechx/base/thread_pool.h b/speechx/speechx/base/thread_pool.h index f6dada90..3405af9d 100644 --- a/speechx/speechx/base/thread_pool.h +++ b/speechx/speechx/base/thread_pool.h @@ -1,3 +1,23 @@ +// Copyright (c) 2012 Jakob Progsch, Václav Zeman + +// This software is provided 'as-is', without any express or implied +// warranty. In no event will the authors be held liable for any damages +// arising from the use of this software. + +// Permission is granted to anyone to use this software for any purpose, +// including commercial applications, and to alter it and redistribute it +// freely, subject to the following restrictions: + +// 1. The origin of this software must not be misrepresented; you must not +// claim that you wrote the original software. If you use this software +// in a product, an acknowledgment in the product documentation would be +// appreciated but is not required. + +// 2. Altered source versions must be plainly marked as such, and must not be +// misrepresented as being the original software. + +// 3. This notice may not be removed or altered from any source +// distribution. // this code is from https://github.com/progschj/ThreadPool #ifndef BASE_THREAD_POOL_H @@ -97,4 +117,4 @@ inline ThreadPool::~ThreadPool() worker.join(); } -#endif \ No newline at end of file +#endif diff --git a/speechx/speechx/frontend/linear_spectrogram_main.cc b/speechx/speechx/codelab/feat_test/linear_spectrogram_main.cc similarity index 94% rename from speechx/speechx/frontend/linear_spectrogram_main.cc rename to speechx/speechx/codelab/feat_test/linear_spectrogram_main.cc index 352e7225..de6d42ec 100644 --- a/speechx/speechx/frontend/linear_spectrogram_main.cc +++ b/speechx/speechx/codelab/feat_test/linear_spectrogram_main.cc @@ -2,6 +2,7 @@ #include "frontend/linear_spectrogram.h" #include "frontend/normalizer.h" +#include "frontend/feature_extractor_interface.h" #include "kaldi/util/table-types.h" #include "base/log.h" #include "base/flags.h" @@ -22,7 +23,7 @@ int main(int argc, char* argv[]) { ppspeech::LinearSpectrogramOptions opt; ppspeech::DecibelNormalizerOptions db_norm_opt; std::unique_ptr base_feature_extractor = - new DecibelNormalizer(db_norm_opt); + new ppspeech::DecibelNormalizer(db_norm_opt); ppspeech::LinearSpectrogram linear_spectrogram(opt, base_featrue_extractor); for (; !wav_reader.Done(); wav_reader.Next()) { diff --git a/speechx/speechx/frontend/CMakeLists.txt b/speechx/speechx/frontend/CMakeLists.txt index e69de29b..48a5267b 100644 --- a/speechx/speechx/frontend/CMakeLists.txt +++ b/speechx/speechx/frontend/CMakeLists.txt @@ -0,0 +1,8 @@ +project(frontend) + +add_library(frontend + normalizer.cc + linear_spectrogram.cc +) + +target_link_libraries(frontend kaldi-matrix) \ No newline at end of file diff --git a/speechx/speechx/frontend/feature_extractor_interface.h b/speechx/speechx/frontend/feature_extractor_interface.h index 863c4281..7395b792 100644 --- a/speechx/speechx/frontend/feature_extractor_interface.h +++ b/speechx/speechx/frontend/feature_extractor_interface.h @@ -15,16 +15,14 @@ #pragma once #include "base/basic_types.h" -#incldue "kaldi/matrix/kaldi-vector.h" +#include "kaldi/matrix/kaldi-vector.h" namespace ppspeech { class FeatureExtractorInterface { public: - virtual void AcceptWaveform(const kaldi::Vector& input) = 0; - virtual void Read(kaldi::Vector* feat) = 0; - virtual void Compute(const kaldi::VectorBase& input, - kaldi::VectorBae* feature) = 0; + virtual void AcceptWaveform(const kaldi::VectorBase& input) = 0; + virtual void Read(kaldi::VectorBase* feat) = 0; virtual size_t Dim() const = 0; }; diff --git a/speechx/speechx/frontend/linear_spectrogram.cc b/speechx/speechx/frontend/linear_spectrogram.cc index 3545cb53..a1d72599 100644 --- a/speechx/speechx/frontend/linear_spectrogram.cc +++ b/speechx/speechx/frontend/linear_spectrogram.cc @@ -16,15 +16,36 @@ #include "kaldi/base/kaldi-math.h" #include "kaldi/matrix/matrix-functions.h" +namespace ppspeech { + using kaldi::int32; using kaldi::BaseFloat; using kaldi::Vector; using kaldi::Matrix; using std::vector; +//todo remove later +void CopyVector2StdVector(const kaldi::Vector& input, + vector* output) { + if (input.Dim() == 0) return; + output->resize(input.Dim()); + for (size_t idx = 0; idx < input.Dim(); ++idx) { + (*output)[idx] = input(idx); + } +} + +void CopyStdVector2Vector(const vector& input, + Vector* output) { + if (input.empty()) return; + output->Resize(input.size()); + for (size_t idx = 0; idx < input.size(); ++idx) { + (*output)(idx) = input[idx]; + } +} + LinearSpectrogram::LinearSpectrogram( const LinearSpectrogramOptions& opts, - const std::unique_ptr base_extractor) { + std::unique_ptr base_extractor) { base_extractor_ = std::move(base_extractor); int32 window_size = opts.frame_opts.WindowSize(); int32 window_shift = opts.frame_opts.WindowShift(); @@ -41,11 +62,8 @@ LinearSpectrogram::LinearSpectrogram( dim_ = fft_points_ / 2 + 1; // the dimension is Fs/2 Hz } -void LinearSpectrogram::AcceptWavefrom(const Vector& input) { - wavefrom_.resize(input.Dim()); - for (size_t idx = 0; idx < input.Dim(); ++idx) { - waveform_[idx] = input(idx); - } +void LinearSpectrogram::AcceptWavefrom(const kaldi::VectorBase& input) { + base_extractor_->AcceptWaveform(input); } void LinearSpectrogram::Hanning(vector* data) const { @@ -58,11 +76,11 @@ void LinearSpectrogram::Hanning(vector* data) const { bool LinearSpectrogram::NumpyFft(vector* v, vector* real, - vector* img) { - if (RealFft(v, true)) { - LOG(ERROR) << "compute the fft occurs error"; - return false; - } + vector* img) const { + Vector v_tmp; + CopyStdVector2Vector(*v, &v_tmp); + RealFft(&v_tmp, true); + CopyVector2StdVector(v_tmp, v); real->push_back(v->at(0)); img->push_back(0); for (int i = 1; i < v->size() / 2; i++) { @@ -75,36 +93,28 @@ bool LinearSpectrogram::NumpyFft(vector* v, return true; } -//todo remove later -void CopyVector2StdVector(const kaldi::Vector& input, - vector* output) { -} - // todo remove later -bool LinearSpectrogram::ReadFeats(Matrix* feats) const { - if (wavefrom_.Dim() == 0) { - return false; - } - kaldi::Vector feats; - Compute(wavefrom_, &feats); +void LinearSpectrogram::ReadFeats(Matrix* feats) { + Vector tmp; + Compute(tmp, &waveform_); vector> result; vector feats_vec; - CopyVector2StdVector(feats, &feats_vec); + CopyVector2StdVector(waveform_, &feats_vec); Compute(feats_vec, result); feats->Resize(result.size(), result[0].size()); for (int row_idx = 0; row_idx < result.size(); ++row_idx) { for (int col_idx = 0; col_idx < result.size(); ++col_idx) { - feats(row_idx, col_idx) = result[row_idx][col_idx]; + (*feats)(row_idx, col_idx) = result[row_idx][col_idx]; + } } - wavefrom_.Resize(0); - return true; + waveform_.Resize(0); } // only for test, remove later // todo: compute the feature frame by frame. -void LinearSpectrogram::Compute(const kaldi::VectorBase& input, - kaldi::VectorBae* feature) { - base_extractor_->Compute(input, feature); +void LinearSpectrogram::Compute(const kaldi::Vector& input, + kaldi::Vector* feature) { + base_extractor_->Read(feature); } // Compute spectrogram feat, only for test, remove later @@ -112,9 +122,9 @@ void LinearSpectrogram::Compute(const kaldi::VectorBase& input bool LinearSpectrogram::Compute(const vector& wave, vector>& feat) { int num_samples = wave.size(); - const int& frame_length = opts.frame_opts.WindowSize(); - const int& sample_rate = opts.frame_opts.samp_freq; - const int& frame_shift = opts.frame_opts.WindowShift(); + const int& frame_length = opts_.frame_opts.WindowSize(); + const int& sample_rate = opts_.frame_opts.samp_freq; + const int& frame_shift = opts_.frame_opts.WindowShift(); const int& fft_points = fft_points_; const float scale = hanning_window_energy_ * frame_shift; @@ -132,11 +142,11 @@ bool LinearSpectrogram::Compute(const vector& wave, for (int i = 0; i < num_frames; ++i) { vector data(wave.data() + i * frame_shift, wave.data() + i * frame_shift + frame_length); - Hanning(data); + Hanning(&data); fft_img.clear(); fft_real.clear(); v.assign(data.begin(), data.end()); - if (NumpyFft(&v, fft_real, fft_img)) { + if (NumpyFft(&v, &fft_real, &fft_img)) { LOG(ERROR)<< i << " fft compute occurs error, please checkout the input data"; return false; } @@ -155,5 +165,8 @@ bool LinearSpectrogram::Compute(const vector& wave, // log added eps=1e-14 feat[i][j] = std::log(feat[i][j] + 1e-14); } + } return true; } + +} // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/frontend/linear_spectrogram.h b/speechx/speechx/frontend/linear_spectrogram.h index 16683890..3e541f7f 100644 --- a/speechx/speechx/frontend/linear_spectrogram.h +++ b/speechx/speechx/frontend/linear_spectrogram.h @@ -8,7 +8,7 @@ namespace ppspeech { struct LinearSpectrogramOptions { - kaldi::FrameExtrationOptions frame_opts; + kaldi::FrameExtractionOptions frame_opts; LinearSpectrogramOptions(): frame_opts() {} @@ -19,19 +19,19 @@ struct LinearSpectrogramOptions { class LinearSpectrogram : public FeatureExtractorInterface { public: - explict LinearSpectrogram(const LinearSpectrogramOptions& opts, - const std::unique_ptr base_extractor); - virtual void AcceptWavefrom(const kaldi::Vector& input); - virtual void Read(kaldi::Vector* feat); + explicit LinearSpectrogram(const LinearSpectrogramOptions& opts, + std::unique_ptr base_extractor); + virtual void AcceptWavefrom(const kaldi::VectorBase& input); + virtual void Read(kaldi::VectorBase* feat); virtual size_t Dim() const { return dim_; } - void ReadFeats(kaldi::Matrix* feats) const; + void ReadFeats(kaldi::Matrix* feats); private: - void Hanning(std::vector& data) const; - kaldi::int32 Compute(const std::vector& wave, - std::vector>& feat); - void Compute(const kaldi::VectorBase& input, - kaldi::VectorBae* feature); + void Hanning(std::vector* data) const; + bool Compute(const std::vector& wave, + std::vector>& feat); + void Compute(const kaldi::Vector& input, + kaldi::Vector* feature); bool NumpyFft(std::vector* v, std::vector* real, std::vector* img) const; @@ -41,7 +41,7 @@ class LinearSpectrogram : public FeatureExtractorInterface { std::vector hanning_window_; kaldi::BaseFloat hanning_window_energy_; LinearSpectrogramOptions opts_; - kaldi::Vector wavefrom_; // remove later, todo(SmileGoat) + kaldi::Vector waveform_; // remove later, todo(SmileGoat) std::unique_ptr base_extractor_; DISALLOW_COPY_AND_ASSIGN(LinearSpectrogram); }; diff --git a/speechx/speechx/frontend/normalizer.cc b/speechx/speechx/frontend/normalizer.cc index ca27d6ac..dde4a98a 100644 --- a/speechx/speechx/frontend/normalizer.cc +++ b/speechx/speechx/frontend/normalizer.cc @@ -1,35 +1,62 @@ #include "frontend/normalizer.h" -DecibelNormalizer::DecibelNormalizer( - const DecibelNormalizerOptions& opts) { +namespace ppspeech { +using kaldi::Vector; +using kaldi::BaseFloat; +using std::vector; + +DecibelNormalizer::DecibelNormalizer(const DecibelNormalizerOptions& opts) { + opts_ = opts; } -void DecibelNormalizer::AcceptWavefrom(const kaldi::Vector& input) { +void DecibelNormalizer::AcceptWavefrom(const Vector& input) { + waveform_ = input; +} +void DecibelNormalizer::Read(Vector* feat) { + if (waveform_.Dim() == 0) return; + Compute(waveform_, feat); } -void DecibelNormalizer::Read(kaldi::Vector* feat) { +//todo remove later +void CopyVector2StdVector(const kaldi::Vector& input, + vector* output) { + if (input.Dim() == 0) return; + output->resize(input.Dim()); + for (size_t idx = 0; idx < input.Dim(); ++idx) { + (*output)[idx] = input(idx); + } +} +void CopyStdVector2Vector(const vector& input, + Vector* output) { + if (input.empty()) return; + output->Resize(input.size()); + for (size_t idx = 0; idx < input.size(); ++idx) { + (*output)(idx) = input[idx]; + } } -bool DecibelNormalizer::Compute(const Vector& input, - kaldi::Vector* feat) { +bool DecibelNormalizer::Compute(const Vector& input, + Vector* feat) const { // calculate db rms - float rms_db = 0.0; - float mean_square = 0.0; - float gain = 0.0; - vector smaples; - samples.resize(input.Size()); + BaseFloat rms_db = 0.0; + BaseFloat mean_square = 0.0; + BaseFloat gain = 0.0; + BaseFloat wave_float_normlization = 1.0f / (std::pow(2, 16 - 1)); + + vector samples; + samples.resize(input.Dim()); for (int32 i = 0; i < samples.size(); ++i) { samples[i] = input(i); } // square for (auto &d : samples) { - if (_opts.convert_int_float) { - d = d * WAVE_FLOAT_NORMALIZATION; + if (opts_.convert_int_float) { + d = d * wave_float_normlization; } mean_square += d * d; } @@ -37,12 +64,12 @@ bool DecibelNormalizer::Compute(const Vector& input, // mean mean_square /= samples.size(); rms_db = 10 * std::log10(mean_square); - gain = opts.target_db - rms_db; + gain = opts_.target_db - rms_db; - if (gain > opts.max_gain_db) { - LOG(ERROR) << "Unable to normalize segment to " << opts.target_db << "dB," - << "because the the probable gain have exceeds opts.max_gain_db" - << opts.max_gain_db << "dB."; + if (gain > opts_.max_gain_db) { + LOG(ERROR) << "Unable to normalize segment to " << opts_.target_db << "dB," + << "because the the probable gain have exceeds opts_.max_gain_db" + << opts_.max_gain_db << "dB."; return false; } @@ -51,27 +78,28 @@ bool DecibelNormalizer::Compute(const Vector& input, // python item *= 10.0 ** (gain / 20.0) item *= std::pow(10.0, gain / 20.0); } - + + CopyStdVector2Vector(samples, feat); return true; } - +/* PPNormalizer::PPNormalizer( const PPNormalizerOptions& opts, const std::unique_ptr& pre_extractor) { } -void PPNormalizer::AcceptWavefrom(const kaldi::Vector& input) { +void PPNormalizer::AcceptWavefrom(const Vector& input) { } -void PPNormalizer::Read(kaldi::Vector* feat) { +void PPNormalizer::Read(Vector* feat) { } -bool PPNormalizer::Compute(const Vector& input, - kaldi::Vector>* feat) { +bool PPNormalizer::Compute(const Vector& input, + Vector>* feat) { if ((input.Dim() % mean_.Dim()) == 0) { LOG(ERROR) << "CMVN dimension is wrong!"; return false; @@ -93,4 +121,6 @@ bool PPNormalizer::Compute(const Vector& input, } return true; -} +}*/ + +} // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/frontend/normalizer.h b/speechx/speechx/frontend/normalizer.h index f297403b..4e914456 100644 --- a/speechx/speechx/frontend/normalizer.h +++ b/speechx/speechx/frontend/normalizer.h @@ -1,7 +1,9 @@ #pragma once +#include "base/common.h" #include "frontend/feature_extractor_interface.h" +#include "kaldi/util/options-itf.h" namespace ppspeech { @@ -9,6 +11,7 @@ namespace ppspeech { struct DecibelNormalizerOptions { float target_db; float max_gain_db; + bool convert_int_float; DecibelNormalizerOptions() : target_db(-20), max_gain_db(300.0), @@ -23,16 +26,19 @@ struct DecibelNormalizerOptions { class DecibelNormalizer : public FeatureExtractorInterface { public: - explict DecibelNormalizer(const DecibelNormalizerOptions& opts, - const std::unique_ptr& pre_extractor); - virtual void AcceptWavefrom(const kaldi::Vector& input); - virtual void Read(kaldi::Vector* feat); - virtual size_t Dim() const; + explicit DecibelNormalizer(const DecibelNormalizerOptions& opts); + virtual void AcceptWavefrom(const kaldi::VectorBase& input); + virtual void Read(kaldi::VectorBase* feat); + virtual size_t Dim() const { return 0; } bool Compute(const kaldi::Vector& input, - kaldi::Vector>* feat); + kaldi::Vector* feat) const; private: + DecibelNormalizerOptions opts_; + std::unique_ptr base_extractor_; + kaldi::Vector waveform_; }; +/* struct NormalizerOptions { std::string mean_std_path; NormalizerOptions() : @@ -61,5 +67,5 @@ class PPNormalizer : public FeatureExtractorInterface { kaldi::Vector variance_; NormalizerOptions _opts; }; - +*/ } // namespace ppspeech \ No newline at end of file From 42c8d0dd97cfc1437fb48fa4f50833022ff31d57 Mon Sep 17 00:00:00 2001 From: SmileGoat Date: Sun, 13 Feb 2022 17:59:48 +0800 Subject: [PATCH 005/124] fix typo & make build success --- speechx/speechx/CMakeLists.txt | 2 +- speechx/speechx/base/flags.h | 2 +- .../feat_test/linear_spectrogram_main.cc | 8 +++---- .../speechx/frontend/linear_spectrogram.cc | 23 ++++++++++++------- speechx/speechx/frontend/linear_spectrogram.h | 6 ++--- speechx/speechx/frontend/normalizer.cc | 20 +++++++++------- speechx/speechx/frontend/normalizer.h | 7 +++--- 7 files changed, 40 insertions(+), 28 deletions(-) diff --git a/speechx/speechx/CMakeLists.txt b/speechx/speechx/CMakeLists.txt index d05c7034..25e7b1e3 100644 --- a/speechx/speechx/CMakeLists.txt +++ b/speechx/speechx/CMakeLists.txt @@ -22,4 +22,4 @@ add_executable(mfcc-test codelab/feat_test/feature-mfcc-test.cc) target_link_libraries(mfcc-test kaldi-mfcc) add_executable(linear_spectrogram_main codelab/feat_test/linear_spectrogram_main.cc) -target_link_libraries(linear_spectrogram_main frontend kaildi-util kaldi-feat) +target_link_libraries(linear_spectrogram_main frontend kaldi-util kaldi-feat-common gflags glog) diff --git a/speechx/speechx/base/flags.h b/speechx/speechx/base/flags.h index 41df0d45..2808fac3 100644 --- a/speechx/speechx/base/flags.h +++ b/speechx/speechx/base/flags.h @@ -14,4 +14,4 @@ #pragma once -#include "fst/flags.h" +#include "gflags/gflags.h" diff --git a/speechx/speechx/codelab/feat_test/linear_spectrogram_main.cc b/speechx/speechx/codelab/feat_test/linear_spectrogram_main.cc index de6d42ec..00162abe 100644 --- a/speechx/speechx/codelab/feat_test/linear_spectrogram_main.cc +++ b/speechx/speechx/codelab/feat_test/linear_spectrogram_main.cc @@ -12,7 +12,7 @@ DEFINE_string(wav_rspecifier, "", "test wav path"); DEFINE_string(feature_wspecifier, "", "test wav ark"); int main(int argc, char* argv[]) { - google::ParseCommandLineFlags(&argc, &argv, false); + gflags::ParseCommandLineFlags(&argc, &argv, false); google::InitGoogleLogging(argv[0]); kaldi::SequentialTableReader wav_reader(FLAGS_wav_rspecifier); @@ -22,9 +22,9 @@ int main(int argc, char* argv[]) { int32 num_done = 0, num_err = 0; ppspeech::LinearSpectrogramOptions opt; ppspeech::DecibelNormalizerOptions db_norm_opt; - std::unique_ptr base_feature_extractor = - new ppspeech::DecibelNormalizer(db_norm_opt); - ppspeech::LinearSpectrogram linear_spectrogram(opt, base_featrue_extractor); + std::unique_ptr base_feature_extractor( + new ppspeech::DecibelNormalizer(db_norm_opt)); + ppspeech::LinearSpectrogram linear_spectrogram(opt, std::move(base_feature_extractor)); for (; !wav_reader.Done(); wav_reader.Next()) { std::string utt = wav_reader.Key(); diff --git a/speechx/speechx/frontend/linear_spectrogram.cc b/speechx/speechx/frontend/linear_spectrogram.cc index a1d72599..a23b4494 100644 --- a/speechx/speechx/frontend/linear_spectrogram.cc +++ b/speechx/speechx/frontend/linear_spectrogram.cc @@ -21,11 +21,12 @@ namespace ppspeech { using kaldi::int32; using kaldi::BaseFloat; using kaldi::Vector; +using kaldi::VectorBase; using kaldi::Matrix; using std::vector; //todo remove later -void CopyVector2StdVector(const kaldi::Vector& input, +void CopyVector2StdVector_(const VectorBase& input, vector* output) { if (input.Dim() == 0) return; output->resize(input.Dim()); @@ -34,7 +35,7 @@ void CopyVector2StdVector(const kaldi::Vector& input, } } -void CopyStdVector2Vector(const vector& input, +void CopyStdVector2Vector_(const vector& input, Vector* output) { if (input.empty()) return; output->Resize(input.size()); @@ -62,7 +63,7 @@ LinearSpectrogram::LinearSpectrogram( dim_ = fft_points_ / 2 + 1; // the dimension is Fs/2 Hz } -void LinearSpectrogram::AcceptWavefrom(const kaldi::VectorBase& input) { +void LinearSpectrogram::AcceptWaveform(const VectorBase& input) { base_extractor_->AcceptWaveform(input); } @@ -78,9 +79,9 @@ bool LinearSpectrogram::NumpyFft(vector* v, vector* real, vector* img) const { Vector v_tmp; - CopyStdVector2Vector(*v, &v_tmp); + CopyStdVector2Vector_(*v, &v_tmp); RealFft(&v_tmp, true); - CopyVector2StdVector(v_tmp, v); + CopyVector2StdVector_(v_tmp, v); real->push_back(v->at(0)); img->push_back(0); for (int i = 1; i < v->size() / 2; i++) { @@ -96,10 +97,11 @@ bool LinearSpectrogram::NumpyFft(vector* v, // todo remove later void LinearSpectrogram::ReadFeats(Matrix* feats) { Vector tmp; + waveform_.Resize(base_extractor_->Dim()); Compute(tmp, &waveform_); vector> result; vector feats_vec; - CopyVector2StdVector(waveform_, &feats_vec); + CopyVector2StdVector_(waveform_, &feats_vec); Compute(feats_vec, result); feats->Resize(result.size(), result[0].size()); for (int row_idx = 0; row_idx < result.size(); ++row_idx) { @@ -110,10 +112,15 @@ void LinearSpectrogram::ReadFeats(Matrix* feats) { waveform_.Resize(0); } +void LinearSpectrogram::Read(VectorBase* feat) { + // todo + return; +} + // only for test, remove later // todo: compute the feature frame by frame. -void LinearSpectrogram::Compute(const kaldi::Vector& input, - kaldi::Vector* feature) { +void LinearSpectrogram::Compute(const VectorBase& input, + VectorBase* feature) { base_extractor_->Read(feature); } diff --git a/speechx/speechx/frontend/linear_spectrogram.h b/speechx/speechx/frontend/linear_spectrogram.h index 3e541f7f..0923acee 100644 --- a/speechx/speechx/frontend/linear_spectrogram.h +++ b/speechx/speechx/frontend/linear_spectrogram.h @@ -21,7 +21,7 @@ class LinearSpectrogram : public FeatureExtractorInterface { public: explicit LinearSpectrogram(const LinearSpectrogramOptions& opts, std::unique_ptr base_extractor); - virtual void AcceptWavefrom(const kaldi::VectorBase& input); + virtual void AcceptWaveform(const kaldi::VectorBase& input); virtual void Read(kaldi::VectorBase* feat); virtual size_t Dim() const { return dim_; } void ReadFeats(kaldi::Matrix* feats); @@ -30,8 +30,8 @@ class LinearSpectrogram : public FeatureExtractorInterface { void Hanning(std::vector* data) const; bool Compute(const std::vector& wave, std::vector>& feat); - void Compute(const kaldi::Vector& input, - kaldi::Vector* feature); + void Compute(const kaldi::VectorBase& input, + kaldi::VectorBase* feature); bool NumpyFft(std::vector* v, std::vector* real, std::vector* img) const; diff --git a/speechx/speechx/frontend/normalizer.cc b/speechx/speechx/frontend/normalizer.cc index dde4a98a..04e88bf4 100644 --- a/speechx/speechx/frontend/normalizer.cc +++ b/speechx/speechx/frontend/normalizer.cc @@ -4,24 +4,28 @@ namespace ppspeech { using kaldi::Vector; +using kaldi::VectorBase; using kaldi::BaseFloat; using std::vector; DecibelNormalizer::DecibelNormalizer(const DecibelNormalizerOptions& opts) { opts_ = opts; + dim_ = 0; } -void DecibelNormalizer::AcceptWavefrom(const Vector& input) { - waveform_ = input; +void DecibelNormalizer::AcceptWaveform(const kaldi::VectorBase& input) { + dim_ = input.Dim(); + waveform_.Resize(input.Dim()); + waveform_.CopyFromVec(input); } -void DecibelNormalizer::Read(Vector* feat) { +void DecibelNormalizer::Read(kaldi::VectorBase* feat) { if (waveform_.Dim() == 0) return; Compute(waveform_, feat); } //todo remove later -void CopyVector2StdVector(const kaldi::Vector& input, +void CopyVector2StdVector(const kaldi::VectorBase& input, vector* output) { if (input.Dim() == 0) return; output->resize(input.Dim()); @@ -31,16 +35,16 @@ void CopyVector2StdVector(const kaldi::Vector& input, } void CopyStdVector2Vector(const vector& input, - Vector* output) { + VectorBase* output) { if (input.empty()) return; - output->Resize(input.size()); + assert(input.size() == output->Dim()); for (size_t idx = 0; idx < input.size(); ++idx) { (*output)(idx) = input[idx]; } } -bool DecibelNormalizer::Compute(const Vector& input, - Vector* feat) const { +bool DecibelNormalizer::Compute(const VectorBase& input, + VectorBase* feat) const { // calculate db rms BaseFloat rms_db = 0.0; BaseFloat mean_square = 0.0; diff --git a/speechx/speechx/frontend/normalizer.h b/speechx/speechx/frontend/normalizer.h index 4e914456..3bf36cfc 100644 --- a/speechx/speechx/frontend/normalizer.h +++ b/speechx/speechx/frontend/normalizer.h @@ -27,13 +27,14 @@ struct DecibelNormalizerOptions { class DecibelNormalizer : public FeatureExtractorInterface { public: explicit DecibelNormalizer(const DecibelNormalizerOptions& opts); - virtual void AcceptWavefrom(const kaldi::VectorBase& input); + virtual void AcceptWaveform(const kaldi::VectorBase& input); virtual void Read(kaldi::VectorBase* feat); virtual size_t Dim() const { return 0; } - bool Compute(const kaldi::Vector& input, - kaldi::Vector* feat) const; + bool Compute(const kaldi::VectorBase& input, + kaldi::VectorBase* feat) const; private: DecibelNormalizerOptions opts_; + size_t dim_; std::unique_ptr base_extractor_; kaldi::Vector waveform_; }; From e57efcb314d5d5a8aabb63b528b1e7c24365a237 Mon Sep 17 00:00:00 2001 From: SmileGoat Date: Tue, 15 Feb 2022 19:56:14 +0800 Subject: [PATCH 006/124] add nnet module --- speechx/CMakeLists.txt | 1 + speechx/speechx/nnet/nnet_interface.h | 7 +- speechx/speechx/nnet/paddle_nnet.cc | 179 ++++++++++++++++++++++++++ speechx/speechx/nnet/paddle_nnet.h | 110 ++++++++++++++++ 4 files changed, 294 insertions(+), 3 deletions(-) create mode 100644 speechx/speechx/nnet/paddle_nnet.cc create mode 100644 speechx/speechx/nnet/paddle_nnet.h diff --git a/speechx/CMakeLists.txt b/speechx/CMakeLists.txt index 1876a4fa..ac3c683d 100644 --- a/speechx/CMakeLists.txt +++ b/speechx/CMakeLists.txt @@ -39,6 +39,7 @@ FetchContent_Declare( GIT_TAG "20210324.1" ) FetchContent_MakeAvailable(absl) +include_directories(${absl_SOURCE_DIR}/absl) # libsndfile include(FetchContent) diff --git a/speechx/speechx/nnet/nnet_interface.h b/speechx/speechx/nnet/nnet_interface.h index e999b8f0..c32774fc 100644 --- a/speechx/speechx/nnet/nnet_interface.h +++ b/speechx/speechx/nnet/nnet_interface.h @@ -1,15 +1,16 @@ #pragma once -#include "" +#include "base/basic_types.h" +#include "kaldi/base/kaldi-types.h" namespace ppspeech { -class NnetForwardInterface { +class NnetInterface { public: virtual ~NnetForwardInterface() {} virtual void FeedForward(const kaldi::Matrix& features, - kaldi::Vector* inference) const = 0; + kaldi::Matrix* inferences) const = 0; }; diff --git a/speechx/speechx/nnet/paddle_nnet.cc b/speechx/speechx/nnet/paddle_nnet.cc new file mode 100644 index 00000000..d6f82619 --- /dev/null +++ b/speechx/speechx/nnet/paddle_nnet.cc @@ -0,0 +1,179 @@ +#include "nnet/paddle_nnet.h" +#include "absl/strings/str_split.h" + +namespace ppspeech { + +void PaddleNnet::init_cache_encouts(const ModelOptions& opts) { + std::vector cache_names; + cache_names = absl::StrSplit(opts.cache_names, ", "); + std::vector cache_shapes; + cache_shapes = absl::StrSplit(opts.cache_shape, ", "); + assert(cache_shapes.size() == cache_names.size()); + + for (size_t i = 0; i < cache_shapes.size(); i++) { + std::vector tmp_shape; + tmp_shape = absl::StrSplit(cache_shapes[i], "- "); + std::vector cur_shape; + std::transform(tmp_shape.begin(), tmp_shape.end(), + std::back_inserter(cur_shape), + [](const std::string& s) { + return atoi(s.c_str()); + }); + cache_names_idx_[cache_names[i]] = i; + std::shared_ptr> cache_eout = std::make_shared>(cur_shape); + cache_encouts_.push_back(cache_eout); + } +} + +PaddleNet::PaddleNnet(const ModelOptions& opts) { + paddle_infer::Config config; + config.SetModel(opts.model_path, opts.params_path); + if (opts.use_gpu) { + config.EnableUseGpu(500, 0); + } + config.SwitchIrOptim(opts.switch_ir_optim); + if (opts.enbale_fc_padding) { + config.DisableFCPadding(); + } + if (opts.enable_profile) { + config.EnableProfile(); + } + pool.reset(new paddle_infer::services::PredictorPool(config, opts.thread_num)); + if (pool == nullptr) { + LOG(ERROR) << "create the predictor pool failed"; + } + pool_usages.resize(num_thread); + std::fill(pool_usages.begin(), pool_usages.end(), false); + LOG(INFO) << "load paddle model success"; + + LOG(INFO) << "start to check the predictor input and output names"; + LOG(INFO) << "input names: " << opts.input_names; + LOG(INFO) << "output names: " << opts.output_names; + vector input_names_vec = absl::StrSplit(opts.input_names, ", "); + vector output_names_vec = absl::StrSplit(opts.output_names, ", "); + paddle_infer::Predictor* predictor = get_predictor(); + + std::vector model_input_names = predictor->GetInputNames(); + assert(input_names_vec.size() == model_input_names.size()); + for (size_t i = 0; i < model_input_names.size(); i++) { + assert(input_names_vec[i] == model_input_names[i]); + } + + std::vector model_output_names = predictor->GetOutputNames(); + assert(output_names_vec.size() == model_output_names.size()); + for (size_t i = 0;i < output_names_vec.size(); i++) { + assert(output_names_vec[i] == model_output_names[i]); + } + release_predictor(predictor); + + init_cache_encouts(opts); +} + +paddle_infer::Predictor* PaddleNnet::get_predictor() { + LOG(INFO) << "attempt to get a new predictor instance " << std::endl; + paddle_infer::Predictor* predictor = nullptr; + std::lock_guard guard(pool_mutex); + int pred_id = 0; + + while (pred_id < pool_usages.size()) { + if (pool_usages[pred_id] == false) { + predictor = pool->Retrive(pred_id); + break; + } + ++pred_id; + } + + if (predictor) { + pool_usages[pred_id] = true; + predictor_to_thread_id[predictor] = pred_id; + LOG(INFO) << pred_id << " predictor create success"; + } else { + LOG(INFO) << "Failed to get predictor from pool !!!"; + } + + return predictor; +} + +int PaddleNnet::ReleasePredictor(paddle_infer::Predictor* predictor) { + LOG(INFO) << "attempt to releae a predictor"; + std::lock_guard guard(pool_mutex); + auto iter = predictor_to_thread_id.find(predictor); + + if (iter == predictor_to_thread_id.end()) { + LOG(INFO) << "there is no such predictor"; + return 0; + } + + LOG(INFO) << iter->second << " predictor will be release"; + pool_usages[iter->second] = false; + predictor_to_thread_id.erase(predictor); + LOG(INFO) << "release success"; + return 0; +} + + + +shared_ptr> PaddleNnet::GetCacheEncoder(const string& name) { + auto iter = cache_names_idx_.find(name); + if (iter == cache_names_idx_.end()) { + return nullptr; + } + assert(iter->second < cache_encouts_.size()); + return cache_encouts_[iter->second].get(); +} + +void PaddleNet::FeedForward(const Matrix& features, Matrix* inferences) const { + + // 1. 得到所有的 input tensor 的名称 + int row = features.NumRows(); + int col = features.NumCols(); + std::vector input_names = predictor->GetInputNames(); + std::vector output_names = predictor->GetOutputNames(); + LOG(INFO) << "feat info: row=" << row << ", col=" << col; + + std::unique_ptr input_tensor = predictor->GetInputHandle(input_names[0]); + std::vector INPUT_SHAPE = {1, row, col}; + input_tensor->Reshape(INPUT_SHAPE); + input_tensor->CopyFromCpu(features.Data()); + // 3. 输入每个音频帧数 + std::unique_ptr input_len = predictor->GetInputHandle(input_names[1]); + std::vector input_len_size = {1}; + input_len->Reshape(input_len_size); + std::vector audio_len; + audio_len.push_back(row); + input_len->CopyFromCpu(audio_len.data()); + // 输入流式的缓存数据 + std::unique_ptr h_box = predictor->GetInputHandle(input_names[2]); + share_ptr> h_cache = GetCacheEncoder(input_names[2])); + h_box->Reshape(h_cache->get_shape()); + h_box->CopyFromCpu(h_cache->get_data().data()); + std::unique_ptr c_box = predictor->GetInputHandle(input_names[3]); + share_ptr> c_cache = GetCacheEncoder(input_names[3]); + c_box->Reshape(c_cache->get_shape()); + c_box->CopyFromCpu(c_cache->get_data().data()); + std::thread::id this_id = std::this_thread::get_id(); + LOG(INFO) << this_id << " start to compute the probability"; + bool success = predictor->Run(); + + if (success == false) { + LOG(INFO) << "predictor run occurs error"; + } + + LOG(INFO) << "get the model success"; + std::unique_ptr h_out = predictor->GetOutputHandle(output_names[2]); + assert(h_cache->get_shape() == h_out->shape()); + h_out->CopyToCpu(h_cache->get_data().data()); + std::unique_ptr c_out = predictor->GetOutputHandle(output_names[3]); + assert(c_cache->get_shape() == c_out->shape()); + c_out->CopyToCpu(c_cache->get_data().data()); + // 5. 得到最后的输出结果 + std::unique_ptr output_tensor = + predictor->GetOutputHandle(output_names[0]); + std::vector output_shape = output_tensor->shape(); + row = output_shape[1]; + col = output_shape[2]; + inference.Resize(row, col); + output_tensor->CopyToCpu(inference.Data()); +} + +} // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/nnet/paddle_nnet.h b/speechx/speechx/nnet/paddle_nnet.h new file mode 100644 index 00000000..1b3cad97 --- /dev/null +++ b/speechx/speechx/nnet/paddle_nnet.h @@ -0,0 +1,110 @@ + +#pragma once + +#include "nnet/nnet_interface.h" +#include "base/common.h" +#include "paddle/paddle_inference_api.h" + + +namespace ppspeech { + +struct ModelOptions { + std::string model_path; + std::string params_path; + int thread_num; + bool use_gpu; + bool switch_ir_optim; + std::string input_names; + std::string output_names; + std::string cache_names; + std::string cache_shape; + bool enable_fc_padding; + bool enable_profile; + ModelDecoderOptions() : + model_path("model/final.zip"), + params_path("model/avg_1.jit.pdmodel"), + thread_num(2), + use_gpu(false), + input_names("audio"), + output_names("probs"), + cache_names("enouts"), + cache_shape("1-1-1"), + switch_ir_optim(false), + enable_fc_padding(false), + enable_profile(false) { + } + + void Register(kaldi::OptionsItf* opts) { + opts->Register("model-path", &model_path, "model file path"); + opts->Register("model-params", ¶ms_path, "params model file path"); + opts->Register("thread-num", &thread_num, "thread num"); + opts->Register("use-gpu", &use_gpu, "if use gpu"); + opts->Register("input-names", &input_names, "paddle input names"); + opts->Register("output-names", &output_names, "paddle output names"); + opts->Register("cache-names", &cache_names, "cache names"); + opts->Register("cache-shape", &cache_shape, "cache shape"); + opts->Register("switch-ir-optiom", &switch_ir_optim, "paddle SwitchIrOptim option"); + opts->Register("enable-fc-padding", &enable_fc_padding, "paddle EnableFCPadding option"); + opts->Register("enable-profile", &enable_profile, "paddle EnableProfile option"); + } +}; + + void Register(kaldi::OptionsItf* opts) { + _model_opts.Register(opts); + opts->Register("subsampling-rate", &subsampling_rate, + "subsampling rate for deepspeech model"); + opts->Register("receptive-field-length", &receptive_field_length, + "receptive field length for deepspeech model"); + } +}; + + +template +class Tensor { +public: + Tensor() { + } + Tensor(const std::vector& shape) : + _shape(shape) { + int data_size = std::accumulate(_shape.begin(), _shape.end(), + 1, std::multiplies()); + LOG(INFO) << "data size: " << data_size; + _data.resize(data_size, 0); + } + void reshape(const std::vector& shape) { + _shape = shape; + int data_size = std::accumulate(_shape.begin(), _shape.end(), + 1, std::multiplies()); + _data.resize(data_size, 0); + } + const std::vector& get_shape() const { + return _shape; + } + std::vector& get_data() { + return _data; + } +private: + std::vector _shape; + std::vector _data; +}; + +class PaddleNnet : public NnetInterface { + public: + PaddleNnet(const ModelOptions& opts); + virtual void FeedForward(const kaldi::Matrix& features, + kaldi::Matrix* inferences) const; + std::shared_ptr> GetCacheEncoder(const std::string& name); + void init_cache_encouts(const ModelOptions& opts); + + private: + std::unique_ptr pool; + std::vector pool_usages; + std::mutex pool_mutex; + std::map cache_names_idx_; + std::vector>> cache_encouts_; + + public: + DISALLOW_COPY_AND_ASSIGN(PaddleNnet); +}; + +} // namespace ppspeech From d14ee800656c3eaab2e605a42081158a64d605d0 Mon Sep 17 00:00:00 2001 From: SmileGoat Date: Thu, 17 Feb 2022 12:07:46 +0800 Subject: [PATCH 007/124] add decodable & ctc_beam_search_deocder --- speechx/speechx/decoder/common.h | 7 + .../decoder/ctc_beam_search_decoder.cc | 264 +++ .../speechx/decoder/ctc_beam_search_decoder.h | 74 + speechx/speechx/decoder/ctc_decoders | 1 + .../kaldi/decoder/lattice-faster-decoder.cc | 1020 +++++++++ .../kaldi/decoder/lattice-faster-decoder.h | 549 +++++ .../decoder/lattice-faster-online-decoder.cc | 285 +++ .../decoder/lattice-faster-online-decoder.h | 147 ++ .../lat/determinize-lattice-pruned-test.cc | 147 ++ .../kaldi/lat/determinize-lattice-pruned.cc | 1541 ++++++++++++++ .../kaldi/lat/determinize-lattice-pruned.h | 296 +++ speechx/speechx/kaldi/lat/kaldi-lattice.cc | 506 +++++ speechx/speechx/kaldi/lat/kaldi-lattice.h | 156 ++ .../speechx/kaldi/lat/lattice-functions.cc | 1880 +++++++++++++++++ speechx/speechx/kaldi/lat/lattice-functions.h | 402 ++++ speechx/speechx/nnet/ctc_decodable.h | 0 speechx/speechx/nnet/decodable-itf.h | 122 ++ speechx/speechx/nnet/decodable.h | 18 + speechx/speechx/nnet/dnn_decodable.h | 0 speechx/speechx/nnet/nnet_interface.h | 2 +- 20 files changed, 7416 insertions(+), 1 deletion(-) create mode 100644 speechx/speechx/decoder/common.h create mode 100644 speechx/speechx/decoder/ctc_beam_search_decoder.cc create mode 100644 speechx/speechx/decoder/ctc_beam_search_decoder.h create mode 120000 speechx/speechx/decoder/ctc_decoders create mode 100644 speechx/speechx/kaldi/decoder/lattice-faster-decoder.cc create mode 100644 speechx/speechx/kaldi/decoder/lattice-faster-decoder.h create mode 100644 speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.cc create mode 100644 speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.h create mode 100644 speechx/speechx/kaldi/lat/determinize-lattice-pruned-test.cc create mode 100644 speechx/speechx/kaldi/lat/determinize-lattice-pruned.cc create mode 100644 speechx/speechx/kaldi/lat/determinize-lattice-pruned.h create mode 100644 speechx/speechx/kaldi/lat/kaldi-lattice.cc create mode 100644 speechx/speechx/kaldi/lat/kaldi-lattice.h create mode 100644 speechx/speechx/kaldi/lat/lattice-functions.cc create mode 100644 speechx/speechx/kaldi/lat/lattice-functions.h create mode 100644 speechx/speechx/nnet/ctc_decodable.h create mode 100644 speechx/speechx/nnet/decodable-itf.h create mode 100644 speechx/speechx/nnet/decodable.h create mode 100644 speechx/speechx/nnet/dnn_decodable.h diff --git a/speechx/speechx/decoder/common.h b/speechx/speechx/decoder/common.h new file mode 100644 index 00000000..4292a871 --- /dev/null +++ b/speechx/speechx/decoder/common.h @@ -0,0 +1,7 @@ +#include "base/basic_types.h" + +struct DecoderResult { + BaseFloat acoustic_score; + std::vector words_idx; + std::vector> time_stamp; +}; diff --git a/speechx/speechx/decoder/ctc_beam_search_decoder.cc b/speechx/speechx/decoder/ctc_beam_search_decoder.cc new file mode 100644 index 00000000..dc21dcb4 --- /dev/null +++ b/speechx/speechx/decoder/ctc_beam_search_decoder.cc @@ -0,0 +1,264 @@ +#include "decoder/ctc_beam_search_decoder.h" + +#include "base/basic_types.h" +#include "decoder/ctc_decoders/decoder_utils.h" + +namespace ppspeech { + +using std::vector; +using FSTMATCH = fst::SortedMatcher; + +CTCBeamSearch::CTCBeamSearch(std::shared_ptr opts) : + opts_(opts), + vocabulary_(nullptr), + init_ext_scorer_(nullptr), + blank_id(-1), + space_id(-1), + root(nullptr) { + + LOG(INFO) << "dict path: " << _opts.dict_file; + vocabulary_ = std::make_shared>(); + if (!basr::ReadDictToVector(_opts.dict_file, *vocabulary_)) { + LOG(INFO) << "load the dict failed"; + } + LOG(INFO) << "read the vocabulary success, dict size: " << vocabulary_->size(); + + LOG(INFO) << "language model path: " << _opts.lm_path; + init_ext_scorer_ = std::make_shared(_opts.alpha, + _opts.beta, + _opts.lm_path, + *vocabulary_); +} + +void CTCBeamSearch::InitDecoder() { + + blank_id = 0; + auto it = std::find(vocabulary_->begin(), vocabulary_->end(), " "); + + space_id = it - vocabulary_->begin(); + // if no space in vocabulary + if ((size_t)space_id >= vocabulary_->size()) { + space_id = -2; + } + + clear_prefixes(); + + root = std::make_shared(); + root->score = root->log_prob_b_prev = 0.0; + prefixes.push_back(root.get()); + if (init_ext_scorer_ != nullptr && !init_ext_scorer_->is_character_based()) { + auto fst_dict = + static_cast(init_ext_scorer_->dictionary); + fst::StdVectorFst *dict_ptr = fst_dict->Copy(true); + root->set_dictionary(dict_ptr); + + auto matcher = std::make_shared(*dict_ptr, fst::MATCH_INPUT); + root->set_matcher(matcher); + } +} + +void CTCBeamSearch::ResetPrefixes() { + for (size_t i = 0; i < prefixes.size(); i++) { + if (prefixes[i] != nullptr) { + delete prefixes[i]; + prefixes[i] = nullptr; + } + } +} + +int CTCBeamSearch::DecodeLikelihoods(const vector>&probs, + vector& nbest_words) { + std::thread::id this_id = std::this_thread::get_id(); + Timer timer; + vector> double_probs(probs.size(), vector(probs[0].size(), 0)); + + int row = probs.size(); + int col = probs[0].size(); + for(int i = 0; i < row; i++) { + for (int j = 0; j < col; j++){ + double_probs[i][j] = static_cast(probs[i][j]); + } + } + + timer.Reset(); + vector> results = AdvanceDecoding(double_probs); + LOG(INFO) <<"ctc decoding elapsed time(s) " << static_cast(timer.Elapsed()) / 1000.0f; + for (const auto& item : results) { + nbest_words.push_back(item.second); + } + return 0; +} + +vector> CTCBeamSearch::AdvanceDecoding(const vector>& probs_seq) { + size_t num_time_steps = probs_seq.size(); + size_t beam_size = _opts.beam_size; + double cutoff_prob = _opts.cutoff_prob; + size_t cutoff_top_n = _opts.cutoff_top_n; + + for (size_t time_step = 0; time_step < num_time_steps; time_step++) { + const auto& prob = probs_seq[time_step]; + + float min_cutoff = -NUM_FLT_INF; + bool full_beam = false; + if (init_ext_scorer_ != nullptr) { + size_t num_prefixes = std::min(prefixes.size(), beam_size); + std::sort(prefixes.begin(), prefixes.begin() + num_prefixes, + prefix_compare); + + if (num_prefixes == 0) { + continue; + } + min_cutoff = prefixes[num_prefixes - 1]->score + + std::log(prob[blank_id]) - + std::max(0.0, init_ext_scorer_->beta); + + full_beam = (num_prefixes == beam_size); + } + + vector> log_prob_idx = + get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n); + + // loop over chars + size_t log_prob_idx_len = log_prob_idx.size(); + for (size_t index = 0; index < log_prob_idx_len; index++) { + SearchOneChar(full_beam, log_prob_idx[index], min_cutoff); + + prefixes.clear(); + + // update log probs + root->iterate_to_vec(prefixes); + // only preserve top beam_size prefixes + if (prefixes.size() >= beam_size) { + std::nth_element(prefixes.begin(), + prefixes.begin() + beam_size, + prefixes.end(), + prefix_compare); + for (size_t i = beam_size; i < prefixes.size(); ++i) { + prefixes[i]->remove(); + } + } // if + } // for probs_seq + + // score the last word of each prefix that doesn't end with space + LMRescore(); + CalculateApproxScore(); + return get_beam_search_result(prefixes, *vocabulary_, beam_size); +} + +int CTCBeamSearch::SearchOneChar(const bool& full_beam, + const std::pair& log_prob_idx, + const float& min_cutoff) { + size_t beam_size = _opts.beam_size; + const auto& c = log_prob_idx.first; + const auto& log_prob_c = log_prob_idx.second; + size_t prefixes_len = std::min(prefixes.size(), beam_size); + + for (size_t i = 0; i < prefixes_len; ++i) { + auto prefix = prefixes[i]; + if (full_beam && log_prob_c + prefix->score < min_cutoff) { + break; + } + + if (c == blank_id) { + prefix->log_prob_b_cur = log_sum_exp( + prefix->log_prob_b_cur, + log_prob_c + + prefix->score); + continue; + } + + // repeated character + if (c == prefix->character) { + // p_{nb}(l;x_{1:t}) = p(c;x_{t})p(l;x_{1:t-1}) + prefix->log_prob_nb_cur = log_sum_exp( + prefix->log_prob_nb_cur, + log_prob_c + + prefix->log_prob_nb_prev); + } + + // get new prefix + auto prefix_new = prefix->get_path_trie(c); + if (prefix_new != nullptr) { + float log_p = -NUM_FLT_INF; + if (c == prefix->character && + prefix->log_prob_b_prev > -NUM_FLT_INF) { + // p_{nb}(l^{+};x_{1:t}) = p(c;x_{t})p_{b}(l;x_{1:t-1}) + log_p = log_prob_c + prefix->log_prob_b_prev; + } else if (c != prefix->character) { + // p_{nb}(l^{+};x_{1:t}) = p(c;x_{t}) p(l;x_{1:t-1}) + log_p = log_prob_c + prefix->score; + } + + // language model scoring + if (init_ext_scorer_ != nullptr && + (c == space_id || init_ext_scorer_->is_character_based())) { + PathTrie *prefix_to_score = nullptr; + // skip scoring the space + if (init_ext_scorer_->is_character_based()) { + prefix_to_score = prefix_new; + } else { + prefix_to_score = prefix; + } + + float score = 0.0; + vector ngram; + ngram = init_ext_scorer_->make_ngram(prefix_to_score); + // lm score: p_{lm}(W)^{\alpha} + \beta + score = init_ext_scorer_->get_log_cond_prob(ngram) * + init_ext_scorer_->alpha; + log_p += score; + log_p += init_ext_scorer_->beta; + } + // p_{nb}(l;x_{1:t}) + prefix_new->log_prob_nb_cur = + log_sum_exp(prefix_new->log_prob_nb_cur, + log_p); + } + } // end of loop over prefix + return 0; +} + +void CTCBeamSearch::CalculateApproxScore() { + size_t beam_size = _opts.beam_size; + size_t num_prefixes = std::min(prefixes.size(), beam_size); + std::sort( + prefixes.begin(), + prefixes.begin() + num_prefixes, + prefix_compare); + + // compute aproximate ctc score as the return score, without affecting the + // return order of decoding result. To delete when decoder gets stable. + for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { + double approx_ctc = prefixes[i]->score; + if (init_ext_scorer_ != nullptr) { + vector output; + prefixes[i]->get_path_vec(output); + auto prefix_length = output.size(); + auto words = init_ext_scorer_->split_labels(output); + // remove word insert + approx_ctc = approx_ctc - prefix_length * init_ext_scorer_->beta; + // remove language model weight: + approx_ctc -= + (init_ext_scorer_->get_sent_log_prob(words)) * init_ext_scorer_->alpha; + } + prefixes[i]->approx_ctc = approx_ctc; + } +} + +void CTCBeamSearch::LMRescore() { + size_t beam_size = _opts.beam_size; + if (init_ext_scorer_ != nullptr && !init_ext_scorer_->is_character_based()) { + for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { + auto prefix = prefixes[i]; + if (!prefix->is_empty() && prefix->character != space_id) { + float score = 0.0; + vector ngram = init_ext_scorer_->make_ngram(prefix); + score = init_ext_scorer_->get_log_cond_prob(ngram) * init_ext_scorer_->alpha; + score += init_ext_scorer_->beta; + prefix->score += score; + } + } + } +} + +} // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/decoder/ctc_beam_search_decoder.h b/speechx/speechx/decoder/ctc_beam_search_decoder.h new file mode 100644 index 00000000..5bf388d3 --- /dev/null +++ b/speechx/speechx/decoder/ctc_beam_search_decoder.h @@ -0,0 +1,74 @@ +#include "base/basic_types.h" + +#pragma once + +namespace ppspeech { + +struct CTCBeamSearchOptions { + std::string dict_file; + std::string lm_path; + BaseFloat alpha; + BaseFloat beta; + BaseFloat cutoff_prob; + int beam_size; + int cutoff_top_n; + int num_proc_bsearch; + CTCBeamSearchOptions() : + dict_file("./model/words.txt"), + lm_path("./model/lm.arpa"), + alpha(1.9f), + beta(5.0), + beam_size(300), + cutoff_prob(0.99f), + cutoff_top_n(40), + num_proc_bsearch(0) { + } + + void Register(kaldi::OptionsItf* opts) { + opts->Register("dict", &dict_file, "dict file "); + opts->Register("lm-path", &lm_path, "language model file"); + opts->Register("alpha", &alpha, "alpha"); + opts->Register("beta", &beta, "beta"); + opts->Register("beam-size", &beam_size, "beam size for beam search method"); + opts->Register("cutoff-prob", &cutoff_prob, "cutoff probs"); + opts->Register("cutoff-top-n", &cutoff_top_n, "cutoff top n"); + opts->Register("num-proc-bsearch", &num_proc_bsearch, "num proc bsearch"); + } +}; + +class CTCBeamSearch { +public: + + CTCBeamSearch(std::shared_ptr opts); + + ~CTCBeamSearch() { + } + bool InitDecoder(); + int DecodeLikelihoods(const std::vector>&probs, + std::vector& nbest_words); + + std::vector& GetDecodeResult() { + return decoder_results_; + } + +private: + void ResetPrefixes(); + int32 SearchOneChar(const bool& full_beam, + const std::pair& log_prob_idx, + const BaseFloat& min_cutoff); + void CalculateApproxScore(); + void LMRescore(); + std::vector> + AdvanceDecoding(const std::vector>& probs_seq); + CTCBeamSearchOptions opts_; + std::shared_ptr init_ext_scorer_; // todo separate later + std::vector decoder_results_; + std::vector> vocabulary_; // todo remove later + + size_t blank_id; + int space_id; + std::shared_ptr root; + std::vector prefixes; +}; + +} // namespace basr \ No newline at end of file diff --git a/speechx/speechx/decoder/ctc_decoders b/speechx/speechx/decoder/ctc_decoders new file mode 120000 index 00000000..b280de09 --- /dev/null +++ b/speechx/speechx/decoder/ctc_decoders @@ -0,0 +1 @@ +../../../third_party/ctc_decoders \ No newline at end of file diff --git a/speechx/speechx/kaldi/decoder/lattice-faster-decoder.cc b/speechx/speechx/kaldi/decoder/lattice-faster-decoder.cc new file mode 100644 index 00000000..42d1d2af --- /dev/null +++ b/speechx/speechx/kaldi/decoder/lattice-faster-decoder.cc @@ -0,0 +1,1020 @@ +// decoder/lattice-faster-decoder.cc + +// Copyright 2009-2012 Microsoft Corporation Mirko Hannemann +// 2013-2018 Johns Hopkins University (Author: Daniel Povey) +// 2014 Guoguo Chen +// 2018 Zhehuai Chen + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "decoder/lattice-faster-decoder.h" +#include "lat/lattice-functions.h" + +namespace kaldi { + +// instantiate this class once for each thing you have to decode. +template +LatticeFasterDecoderTpl::LatticeFasterDecoderTpl( + const FST &fst, const LatticeFasterDecoderConfig &config) + : fst_(&fst), + delete_fst_(false), + config_(config), + num_toks_(0), + token_pool_(config.memory_pool_tokens_block_size), + forward_link_pool_(config.memory_pool_links_block_size) { + config.Check(); + toks_.SetSize(1000); // just so on the first frame we do something reasonable. +} + +template +LatticeFasterDecoderTpl::LatticeFasterDecoderTpl( + const LatticeFasterDecoderConfig &config, FST *fst) + : fst_(fst), + delete_fst_(true), + config_(config), + num_toks_(0), + token_pool_(config.memory_pool_tokens_block_size), + forward_link_pool_(config.memory_pool_links_block_size) { + config.Check(); + toks_.SetSize(1000); // just so on the first frame we do something reasonable. +} + +template +LatticeFasterDecoderTpl::~LatticeFasterDecoderTpl() { + DeleteElems(toks_.Clear()); + ClearActiveTokens(); + if (delete_fst_) delete fst_; +} + +template +void LatticeFasterDecoderTpl::InitDecoding() { + // clean up from last time: + DeleteElems(toks_.Clear()); + cost_offsets_.clear(); + ClearActiveTokens(); + warned_ = false; + num_toks_ = 0; + decoding_finalized_ = false; + final_costs_.clear(); + StateId start_state = fst_->Start(); + KALDI_ASSERT(start_state != fst::kNoStateId); + active_toks_.resize(1); + Token *start_tok = + new (token_pool_.Allocate()) Token(0.0, 0.0, NULL, NULL, NULL); + active_toks_[0].toks = start_tok; + toks_.Insert(start_state, start_tok); + num_toks_++; + ProcessNonemitting(config_.beam); +} + +// Returns true if any kind of traceback is available (not necessarily from +// a final state). It should only very rarely return false; this indicates +// an unusual search error. +template +bool LatticeFasterDecoderTpl::Decode(DecodableInterface *decodable) { + InitDecoding(); + // We use 1-based indexing for frames in this decoder (if you view it in + // terms of features), but note that the decodable object uses zero-based + // numbering, which we have to correct for when we call it. + AdvanceDecoding(decodable); + FinalizeDecoding(); + + // Returns true if we have any kind of traceback available (not necessarily + // to the end state; query ReachedFinal() for that). + return !active_toks_.empty() && active_toks_.back().toks != NULL; +} + + +// Outputs an FST corresponding to the single best path through the lattice. +template +bool LatticeFasterDecoderTpl::GetBestPath(Lattice *olat, + bool use_final_probs) const { + Lattice raw_lat; + GetRawLattice(&raw_lat, use_final_probs); + ShortestPath(raw_lat, olat); + return (olat->NumStates() != 0); +} + + +// Outputs an FST corresponding to the raw, state-level lattice +template +bool LatticeFasterDecoderTpl::GetRawLattice( + Lattice *ofst, + bool use_final_probs) const { + typedef LatticeArc Arc; + typedef Arc::StateId StateId; + typedef Arc::Weight Weight; + typedef Arc::Label Label; + + // Note: you can't use the old interface (Decode()) if you want to + // get the lattice with use_final_probs = false. You'd have to do + // InitDecoding() and then AdvanceDecoding(). + if (decoding_finalized_ && !use_final_probs) + KALDI_ERR << "You cannot call FinalizeDecoding() and then call " + << "GetRawLattice() with use_final_probs == false"; + + unordered_map final_costs_local; + + const unordered_map &final_costs = + (decoding_finalized_ ? final_costs_ : final_costs_local); + if (!decoding_finalized_ && use_final_probs) + ComputeFinalCosts(&final_costs_local, NULL, NULL); + + ofst->DeleteStates(); + // num-frames plus one (since frames are one-based, and we have + // an extra frame for the start-state). + int32 num_frames = active_toks_.size() - 1; + KALDI_ASSERT(num_frames > 0); + const int32 bucket_count = num_toks_/2 + 3; + unordered_map tok_map(bucket_count); + // First create all states. + std::vector token_list; + for (int32 f = 0; f <= num_frames; f++) { + if (active_toks_[f].toks == NULL) { + KALDI_WARN << "GetRawLattice: no tokens active on frame " << f + << ": not producing lattice.\n"; + return false; + } + TopSortTokens(active_toks_[f].toks, &token_list); + for (size_t i = 0; i < token_list.size(); i++) + if (token_list[i] != NULL) + tok_map[token_list[i]] = ofst->AddState(); + } + // The next statement sets the start state of the output FST. Because we + // topologically sorted the tokens, state zero must be the start-state. + ofst->SetStart(0); + + KALDI_VLOG(4) << "init:" << num_toks_/2 + 3 << " buckets:" + << tok_map.bucket_count() << " load:" << tok_map.load_factor() + << " max:" << tok_map.max_load_factor(); + // Now create all arcs. + for (int32 f = 0; f <= num_frames; f++) { + for (Token *tok = active_toks_[f].toks; tok != NULL; tok = tok->next) { + StateId cur_state = tok_map[tok]; + for (ForwardLinkT *l = tok->links; + l != NULL; + l = l->next) { + typename unordered_map::const_iterator + iter = tok_map.find(l->next_tok); + StateId nextstate = iter->second; + KALDI_ASSERT(iter != tok_map.end()); + BaseFloat cost_offset = 0.0; + if (l->ilabel != 0) { // emitting.. + KALDI_ASSERT(f >= 0 && f < cost_offsets_.size()); + cost_offset = cost_offsets_[f]; + } + Arc arc(l->ilabel, l->olabel, + Weight(l->graph_cost, l->acoustic_cost - cost_offset), + nextstate); + ofst->AddArc(cur_state, arc); + } + if (f == num_frames) { + if (use_final_probs && !final_costs.empty()) { + typename unordered_map::const_iterator + iter = final_costs.find(tok); + if (iter != final_costs.end()) + ofst->SetFinal(cur_state, LatticeWeight(iter->second, 0)); + } else { + ofst->SetFinal(cur_state, LatticeWeight::One()); + } + } + } + } + return (ofst->NumStates() > 0); +} + + +// This function is now deprecated, since now we do determinization from outside +// the LatticeFasterDecoder class. Outputs an FST corresponding to the +// lattice-determinized lattice (one path per word sequence). +template +bool LatticeFasterDecoderTpl::GetLattice(CompactLattice *ofst, + bool use_final_probs) const { + Lattice raw_fst; + GetRawLattice(&raw_fst, use_final_probs); + Invert(&raw_fst); // make it so word labels are on the input. + // (in phase where we get backward-costs). + fst::ILabelCompare ilabel_comp; + ArcSort(&raw_fst, ilabel_comp); // sort on ilabel; makes + // lattice-determinization more efficient. + + fst::DeterminizeLatticePrunedOptions lat_opts; + lat_opts.max_mem = config_.det_opts.max_mem; + + DeterminizeLatticePruned(raw_fst, config_.lattice_beam, ofst, lat_opts); + raw_fst.DeleteStates(); // Free memory-- raw_fst no longer needed. + Connect(ofst); // Remove unreachable states... there might be + // a small number of these, in some cases. + // Note: if something went wrong and the raw lattice was empty, + // we should still get to this point in the code without warnings or failures. + return (ofst->NumStates() != 0); +} + +template +void LatticeFasterDecoderTpl::PossiblyResizeHash(size_t num_toks) { + size_t new_sz = static_cast(static_cast(num_toks) + * config_.hash_ratio); + if (new_sz > toks_.Size()) { + toks_.SetSize(new_sz); + } +} + +/* + A note on the definition of extra_cost. + + extra_cost is used in pruning tokens, to save memory. + + extra_cost can be thought of as a beta (backward) cost assuming + we had set the betas on currently-active tokens to all be the negative + of the alphas for those tokens. (So all currently active tokens would + be on (tied) best paths). + + We can use the extra_cost to accurately prune away tokens that we know will + never appear in the lattice. If the extra_cost is greater than the desired + lattice beam, the token would provably never appear in the lattice, so we can + prune away the token. + + (Note: we don't update all the extra_costs every time we update a frame; we + only do it every 'config_.prune_interval' frames). + */ + +// FindOrAddToken either locates a token in hash of toks_, +// or if necessary inserts a new, empty token (i.e. with no forward links) +// for the current frame. [note: it's inserted if necessary into hash toks_ +// and also into the singly linked list of tokens active on this frame +// (whose head is at active_toks_[frame]). +template +inline typename LatticeFasterDecoderTpl::Elem* +LatticeFasterDecoderTpl::FindOrAddToken( + StateId state, int32 frame_plus_one, BaseFloat tot_cost, + Token *backpointer, bool *changed) { + // Returns the Token pointer. Sets "changed" (if non-NULL) to true + // if the token was newly created or the cost changed. + KALDI_ASSERT(frame_plus_one < active_toks_.size()); + Token *&toks = active_toks_[frame_plus_one].toks; + Elem *e_found = toks_.Insert(state, NULL); + if (e_found->val == NULL) { // no such token presently. + const BaseFloat extra_cost = 0.0; + // tokens on the currently final frame have zero extra_cost + // as any of them could end up + // on the winning path. + Token *new_tok = new (token_pool_.Allocate()) + Token(tot_cost, extra_cost, NULL, toks, backpointer); + // NULL: no forward links yet + toks = new_tok; + num_toks_++; + e_found->val = new_tok; + if (changed) *changed = true; + return e_found; + } else { + Token *tok = e_found->val; // There is an existing Token for this state. + if (tok->tot_cost > tot_cost) { // replace old token + tok->tot_cost = tot_cost; + // SetBackpointer() just does tok->backpointer = backpointer in + // the case where Token == BackpointerToken, else nothing. + tok->SetBackpointer(backpointer); + // we don't allocate a new token, the old stays linked in active_toks_ + // we only replace the tot_cost + // in the current frame, there are no forward links (and no extra_cost) + // only in ProcessNonemitting we have to delete forward links + // in case we visit a state for the second time + // those forward links, that lead to this replaced token before: + // they remain and will hopefully be pruned later (PruneForwardLinks...) + if (changed) *changed = true; + } else { + if (changed) *changed = false; + } + return e_found; + } +} + +// prunes outgoing links for all tokens in active_toks_[frame] +// it's called by PruneActiveTokens +// all links, that have link_extra_cost > lattice_beam are pruned +template +void LatticeFasterDecoderTpl::PruneForwardLinks( + int32 frame_plus_one, bool *extra_costs_changed, + bool *links_pruned, BaseFloat delta) { + // delta is the amount by which the extra_costs must change + // If delta is larger, we'll tend to go back less far + // toward the beginning of the file. + // extra_costs_changed is set to true if extra_cost was changed for any token + // links_pruned is set to true if any link in any token was pruned + + *extra_costs_changed = false; + *links_pruned = false; + KALDI_ASSERT(frame_plus_one >= 0 && frame_plus_one < active_toks_.size()); + if (active_toks_[frame_plus_one].toks == NULL) { // empty list; should not happen. + if (!warned_) { + KALDI_WARN << "No tokens alive [doing pruning].. warning first " + "time only for each utterance\n"; + warned_ = true; + } + } + + // We have to iterate until there is no more change, because the links + // are not guaranteed to be in topological order. + bool changed = true; // difference new minus old extra cost >= delta ? + while (changed) { + changed = false; + for (Token *tok = active_toks_[frame_plus_one].toks; + tok != NULL; tok = tok->next) { + ForwardLinkT *link, *prev_link = NULL; + // will recompute tok_extra_cost for tok. + BaseFloat tok_extra_cost = std::numeric_limits::infinity(); + // tok_extra_cost is the best (min) of link_extra_cost of outgoing links + for (link = tok->links; link != NULL; ) { + // See if we need to excise this link... + Token *next_tok = link->next_tok; + BaseFloat link_extra_cost = next_tok->extra_cost + + ((tok->tot_cost + link->acoustic_cost + link->graph_cost) + - next_tok->tot_cost); // difference in brackets is >= 0 + // link_exta_cost is the difference in score between the best paths + // through link source state and through link destination state + KALDI_ASSERT(link_extra_cost == link_extra_cost); // check for NaN + if (link_extra_cost > config_.lattice_beam) { // excise link + ForwardLinkT *next_link = link->next; + if (prev_link != NULL) prev_link->next = next_link; + else tok->links = next_link; + forward_link_pool_.Free(link); + link = next_link; // advance link but leave prev_link the same. + *links_pruned = true; + } else { // keep the link and update the tok_extra_cost if needed. + if (link_extra_cost < 0.0) { // this is just a precaution. + if (link_extra_cost < -0.01) + KALDI_WARN << "Negative extra_cost: " << link_extra_cost; + link_extra_cost = 0.0; + } + if (link_extra_cost < tok_extra_cost) + tok_extra_cost = link_extra_cost; + prev_link = link; // move to next link + link = link->next; + } + } // for all outgoing links + if (fabs(tok_extra_cost - tok->extra_cost) > delta) + changed = true; // difference new minus old is bigger than delta + tok->extra_cost = tok_extra_cost; + // will be +infinity or <= lattice_beam_. + // infinity indicates, that no forward link survived pruning + } // for all Token on active_toks_[frame] + if (changed) *extra_costs_changed = true; + + // Note: it's theoretically possible that aggressive compiler + // optimizations could cause an infinite loop here for small delta and + // high-dynamic-range scores. + } // while changed +} + +// PruneForwardLinksFinal is a version of PruneForwardLinks that we call +// on the final frame. If there are final tokens active, it uses +// the final-probs for pruning, otherwise it treats all tokens as final. +template +void LatticeFasterDecoderTpl::PruneForwardLinksFinal() { + KALDI_ASSERT(!active_toks_.empty()); + int32 frame_plus_one = active_toks_.size() - 1; + + if (active_toks_[frame_plus_one].toks == NULL) // empty list; should not happen. + KALDI_WARN << "No tokens alive at end of file"; + + typedef typename unordered_map::const_iterator IterType; + ComputeFinalCosts(&final_costs_, &final_relative_cost_, &final_best_cost_); + decoding_finalized_ = true; + // We call DeleteElems() as a nicety, not because it's really necessary; + // otherwise there would be a time, after calling PruneTokensForFrame() on the + // final frame, when toks_.GetList() or toks_.Clear() would contain pointers + // to nonexistent tokens. + DeleteElems(toks_.Clear()); + + // Now go through tokens on this frame, pruning forward links... may have to + // iterate a few times until there is no more change, because the list is not + // in topological order. This is a modified version of the code in + // PruneForwardLinks, but here we also take account of the final-probs. + bool changed = true; + BaseFloat delta = 1.0e-05; + while (changed) { + changed = false; + for (Token *tok = active_toks_[frame_plus_one].toks; + tok != NULL; tok = tok->next) { + ForwardLinkT *link, *prev_link = NULL; + // will recompute tok_extra_cost. It has a term in it that corresponds + // to the "final-prob", so instead of initializing tok_extra_cost to infinity + // below we set it to the difference between the (score+final_prob) of this token, + // and the best such (score+final_prob). + BaseFloat final_cost; + if (final_costs_.empty()) { + final_cost = 0.0; + } else { + IterType iter = final_costs_.find(tok); + if (iter != final_costs_.end()) + final_cost = iter->second; + else + final_cost = std::numeric_limits::infinity(); + } + BaseFloat tok_extra_cost = tok->tot_cost + final_cost - final_best_cost_; + // tok_extra_cost will be a "min" over either directly being final, or + // being indirectly final through other links, and the loop below may + // decrease its value: + for (link = tok->links; link != NULL; ) { + // See if we need to excise this link... + Token *next_tok = link->next_tok; + BaseFloat link_extra_cost = next_tok->extra_cost + + ((tok->tot_cost + link->acoustic_cost + link->graph_cost) + - next_tok->tot_cost); + if (link_extra_cost > config_.lattice_beam) { // excise link + ForwardLinkT *next_link = link->next; + if (prev_link != NULL) prev_link->next = next_link; + else tok->links = next_link; + forward_link_pool_.Free(link); + link = next_link; // advance link but leave prev_link the same. + } else { // keep the link and update the tok_extra_cost if needed. + if (link_extra_cost < 0.0) { // this is just a precaution. + if (link_extra_cost < -0.01) + KALDI_WARN << "Negative extra_cost: " << link_extra_cost; + link_extra_cost = 0.0; + } + if (link_extra_cost < tok_extra_cost) + tok_extra_cost = link_extra_cost; + prev_link = link; + link = link->next; + } + } + // prune away tokens worse than lattice_beam above best path. This step + // was not necessary in the non-final case because then, this case + // showed up as having no forward links. Here, the tok_extra_cost has + // an extra component relating to the final-prob. + if (tok_extra_cost > config_.lattice_beam) + tok_extra_cost = std::numeric_limits::infinity(); + // to be pruned in PruneTokensForFrame + + if (!ApproxEqual(tok->extra_cost, tok_extra_cost, delta)) + changed = true; + tok->extra_cost = tok_extra_cost; // will be +infinity or <= lattice_beam_. + } + } // while changed +} + +template +BaseFloat LatticeFasterDecoderTpl::FinalRelativeCost() const { + if (!decoding_finalized_) { + BaseFloat relative_cost; + ComputeFinalCosts(NULL, &relative_cost, NULL); + return relative_cost; + } else { + // we're not allowed to call that function if FinalizeDecoding() has + // been called; return a cached value. + return final_relative_cost_; + } +} + + +// Prune away any tokens on this frame that have no forward links. +// [we don't do this in PruneForwardLinks because it would give us +// a problem with dangling pointers]. +// It's called by PruneActiveTokens if any forward links have been pruned +template +void LatticeFasterDecoderTpl::PruneTokensForFrame(int32 frame_plus_one) { + KALDI_ASSERT(frame_plus_one >= 0 && frame_plus_one < active_toks_.size()); + Token *&toks = active_toks_[frame_plus_one].toks; + if (toks == NULL) + KALDI_WARN << "No tokens alive [doing pruning]"; + Token *tok, *next_tok, *prev_tok = NULL; + for (tok = toks; tok != NULL; tok = next_tok) { + next_tok = tok->next; + if (tok->extra_cost == std::numeric_limits::infinity()) { + // token is unreachable from end of graph; (no forward links survived) + // excise tok from list and delete tok. + if (prev_tok != NULL) prev_tok->next = tok->next; + else toks = tok->next; + token_pool_.Free(tok); + num_toks_--; + } else { // fetch next Token + prev_tok = tok; + } + } +} + +// Go backwards through still-alive tokens, pruning them, starting not from +// the current frame (where we want to keep all tokens) but from the frame before +// that. We go backwards through the frames and stop when we reach a point +// where the delta-costs are not changing (and the delta controls when we consider +// a cost to have "not changed"). +template +void LatticeFasterDecoderTpl::PruneActiveTokens(BaseFloat delta) { + int32 cur_frame_plus_one = NumFramesDecoded(); + int32 num_toks_begin = num_toks_; + // The index "f" below represents a "frame plus one", i.e. you'd have to subtract + // one to get the corresponding index for the decodable object. + for (int32 f = cur_frame_plus_one - 1; f >= 0; f--) { + // Reason why we need to prune forward links in this situation: + // (1) we have never pruned them (new TokenList) + // (2) we have not yet pruned the forward links to the next f, + // after any of those tokens have changed their extra_cost. + if (active_toks_[f].must_prune_forward_links) { + bool extra_costs_changed = false, links_pruned = false; + PruneForwardLinks(f, &extra_costs_changed, &links_pruned, delta); + if (extra_costs_changed && f > 0) // any token has changed extra_cost + active_toks_[f-1].must_prune_forward_links = true; + if (links_pruned) // any link was pruned + active_toks_[f].must_prune_tokens = true; + active_toks_[f].must_prune_forward_links = false; // job done + } + if (f+1 < cur_frame_plus_one && // except for last f (no forward links) + active_toks_[f+1].must_prune_tokens) { + PruneTokensForFrame(f+1); + active_toks_[f+1].must_prune_tokens = false; + } + } + KALDI_VLOG(4) << "PruneActiveTokens: pruned tokens from " << num_toks_begin + << " to " << num_toks_; +} + +template +void LatticeFasterDecoderTpl::ComputeFinalCosts( + unordered_map *final_costs, + BaseFloat *final_relative_cost, + BaseFloat *final_best_cost) const { + KALDI_ASSERT(!decoding_finalized_); + if (final_costs != NULL) + final_costs->clear(); + const Elem *final_toks = toks_.GetList(); + BaseFloat infinity = std::numeric_limits::infinity(); + BaseFloat best_cost = infinity, + best_cost_with_final = infinity; + + while (final_toks != NULL) { + StateId state = final_toks->key; + Token *tok = final_toks->val; + const Elem *next = final_toks->tail; + BaseFloat final_cost = fst_->Final(state).Value(); + BaseFloat cost = tok->tot_cost, + cost_with_final = cost + final_cost; + best_cost = std::min(cost, best_cost); + best_cost_with_final = std::min(cost_with_final, best_cost_with_final); + if (final_costs != NULL && final_cost != infinity) + (*final_costs)[tok] = final_cost; + final_toks = next; + } + if (final_relative_cost != NULL) { + if (best_cost == infinity && best_cost_with_final == infinity) { + // Likely this will only happen if there are no tokens surviving. + // This seems the least bad way to handle it. + *final_relative_cost = infinity; + } else { + *final_relative_cost = best_cost_with_final - best_cost; + } + } + if (final_best_cost != NULL) { + if (best_cost_with_final != infinity) { // final-state exists. + *final_best_cost = best_cost_with_final; + } else { // no final-state exists. + *final_best_cost = best_cost; + } + } +} + +template +void LatticeFasterDecoderTpl::AdvanceDecoding(DecodableInterface *decodable, + int32 max_num_frames) { + if (std::is_same >::value) { + // if the type 'FST' is the FST base-class, then see if the FST type of fst_ + // is actually VectorFst or ConstFst. If so, call the AdvanceDecoding() + // function after casting *this to the more specific type. + if (fst_->Type() == "const") { + LatticeFasterDecoderTpl, Token> *this_cast = + reinterpret_cast, Token>* >(this); + this_cast->AdvanceDecoding(decodable, max_num_frames); + return; + } else if (fst_->Type() == "vector") { + LatticeFasterDecoderTpl, Token> *this_cast = + reinterpret_cast, Token>* >(this); + this_cast->AdvanceDecoding(decodable, max_num_frames); + return; + } + } + + + KALDI_ASSERT(!active_toks_.empty() && !decoding_finalized_ && + "You must call InitDecoding() before AdvanceDecoding"); + int32 num_frames_ready = decodable->NumFramesReady(); + // num_frames_ready must be >= num_frames_decoded, or else + // the number of frames ready must have decreased (which doesn't + // make sense) or the decodable object changed between calls + // (which isn't allowed). + KALDI_ASSERT(num_frames_ready >= NumFramesDecoded()); + int32 target_frames_decoded = num_frames_ready; + if (max_num_frames >= 0) + target_frames_decoded = std::min(target_frames_decoded, + NumFramesDecoded() + max_num_frames); + while (NumFramesDecoded() < target_frames_decoded) { + if (NumFramesDecoded() % config_.prune_interval == 0) { + PruneActiveTokens(config_.lattice_beam * config_.prune_scale); + } + BaseFloat cost_cutoff = ProcessEmitting(decodable); + ProcessNonemitting(cost_cutoff); + } +} + +// FinalizeDecoding() is a version of PruneActiveTokens that we call +// (optionally) on the final frame. Takes into account the final-prob of +// tokens. This function used to be called PruneActiveTokensFinal(). +template +void LatticeFasterDecoderTpl::FinalizeDecoding() { + int32 final_frame_plus_one = NumFramesDecoded(); + int32 num_toks_begin = num_toks_; + // PruneForwardLinksFinal() prunes final frame (with final-probs), and + // sets decoding_finalized_. + PruneForwardLinksFinal(); + for (int32 f = final_frame_plus_one - 1; f >= 0; f--) { + bool b1, b2; // values not used. + BaseFloat dontcare = 0.0; // delta of zero means we must always update + PruneForwardLinks(f, &b1, &b2, dontcare); + PruneTokensForFrame(f + 1); + } + PruneTokensForFrame(0); + KALDI_VLOG(4) << "pruned tokens from " << num_toks_begin + << " to " << num_toks_; +} + +/// Gets the weight cutoff. Also counts the active tokens. +template +BaseFloat LatticeFasterDecoderTpl::GetCutoff(Elem *list_head, size_t *tok_count, + BaseFloat *adaptive_beam, Elem **best_elem) { + BaseFloat best_weight = std::numeric_limits::infinity(); + // positive == high cost == bad. + size_t count = 0; + if (config_.max_active == std::numeric_limits::max() && + config_.min_active == 0) { + for (Elem *e = list_head; e != NULL; e = e->tail, count++) { + BaseFloat w = static_cast(e->val->tot_cost); + if (w < best_weight) { + best_weight = w; + if (best_elem) *best_elem = e; + } + } + if (tok_count != NULL) *tok_count = count; + if (adaptive_beam != NULL) *adaptive_beam = config_.beam; + return best_weight + config_.beam; + } else { + tmp_array_.clear(); + for (Elem *e = list_head; e != NULL; e = e->tail, count++) { + BaseFloat w = e->val->tot_cost; + tmp_array_.push_back(w); + if (w < best_weight) { + best_weight = w; + if (best_elem) *best_elem = e; + } + } + if (tok_count != NULL) *tok_count = count; + + BaseFloat beam_cutoff = best_weight + config_.beam, + min_active_cutoff = std::numeric_limits::infinity(), + max_active_cutoff = std::numeric_limits::infinity(); + + KALDI_VLOG(6) << "Number of tokens active on frame " << NumFramesDecoded() + << " is " << tmp_array_.size(); + + if (tmp_array_.size() > static_cast(config_.max_active)) { + std::nth_element(tmp_array_.begin(), + tmp_array_.begin() + config_.max_active, + tmp_array_.end()); + max_active_cutoff = tmp_array_[config_.max_active]; + } + if (max_active_cutoff < beam_cutoff) { // max_active is tighter than beam. + if (adaptive_beam) + *adaptive_beam = max_active_cutoff - best_weight + config_.beam_delta; + return max_active_cutoff; + } + if (tmp_array_.size() > static_cast(config_.min_active)) { + if (config_.min_active == 0) min_active_cutoff = best_weight; + else { + std::nth_element(tmp_array_.begin(), + tmp_array_.begin() + config_.min_active, + tmp_array_.size() > static_cast(config_.max_active) ? + tmp_array_.begin() + config_.max_active : + tmp_array_.end()); + min_active_cutoff = tmp_array_[config_.min_active]; + } + } + if (min_active_cutoff > beam_cutoff) { // min_active is looser than beam. + if (adaptive_beam) + *adaptive_beam = min_active_cutoff - best_weight + config_.beam_delta; + return min_active_cutoff; + } else { + *adaptive_beam = config_.beam; + return beam_cutoff; + } + } +} + +template +BaseFloat LatticeFasterDecoderTpl::ProcessEmitting( + DecodableInterface *decodable) { + KALDI_ASSERT(active_toks_.size() > 0); + int32 frame = active_toks_.size() - 1; // frame is the frame-index + // (zero-based) used to get likelihoods + // from the decodable object. + active_toks_.resize(active_toks_.size() + 1); + + Elem *final_toks = toks_.Clear(); // analogous to swapping prev_toks_ / cur_toks_ + // in simple-decoder.h. Removes the Elems from + // being indexed in the hash in toks_. + Elem *best_elem = NULL; + BaseFloat adaptive_beam; + size_t tok_cnt; + BaseFloat cur_cutoff = GetCutoff(final_toks, &tok_cnt, &adaptive_beam, &best_elem); + KALDI_VLOG(6) << "Adaptive beam on frame " << NumFramesDecoded() << " is " + << adaptive_beam; + + PossiblyResizeHash(tok_cnt); // This makes sure the hash is always big enough. + + BaseFloat next_cutoff = std::numeric_limits::infinity(); + // pruning "online" before having seen all tokens + + BaseFloat cost_offset = 0.0; // Used to keep probabilities in a good + // dynamic range. + + + // First process the best token to get a hopefully + // reasonably tight bound on the next cutoff. The only + // products of the next block are "next_cutoff" and "cost_offset". + if (best_elem) { + StateId state = best_elem->key; + Token *tok = best_elem->val; + cost_offset = - tok->tot_cost; + for (fst::ArcIterator aiter(*fst_, state); + !aiter.Done(); + aiter.Next()) { + const Arc &arc = aiter.Value(); + if (arc.ilabel != 0) { // propagate.. + BaseFloat new_weight = arc.weight.Value() + cost_offset - + decodable->LogLikelihood(frame, arc.ilabel) + tok->tot_cost; + if (new_weight + adaptive_beam < next_cutoff) + next_cutoff = new_weight + adaptive_beam; + } + } + } + + // Store the offset on the acoustic likelihoods that we're applying. + // Could just do cost_offsets_.push_back(cost_offset), but we + // do it this way as it's more robust to future code changes. + cost_offsets_.resize(frame + 1, 0.0); + cost_offsets_[frame] = cost_offset; + + // the tokens are now owned here, in final_toks, and the hash is empty. + // 'owned' is a complex thing here; the point is we need to call DeleteElem + // on each elem 'e' to let toks_ know we're done with them. + for (Elem *e = final_toks, *e_tail; e != NULL; e = e_tail) { + // loop this way because we delete "e" as we go. + StateId state = e->key; + Token *tok = e->val; + if (tok->tot_cost <= cur_cutoff) { + for (fst::ArcIterator aiter(*fst_, state); + !aiter.Done(); + aiter.Next()) { + const Arc &arc = aiter.Value(); + if (arc.ilabel != 0) { // propagate.. + BaseFloat ac_cost = cost_offset - + decodable->LogLikelihood(frame, arc.ilabel), + graph_cost = arc.weight.Value(), + cur_cost = tok->tot_cost, + tot_cost = cur_cost + ac_cost + graph_cost; + if (tot_cost >= next_cutoff) continue; + else if (tot_cost + adaptive_beam < next_cutoff) + next_cutoff = tot_cost + adaptive_beam; // prune by best current token + // Note: the frame indexes into active_toks_ are one-based, + // hence the + 1. + Elem *e_next = FindOrAddToken(arc.nextstate, + frame + 1, tot_cost, tok, NULL); + // NULL: no change indicator needed + + // Add ForwardLink from tok to next_tok (put on head of list tok->links) + tok->links = new (forward_link_pool_.Allocate()) + ForwardLinkT(e_next->val, arc.ilabel, arc.olabel, graph_cost, + ac_cost, tok->links); + } + } // for all arcs + } + e_tail = e->tail; + toks_.Delete(e); // delete Elem + } + return next_cutoff; +} + +// static inline +template +void LatticeFasterDecoderTpl::DeleteForwardLinks(Token *tok) { + ForwardLinkT *l = tok->links, *m; + while (l != NULL) { + m = l->next; + forward_link_pool_.Free(l); + l = m; + } + tok->links = NULL; +} + + +template +void LatticeFasterDecoderTpl::ProcessNonemitting(BaseFloat cutoff) { + KALDI_ASSERT(!active_toks_.empty()); + int32 frame = static_cast(active_toks_.size()) - 2; + // Note: "frame" is the time-index we just processed, or -1 if + // we are processing the nonemitting transitions before the + // first frame (called from InitDecoding()). + + // Processes nonemitting arcs for one frame. Propagates within toks_. + // Note-- this queue structure is not very optimal as + // it may cause us to process states unnecessarily (e.g. more than once), + // but in the baseline code, turning this vector into a set to fix this + // problem did not improve overall speed. + + KALDI_ASSERT(queue_.empty()); + + if (toks_.GetList() == NULL) { + if (!warned_) { + KALDI_WARN << "Error, no surviving tokens: frame is " << frame; + warned_ = true; + } + } + + for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail) { + StateId state = e->key; + if (fst_->NumInputEpsilons(state) != 0) + queue_.push_back(e); + } + + while (!queue_.empty()) { + const Elem *e = queue_.back(); + queue_.pop_back(); + + StateId state = e->key; + Token *tok = e->val; // would segfault if e is a NULL pointer but this can't happen. + BaseFloat cur_cost = tok->tot_cost; + if (cur_cost >= cutoff) // Don't bother processing successors. + continue; + // If "tok" has any existing forward links, delete them, + // because we're about to regenerate them. This is a kind + // of non-optimality (remember, this is the simple decoder), + // but since most states are emitting it's not a huge issue. + DeleteForwardLinks(tok); // necessary when re-visiting + tok->links = NULL; + for (fst::ArcIterator aiter(*fst_, state); + !aiter.Done(); + aiter.Next()) { + const Arc &arc = aiter.Value(); + if (arc.ilabel == 0) { // propagate nonemitting only... + BaseFloat graph_cost = arc.weight.Value(), + tot_cost = cur_cost + graph_cost; + if (tot_cost < cutoff) { + bool changed; + + Elem *e_new = FindOrAddToken(arc.nextstate, frame + 1, tot_cost, + tok, &changed); + + tok->links = new (forward_link_pool_.Allocate()) ForwardLinkT( + e_new->val, 0, arc.olabel, graph_cost, 0, tok->links); + + // "changed" tells us whether the new token has a different + // cost from before, or is new [if so, add into queue]. + if (changed && fst_->NumInputEpsilons(arc.nextstate) != 0) + queue_.push_back(e_new); + } + } + } // for all arcs + } // while queue not empty +} + + +template +void LatticeFasterDecoderTpl::DeleteElems(Elem *list) { + for (Elem *e = list, *e_tail; e != NULL; e = e_tail) { + e_tail = e->tail; + toks_.Delete(e); + } +} + +template +void LatticeFasterDecoderTpl::ClearActiveTokens() { // a cleanup routine, at utt end/begin + for (size_t i = 0; i < active_toks_.size(); i++) { + // Delete all tokens alive on this frame, and any forward + // links they may have. + for (Token *tok = active_toks_[i].toks; tok != NULL; ) { + DeleteForwardLinks(tok); + Token *next_tok = tok->next; + token_pool_.Free(tok); + num_toks_--; + tok = next_tok; + } + } + active_toks_.clear(); + KALDI_ASSERT(num_toks_ == 0); +} + +// static +template +void LatticeFasterDecoderTpl::TopSortTokens( + Token *tok_list, std::vector *topsorted_list) { + unordered_map token2pos; + typedef typename unordered_map::iterator IterType; + int32 num_toks = 0; + for (Token *tok = tok_list; tok != NULL; tok = tok->next) + num_toks++; + int32 cur_pos = 0; + // We assign the tokens numbers num_toks - 1, ... , 2, 1, 0. + // This is likely to be in closer to topological order than + // if we had given them ascending order, because of the way + // new tokens are put at the front of the list. + for (Token *tok = tok_list; tok != NULL; tok = tok->next) + token2pos[tok] = num_toks - ++cur_pos; + + unordered_set reprocess; + + for (IterType iter = token2pos.begin(); iter != token2pos.end(); ++iter) { + Token *tok = iter->first; + int32 pos = iter->second; + for (ForwardLinkT *link = tok->links; link != NULL; link = link->next) { + if (link->ilabel == 0) { + // We only need to consider epsilon links, since non-epsilon links + // transition between frames and this function only needs to sort a list + // of tokens from a single frame. + IterType following_iter = token2pos.find(link->next_tok); + if (following_iter != token2pos.end()) { // another token on this frame, + // so must consider it. + int32 next_pos = following_iter->second; + if (next_pos < pos) { // reassign the position of the next Token. + following_iter->second = cur_pos++; + reprocess.insert(link->next_tok); + } + } + } + } + // In case we had previously assigned this token to be reprocessed, we can + // erase it from that set because it's "happy now" (we just processed it). + reprocess.erase(tok); + } + + size_t max_loop = 1000000, loop_count; // max_loop is to detect epsilon cycles. + for (loop_count = 0; + !reprocess.empty() && loop_count < max_loop; ++loop_count) { + std::vector reprocess_vec; + for (typename unordered_set::iterator iter = reprocess.begin(); + iter != reprocess.end(); ++iter) + reprocess_vec.push_back(*iter); + reprocess.clear(); + for (typename std::vector::iterator iter = reprocess_vec.begin(); + iter != reprocess_vec.end(); ++iter) { + Token *tok = *iter; + int32 pos = token2pos[tok]; + // Repeat the processing we did above (for comments, see above). + for (ForwardLinkT *link = tok->links; link != NULL; link = link->next) { + if (link->ilabel == 0) { + IterType following_iter = token2pos.find(link->next_tok); + if (following_iter != token2pos.end()) { + int32 next_pos = following_iter->second; + if (next_pos < pos) { + following_iter->second = cur_pos++; + reprocess.insert(link->next_tok); + } + } + } + } + } + } + KALDI_ASSERT(loop_count < max_loop && "Epsilon loops exist in your decoding " + "graph (this is not allowed!)"); + + topsorted_list->clear(); + topsorted_list->resize(cur_pos, NULL); // create a list with NULLs in between. + for (IterType iter = token2pos.begin(); iter != token2pos.end(); ++iter) + (*topsorted_list)[iter->second] = iter->first; +} + +// Instantiate the template for the combination of token types and FST types +// that we'll need. +template class LatticeFasterDecoderTpl, decoder::StdToken>; +template class LatticeFasterDecoderTpl, decoder::StdToken >; +template class LatticeFasterDecoderTpl, decoder::StdToken >; + +template class LatticeFasterDecoderTpl; +template class LatticeFasterDecoderTpl; + +template class LatticeFasterDecoderTpl , decoder::BackpointerToken>; +template class LatticeFasterDecoderTpl, decoder::BackpointerToken >; +template class LatticeFasterDecoderTpl, decoder::BackpointerToken >; +template class LatticeFasterDecoderTpl; +template class LatticeFasterDecoderTpl; + + +} // end namespace kaldi. diff --git a/speechx/speechx/kaldi/decoder/lattice-faster-decoder.h b/speechx/speechx/kaldi/decoder/lattice-faster-decoder.h new file mode 100644 index 00000000..2016ad57 --- /dev/null +++ b/speechx/speechx/kaldi/decoder/lattice-faster-decoder.h @@ -0,0 +1,549 @@ +// decoder/lattice-faster-decoder.h + +// Copyright 2009-2013 Microsoft Corporation; Mirko Hannemann; +// 2013-2014 Johns Hopkins University (Author: Daniel Povey) +// 2014 Guoguo Chen +// 2018 Zhehuai Chen + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_DECODER_LATTICE_FASTER_DECODER_H_ +#define KALDI_DECODER_LATTICE_FASTER_DECODER_H_ + +#include "decoder/grammar-fst.h" +#include "fst/fstlib.h" +#include "fst/memory.h" +#include "fstext/fstext-lib.h" +#include "itf/decodable-itf.h" +#include "lat/determinize-lattice-pruned.h" +#include "lat/kaldi-lattice.h" +#include "util/hash-list.h" +#include "util/stl-utils.h" + +namespace kaldi { + +struct LatticeFasterDecoderConfig { + BaseFloat beam; + int32 max_active; + int32 min_active; + BaseFloat lattice_beam; + int32 prune_interval; + bool determinize_lattice; // not inspected by this class... used in + // command-line program. + BaseFloat beam_delta; + BaseFloat hash_ratio; + // Note: we don't make prune_scale configurable on the command line, it's not + // a very important parameter. It affects the algorithm that prunes the + // tokens as we go. + BaseFloat prune_scale; + + // Number of elements in the block for Token and ForwardLink memory + // pool allocation. + int32 memory_pool_tokens_block_size; + int32 memory_pool_links_block_size; + + // Most of the options inside det_opts are not actually queried by the + // LatticeFasterDecoder class itself, but by the code that calls it, for + // example in the function DecodeUtteranceLatticeFaster. + fst::DeterminizeLatticePhonePrunedOptions det_opts; + + LatticeFasterDecoderConfig() + : beam(16.0), + max_active(std::numeric_limits::max()), + min_active(200), + lattice_beam(10.0), + prune_interval(25), + determinize_lattice(true), + beam_delta(0.5), + hash_ratio(2.0), + prune_scale(0.1), + memory_pool_tokens_block_size(1 << 8), + memory_pool_links_block_size(1 << 8) {} + void Register(OptionsItf *opts) { + det_opts.Register(opts); + opts->Register("beam", &beam, "Decoding beam. Larger->slower, more accurate."); + opts->Register("max-active", &max_active, "Decoder max active states. Larger->slower; " + "more accurate"); + opts->Register("min-active", &min_active, "Decoder minimum #active states."); + opts->Register("lattice-beam", &lattice_beam, "Lattice generation beam. Larger->slower, " + "and deeper lattices"); + opts->Register("prune-interval", &prune_interval, "Interval (in frames) at " + "which to prune tokens"); + opts->Register("determinize-lattice", &determinize_lattice, "If true, " + "determinize the lattice (lattice-determinization, keeping only " + "best pdf-sequence for each word-sequence)."); + opts->Register("beam-delta", &beam_delta, "Increment used in decoding-- this " + "parameter is obscure and relates to a speedup in the way the " + "max-active constraint is applied. Larger is more accurate."); + opts->Register("hash-ratio", &hash_ratio, "Setting used in decoder to " + "control hash behavior"); + opts->Register("memory-pool-tokens-block-size", &memory_pool_tokens_block_size, + "Memory pool block size suggestion for storing tokens (in elements). " + "Smaller uses less memory but increases cache misses."); + opts->Register("memory-pool-links-block-size", &memory_pool_links_block_size, + "Memory pool block size suggestion for storing links (in elements). " + "Smaller uses less memory but increases cache misses."); + } + void Check() const { + KALDI_ASSERT(beam > 0.0 && max_active > 1 && lattice_beam > 0.0 + && min_active <= max_active + && prune_interval > 0 && beam_delta > 0.0 && hash_ratio >= 1.0 + && prune_scale > 0.0 && prune_scale < 1.0); + } +}; + +namespace decoder { +// We will template the decoder on the token type as well as the FST type; this +// is a mechanism so that we can use the same underlying decoder code for +// versions of the decoder that support quickly getting the best path +// (LatticeFasterOnlineDecoder, see lattice-faster-online-decoder.h) and also +// those that do not (LatticeFasterDecoder). + + +// ForwardLinks are the links from a token to a token on the next frame. +// or sometimes on the current frame (for input-epsilon links). +template +struct ForwardLink { + using Label = fst::StdArc::Label; + + Token *next_tok; // the next token [or NULL if represents final-state] + Label ilabel; // ilabel on arc + Label olabel; // olabel on arc + BaseFloat graph_cost; // graph cost of traversing arc (contains LM, etc.) + BaseFloat acoustic_cost; // acoustic cost (pre-scaled) of traversing arc + ForwardLink *next; // next in singly-linked list of forward arcs (arcs + // in the state-level lattice) from a token. + inline ForwardLink(Token *next_tok, Label ilabel, Label olabel, + BaseFloat graph_cost, BaseFloat acoustic_cost, + ForwardLink *next): + next_tok(next_tok), ilabel(ilabel), olabel(olabel), + graph_cost(graph_cost), acoustic_cost(acoustic_cost), + next(next) { } +}; + + +struct StdToken { + using ForwardLinkT = ForwardLink; + using Token = StdToken; + + // Standard token type for LatticeFasterDecoder. Each active HCLG + // (decoding-graph) state on each frame has one token. + + // tot_cost is the total (LM + acoustic) cost from the beginning of the + // utterance up to this point. (but see cost_offset_, which is subtracted + // to keep it in a good numerical range). + BaseFloat tot_cost; + + // exta_cost is >= 0. After calling PruneForwardLinks, this equals the + // minimum difference between the cost of the best path that this link is a + // part of, and the cost of the absolute best path, under the assumption that + // any of the currently active states at the decoding front may eventually + // succeed (e.g. if you were to take the currently active states one by one + // and compute this difference, and then take the minimum). + BaseFloat extra_cost; + + // 'links' is the head of singly-linked list of ForwardLinks, which is what we + // use for lattice generation. + ForwardLinkT *links; + + //'next' is the next in the singly-linked list of tokens for this frame. + Token *next; + + // This function does nothing and should be optimized out; it's needed + // so we can share the regular LatticeFasterDecoderTpl code and the code + // for LatticeFasterOnlineDecoder that supports fast traceback. + inline void SetBackpointer (Token *backpointer) { } + + // This constructor just ignores the 'backpointer' argument. That argument is + // needed so that we can use the same decoder code for LatticeFasterDecoderTpl + // and LatticeFasterOnlineDecoderTpl (which needs backpointers to support a + // fast way to obtain the best path). + inline StdToken(BaseFloat tot_cost, BaseFloat extra_cost, ForwardLinkT *links, + Token *next, Token *backpointer): + tot_cost(tot_cost), extra_cost(extra_cost), links(links), next(next) { } +}; + +struct BackpointerToken { + using ForwardLinkT = ForwardLink; + using Token = BackpointerToken; + + // BackpointerToken is like Token but also + // Standard token type for LatticeFasterDecoder. Each active HCLG + // (decoding-graph) state on each frame has one token. + + // tot_cost is the total (LM + acoustic) cost from the beginning of the + // utterance up to this point. (but see cost_offset_, which is subtracted + // to keep it in a good numerical range). + BaseFloat tot_cost; + + // exta_cost is >= 0. After calling PruneForwardLinks, this equals + // the minimum difference between the cost of the best path, and the cost of + // this is on, and the cost of the absolute best path, under the assumption + // that any of the currently active states at the decoding front may + // eventually succeed (e.g. if you were to take the currently active states + // one by one and compute this difference, and then take the minimum). + BaseFloat extra_cost; + + // 'links' is the head of singly-linked list of ForwardLinks, which is what we + // use for lattice generation. + ForwardLinkT *links; + + //'next' is the next in the singly-linked list of tokens for this frame. + BackpointerToken *next; + + // Best preceding BackpointerToken (could be a on this frame, connected to + // this via an epsilon transition, or on a previous frame). This is only + // required for an efficient GetBestPath function in + // LatticeFasterOnlineDecoderTpl; it plays no part in the lattice generation + // (the "links" list is what stores the forward links, for that). + Token *backpointer; + + inline void SetBackpointer (Token *backpointer) { + this->backpointer = backpointer; + } + + inline BackpointerToken(BaseFloat tot_cost, BaseFloat extra_cost, ForwardLinkT *links, + Token *next, Token *backpointer): + tot_cost(tot_cost), extra_cost(extra_cost), links(links), next(next), + backpointer(backpointer) { } +}; + +} // namespace decoder + + +/** This is the "normal" lattice-generating decoder. + See \ref lattices_generation \ref decoders_faster and \ref decoders_simple + for more information. + + The decoder is templated on the FST type and the token type. The token type + will normally be StdToken, but also may be BackpointerToken which is to support + quick lookup of the current best path (see lattice-faster-online-decoder.h) + + The FST you invoke this decoder which is expected to equal + Fst::Fst, a.k.a. StdFst, or GrammarFst. If you invoke it with + FST == StdFst and it notices that the actual FST type is + fst::VectorFst or fst::ConstFst, the decoder object + will internally cast itself to one that is templated on those more specific + types; this is an optimization for speed. + */ +template +class LatticeFasterDecoderTpl { + public: + using Arc = typename FST::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + using ForwardLinkT = decoder::ForwardLink; + + // Instantiate this class once for each thing you have to decode. + // This version of the constructor does not take ownership of + // 'fst'. + LatticeFasterDecoderTpl(const FST &fst, + const LatticeFasterDecoderConfig &config); + + // This version of the constructor takes ownership of the fst, and will delete + // it when this object is destroyed. + LatticeFasterDecoderTpl(const LatticeFasterDecoderConfig &config, + FST *fst); + + void SetOptions(const LatticeFasterDecoderConfig &config) { + config_ = config; + } + + const LatticeFasterDecoderConfig &GetOptions() const { + return config_; + } + + ~LatticeFasterDecoderTpl(); + + /// Decodes until there are no more frames left in the "decodable" object.. + /// note, this may block waiting for input if the "decodable" object blocks. + /// Returns true if any kind of traceback is available (not necessarily from a + /// final state). + bool Decode(DecodableInterface *decodable); + + + /// says whether a final-state was active on the last frame. If it was not, the + /// lattice (or traceback) will end with states that are not final-states. + bool ReachedFinal() const { + return FinalRelativeCost() != std::numeric_limits::infinity(); + } + + /// Outputs an FST corresponding to the single best path through the lattice. + /// Returns true if result is nonempty (using the return status is deprecated, + /// it will become void). If "use_final_probs" is true AND we reached the + /// final-state of the graph then it will include those as final-probs, else + /// it will treat all final-probs as one. Note: this just calls GetRawLattice() + /// and figures out the shortest path. + bool GetBestPath(Lattice *ofst, + bool use_final_probs = true) const; + + /// Outputs an FST corresponding to the raw, state-level + /// tracebacks. Returns true if result is nonempty. + /// If "use_final_probs" is true AND we reached the final-state + /// of the graph then it will include those as final-probs, else + /// it will treat all final-probs as one. + /// The raw lattice will be topologically sorted. + /// + /// See also GetRawLatticePruned in lattice-faster-online-decoder.h, + /// which also supports a pruning beam, in case for some reason + /// you want it pruned tighter than the regular lattice beam. + /// We could put that here in future needed. + bool GetRawLattice(Lattice *ofst, bool use_final_probs = true) const; + + + + /// [Deprecated, users should now use GetRawLattice and determinize it + /// themselves, e.g. using DeterminizeLatticePhonePrunedWrapper]. + /// Outputs an FST corresponding to the lattice-determinized + /// lattice (one path per word sequence). Returns true if result is nonempty. + /// If "use_final_probs" is true AND we reached the final-state of the graph + /// then it will include those as final-probs, else it will treat all + /// final-probs as one. + bool GetLattice(CompactLattice *ofst, + bool use_final_probs = true) const; + + /// InitDecoding initializes the decoding, and should only be used if you + /// intend to call AdvanceDecoding(). If you call Decode(), you don't need to + /// call this. You can also call InitDecoding if you have already decoded an + /// utterance and want to start with a new utterance. + void InitDecoding(); + + /// This will decode until there are no more frames ready in the decodable + /// object. You can keep calling it each time more frames become available. + /// If max_num_frames is specified, it specifies the maximum number of frames + /// the function will decode before returning. + void AdvanceDecoding(DecodableInterface *decodable, + int32 max_num_frames = -1); + + /// This function may be optionally called after AdvanceDecoding(), when you + /// do not plan to decode any further. It does an extra pruning step that + /// will help to prune the lattices output by GetLattice and (particularly) + /// GetRawLattice more completely, particularly toward the end of the + /// utterance. If you call this, you cannot call AdvanceDecoding again (it + /// will fail), and you cannot call GetLattice() and related functions with + /// use_final_probs = false. Used to be called PruneActiveTokensFinal(). + void FinalizeDecoding(); + + /// FinalRelativeCost() serves the same purpose as ReachedFinal(), but gives + /// more information. It returns the difference between the best (final-cost + /// plus cost) of any token on the final frame, and the best cost of any token + /// on the final frame. If it is infinity it means no final-states were + /// present on the final frame. It will usually be nonnegative. If it not + /// too positive (e.g. < 5 is my first guess, but this is not tested) you can + /// take it as a good indication that we reached the final-state with + /// reasonable likelihood. + BaseFloat FinalRelativeCost() const; + + + // Returns the number of frames decoded so far. The value returned changes + // whenever we call ProcessEmitting(). + inline int32 NumFramesDecoded() const { return active_toks_.size() - 1; } + + protected: + // we make things protected instead of private, as code in + // LatticeFasterOnlineDecoderTpl, which inherits from this, also uses the + // internals. + + // Deletes the elements of the singly linked list tok->links. + void DeleteForwardLinks(Token *tok); + + // head of per-frame list of Tokens (list is in topological order), + // and something saying whether we ever pruned it using PruneForwardLinks. + struct TokenList { + Token *toks; + bool must_prune_forward_links; + bool must_prune_tokens; + TokenList(): toks(NULL), must_prune_forward_links(true), + must_prune_tokens(true) { } + }; + + using Elem = typename HashList::Elem; + // Equivalent to: + // struct Elem { + // StateId key; + // Token *val; + // Elem *tail; + // }; + + void PossiblyResizeHash(size_t num_toks); + + // FindOrAddToken either locates a token in hash of toks_, or if necessary + // inserts a new, empty token (i.e. with no forward links) for the current + // frame. [note: it's inserted if necessary into hash toks_ and also into the + // singly linked list of tokens active on this frame (whose head is at + // active_toks_[frame]). The frame_plus_one argument is the acoustic frame + // index plus one, which is used to index into the active_toks_ array. + // Returns the Token pointer. Sets "changed" (if non-NULL) to true if the + // token was newly created or the cost changed. + // If Token == StdToken, the 'backpointer' argument has no purpose (and will + // hopefully be optimized out). + inline Elem *FindOrAddToken(StateId state, int32 frame_plus_one, + BaseFloat tot_cost, Token *backpointer, + bool *changed); + + // prunes outgoing links for all tokens in active_toks_[frame] + // it's called by PruneActiveTokens + // all links, that have link_extra_cost > lattice_beam are pruned + // delta is the amount by which the extra_costs must change + // before we set *extra_costs_changed = true. + // If delta is larger, we'll tend to go back less far + // toward the beginning of the file. + // extra_costs_changed is set to true if extra_cost was changed for any token + // links_pruned is set to true if any link in any token was pruned + void PruneForwardLinks(int32 frame_plus_one, bool *extra_costs_changed, + bool *links_pruned, + BaseFloat delta); + + // This function computes the final-costs for tokens active on the final + // frame. It outputs to final-costs, if non-NULL, a map from the Token* + // pointer to the final-prob of the corresponding state, for all Tokens + // that correspond to states that have final-probs. This map will be + // empty if there were no final-probs. It outputs to + // final_relative_cost, if non-NULL, the difference between the best + // forward-cost including the final-prob cost, and the best forward-cost + // without including the final-prob cost (this will usually be positive), or + // infinity if there were no final-probs. [c.f. FinalRelativeCost(), which + // outputs this quanitity]. It outputs to final_best_cost, if + // non-NULL, the lowest for any token t active on the final frame, of + // forward-cost[t] + final-cost[t], where final-cost[t] is the final-cost in + // the graph of the state corresponding to token t, or the best of + // forward-cost[t] if there were no final-probs active on the final frame. + // You cannot call this after FinalizeDecoding() has been called; in that + // case you should get the answer from class-member variables. + void ComputeFinalCosts(unordered_map *final_costs, + BaseFloat *final_relative_cost, + BaseFloat *final_best_cost) const; + + // PruneForwardLinksFinal is a version of PruneForwardLinks that we call + // on the final frame. If there are final tokens active, it uses + // the final-probs for pruning, otherwise it treats all tokens as final. + void PruneForwardLinksFinal(); + + // Prune away any tokens on this frame that have no forward links. + // [we don't do this in PruneForwardLinks because it would give us + // a problem with dangling pointers]. + // It's called by PruneActiveTokens if any forward links have been pruned + void PruneTokensForFrame(int32 frame_plus_one); + + + // Go backwards through still-alive tokens, pruning them if the + // forward+backward cost is more than lat_beam away from the best path. It's + // possible to prove that this is "correct" in the sense that we won't lose + // anything outside of lat_beam, regardless of what happens in the future. + // delta controls when it considers a cost to have changed enough to continue + // going backward and propagating the change. larger delta -> will recurse + // less far. + void PruneActiveTokens(BaseFloat delta); + + /// Gets the weight cutoff. Also counts the active tokens. + BaseFloat GetCutoff(Elem *list_head, size_t *tok_count, + BaseFloat *adaptive_beam, Elem **best_elem); + + /// Processes emitting arcs for one frame. Propagates from prev_toks_ to + /// cur_toks_. Returns the cost cutoff for subsequent ProcessNonemitting() to + /// use. + BaseFloat ProcessEmitting(DecodableInterface *decodable); + + /// Processes nonemitting (epsilon) arcs for one frame. Called after + /// ProcessEmitting() on each frame. The cost cutoff is computed by the + /// preceding ProcessEmitting(). + void ProcessNonemitting(BaseFloat cost_cutoff); + + // HashList defined in ../util/hash-list.h. It actually allows us to maintain + // more than one list (e.g. for current and previous frames), but only one of + // them at a time can be indexed by StateId. It is indexed by frame-index + // plus one, where the frame-index is zero-based, as used in decodable object. + // That is, the emitting probs of frame t are accounted for in tokens at + // toks_[t+1]. The zeroth frame is for nonemitting transition at the start of + // the graph. + HashList toks_; + + std::vector active_toks_; // Lists of tokens, indexed by + // frame (members of TokenList are toks, must_prune_forward_links, + // must_prune_tokens). + std::vector queue_; // temp variable used in ProcessNonemitting, + std::vector tmp_array_; // used in GetCutoff. + + // fst_ is a pointer to the FST we are decoding from. + const FST *fst_; + // delete_fst_ is true if the pointer fst_ needs to be deleted when this + // object is destroyed. + bool delete_fst_; + + std::vector cost_offsets_; // This contains, for each + // frame, an offset that was added to the acoustic log-likelihoods on that + // frame in order to keep everything in a nice dynamic range i.e. close to + // zero, to reduce roundoff errors. + LatticeFasterDecoderConfig config_; + int32 num_toks_; // current total #toks allocated... + bool warned_; + + /// decoding_finalized_ is true if someone called FinalizeDecoding(). [note, + /// calling this is optional]. If true, it's forbidden to decode more. Also, + /// if this is set, then the output of ComputeFinalCosts() is in the next + /// three variables. The reason we need to do this is that after + /// FinalizeDecoding() calls PruneTokensForFrame() for the final frame, some + /// of the tokens on the last frame are freed, so we free the list from toks_ + /// to avoid having dangling pointers hanging around. + bool decoding_finalized_; + /// For the meaning of the next 3 variables, see the comment for + /// decoding_finalized_ above., and ComputeFinalCosts(). + unordered_map final_costs_; + BaseFloat final_relative_cost_; + BaseFloat final_best_cost_; + + // Memory pools for storing tokens and forward links. + // We use it to decrease the work put on allocator and to move some of data + // together. Too small block sizes will result in more work to allocator but + // bigger ones increase the memory usage. + fst::MemoryPool token_pool_; + fst::MemoryPool forward_link_pool_; + + // There are various cleanup tasks... the toks_ structure contains + // singly linked lists of Token pointers, where Elem is the list type. + // It also indexes them in a hash, indexed by state (this hash is only + // maintained for the most recent frame). toks_.Clear() + // deletes them from the hash and returns the list of Elems. The + // function DeleteElems calls toks_.Delete(elem) for each elem in + // the list, which returns ownership of the Elem to the toks_ structure + // for reuse, but does not delete the Token pointer. The Token pointers + // are reference-counted and are ultimately deleted in PruneTokensForFrame, + // but are also linked together on each frame by their own linked-list, + // using the "next" pointer. We delete them manually. + void DeleteElems(Elem *list); + + // This function takes a singly linked list of tokens for a single frame, and + // outputs a list of them in topological order (it will crash if no such order + // can be found, which will typically be due to decoding graphs with epsilon + // cycles, which are not allowed). Note: the output list may contain NULLs, + // which the caller should pass over; it just happens to be more efficient for + // the algorithm to output a list that contains NULLs. + static void TopSortTokens(Token *tok_list, + std::vector *topsorted_list); + + void ClearActiveTokens(); + + KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeFasterDecoderTpl); +}; + +typedef LatticeFasterDecoderTpl LatticeFasterDecoder; + + + +} // end namespace kaldi. + +#endif diff --git a/speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.cc b/speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.cc new file mode 100644 index 00000000..ebdace7e --- /dev/null +++ b/speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.cc @@ -0,0 +1,285 @@ +// decoder/lattice-faster-online-decoder.cc + +// Copyright 2009-2012 Microsoft Corporation Mirko Hannemann +// 2013-2014 Johns Hopkins University (Author: Daniel Povey) +// 2014 Guoguo Chen +// 2014 IMSL, PKU-HKUST (author: Wei Shi) +// 2018 Zhehuai Chen + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +// see note at the top of lattice-faster-decoder.cc, about how to maintain this +// file in sync with lattice-faster-decoder.cc + +#include "decoder/lattice-faster-online-decoder.h" +#include "lat/lattice-functions.h" + +namespace kaldi { + +template +bool LatticeFasterOnlineDecoderTpl::TestGetBestPath( + bool use_final_probs) const { + Lattice lat1; + { + Lattice raw_lat; + this->GetRawLattice(&raw_lat, use_final_probs); + ShortestPath(raw_lat, &lat1); + } + Lattice lat2; + GetBestPath(&lat2, use_final_probs); + BaseFloat delta = 0.1; + int32 num_paths = 1; + if (!fst::RandEquivalent(lat1, lat2, num_paths, delta, rand())) { + KALDI_WARN << "Best-path test failed"; + return false; + } else { + return true; + } +} + + +// Outputs an FST corresponding to the single best path through the lattice. +template +bool LatticeFasterOnlineDecoderTpl::GetBestPath(Lattice *olat, + bool use_final_probs) const { + olat->DeleteStates(); + BaseFloat final_graph_cost; + BestPathIterator iter = BestPathEnd(use_final_probs, &final_graph_cost); + if (iter.Done()) + return false; // would have printed warning. + StateId state = olat->AddState(); + olat->SetFinal(state, LatticeWeight(final_graph_cost, 0.0)); + while (!iter.Done()) { + LatticeArc arc; + iter = TraceBackBestPath(iter, &arc); + arc.nextstate = state; + StateId new_state = olat->AddState(); + olat->AddArc(new_state, arc); + state = new_state; + } + olat->SetStart(state); + return true; +} + +template +typename LatticeFasterOnlineDecoderTpl::BestPathIterator LatticeFasterOnlineDecoderTpl::BestPathEnd( + bool use_final_probs, + BaseFloat *final_cost_out) const { + if (this->decoding_finalized_ && !use_final_probs) + KALDI_ERR << "You cannot call FinalizeDecoding() and then call " + << "BestPathEnd() with use_final_probs == false"; + KALDI_ASSERT(this->NumFramesDecoded() > 0 && + "You cannot call BestPathEnd if no frames were decoded."); + + unordered_map final_costs_local; + + const unordered_map &final_costs = + (this->decoding_finalized_ ? this->final_costs_ :final_costs_local); + if (!this->decoding_finalized_ && use_final_probs) + this->ComputeFinalCosts(&final_costs_local, NULL, NULL); + + // Singly linked list of tokens on last frame (access list through "next" + // pointer). + BaseFloat best_cost = std::numeric_limits::infinity(); + BaseFloat best_final_cost = 0; + Token *best_tok = NULL; + for (Token *tok = this->active_toks_.back().toks; + tok != NULL; tok = tok->next) { + BaseFloat cost = tok->tot_cost, final_cost = 0.0; + if (use_final_probs && !final_costs.empty()) { + // if we are instructed to use final-probs, and any final tokens were + // active on final frame, include the final-prob in the cost of the token. + typename unordered_map::const_iterator + iter = final_costs.find(tok); + if (iter != final_costs.end()) { + final_cost = iter->second; + cost += final_cost; + } else { + cost = std::numeric_limits::infinity(); + } + } + if (cost < best_cost) { + best_cost = cost; + best_tok = tok; + best_final_cost = final_cost; + } + } + if (best_tok == NULL) { // this should not happen, and is likely a code error or + // caused by infinities in likelihoods, but I'm not making + // it a fatal error for now. + KALDI_WARN << "No final token found."; + } + if (final_cost_out) + *final_cost_out = best_final_cost; + return BestPathIterator(best_tok, this->NumFramesDecoded() - 1); +} + + +template +typename LatticeFasterOnlineDecoderTpl::BestPathIterator LatticeFasterOnlineDecoderTpl::TraceBackBestPath( + BestPathIterator iter, LatticeArc *oarc) const { + KALDI_ASSERT(!iter.Done() && oarc != NULL); + Token *tok = static_cast(iter.tok); + int32 cur_t = iter.frame, step_t = 0; + if (tok->backpointer != NULL) { + // retrieve the correct forward link(with the best link cost) + BaseFloat best_cost = std::numeric_limits::infinity(); + ForwardLinkT *link; + for (link = tok->backpointer->links; + link != NULL; link = link->next) { + if (link->next_tok == tok) { // this is a link to "tok" + BaseFloat graph_cost = link->graph_cost, + acoustic_cost = link->acoustic_cost; + BaseFloat cost = graph_cost + acoustic_cost; + if (cost < best_cost) { + oarc->ilabel = link->ilabel; + oarc->olabel = link->olabel; + if (link->ilabel != 0) { + KALDI_ASSERT(static_cast(cur_t) < this->cost_offsets_.size()); + acoustic_cost -= this->cost_offsets_[cur_t]; + step_t = -1; + } else { + step_t = 0; + } + oarc->weight = LatticeWeight(graph_cost, acoustic_cost); + best_cost = cost; + } + } + } + if (link == NULL && + best_cost == std::numeric_limits::infinity()) { // Did not find correct link. + KALDI_ERR << "Error tracing best-path back (likely " + << "bug in token-pruning algorithm)"; + } + } else { + oarc->ilabel = 0; + oarc->olabel = 0; + oarc->weight = LatticeWeight::One(); // zero costs. + } + return BestPathIterator(tok->backpointer, cur_t + step_t); +} + +template +bool LatticeFasterOnlineDecoderTpl::GetRawLatticePruned( + Lattice *ofst, + bool use_final_probs, + BaseFloat beam) const { + typedef LatticeArc Arc; + typedef Arc::StateId StateId; + typedef Arc::Weight Weight; + typedef Arc::Label Label; + + // Note: you can't use the old interface (Decode()) if you want to + // get the lattice with use_final_probs = false. You'd have to do + // InitDecoding() and then AdvanceDecoding(). + if (this->decoding_finalized_ && !use_final_probs) + KALDI_ERR << "You cannot call FinalizeDecoding() and then call " + << "GetRawLattice() with use_final_probs == false"; + + unordered_map final_costs_local; + + const unordered_map &final_costs = + (this->decoding_finalized_ ? this->final_costs_ : final_costs_local); + if (!this->decoding_finalized_ && use_final_probs) + this->ComputeFinalCosts(&final_costs_local, NULL, NULL); + + ofst->DeleteStates(); + // num-frames plus one (since frames are one-based, and we have + // an extra frame for the start-state). + int32 num_frames = this->active_toks_.size() - 1; + KALDI_ASSERT(num_frames > 0); + for (int32 f = 0; f <= num_frames; f++) { + if (this->active_toks_[f].toks == NULL) { + KALDI_WARN << "No tokens active on frame " << f + << ": not producing lattice.\n"; + return false; + } + } + unordered_map tok_map; + std::queue > tok_queue; + // First initialize the queue and states. Put the initial state on the queue; + // this is the last token in the list active_toks_[0].toks. + for (Token *tok = this->active_toks_[0].toks; + tok != NULL; tok = tok->next) { + if (tok->next == NULL) { + tok_map[tok] = ofst->AddState(); + ofst->SetStart(tok_map[tok]); + std::pair tok_pair(tok, 0); // #frame = 0 + tok_queue.push(tok_pair); + } + } + + // Next create states for "good" tokens + while (!tok_queue.empty()) { + std::pair cur_tok_pair = tok_queue.front(); + tok_queue.pop(); + Token *cur_tok = cur_tok_pair.first; + int32 cur_frame = cur_tok_pair.second; + KALDI_ASSERT(cur_frame >= 0 && + cur_frame <= this->cost_offsets_.size()); + + typename unordered_map::const_iterator iter = + tok_map.find(cur_tok); + KALDI_ASSERT(iter != tok_map.end()); + StateId cur_state = iter->second; + + for (ForwardLinkT *l = cur_tok->links; + l != NULL; + l = l->next) { + Token *next_tok = l->next_tok; + if (next_tok->extra_cost < beam) { + // so both the current and the next token are good; create the arc + int32 next_frame = l->ilabel == 0 ? cur_frame : cur_frame + 1; + StateId nextstate; + if (tok_map.find(next_tok) == tok_map.end()) { + nextstate = tok_map[next_tok] = ofst->AddState(); + tok_queue.push(std::pair(next_tok, next_frame)); + } else { + nextstate = tok_map[next_tok]; + } + BaseFloat cost_offset = (l->ilabel != 0 ? + this->cost_offsets_[cur_frame] : 0); + Arc arc(l->ilabel, l->olabel, + Weight(l->graph_cost, l->acoustic_cost - cost_offset), + nextstate); + ofst->AddArc(cur_state, arc); + } + } + if (cur_frame == num_frames) { + if (use_final_probs && !final_costs.empty()) { + typename unordered_map::const_iterator iter = + final_costs.find(cur_tok); + if (iter != final_costs.end()) + ofst->SetFinal(cur_state, LatticeWeight(iter->second, 0)); + } else { + ofst->SetFinal(cur_state, LatticeWeight::One()); + } + } + } + return (ofst->NumStates() != 0); +} + + + +// Instantiate the template for the FST types that we'll need. +template class LatticeFasterOnlineDecoderTpl >; +template class LatticeFasterOnlineDecoderTpl >; +template class LatticeFasterOnlineDecoderTpl >; +template class LatticeFasterOnlineDecoderTpl; +template class LatticeFasterOnlineDecoderTpl; + + +} // end namespace kaldi. diff --git a/speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.h b/speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.h new file mode 100644 index 00000000..8b10996f --- /dev/null +++ b/speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.h @@ -0,0 +1,147 @@ +// decoder/lattice-faster-online-decoder.h + +// Copyright 2009-2013 Microsoft Corporation; Mirko Hannemann; +// 2013-2014 Johns Hopkins University (Author: Daniel Povey) +// 2014 Guoguo Chen +// 2018 Zhehuai Chen + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +// see note at the top of lattice-faster-decoder.h, about how to maintain this +// file in sync with lattice-faster-decoder.h + + +#ifndef KALDI_DECODER_LATTICE_FASTER_ONLINE_DECODER_H_ +#define KALDI_DECODER_LATTICE_FASTER_ONLINE_DECODER_H_ + +#include "util/stl-utils.h" +#include "util/hash-list.h" +#include "fst/fstlib.h" +#include "itf/decodable-itf.h" +#include "fstext/fstext-lib.h" +#include "lat/determinize-lattice-pruned.h" +#include "lat/kaldi-lattice.h" +#include "decoder/lattice-faster-decoder.h" + +namespace kaldi { + + + +/** LatticeFasterOnlineDecoderTpl is as LatticeFasterDecoderTpl but also + supports an efficient way to get the best path (see the function + BestPathEnd()), which is useful in endpointing and in situations where you + might want to frequently access the best path. + + This is only templated on the FST type, since the Token type is required to + be BackpointerToken. Actually it only makes sense to instantiate + LatticeFasterDecoderTpl with Token == BackpointerToken if you do so indirectly via + this child class. + */ +template +class LatticeFasterOnlineDecoderTpl: + public LatticeFasterDecoderTpl { + public: + using Arc = typename FST::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + using Token = decoder::BackpointerToken; + using ForwardLinkT = decoder::ForwardLink; + + // Instantiate this class once for each thing you have to decode. + // This version of the constructor does not take ownership of + // 'fst'. + LatticeFasterOnlineDecoderTpl(const FST &fst, + const LatticeFasterDecoderConfig &config): + LatticeFasterDecoderTpl(fst, config) { } + + // This version of the initializer takes ownership of 'fst', and will delete + // it when this object is destroyed. + LatticeFasterOnlineDecoderTpl(const LatticeFasterDecoderConfig &config, + FST *fst): + LatticeFasterDecoderTpl(config, fst) { } + + + struct BestPathIterator { + void *tok; + int32 frame; + // note, "frame" is the frame-index of the frame you'll get the + // transition-id for next time, if you call TraceBackBestPath on this + // iterator (assuming it's not an epsilon transition). Note that this + // is one less than you might reasonably expect, e.g. it's -1 for + // the nonemitting transitions before the first frame. + BestPathIterator(void *t, int32 f): tok(t), frame(f) { } + bool Done() const { return tok == NULL; } + }; + + + /// Outputs an FST corresponding to the single best path through the lattice. + /// This is quite efficient because it doesn't get the entire raw lattice and find + /// the best path through it; instead, it uses the BestPathEnd and BestPathIterator + /// so it basically traces it back through the lattice. + /// Returns true if result is nonempty (using the return status is deprecated, + /// it will become void). If "use_final_probs" is true AND we reached the + /// final-state of the graph then it will include those as final-probs, else + /// it will treat all final-probs as one. + bool GetBestPath(Lattice *ofst, + bool use_final_probs = true) const; + + + /// This function does a self-test of GetBestPath(). Returns true on + /// success; returns false and prints a warning on failure. + bool TestGetBestPath(bool use_final_probs = true) const; + + + /// This function returns an iterator that can be used to trace back + /// the best path. If use_final_probs == true and at least one final state + /// survived till the end, it will use the final-probs in working out the best + /// final Token, and will output the final cost to *final_cost (if non-NULL), + /// else it will use only the forward likelihood, and will put zero in + /// *final_cost (if non-NULL). + /// Requires that NumFramesDecoded() > 0. + BestPathIterator BestPathEnd(bool use_final_probs, + BaseFloat *final_cost = NULL) const; + + + /// This function can be used in conjunction with BestPathEnd() to trace back + /// the best path one link at a time (e.g. this can be useful in endpoint + /// detection). By "link" we mean a link in the graph; not all links cross + /// frame boundaries, but each time you see a nonzero ilabel you can interpret + /// that as a frame. The return value is the updated iterator. It outputs + /// the ilabel and olabel, and the (graph and acoustic) weight to the "arc" pointer, + /// while leaving its "nextstate" variable unchanged. + BestPathIterator TraceBackBestPath( + BestPathIterator iter, LatticeArc *arc) const; + + + /// Behaves the same as GetRawLattice but only processes tokens whose + /// extra_cost is smaller than the best-cost plus the specified beam. + /// It is only worthwhile to call this function if beam is less than + /// the lattice_beam specified in the config; otherwise, it would + /// return essentially the same thing as GetRawLattice, but more slowly. + bool GetRawLatticePruned(Lattice *ofst, + bool use_final_probs, + BaseFloat beam) const; + + KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeFasterOnlineDecoderTpl); +}; + +typedef LatticeFasterOnlineDecoderTpl LatticeFasterOnlineDecoder; + + +} // end namespace kaldi. + +#endif diff --git a/speechx/speechx/kaldi/lat/determinize-lattice-pruned-test.cc b/speechx/speechx/kaldi/lat/determinize-lattice-pruned-test.cc new file mode 100644 index 00000000..f6684f0b --- /dev/null +++ b/speechx/speechx/kaldi/lat/determinize-lattice-pruned-test.cc @@ -0,0 +1,147 @@ +// lat/determinize-lattice-pruned-test.cc + +// Copyright 2009-2012 Microsoft Corporation +// 2012-2013 Johns Hopkins University (Author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "lat/determinize-lattice-pruned.h" +#include "fstext/lattice-utils.h" +#include "fstext/fst-test-utils.h" +#include "lat/kaldi-lattice.h" +#include "lat/lattice-functions.h" + +namespace fst { +// Caution: these tests are not as generic as you might think from all the +// templates in the code. They are basically only valid for LatticeArc. +// This is partly due to the fact that certain templates need to be instantiated +// in other .cc files in this directory. + +// test that determinization proceeds correctly on general +// FSTs (not guaranteed determinzable, but we use the +// max-states option to stop it getting out of control). +template void TestDeterminizeLatticePruned() { + typedef kaldi::int32 Int; + typedef typename Arc::Weight Weight; + typedef ArcTpl > CompactArc; + + for(int i = 0; i < 100; i++) { + RandFstOptions opts; + opts.n_states = 4; + opts.n_arcs = 10; + opts.n_final = 2; + opts.allow_empty = false; + opts.weight_multiplier = 0.5; // impt for the randomly generated weights + opts.acyclic = true; + // to be exactly representable in float, + // or this test fails because numerical differences can cause symmetry in + // weights to be broken, which causes the wrong path to be chosen as far + // as the string part is concerned. + + VectorFst *fst = RandPairFst(opts); + + bool sorted = TopSort(fst); + KALDI_ASSERT(sorted); + + ILabelCompare ilabel_comp; + if (kaldi::Rand() % 2 == 0) + ArcSort(fst, ilabel_comp); + + std::cout << "FST before lattice-determinizing is:\n"; + { + FstPrinter fstprinter(*fst, NULL, NULL, NULL, false, true, "\t"); + fstprinter.Print(&std::cout, "standard output"); + } + VectorFst det_fst; + try { + DeterminizeLatticePrunedOptions lat_opts; + lat_opts.max_mem = ((kaldi::Rand() % 2 == 0) ? 100 : 1000); + lat_opts.max_states = ((kaldi::Rand() % 2 == 0) ? -1 : 20); + lat_opts.max_arcs = ((kaldi::Rand() % 2 == 0) ? -1 : 30); + bool ans = DeterminizeLatticePruned(*fst, 10.0, &det_fst, lat_opts); + + std::cout << "FST after lattice-determinizing is:\n"; + { + FstPrinter fstprinter(det_fst, NULL, NULL, NULL, false, true, "\t"); + fstprinter.Print(&std::cout, "standard output"); + } + KALDI_ASSERT(det_fst.Properties(kIDeterministic, true) & kIDeterministic); + // OK, now determinize it a different way and check equivalence. + // [note: it's not normal determinization, it's taking the best path + // for any input-symbol sequence.... + + + VectorFst pruned_fst(*fst); + if (pruned_fst.NumStates() != 0) + kaldi::PruneLattice(10.0, &pruned_fst); + + VectorFst compact_pruned_fst, compact_pruned_det_fst; + ConvertLattice(pruned_fst, &compact_pruned_fst, false); + std::cout << "Compact pruned FST is:\n"; + { + FstPrinter fstprinter(compact_pruned_fst, NULL, NULL, NULL, false, true, "\t"); + fstprinter.Print(&std::cout, "standard output"); + } + ConvertLattice(det_fst, &compact_pruned_det_fst, false); + + std::cout << "Compact version of determinized FST is:\n"; + { + FstPrinter fstprinter(compact_pruned_det_fst, NULL, NULL, NULL, false, true, "\t"); + fstprinter.Print(&std::cout, "standard output"); + } + + if (ans) + KALDI_ASSERT(RandEquivalent(compact_pruned_det_fst, compact_pruned_fst, 5/*paths*/, 0.01/*delta*/, kaldi::Rand()/*seed*/, 100/*path length, max*/)); + } catch (...) { + std::cout << "Failed to lattice-determinize this FST (probably not determinizable)\n"; + } + delete fst; + } +} + +// test that determinization proceeds without crash on acyclic FSTs +// (guaranteed determinizable in this sense). +template void TestDeterminizeLatticePruned2() { + typedef typename Arc::Weight Weight; + RandFstOptions opts; + opts.acyclic = true; + for(int i = 0; i < 100; i++) { + VectorFst *fst = RandPairFst(opts); + std::cout << "FST before lattice-determinizing is:\n"; + { + FstPrinter fstprinter(*fst, NULL, NULL, NULL, false, true, "\t"); + fstprinter.Print(&std::cout, "standard output"); + } + VectorFst ofst; + DeterminizeLatticePruned(*fst, 10.0, &ofst); + std::cout << "FST after lattice-determinizing is:\n"; + { + FstPrinter fstprinter(ofst, NULL, NULL, NULL, false, true, "\t"); + fstprinter.Print(&std::cout, "standard output"); + } + delete fst; + } +} + + +} // end namespace fst + +int main() { + using namespace fst; + TestDeterminizeLatticePruned(); + TestDeterminizeLatticePruned2(); + std::cout << "Tests succeeded\n"; +} diff --git a/speechx/speechx/kaldi/lat/determinize-lattice-pruned.cc b/speechx/speechx/kaldi/lat/determinize-lattice-pruned.cc new file mode 100644 index 00000000..dbdd9af4 --- /dev/null +++ b/speechx/speechx/kaldi/lat/determinize-lattice-pruned.cc @@ -0,0 +1,1541 @@ +// lat/determinize-lattice-pruned.cc + +// Copyright 2009-2012 Microsoft Corporation +// 2012-2013 Johns Hopkins University (Author: Daniel Povey) +// 2014 Guoguo Chen + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "fstext/determinize-lattice.h" // for LatticeStringRepository +#include "fstext/fstext-utils.h" +#include "lat/lattice-functions.h" // for PruneLattice +#include "lat/minimize-lattice.h" // for minimization +#include "lat/push-lattice.h" // for minimization +#include "lat/determinize-lattice-pruned.h" + +namespace fst { + +using std::vector; +using std::pair; +using std::greater; + +// class LatticeDeterminizerPruned is templated on the same types that +// CompactLatticeWeight is templated on: the base weight (Weight), typically +// LatticeWeightTpl etc. but could also be e.g. TropicalWeight, and the +// IntType, typically int32, used for the output symbols in the compact +// representation of strings [note: the output symbols would usually be +// p.d.f. id's in the anticipated use of this code] It has a special requirement +// on the Weight type: that there should be a Compare function on the weights +// such that Compare(w1, w2) returns -1 if w1 < w2, 0 if w1 == w2, and +1 if w1 > +// w2. This requires that there be a total order on the weights. + +template class LatticeDeterminizerPruned { + public: + // Output to Gallic acceptor (so the strings go on weights, and there is a 1-1 correspondence + // between our states and the states in ofst. If destroy == true, release memory as we go + // (but we cannot output again). + + typedef CompactLatticeWeightTpl CompactWeight; + typedef ArcTpl CompactArc; // arc in compact, acceptor form of lattice + typedef ArcTpl Arc; // arc in non-compact version of lattice + + // Output to standard FST with CompactWeightTpl as its weight type (the + // weight stores the original output-symbol strings). If destroy == true, + // release memory as we go (but we cannot output again). + void Output(MutableFst *ofst, bool destroy = true) { + KALDI_ASSERT(determinized_); + typedef typename Arc::StateId StateId; + StateId nStates = static_cast(output_states_.size()); + if (destroy) + FreeMostMemory(); + ofst->DeleteStates(); + ofst->SetStart(kNoStateId); + if (nStates == 0) { + return; + } + for (StateId s = 0;s < nStates;s++) { + OutputStateId news = ofst->AddState(); + KALDI_ASSERT(news == s); + } + ofst->SetStart(0); + // now process transitions. + for (StateId this_state_id = 0; this_state_id < nStates; this_state_id++) { + OutputState &this_state = *(output_states_[this_state_id]); + vector &this_vec(this_state.arcs); + typename vector::const_iterator iter = this_vec.begin(), end = this_vec.end(); + + for (;iter != end; ++iter) { + const TempArc &temp_arc(*iter); + CompactArc new_arc; + vector