test linear spectrogram feature

pull/1400/head
SmileGoat 3 years ago
parent a01fa866a4
commit f03d48f79b

@ -23,6 +23,8 @@ class FeatureExtractorInterface {
public: public:
virtual void AcceptWaveform(const kaldi::Vector<kaldi::BaseFloat>& input) = 0; virtual void AcceptWaveform(const kaldi::Vector<kaldi::BaseFloat>& input) = 0;
virtual void Read(kaldi::Vector<kaldi::BaseFloat>* feat) = 0; virtual void Read(kaldi::Vector<kaldi::BaseFloat>* feat) = 0;
virtual void Compute(const kaldi::VectorBase<kaldi::BaseFloat>& input,
kaldi::VectorBae<kaldi::BaseFloat>* feature) = 0;
virtual size_t Dim() const = 0; virtual size_t Dim() const = 0;
}; };

@ -22,7 +22,10 @@ using kaldi::Vector;
using kaldi::Matrix; using kaldi::Matrix;
using std::vector; using std::vector;
LinearSpectrogram::LinearSpectrogram(const LinearSpectrogramOptions& opts) { LinearSpectrogram::LinearSpectrogram(
const LinearSpectrogramOptions& opts,
const std::unique_ptr<FeatureExtractorInterface> base_extractor) {
base_extractor_ = std::move(base_extractor);
int32 window_size = opts.frame_opts.WindowSize(); int32 window_size = opts.frame_opts.WindowSize();
int32 window_shift = opts.frame_opts.WindowShift(); int32 window_shift = opts.frame_opts.WindowShift();
fft_points_ = window_size; fft_points_ = window_size;
@ -34,6 +37,8 @@ LinearSpectrogram::LinearSpectrogram(const LinearSpectrogramOptions& opts) {
hanning_window_[i] = 0.5 - 0.5 * cos(a * i); hanning_window_[i] = 0.5 - 0.5 * cos(a * i);
hanning_window_energy_ += hanning_window_[i] * hanning_window_[i]; hanning_window_energy_ += hanning_window_[i] * hanning_window_[i];
} }
dim_ = fft_points_ / 2 + 1; // the dimension is Fs/2 Hz
} }
void LinearSpectrogram::AcceptWavefrom(const Vector<BaseFloat>& input) { void LinearSpectrogram::AcceptWavefrom(const Vector<BaseFloat>& input) {
@ -70,27 +75,42 @@ bool LinearSpectrogram::NumpyFft(vector<BaseFloat>* v,
return true; return true;
} }
// todo refactor later //todo remove later
void CopyVector2StdVector(const kaldi::Vector<BaseFloat>& input,
vector<BaseFloat>* output) {
}
// todo remove later
bool LinearSpectrogram::ReadFeats(Matrix<BaseFloat>* feats) const { bool LinearSpectrogram::ReadFeats(Matrix<BaseFloat>* feats) const {
vector<vector<BaseFloat>> feat; if (wavefrom_.Dim() == 0) {
if (wavefrom_.empty()) {
return false; return false;
} }
kaldi::Vector<BaseFloat> feats;
Compute(wavefrom_, &feats);
vector<vector<BaseFloat>> result; vector<vector<BaseFloat>> result;
Compute(wavefrom_, result); vector<BaseFloat> feats_vec;
CopyVector2StdVector(feats, &feats_vec);
Compute(feats_vec, result);
feats->Resize(result.size(), result[0].size()); feats->Resize(result.size(), result[0].size());
for (int row_idx = 0; row_idx < result.size(); ++row_idx) { for (int row_idx = 0; row_idx < result.size(); ++row_idx) {
for (int col_idx = 0; col_idx < result.size(); ++col_idx) { for (int col_idx = 0; col_idx < result.size(); ++col_idx) {
feats(row_idx, col_idx) = result[row_idx][col_idx]; feats(row_idx, col_idx) = result[row_idx][col_idx];
} }
wavefrom_.clear(); wavefrom_.Resize(0);
return true; return true;
} }
// Compute spectrogram feat, return num frames // only for test, remove later
// todo: compute the feature frame by frame.
void LinearSpectrogram::Compute(const kaldi::VectorBase<kaldi::BaseFloat>& input,
kaldi::VectorBae<kaldi::BaseFloat>* feature) {
base_extractor_->Compute(input, feature);
}
// Compute spectrogram feat, only for test, remove later
// todo: refactor later (SmileGoat) // todo: refactor later (SmileGoat)
bool LinearSpectrogram::Compute(const vector<float>& wave, bool LinearSpectrogram::Compute(const vector<float>& wave,
vector<vector<float>>& feat) { vector<vector<float>>& feat) {
int num_samples = wave.size(); int num_samples = wave.size();
const int& frame_length = opts.frame_opts.WindowSize(); const int& frame_length = opts.frame_opts.WindowSize();
const int& sample_rate = opts.frame_opts.samp_freq; const int& sample_rate = opts.frame_opts.samp_freq;

@ -19,16 +19,19 @@ struct LinearSpectrogramOptions {
class LinearSpectrogram : public FeatureExtractorInterface { class LinearSpectrogram : public FeatureExtractorInterface {
public: public:
explict LinearSpectrogram(const LinearSpectrogramOptions& opts); explict LinearSpectrogram(const LinearSpectrogramOptions& opts,
const std::unique_ptr<FeatureExtractorInterface> base_extractor);
virtual void AcceptWavefrom(const kaldi::Vector<kaldi::BaseFloat>& input); virtual void AcceptWavefrom(const kaldi::Vector<kaldi::BaseFloat>& input);
virtual void Read(kaldi::Vector<kaldi::BaseFloat>* feat); virtual void Read(kaldi::Vector<kaldi::BaseFloat>* feat);
virtual size_t Dim() const; virtual size_t Dim() const { return dim_; }
void ReadFeats(kaldi::Matrix<kaldi::BaesFloat>* feats) const; void ReadFeats(kaldi::Matrix<kaldi::BaesFloat>* feats) const;
private: private:
void Hanning(std::vector<kaldi::BaseFloat>& data) const; void Hanning(std::vector<kaldi::BaseFloat>& data) const;
kaldi::int32 Compute(const std::vector<kaldi::BaseFloat>& wave, kaldi::int32 Compute(const std::vector<kaldi::BaseFloat>& wave,
std::vector<std::vector<kaldi::BaseFloat>>& feat); std::vector<std::vector<kaldi::BaseFloat>>& feat);
void Compute(const kaldi::VectorBase<kaldi::BaseFloat>& input,
kaldi::VectorBae<kaldi::BaseFloat>* feature);
bool NumpyFft(std::vector<kaldi::BaseFloat>* v, bool NumpyFft(std::vector<kaldi::BaseFloat>* v,
std::vector<kaldi::BaseFloat>* real, std::vector<kaldi::BaseFloat>* real,
std::vector<kaldi::BaseFloat>* img) const; std::vector<kaldi::BaseFloat>* img) const;
@ -38,7 +41,8 @@ class LinearSpectrogram : public FeatureExtractorInterface {
std::vector<kaldi::BaseFloat> hanning_window_; std::vector<kaldi::BaseFloat> hanning_window_;
kaldi::BaseFloat hanning_window_energy_; kaldi::BaseFloat hanning_window_energy_;
LinearSpectrogramOptions opts_; LinearSpectrogramOptions opts_;
std::vector<kaldi::BaseFloat> wavefrom_; // remove later, todo(SmileGoat) kaldi::Vector<kaldi::BaseFloat> wavefrom_; // remove later, todo(SmileGoat)
std::unique_ptr<FeatureExtractorInterface> base_extractor_;
DISALLOW_COPY_AND_ASSIGN(LinearSpectrogram); DISALLOW_COPY_AND_ASSIGN(LinearSpectrogram);
}; };

@ -1,5 +1,7 @@
// todo refactor, repalce with gtest
#include "frontend/linear_spectrogram.h" #include "frontend/linear_spectrogram.h"
#include "frontend/normalizer.h"
#include "kaldi/util/table-types.h" #include "kaldi/util/table-types.h"
#include "base/log.h" #include "base/log.h"
#include "base/flags.h" #include "base/flags.h"
@ -15,9 +17,14 @@ int main(int argc, char* argv[]) {
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(FLAGS_wav_rspecifier); kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(FLAGS_wav_rspecifier);
kaldi::BaseFloatMatrixWriter feat_writer(FLAGS_feature_wspecifier); kaldi::BaseFloatMatrixWriter feat_writer(FLAGS_feature_wspecifier);
// test feature linear_spectorgram: wave --> decibel_normalizer --> hanning window -->linear_spectrogram --> cmvn
int32 num_done = 0, num_err = 0; int32 num_done = 0, num_err = 0;
ppspeech::LinearSpectrogramOptions opt; ppspeech::LinearSpectrogramOptions opt;
ppspeech::LinearSpectrogram linear_spectrogram(opt); ppspeech::DecibelNormalizerOptions db_norm_opt;
std::unique_ptr<ppspeech::FeatureExtractorInterface> base_feature_extractor =
new DecibelNormalizer(db_norm_opt);
ppspeech::LinearSpectrogram linear_spectrogram(opt, base_featrue_extractor);
for (; !wav_reader.Done(); wav_reader.Next()) { for (; !wav_reader.Done(); wav_reader.Next()) {
std::string utt = wav_reader.Key(); std::string utt = wav_reader.Key();
const kaldi::WaveData &wave_data = wav_reader.Value(); const kaldi::WaveData &wave_data = wav_reader.Value();

@ -2,8 +2,7 @@
#include "frontend/normalizer.h" #include "frontend/normalizer.h"
DecibelNormalizer::DecibelNormalizer( DecibelNormalizer::DecibelNormalizer(
const DecibelNormalizerOptions& opts, const DecibelNormalizerOptions& opts) {
const std::unique_ptr<FeatureExtractorInterface>& pre_extractor) {
} }

@ -0,0 +1,16 @@
#pragma once
#include ""
namespace ppspeech {
class NnetForwardInterface {
public:
virtual ~NnetForwardInterface() {}
virtual void FeedForward(const kaldi::Matrix<BaseFloat>& features,
kaldi::Vector<kaldi::BaseFloat>* inference) const = 0;
};
} // namespace ppspeech
Loading…
Cancel
Save