diff --git a/speechx/speechx/codelab/feat_test/linear_spectrogram_main.cc b/speechx/speechx/codelab/feat_test/linear_spectrogram_main.cc index 38a1a0b3..3cd1ae61 100644 --- a/speechx/speechx/codelab/feat_test/linear_spectrogram_main.cc +++ b/speechx/speechx/codelab/feat_test/linear_spectrogram_main.cc @@ -7,20 +7,36 @@ #include "base/log.h" #include "base/flags.h" #include "kaldi/feat/wave-reader.h" +#include "kaldi/util/kaldi-io.h" DEFINE_string(wav_rspecifier, "", "test wav path"); DEFINE_string(feature_wspecifier, "", "test wav ark"); +DEFINE_string(feature_check_wspecifier, "", "test wav ark"); +DEFINE_string(cmvn_write_path, "./cmvn.ark", "test wav ark"); + std::vector mean_{-13730251.531853663, -12982852.199316509, -13673844.299583456, -13089406.559646806, -12673095.524938712, -12823859.223276224, -13590267.158903603, -14257618.467152044, -14374605.116185192, -14490009.21822485, -14849827.158924166, -15354435.470563512, -15834149.206532761, -16172971.985514281, -16348740.496746974, -16423536.699409386, -16556246.263649225, -16744088.772748645, -16916184.08510357, -17054034.840031497, -17165612.509455364, -17255955.470915023, -17322572.527648456, -17408943.862033736, -17521554.799865916, -17620623.254924215, -17699792.395918526, -17723364.411134344, -17741483.4433254, -17747426.888704527, -17733315.928209435, -17748780.160905756, -17808336.883775543, -17895918.671983004, -18009812.59173023, -18098188.66548325, -18195798.958462656, -18293617.62980999, -18397432.92077201, -18505834.787318766, -18585451.8100908, -18652438.235649142, -18700960.306275308, -18734944.58792185, -18737426.313365128, -18735347.165987637, -18738813.444170244, -18737086.848890636, -18731576.2474336, -18717405.44095871, -18703089.25545657, -18691014.546456724, -18692460.568905357, -18702119.628629155, -18727710.621126678, -18761582.72034647, -18806745.835547544, -18850674.8692112, -18884431.510951452, -18919999.992506847, -18939303.799078144, -18952946.273760635, -18980289.22996379, -19011610.17803294, -19040948.61805145, -19061021.429847397, -19112055.53768819, -19149667.414264943, -19201127.05091321, -19270250.82564605, -19334606.883057203, -19390513.336589377, -19444176.259208687, -19502755.000038862, -19544333.014549147, -19612668.183176614, -19681902.19006569, -19771969.951249883, -19873329.723376893, -19996752.59235844, -20110031.131400537, -20231658.612529557, -20319378.894054495, -20378534.45718066, -20413332.089584175, -20438147.844177883, -20443710.248040095, -20465457.02238927, -20488610.969337028, -20516295.16424432, -20541423.795738827, -20553192.874953747, -20573605.50701977, -20577871.61936797, -20571807.008916274, -20556242.38912231, -20542199.30819195, -20521239.063551214, -20519150.80004532, -20527204.80248933, -20536933.769257784, -20543470.522332076, -20549700.089992985, -20551525.24958494, -20554873.406493705, -20564277.65794227, -20572211.740052115, -20574305.69550465, -20575494.450104576, -20567092.577932164, -20549302.929608088, -20545445.11878376, -20546625.326603737, -20549190.03499401, -20554824.947828256, -20568341.378989458, -20577582.331383612, -20577980.519402675, -20566603.03458152, -20560131.592262644, -20552166.469060015, -20549063.06763577, -20544490.562339947, -20539817.82346569, -20528747.715731595, -20518026.24576161, -20510977.844974525, -20506874.36087992, -20506731.11977665, -20510482.133420516, -20507760.92101862, -20494644.834457114, -20480107.89304893, -20461312.091867123, -20442941.75080173, -20426123.02834838, -20424607.675283, -20426810.369107097, -20434024.50097819, -20437404.75544205, -20447688.63916367, -20460893.335563846, -20482922.735127095, -20503610.119434915, -20527062.76448319, -20557830.035128627, -20593274.72068722, -20632528.452965066, -20673637.471334763, -20733106.97143075, -20842921.0447562, -21054357.83621519, -21416569.534189366, -21978460.272811692, -22753170.052172784, -23671344.10563395, -24613499.293358143, -25406477.12230188, -25884377.82156489, -26049040.62791664, -26996879.104431007}; std::vector variance_{213747175.10846674, 188395815.34302503, 212706429.10966414, 199109025.81461075, 189235901.23864496, 194901336.53253657, 217481594.29306737, 238689869.12327808, 243977501.24115244, 248479623.6431067, 259766741.47116545, 275516766.7790273, 291271202.3691234, 302693239.8220509, 308627358.3997694, 311143911.38788426, 315446105.07731867, 321705430.9341829, 327458907.4659941, 332245072.43223983, 336251717.5935284, 339694069.7639722, 342188204.4322228, 345587110.31313115, 349903086.2875232, 353660214.20643026, 356700344.5270885, 357665362.3529641, 358493352.05658793, 358857951.620328, 358375239.52774596, 358899733.6342954, 361051818.3511561, 364361716.05025816, 368750322.3771452, 372047800.6462831, 375655861.1349018, 379358519.1980013, 383327605.3935181, 387458599.282341, 390434692.3406868, 392994486.35057056, 394874418.04603153, 396230525.79763395, 396365592.0414835, 396334819.8242737, 396488353.19250053, 396438877.00744957, 396197980.4459586, 395590921.6672991, 395001107.62072515, 394528291.7318225, 394593110.424006, 395018405.59353715, 396110577.5415993, 397506704.0371068, 399400197.4657644, 401243568.2468382, 402687134.7805103, 404136047.2872507, 404883170.001883, 405522253.219517, 406660365.3626476, 407919346.0991902, 409045348.5384909, 409759588.7889818, 411974821.8564483, 413489718.78201455, 415535392.56684107, 418466481.97674364, 421104678.35678065, 423405392.5200779, 425550570.40798235, 427929423.9579701, 429585274.253478, 432368493.55181056, 435193587.13513297, 438886855.20476013, 443058876.8633751, 448181232.5093362, 452883835.6332396, 458056721.77926534, 461816531.22735566, 464363620.1970998, 465886343.5057493, 466928872.0651, 467180536.42647296, 468111848.70714295, 469138695.3071312, 470378429.6930793, 471517958.7132626, 472109050.4262365, 473087417.0177867, 473381322.04648733, 473220195.85483915, 472666071.8998819, 472124669.87879956, 471298571.411737, 471251033.2902761, 471672676.43128747, 472177147.2193172, 472572361.7711908, 472968783.7751127, 473156295.4164052, 473398034.82676554, 473897703.5203811, 474328271.33112127, 474452670.98002136, 474549003.99284613, 474252887.13567275, 473557462.909069, 473483385.85193115, 473609738.04855174, 473746944.82085115, 474016729.91696435, 474617321.94138587, 475045097.237122, 475125402.586558, 474664112.9824912, 474426247.5800283, 474104075.42796475, 473978219.7273978, 473773171.7798875, 473578534.69508696, 473102924.16904145, 472651240.5232615, 472374383.1810912, 472209479.6956096, 472202298.8921673, 472370090.76781124, 472220933.99374026, 471625467.37106377, 470994646.51883453, 470182428.9637543, 469348211.5939578, 468570387.4467277, 468540442.7225135, 468672018.90414184, 468994346.9533251, 469138757.58201426, 469553915.95710236, 470134523.38582784, 471082421.62055486, 471962316.51804745, 472939745.1708408, 474250621.5944825, 475773933.43199486, 477465399.71087736, 479218782.61382693, 481752299.7930922, 486608947.8984568, 496119403.2067917, 512730085.5704984, 539048915.2641417, 576285298.3548826, 621610270.2240586, 669308196.4436442, 710656993.5957186, 736344437.3725077, 745481288.0241544, 801121432.9925804}; int count_ = 912592; +void WriteMatrix() { + kaldi::Matrix cmvn_stats(2, mean_.size()+ 1); + for (size_t idx = 0; idx < mean_.size(); ++idx) { + cmvn_stats(0, idx) = mean_[idx]; + cmvn_stats(1, idx) = variance_[idx]; + } + cmvn_stats(0, mean_.size()) = count_; + kaldi::WriteKaldiObject(cmvn_stats, FLAGS_cmvn_write_path, true); +} + int main(int argc, char* argv[]) { gflags::ParseCommandLineFlags(&argc, &argv, false); google::InitGoogleLogging(argv[0]); kaldi::SequentialTableReader wav_reader(FLAGS_wav_rspecifier); kaldi::BaseFloatMatrixWriter feat_writer(FLAGS_feature_wspecifier); + kaldi::BaseFloatMatrixWriter feat_cmvn_check_writer(FLAGS_feature_check_wspecifier); + WriteMatrix(); // test feature linear_spectorgram: wave --> decibel_normalizer --> hanning window -->linear_spectrogram --> cmvn int32 num_done = 0, num_err = 0; @@ -32,6 +48,8 @@ int main(int argc, char* argv[]) { new ppspeech::DecibelNormalizer(db_norm_opt)); ppspeech::LinearSpectrogram linear_spectrogram(opt, std::move(base_feature_extractor)); + ppspeech::CMVN cmvn(FLAGS_cmvn_write_path); + float streaming_chunk = 0.36; int sample_rate = 16000; int chunk_sample_size = streaming_chunk * sample_rate; @@ -57,18 +75,18 @@ int main(int argc, char* argv[]) { std::vector> feats; int feature_rows = 0; while (sample_offset < tot_samples) { - int cur_chunk_size = std::min(chunk_sample_size, tot_samples - sample_offset); - kaldi::Vector wav_chunk(cur_chunk_size); - for (int i = 0; i < cur_chunk_size; ++i) { - wav_chunk(i) = waveform(sample_offset + i); - } - kaldi::Matrix features; - linear_spectrogram.AcceptWaveform(wav_chunk); - linear_spectrogram.ReadFeats(&features); - - feats.push_back(features); - sample_offset += cur_chunk_size; - feature_rows += features.NumRows(); + int cur_chunk_size = std::min(chunk_sample_size, tot_samples - sample_offset); + kaldi::Vector wav_chunk(cur_chunk_size); + for (int i = 0; i < cur_chunk_size; ++i) { + wav_chunk(i) = waveform(sample_offset + i); + } + kaldi::Matrix features; + linear_spectrogram.AcceptWaveform(wav_chunk); + linear_spectrogram.ReadFeats(&features); + + feats.push_back(features); + sample_offset += cur_chunk_size; + feature_rows += features.NumRows(); } int cur_idx = 0; @@ -81,8 +99,22 @@ int main(int argc, char* argv[]) { ++cur_idx; } } - feat_writer.Write(utt, features); + + cur_idx = 0; + kaldi::Matrix features_check(feature_rows, feats[0].NumCols()); + for (auto feat : feats) { + for (int row_idx = 0; row_idx < feat.NumRows(); ++row_idx) { + for (int col_idx = 0; col_idx < feat.NumCols(); ++col_idx) { + features_check(cur_idx, col_idx) = feat(row_idx, col_idx); + } + kaldi::SubVector row_feat(features_check, cur_idx); + cmvn.ApplyCMVN(true, &row_feat); + ++cur_idx; + } + } + feat_cmvn_check_writer.Write(utt, features_check); + if (num_done % 50 == 0 && num_done != 0) KALDI_VLOG(2) << "Processed " << num_done << " utterances"; num_done++; @@ -90,4 +122,4 @@ int main(int argc, char* argv[]) { KALDI_LOG << "Done " << num_done << " utterances, " << num_err << " with errors."; return (num_done != 0 ? 0 : 1); -} \ No newline at end of file +} diff --git a/speechx/speechx/frontend/normalizer.cc b/speechx/speechx/frontend/normalizer.cc index 04e88bf4..16fc09a8 100644 --- a/speechx/speechx/frontend/normalizer.cc +++ b/speechx/speechx/frontend/normalizer.cc @@ -1,5 +1,7 @@ #include "frontend/normalizer.h" +#include "kaldi/feat/cmvn.h" +#include "kaldi/util/kaldi-io.h" namespace ppspeech { @@ -7,6 +9,7 @@ using kaldi::Vector; using kaldi::VectorBase; using kaldi::BaseFloat; using std::vector; +using kaldi::SubVector; DecibelNormalizer::DecibelNormalizer(const DecibelNormalizerOptions& opts) { opts_ = opts; @@ -87,44 +90,91 @@ bool DecibelNormalizer::Compute(const VectorBase& input, return true; } -/* -PPNormalizer::PPNormalizer( - const PPNormalizerOptions& opts, - const std::unique_ptr& pre_extractor) { +CMVN::CMVN(std::string cmvn_file) : var_norm_(true) { + bool binary; + kaldi::Input ki(cmvn_file, &binary); + stats_.Read(ki.Stream(), binary); +} +void CMVN::AcceptWaveform(const kaldi::VectorBase& input) { + return; } - -void PPNormalizer::AcceptWavefrom(const Vector& input) { +void CMVN::Read(kaldi::VectorBase* feat) { + return; } -void PPNormalizer::Read(Vector* feat) { +// feats contain num_frames feature. +void CMVN::ApplyCMVN(bool var_norm, VectorBase* feats) { + KALDI_ASSERT(feats != NULL); + int32 dim = stats_.NumCols() - 1; + if (stats_.NumRows() > 2 || stats_.NumRows() < 1 || feats->Dim() % dim != 0) { + KALDI_ERR << "Dim mismatch: cmvn " + << stats_.NumRows() << 'x' << stats_.NumCols() + << ", feats " << feats->Dim() << 'x'; + } + if (stats_.NumRows() == 1 && var_norm) { + KALDI_ERR << "You requested variance normalization but no variance stats_ " + << "are supplied."; + } + double count = stats_(0, dim); + // Do not change the threshold of 1.0 here: in the balanced-cmvn code, when + // computing an offset and representing it as stats_, we use a count of one. + if (count < 1.0) + KALDI_ERR << "Insufficient stats_ for cepstral mean and variance normalization: " + << "count = " << count; + + if (!var_norm) { + Vector offset(feats->Dim()); + SubVector mean_stats(stats_.RowData(0), dim); + Vector mean_stats_apply(feats->Dim()); + //fill the datat of mean_stats in mean_stats_appy whose dim is equal with the dim of feature. + //the dim of feats = dim * num_frames; + for (int32 idx = 0; idx < feats->Dim() / dim; ++idx) { + SubVector stats_tmp(mean_stats_apply.Data() + dim*idx, dim); + stats_tmp.CopyFromVec(mean_stats); + } + offset.AddVec(-1.0 / count, mean_stats_apply); + feats->AddVec(1.0, offset); + return; + } + // norm(0, d) = mean offset; + // norm(1, d) = scale, e.g. x(d) <-- x(d)*norm(1, d) + norm(0, d). + kaldi::Matrix norm(2, feats->Dim()); + for (int32 d = 0; d < dim; d++) { + double mean, offset, scale; + mean = stats_(0, d)/count; + double var = (stats_(1, d)/count) - mean*mean, + floor = 1.0e-20; + if (var < floor) { + KALDI_WARN << "Flooring cepstral variance from " << var << " to " + << floor; + var = floor; + } + scale = 1.0 / sqrt(var); + if (scale != scale || 1/scale == 0.0) + KALDI_ERR << "NaN or infinity in cepstral mean/variance computation"; + offset = -(mean*scale); + for (int32 d_skip = d; d_skip < feats->Dim();) { + norm(0, d_skip) = offset; + norm(1, d_skip) = scale; + d_skip = d_skip + dim; + } + } + // Apply the normalization. + feats->MulElements(norm.Row(1)); + feats->AddVec(1.0, norm.Row(0)); } -bool PPNormalizer::Compute(const Vector& input, - Vector>* feat) { - if ((input.Dim() % mean_.Dim()) == 0) { - LOG(ERROR) << "CMVN dimension is wrong!"; - return false; - } - - try { - int32 size = mean_.Dim(); - feat->Resize(input.Dim()); - for (int32 row_idx = 0; row_idx < j; ++row_idx) { - int32 base_idx = row_idx * size; - for (int32 idx = 0; idx < mean_.Dim(); ++idx) { - (*feat)(base_idx + idx) = (input(base_dix + idx) - mean_(idx))* variance_(idx); - } - } - - } catch(const std::exception& e) { - std::cerr << e.what() << '\n'; - return false; - } +void CMVN::ApplyCMVNMatrix(bool var_norm, kaldi::MatrixBase* feats) { + ApplyCmvn(stats_, var_norm, feats); +} + +bool CMVN::Compute(const VectorBase& input, + VectorBase* feat) const { + return false; +} - return true; -}*/ -} // namespace ppspeech \ No newline at end of file +} // namespace ppspeech diff --git a/speechx/speechx/frontend/normalizer.h b/speechx/speechx/frontend/normalizer.h index 78670bb4..eea03fc1 100644 --- a/speechx/speechx/frontend/normalizer.h +++ b/speechx/speechx/frontend/normalizer.h @@ -4,10 +4,10 @@ #include "base/common.h" #include "frontend/feature_extractor_interface.h" #include "kaldi/util/options-itf.h" +#include "kaldi/matrix/kaldi-matrix.h" namespace ppspeech { - struct DecibelNormalizerOptions { float target_db; float max_gain_db; @@ -39,34 +39,23 @@ class DecibelNormalizer : public FeatureExtractorInterface { kaldi::Vector waveform_; }; -/* -struct NormalizerOptions { - std::string mean_std_path; - NormalizerOptions() : - mean_std_path("") {} - - void Register(kaldi::OptionsItf* opts) { - opts->Register("mean-std", &mean_std_path, "mean std file"); - } -}; -// todo refactor later (SmileGoat) -class PPNormalizer : public FeatureExtractorInterface { +class CMVN : public FeatureExtractorInterface { public: - explicit PPNormalizer(const NormalizerOptions& opts, - const std::unique_ptr& pre_extractor); - ~PPNormalizer() {} - virtual void AcceptWavefrom(const kaldi::Vector& input); - virtual void Read(kaldi::Vector* feat); - virtual size_t Dim() const; - bool Compute(const kaldi::Vector& input, - kaldi::Vector>& feat); - + explicit CMVN(std::string cmvn_file); + virtual void AcceptWaveform(const kaldi::VectorBase& input); + virtual void Read(kaldi::VectorBase* feat); + virtual size_t Dim() const { return stats_.NumCols() - 1; } + bool Compute(const kaldi::VectorBase& input, + kaldi::VectorBase* feat) const; + // for test + void ApplyCMVN(bool var_norm, kaldi::VectorBase* feats); + void ApplyCMVNMatrix(bool var_norm, kaldi::MatrixBase* feats); private: - bool _initialized; - kaldi::Vector mean_; - kaldi::Vector variance_; - NormalizerOptions _opts; + kaldi::Matrix stats_; + std::shared_ptr base_extractor_; + size_t dim_; + bool var_norm_; }; -*/ + } // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/kaldi/feat/CMakeLists.txt b/speechx/speechx/kaldi/feat/CMakeLists.txt index 8b914962..c3a996ff 100644 --- a/speechx/speechx/kaldi/feat/CMakeLists.txt +++ b/speechx/speechx/kaldi/feat/CMakeLists.txt @@ -15,5 +15,6 @@ add_library(kaldi-feat-common feature-window.cc resample.cc mel-computations.cc + cmvn.cc ) -target_link_libraries(kaldi-feat-common PUBLIC kaldi-base kaldi-matrix kaldi-util) \ No newline at end of file +target_link_libraries(kaldi-feat-common PUBLIC kaldi-base kaldi-matrix kaldi-util) diff --git a/speechx/speechx/kaldi/feat/cmvn.cc b/speechx/speechx/kaldi/feat/cmvn.cc new file mode 100644 index 00000000..b2aa46e4 --- /dev/null +++ b/speechx/speechx/kaldi/feat/cmvn.cc @@ -0,0 +1,183 @@ +// transform/cmvn.cc + +// Copyright 2009-2013 Microsoft Corporation +// Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "feat/cmvn.h" + +namespace kaldi { + +void InitCmvnStats(int32 dim, Matrix *stats) { + KALDI_ASSERT(dim > 0); + stats->Resize(2, dim+1); +} + +void AccCmvnStats(const VectorBase &feats, BaseFloat weight, MatrixBase *stats) { + int32 dim = feats.Dim(); + KALDI_ASSERT(stats != NULL); + KALDI_ASSERT(stats->NumRows() == 2 && stats->NumCols() == dim + 1); + // Remove these __restrict__ modifiers if they cause compilation problems. + // It's just an optimization. + double *__restrict__ mean_ptr = stats->RowData(0), + *__restrict__ var_ptr = stats->RowData(1), + *__restrict__ count_ptr = mean_ptr + dim; + const BaseFloat * __restrict__ feats_ptr = feats.Data(); + *count_ptr += weight; + // Careful-- if we change the format of the matrix, the "mean_ptr < count_ptr" + // statement below might become wrong. + for (; mean_ptr < count_ptr; mean_ptr++, var_ptr++, feats_ptr++) { + *mean_ptr += *feats_ptr * weight; + *var_ptr += *feats_ptr * *feats_ptr * weight; + } +} + +void AccCmvnStats(const MatrixBase &feats, + const VectorBase *weights, + MatrixBase *stats) { + int32 num_frames = feats.NumRows(); + if (weights != NULL) { + KALDI_ASSERT(weights->Dim() == num_frames); + } + for (int32 i = 0; i < num_frames; i++) { + SubVector this_frame = feats.Row(i); + BaseFloat weight = (weights == NULL ? 1.0 : (*weights)(i)); + if (weight != 0.0) + AccCmvnStats(this_frame, weight, stats); + } +} + +void ApplyCmvn(const MatrixBase &stats, + bool var_norm, + MatrixBase *feats) { + KALDI_ASSERT(feats != NULL); + int32 dim = stats.NumCols() - 1; + if (stats.NumRows() > 2 || stats.NumRows() < 1 || feats->NumCols() != dim) { + KALDI_ERR << "Dim mismatch: cmvn " + << stats.NumRows() << 'x' << stats.NumCols() + << ", feats " << feats->NumRows() << 'x' << feats->NumCols(); + } + if (stats.NumRows() == 1 && var_norm) + KALDI_ERR << "You requested variance normalization but no variance stats " + << "are supplied."; + + double count = stats(0, dim); + // Do not change the threshold of 1.0 here: in the balanced-cmvn code, when + // computing an offset and representing it as stats, we use a count of one. + if (count < 1.0) + KALDI_ERR << "Insufficient stats for cepstral mean and variance normalization: " + << "count = " << count; + + if (!var_norm) { + Vector offset(dim); + SubVector mean_stats(stats.RowData(0), dim); + offset.AddVec(-1.0 / count, mean_stats); + feats->AddVecToRows(1.0, offset); + return; + } + // norm(0, d) = mean offset; + // norm(1, d) = scale, e.g. x(d) <-- x(d)*norm(1, d) + norm(0, d). + Matrix norm(2, dim); + for (int32 d = 0; d < dim; d++) { + double mean, offset, scale; + mean = stats(0, d)/count; + double var = (stats(1, d)/count) - mean*mean, + floor = 1.0e-20; + if (var < floor) { + KALDI_WARN << "Flooring cepstral variance from " << var << " to " + << floor; + var = floor; + } + scale = 1.0 / sqrt(var); + if (scale != scale || 1/scale == 0.0) + KALDI_ERR << "NaN or infinity in cepstral mean/variance computation"; + offset = -(mean*scale); + norm(0, d) = offset; + norm(1, d) = scale; + } + // Apply the normalization. + feats->MulColsVec(norm.Row(1)); + feats->AddVecToRows(1.0, norm.Row(0)); +} + +void ApplyCmvnReverse(const MatrixBase &stats, + bool var_norm, + MatrixBase *feats) { + KALDI_ASSERT(feats != NULL); + int32 dim = stats.NumCols() - 1; + if (stats.NumRows() > 2 || stats.NumRows() < 1 || feats->NumCols() != dim) { + KALDI_ERR << "Dim mismatch: cmvn " + << stats.NumRows() << 'x' << stats.NumCols() + << ", feats " << feats->NumRows() << 'x' << feats->NumCols(); + } + if (stats.NumRows() == 1 && var_norm) + KALDI_ERR << "You requested variance normalization but no variance stats " + << "are supplied."; + + double count = stats(0, dim); + // Do not change the threshold of 1.0 here: in the balanced-cmvn code, when + // computing an offset and representing it as stats, we use a count of one. + if (count < 1.0) + KALDI_ERR << "Insufficient stats for cepstral mean and variance normalization: " + << "count = " << count; + + Matrix norm(2, dim); // norm(0, d) = mean offset + // norm(1, d) = scale, e.g. x(d) <-- x(d)*norm(1, d) + norm(0, d). + for (int32 d = 0; d < dim; d++) { + double mean, offset, scale; + mean = stats(0, d) / count; + if (!var_norm) { + scale = 1.0; + offset = mean; + } else { + double var = (stats(1, d)/count) - mean*mean, + floor = 1.0e-20; + if (var < floor) { + KALDI_WARN << "Flooring cepstral variance from " << var << " to " + << floor; + var = floor; + } + // we aim to transform zero-mean, unit-variance input into data + // with the given mean and variance. + scale = sqrt(var); + offset = mean; + } + norm(0, d) = offset; + norm(1, d) = scale; + } + if (var_norm) + feats->MulColsVec(norm.Row(1)); + feats->AddVecToRows(1.0, norm.Row(0)); +} + + +void FakeStatsForSomeDims(const std::vector &dims, + MatrixBase *stats) { + KALDI_ASSERT(stats->NumRows() == 2 && stats->NumCols() > 1); + int32 dim = stats->NumCols() - 1; + double count = (*stats)(0, dim); + for (size_t i = 0; i < dims.size(); i++) { + int32 d = dims[i]; + KALDI_ASSERT(d >= 0 && d < dim); + (*stats)(0, d) = 0.0; + (*stats)(1, d) = count; + } +} + + + +} // namespace kaldi diff --git a/speechx/speechx/kaldi/feat/cmvn.h b/speechx/speechx/kaldi/feat/cmvn.h new file mode 100644 index 00000000..c6d1b7f7 --- /dev/null +++ b/speechx/speechx/kaldi/feat/cmvn.h @@ -0,0 +1,75 @@ +// transform/cmvn.h + +// Copyright 2009-2013 Microsoft Corporation +// Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#ifndef KALDI_TRANSFORM_CMVN_H_ +#define KALDI_TRANSFORM_CMVN_H_ + +#include "base/kaldi-common.h" +#include "matrix/matrix-lib.h" + +namespace kaldi { + +/// This function initializes the matrix to dimension 2 by (dim+1); +/// 1st "dim" elements of 1st row are mean stats, 1st "dim" elements +/// of 2nd row are var stats, last element of 1st row is count, +/// last element of 2nd row is zero. +void InitCmvnStats(int32 dim, Matrix *stats); + +/// Accumulation from a single frame (weighted). +void AccCmvnStats(const VectorBase &feat, + BaseFloat weight, + MatrixBase *stats); + +/// Accumulation from a feature file (possibly weighted-- useful in excluding silence). +void AccCmvnStats(const MatrixBase &feats, + const VectorBase *weights, // or NULL + MatrixBase *stats); + +/// Apply cepstral mean and variance normalization to a matrix of features. +/// If norm_vars == true, expects stats to be of dimension 2 by (dim+1), but +/// if norm_vars == false, will accept stats of dimension 1 by (dim+1); these +/// are produced by the balanced-cmvn code when it computes an offset and +/// represents it as "fake stats". +void ApplyCmvn(const MatrixBase &stats, + bool norm_vars, + MatrixBase *feats); + +/// This is as ApplyCmvn, but does so in the reverse sense, i.e. applies a transform +/// that would take zero-mean, unit-variance input and turn it into output with the +/// stats of "stats". This can be useful if you trained without CMVN but later want +/// to correct a mismatch, so you would first apply CMVN and then do the "reverse" +/// CMVN with the summed stats of your training data. +void ApplyCmvnReverse(const MatrixBase &stats, + bool norm_vars, + MatrixBase *feats); + + +/// Modify the stats so that for some dimensions (specified in "dims"), we +/// replace them with "fake" stats that have zero mean and unit variance; this +/// is done to disable CMVN for those dimensions. +void FakeStatsForSomeDims(const std::vector &dims, + MatrixBase *stats); + + + +} // namespace kaldi + +#endif // KALDI_TRANSFORM_CMVN_H_