[speechx]add kaldi-native-fbank && refactor frontend (#2794)

* replace kaldi-fbank with kaldi-native-fbank

* make kaldi-native-fbank work
pull/2854/head
YangZhou 2 years ago committed by GitHub
parent acf1d27230
commit c1b1ae0515
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -19,21 +19,12 @@ aishell_wav_scp=aishell_test.scp
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
cmvn_json2kaldi_main \
--json_file $model_dir/mean_std.json \
--cmvn_write_path $exp/cmvn.ark \
--binary=false
echo "convert json cmvn to kaldi ark."
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/feat.log \
compute_fbank_main \
--num_bins 80 \
--cmvn_file=$exp/cmvn.ark \
--cmvn_file=$model_dir/mean_std.json \
--streaming_chunk=36 \
--wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \
--feature_wspecifier=ark,scp:$data/split${nj}/JOB/fbank.ark,$data/split${nj}/JOB/fbank.scp

@ -19,7 +19,7 @@ utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recognizer.log \
u2_recognizer_main \
--use_fbank=true \
--num_bins=80 \
--cmvn_file=$exp/cmvn.ark \
--cmvn_file=$model_dir/mean_std.json \
--model_path=$model_dir/export.jit \
--vocab_path=$model_dir/unit.txt \
--nnet_decoder_chunk=16 \

@ -19,7 +19,7 @@ utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recognizer.quant.log \
u2_recognizer_main \
--use_fbank=true \
--num_bins=80 \
--cmvn_file=$exp/cmvn.ark \
--cmvn_file=$model_dir/mean_std.json \
--model_path=$model_dir/export \
--vocab_path=$model_dir/unit.txt \
--nnet_decoder_chunk=16 \

@ -22,7 +22,6 @@ if [ ! -d ${SPEECHX_BUILD} ]; then
popd
fi
ckpt_dir=$data/model
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ];then
@ -72,7 +71,7 @@ fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# process cmvn and compute fbank feat
# process compute fbank feat
./local/feat.sh
fi

@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "base/common.h"
#include "decoder/ctc_prefix_beam_search_decoder.h"
#include "base/common.h"
#include "frontend/audio/data_cache.h"
#include "fst/symbol-table.h"
#include "kaldi/util/table-types.h"
@ -124,15 +124,14 @@ int main(int argc, char* argv[]) {
}
kaldi::Vector<kaldi::BaseFloat> feature_chunk(this_chunk_size *
feat_dim);
std::vector<kaldi::BaseFloat> feature_chunk(this_chunk_size *
feat_dim);
int32 start = chunk_idx * chunk_stride;
for (int row_id = 0; row_id < this_chunk_size; ++row_id) {
kaldi::SubVector<kaldi::BaseFloat> feat_row(feature, start);
kaldi::SubVector<kaldi::BaseFloat> feature_chunk_row(
feature_chunk.Data() + row_id * feat_dim, feat_dim);
feature_chunk_row.CopyFromVec(feat_row);
std::memcpy(feature_chunk.data() + row_id * feat_dim,
feat_row.Data(),
feat_dim * sizeof(kaldi::BaseFloat));
++start;
}

@ -71,7 +71,7 @@ struct ModelOptions {
struct NnetOut {
// nnet out. maybe logprob or prob. Almost time this is logprob.
kaldi::Vector<kaldi::BaseFloat> logprobs;
std::vector<kaldi::BaseFloat> logprobs;
int32 vocab_dim;
// nnet state. Only using in Attention model.
@ -89,7 +89,7 @@ class NnetInterface {
// nnet do not cache feats, feats cached by frontend.
// nnet cache model state, i.e. encoder_outs, att_cache, cnn_cache,
// frame_offset.
virtual void FeedForward(const kaldi::Vector<kaldi::BaseFloat>& features,
virtual void FeedForward(const std::vector<kaldi::BaseFloat>& features,
const int32& feature_dim,
NnetOut* out) = 0;
@ -105,7 +105,7 @@ class NnetInterface {
// using to get encoder outs. e.g. seq2seq with Attention model.
virtual void EncoderOuts(
std::vector<kaldi::Vector<kaldi::BaseFloat>>* encoder_out) const = 0;
std::vector<std::vector<kaldi::BaseFloat>>* encoder_out) const = 0;
};

@ -17,13 +17,14 @@
namespace ppspeech {
using kaldi::Vector;
using std::vector;
using kaldi::BaseFloat;
NnetProducer::NnetProducer(std::shared_ptr<NnetBase> nnet,
std::shared_ptr<FrontendInterface> frontend)
: nnet_(nnet), frontend_(frontend) {}
void NnetProducer::Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs) {
void NnetProducer::Accept(const std::vector<kaldi::BaseFloat>& inputs) {
frontend_->Accept(inputs);
bool result = false;
do {
@ -49,26 +50,24 @@ bool NnetProducer::Read(std::vector<kaldi::BaseFloat>* nnet_prob) {
}
bool NnetProducer::Compute() {
Vector<BaseFloat> features;
vector<BaseFloat> features;
if (frontend_ == NULL || frontend_->Read(&features) == false) {
// no feat or frontend_ not init.
VLOG(3) << "no feat avalible";
return false;
}
CHECK_GE(frontend_->Dim(), 0);
VLOG(2) << "Forward in " << features.Dim() / frontend_->Dim() << " feats.";
VLOG(2) << "Forward in " << features.size() / frontend_->Dim() << " feats.";
NnetOut out;
nnet_->FeedForward(features, frontend_->Dim(), &out);
int32& vocab_dim = out.vocab_dim;
Vector<BaseFloat>& logprobs = out.logprobs;
size_t nframes = logprobs.Dim() / vocab_dim;
size_t nframes = out.logprobs.size() / vocab_dim;
VLOG(2) << "Forward out " << nframes << " decoder frames.";
std::vector<BaseFloat> logprob(vocab_dim);
for (size_t idx = 0; idx < nframes; ++idx) {
for (size_t prob_idx = 0; prob_idx < vocab_dim; ++prob_idx) {
logprob[prob_idx] = logprobs(idx * vocab_dim + prob_idx);
}
std::vector<BaseFloat> logprob(
out.logprobs.data() + idx * vocab_dim,
out.logprobs.data() + (idx + 1) * vocab_dim);
cache_.push_back(logprob);
}
return true;
@ -80,4 +79,4 @@ void NnetProducer::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
nnet_->AttentionRescoring(hyps, reverse_weight, rescoring_score);
}
} // namespace ppspeech
} // namespace ppspeech

@ -27,7 +27,7 @@ class NnetProducer {
std::shared_ptr<FrontendInterface> frontend = NULL);
// Feed feats or waves
void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs);
void Accept(const std::vector<kaldi::BaseFloat>& inputs);
void Acceptlikelihood(const kaldi::Matrix<BaseFloat>& likelihood);

@ -165,23 +165,16 @@ void U2Nnet::FeedEncoderOuts(const paddle::Tensor& encoder_out) {
}
void U2Nnet::FeedForward(const kaldi::Vector<BaseFloat>& features,
void U2Nnet::FeedForward(const std::vector<BaseFloat>& features,
const int32& feature_dim,
NnetOut* out) {
kaldi::Timer timer;
std::vector<kaldi::BaseFloat> chunk_feats(features.Data(),
features.Data() + features.Dim());
std::vector<kaldi::BaseFloat> ctc_probs;
ForwardEncoderChunkImpl(
chunk_feats, feature_dim, &ctc_probs, &out->vocab_dim);
out->logprobs.Resize(ctc_probs.size(), kaldi::kSetZero);
std::memcpy(out->logprobs.Data(),
ctc_probs.data(),
ctc_probs.size() * sizeof(kaldi::BaseFloat));
features, feature_dim, &out->logprobs, &out->vocab_dim);
VLOG(1) << "FeedForward cost: " << timer.Elapsed() << " sec. "
<< chunk_feats.size() / feature_dim << " frames.";
<< features.size() / feature_dim << " frames.";
}
@ -638,7 +631,7 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
void U2Nnet::EncoderOuts(
std::vector<kaldi::Vector<kaldi::BaseFloat>>* encoder_out) const {
std::vector<std::vector<kaldi::BaseFloat>>* encoder_out) const {
// list of (B=1,T,D)
int size = encoder_outs_.size();
VLOG(3) << "encoder_outs_ size: " << size;
@ -657,8 +650,8 @@ void U2Nnet::EncoderOuts(
const float* this_tensor_ptr = item.data<float>();
for (int j = 0; j < T; j++) {
const float* cur = this_tensor_ptr + j * D;
kaldi::Vector<kaldi::BaseFloat> out(D);
std::memcpy(out.Data(), cur, D * sizeof(kaldi::BaseFloat));
std::vector<kaldi::BaseFloat> out(D);
std::memcpy(out.data(), cur, D * sizeof(kaldi::BaseFloat));
encoder_out->emplace_back(out);
}
}

@ -76,7 +76,7 @@ class U2Nnet : public U2NnetBase {
explicit U2Nnet(const ModelOptions& opts);
U2Nnet(const U2Nnet& other);
void FeedForward(const kaldi::Vector<kaldi::BaseFloat>& features,
void FeedForward(const std::vector<kaldi::BaseFloat>& features,
const int32& feature_dim,
NnetOut* out) override;
@ -111,7 +111,7 @@ class U2Nnet : public U2NnetBase {
void FeedEncoderOuts(const paddle::Tensor& encoder_out);
void EncoderOuts(
std::vector<kaldi::Vector<kaldi::BaseFloat>>* encoder_out) const;
std::vector<std::vector<kaldi::BaseFloat>>* encoder_out) const;
private:
ModelOptions opts_;

@ -15,8 +15,8 @@ set(TEST_BINS
foreach(bin_name IN LISTS TEST_BINS)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} recognizer nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util)
target_link_libraries(${bin_name} recognizer nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-feat-common)
target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS})
target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS})
endforeach()
endforeach()

@ -19,9 +19,6 @@
namespace ppspeech {
using kaldi::BaseFloat;
using kaldi::SubVector;
using kaldi::Vector;
using kaldi::VectorBase;
using std::unique_ptr;
using std::vector;
@ -67,10 +64,10 @@ void U2Recognizer::ResetContinuousDecoding() {
}
void U2Recognizer::Accept(const VectorBase<BaseFloat>& waves) {
void U2Recognizer::Accept(const vector<BaseFloat>& waves) {
kaldi::Timer timer;
nnet_producer_->Accept(waves);
VLOG(1) << "feed waves cost: " << timer.Elapsed() << " sec. " << waves.Dim()
VLOG(1) << "feed waves cost: " << timer.Elapsed() << " sec. " << waves.size()
<< " samples.";
}

@ -115,7 +115,7 @@ class U2Recognizer {
void Reset();
void ResetContinuousDecoding();
void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& waves);
void Accept(const std::vector<kaldi::BaseFloat>& waves);
void Decode();
void Rescoring();

@ -71,9 +71,9 @@ int main(int argc, char* argv[]) {
int cur_chunk_size =
std::min(chunk_sample_size, tot_samples - sample_offset);
kaldi::Vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
std::vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk(i) = waveform(sample_offset + i);
wav_chunk[i] = waveform(sample_offset + i);
}
// wav_chunk = waveform.Range(sample_offset + i, cur_chunk_size);

@ -81,9 +81,9 @@ int main(int argc, char* argv[]) {
int cur_chunk_size =
std::min(chunk_sample_size, tot_samples - sample_offset);
kaldi::Vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
std::vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk(i) = waveform(sample_offset + i);
wav_chunk[i] = waveform(sample_offset + i);
}
// wav_chunk = waveform.Range(sample_offset + i, cur_chunk_size);

@ -1,16 +1,10 @@
include_directories(
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/base
)
include_directories(
${CMAKE_CURRENT_SOURCE_DIR}/../
${CMAKE_CURRENT_SOURCE_DIR}/utils
)
add_subdirectory(utils)
include_directories(
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/frontend
)
add_subdirectory(frontend)

@ -1,29 +1,27 @@
add_library(kaldi-native-fbank-core
feature-fbank.cc
feature-functions.cc
feature-window.cc
fftsg.c
mel-computations.cc
rfft.cc
)
add_library(frontend STATIC
cmvn.cc
db_norm.cc
linear_spectrogram.cc
audio_cache.cc
feature_cache.cc
feature_pipeline.cc
fbank.cc
assembler.cc
)
target_link_libraries(frontend PUBLIC kaldi-matrix kaldi-feat-common kaldi-fbank)
set(bin_name cmvn_json2kaldi_main)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog)
target_link_libraries(frontend PUBLIC kaldi-native-fbank-core utils)
set(BINS
compute_linear_spectrogram_main
compute_fbank_main
)
foreach(bin_name IN LISTS BINS)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} PUBLIC frontend utils kaldi-util gflags glog)
target_link_libraries(${bin_name} PUBLIC frontend utils kaldi-util gflags glog kaldi-feat-common)
endforeach()

