frontend itf

pull/1640/head
Hui Zhang 2 years ago
parent 36df70cbe6
commit a9f4ce47a3

@ -20,7 +20,7 @@
#include "frontend/audio_cache.h"
#include "frontend/data_cache.h"
#include "frontend/feature_cache.h"
#include "frontend/feature_extractor_interface.h"
#include "frontend/frontend_itf.h"
#include "frontend/normalizer.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/kaldi-io.h"
@ -170,13 +170,13 @@ int main(int argc, char* argv[]) {
// feature pipeline: wave cache --> decibel_normalizer --> hanning
// window -->linear_spectrogram --> global cmvn -> feat cache
// std::unique_ptr<ppspeech::FeatureExtractorInterface> data_source(new
// std::unique_ptr<ppspeech::FrontendInterface> data_source(new
// ppspeech::DataCache());
std::unique_ptr<ppspeech::FeatureExtractorInterface> data_source(
std::unique_ptr<ppspeech::FrontendInterface> data_source(
new ppspeech::AudioCache());
ppspeech::DecibelNormalizerOptions db_norm_opt;
std::unique_ptr<ppspeech::FeatureExtractorInterface> db_norm(
std::unique_ptr<ppspeech::FrontendInterface> db_norm(
new ppspeech::DecibelNormalizer(db_norm_opt, std::move(data_source)));
ppspeech::LinearSpectrogramOptions opt;
@ -185,10 +185,10 @@ int main(int argc, char* argv[]) {
LOG(INFO) << "frame length (ms): " << opt.frame_opts.frame_length_ms;
LOG(INFO) << "frame shift (ms): " << opt.frame_opts.frame_shift_ms;
std::unique_ptr<ppspeech::FeatureExtractorInterface> linear_spectrogram(
std::unique_ptr<ppspeech::FrontendInterface> linear_spectrogram(
new ppspeech::LinearSpectrogram(opt, std::move(db_norm)));
std::unique_ptr<ppspeech::FeatureExtractorInterface> cmvn(
std::unique_ptr<ppspeech::FrontendInterface> cmvn(
new ppspeech::CMVN(FLAGS_cmvn_write_path,
std::move(linear_spectrogram)));

@ -16,12 +16,12 @@
#pragma once
#include "base/common.h"
#include "frontend/feature_extractor_interface.h"
#include "frontend/frontend_itf.h"
namespace ppspeech {
// waves cache
class AudioCache : public FeatureExtractorInterface {
class AudioCache : public FrontendInterface {
public:
explicit AudioCache(int buffer_size = kint16max);

@ -17,13 +17,13 @@
#include "base/common.h"
#include "frontend/feature_extractor_interface.h"
#include "frontend/frontend_itf.h"
namespace ppspeech {
// A data source for testing different frontend module.
// It accepts waves or feats.
class DataCache : public FeatureExtractorInterface {
class DataCache : public FrontendInterface {
public:
explicit DataCache() { finished_ = false; }

@ -20,10 +20,10 @@
namespace ppspeech {
class FbankExtractor : FeatureExtractorInterface {
class FbankExtractor : FrontendInterface {
public:
explicit FbankExtractor(const FbankOptions& opts,
share_ptr<FeatureExtractorInterface> pre_extractor);
share_ptr<FrontendInterface> pre_extractor);
virtual void AcceptWaveform(
const kaldi::Vector<kaldi::BaseFloat>& input) = 0;
virtual void Read(kaldi::Vector<kaldi::BaseFloat>* feat) = 0;

@ -24,7 +24,7 @@ using kaldi::SubVector;
using std::unique_ptr;
FeatureCache::FeatureCache(
int max_size, unique_ptr<FeatureExtractorInterface> base_extractor) {
int max_size, unique_ptr<FrontendInterface> base_extractor) {
max_size_ = max_size;
base_extractor_ = std::move(base_extractor);
}

@ -15,15 +15,15 @@
#pragma once
#include "base/common.h"
#include "frontend/feature_extractor_interface.h"
#include "frontend/frontend_itf.h"
namespace ppspeech {
class FeatureCache : public FeatureExtractorInterface {
class FeatureCache : public FrontendInterface {
public:
explicit FeatureCache(
int32 max_size = kint16max,
std::unique_ptr<FeatureExtractorInterface> base_extractor = NULL);
std::unique_ptr<FrontendInterface> base_extractor = NULL);
// Feed feats or waves
virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs);
@ -53,7 +53,7 @@ class FeatureCache : public FeatureExtractorInterface {
bool Compute();
size_t max_size_;
std::unique_ptr<FeatureExtractorInterface> base_extractor_;
std::unique_ptr<FrontendInterface> base_extractor_;
std::mutex mutex_;
std::queue<kaldi::Vector<BaseFloat>> cache_;

@ -1,13 +0,0 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

@ -1,13 +0,0 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

@ -19,7 +19,7 @@
namespace ppspeech {
class FeatureExtractorInterface {
class FrontendInterface {
public:
// Feed inputs: features(2D saved in 1D) or waveforms(1D).
virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs) = 0;

@ -27,7 +27,7 @@ using std::vector;
LinearSpectrogram::LinearSpectrogram(
const LinearSpectrogramOptions& opts,
std::unique_ptr<FeatureExtractorInterface> base_extractor) {
std::unique_ptr<FrontendInterface> base_extractor) {
opts_ = opts;
base_extractor_ = std::move(base_extractor);
int32 window_size = opts.frame_opts.WindowSize();

@ -16,7 +16,7 @@
#pragma once
#include "base/common.h"
#include "frontend/feature_extractor_interface.h"
#include "frontend/frontend_itf.h"
#include "kaldi/feat/feature-window.h"
namespace ppspeech {
@ -35,11 +35,11 @@ struct LinearSpectrogramOptions {
}
};
class LinearSpectrogram : public FeatureExtractorInterface {
class LinearSpectrogram : public FrontendInterface {
public:
explicit LinearSpectrogram(
const LinearSpectrogramOptions& opts,
std::unique_ptr<FeatureExtractorInterface> base_extractor);
std::unique_ptr<FrontendInterface> base_extractor);
virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs);
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* feats);
// the dim_ is the dim of single frame feature
@ -61,7 +61,7 @@ class LinearSpectrogram : public FeatureExtractorInterface {
std::vector<kaldi::BaseFloat> hanning_window_;
kaldi::BaseFloat hanning_window_energy_;
LinearSpectrogramOptions opts_;
std::unique_ptr<FeatureExtractorInterface> base_extractor_;
std::unique_ptr<FrontendInterface> base_extractor_;
int chunk_sample_size_;
DISALLOW_COPY_AND_ASSIGN(LinearSpectrogram);
};

@ -28,7 +28,7 @@ using std::unique_ptr;
DecibelNormalizer::DecibelNormalizer(
const DecibelNormalizerOptions& opts,
std::unique_ptr<FeatureExtractorInterface> base_extractor) {
std::unique_ptr<FrontendInterface> base_extractor) {
base_extractor_ = std::move(base_extractor);
opts_ = opts;
dim_ = 1;
@ -92,7 +92,7 @@ bool DecibelNormalizer::Compute(VectorBase<BaseFloat>* waves) const {
}
CMVN::CMVN(std::string cmvn_file,
unique_ptr<FeatureExtractorInterface> base_extractor)
unique_ptr<FrontendInterface> base_extractor)
: var_norm_(true) {
base_extractor_ = std::move(base_extractor);
bool binary;

@ -16,7 +16,7 @@
#pragma once
#include "base/common.h"
#include "frontend/feature_extractor_interface.h"
#include "frontend/frontend_itf.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "kaldi/util/options-itf.h"
@ -40,11 +40,11 @@ struct DecibelNormalizerOptions {
}
};
class DecibelNormalizer : public FeatureExtractorInterface {
class DecibelNormalizer : public FrontendInterface {
public:
explicit DecibelNormalizer(
const DecibelNormalizerOptions& opts,
std::unique_ptr<FeatureExtractorInterface> base_extractor);
std::unique_ptr<FrontendInterface> base_extractor);
virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& waves);
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* waves);
// noramlize audio, the dim is 1.
@ -57,15 +57,15 @@ class DecibelNormalizer : public FeatureExtractorInterface {
bool Compute(kaldi::VectorBase<kaldi::BaseFloat>* waves) const;
DecibelNormalizerOptions opts_;
size_t dim_;
std::unique_ptr<FeatureExtractorInterface> base_extractor_;
std::unique_ptr<FrontendInterface> base_extractor_;
kaldi::Vector<kaldi::BaseFloat> waveform_;
};
class CMVN : public FeatureExtractorInterface {
class CMVN : public FrontendInterface {
public:
explicit CMVN(std::string cmvn_file,
std::unique_ptr<FeatureExtractorInterface> base_extractor);
std::unique_ptr<FrontendInterface> base_extractor);
virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs);
// the length of feats = feature_row * feature_dim,
@ -81,7 +81,7 @@ class CMVN : public FeatureExtractorInterface {
void Compute(kaldi::VectorBase<kaldi::BaseFloat>* feats) const;
void ApplyCMVN(kaldi::MatrixBase<BaseFloat>* feats);
kaldi::Matrix<double> stats_;
std::unique_ptr<FeatureExtractorInterface> base_extractor_;
std::unique_ptr<FrontendInterface> base_extractor_;
size_t dim_;
bool var_norm_;
};

@ -22,7 +22,7 @@ using std::vector;
using kaldi::Vector;
Decodable::Decodable(const std::shared_ptr<NnetInterface>& nnet,
const std::shared_ptr<FeatureExtractorInterface>& frontend)
const std::shared_ptr<FrontendInterface>& frontend)
: frontend_(frontend), nnet_(nnet), frame_offset_(0), frames_ready_(0) {}
void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) {

@ -13,7 +13,7 @@
// limitations under the License.
#include "base/common.h"
#include "frontend/feature_extractor_interface.h"
#include "frontend/frontend_itf.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "nnet/decodable-itf.h"
#include "nnet/nnet_interface.h"
@ -26,7 +26,7 @@ class Decodable : public kaldi::DecodableInterface {
public:
explicit Decodable(
const std::shared_ptr<NnetInterface>& nnet,
const std::shared_ptr<FeatureExtractorInterface>& frontend);
const std::shared_ptr<FrontendInterface>& frontend);
// void Init(DecodableOpts config);
virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index);
virtual bool IsLastFrame(int32 frame) const;
@ -41,7 +41,7 @@ class Decodable : public kaldi::DecodableInterface {
private:
bool AdvanceChunk();
std::shared_ptr<FeatureExtractorInterface> frontend_;
std::shared_ptr<FrontendInterface> frontend_;
std::shared_ptr<NnetInterface> nnet_;
kaldi::Matrix<kaldi::BaseFloat> nnet_cache_;
// std::vector<std::vector<kaldi::BaseFloat>> nnet_cache_;

Loading…
Cancel
Save