frontend itf

pull/1640/head
Hui Zhang 3 years ago
parent 36df70cbe6
commit a9f4ce47a3

@ -20,7 +20,7 @@
#include "frontend/audio_cache.h" #include "frontend/audio_cache.h"
#include "frontend/data_cache.h" #include "frontend/data_cache.h"
#include "frontend/feature_cache.h" #include "frontend/feature_cache.h"
#include "frontend/feature_extractor_interface.h" #include "frontend/frontend_itf.h"
#include "frontend/normalizer.h" #include "frontend/normalizer.h"
#include "kaldi/feat/wave-reader.h" #include "kaldi/feat/wave-reader.h"
#include "kaldi/util/kaldi-io.h" #include "kaldi/util/kaldi-io.h"
@ -170,13 +170,13 @@ int main(int argc, char* argv[]) {
// feature pipeline: wave cache --> decibel_normalizer --> hanning // feature pipeline: wave cache --> decibel_normalizer --> hanning
// window -->linear_spectrogram --> global cmvn -> feat cache // window -->linear_spectrogram --> global cmvn -> feat cache
// std::unique_ptr<ppspeech::FeatureExtractorInterface> data_source(new // std::unique_ptr<ppspeech::FrontendInterface> data_source(new
// ppspeech::DataCache()); // ppspeech::DataCache());
std::unique_ptr<ppspeech::FeatureExtractorInterface> data_source( std::unique_ptr<ppspeech::FrontendInterface> data_source(
new ppspeech::AudioCache()); new ppspeech::AudioCache());
ppspeech::DecibelNormalizerOptions db_norm_opt; ppspeech::DecibelNormalizerOptions db_norm_opt;
std::unique_ptr<ppspeech::FeatureExtractorInterface> db_norm( std::unique_ptr<ppspeech::FrontendInterface> db_norm(
new ppspeech::DecibelNormalizer(db_norm_opt, std::move(data_source))); new ppspeech::DecibelNormalizer(db_norm_opt, std::move(data_source)));
ppspeech::LinearSpectrogramOptions opt; ppspeech::LinearSpectrogramOptions opt;
@ -185,10 +185,10 @@ int main(int argc, char* argv[]) {
LOG(INFO) << "frame length (ms): " << opt.frame_opts.frame_length_ms; LOG(INFO) << "frame length (ms): " << opt.frame_opts.frame_length_ms;
LOG(INFO) << "frame shift (ms): " << opt.frame_opts.frame_shift_ms; LOG(INFO) << "frame shift (ms): " << opt.frame_opts.frame_shift_ms;
std::unique_ptr<ppspeech::FeatureExtractorInterface> linear_spectrogram( std::unique_ptr<ppspeech::FrontendInterface> linear_spectrogram(
new ppspeech::LinearSpectrogram(opt, std::move(db_norm))); new ppspeech::LinearSpectrogram(opt, std::move(db_norm)));
std::unique_ptr<ppspeech::FeatureExtractorInterface> cmvn( std::unique_ptr<ppspeech::FrontendInterface> cmvn(
new ppspeech::CMVN(FLAGS_cmvn_write_path, new ppspeech::CMVN(FLAGS_cmvn_write_path,
std::move(linear_spectrogram))); std::move(linear_spectrogram)));

@ -16,12 +16,12 @@
#pragma once #pragma once
#include "base/common.h" #include "base/common.h"
#include "frontend/feature_extractor_interface.h" #include "frontend/frontend_itf.h"
namespace ppspeech { namespace ppspeech {
// waves cache // waves cache
class AudioCache : public FeatureExtractorInterface { class AudioCache : public FrontendInterface {
public: public:
explicit AudioCache(int buffer_size = kint16max); explicit AudioCache(int buffer_size = kint16max);

@ -17,13 +17,13 @@
#include "base/common.h" #include "base/common.h"
#include "frontend/feature_extractor_interface.h" #include "frontend/frontend_itf.h"
namespace ppspeech { namespace ppspeech {
// A data source for testing different frontend module. // A data source for testing different frontend module.
// It accepts waves or feats. // It accepts waves or feats.
class DataCache : public FeatureExtractorInterface { class DataCache : public FrontendInterface {
public: public:
explicit DataCache() { finished_ = false; } explicit DataCache() { finished_ = false; }

@ -20,10 +20,10 @@
namespace ppspeech { namespace ppspeech {
class FbankExtractor : FeatureExtractorInterface { class FbankExtractor : FrontendInterface {
public: public:
explicit FbankExtractor(const FbankOptions& opts, explicit FbankExtractor(const FbankOptions& opts,
share_ptr<FeatureExtractorInterface> pre_extractor); share_ptr<FrontendInterface> pre_extractor);
virtual void AcceptWaveform( virtual void AcceptWaveform(
const kaldi::Vector<kaldi::BaseFloat>& input) = 0; const kaldi::Vector<kaldi::BaseFloat>& input) = 0;
virtual void Read(kaldi::Vector<kaldi::BaseFloat>* feat) = 0; virtual void Read(kaldi::Vector<kaldi::BaseFloat>* feat) = 0;

@ -24,7 +24,7 @@ using kaldi::SubVector;
using std::unique_ptr; using std::unique_ptr;
FeatureCache::FeatureCache( FeatureCache::FeatureCache(
int max_size, unique_ptr<FeatureExtractorInterface> base_extractor) { int max_size, unique_ptr<FrontendInterface> base_extractor) {
max_size_ = max_size; max_size_ = max_size;
base_extractor_ = std::move(base_extractor); base_extractor_ = std::move(base_extractor);
} }

@ -15,15 +15,15 @@
#pragma once #pragma once
#include "base/common.h" #include "base/common.h"
#include "frontend/feature_extractor_interface.h" #include "frontend/frontend_itf.h"
namespace ppspeech { namespace ppspeech {
class FeatureCache : public FeatureExtractorInterface { class FeatureCache : public FrontendInterface {
public: public:
explicit FeatureCache( explicit FeatureCache(
int32 max_size = kint16max, int32 max_size = kint16max,
std::unique_ptr<FeatureExtractorInterface> base_extractor = NULL); std::unique_ptr<FrontendInterface> base_extractor = NULL);
// Feed feats or waves // Feed feats or waves
virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs); virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs);
@ -53,7 +53,7 @@ class FeatureCache : public FeatureExtractorInterface {
bool Compute(); bool Compute();
size_t max_size_; size_t max_size_;
std::unique_ptr<FeatureExtractorInterface> base_extractor_; std::unique_ptr<FrontendInterface> base_extractor_;
std::mutex mutex_; std::mutex mutex_;
std::queue<kaldi::Vector<BaseFloat>> cache_; std::queue<kaldi::Vector<BaseFloat>> cache_;

@ -1,13 +0,0 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

@ -1,13 +0,0 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

@ -19,7 +19,7 @@
namespace ppspeech { namespace ppspeech {
class FeatureExtractorInterface { class FrontendInterface {
public: public:
// Feed inputs: features(2D saved in 1D) or waveforms(1D). // Feed inputs: features(2D saved in 1D) or waveforms(1D).
virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs) = 0; virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs) = 0;

@ -27,7 +27,7 @@ using std::vector;
LinearSpectrogram::LinearSpectrogram( LinearSpectrogram::LinearSpectrogram(
const LinearSpectrogramOptions& opts, const LinearSpectrogramOptions& opts,
std::unique_ptr<FeatureExtractorInterface> base_extractor) { std::unique_ptr<FrontendInterface> base_extractor) {
opts_ = opts; opts_ = opts;
base_extractor_ = std::move(base_extractor); base_extractor_ = std::move(base_extractor);
int32 window_size = opts.frame_opts.WindowSize(); int32 window_size = opts.frame_opts.WindowSize();

@ -16,7 +16,7 @@
#pragma once #pragma once
#include "base/common.h" #include "base/common.h"
#include "frontend/feature_extractor_interface.h" #include "frontend/frontend_itf.h"
#include "kaldi/feat/feature-window.h" #include "kaldi/feat/feature-window.h"
namespace ppspeech { namespace ppspeech {
@ -35,11 +35,11 @@ struct LinearSpectrogramOptions {
} }
}; };
class LinearSpectrogram : public FeatureExtractorInterface { class LinearSpectrogram : public FrontendInterface {
public: public:
explicit LinearSpectrogram( explicit LinearSpectrogram(
const LinearSpectrogramOptions& opts, const LinearSpectrogramOptions& opts,
std::unique_ptr<FeatureExtractorInterface> base_extractor); std::unique_ptr<FrontendInterface> base_extractor);
virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs); virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs);
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* feats); virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* feats);
// the dim_ is the dim of single frame feature // the dim_ is the dim of single frame feature
@ -61,7 +61,7 @@ class LinearSpectrogram : public FeatureExtractorInterface {
std::vector<kaldi::BaseFloat> hanning_window_; std::vector<kaldi::BaseFloat> hanning_window_;
kaldi::BaseFloat hanning_window_energy_; kaldi::BaseFloat hanning_window_energy_;
LinearSpectrogramOptions opts_; LinearSpectrogramOptions opts_;
std::unique_ptr<FeatureExtractorInterface> base_extractor_; std::unique_ptr<FrontendInterface> base_extractor_;
int chunk_sample_size_; int chunk_sample_size_;
DISALLOW_COPY_AND_ASSIGN(LinearSpectrogram); DISALLOW_COPY_AND_ASSIGN(LinearSpectrogram);
}; };

@ -28,7 +28,7 @@ using std::unique_ptr;
DecibelNormalizer::DecibelNormalizer( DecibelNormalizer::DecibelNormalizer(
const DecibelNormalizerOptions& opts, const DecibelNormalizerOptions& opts,
std::unique_ptr<FeatureExtractorInterface> base_extractor) { std::unique_ptr<FrontendInterface> base_extractor) {
base_extractor_ = std::move(base_extractor); base_extractor_ = std::move(base_extractor);
opts_ = opts; opts_ = opts;
dim_ = 1; dim_ = 1;
@ -92,7 +92,7 @@ bool DecibelNormalizer::Compute(VectorBase<BaseFloat>* waves) const {
} }
CMVN::CMVN(std::string cmvn_file, CMVN::CMVN(std::string cmvn_file,
unique_ptr<FeatureExtractorInterface> base_extractor) unique_ptr<FrontendInterface> base_extractor)
: var_norm_(true) { : var_norm_(true) {
base_extractor_ = std::move(base_extractor); base_extractor_ = std::move(base_extractor);
bool binary; bool binary;

@ -16,7 +16,7 @@
#pragma once #pragma once
#include "base/common.h" #include "base/common.h"
#include "frontend/feature_extractor_interface.h" #include "frontend/frontend_itf.h"
#include "kaldi/matrix/kaldi-matrix.h" #include "kaldi/matrix/kaldi-matrix.h"
#include "kaldi/util/options-itf.h" #include "kaldi/util/options-itf.h"
@ -40,11 +40,11 @@ struct DecibelNormalizerOptions {
} }
}; };
class DecibelNormalizer : public FeatureExtractorInterface { class DecibelNormalizer : public FrontendInterface {
public: public:
explicit DecibelNormalizer( explicit DecibelNormalizer(
const DecibelNormalizerOptions& opts, const DecibelNormalizerOptions& opts,
std::unique_ptr<FeatureExtractorInterface> base_extractor); std::unique_ptr<FrontendInterface> base_extractor);
virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& waves); virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& waves);
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* waves); virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* waves);
// noramlize audio, the dim is 1. // noramlize audio, the dim is 1.
@ -57,15 +57,15 @@ class DecibelNormalizer : public FeatureExtractorInterface {
bool Compute(kaldi::VectorBase<kaldi::BaseFloat>* waves) const; bool Compute(kaldi::VectorBase<kaldi::BaseFloat>* waves) const;
DecibelNormalizerOptions opts_; DecibelNormalizerOptions opts_;
size_t dim_; size_t dim_;
std::unique_ptr<FeatureExtractorInterface> base_extractor_; std::unique_ptr<FrontendInterface> base_extractor_;
kaldi::Vector<kaldi::BaseFloat> waveform_; kaldi::Vector<kaldi::BaseFloat> waveform_;
}; };
class CMVN : public FeatureExtractorInterface { class CMVN : public FrontendInterface {
public: public:
explicit CMVN(std::string cmvn_file, explicit CMVN(std::string cmvn_file,
std::unique_ptr<FeatureExtractorInterface> base_extractor); std::unique_ptr<FrontendInterface> base_extractor);
virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs); virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs);
// the length of feats = feature_row * feature_dim, // the length of feats = feature_row * feature_dim,
@ -81,7 +81,7 @@ class CMVN : public FeatureExtractorInterface {
void Compute(kaldi::VectorBase<kaldi::BaseFloat>* feats) const; void Compute(kaldi::VectorBase<kaldi::BaseFloat>* feats) const;
void ApplyCMVN(kaldi::MatrixBase<BaseFloat>* feats); void ApplyCMVN(kaldi::MatrixBase<BaseFloat>* feats);
kaldi::Matrix<double> stats_; kaldi::Matrix<double> stats_;
std::unique_ptr<FeatureExtractorInterface> base_extractor_; std::unique_ptr<FrontendInterface> base_extractor_;
size_t dim_; size_t dim_;
bool var_norm_; bool var_norm_;
}; };

@ -22,7 +22,7 @@ using std::vector;
using kaldi::Vector; using kaldi::Vector;
Decodable::Decodable(const std::shared_ptr<NnetInterface>& nnet, Decodable::Decodable(const std::shared_ptr<NnetInterface>& nnet,
const std::shared_ptr<FeatureExtractorInterface>& frontend) const std::shared_ptr<FrontendInterface>& frontend)
: frontend_(frontend), nnet_(nnet), frame_offset_(0), frames_ready_(0) {} : frontend_(frontend), nnet_(nnet), frame_offset_(0), frames_ready_(0) {}
void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) { void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) {

@ -13,7 +13,7 @@
// limitations under the License. // limitations under the License.
#include "base/common.h" #include "base/common.h"
#include "frontend/feature_extractor_interface.h" #include "frontend/frontend_itf.h"
#include "kaldi/matrix/kaldi-matrix.h" #include "kaldi/matrix/kaldi-matrix.h"
#include "nnet/decodable-itf.h" #include "nnet/decodable-itf.h"
#include "nnet/nnet_interface.h" #include "nnet/nnet_interface.h"
@ -26,7 +26,7 @@ class Decodable : public kaldi::DecodableInterface {
public: public:
explicit Decodable( explicit Decodable(
const std::shared_ptr<NnetInterface>& nnet, const std::shared_ptr<NnetInterface>& nnet,
const std::shared_ptr<FeatureExtractorInterface>& frontend); const std::shared_ptr<FrontendInterface>& frontend);
// void Init(DecodableOpts config); // void Init(DecodableOpts config);
virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index); virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index);
virtual bool IsLastFrame(int32 frame) const; virtual bool IsLastFrame(int32 frame) const;
@ -41,7 +41,7 @@ class Decodable : public kaldi::DecodableInterface {
private: private:
bool AdvanceChunk(); bool AdvanceChunk();
std::shared_ptr<FeatureExtractorInterface> frontend_; std::shared_ptr<FrontendInterface> frontend_;
std::shared_ptr<NnetInterface> nnet_; std::shared_ptr<NnetInterface> nnet_;
kaldi::Matrix<kaldi::BaseFloat> nnet_cache_; kaldi::Matrix<kaldi::BaseFloat> nnet_cache_;
// std::vector<std::vector<kaldi::BaseFloat>> nnet_cache_; // std::vector<std::vector<kaldi::BaseFloat>> nnet_cache_;

Loading…
Cancel
Save