@ -17,8 +17,8 @@
namespace ppspeech {
using kaldi::BaseFloat;
using kaldi::Vector;
using kaldi::VectorBase;
using std::vector;
using std::vector;
using std::unique_ptr;
Assembler::Assembler(AssemblerOptions opts,
@ -33,13 +33,13 @@ Assembler::Assembler(AssemblerOptions opts,
dim_ = base_extractor_->Dim();
}
void Assembler::Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs) {
void Assembler::Accept(const std::vector<BaseFloat>& inputs) {
// read inputs
base_extractor_->Accept(inputs);
}
// pop feature chunk
bool Assembler::Read(kaldi::Vector<kaldi::BaseFloat>* feats) {
bool Assembler::Read(std::vector<BaseFloat>* feats) {
kaldi::Timer timer;
bool result = Compute(feats);
VLOG(1) << "Assembler::Read cost: " << timer.Elapsed() << " sec.";
@ -47,14 +47,14 @@ bool Assembler::Read(kaldi::Vector<kaldi::BaseFloat>* feats) {
}
// read frame by frame from base_feature_extractor_ into cache_
bool Assembler::Compute(Vector<BaseFloat>* feats) {
bool Assembler::Compute(vector<BaseFloat>* feats) {
// compute and feed frame by frame
while (feature_cache_.size() < frame_chunk_size_) {
Vector<BaseFloat> feature;
vector<BaseFloat> feature;
bool result = base_extractor_->Read(&feature);
if (result == false || feature.Dim() == 0) {
if (result == false || feature.size() == 0) {
VLOG(3) << "result: " << result
<< " feature dim: " << feature.Dim();
<< " feature dim: " << feature.size();
if (IsFinished() == false) {
VLOG(3) << "finished reading feature. cache size: "
<< feature_cache_.size();
@ -65,7 +65,7 @@ bool Assembler::Compute(Vector<BaseFloat>* feats) {
}
}
CHECK(feature.Dim() == dim_);
CHECK(feature.size() == dim_);
feature_cache_.push(feature);
nframes_ += 1;
@ -73,14 +73,14 @@ bool Assembler::Compute(Vector<BaseFloat>* feats) {
}
if (feature_cache_.size() < receptive_filed_length_) {
VLOG(3) << "feature_cache less than receptive_filed_lenght. "
VLOG(3) << "feature_cache less than receptive_filed_length. "
<< feature_cache_.size() << ": " << receptive_filed_length_;
return false;
}
if (fill_zero_) {
while (feature_cache_.size() < frame_chunk_size_) {
Vector<BaseFloat> feature(dim_, kaldi::kSetZero);
vector<BaseFloat> feature(dim_, kaldi::kSetZero);
nframes_ += 1;
feature_cache_.push(feature);
}
@ -88,16 +88,17 @@ bool Assembler::Compute(Vector<BaseFloat>* feats) {
int32 this_chunk_size =
std::min(static_cast<int32>(feature_cache_.size()), frame_chunk_size_);
feats->Resize(dim_ * this_chunk_size);
feats->resize(dim_ * this_chunk_size);
VLOG(3) << "read " << this_chunk_size << " feat.";
int32 counter = 0;
while (counter < this_chunk_size) {
Vector<BaseFloat>& val = feature_cache_.front();
CHECK(val.Dim() == dim_) << val.Dim();
vector<BaseFloat>& val = feature_cache_.front();
CHECK(val.size() == dim_) << val.size();
int32 start = counter * dim_;
feats->Range(start, dim_).CopyFromVec(val);
std::memcpy(feats->data() + start,
val.data(), val.size() * sizeof(BaseFloat));
if (this_chunk_size - counter <= cache_size_) {
feature_cache_.push(val);
@ -115,7 +116,7 @@ bool Assembler::Compute(Vector<BaseFloat>* feats) {
void Assembler::Reset() {
std::queue<kaldi::Vector<kaldi::BaseFloat>> empty;
std::queue<std::vector<BaseFloat>> empty;
std::swap(feature_cache_, empty);
nframes_ = 0;
base_extractor_->Reset();

@ -36,10 +36,10 @@ class Assembler : public FrontendInterface {
std::unique_ptr<FrontendInterface> base_extractor = NULL);
// Feed feats or waves
void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs) override;
void Accept(const std::vector<kaldi::BaseFloat>& inputs) override;
// feats size = num_frames * feat_dim
bool Read(kaldi::Vector<kaldi::BaseFloat>* feats) override;
bool Read(std::vector<kaldi::BaseFloat>* feats) override;
// feat dim
size_t Dim() const override { return dim_; }
@ -51,7 +51,7 @@ class Assembler : public FrontendInterface {
void Reset() override;
private:
bool Compute(kaldi::Vector<kaldi::BaseFloat>* feats);
bool Compute(std::vector<kaldi::BaseFloat>* feats);
bool fill_zero_{false};
@ -60,7 +60,7 @@ class Assembler : public FrontendInterface {
int32 frame_chunk_stride_; // stride
int32 cache_size_; // window - stride
int32 receptive_filed_length_;
std::queue<kaldi::Vector<kaldi::BaseFloat>> feature_cache_;
std::queue<std::vector<kaldi::BaseFloat>> feature_cache_;
std::unique_ptr<FrontendInterface> base_extractor_;
int32 nframes_; // num frame computed

@ -19,8 +19,7 @@
namespace ppspeech {
using kaldi::BaseFloat;
using kaldi::Vector;
using kaldi::VectorBase;
using std::vector;
AudioCache::AudioCache(int buffer_size, bool to_float32)
: finished_(false),
@ -37,25 +36,25 @@ BaseFloat AudioCache::Convert2PCM32(BaseFloat val) {
return val * (1. / std::pow(2.0, 15));
}
void AudioCache::Accept(const VectorBase<BaseFloat>& waves) {
void AudioCache::Accept(const vector<BaseFloat>& waves) {
kaldi::Timer timer;
std::unique_lock<std::mutex> lock(mutex_);
while (size_ + waves.Dim() > ring_buffer_.size()) {
while (size_ + waves.size() > ring_buffer_.size()) {
ready_feed_condition_.wait(lock);
}
for (size_t idx = 0; idx < waves.Dim(); ++idx) {
for (size_t idx = 0; idx < waves.size(); ++idx) {
int32 buffer_idx = (idx + offset_ + size_) % ring_buffer_.size();
ring_buffer_[buffer_idx] = waves(idx);
if (to_float32_) ring_buffer_[buffer_idx] = Convert2PCM32(waves(idx));
ring_buffer_[buffer_idx] = waves[idx];
if (to_float32_) ring_buffer_[buffer_idx] = Convert2PCM32(waves[idx]);
}
size_ += waves.Dim();
size_ += waves.size();
VLOG(1) << "AudioCache::Accept cost: " << timer.Elapsed() << " sec. "
<< waves.Dim() << " samples.";
<< waves.size() << " samples.";
}
bool AudioCache::Read(Vector<BaseFloat>* waves) {
bool AudioCache::Read(vector<BaseFloat>* waves) {
kaldi::Timer timer;
size_t chunk_size = waves->Dim();
size_t chunk_size = waves->size();
std::unique_lock<std::mutex> lock(mutex_);
while (chunk_size > size_) {
// when audio is empty and no more data feed
@ -78,12 +77,12 @@ bool AudioCache::Read(Vector<BaseFloat>* waves) {
// read last chunk data
if (chunk_size > size_) {
chunk_size = size_;
waves->Resize(chunk_size);
waves->resize(chunk_size);
}
for (size_t idx = 0; idx < chunk_size; ++idx) {
int buff_idx = (offset_ + idx) % ring_buffer_.size();
waves->Data()[idx] = ring_buffer_[buff_idx];
waves->at(idx) = ring_buffer_[buff_idx];
}
size_ -= chunk_size;
offset_ = (offset_ + chunk_size) % ring_buffer_.size();

@ -26,9 +26,9 @@ class AudioCache : public FrontendInterface {
explicit AudioCache(int buffer_size = 1000 * kint16max,
bool to_float32 = false);
virtual void Accept(const kaldi::VectorBase<BaseFloat>& waves);
virtual void Accept(const std::vector<BaseFloat>& waves);
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* waves);
virtual bool Read(std::vector<kaldi::BaseFloat>* waves);
// the audio dim is 1, one sample, which is useless,
// so we return size_(cache samples) instead.

@ -15,15 +15,12 @@
#include "frontend/audio/cmvn.h"
#include "kaldi/feat/cmvn.h"
#include "kaldi/util/kaldi-io.h"
#include "utils/file_utils.h"
#include "utils/picojson.h"
namespace ppspeech {
using kaldi::BaseFloat;
using kaldi::SubVector;
using kaldi::Vector;
using kaldi::VectorBase;
using std::unique_ptr;
using std::vector;
@ -32,22 +29,46 @@ CMVN::CMVN(std::string cmvn_file, unique_ptr<FrontendInterface> base_extractor)
: var_norm_(true) {
CHECK_NE(cmvn_file, "");
base_extractor_ = std::move(base_extractor);
ReadCMVNFromJson(cmvn_file);
dim_ = mean_stats_.size() - 1;
}
void CMVN::ReadCMVNFromJson(string cmvn_file) {
std::string json_str = ppspeech::ReadFile2String(cmvn_file);
picojson::value value;
std::string err;
const char* json_end = picojson::parse(
value, json_str.c_str(), json_str.c_str() + json_str.size(), &err);
if (!value.is<picojson::object>()) {
LOG(ERROR) << "Input json file format error.";
}
const picojson::value::array& mean_stat =
value.get("mean_stat").get<picojson::array>();
for (auto it = mean_stat.begin(); it != mean_stat.end(); it++) {
mean_stats_.push_back((*it).get<double>());
}
const picojson::value::array& var_stat =
value.get("var_stat").get<picojson::array>();
for (auto it = var_stat.begin(); it != var_stat.end(); it++) {
var_stats_.push_back((*it).get<double>());
}
bool binary;
kaldi::Input ki(cmvn_file, &binary);
stats_.Read(ki.Stream(), binary);
dim_ = stats_.NumCols() - 1;
kaldi::int32 frame_num = value.get("frame_num").get<int64_t>();
LOG(INFO) << "nframe: " << frame_num;
mean_stats_.push_back(frame_num);
var_stats_.push_back(0);
}
void CMVN::Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs) {
void CMVN::Accept(const std::vector<kaldi::BaseFloat>& inputs) {
// feed waves/feats to compute feature
base_extractor_->Accept(inputs);
return;
}
bool CMVN::Read(kaldi::Vector<BaseFloat>* feats) {
bool CMVN::Read(std::vector<BaseFloat>* feats) {
// compute feature
if (base_extractor_->Read(feats) == false || feats->Dim() == 0) {
if (base_extractor_->Read(feats) == false || feats->size() == 0) {
return false;
}
@ -59,74 +80,78 @@ bool CMVN::Read(kaldi::Vector<BaseFloat>* feats) {
}
// feats contain num_frames feature.
void CMVN::Compute(VectorBase<BaseFloat>* feats) const {
void CMVN::Compute(vector<BaseFloat>* feats) const {
KALDI_ASSERT(feats != NULL);
if (stats_.NumRows() > 2 || stats_.NumRows() < 1 ||
feats->Dim() % dim_ != 0) {
KALDI_ERR << "Dim mismatch: cmvn " << stats_.NumRows() << ','
<< stats_.NumCols() - 1 << ", feats " << feats->Dim() << 'x';
if (feats->size() % dim_ != 0) {
LOG(ERROR)<< "Dim mismatch: cmvn " << mean_stats_.size() << ','
<< var_stats_.size() - 1 << ", feats " << feats->size() << 'x';
}
if (stats_.NumRows() == 1 && var_norm_) {
KALDI_ERR
if (var_stats_.size() == 0 && var_norm_) {
LOG(ERROR)
<< "You requested variance normalization but no variance stats_ "
<< "are supplied.";
}
double count = stats_(0, dim_);
double count = mean_stats_[dim_];
// Do not change the threshold of 1.0 here: in the balanced-cmvn code, when
// computing an offset and representing it as stats_, we use a count of one.
if (count < 1.0)
KALDI_ERR << "Insufficient stats_ for cepstral mean and variance "
LOG(ERROR) << "Insufficient stats_ for cepstral mean and variance "
"normalization: "
<< "count = " << count;
if (!var_norm_) {
Vector<BaseFloat> offset(feats->Dim());
SubVector<double> mean_stats(stats_.RowData(0), dim_);
Vector<double> mean_stats_apply(feats->Dim());
vector<BaseFloat> offset(feats->size());
vector<double> mean_stats(mean_stats_);
for (size_t i = 0; i < mean_stats.size(); ++i) {
mean_stats[i] /= count;
}
vector<double> mean_stats_apply(feats->size());
// fill the datat of mean_stats in mean_stats_appy whose dim_ is equal
// with the dim_ of feature.
// the dim_ of feats = dim_ * num_frames;
for (int32 idx = 0; idx < feats->Dim() / dim_; ++idx) {
SubVector<double> stats_tmp(mean_stats_apply.Data() + dim_ * idx,
dim_);
stats_tmp.CopyFromVec(mean_stats);
for (int32 idx = 0; idx < feats->size() / dim_; ++idx) {
std::memcpy(mean_stats_apply.data() + dim_ * idx,
mean_stats.data(), dim_* sizeof(double));
}
for (size_t idx = 0; idx < feats->size(); ++idx) {
feats->at(idx) += offset[idx];
}
offset.AddVec(-1.0 / count, mean_stats_apply);
feats->AddVec(1.0, offset);
return;
}
// norm(0, d) = mean offset;
// norm(1, d) = scale, e.g. x(d) <-- x(d)*norm(1, d) + norm(0, d).
kaldi::Matrix<BaseFloat> norm(2, feats->Dim());
vector<BaseFloat> norm0(feats->size());
vector<BaseFloat> norm1(feats->size());
for (int32 d = 0; d < dim_; d++) {
double mean, offset, scale;
mean = stats_(0, d) / count;
double var = (stats_(1, d) / count) - mean * mean, floor = 1.0e-20;
mean = mean_stats_[d] / count;
double var = (var_stats_[d] / count) - mean * mean, floor = 1.0e-20;
if (var < floor) {
KALDI_WARN << "Flooring cepstral variance from " << var << " to "
LOG(WARNING) << "Flooring cepstral variance from " << var << " to "
<< floor;
var = floor;
}
scale = 1.0 / sqrt(var);
if (scale != scale || 1 / scale == 0.0)
KALDI_ERR
LOG(ERROR)
<< "NaN or infinity in cepstral mean/variance computation";
offset = -(mean * scale);
for (int32 d_skip = d; d_skip < feats->Dim();) {
norm(0, d_skip) = offset;
norm(1, d_skip) = scale;
for (int32 d_skip = d; d_skip < feats->size();) {
norm0[d_skip] = offset;
norm1[d_skip] = scale;
d_skip = d_skip + dim_;
}
}
// Apply the normalization.
feats->MulElements(norm.Row(1));
feats->AddVec(1.0, norm.Row(0));
}
for (size_t idx = 0; idx < feats->size(); ++idx) {
feats->at(idx) *= norm1[idx];
}
void CMVN::ApplyCMVN(kaldi::MatrixBase<BaseFloat>* feats) {
ApplyCmvn(stats_, var_norm_, feats);
for (size_t idx = 0; idx < feats->size(); ++idx) {
feats->at(idx) += norm0[idx];
}
}
} // namespace ppspeech

@ -25,11 +25,11 @@ class CMVN : public FrontendInterface {
public:
explicit CMVN(std::string cmvn_file,
std::unique_ptr<FrontendInterface> base_extractor);
virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs);
virtual void Accept(const std::vector<kaldi::BaseFloat>& inputs);
// the length of feats = feature_row * feature_dim,
// the Matrix is squashed into Vector
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* feats);
virtual bool Read(std::vector<kaldi::BaseFloat>* feats);
// the dim_ is the feautre dim.
virtual size_t Dim() const { return dim_; }
virtual void SetFinished() { base_extractor_->SetFinished(); }
@ -37,9 +37,10 @@ class CMVN : public FrontendInterface {
virtual void Reset() { base_extractor_->Reset(); }
private:
void Compute(kaldi::VectorBase<kaldi::BaseFloat>* feats) const;
void ApplyCMVN(kaldi::MatrixBase<BaseFloat>* feats);
kaldi::Matrix<double> stats_;
void ReadCMVNFromJson(std::string cmvn_file);
void Compute(std::vector<kaldi::BaseFloat>* feats) const;
std::vector<double> mean_stats_;
std::vector<double> var_stats_;
std::unique_ptr<FrontendInterface> base_extractor_;
size_t dim_;
bool var_norm_;

@ -1,98 +0,0 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Note: Do not print/log ondemand object.
#include "base/common.h"
#include "base/flags.h"
#include "base/log.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "kaldi/util/kaldi-io.h"
#include "utils/file_utils.h"
#include "utils/picojson.h"
DEFINE_string(json_file, "", "cmvn json file");
DEFINE_string(cmvn_write_path, "./cmvn.ark", "write cmvn");
DEFINE_bool(binary, true, "write cmvn in binary (true) or text(false)");
int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
LOG(INFO) << "cmvn josn path: " << FLAGS_json_file;
auto ifs = std::ifstream(FLAGS_json_file);
std::string json_str = ppspeech::ReadFile2String(FLAGS_json_file);
picojson::value value;
std::string err;
const char* json_end = picojson::parse(
value, json_str.c_str(), json_str.c_str() + json_str.size(), &err);
if (!value.is<picojson::object>()) {
LOG(ERROR) << "Input json file format error.";
}
const picojson::value::object& obj = value.get<picojson::object>();
for (picojson::value::object::const_iterator elem = obj.begin();
elem != obj.end();
++elem) {
if (elem->first == "mean_stat") {
VLOG(2) << "mean_stat:" << elem->second;
// const picojson::value tmp =
// elem->second.get(0);//<picojson::array>();
double tmp =
elem->second.get(0).get<double>(); //<picojson::array>();
VLOG(2) << "tmp: " << tmp;
}
if (elem->first == "var_stat") {
VLOG(2) << "var_stat: " << elem->second;
}
if (elem->first == "frame_num") {
VLOG(2) << "frame_num: " << elem->second;
}
}
const picojson::value::array& mean_stat =
value.get("mean_stat").get<picojson::array>();
std::vector<kaldi::BaseFloat> mean_stat_vec;
for (auto it = mean_stat.begin(); it != mean_stat.end(); it++) {
mean_stat_vec.push_back((*it).get<double>());
}
const picojson::value::array& var_stat =
value.get("var_stat").get<picojson::array>();
std::vector<kaldi::BaseFloat> var_stat_vec;
for (auto it = var_stat.begin(); it != var_stat.end(); it++) {
var_stat_vec.push_back((*it).get<double>());
}
kaldi::int32 frame_num = value.get("frame_num").get<int64_t>();
LOG(INFO) << "nframe: " << frame_num;
size_t mean_size = mean_stat_vec.size();
kaldi::Matrix<double> cmvn_stats(2, mean_size + 1);
for (size_t idx = 0; idx < mean_size; ++idx) {
cmvn_stats(0, idx) = mean_stat_vec[idx];
cmvn_stats(1, idx) = var_stat_vec[idx];
}
cmvn_stats(0, mean_size) = frame_num;
VLOG(2) << cmvn_stats;
kaldi::WriteKaldiObject(cmvn_stats, FLAGS_cmvn_write_path, FLAGS_binary);
LOG(INFO) << "cmvn stats have write into: " << FLAGS_cmvn_write_path;
LOG(INFO) << "Binary: " << FLAGS_binary;
return 0;
}

@ -56,7 +56,7 @@ int main(int argc, char* argv[]) {
std::unique_ptr<ppspeech::FrontendInterface> data_source(
new ppspeech::AudioCache(3600 * 1600, false));
kaldi::FbankOptions opt;
knf::FbankOptions opt;
opt.frame_opts.frame_length_ms = 25;
opt.frame_opts.frame_shift_ms = 10;
opt.mel_opts.num_bins = FLAGS_num_bins;
@ -117,9 +117,9 @@ int main(int argc, char* argv[]) {
std::min(chunk_sample_size, tot_samples - sample_offset);
// get chunk wav
kaldi::Vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
std::vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk(i) = waveform(sample_offset + i);
wav_chunk[i] = waveform(sample_offset + i);
}
// compute feat
@ -131,10 +131,14 @@ int main(int argc, char* argv[]) {
}
// read feat
kaldi::Vector<BaseFloat> features;
kaldi::Vector<BaseFloat> features(feature_cache.Dim());
bool flag = true;
do {
flag = feature_cache.Read(&features);
std::vector<BaseFloat> tmp;
flag = feature_cache.Read(&tmp);
std::memcpy(features.Data(),
tmp.data(),
tmp.size() * sizeof(BaseFloat));
if (flag && features.Dim() != 0) {
feats.push_back(features);
feature_rows += features.Dim() / feature_cache.Dim();

@ -15,10 +15,10 @@
#pragma once
#include "base/common.h"
#include "frontend/audio/frontend_itf.h"
using std::vector;
namespace ppspeech {
@ -30,16 +30,16 @@ class DataCache : public FrontendInterface {
DataCache() : finished_{false}, dim_{0} {}
// accept waves/feats
void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs) override {
data_ = inputs;
void Accept(const std::vector<kaldi::BaseFloat>& inputs) override {
data_ = std::move(inputs);
}
bool Read(kaldi::Vector<kaldi::BaseFloat>* feats) override {
if (data_.Dim() == 0) {
bool Read(vector<kaldi::BaseFloat>* feats) override {
if (data_.size() == 0) {
return false;
}
(*feats) = data_;
data_.Resize(0);
(*feats) = std::move(data_);
data_.resize(0);
return true;
}
@ -53,7 +53,7 @@ class DataCache : public FrontendInterface {
}
private:
kaldi::Vector<kaldi::BaseFloat> data_;
std::vector<kaldi::BaseFloat> data_;
bool finished_;
int32 dim_;

@ -16,35 +16,10 @@
#include "base/common.h"
#include "frontend/audio/feature_common.h"
#include "frontend/audio/frontend_itf.h"
#include "kaldi/feat/feature-fbank.h"
#include "kaldi/feat/feature-mfcc.h"
#include "kaldi/matrix/kaldi-vector.h"
#include "frontend/audio/feature-fbank.h"
namespace ppspeech {
class FbankComputer {
public:
typedef kaldi::FbankOptions Options;
explicit FbankComputer(const Options& opts);
kaldi::FrameExtractionOptions& GetFrameOptions() {
return opts_.frame_opts;
}
bool Compute(kaldi::Vector<kaldi::BaseFloat>* window,
kaldi::Vector<kaldi::BaseFloat>* feat);
int32 Dim() const;
bool NeedRawLogEnergy();
private:
Options opts_;
kaldi::FbankComputer computer_;
DISALLOW_COPY_AND_ASSIGN(FbankComputer);
};
typedef StreamingFeatureTpl<FbankComputer> Fbank;
typedef StreamingFeatureTpl<knf::FbankComputer> Fbank;
} // namespace ppspeech

@ -0,0 +1,123 @@
/**
* Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// This file is copied/modified from kaldi/src/feat/feature-fbank.cc
//
#include "frontend/audio/feature-fbank.h"
#include <cmath>
#include "frontend/audio/feature-functions.h"
namespace knf {
static void Sqrt(float *in_out, int32_t n) {
for (int32_t i = 0; i != n; ++i) {
in_out[i] = std::sqrt(in_out[i]);
}
}
std::ostream &operator<<(std::ostream &os, const FbankOptions &opts) {
os << opts.ToString();
return os;
}
FbankComputer::FbankComputer(const FbankOptions &opts)
: opts_(opts), rfft_(opts.frame_opts.PaddedWindowSize()) {
if (opts.energy_floor > 0.0f) {
log_energy_floor_ = logf(opts.energy_floor);
}
// We'll definitely need the filterbanks info for VTLN warping factor 1.0.
// [note: this call caches it.]
GetMelBanks(1.0f);
}
FbankComputer::~FbankComputer() {
for (auto iter = mel_banks_.begin(); iter != mel_banks_.end(); ++iter)
delete iter->second;
}
const MelBanks *FbankComputer::GetMelBanks(float vtln_warp) {
MelBanks *this_mel_banks = nullptr;
// std::map<float, MelBanks *>::iterator iter = mel_banks_.find(vtln_warp);
auto iter = mel_banks_.find(vtln_warp);
if (iter == mel_banks_.end()) {
this_mel_banks =
new MelBanks(opts_.mel_opts, opts_.frame_opts, vtln_warp);
mel_banks_[vtln_warp] = this_mel_banks;
} else {
this_mel_banks = iter->second;
}
return this_mel_banks;
}
void FbankComputer::Compute(float signal_raw_log_energy,
float vtln_warp,
std::vector<float> *signal_frame,
float *feature) {
const MelBanks &mel_banks = *(GetMelBanks(vtln_warp));
CHECK_EQ(signal_frame->size(), opts_.frame_opts.PaddedWindowSize());
// Compute energy after window function (not the raw one).
if (opts_.use_energy && !opts_.raw_energy) {
signal_raw_log_energy =
std::log(std::max<float>(InnerProduct(signal_frame->data(),
signal_frame->data(),
signal_frame->size()),
std::numeric_limits<float>::epsilon()));
}
rfft_.Compute(signal_frame->data()); // signal_frame is modified in-place
ComputePowerSpectrum(signal_frame);
// Use magnitude instead of power if requested.
if (!opts_.use_power) {
Sqrt(signal_frame->data(), signal_frame->size() / 2 + 1);
}
int32_t mel_offset = ((opts_.use_energy && !opts_.htk_compat) ? 1 : 0);
// Its length is opts_.mel_opts.num_bins
float *mel_energies = feature + mel_offset;
// Sum with mel filter banks over the power spectrum
mel_banks.Compute(signal_frame->data(), mel_energies);
if (opts_.use_log_fbank) {
// Avoid log of zero (which should be prevented anyway by dithering).
for (int32_t i = 0; i != opts_.mel_opts.num_bins; ++i) {
auto t = std::max(mel_energies[i],
std::numeric_limits<float>::epsilon());
mel_energies[i] = std::log(t);
}
}
// Copy energy as first value (or the last, if htk_compat == true).
if (opts_.use_energy) {
if (opts_.energy_floor > 0.0 &&
signal_raw_log_energy < log_energy_floor_) {
signal_raw_log_energy = log_energy_floor_;
}
int32_t energy_index = opts_.htk_compat ? opts_.mel_opts.num_bins : 0;
feature[energy_index] = signal_raw_log_energy;
}
}
} // namespace knf

@ -0,0 +1,137 @@
/**
* Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// This file is copied/modified from kaldi/src/feat/feature-fbank.h
#ifndef KALDI_NATIVE_FBANK_CSRC_FEATURE_FBANK_H_
#define KALDI_NATIVE_FBANK_CSRC_FEATURE_FBANK_H_
#include <map>
#include "frontend/audio/feature-window.h"
#include "frontend/audio/mel-computations.h"
#include "frontend/audio/rfft.h"
namespace knf {
struct FbankOptions {
FrameExtractionOptions frame_opts;
MelBanksOptions mel_opts;
// append an extra dimension with energy to the filter banks
bool use_energy = false;
float energy_floor = 0.0f; // active iff use_energy==true
// If true, compute log_energy before preemphasis and windowing
// If false, compute log_energy after preemphasis ans windowing
bool raw_energy = true; // active iff use_energy==true
// If true, put energy last (if using energy)
// If false, put energy first
bool htk_compat = false; // active iff use_energy==true
// if true (default), produce log-filterbank, else linear
bool use_log_fbank = true;
// if true (default), use power in filterbank
// analysis, else magnitude.
bool use_power = true;
FbankOptions() { mel_opts.num_bins = 23; }
std::string ToString() const {
std::ostringstream os;
os << "frame_opts: \n";
os << frame_opts << "\n";
os << "\n";
os << "mel_opts: \n";
os << mel_opts << "\n";
os << "use_energy: " << use_energy << "\n";
os << "energy_floor: " << energy_floor << "\n";
os << "raw_energy: " << raw_energy << "\n";
os << "htk_compat: " << htk_compat << "\n";
os << "use_log_fbank: " << use_log_fbank << "\n";
os << "use_power: " << use_power << "\n";
return os.str();
}
};
std::ostream &operator<<(std::ostream &os, const FbankOptions &opts);
class FbankComputer {
public:
using Options = FbankOptions;
explicit FbankComputer(const FbankOptions &opts);
~FbankComputer();
int32_t Dim() const {
return opts_.mel_opts.num_bins + (opts_.use_energy ? 1 : 0);
}
// if true, compute log_energy_pre_window but after dithering and dc removal
bool NeedRawLogEnergy() const {
return opts_.use_energy && opts_.raw_energy;
}
const FrameExtractionOptions &GetFrameOptions() const {
return opts_.frame_opts;
}
const FbankOptions &GetOptions() const { return opts_; }
/**
Function that computes one frame of features from
one frame of signal.
@param [in] signal_raw_log_energy The log-energy of the frame of the
signal
prior to windowing and pre-emphasis, or
log(numeric_limits<float>::min()), whichever is greater. Must be
ignored by this function if this class returns false from
this->NeedsRawLogEnergy().
@param [in] vtln_warp The VTLN warping factor that the user wants
to be applied when computing features for this utterance. Will
normally be 1.0, meaning no warping is to be done. The value will
be ignored for feature types that don't support VLTN, such as
spectrogram features.
@param [in] signal_frame One frame of the signal,
as extracted using the function ExtractWindow() using the options
returned by this->GetFrameOptions(). The function will use the
vector as a workspace, which is why it's a non-const pointer.
@param [out] feature Pointer to a vector of size this->Dim(), to which
the computed feature will be written. It should be pre-allocated.
*/
void Compute(float signal_raw_log_energy,
float vtln_warp,
std::vector<float> *signal_frame,
float *feature);
private:
const MelBanks *GetMelBanks(float vtln_warp);
FbankOptions opts_;
float log_energy_floor_;
std::map<float, MelBanks *> mel_banks_; // float is VTLN coefficient.
Rfft rfft_;
};
} // namespace knf
#endif // KALDI_NATIVE_FBANK_CSRC_FEATURE_FBANK_H_

@ -0,0 +1,49 @@
/**
* Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// This file is copied/modified from kaldi/src/feat/feature-functions.cc
#include "frontend/audio/feature-functions.h"
#include <cstdint>
#include <vector>
namespace knf {
void ComputePowerSpectrum(std::vector<float> *complex_fft) {
int32_t dim = complex_fft->size();
// now we have in complex_fft, first half of complex spectrum
// it's stored as [real0, realN/2, real1, im1, real2, im2, ...]
float *p = complex_fft->data();
int32_t half_dim = dim / 2;
float first_energy = p[0] * p[0];
float last_energy = p[1] * p[1]; // handle this special case
for (int32_t i = 1; i < half_dim; ++i) {
float real = p[i * 2];
float im = p[i * 2 + 1];
p[i] = real * real + im * im;
}
p[0] = first_energy;
p[half_dim] = last_energy; // Will actually never be used, and anyway
// if the signal has been bandlimited sensibly this should be zero.
}
} // namespace knf

@ -0,0 +1,38 @@
/**
* Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// This file is copied/modified from kaldi/src/feat/feature-functions.h
#ifndef KALDI_NATIVE_FBANK_CSRC_FEATURE_FUNCTIONS_H
#define KALDI_NATIVE_FBANK_CSRC_FEATURE_FUNCTIONS_H
#include <vector>
namespace knf {
// ComputePowerSpectrum converts a complex FFT (as produced by the FFT
// functions in csrc/rfft.h), and converts it into
// a power spectrum. If the complex FFT is a vector of size n (representing
// half of the complex FFT of a real signal of size n, as described there),
// this function computes in the first (n/2) + 1 elements of it, the
// energies of the fft bins from zero to the Nyquist frequency. Contents of the
// remaining (n/2) - 1 elements are undefined at output.
void ComputePowerSpectrum(std::vector<float> *complex_fft);
} // namespace knf
#endif // KALDI_NATIVE_FBANK_CSRC_FEATURE_FUNCTIONS_H

@ -0,0 +1,247 @@
// kaldi-native-fbank/csrc/feature-window.cc
//
// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
// This file is copied/modified from kaldi/src/feat/feature-window.cc
#include "frontend/audio/feature-window.h"
#include <cmath>
#include <vector>
#ifndef M_2PI
#define M_2PI 6.283185307179586476925286766559005
#endif
namespace knf {
std::ostream &operator<<(std::ostream &os, const FrameExtractionOptions &opts) {
os << opts.ToString();
return os;
}
FeatureWindowFunction::FeatureWindowFunction(const FrameExtractionOptions &opts)
: window_(opts.WindowSize()) {
int32_t frame_length = opts.WindowSize();
CHECK_GT(frame_length, 0);
float *window_data = window_.data();
double a = M_2PI / (frame_length - 1);
for (int32_t i = 0; i < frame_length; i++) {
double i_fl = static_cast<double>(i);
if (opts.window_type == "hanning") {
window_data[i] = 0.5 - 0.5 * cos(a * i_fl);
} else if (opts.window_type == "sine") {
// when you are checking ws wikipedia, please
// note that 0.5 * a = M_PI/(frame_length-1)
window_data[i] = sin(0.5 * a * i_fl);
} else if (opts.window_type == "hamming") {
window_data[i] = 0.54 - 0.46 * cos(a * i_fl);
} else if (opts.window_type ==
"povey") { // like hamming but goes to zero at edges.
window_data[i] = pow(0.5 - 0.5 * cos(a * i_fl), 0.85);
} else if (opts.window_type == "rectangular") {
window_data[i] = 1.0;
} else if (opts.window_type == "blackman") {
window_data[i] = opts.blackman_coeff - 0.5 * cos(a * i_fl) +
(0.5 - opts.blackman_coeff) * cos(2 * a * i_fl);
} else {
LOG(FATAL) << "Invalid window type " << opts.window_type;
}
}
}
void FeatureWindowFunction::Apply(float *wave) const {
int32_t window_size = window_.size();
const float *p = window_.data();
for (int32_t k = 0; k != window_size; ++k) {
wave[k] *= p[k];
}
}
int64_t FirstSampleOfFrame(int32_t frame, const FrameExtractionOptions &opts) {
int64_t frame_shift = opts.WindowShift();
if (opts.snip_edges) {
return frame * frame_shift;
} else {
int64_t midpoint_of_frame = frame_shift * frame + frame_shift / 2,
beginning_of_frame = midpoint_of_frame - opts.WindowSize() / 2;
return beginning_of_frame;
}
}
int32_t NumFrames(int64_t num_samples,
const FrameExtractionOptions &opts,
bool flush /*= true*/) {
int64_t frame_shift = opts.WindowShift();
int64_t frame_length = opts.WindowSize();
if (opts.snip_edges) {
// with --snip-edges=true (the default), we use a HTK-like approach to
// determining the number of frames-- all frames have to fit completely
// into
// the waveform, and the first frame begins at sample zero.
if (num_samples < frame_length)
return 0;
else
return (1 + ((num_samples - frame_length) / frame_shift));
// You can understand the expression above as follows: 'num_samples -
// frame_length' is how much room we have to shift the frame within the
// waveform; 'frame_shift' is how much we shift it each time; and the
// ratio
// is how many times we can shift it (integer arithmetic rounds down).
} else {
// if --snip-edges=false, the number of frames is determined by rounding
// the
// (file-length / frame-shift) to the nearest integer. The point of
// this
// formula is to make the number of frames an obvious and predictable
// function of the frame shift and signal length, which makes many
// segmentation-related questions simpler.
//
// Because integer division in C++ rounds toward zero, we add (half the
// frame-shift minus epsilon) before dividing, to have the effect of
// rounding towards the closest integer.
int32_t num_frames = (num_samples + (frame_shift / 2)) / frame_shift;
if (flush) return num_frames;
// note: 'end' always means the last plus one, i.e. one past the last.
int64_t end_sample_of_last_frame =
FirstSampleOfFrame(num_frames - 1, opts) + frame_length;
// the following code is optimized more for clarity than efficiency.
// If flush == false, we can't output frames that extend past the end
// of the signal.
while (num_frames > 0 && end_sample_of_last_frame > num_samples) {
num_frames--;
end_sample_of_last_frame -= frame_shift;
}
return num_frames;
}
}
void ExtractWindow(int64_t sample_offset,
const std::vector<float> &wave,
int32_t f,
const FrameExtractionOptions &opts,
const FeatureWindowFunction &window_function,
std::vector<float> *window,
float *log_energy_pre_window /*= nullptr*/) {
CHECK(sample_offset >= 0 && wave.size() != 0);
int32_t frame_length = opts.WindowSize();
int32_t frame_length_padded = opts.PaddedWindowSize();
int64_t num_samples = sample_offset + wave.size();
int64_t start_sample = FirstSampleOfFrame(f, opts);
int64_t end_sample = start_sample + frame_length;
if (opts.snip_edges) {
CHECK(start_sample >= sample_offset && end_sample <= num_samples);
} else {
CHECK(sample_offset == 0 || start_sample >= sample_offset);
}
if (window->size() != frame_length_padded) {
window->resize(frame_length_padded);
}
// wave_start and wave_end are start and end indexes into 'wave', for the
// piece of wave that we're trying to extract.
int32_t wave_start = int32_t(start_sample - sample_offset);
int32_t wave_end = wave_start + frame_length;
if (wave_start >= 0 && wave_end <= wave.size()) {
// the normal case-- no edge effects to consider.
std::copy(wave.begin() + wave_start,
wave.begin() + wave_start + frame_length,
window->data());
} else {
// Deal with any end effects by reflection, if needed. This code will
// only
// be reached for about two frames per utterance, so we don't concern
// ourselves excessively with efficiency.
int32_t wave_dim = wave.size();
for (int32_t s = 0; s < frame_length; ++s) {
int32_t s_in_wave = s + wave_start;
while (s_in_wave < 0 || s_in_wave >= wave_dim) {
// reflect around the beginning or end of the wave.
// e.g. -1 -> 0, -2 -> 1.
// dim -> dim - 1, dim + 1 -> dim - 2.
// the code supports repeated reflections, although this
// would only be needed in pathological cases.
if (s_in_wave < 0)
s_in_wave = -s_in_wave - 1;
else
s_in_wave = 2 * wave_dim - 1 - s_in_wave;
}
(*window)[s] = wave[s_in_wave];
}
}
ProcessWindow(opts, window_function, window->data(), log_energy_pre_window);
}
static void RemoveDcOffset(float *d, int32_t n) {
float sum = 0;
for (int32_t i = 0; i != n; ++i) {
sum += d[i];
}
float mean = sum / n;
for (int32_t i = 0; i != n; ++i) {
d[i] -= mean;
}
}
float InnerProduct(const float *a, const float *b, int32_t n) {
float sum = 0;
for (int32_t i = 0; i != n; ++i) {
sum += a[i] * b[i];
}
return sum;
}
static void Preemphasize(float *d, int32_t n, float preemph_coeff) {
if (preemph_coeff == 0.0) {
return;
}
CHECK(preemph_coeff >= 0.0 && preemph_coeff <= 1.0);
for (int32_t i = n - 1; i > 0; --i) {
d[i] -= preemph_coeff * d[i - 1];
}
d[0] -= preemph_coeff * d[0];
}
void ProcessWindow(const FrameExtractionOptions &opts,
const FeatureWindowFunction &window_function,
float *window,
float *log_energy_pre_window /*= nullptr*/) {
int32_t frame_length = opts.WindowSize();
// TODO(fangjun): Remove dither
CHECK_EQ(opts.dither, 0);
if (opts.remove_dc_offset) {
RemoveDcOffset(window, frame_length);
}
if (log_energy_pre_window != NULL) {
float energy =
std::max<float>(InnerProduct(window, window, frame_length),
std::numeric_limits<float>::epsilon());
*log_energy_pre_window = std::log(energy);
}
if (opts.preemph_coeff != 0.0) {
Preemphasize(window, frame_length, opts.preemph_coeff);
}
window_function.Apply(window);
}
} // namespace knf

@ -0,0 +1,183 @@
// kaldi-native-fbank/csrc/feature-window.h
//
// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
// This file is copied/modified from kaldi/src/feat/feature-window.h
#ifndef KALDI_NATIVE_FEAT_CSRC_FEATURE_WINDOW_H_
#define KALDI_NATIVE_FEAT_CSRC_FEATURE_WINDOW_H_
#include <sstream>
#include <string>
#include <vector>
#include "base/log.h"
namespace knf {
inline int32_t RoundUpToNearestPowerOfTwo(int32_t n) {
// copied from kaldi/src/base/kaldi-math.cc
CHECK_GT(n, 0);
n--;
n |= n >> 1;
n |= n >> 2;
n |= n >> 4;
n |= n >> 8;
n |= n >> 16;
return n + 1;
}
struct FrameExtractionOptions {
float samp_freq = 16000;
float frame_shift_ms = 10.0f; // in milliseconds.
float frame_length_ms = 25.0f; // in milliseconds.
float dither = 1.0f; // Amount of dithering, 0.0 means no dither.
float preemph_coeff = 0.97f; // Preemphasis coefficient.
bool remove_dc_offset = true; // Subtract mean of wave before FFT.
std::string window_type = "povey"; // e.g. Hamming window
// May be "hamming", "rectangular", "povey", "hanning", "sine", "blackman"
// "povey" is a window I made to be similar to Hamming but to go to zero at
// the edges, it's pow((0.5 - 0.5*cos(n/N*2*pi)), 0.85) I just don't think
// the
// Hamming window makes sense as a windowing function.
bool round_to_power_of_two = true;
float blackman_coeff = 0.42f;
bool snip_edges = true;
// bool allow_downsample = false;
// bool allow_upsample = false;
// Used for streaming feature extraction. It indicates the number
// of feature frames to keep in the recycling vector. -1 means to
// keep all feature frames.
int32_t max_feature_vectors = -1;
int32_t WindowShift() const {
return static_cast<int32_t>(samp_freq * 0.001f * frame_shift_ms);
}
int32_t WindowSize() const {
return static_cast<int32_t>(samp_freq * 0.001f * frame_length_ms);
}
int32_t PaddedWindowSize() const {
return (round_to_power_of_two ? RoundUpToNearestPowerOfTwo(WindowSize())
: WindowSize());
}
std::string ToString() const {
std::ostringstream os;
#define KNF_PRINT(x) os << #x << ": " << x << "\n"
KNF_PRINT(samp_freq);
KNF_PRINT(frame_shift_ms);
KNF_PRINT(frame_length_ms);
KNF_PRINT(dither);
KNF_PRINT(preemph_coeff);
KNF_PRINT(remove_dc_offset);
KNF_PRINT(window_type);
KNF_PRINT(round_to_power_of_two);
KNF_PRINT(blackman_coeff);
KNF_PRINT(snip_edges);
// KNF_PRINT(allow_downsample);
// KNF_PRINT(allow_upsample);
KNF_PRINT(max_feature_vectors);
#undef KNF_PRINT
return os.str();
}
};
std::ostream &operator<<(std::ostream &os, const FrameExtractionOptions &opts);
class FeatureWindowFunction {
public:
FeatureWindowFunction() = default;
explicit FeatureWindowFunction(const FrameExtractionOptions &opts);
/**
* @param wave Pointer to a 1-D array of shape [window_size].
* It is modified in-place: wave[i] = wave[i] * window_[i].
* @param
*/
void Apply(float *wave) const;
private:
std::vector<float> window_; // of size opts.WindowSize()
};
int64_t FirstSampleOfFrame(int32_t frame, const FrameExtractionOptions &opts);
/**
This function returns the number of frames that we can extract from a wave
file with the given number of samples in it (assumed to have the same
sampling rate as specified in 'opts').
@param [in] num_samples The number of samples in the wave file.
@param [in] opts The frame-extraction options class
@param [in] flush True if we are asserting that this number of samples
is 'all there is', false if we expecting more data to possibly come in. This
only makes a difference to the answer
if opts.snips_edges== false. For offline feature extraction you always want
flush == true. In an online-decoding context, once you know (or decide) that
no more data is coming in, you'd call it with flush == true at the end to
flush out any remaining data.
*/
int32_t NumFrames(int64_t num_samples,
const FrameExtractionOptions &opts,
bool flush = true);
/*
ExtractWindow() extracts a windowed frame of waveform (possibly with a
power-of-two, padded size, depending on the config), including all the
processing done by ProcessWindow().
@param [in] sample_offset If 'wave' is not the entire waveform, but
part of it to the left has been discarded, then the
number of samples prior to 'wave' that we have
already discarded. Set this to zero if you are
processing the entire waveform in one piece, or
if you get 'no matching function' compilation
errors when updating the code.
@param [in] wave The waveform
@param [in] f The frame index to be extracted, with
0 <= f < NumFrames(sample_offset + wave.Dim(), opts, true)
@param [in] opts The options class to be used
@param [in] window_function The windowing function, as derived from the
options class.
@param [out] window The windowed, possibly-padded waveform to be
extracted. Will be resized as needed.
@param [out] log_energy_pre_window If non-NULL, the log-energy of
the signal prior to pre-emphasis and multiplying by
the windowing function will be written to here.
*/
void ExtractWindow(int64_t sample_offset,
const std::vector<float> &wave,
int32_t f,
const FrameExtractionOptions &opts,
const FeatureWindowFunction &window_function,
std::vector<float> *window,
float *log_energy_pre_window = nullptr);
/**
This function does all the windowing steps after actually
extracting the windowed signal: depending on the
configuration, it does dithering, dc offset removal,
preemphasis, and multiplication by the windowing function.
@param [in] opts The options class to be used
@param [in] window_function The windowing function-- should have
been initialized using 'opts'.
@param [in,out] window A vector of size opts.WindowSize(). Note:
it will typically be a sub-vector of a larger vector of size
opts.PaddedWindowSize(), with the remaining samples zero,
as the FFT code is more efficient if it operates on data with
power-of-two size.
@param [out] log_energy_pre_window If non-NULL, then after dithering and
DC offset removal, this function will write to this pointer the log of
the total energy (i.e. sum-squared) of the frame.
*/
void ProcessWindow(const FrameExtractionOptions &opts,
const FeatureWindowFunction &window_function,
float *window,
float *log_energy_pre_window = nullptr);
// Compute the inner product of two vectors
float InnerProduct(const float *a, const float *b, int32_t n);
} // namespace knf
#endif // KALDI_NATIVE_FEAT_CSRC_FEATURE_WINDOW_H_

@ -17,9 +17,6 @@
namespace ppspeech {
using kaldi::BaseFloat;
using kaldi::SubVector;
using kaldi::Vector;
using kaldi::VectorBase;
using std::unique_ptr;
using std::vector;
@ -31,7 +28,7 @@ FeatureCache::FeatureCache(FeatureCacheOptions opts,
dim_ = base_extractor_->Dim();
}
void FeatureCache::Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs) {
void FeatureCache::Accept(const std::vector<kaldi::BaseFloat>& inputs) {
// read inputs
base_extractor_->Accept(inputs);
@ -43,7 +40,7 @@ void FeatureCache::Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs) {
}
// pop feature chunk
bool FeatureCache::Read(kaldi::Vector<kaldi::BaseFloat>* feats) {
bool FeatureCache::Read(std::vector<kaldi::BaseFloat>* feats) {
kaldi::Timer timer;
std::unique_lock<std::mutex> lock(mutex_);
@ -59,8 +56,7 @@ bool FeatureCache::Read(kaldi::Vector<kaldi::BaseFloat>* feats) {
if (cache_.empty()) return false;
// read from cache
feats->Resize(cache_.front().Dim());
feats->CopyFromVec(cache_.front());
*feats = cache_.front();
cache_.pop();
ready_feed_condition_.notify_one();
VLOG(1) << "FeatureCache::Read cost: " << timer.Elapsed() << " sec.";
@ -70,21 +66,20 @@ bool FeatureCache::Read(kaldi::Vector<kaldi::BaseFloat>* feats) {
// read all data from base_feature_extractor_ into cache_
bool FeatureCache::Compute() {
// compute and feed
Vector<BaseFloat> feature;
vector<BaseFloat> feature;
bool result = base_extractor_->Read(&feature);
if (result == false || feature.Dim() == 0) return false;
if (result == false || feature.size() == 0) return false;
kaldi::Timer timer;
int32 num_chunk = feature.Dim() / dim_;
int32 num_chunk = feature.size() / dim_;
nframe_ += num_chunk;
VLOG(3) << "nframe computed: " << nframe_;
for (int chunk_idx = 0; chunk_idx < num_chunk; ++chunk_idx) {
int32 start = chunk_idx * dim_;
Vector<BaseFloat> feature_chunk(dim_);
SubVector<BaseFloat> tmp(feature.Data() + start, dim_);
feature_chunk.CopyFromVec(tmp);
vector<BaseFloat> feature_chunk(feature.data() + start,
feature.data() + start + dim_);
std::unique_lock<std::mutex> lock(mutex_);
while (cache_.size() >= max_size_) {

@ -32,10 +32,10 @@ class FeatureCache : public FrontendInterface {
std::unique_ptr<FrontendInterface> base_extractor = NULL);
// Feed feats or waves
virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs);
virtual void Accept(const std::vector<kaldi::BaseFloat>& inputs);
// feats size = num_frames * feat_dim
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* feats);
virtual bool Read(std::vector<kaldi::BaseFloat>* feats);
// feat dim
virtual size_t Dim() const { return dim_; }
@ -54,7 +54,7 @@ class FeatureCache : public FrontendInterface {
virtual bool IsFinished() const { return base_extractor_->IsFinished(); }
void Reset() override {
std::queue<kaldi::Vector<BaseFloat>> empty;
std::queue<std::vector<BaseFloat>> empty;
std::swap(cache_, empty);
nframe_ = 0;
base_extractor_->Reset();
@ -71,8 +71,8 @@ class FeatureCache : public FrontendInterface {
std::unique_ptr<FrontendInterface> base_extractor_;
kaldi::int32 timeout_; // ms
kaldi::Vector<kaldi::BaseFloat> remained_feature_;
std::queue<kaldi::Vector<BaseFloat>> cache_; // feature cache
std::vector<kaldi::BaseFloat> remained_feature_;
std::queue<std::vector<BaseFloat>> cache_; // feature cache
std::mutex mutex_;
std::condition_variable ready_feed_condition_;
std::condition_variable ready_read_condition_;

@ -15,7 +15,7 @@
#pragma once
#include "frontend_itf.h"
#include "kaldi/feat/feature-window.h"
#include "frontend/audio/feature-window.h"
namespace ppspeech {
@ -25,8 +25,8 @@ class StreamingFeatureTpl : public FrontendInterface {
typedef typename F::Options Options;
StreamingFeatureTpl(const Options& opts,
std::unique_ptr<FrontendInterface> base_extractor);
virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& waves);
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* feats);
virtual void Accept(const std::vector<kaldi::BaseFloat>& waves);
virtual bool Read(std::vector<kaldi::BaseFloat>* feats);
// the dim_ is the dim of single frame feature
virtual size_t Dim() const { return computer_.Dim(); }
@ -37,16 +37,16 @@ class StreamingFeatureTpl : public FrontendInterface {
virtual void Reset() {
base_extractor_->Reset();
remained_wav_.Resize(0);
remained_wav_.resize(0);
}
private:
bool Compute(const kaldi::Vector<kaldi::BaseFloat>& waves,
kaldi::Vector<kaldi::BaseFloat>* feats);
bool Compute(const std::vector<kaldi::BaseFloat>& waves,
std::vector<kaldi::BaseFloat>* feats);
Options opts_;
std::unique_ptr<FrontendInterface> base_extractor_;
kaldi::FeatureWindowFunction window_function_;
kaldi::Vector<kaldi::BaseFloat> remained_wav_;
knf::FeatureWindowFunction window_function_;
std::vector<kaldi::BaseFloat> remained_wav_;
F computer_;
};

@ -24,75 +24,77 @@ StreamingFeatureTpl<F>::StreamingFeatureTpl(
template <class F>
void StreamingFeatureTpl<F>::Accept(
const kaldi::VectorBase<kaldi::BaseFloat>& waves) {
const std::vector<kaldi::BaseFloat>& waves) {
base_extractor_->Accept(waves);
}
template <class F>
bool StreamingFeatureTpl<F>::Read(kaldi::Vector<kaldi::BaseFloat>* feats) {
kaldi::Vector<kaldi::BaseFloat> wav(base_extractor_->Dim());
bool StreamingFeatureTpl<F>::Read(std::vector<kaldi::BaseFloat>* feats) {
std::vector<kaldi::BaseFloat> wav(base_extractor_->Dim());
bool flag = base_extractor_->Read(&wav);
if (flag == false || wav.Dim() == 0) return false;
if (flag == false || wav.size() == 0) return false;
kaldi::Timer timer;
// append remaned waves
int32 wav_len = wav.Dim();
int32 left_len = remained_wav_.Dim();
kaldi::Vector<kaldi::BaseFloat> waves(left_len + wav_len);
waves.Range(0, left_len).CopyFromVec(remained_wav_);
waves.Range(left_len, wav_len).CopyFromVec(wav);
int32 wav_len = wav.size();
int32 left_len = remained_wav_.size();
std::vector<kaldi::BaseFloat> waves(left_len + wav_len);
std::memcpy(waves.data(),
remained_wav_.data(),
left_len * sizeof(kaldi::BaseFloat));
std::memcpy(waves.data() + left_len,
wav.data(),
wav_len * sizeof(kaldi::BaseFloat));
// compute speech feature
Compute(waves, feats);
// cache remaned waves
kaldi::FrameExtractionOptions frame_opts = computer_.GetFrameOptions();
int32 num_frames = kaldi::NumFrames(waves.Dim(), frame_opts);
knf::FrameExtractionOptions frame_opts = computer_.GetFrameOptions();
int32 num_frames = knf::NumFrames(waves.size(), frame_opts);
int32 frame_shift = frame_opts.WindowShift();
int32 left_samples = waves.Dim() - frame_shift * num_frames;
remained_wav_.Resize(left_samples);
remained_wav_.CopyFromVec(
waves.Range(frame_shift * num_frames, left_samples));
VLOG(1) << "StreamingFeatureTpl<F>::Read cost: " << timer.Elapsed()
<< " sec.";
int32 left_samples = waves.size() - frame_shift * num_frames;
remained_wav_.resize(left_samples);
std::memcpy(remained_wav_.data(),
waves.data() + frame_shift * num_frames,
left_samples * sizeof(BaseFloat));
return true;
}
// Compute feat
template <class F>
bool StreamingFeatureTpl<F>::Compute(
const kaldi::Vector<kaldi::BaseFloat>& waves,
kaldi::Vector<kaldi::BaseFloat>* feats) {
const kaldi::FrameExtractionOptions& frame_opts =
computer_.GetFrameOptions();
int32 num_samples = waves.Dim();
bool StreamingFeatureTpl<F>::Compute(const std::vector<kaldi::BaseFloat>& waves,
std::vector<kaldi::BaseFloat>* feats) {
const knf::FrameExtractionOptions& frame_opts = computer_.GetFrameOptions();
int32 num_samples = waves.size();
int32 frame_length = frame_opts.WindowSize();
int32 sample_rate = frame_opts.samp_freq;
if (num_samples < frame_length) {
return true;
}
int32 num_frames = kaldi::NumFrames(num_samples, frame_opts);
feats->Resize(num_frames * Dim());
int32 num_frames = knf::NumFrames(num_samples, frame_opts);
feats->resize(num_frames * Dim());
kaldi::Vector<kaldi::BaseFloat> window;
std::vector<kaldi::BaseFloat> window;
bool need_raw_log_energy = computer_.NeedRawLogEnergy();
for (int32 frame = 0; frame < num_frames; frame++) {
std::fill(window.begin(), window.end(), 0);
kaldi::BaseFloat raw_log_energy = 0.0;
kaldi::ExtractWindow(0,
waves,
frame,
frame_opts,
window_function_,
&window,
need_raw_log_energy ? &raw_log_energy : NULL);
kaldi::BaseFloat vtln_warp = 1.0;
knf::ExtractWindow(0,
waves,
frame,
frame_opts,
window_function_,
&window,
need_raw_log_energy ? &raw_log_energy : NULL);
kaldi::Vector<kaldi::BaseFloat> this_feature(computer_.Dim(),
kaldi::kUndefined);
computer_.Compute(&window, &this_feature);
kaldi::SubVector<kaldi::BaseFloat> output_row(
feats->Data() + frame * Dim(), Dim());
output_row.CopyFromVec(this_feature);
std::vector<kaldi::BaseFloat> this_feature(computer_.Dim());
computer_.Compute(
raw_log_energy, vtln_warp, &window, this_feature.data());
std::memcpy(feats->data() + frame * Dim(),
this_feature.data(),
sizeof(BaseFloat) * Dim());
}
return true;
}

@ -21,17 +21,12 @@ using std::unique_ptr;
FeaturePipeline::FeaturePipeline(const FeaturePipelineOptions& opts)
: opts_(opts) {
unique_ptr<FrontendInterface> data_source(
new ppspeech::AudioCache(1000 * kint16max, opts.to_float32));
new ppspeech::AudioCache(1000 * kint16max, false));
unique_ptr<FrontendInterface> base_feature;
if (opts.use_fbank) {
base_feature.reset(
new ppspeech::Fbank(opts.fbank_opts, std::move(data_source)));
} else {
base_feature.reset(new ppspeech::LinearSpectrogram(
opts.linear_spectrogram_opts, std::move(data_source)));
}
base_feature.reset(
new ppspeech::Fbank(opts.fbank_opts, std::move(data_source)));
CHECK_NE(opts.cmvn_file, "");
unique_ptr<FrontendInterface> cmvn(

@ -22,11 +22,9 @@
#include "frontend/audio/fbank.h"
#include "frontend/audio/feature_cache.h"
#include "frontend/audio/frontend_itf.h"
#include "frontend/audio/linear_spectrogram.h"
#include "frontend/audio/normalizer.h"
// feature
DECLARE_bool(use_fbank);
DECLARE_bool(fill_zero);
DECLARE_int32(num_bins);
DECLARE_string(cmvn_file);
@ -40,10 +38,7 @@ namespace ppspeech {
struct FeaturePipelineOptions {
std::string cmvn_file{};
bool to_float32{false}; // true, only for linear feature
bool use_fbank{true};
LinearSpectrogramOptions linear_spectrogram_opts{};
kaldi::FbankOptions fbank_opts{};
knf::FbankOptions fbank_opts{};
FeatureCacheOptions feature_cache_opts{};
AssemblerOptions assembler_opts{};
@ -53,30 +48,17 @@ struct FeaturePipelineOptions {
LOG(INFO) << "cmvn file: " << opts.cmvn_file;
// frame options
kaldi::FrameExtractionOptions frame_opts;
knf::FrameExtractionOptions frame_opts;
frame_opts.dither = 0.0;
LOG(INFO) << "dither: " << frame_opts.dither;
frame_opts.frame_shift_ms = 10;
LOG(INFO) << "frame shift ms: " << frame_opts.frame_shift_ms;
opts.use_fbank = FLAGS_use_fbank;
LOG(INFO) << "feature type: " << (opts.use_fbank ? "fbank" : "linear");
if (opts.use_fbank) {
opts.to_float32 = false;
frame_opts.window_type = "povey";
frame_opts.frame_length_ms = 25;
opts.fbank_opts.mel_opts.num_bins = FLAGS_num_bins;
LOG(INFO) << "num bins: " << opts.fbank_opts.mel_opts.num_bins;
opts.fbank_opts.frame_opts = frame_opts;
} else {
opts.to_float32 = true;
frame_opts.remove_dc_offset = false;
frame_opts.frame_length_ms = 20;
frame_opts.window_type = "hanning";
frame_opts.preemph_coeff = 0.0;
opts.linear_spectrogram_opts.frame_opts = frame_opts;
}
frame_opts.window_type = "povey";
frame_opts.frame_length_ms = 25;
opts.fbank_opts.mel_opts.num_bins = FLAGS_num_bins;
LOG(INFO) << "num bins: " << opts.fbank_opts.mel_opts.num_bins;
opts.fbank_opts.frame_opts = frame_opts;
LOG(INFO) << "frame length ms: " << frame_opts.frame_length_ms;
// assembler opts
@ -100,10 +82,10 @@ struct FeaturePipelineOptions {
class FeaturePipeline : public FrontendInterface {
public:
explicit FeaturePipeline(const FeaturePipelineOptions& opts);
virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& waves) {
virtual void Accept(const std::vector<kaldi::BaseFloat>& waves) {
base_extractor_->Accept(waves);
}
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* feats) {
virtual bool Read(std::vector<kaldi::BaseFloat>* feats) {
return base_extractor_->Read(feats);
}
virtual size_t Dim() const { return base_extractor_->Dim(); }

File diff suppressed because it is too large Load Diff

@ -22,13 +22,13 @@ namespace ppspeech {
class FrontendInterface {
public:
// Feed inputs: features(2D saved in 1D) or waveforms(1D).
virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs) = 0;
virtual void Accept(const std::vector<float>& inputs) = 0;
// Fetch processed data: features or waveforms.
// For features(2D saved in 1D), the Matrix is squashed into Vector,
// the length of output = feature_row * feature_dim.
// For waveforms(1D), samples saved in vector.
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* outputs) = 0;
virtual bool Read(std::vector<float>* outputs) = 0;
// Dim is the feature dim. For waveforms(1D), Dim is zero; else is specific,
// e.g 80 for fbank.

@ -0,0 +1,277 @@
/**
* Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// This file is copied/modified from kaldi/src/feat/mel-computations.cc
#include "frontend/audio/mel-computations.h"
#include <algorithm>
#include <sstream>
#include "frontend/audio/feature-window.h"
namespace knf {
std::ostream &operator<<(std::ostream &os, const MelBanksOptions &opts) {
os << opts.ToString();
return os;
}
float MelBanks::VtlnWarpFreq(
float vtln_low_cutoff, // upper+lower frequency cutoffs for VTLN.
float vtln_high_cutoff,
float low_freq, // upper+lower frequency cutoffs in mel computation
float high_freq,
float vtln_warp_factor,
float freq) {
/// This computes a VTLN warping function that is not the same as HTK's one,
/// but has similar inputs (this function has the advantage of never
/// producing
/// empty bins).
/// This function computes a warp function F(freq), defined between low_freq
/// and high_freq inclusive, with the following properties:
/// F(low_freq) == low_freq
/// F(high_freq) == high_freq
/// The function is continuous and piecewise linear with two inflection
/// points.
/// The lower inflection point (measured in terms of the unwarped
/// frequency) is at frequency l, determined as described below.
/// The higher inflection point is at a frequency h, determined as
/// described below.
/// If l <= f <= h, then F(f) = f/vtln_warp_factor.
/// If the higher inflection point (measured in terms of the unwarped
/// frequency) is at h, then max(h, F(h)) == vtln_high_cutoff.
/// Since (by the last point) F(h) == h/vtln_warp_factor, then
/// max(h, h/vtln_warp_factor) == vtln_high_cutoff, so
/// h = vtln_high_cutoff / max(1, 1/vtln_warp_factor).
/// = vtln_high_cutoff * min(1, vtln_warp_factor).
/// If the lower inflection point (measured in terms of the unwarped
/// frequency) is at l, then min(l, F(l)) == vtln_low_cutoff
/// This implies that l = vtln_low_cutoff / min(1, 1/vtln_warp_factor)
/// = vtln_low_cutoff * max(1, vtln_warp_factor)
if (freq < low_freq || freq > high_freq)
return freq; // in case this gets called
// for out-of-range frequencies, just return the freq.
CHECK_GT(vtln_low_cutoff, low_freq);
CHECK_LT(vtln_high_cutoff, high_freq);
float one = 1.0f;
float l = vtln_low_cutoff * std::max(one, vtln_warp_factor);
float h = vtln_high_cutoff * std::min(one, vtln_warp_factor);
float scale = 1.0f / vtln_warp_factor;
float Fl = scale * l; // F(l);
float Fh = scale * h; // F(h);
CHECK(l > low_freq && h < high_freq);
// slope of left part of the 3-piece linear function
float scale_left = (Fl - low_freq) / (l - low_freq);
// [slope of center part is just "scale"]
// slope of right part of the 3-piece linear function
float scale_right = (high_freq - Fh) / (high_freq - h);
if (freq < l) {
return low_freq + scale_left * (freq - low_freq);
} else if (freq < h) {
return scale * freq;
} else { // freq >= h
return high_freq + scale_right * (freq - high_freq);
}
}
float MelBanks::VtlnWarpMelFreq(
float vtln_low_cutoff, // upper+lower frequency cutoffs for VTLN.
float vtln_high_cutoff,
float low_freq, // upper+lower frequency cutoffs in mel computation
float high_freq,
float vtln_warp_factor,
float mel_freq) {
return MelScale(VtlnWarpFreq(vtln_low_cutoff,
vtln_high_cutoff,
low_freq,
high_freq,
vtln_warp_factor,
InverseMelScale(mel_freq)));
}
MelBanks::MelBanks(const MelBanksOptions &opts,
const FrameExtractionOptions &frame_opts,
float vtln_warp_factor)
: htk_mode_(opts.htk_mode) {
int32_t num_bins = opts.num_bins;
if (num_bins < 3) LOG(FATAL) << "Must have at least 3 mel bins";
float sample_freq = frame_opts.samp_freq;
int32_t window_length_padded = frame_opts.PaddedWindowSize();
CHECK_EQ(window_length_padded % 2, 0);
int32_t num_fft_bins = window_length_padded / 2;
float nyquist = 0.5f * sample_freq;
float low_freq = opts.low_freq, high_freq;
if (opts.high_freq > 0.0f)
high_freq = opts.high_freq;
else
high_freq = nyquist + opts.high_freq;
if (low_freq < 0.0f || low_freq >= nyquist || high_freq <= 0.0f ||
high_freq > nyquist || high_freq <= low_freq) {
LOG(FATAL) << "Bad values in options: low-freq " << low_freq
<< " and high-freq " << high_freq << " vs. nyquist "
<< nyquist;
}
float fft_bin_width = sample_freq / window_length_padded;
// fft-bin width [think of it as Nyquist-freq / half-window-length]
float mel_low_freq = MelScale(low_freq);
float mel_high_freq = MelScale(high_freq);
debug_ = opts.debug_mel;
// divide by num_bins+1 in next line because of end-effects where the bins
// spread out to the sides.
float mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1);
float vtln_low = opts.vtln_low, vtln_high = opts.vtln_high;
if (vtln_high < 0.0f) {
vtln_high += nyquist;
}
if (vtln_warp_factor != 1.0f &&
(vtln_low < 0.0f || vtln_low <= low_freq || vtln_low >= high_freq ||
vtln_high <= 0.0f || vtln_high >= high_freq ||
vtln_high <= vtln_low)) {
LOG(FATAL) << "Bad values in options: vtln-low " << vtln_low
<< " and vtln-high " << vtln_high << ", versus "
<< "low-freq " << low_freq << " and high-freq " << high_freq;
}
bins_.resize(num_bins);
center_freqs_.resize(num_bins);
for (int32_t bin = 0; bin < num_bins; ++bin) {
float left_mel = mel_low_freq + bin * mel_freq_delta,
center_mel = mel_low_freq + (bin + 1) * mel_freq_delta,
right_mel = mel_low_freq + (bin + 2) * mel_freq_delta;
if (vtln_warp_factor != 1.0f) {
left_mel = VtlnWarpMelFreq(vtln_low,
vtln_high,
low_freq,
high_freq,
vtln_warp_factor,
left_mel);
center_mel = VtlnWarpMelFreq(vtln_low,
vtln_high,
low_freq,
high_freq,
vtln_warp_factor,
center_mel);
right_mel = VtlnWarpMelFreq(vtln_low,
vtln_high,
low_freq,
high_freq,
vtln_warp_factor,
right_mel);
}
center_freqs_[bin] = InverseMelScale(center_mel);
// this_bin will be a vector of coefficients that is only
// nonzero where this mel bin is active.
std::vector<float> this_bin(num_fft_bins);
int32_t first_index = -1, last_index = -1;
for (int32_t i = 0; i < num_fft_bins; ++i) {
float freq = (fft_bin_width * i); // Center frequency of this fft
// bin.
float mel = MelScale(freq);
if (mel > left_mel && mel < right_mel) {
float weight;
if (mel <= center_mel)
weight = (mel - left_mel) / (center_mel - left_mel);
else
weight = (right_mel - mel) / (right_mel - center_mel);
this_bin[i] = weight;
if (first_index == -1) first_index = i;
last_index = i;
}
}
CHECK(first_index != -1 && last_index >= first_index &&
"You may have set num_mel_bins too large.");
bins_[bin].first = first_index;
int32_t size = last_index + 1 - first_index;
bins_[bin].second.insert(bins_[bin].second.end(),
this_bin.begin() + first_index,
this_bin.begin() + first_index + size);
// Replicate a bug in HTK, for testing purposes.
if (opts.htk_mode && bin == 0 && mel_low_freq != 0.0f) {
bins_[bin].second[0] = 0.0;
}
} // for (int32_t bin = 0; bin < num_bins; ++bin) {
if (debug_) {
std::ostringstream os;
for (size_t i = 0; i < bins_.size(); i++) {
os << "bin " << i << ", offset = " << bins_[i].first << ", vec = ";
for (auto k : bins_[i].second) os << k << ", ";
os << "\n";
}
LOG(INFO) << os.str();
}
}
// "power_spectrum" contains fft energies.
void MelBanks::Compute(const float *power_spectrum,
float *mel_energies_out) const {
int32_t num_bins = bins_.size();
for (int32_t i = 0; i < num_bins; i++) {
int32_t offset = bins_[i].first;
const auto &v = bins_[i].second;
float energy = 0;
for (int32_t k = 0; k != v.size(); ++k) {
energy += v[k] * power_spectrum[k + offset];
}
// HTK-like flooring- for testing purposes (we prefer dither)
if (htk_mode_ && energy < 1.0) {
energy = 1.0;
}
mel_energies_out[i] = energy;
// The following assert was added due to a problem with OpenBlas that
// we had at one point (it was a bug in that library). Just to detect
// it early.
CHECK_EQ(energy, energy); // check that energy is not nan
}
if (debug_) {
fprintf(stderr, "MEL BANKS:\n");
for (int32_t i = 0; i < num_bins; i++)
fprintf(stderr, " %f", mel_energies_out[i]);
fprintf(stderr, "\n");
}
}
} // namespace knf

@ -0,0 +1,120 @@
/**
* Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// This file is copied/modified from kaldi/src/feat/mel-computations.h
#ifndef KALDI_NATIVE_FBANK_CSRC_MEL_COMPUTATIONS_H_
#define KALDI_NATIVE_FBANK_CSRC_MEL_COMPUTATIONS_H_
#include <cmath>
#include <string>
#include "frontend/audio/feature-window.h"
namespace knf {
struct MelBanksOptions {
int32_t num_bins = 25; // e.g. 25; number of triangular bins
float low_freq = 20; // e.g. 20; lower frequency cutoff
// an upper frequency cutoff; 0 -> no cutoff, negative
// ->added to the Nyquist frequency to get the cutoff.
float high_freq = 0;
float vtln_low = 100; // vtln lower cutoff of warping function.
// vtln upper cutoff of warping function: if negative, added
// to the Nyquist frequency to get the cutoff.
float vtln_high = -500;
bool debug_mel = false;
// htk_mode is a "hidden" config, it does not show up on command line.
// Enables more exact compatibility with HTK, for testing purposes. Affects
// mel-energy flooring and reproduces a bug in HTK.
bool htk_mode = false;
std::string ToString() const {
std::ostringstream os;
os << "num_bins: " << num_bins << "\n";
os << "low_freq: " << low_freq << "\n";
os << "high_freq: " << high_freq << "\n";
os << "vtln_low: " << vtln_low << "\n";
os << "vtln_high: " << vtln_high << "\n";
os << "debug_mel: " << debug_mel << "\n";
os << "htk_mode: " << htk_mode << "\n";
return os.str();
}
};
std::ostream &operator<<(std::ostream &os, const MelBanksOptions &opts);
class MelBanks {
public:
static inline float InverseMelScale(float mel_freq) {
return 700.0f * (expf(mel_freq / 1127.0f) - 1.0f);
}
static inline float MelScale(float freq) {
return 1127.0f * logf(1.0f + freq / 700.0f);
}
static float VtlnWarpFreq(
float vtln_low_cutoff,
float vtln_high_cutoff, // discontinuities in warp func
float low_freq,
float high_freq, // upper+lower frequency cutoffs in
// the mel computation
float vtln_warp_factor,
float freq);
static float VtlnWarpMelFreq(float vtln_low_cutoff,
float vtln_high_cutoff,
float low_freq,
float high_freq,
float vtln_warp_factor,
float mel_freq);
// TODO(fangjun): Remove vtln_warp_factor
MelBanks(const MelBanksOptions &opts,
const FrameExtractionOptions &frame_opts,
float vtln_warp_factor);
/// Compute Mel energies (note: not log energies).
/// At input, "fft_energies" contains the FFT energies (not log).
///
/// @param fft_energies 1-D array of size num_fft_bins/2+1
/// @param mel_energies_out 1-D array of size num_mel_bins
void Compute(const float *fft_energies, float *mel_energies_out) const;
int32_t NumBins() const { return bins_.size(); }
private:
// center frequencies of bins, numbered from 0 ... num_bins-1.
// Needed by GetCenterFreqs().
std::vector<float> center_freqs_;
// the "bins_" vector is a vector, one for each bin, of a pair:
// (the first nonzero fft-bin), (the vector of weights).
std::vector<std::pair<int32_t, std::vector<float>>> bins_;
// TODO(fangjun): Remove debug_ and htk_mode_
bool debug_;
bool htk_mode_;
};
} // namespace knf
#endif // KALDI_NATIVE_FBANK_CSRC_MEL_COMPUTATIONS_H_

@ -0,0 +1,66 @@
/**
* Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "frontend/audio/rfft.h"
#include <cmath>
#include <vector>
#include "base/log.h"
// see fftsg.c
#ifdef __cplusplus
extern "C" void rdft(int n, int isgn, double *a, int *ip, double *w);
#else
void rdft(int n, int isgn, double *a, int *ip, double *w);
#endif
namespace knf {
class Rfft::RfftImpl {
public:
explicit RfftImpl(int32_t n) : n_(n), ip_(2 + std::sqrt(n / 2)), w_(n / 2) {
CHECK_EQ(n & (n - 1), 0);
}
void Compute(float *in_out) {
std::vector<double> d(in_out, in_out + n_);
Compute(d.data());
std::copy(d.begin(), d.end(), in_out);
}
void Compute(double *in_out) {
// 1 means forward fft
rdft(n_, 1, in_out, ip_.data(), w_.data());
}
private:
int32_t n_;
std::vector<int32_t> ip_;
std::vector<double> w_;
};
Rfft::Rfft(int32_t n) : impl_(std::make_unique<RfftImpl>(n)) {}
Rfft::~Rfft() = default;
void Rfft::Compute(float *in_out) { impl_->Compute(in_out); }
void Rfft::Compute(double *in_out) { impl_->Compute(in_out); }
} // namespace knf

@ -0,0 +1,56 @@
/**
* Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef KALDI_NATIVE_FBANK_CSRC_RFFT_H_
#define KALDI_NATIVE_FBANK_CSRC_RFFT_H_
#include <memory>
namespace knf {
// n-point Real discrete Fourier transform
// where n is a power of 2. n >= 2
//
// R[k] = sum_j=0^n-1 in[j]*cos(2*pi*j*k/n), 0<=k<=n/2
// I[k] = sum_j=0^n-1 in[j]*sin(2*pi*j*k/n), 0<k<n/2
class Rfft {
public:
// @param n Number of fft bins. it should be a power of 2.
explicit Rfft(int32_t n);
~Rfft();
/** @param in_out A 1-D array of size n.
* On return:
* in_out[0] = R[0]
* in_out[1] = R[n/2]
* for 1 < k < n/2,
* in_out[2*k] = R[k]
* in_out[2*k+1] = I[k]
*
*/
void Compute(float *in_out);
void Compute(double *in_out);
private:
class RfftImpl;
std::unique_ptr<RfftImpl> impl_;
};
} // namespace knf
#endif // KALDI_NATIVE_FBANK_CSRC_RFFT_H_
Loading…
Cancel
Save