Merge pull request #1640 from zh794390558/frontend

[speechx] Frontend refactor
pull/1641/head
Hui Zhang 3 years ago committed by GitHub
commit 2e94e0f699
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -17,7 +17,7 @@
#include "base/flags.h"
#include "base/log.h"
#include "decoder/ctc_beam_search_decoder.h"
#include "frontend/data_cache.h"
#include "frontend/audio/data_cache.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "nnet/paddle_nnet.h"

@ -17,7 +17,7 @@
#include "base/flags.h"
#include "base/log.h"
#include "decoder/ctc_beam_search_decoder.h"
#include "frontend/data_cache.h"
#include "frontend/audio/data_cache.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "nnet/paddle_nnet.h"

@ -14,19 +14,18 @@
// todo refactor, repalce with gtest
#include "frontend/linear_spectrogram.h"
#include "base/flags.h"
#include "base/log.h"
#include "frontend/audio_cache.h"
#include "frontend/data_cache.h"
#include "frontend/feature_cache.h"
#include "frontend/feature_extractor_interface.h"
#include "frontend/normalizer.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/kaldi-io.h"
#include "kaldi/util/table-types.h"
#include <glog/logging.h>
#include "frontend/audio/audio_cache.h"
#include "frontend/audio/data_cache.h"
#include "frontend/audio/feature_cache.h"
#include "frontend/audio/frontend_itf.h"
#include "frontend/audio/linear_spectrogram.h"
#include "frontend/audio/normalizer.h"
DEFINE_string(wav_rspecifier, "", "test wav scp path");
DEFINE_string(feature_wspecifier, "", "output feats wspecifier");
@ -170,13 +169,13 @@ int main(int argc, char* argv[]) {
// feature pipeline: wave cache --> decibel_normalizer --> hanning
// window -->linear_spectrogram --> global cmvn -> feat cache
// std::unique_ptr<ppspeech::FeatureExtractorInterface> data_source(new
// std::unique_ptr<ppspeech::FrontendInterface> data_source(new
// ppspeech::DataCache());
std::unique_ptr<ppspeech::FeatureExtractorInterface> data_source(
std::unique_ptr<ppspeech::FrontendInterface> data_source(
new ppspeech::AudioCache());
ppspeech::DecibelNormalizerOptions db_norm_opt;
std::unique_ptr<ppspeech::FeatureExtractorInterface> db_norm(
std::unique_ptr<ppspeech::FrontendInterface> db_norm(
new ppspeech::DecibelNormalizer(db_norm_opt, std::move(data_source)));
ppspeech::LinearSpectrogramOptions opt;
@ -185,12 +184,11 @@ int main(int argc, char* argv[]) {
LOG(INFO) << "frame length (ms): " << opt.frame_opts.frame_length_ms;
LOG(INFO) << "frame shift (ms): " << opt.frame_opts.frame_shift_ms;
std::unique_ptr<ppspeech::FeatureExtractorInterface> linear_spectrogram(
std::unique_ptr<ppspeech::FrontendInterface> linear_spectrogram(
new ppspeech::LinearSpectrogram(opt, std::move(db_norm)));
std::unique_ptr<ppspeech::FeatureExtractorInterface> cmvn(
new ppspeech::CMVN(FLAGS_cmvn_write_path,
std::move(linear_spectrogram)));
std::unique_ptr<ppspeech::FrontendInterface> cmvn(new ppspeech::CMVN(
FLAGS_cmvn_write_path, std::move(linear_spectrogram)));
ppspeech::FeatureCache feature_cache(kint16max, std::move(cmvn));
LOG(INFO) << "feat dim: " << feature_cache.Dim();

@ -1,10 +1,2 @@
project(frontend)
add_library(frontend STATIC
normalizer.cc
linear_spectrogram.cc
audio_cache.cc
feature_cache.cc
)
target_link_libraries(frontend PUBLIC kaldi-matrix)
add_subdirectory(audio)

@ -0,0 +1,11 @@
project(frontend)
add_library(frontend STATIC
cmvn.cc
db_norm.cc
linear_spectrogram.cc
audio_cache.cc
feature_cache.cc
)
target_link_libraries(frontend PUBLIC kaldi-matrix)

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "frontend/audio_cache.h"
#include "frontend/audio/audio_cache.h"
#include "kaldi/base/timer.h"
namespace ppspeech {

@ -16,12 +16,12 @@
#pragma once
#include "base/common.h"
#include "frontend/feature_extractor_interface.h"
#include "frontend/audio/frontend_itf.h"
namespace ppspeech {
// waves cache
class AudioCache : public FeatureExtractorInterface {
class AudioCache : public FrontendInterface {
public:
explicit AudioCache(int buffer_size = kint16max);

@ -13,7 +13,7 @@
// limitations under the License.
#include "frontend/normalizer.h"
#include "frontend/audio/cmvn.h"
#include "kaldi/feat/cmvn.h"
#include "kaldi/util/kaldi-io.h"
@ -26,73 +26,8 @@ using std::vector;
using kaldi::SubVector;
using std::unique_ptr;
DecibelNormalizer::DecibelNormalizer(
const DecibelNormalizerOptions& opts,
std::unique_ptr<FeatureExtractorInterface> base_extractor) {
base_extractor_ = std::move(base_extractor);
opts_ = opts;
dim_ = 1;
}
void DecibelNormalizer::Accept(const kaldi::VectorBase<BaseFloat>& waves) {
base_extractor_->Accept(waves);
}
bool DecibelNormalizer::Read(kaldi::Vector<BaseFloat>* waves) {
if (base_extractor_->Read(waves) == false || waves->Dim() == 0) {
return false;
}
Compute(waves);
return true;
}
bool DecibelNormalizer::Compute(VectorBase<BaseFloat>* waves) const {
// calculate db rms
BaseFloat rms_db = 0.0;
BaseFloat mean_square = 0.0;
BaseFloat gain = 0.0;
BaseFloat wave_float_normlization = 1.0f / (std::pow(2, 16 - 1));
vector<BaseFloat> samples;
samples.resize(waves->Dim());
for (size_t i = 0; i < samples.size(); ++i) {
samples[i] = (*waves)(i);
}
// square
for (auto& d : samples) {
if (opts_.convert_int_float) {
d = d * wave_float_normlization;
}
mean_square += d * d;
}
// mean
mean_square /= samples.size();
rms_db = 10 * std::log10(mean_square);
gain = opts_.target_db - rms_db;
if (gain > opts_.max_gain_db) {
LOG(ERROR)
<< "Unable to normalize segment to " << opts_.target_db << "dB,"
<< "because the the probable gain have exceeds opts_.max_gain_db"
<< opts_.max_gain_db << "dB.";
return false;
}
// Note that this is an in-place transformation.
for (auto& item : samples) {
// python item *= 10.0 ** (gain / 20.0)
item *= std::pow(10.0, gain / 20.0);
}
std::memcpy(
waves->Data(), samples.data(), sizeof(BaseFloat) * samples.size());
return true;
}
CMVN::CMVN(std::string cmvn_file,
unique_ptr<FeatureExtractorInterface> base_extractor)
CMVN::CMVN(std::string cmvn_file, unique_ptr<FrontendInterface> base_extractor)
: var_norm_(true) {
base_extractor_ = std::move(base_extractor);
bool binary;
@ -185,4 +120,4 @@ void CMVN::ApplyCMVN(kaldi::MatrixBase<BaseFloat>* feats) {
ApplyCmvn(stats_, var_norm_, feats);
}
} // namespace ppspeech
} // namespace ppspeech

@ -0,0 +1,48 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base/common.h"
#include "frontend/audio/frontend_itf.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "kaldi/util/options-itf.h"
namespace ppspeech {
class CMVN : public FrontendInterface {
public:
explicit CMVN(std::string cmvn_file,
std::unique_ptr<FrontendInterface> base_extractor);
virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs);
// the length of feats = feature_row * feature_dim,
// the Matrix is squashed into Vector
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* feats);
// the dim_ is the feautre dim.
virtual size_t Dim() const { return dim_; }
virtual void SetFinished() { base_extractor_->SetFinished(); }
virtual bool IsFinished() const { return base_extractor_->IsFinished(); }
virtual void Reset() { base_extractor_->Reset(); }
private:
void Compute(kaldi::VectorBase<kaldi::BaseFloat>* feats) const;
void ApplyCMVN(kaldi::MatrixBase<BaseFloat>* feats);
kaldi::Matrix<double> stats_;
std::unique_ptr<FrontendInterface> base_extractor_;
size_t dim_;
bool var_norm_;
};
} // namespace ppspeech

@ -17,13 +17,13 @@
#include "base/common.h"
#include "frontend/feature_extractor_interface.h"
#include "frontend/audio/frontend_itf.h"
namespace ppspeech {
// A data source for testing different frontend module.
// It accepts waves or feats.
class DataCache : public FeatureExtractorInterface {
class DataCache : public FrontendInterface {
public:
explicit DataCache() { finished_ = false; }

@ -0,0 +1,95 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "frontend/audio/db_norm.h"
#include "kaldi/feat/cmvn.h"
#include "kaldi/util/kaldi-io.h"
namespace ppspeech {
using kaldi::Vector;
using kaldi::VectorBase;
using kaldi::BaseFloat;
using std::vector;
using kaldi::SubVector;
using std::unique_ptr;
DecibelNormalizer::DecibelNormalizer(
const DecibelNormalizerOptions& opts,
std::unique_ptr<FrontendInterface> base_extractor) {
base_extractor_ = std::move(base_extractor);
opts_ = opts;
dim_ = 1;
}
void DecibelNormalizer::Accept(const kaldi::VectorBase<BaseFloat>& waves) {
base_extractor_->Accept(waves);
}
bool DecibelNormalizer::Read(kaldi::Vector<BaseFloat>* waves) {
if (base_extractor_->Read(waves) == false || waves->Dim() == 0) {
return false;
}
Compute(waves);
return true;
}
bool DecibelNormalizer::Compute(VectorBase<BaseFloat>* waves) const {
// calculate db rms
BaseFloat rms_db = 0.0;
BaseFloat mean_square = 0.0;
BaseFloat gain = 0.0;
BaseFloat wave_float_normlization = 1.0f / (std::pow(2, 16 - 1));
vector<BaseFloat> samples;
samples.resize(waves->Dim());
for (size_t i = 0; i < samples.size(); ++i) {
samples[i] = (*waves)(i);
}
// square
for (auto& d : samples) {
if (opts_.convert_int_float) {
d = d * wave_float_normlization;
}
mean_square += d * d;
}
// mean
mean_square /= samples.size();
rms_db = 10 * std::log10(mean_square);
gain = opts_.target_db - rms_db;
if (gain > opts_.max_gain_db) {
LOG(ERROR)
<< "Unable to normalize segment to " << opts_.target_db << "dB,"
<< "because the the probable gain have exceeds opts_.max_gain_db"
<< opts_.max_gain_db << "dB.";
return false;
}
// Note that this is an in-place transformation.
for (auto& item : samples) {
// python item *= 10.0 ** (gain / 20.0)
item *= std::pow(10.0, gain / 20.0);
}
std::memcpy(
waves->Data(), samples.data(), sizeof(BaseFloat) * samples.size());
return true;
}
} // namespace ppspeech

@ -16,7 +16,7 @@
#pragma once
#include "base/common.h"
#include "frontend/feature_extractor_interface.h"
#include "frontend/audio/frontend_itf.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "kaldi/util/options-itf.h"
@ -40,11 +40,11 @@ struct DecibelNormalizerOptions {
}
};
class DecibelNormalizer : public FeatureExtractorInterface {
class DecibelNormalizer : public FrontendInterface {
public:
explicit DecibelNormalizer(
const DecibelNormalizerOptions& opts,
std::unique_ptr<FeatureExtractorInterface> base_extractor);
std::unique_ptr<FrontendInterface> base_extractor);
virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& waves);
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* waves);
// noramlize audio, the dim is 1.
@ -57,33 +57,9 @@ class DecibelNormalizer : public FeatureExtractorInterface {
bool Compute(kaldi::VectorBase<kaldi::BaseFloat>* waves) const;
DecibelNormalizerOptions opts_;
size_t dim_;
std::unique_ptr<FeatureExtractorInterface> base_extractor_;
std::unique_ptr<FrontendInterface> base_extractor_;
kaldi::Vector<kaldi::BaseFloat> waveform_;
};
class CMVN : public FeatureExtractorInterface {
public:
explicit CMVN(std::string cmvn_file,
std::unique_ptr<FeatureExtractorInterface> base_extractor);
virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs);
// the length of feats = feature_row * feature_dim,
// the Matrix is squashed into Vector
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* feats);
// the dim_ is the feautre dim.
virtual size_t Dim() const { return dim_; }
virtual void SetFinished() { base_extractor_->SetFinished(); }
virtual bool IsFinished() const { return base_extractor_->IsFinished(); }
virtual void Reset() { base_extractor_->Reset(); }
private:
void Compute(kaldi::VectorBase<kaldi::BaseFloat>* feats) const;
void ApplyCMVN(kaldi::MatrixBase<BaseFloat>* feats);
kaldi::Matrix<double> stats_;
std::unique_ptr<FeatureExtractorInterface> base_extractor_;
size_t dim_;
bool var_norm_;
};
} // namespace ppspeech

@ -20,10 +20,10 @@
namespace ppspeech {
class FbankExtractor : FeatureExtractorInterface {
class FbankExtractor : FrontendInterface {
public:
explicit FbankExtractor(const FbankOptions& opts,
share_ptr<FeatureExtractorInterface> pre_extractor);
share_ptr<FrontendInterface> pre_extractor);
virtual void AcceptWaveform(
const kaldi::Vector<kaldi::BaseFloat>& input) = 0;
virtual void Read(kaldi::Vector<kaldi::BaseFloat>* feat) = 0;

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "frontend/feature_cache.h"
#include "frontend/audio/feature_cache.h"
namespace ppspeech {
@ -23,8 +23,8 @@ using std::vector;
using kaldi::SubVector;
using std::unique_ptr;
FeatureCache::FeatureCache(
int max_size, unique_ptr<FeatureExtractorInterface> base_extractor) {
FeatureCache::FeatureCache(int max_size,
unique_ptr<FrontendInterface> base_extractor) {
max_size_ = max_size;
base_extractor_ = std::move(base_extractor);
}

@ -15,15 +15,15 @@
#pragma once
#include "base/common.h"
#include "frontend/feature_extractor_interface.h"
#include "frontend/audio/frontend_itf.h"
namespace ppspeech {
class FeatureCache : public FeatureExtractorInterface {
class FeatureCache : public FrontendInterface {
public:
explicit FeatureCache(
int32 max_size = kint16max,
std::unique_ptr<FeatureExtractorInterface> base_extractor = NULL);
std::unique_ptr<FrontendInterface> base_extractor = NULL);
// Feed feats or waves
virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs);
@ -53,7 +53,7 @@ class FeatureCache : public FeatureExtractorInterface {
bool Compute();
size_t max_size_;
std::unique_ptr<FeatureExtractorInterface> base_extractor_;
std::unique_ptr<FrontendInterface> base_extractor_;
std::mutex mutex_;
std::queue<kaldi::Vector<BaseFloat>> cache_;

@ -19,7 +19,7 @@
namespace ppspeech {
class FeatureExtractorInterface {
class FrontendInterface {
public:
// Feed inputs: features(2D saved in 1D) or waveforms(1D).
virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs) = 0;

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "frontend/linear_spectrogram.h"
#include "frontend/audio/linear_spectrogram.h"
#include "kaldi/base/kaldi-math.h"
#include "kaldi/matrix/matrix-functions.h"
@ -27,7 +27,7 @@ using std::vector;
LinearSpectrogram::LinearSpectrogram(
const LinearSpectrogramOptions& opts,
std::unique_ptr<FeatureExtractorInterface> base_extractor) {
std::unique_ptr<FrontendInterface> base_extractor) {
opts_ = opts;
base_extractor_ = std::move(base_extractor);
int32 window_size = opts.frame_opts.WindowSize();

@ -16,7 +16,7 @@
#pragma once
#include "base/common.h"
#include "frontend/feature_extractor_interface.h"
#include "frontend/audio/frontend_itf.h"
#include "kaldi/feat/feature-window.h"
namespace ppspeech {
@ -35,11 +35,11 @@ struct LinearSpectrogramOptions {
}
};
class LinearSpectrogram : public FeatureExtractorInterface {
class LinearSpectrogram : public FrontendInterface {
public:
explicit LinearSpectrogram(
const LinearSpectrogramOptions& opts,
std::unique_ptr<FeatureExtractorInterface> base_extractor);
std::unique_ptr<FrontendInterface> base_extractor);
virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs);
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* feats);
// the dim_ is the dim of single frame feature
@ -61,7 +61,7 @@ class LinearSpectrogram : public FeatureExtractorInterface {
std::vector<kaldi::BaseFloat> hanning_window_;
kaldi::BaseFloat hanning_window_energy_;
LinearSpectrogramOptions opts_;
std::unique_ptr<FeatureExtractorInterface> base_extractor_;
std::unique_ptr<FrontendInterface> base_extractor_;
int chunk_sample_size_;
DISALLOW_COPY_AND_ASSIGN(LinearSpectrogram);
};

@ -12,4 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// extract the window of kaldi feat.
#pragma once
#include "frontend/audio/cmvn.h"
#include "frontend/audio/db_norm.h"

@ -1,13 +0,0 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

@ -1,13 +0,0 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

@ -22,7 +22,7 @@ using std::vector;
using kaldi::Vector;
Decodable::Decodable(const std::shared_ptr<NnetInterface>& nnet,
const std::shared_ptr<FeatureExtractorInterface>& frontend)
const std::shared_ptr<FrontendInterface>& frontend)
: frontend_(frontend), nnet_(nnet), frame_offset_(0), frames_ready_(0) {}
void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) {

@ -13,7 +13,7 @@
// limitations under the License.
#include "base/common.h"
#include "frontend/feature_extractor_interface.h"
#include "frontend/audio/frontend_itf.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "nnet/decodable-itf.h"
#include "nnet/nnet_interface.h"
@ -24,9 +24,8 @@ struct DecodableOpts;
class Decodable : public kaldi::DecodableInterface {
public:
explicit Decodable(
const std::shared_ptr<NnetInterface>& nnet,
const std::shared_ptr<FeatureExtractorInterface>& frontend);
explicit Decodable(const std::shared_ptr<NnetInterface>& nnet,
const std::shared_ptr<FrontendInterface>& frontend);
// void Init(DecodableOpts config);
virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index);
virtual bool IsLastFrame(int32 frame) const;
@ -41,7 +40,7 @@ class Decodable : public kaldi::DecodableInterface {
private:
bool AdvanceChunk();
std::shared_ptr<FeatureExtractorInterface> frontend_;
std::shared_ptr<FrontendInterface> frontend_;
std::shared_ptr<NnetInterface> nnet_;
kaldi::Matrix<kaldi::BaseFloat> nnet_cache_;
// std::vector<std::vector<kaldi::BaseFloat>> nnet_cache_;

@ -15,13 +15,14 @@
#pragma once
#include "base/common.h"
#include "nnet/nnet_interface.h"
#include "paddle_inference_api.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "kaldi/util/options-itf.h"
#include "base/common.h"
#include "nnet/nnet_itf.h"
#include "paddle_inference_api.h"
#include <numeric>
namespace ppspeech {

Loading…
Cancel
Save