Merge pull request #1542 from SmileGoat/stream_feature

[speechx]add raw_audio & feature_cache
pull/1559/head
Hui Zhang 2 years ago committed by GitHub
commit bedd2de46c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -5,6 +5,6 @@ add_executable(mfcc-test ${CMAKE_CURRENT_SOURCE_DIR}/feature-mfcc-test.cc)
target_include_directories(mfcc-test PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(mfcc-test kaldi-mfcc)
add_executable(linear-spectrogram-main ${CMAKE_CURRENT_SOURCE_DIR}/linear-spectrogram-main.cc)
target_include_directories(linear-spectrogram-main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(linear-spectrogram-main frontend kaldi-util kaldi-feat-common gflags glog)
add_executable(linear_spectrogram_main ${CMAKE_CURRENT_SOURCE_DIR}/linear_spectrogram_main.cc)
target_include_directories(linear_spectrogram_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(linear_spectrogram_main frontend kaldi-util kaldi-feat-common gflags glog)

@ -14,19 +14,20 @@
// todo refactor, repalce with gtest
#include "frontend/linear_spectrogram.h"
#include "base/flags.h"
#include "base/log.h"
#include "frontend/feature_cache.h"
#include "frontend/feature_extractor_interface.h"
#include "frontend/linear_spectrogram.h"
#include "frontend/normalizer.h"
#include "frontend/raw_audio.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/kaldi-io.h"
#include "kaldi/util/table-types.h"
DEFINE_string(wav_rspecifier, "", "test wav path");
DEFINE_string(feature_wspecifier, "", "test wav ark");
DEFINE_string(feature_check_wspecifier, "", "test wav ark");
DEFINE_string(cmvn_write_path, "./cmvn.ark", "test wav ark");
DEFINE_string(wav_rspecifier, "", "test wav scp path");
DEFINE_string(feature_wspecifier, "", "output feats wspecifier");
DEFINE_string(cmvn_write_path, "./cmvn.ark", "write cmvn");
std::vector<float> mean_{
@ -158,38 +159,37 @@ int main(int argc, char* argv[]) {
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
FLAGS_wav_rspecifier);
kaldi::BaseFloatMatrixWriter feat_writer(FLAGS_feature_wspecifier);
kaldi::BaseFloatMatrixWriter feat_cmvn_check_writer(
FLAGS_feature_check_wspecifier);
WriteMatrix();
// test feature linear_spectorgram: wave --> decibel_normalizer --> hanning
// window -->linear_spectrogram --> cmvn
int32 num_done = 0, num_err = 0;
//std::unique_ptr<ppspeech::FeatureExtractorInterface> data_source(new
//ppspeech::RawDataCache());
std::unique_ptr<ppspeech::FeatureExtractorInterface> data_source(
new ppspeech::RawAudioCache());
ppspeech::LinearSpectrogramOptions opt;
opt.frame_opts.frame_length_ms = 20;
opt.frame_opts.frame_shift_ms = 10;
ppspeech::DecibelNormalizerOptions db_norm_opt;
std::unique_ptr<ppspeech::FeatureExtractorInterface> base_feature_extractor(
new ppspeech::DecibelNormalizer(db_norm_opt));
ppspeech::LinearSpectrogram linear_spectrogram(
opt, std::move(base_feature_extractor));
new ppspeech::DecibelNormalizer(db_norm_opt, std::move(data_source)));
std::unique_ptr<ppspeech::FeatureExtractorInterface> linear_spectrogram(
new ppspeech::LinearSpectrogram(opt,
std::move(base_feature_extractor)));
ppspeech::CMVN cmvn(FLAGS_cmvn_write_path);
std::unique_ptr<ppspeech::FeatureExtractorInterface> cmvn(
new ppspeech::CMVN(FLAGS_cmvn_write_path,
std::move(linear_spectrogram)));
ppspeech::FeatureCache feature_cache(kint16max, std::move(cmvn));
float streaming_chunk = 0.36;
int sample_rate = 16000;
int chunk_sample_size = streaming_chunk * sample_rate;
LOG(INFO) << mean_.size();
for (size_t i = 0; i < mean_.size(); i++) {
mean_[i] /= count_;
variance_[i] = variance_[i] / count_ - mean_[i] * mean_[i];
if (variance_[i] < 1.0e-20) {
variance_[i] = 1.0e-20;
}
variance_[i] = 1.0 / std::sqrt(variance_[i]);
}
for (; !wav_reader.Done(); wav_reader.Next()) {
std::string utt = wav_reader.Key();
const kaldi::WaveData& wave_data = wav_reader.Value();
@ -199,54 +199,45 @@ int main(int argc, char* argv[]) {
this_channel);
int tot_samples = waveform.Dim();
int sample_offset = 0;
std::vector<kaldi::Matrix<BaseFloat>> feats;
std::vector<kaldi::Vector<BaseFloat>> feats;
int feature_rows = 0;
while (sample_offset < tot_samples) {
int cur_chunk_size =
std::min(chunk_sample_size, tot_samples - sample_offset);
kaldi::Vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk(i) = waveform(sample_offset + i);
}
kaldi::Matrix<BaseFloat> features;
linear_spectrogram.AcceptWaveform(wav_chunk);
linear_spectrogram.ReadFeats(&features);
kaldi::Vector<BaseFloat> features;
feature_cache.Accept(wav_chunk);
if (cur_chunk_size < chunk_sample_size) {
feature_cache.SetFinished();
}
feature_cache.Read(&features);
if (features.Dim() == 0) break;
feats.push_back(features);
sample_offset += cur_chunk_size;
feature_rows += features.NumRows();
feature_rows += features.Dim() / feature_cache.Dim();
}
int cur_idx = 0;
kaldi::Matrix<kaldi::BaseFloat> features(feature_rows,
feats[0].NumCols());
feature_cache.Dim());
for (auto feat : feats) {
for (int row_idx = 0; row_idx < feat.NumRows(); ++row_idx) {
for (int col_idx = 0; col_idx < feat.NumCols(); ++col_idx) {
int num_rows = feat.Dim() / feature_cache.Dim();
for (int row_idx = 0; row_idx < num_rows; ++row_idx) {
for (size_t col_idx = 0; col_idx < feature_cache.Dim();
++col_idx) {
features(cur_idx, col_idx) =
(feat(row_idx, col_idx) - mean_[col_idx]) *
variance_[col_idx];
feat(row_idx * feature_cache.Dim() + col_idx);
}
++cur_idx;
}
}
feat_writer.Write(utt, features);
cur_idx = 0;
kaldi::Matrix<kaldi::BaseFloat> features_check(feature_rows,
feats[0].NumCols());
for (auto feat : feats) {
for (int row_idx = 0; row_idx < feat.NumRows(); ++row_idx) {
for (int col_idx = 0; col_idx < feat.NumCols(); ++col_idx) {
features_check(cur_idx, col_idx) = feat(row_idx, col_idx);
}
kaldi::SubVector<BaseFloat> row_feat(features_check, cur_idx);
cmvn.ApplyCMVN(true, &row_feat);
++cur_idx;
}
}
feat_cmvn_check_writer.Write(utt, features_check);
if (num_done % 50 == 0 && num_done != 0)
KALDI_VLOG(2) << "Processed " << num_done << " utterances";
num_done++;

@ -14,6 +14,7 @@
#pragma once
#include <condition_variable>
#include <deque>
#include <fstream>
#include <iostream>
@ -22,6 +23,7 @@
#include <memory>
#include <mutex>
#include <ostream>
#include <queue>
#include <set>
#include <sstream>
#include <stack>

@ -2,7 +2,9 @@ project(frontend)
add_library(frontend STATIC
normalizer.cc
linear_spectrogram.cc
linear_spectrogram.cc
raw_audio.cc
feature_cache.cc
)
target_link_libraries(frontend PUBLIC kaldi-matrix)
target_link_libraries(frontend PUBLIC kaldi-matrix)

@ -0,0 +1,84 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "frontend/feature_cache.h"
namespace ppspeech {
using kaldi::Vector;
using kaldi::VectorBase;
using kaldi::BaseFloat;
using std::vector;
using kaldi::SubVector;
using std::unique_ptr;
FeatureCache::FeatureCache(
int max_size, unique_ptr<FeatureExtractorInterface> base_extractor) {
max_size_ = max_size;
base_extractor_ = std::move(base_extractor);
}
void FeatureCache::Accept(
const kaldi::VectorBase<kaldi::BaseFloat>& inputs) {
base_extractor_->Accept(inputs);
// feed current data
bool result = false;
do {
result = Compute();
} while (result);
}
// pop feature chunk
bool FeatureCache::Read(kaldi::Vector<kaldi::BaseFloat>* feats) {
kaldi::Timer timer;
std::unique_lock<std::mutex> lock(mutex_);
while (cache_.empty() && base_extractor_->IsFinished() == false) {
ready_read_condition_.wait(lock);
BaseFloat elapsed = timer.Elapsed() * 1000;
// todo replace 1.0 with timeout_
if (elapsed > 1.0) {
return false;
}
usleep(1000); // sleep 1 ms
}
if (cache_.empty()) return false;
feats->Resize(cache_.front().Dim());
feats->CopyFromVec(cache_.front());
cache_.pop();
ready_feed_condition_.notify_one();
return true;
}
// read all data from base_feature_extractor_ into cache_
bool FeatureCache::Compute() {
// compute and feed
Vector<BaseFloat> feature_chunk;
bool result = base_extractor_->Read(&feature_chunk);
std::unique_lock<std::mutex> lock(mutex_);
while (cache_.size() >= max_size_) {
ready_feed_condition_.wait(lock);
}
if (feature_chunk.Dim() != 0) {
cache_.push(feature_chunk);
}
ready_read_condition_.notify_one();
return result;
}
void Reset() {
// std::lock_guard<std::mutex> lock(mutex_);
return;
}
} // namespace ppspeech

@ -0,0 +1,53 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base/common.h"
#include "frontend/feature_extractor_interface.h"
namespace ppspeech {
class FeatureCache : public FeatureExtractorInterface {
public:
explicit FeatureCache(
int32 max_size = kint16max,
std::unique_ptr<FeatureExtractorInterface> base_extractor = NULL);
virtual void Accept(
const kaldi::VectorBase<kaldi::BaseFloat>& inputs);
// feats dim = num_frames * feature_dim
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* feats);
// feature cache only cache feature which from base extractor
virtual size_t Dim() const { return base_extractor_->Dim(); }
virtual void SetFinished() {
base_extractor_->SetFinished();
// read the last chunk data
Compute();
}
virtual bool IsFinished() const { return base_extractor_->IsFinished(); }
private:
bool Compute();
bool finished_;
std::mutex mutex_;
size_t max_size_;
std::queue<kaldi::Vector<BaseFloat>> cache_;
std::unique_ptr<FeatureExtractorInterface> base_extractor_;
std::condition_variable ready_feed_condition_;
std::condition_variable ready_read_condition_;
//DISALLOW_COPY_AND_ASSGIN(FeatureCache);
};
} // namespace ppspeech

@ -21,10 +21,19 @@ namespace ppspeech {
class FeatureExtractorInterface {
public:
virtual void AcceptWaveform(
const kaldi::VectorBase<kaldi::BaseFloat>& input) = 0;
virtual void Read(kaldi::VectorBase<kaldi::BaseFloat>* feat) = 0;
// accept input data, accept feature or raw waves which decided
// by the base_extractor
virtual void Accept(
const kaldi::VectorBase<kaldi::BaseFloat>& inputs) = 0;
// get the processed result
// the length of output = feature_row * feature_dim,
// the Matrix is squashed into Vector
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* outputs) = 0;
// the Dim is the feature dim
virtual size_t Dim() const = 0;
virtual void SetFinished() = 0;
virtual bool IsFinished() const = 0;
// virtual void Reset();
};
} // namespace ppspeech
} // namespace ppspeech

@ -52,6 +52,8 @@ LinearSpectrogram::LinearSpectrogram(
int32 window_size = opts.frame_opts.WindowSize();
int32 window_shift = opts.frame_opts.WindowShift();
fft_points_ = window_size;
chunk_sample_size_ =
static_cast<int32>(opts.streaming_chunk * opts.frame_opts.samp_freq);
hanning_window_.resize(window_size);
double a = M_2PI / (window_size - 1);
@ -64,8 +66,29 @@ LinearSpectrogram::LinearSpectrogram(
dim_ = fft_points_ / 2 + 1; // the dimension is Fs/2 Hz
}
void LinearSpectrogram::AcceptWaveform(const VectorBase<BaseFloat>& input) {
base_extractor_->AcceptWaveform(input);
void LinearSpectrogram::Accept(const VectorBase<BaseFloat>& inputs) {
base_extractor_->Accept(inputs);
}
bool LinearSpectrogram::Read(Vector<BaseFloat>* feats) {
Vector<BaseFloat> input_feats(chunk_sample_size_);
bool flag = base_extractor_->Read(&input_feats);
if (flag == false || input_feats.Dim() == 0) return false;
vector<BaseFloat> input_feats_vec(input_feats.Dim());
CopyVector2StdVector_(input_feats, &input_feats_vec);
vector<vector<BaseFloat>> result;
Compute(input_feats_vec, result);
int32 feat_size = 0;
if (result.size() != 0) {
feat_size = result.size() * result[0].size();
}
feats->Resize(feat_size);
// todo refactor (SimleGoat)
for (size_t idx = 0; idx < feat_size; ++idx) {
(*feats)(idx) = result[idx / dim_][idx % dim_];
}
return true;
}
void LinearSpectrogram::Hanning(vector<float>* data) const {
@ -95,41 +118,11 @@ bool LinearSpectrogram::NumpyFft(vector<BaseFloat>* v,
return true;
}
// todo remove later
void LinearSpectrogram::ReadFeats(Matrix<BaseFloat>* feats) {
Vector<BaseFloat> tmp;
waveform_.Resize(base_extractor_->Dim());
Compute(tmp, &waveform_);
vector<vector<BaseFloat>> result;
vector<BaseFloat> feats_vec;
CopyVector2StdVector_(waveform_, &feats_vec);
Compute(feats_vec, result);
feats->Resize(result.size(), result[0].size());
for (int row_idx = 0; row_idx < result.size(); ++row_idx) {
for (int col_idx = 0; col_idx < result[0].size(); ++col_idx) {
(*feats)(row_idx, col_idx) = result[row_idx][col_idx];
}
}
waveform_.Resize(0);
}
void LinearSpectrogram::Read(VectorBase<BaseFloat>* feat) {
// todo
return;
}
// only for test, remove later
// todo: compute the feature frame by frame.
void LinearSpectrogram::Compute(const VectorBase<kaldi::BaseFloat>& input,
VectorBase<kaldi::BaseFloat>* feature) {
base_extractor_->Read(feature);
}
// Compute spectrogram feat, only for test, remove later
// Compute spectrogram feat
// todo: refactor later (SmileGoat)
bool LinearSpectrogram::Compute(const vector<float>& wave,
vector<vector<float>>& feat) {
int num_samples = wave.size();
bool LinearSpectrogram::Compute(const vector<float>& waves,
vector<vector<float>>& feats) {
int num_samples = waves.size();
const int& frame_length = opts_.frame_opts.WindowSize();
const int& sample_rate = opts_.frame_opts.samp_freq;
const int& frame_shift = opts_.frame_opts.WindowShift();
@ -141,34 +134,34 @@ bool LinearSpectrogram::Compute(const vector<float>& wave,
}
int num_frames = 1 + ((num_samples - frame_length) / frame_shift);
feat.resize(num_frames);
feats.resize(num_frames);
vector<float> fft_real((fft_points_ / 2 + 1), 0);
vector<float> fft_img((fft_points_ / 2 + 1), 0);
vector<float> v(frame_length, 0);
vector<float> power((fft_points / 2 + 1));
for (int i = 0; i < num_frames; ++i) {
vector<float> data(wave.data() + i * frame_shift,
wave.data() + i * frame_shift + frame_length);
vector<float> data(waves.data() + i * frame_shift,
waves.data() + i * frame_shift + frame_length);
Hanning(&data);
fft_img.clear();
fft_real.clear();
v.assign(data.begin(), data.end());
NumpyFft(&v, &fft_real, &fft_img);
feat[i].resize(fft_points / 2 + 1); // the last dimension is Fs/2 Hz
feats[i].resize(fft_points / 2 + 1); // the last dimension is Fs/2 Hz
for (int j = 0; j < (fft_points / 2 + 1); ++j) {
power[j] = fft_real[j] * fft_real[j] + fft_img[j] * fft_img[j];
feat[i][j] = power[j];
feats[i][j] = power[j];
if (j == 0 || j == feat[0].size() - 1) {
feat[i][j] /= scale;
if (j == 0 || j == feats[0].size() - 1) {
feats[i][j] /= scale;
} else {
feat[i][j] *= (2.0 / scale);
feats[i][j] *= (2.0 / scale);
}
// log added eps=1e-14
feat[i][j] = std::log(feat[i][j] + 1e-14);
feats[i][j] = std::log(feats[i][j] + 1e-14);
}
}
return true;

@ -23,9 +23,14 @@ namespace ppspeech {
struct LinearSpectrogramOptions {
kaldi::FrameExtractionOptions frame_opts;
LinearSpectrogramOptions() : frame_opts() {}
kaldi::BaseFloat streaming_chunk;
LinearSpectrogramOptions() : streaming_chunk(0.36), frame_opts() {}
void Register(kaldi::OptionsItf* opts) { frame_opts.Register(opts); }
void Register(kaldi::OptionsItf* opts) {
opts->Register(
"streaming-chunk", &streaming_chunk, "streaming chunk size");
frame_opts.Register(opts);
}
};
class LinearSpectrogram : public FeatureExtractorInterface {
@ -33,18 +38,18 @@ class LinearSpectrogram : public FeatureExtractorInterface {
explicit LinearSpectrogram(
const LinearSpectrogramOptions& opts,
std::unique_ptr<FeatureExtractorInterface> base_extractor);
virtual void AcceptWaveform(
const kaldi::VectorBase<kaldi::BaseFloat>& input);
virtual void Read(kaldi::VectorBase<kaldi::BaseFloat>* feat);
virtual void Accept(
const kaldi::VectorBase<kaldi::BaseFloat>& inputs);
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* feats);
// the dim_ is the dim of single frame feature
virtual size_t Dim() const { return dim_; }
void ReadFeats(kaldi::Matrix<kaldi::BaseFloat>* feats);
virtual void SetFinished() { base_extractor_->SetFinished(); }
virtual bool IsFinished() const { return base_extractor_->IsFinished(); }
private:
void Hanning(std::vector<kaldi::BaseFloat>* data) const;
bool Compute(const std::vector<kaldi::BaseFloat>& wave,
std::vector<std::vector<kaldi::BaseFloat>>& feat);
void Compute(const kaldi::VectorBase<kaldi::BaseFloat>& input,
kaldi::VectorBase<kaldi::BaseFloat>* feature);
bool Compute(const std::vector<kaldi::BaseFloat>& waves,
std::vector<std::vector<kaldi::BaseFloat>>& feats);
bool NumpyFft(std::vector<kaldi::BaseFloat>* v,
std::vector<kaldi::BaseFloat>* real,
std::vector<kaldi::BaseFloat>* img) const;
@ -54,8 +59,8 @@ class LinearSpectrogram : public FeatureExtractorInterface {
std::vector<kaldi::BaseFloat> hanning_window_;
kaldi::BaseFloat hanning_window_energy_;
LinearSpectrogramOptions opts_;
kaldi::Vector<kaldi::BaseFloat> waveform_; // remove later, todo(SmileGoat)
std::unique_ptr<FeatureExtractorInterface> base_extractor_;
int chunk_sample_size_;
DISALLOW_COPY_AND_ASSIGN(LinearSpectrogram);
};

@ -24,22 +24,28 @@ using kaldi::VectorBase;
using kaldi::BaseFloat;
using std::vector;
using kaldi::SubVector;
using std::unique_ptr;
DecibelNormalizer::DecibelNormalizer(const DecibelNormalizerOptions& opts) {
DecibelNormalizer::DecibelNormalizer(
const DecibelNormalizerOptions& opts,
std::unique_ptr<FeatureExtractorInterface> base_extractor) {
base_extractor_ = std::move(base_extractor);
opts_ = opts;
dim_ = 0;
dim_ = 1;
}
void DecibelNormalizer::AcceptWaveform(
const kaldi::VectorBase<BaseFloat>& input) {
dim_ = input.Dim();
waveform_.Resize(input.Dim());
waveform_.CopyFromVec(input);
void DecibelNormalizer::Accept(
const kaldi::VectorBase<BaseFloat>& waves) {
base_extractor_->Accept(waves);
}
void DecibelNormalizer::Read(kaldi::VectorBase<BaseFloat>* feat) {
if (waveform_.Dim() == 0) return;
Compute(waveform_, feat);
bool DecibelNormalizer::Read(kaldi::Vector<BaseFloat>* waves) {
if (base_extractor_->Read(waves) == false ||
waves->Dim() == 0) {
return false;
}
Compute(waves);
return true;
}
// todo remove later
@ -61,8 +67,7 @@ void CopyStdVector2Vector(const vector<BaseFloat>& input,
}
}
bool DecibelNormalizer::Compute(const VectorBase<BaseFloat>& input,
VectorBase<BaseFloat>* feat) const {
bool DecibelNormalizer::Compute(VectorBase<BaseFloat>* waves) const {
// calculate db rms
BaseFloat rms_db = 0.0;
BaseFloat mean_square = 0.0;
@ -70,9 +75,9 @@ bool DecibelNormalizer::Compute(const VectorBase<BaseFloat>& input,
BaseFloat wave_float_normlization = 1.0f / (std::pow(2, 16 - 1));
vector<BaseFloat> samples;
samples.resize(input.Dim());
for (int32 i = 0; i < samples.size(); ++i) {
samples[i] = input(i);
samples.resize(waves->Dim());
for (size_t i = 0; i < samples.size(); ++i) {
samples[i] = (*waves)(i);
}
// square
@ -102,24 +107,35 @@ bool DecibelNormalizer::Compute(const VectorBase<BaseFloat>& input,
item *= std::pow(10.0, gain / 20.0);
}
CopyStdVector2Vector(samples, feat);
CopyStdVector2Vector(samples, waves);
return true;
}
CMVN::CMVN(std::string cmvn_file) : var_norm_(true) {
CMVN::CMVN(std::string cmvn_file,
unique_ptr<FeatureExtractorInterface> base_extractor)
: var_norm_(true) {
base_extractor_ = std::move(base_extractor);
bool binary;
kaldi::Input ki(cmvn_file, &binary);
stats_.Read(ki.Stream(), binary);
dim_ = stats_.NumCols() - 1;
}
void CMVN::AcceptWaveform(const kaldi::VectorBase<kaldi::BaseFloat>& input) {
void CMVN::Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs) {
base_extractor_->Accept(inputs);
return;
}
void CMVN::Read(kaldi::VectorBase<BaseFloat>* feat) { return; }
bool CMVN::Read(kaldi::Vector<BaseFloat>* feats) {
if (base_extractor_->Read(feats) == false) {
return false;
}
Compute(feats);
return true;
}
// feats contain num_frames feature.
void CMVN::ApplyCMVN(bool var_norm, VectorBase<BaseFloat>* feats) {
void CMVN::Compute(VectorBase<BaseFloat>* feats) const {
KALDI_ASSERT(feats != NULL);
int32 dim = stats_.NumCols() - 1;
if (stats_.NumRows() > 2 || stats_.NumRows() < 1 ||
@ -127,7 +143,7 @@ void CMVN::ApplyCMVN(bool var_norm, VectorBase<BaseFloat>* feats) {
KALDI_ERR << "Dim mismatch: cmvn " << stats_.NumRows() << 'x'
<< stats_.NumCols() << ", feats " << feats->Dim() << 'x';
}
if (stats_.NumRows() == 1 && var_norm) {
if (stats_.NumRows() == 1 && var_norm_) {
KALDI_ERR
<< "You requested variance normalization but no variance stats_ "
<< "are supplied.";
@ -141,7 +157,7 @@ void CMVN::ApplyCMVN(bool var_norm, VectorBase<BaseFloat>* feats) {
"normalization: "
<< "count = " << count;
if (!var_norm) {
if (!var_norm_) {
Vector<BaseFloat> offset(feats->Dim());
SubVector<double> mean_stats(stats_.RowData(0), dim);
Vector<double> mean_stats_apply(feats->Dim());
@ -185,14 +201,8 @@ void CMVN::ApplyCMVN(bool var_norm, VectorBase<BaseFloat>* feats) {
feats->AddVec(1.0, norm.Row(0));
}
void CMVN::ApplyCMVNMatrix(bool var_norm, kaldi::MatrixBase<BaseFloat>* feats) {
ApplyCmvn(stats_, var_norm, feats);
void CMVN::ApplyCMVN(kaldi::MatrixBase<BaseFloat>* feats) {
ApplyCmvn(stats_, var_norm_, feats);
}
bool CMVN::Compute(const VectorBase<BaseFloat>& input,
VectorBase<BaseFloat>* feat) const {
return false;
}
} // namespace ppspeech

@ -42,15 +42,19 @@ struct DecibelNormalizerOptions {
class DecibelNormalizer : public FeatureExtractorInterface {
public:
explicit DecibelNormalizer(const DecibelNormalizerOptions& opts);
virtual void AcceptWaveform(
const kaldi::VectorBase<kaldi::BaseFloat>& input);
virtual void Read(kaldi::VectorBase<kaldi::BaseFloat>* feat);
explicit DecibelNormalizer(
const DecibelNormalizerOptions& opts,
std::unique_ptr<FeatureExtractorInterface> base_extractor);
virtual void Accept(
const kaldi::VectorBase<kaldi::BaseFloat>& waves);
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* waves);
// noramlize audio, the dim is 1.
virtual size_t Dim() const { return dim_; }
bool Compute(const kaldi::VectorBase<kaldi::BaseFloat>& input,
kaldi::VectorBase<kaldi::BaseFloat>* feat) const;
virtual void SetFinished() { base_extractor_->SetFinished(); }
virtual bool IsFinished() const { return base_extractor_->IsFinished(); }
private:
bool Compute(kaldi::VectorBase<kaldi::BaseFloat>* waves) const;
DecibelNormalizerOptions opts_;
size_t dim_;
std::unique_ptr<FeatureExtractorInterface> base_extractor_;
@ -60,20 +64,24 @@ class DecibelNormalizer : public FeatureExtractorInterface {
class CMVN : public FeatureExtractorInterface {
public:
explicit CMVN(std::string cmvn_file);
virtual void AcceptWaveform(
const kaldi::VectorBase<kaldi::BaseFloat>& input);
virtual void Read(kaldi::VectorBase<kaldi::BaseFloat>* feat);
virtual size_t Dim() const { return stats_.NumCols() - 1; }
bool Compute(const kaldi::VectorBase<kaldi::BaseFloat>& input,
kaldi::VectorBase<kaldi::BaseFloat>* feat) const;
// for test
void ApplyCMVN(bool var_norm, kaldi::VectorBase<BaseFloat>* feats);
void ApplyCMVNMatrix(bool var_norm, kaldi::MatrixBase<BaseFloat>* feats);
explicit CMVN(std::string cmvn_file,
std::unique_ptr<FeatureExtractorInterface> base_extractor);
virtual void Accept(
const kaldi::VectorBase<kaldi::BaseFloat>& inputs);
// the length of feats = feature_row * feature_dim,
// the Matrix is squashed into Vector
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* feats);
// the dim_ is the feautre dim.
virtual size_t Dim() const { return dim_; }
virtual void SetFinished() { base_extractor_->SetFinished(); }
virtual bool IsFinished() const { return base_extractor_->IsFinished(); }
private:
void Compute(kaldi::VectorBase<kaldi::BaseFloat>* feats) const;
void ApplyCMVN(kaldi::MatrixBase<BaseFloat>* feats);
kaldi::Matrix<double> stats_;
std::shared_ptr<FeatureExtractorInterface> base_extractor_;
std::unique_ptr<FeatureExtractorInterface> base_extractor_;
size_t dim_;
bool var_norm_;
};

@ -0,0 +1,77 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "frontend/raw_audio.h"
#include "kaldi/base/timer.h"
namespace ppspeech {
using kaldi::BaseFloat;
using kaldi::VectorBase;
using kaldi::Vector;
RawAudioCache::RawAudioCache(int buffer_size)
: finished_(false), data_length_(0), start_(0), timeout_(1) {
ring_buffer_.resize(buffer_size);
}
void RawAudioCache::Accept(const VectorBase<BaseFloat>& waves) {
std::unique_lock<std::mutex> lock(mutex_);
while (data_length_ + waves.Dim() > ring_buffer_.size()) {
ready_feed_condition_.wait(lock);
}
for (size_t idx = 0; idx < waves.Dim(); ++idx) {
int32 buffer_idx = (idx + start_) % ring_buffer_.size();
ring_buffer_[buffer_idx] = waves(idx);
}
data_length_ += waves.Dim();
}
bool RawAudioCache::Read(Vector<BaseFloat>* waves) {
size_t chunk_size = waves->Dim();
kaldi::Timer timer;
std::unique_lock<std::mutex> lock(mutex_);
while (chunk_size > data_length_) {
// when audio is empty and no more data feed
// ready_read_condition will block in dead lock. so replace with timeout_
// ready_read_condition_.wait(lock);
int32 elapsed = static_cast<int32>(timer.Elapsed() * 1000);
if (elapsed > timeout_) {
if (finished_ == true) { // read last chunk data
break;
}
if (chunk_size > data_length_) {
return false;
}
}
usleep(100); // sleep 0.1 ms
}
// read last chunk data
if (chunk_size > data_length_) {
chunk_size = data_length_;
waves->Resize(chunk_size);
}
for (size_t idx = 0; idx < chunk_size; ++idx) {
int buff_idx = (start_ + idx) % ring_buffer_.size();
waves->Data()[idx] = ring_buffer_[buff_idx];
}
data_length_ -= chunk_size;
start_ = (start_ + chunk_size) % ring_buffer_.size();
ready_feed_condition_.notify_one();
return true;
}
} // namespace ppspeech

@ -0,0 +1,77 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base/common.h"
#include "frontend/feature_extractor_interface.h"
namespace ppspeech {
class RawAudioCache : public FeatureExtractorInterface {
public:
explicit RawAudioCache(int buffer_size = kint16max);
virtual void Accept(const kaldi::VectorBase<BaseFloat>& waves);
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* waves);
// the audio dim is 1
virtual size_t Dim() const { return 1; }
virtual void SetFinished() {
std::lock_guard<std::mutex> lock(mutex_);
finished_ = true;
}
virtual bool IsFinished() const { return finished_; }
private:
std::vector<kaldi::BaseFloat> ring_buffer_;
size_t start_;
size_t data_length_;
bool finished_;
mutable std::mutex mutex_;
std::condition_variable ready_feed_condition_;
kaldi::int32 timeout_;
DISALLOW_COPY_AND_ASSIGN(RawAudioCache);
};
// it is a data source to test different frontend module.
// it Accepts waves or feats.
class RawDataCache: public FeatureExtractorInterface {
public:
explicit RawDataCache() { finished_ = false; }
virtual void Accept(
const kaldi::VectorBase<kaldi::BaseFloat>& inputs) {
data_ = inputs;
}
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* feats) {
if (data_.Dim() == 0) {
return false;
}
(*feats) = data_;
data_.Resize(0);
return true;
}
//the dim is data_ length
virtual size_t Dim() const { return data_.Dim(); }
virtual void SetFinished() { finished_ = true; }
virtual bool IsFinished() const { return finished_; }
private:
kaldi::Vector<kaldi::BaseFloat> data_;
bool finished_;
DISALLOW_COPY_AND_ASSIGN(RawDataCache);
};
} // namespace ppspeech
Loading…
Cancel
Save