diff --git a/speechx/speechx/frontend/feature_extractor_interface.h b/speechx/speechx/frontend/feature_extractor_interface.h index 3f3f0ff9..863c4281 100644 --- a/speechx/speechx/frontend/feature_extractor_interface.h +++ b/speechx/speechx/frontend/feature_extractor_interface.h @@ -23,6 +23,8 @@ class FeatureExtractorInterface { public: virtual void AcceptWaveform(const kaldi::Vector& input) = 0; virtual void Read(kaldi::Vector* feat) = 0; + virtual void Compute(const kaldi::VectorBase& input, + kaldi::VectorBae* feature) = 0; virtual size_t Dim() const = 0; }; diff --git a/speechx/speechx/frontend/linear_spectrogram.cc b/speechx/speechx/frontend/linear_spectrogram.cc index df8a29f9..3545cb53 100644 --- a/speechx/speechx/frontend/linear_spectrogram.cc +++ b/speechx/speechx/frontend/linear_spectrogram.cc @@ -22,7 +22,10 @@ using kaldi::Vector; using kaldi::Matrix; using std::vector; -LinearSpectrogram::LinearSpectrogram(const LinearSpectrogramOptions& opts) { +LinearSpectrogram::LinearSpectrogram( + const LinearSpectrogramOptions& opts, + const std::unique_ptr base_extractor) { + base_extractor_ = std::move(base_extractor); int32 window_size = opts.frame_opts.WindowSize(); int32 window_shift = opts.frame_opts.WindowShift(); fft_points_ = window_size; @@ -34,6 +37,8 @@ LinearSpectrogram::LinearSpectrogram(const LinearSpectrogramOptions& opts) { hanning_window_[i] = 0.5 - 0.5 * cos(a * i); hanning_window_energy_ += hanning_window_[i] * hanning_window_[i]; } + + dim_ = fft_points_ / 2 + 1; // the dimension is Fs/2 Hz } void LinearSpectrogram::AcceptWavefrom(const Vector& input) { @@ -70,27 +75,42 @@ bool LinearSpectrogram::NumpyFft(vector* v, return true; } -// todo refactor later +//todo remove later +void CopyVector2StdVector(const kaldi::Vector& input, + vector* output) { +} + +// todo remove later bool LinearSpectrogram::ReadFeats(Matrix* feats) const { - vector> feat; - if (wavefrom_.empty()) { + if (wavefrom_.Dim() == 0) { return false; } + kaldi::Vector feats; + Compute(wavefrom_, &feats); vector> result; - Compute(wavefrom_, result); + vector feats_vec; + CopyVector2StdVector(feats, &feats_vec); + Compute(feats_vec, result); feats->Resize(result.size(), result[0].size()); for (int row_idx = 0; row_idx < result.size(); ++row_idx) { for (int col_idx = 0; col_idx < result.size(); ++col_idx) { feats(row_idx, col_idx) = result[row_idx][col_idx]; } - wavefrom_.clear(); + wavefrom_.Resize(0); return true; } -// Compute spectrogram feat, return num frames +// only for test, remove later +// todo: compute the feature frame by frame. +void LinearSpectrogram::Compute(const kaldi::VectorBase& input, + kaldi::VectorBae* feature) { + base_extractor_->Compute(input, feature); +} + +// Compute spectrogram feat, only for test, remove later // todo: refactor later (SmileGoat) bool LinearSpectrogram::Compute(const vector& wave, - vector>& feat) { + vector>& feat) { int num_samples = wave.size(); const int& frame_length = opts.frame_opts.WindowSize(); const int& sample_rate = opts.frame_opts.samp_freq; diff --git a/speechx/speechx/frontend/linear_spectrogram.h b/speechx/speechx/frontend/linear_spectrogram.h index 981f92ea..16683890 100644 --- a/speechx/speechx/frontend/linear_spectrogram.h +++ b/speechx/speechx/frontend/linear_spectrogram.h @@ -19,16 +19,19 @@ struct LinearSpectrogramOptions { class LinearSpectrogram : public FeatureExtractorInterface { public: - explict LinearSpectrogram(const LinearSpectrogramOptions& opts); + explict LinearSpectrogram(const LinearSpectrogramOptions& opts, + const std::unique_ptr base_extractor); virtual void AcceptWavefrom(const kaldi::Vector& input); virtual void Read(kaldi::Vector* feat); - virtual size_t Dim() const; + virtual size_t Dim() const { return dim_; } void ReadFeats(kaldi::Matrix* feats) const; private: void Hanning(std::vector& data) const; kaldi::int32 Compute(const std::vector& wave, std::vector>& feat); + void Compute(const kaldi::VectorBase& input, + kaldi::VectorBae* feature); bool NumpyFft(std::vector* v, std::vector* real, std::vector* img) const; @@ -38,7 +41,8 @@ class LinearSpectrogram : public FeatureExtractorInterface { std::vector hanning_window_; kaldi::BaseFloat hanning_window_energy_; LinearSpectrogramOptions opts_; - std::vector wavefrom_; // remove later, todo(SmileGoat) + kaldi::Vector wavefrom_; // remove later, todo(SmileGoat) + std::unique_ptr base_extractor_; DISALLOW_COPY_AND_ASSIGN(LinearSpectrogram); }; diff --git a/speechx/speechx/frontend/linear_spectrogram_main.cc b/speechx/speechx/frontend/linear_spectrogram_main.cc index 455f4f91..352e7225 100644 --- a/speechx/speechx/frontend/linear_spectrogram_main.cc +++ b/speechx/speechx/frontend/linear_spectrogram_main.cc @@ -1,5 +1,7 @@ +// todo refactor, repalce with gtest #include "frontend/linear_spectrogram.h" +#include "frontend/normalizer.h" #include "kaldi/util/table-types.h" #include "base/log.h" #include "base/flags.h" @@ -15,9 +17,14 @@ int main(int argc, char* argv[]) { kaldi::SequentialTableReader wav_reader(FLAGS_wav_rspecifier); kaldi::BaseFloatMatrixWriter feat_writer(FLAGS_feature_wspecifier); + // test feature linear_spectorgram: wave --> decibel_normalizer --> hanning window -->linear_spectrogram --> cmvn int32 num_done = 0, num_err = 0; ppspeech::LinearSpectrogramOptions opt; - ppspeech::LinearSpectrogram linear_spectrogram(opt); + ppspeech::DecibelNormalizerOptions db_norm_opt; + std::unique_ptr base_feature_extractor = + new DecibelNormalizer(db_norm_opt); + ppspeech::LinearSpectrogram linear_spectrogram(opt, base_featrue_extractor); + for (; !wav_reader.Done(); wav_reader.Next()) { std::string utt = wav_reader.Key(); const kaldi::WaveData &wave_data = wav_reader.Value(); diff --git a/speechx/speechx/frontend/normalizer.cc b/speechx/speechx/frontend/normalizer.cc index 9a384484..ca27d6ac 100644 --- a/speechx/speechx/frontend/normalizer.cc +++ b/speechx/speechx/frontend/normalizer.cc @@ -2,8 +2,7 @@ #include "frontend/normalizer.h" DecibelNormalizer::DecibelNormalizer( - const DecibelNormalizerOptions& opts, - const std::unique_ptr& pre_extractor) { + const DecibelNormalizerOptions& opts) { } diff --git a/speechx/speechx/nnet/nnet_interface.h b/speechx/speechx/nnet/nnet_interface.h new file mode 100644 index 00000000..e999b8f0 --- /dev/null +++ b/speechx/speechx/nnet/nnet_interface.h @@ -0,0 +1,16 @@ + +#pragma once + +#include "" + +namespace ppspeech { + +class NnetForwardInterface { + public: + virtual ~NnetForwardInterface() {} + virtual void FeedForward(const kaldi::Matrix& features, + kaldi::Vector* inference) const = 0; + +}; + +} // namespace ppspeech \ No newline at end of file