pull/1541/head
Hui Zhang 3 years ago
parent 11dc485d63
commit 41feecbd06

@ -50,13 +50,13 @@ repos:
entry: bash .pre-commit-hooks/clang-format.hook -i entry: bash .pre-commit-hooks/clang-format.hook -i
language: system language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$ files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$
exclude: (?=speechx/speechx/kaldi).*(\.cpp|\.cc|\.h|\.py)$ exclude: (?=speechx/speechx/kaldi|speechx/patch).*(\.cpp|\.cc|\.h|\.py)$
- id: copyright_checker - id: copyright_checker
name: copyright_checker name: copyright_checker
entry: python .pre-commit-hooks/copyright-check.hook entry: python .pre-commit-hooks/copyright-check.hook
language: system language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$ files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$
exclude: (?=third_party|pypinyin|speechx/speechx/kaldi).*(\.cpp|\.cc|\.h|\.py)$ exclude: (?=third_party|pypinyin|speechx/speechx/kaldi|speechx/patch).*(\.cpp|\.cc|\.h|\.py)$
- repo: https://github.com/asottile/reorder_python_imports - repo: https://github.com/asottile/reorder_python_imports
rev: v2.4.0 rev: v2.4.0
hooks: hooks:

@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False):
""" """
rng = np.random.RandomState(epoch) rng = np.random.RandomState(epoch)
shift_len = rng.randint(0, batch_size - 1) shift_len = rng.randint(0, batch_size - 1)
batch_indices = list(zip(*[iter(indices[shift_len:])] * batch_size)) batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size))
rng.shuffle(batch_indices) rng.shuffle(batch_indices)
batch_indices = [item for batch in batch_indices for item in batch] batch_indices = [item for batch in batch_indices for item in batch]
assert clipped is False assert clipped is False

@ -36,4 +36,4 @@ def repeat(N, fn):
Returns: Returns:
MultiSequential: Repeated model instance. MultiSequential: Repeated model instance.
""" """
return MultiSequential(*[fn(n) for n in range(N)]) return MultiSequential(* [fn(n) for n in range(N)])

@ -3,4 +3,3 @@
* decoder - offline decoder * decoder - offline decoder
* feat - mfcc, linear * feat - mfcc, linear
* nnet - ds2 nn * nnet - ds2 nn

@ -1,11 +1,25 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// todo refactor, repalce with gtest // todo refactor, repalce with gtest
#include "base/flags.h"
#include "base/log.h"
#include "decoder/ctc_beam_search_decoder.h" #include "decoder/ctc_beam_search_decoder.h"
#include "kaldi/util/table-types.h" #include "kaldi/util/table-types.h"
#include "base/log.h"
#include "base/flags.h"
#include "nnet/paddle_nnet.h"
#include "nnet/decodable.h" #include "nnet/decodable.h"
#include "nnet/paddle_nnet.h"
DEFINE_string(feature_respecifier, "", "test nnet prob"); DEFINE_string(feature_respecifier, "", "test nnet prob");
@ -13,45 +27,48 @@ using kaldi::BaseFloat;
using kaldi::Matrix; using kaldi::Matrix;
using std::vector; using std::vector;
//void SplitFeature(kaldi::Matrix<BaseFloat> feature, // void SplitFeature(kaldi::Matrix<BaseFloat> feature,
// int32 chunk_size, // int32 chunk_size,
// std::vector<kaldi::Matrix<BaseFloat>* feature_chunks) { // std::vector<kaldi::Matrix<BaseFloat>* feature_chunks) {
//} //}
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false); gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
kaldi::SequentialBaseFloatMatrixReader feature_reader(FLAGS_feature_respecifier); kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_respecifier);
// test nnet_output --> decoder result
int32 num_done = 0, num_err = 0; // test nnet_output --> decoder result
int32 num_done = 0, num_err = 0;
ppspeech::CTCBeamSearchOptions opts;
ppspeech::CTCBeamSearch decoder(opts); ppspeech::CTCBeamSearchOptions opts;
ppspeech::CTCBeamSearch decoder(opts);
ppspeech::ModelOptions model_opts;
std::shared_ptr<ppspeech::PaddleNnet> nnet(new ppspeech::PaddleNnet(model_opts)); ppspeech::ModelOptions model_opts;
std::shared_ptr<ppspeech::PaddleNnet> nnet(
std::shared_ptr<ppspeech::Decodable> decodable(new ppspeech::Decodable(nnet)); new ppspeech::PaddleNnet(model_opts));
//int32 chunk_size = 35; std::shared_ptr<ppspeech::Decodable> decodable(
decoder.InitDecoder(); new ppspeech::Decodable(nnet));
for (; !feature_reader.Done(); feature_reader.Next()) {
string utt = feature_reader.Key(); // int32 chunk_size = 35;
const kaldi::Matrix<BaseFloat> feature = feature_reader.Value(); decoder.InitDecoder();
decodable->FeedFeatures(feature); for (; !feature_reader.Done(); feature_reader.Next()) {
decoder.AdvanceDecode(decodable, 8); string utt = feature_reader.Key();
decodable->InputFinished(); const kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
std::string result; decodable->FeedFeatures(feature);
result = decoder.GetFinalBestPath(); decoder.AdvanceDecode(decodable, 8);
KALDI_LOG << " the result of " << utt << " is " << result; decodable->InputFinished();
decodable->Reset(); std::string result;
++num_done; result = decoder.GetFinalBestPath();
} KALDI_LOG << " the result of " << utt << " is " << result;
decodable->Reset();
KALDI_LOG << "Done " << num_done << " utterances, " << num_err ++num_done;
<< " with errors."; }
return (num_done != 0 ? 0 : 1);
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
return (num_done != 0 ? 0 : 1);
} }

File diff suppressed because it is too large Load Diff

@ -1,13 +1,27 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// todo refactor, repalce with gtest // todo refactor, repalce with gtest
#include "base/flags.h"
#include "base/log.h"
#include "frontend/feature_extractor_interface.h"
#include "frontend/linear_spectrogram.h" #include "frontend/linear_spectrogram.h"
#include "frontend/normalizer.h" #include "frontend/normalizer.h"
#include "frontend/feature_extractor_interface.h"
#include "kaldi/util/table-types.h"
#include "base/log.h"
#include "base/flags.h"
#include "kaldi/feat/wave-reader.h" #include "kaldi/feat/wave-reader.h"
#include "kaldi/util/kaldi-io.h" #include "kaldi/util/kaldi-io.h"
#include "kaldi/util/table-types.h"
DEFINE_string(wav_rspecifier, "", "test wav path"); DEFINE_string(wav_rspecifier, "", "test wav path");
DEFINE_string(feature_wspecifier, "", "test wav ark"); DEFINE_string(feature_wspecifier, "", "test wav ark");
@ -15,110 +29,228 @@ DEFINE_string(feature_check_wspecifier, "", "test wav ark");
DEFINE_string(cmvn_write_path, "./cmvn.ark", "test wav ark"); DEFINE_string(cmvn_write_path, "./cmvn.ark", "test wav ark");
std::vector<float> mean_{-13730251.531853663, -12982852.199316509, -13673844.299583456, -13089406.559646806, -12673095.524938712, -12823859.223276224, -13590267.158903603, -14257618.467152044, -14374605.116185192, -14490009.21822485, -14849827.158924166, -15354435.470563512, -15834149.206532761, -16172971.985514281, -16348740.496746974, -16423536.699409386, -16556246.263649225, -16744088.772748645, -16916184.08510357, -17054034.840031497, -17165612.509455364, -17255955.470915023, -17322572.527648456, -17408943.862033736, -17521554.799865916, -17620623.254924215, -17699792.395918526, -17723364.411134344, -17741483.4433254, -17747426.888704527, -17733315.928209435, -17748780.160905756, -17808336.883775543, -17895918.671983004, -18009812.59173023, -18098188.66548325, -18195798.958462656, -18293617.62980999, -18397432.92077201, -18505834.787318766, -18585451.8100908, -18652438.235649142, -18700960.306275308, -18734944.58792185, -18737426.313365128, -18735347.165987637, -18738813.444170244, -18737086.848890636, -18731576.2474336, -18717405.44095871, -18703089.25545657, -18691014.546456724, -18692460.568905357, -18702119.628629155, -18727710.621126678, -18761582.72034647, -18806745.835547544, -18850674.8692112, -18884431.510951452, -18919999.992506847, -18939303.799078144, -18952946.273760635, -18980289.22996379, -19011610.17803294, -19040948.61805145, -19061021.429847397, -19112055.53768819, -19149667.414264943, -19201127.05091321, -19270250.82564605, -19334606.883057203, -19390513.336589377, -19444176.259208687, -19502755.000038862, -19544333.014549147, -19612668.183176614, -19681902.19006569, -19771969.951249883, -19873329.723376893, -19996752.59235844, -20110031.131400537, -20231658.612529557, -20319378.894054495, -20378534.45718066, -20413332.089584175, -20438147.844177883, -20443710.248040095, -20465457.02238927, -20488610.969337028, -20516295.16424432, -20541423.795738827, -20553192.874953747, -20573605.50701977, -20577871.61936797, -20571807.008916274, -20556242.38912231, -20542199.30819195, -20521239.063551214, -20519150.80004532, -20527204.80248933, -20536933.769257784, -20543470.522332076, -20549700.089992985, -20551525.24958494, -20554873.406493705, -20564277.65794227, -20572211.740052115, -20574305.69550465, -20575494.450104576, -20567092.577932164, -20549302.929608088, -20545445.11878376, -20546625.326603737, -20549190.03499401, -20554824.947828256, -20568341.378989458, -20577582.331383612, -20577980.519402675, -20566603.03458152, -20560131.592262644, -20552166.469060015, -20549063.06763577, -20544490.562339947, -20539817.82346569, -20528747.715731595, -20518026.24576161, -20510977.844974525, -20506874.36087992, -20506731.11977665, -20510482.133420516, -20507760.92101862, -20494644.834457114, -20480107.89304893, -20461312.091867123, -20442941.75080173, -20426123.02834838, -20424607.675283, -20426810.369107097, -20434024.50097819, -20437404.75544205, -20447688.63916367, -20460893.335563846, -20482922.735127095, -20503610.119434915, -20527062.76448319, -20557830.035128627, -20593274.72068722, -20632528.452965066, -20673637.471334763, -20733106.97143075, -20842921.0447562, -21054357.83621519, -21416569.534189366, -21978460.272811692, -22753170.052172784, -23671344.10563395, -24613499.293358143, -25406477.12230188, -25884377.82156489, -26049040.62791664, -26996879.104431007}; std::vector<float> mean_{
std::vector<float> variance_{213747175.10846674, 188395815.34302503, 212706429.10966414, 199109025.81461075, 189235901.23864496, 194901336.53253657, 217481594.29306737, 238689869.12327808, 243977501.24115244, 248479623.6431067, 259766741.47116545, 275516766.7790273, 291271202.3691234, 302693239.8220509, 308627358.3997694, 311143911.38788426, 315446105.07731867, 321705430.9341829, 327458907.4659941, 332245072.43223983, 336251717.5935284, 339694069.7639722, 342188204.4322228, 345587110.31313115, 349903086.2875232, 353660214.20643026, 356700344.5270885, 357665362.3529641, 358493352.05658793, 358857951.620328, 358375239.52774596, 358899733.6342954, 361051818.3511561, 364361716.05025816, 368750322.3771452, 372047800.6462831, 375655861.1349018, 379358519.1980013, 383327605.3935181, 387458599.282341, 390434692.3406868, 392994486.35057056, 394874418.04603153, 396230525.79763395, 396365592.0414835, 396334819.8242737, 396488353.19250053, 396438877.00744957, 396197980.4459586, 395590921.6672991, 395001107.62072515, 394528291.7318225, 394593110.424006, 395018405.59353715, 396110577.5415993, 397506704.0371068, 399400197.4657644, 401243568.2468382, 402687134.7805103, 404136047.2872507, 404883170.001883, 405522253.219517, 406660365.3626476, 407919346.0991902, 409045348.5384909, 409759588.7889818, 411974821.8564483, 413489718.78201455, 415535392.56684107, 418466481.97674364, 421104678.35678065, 423405392.5200779, 425550570.40798235, 427929423.9579701, 429585274.253478, 432368493.55181056, 435193587.13513297, 438886855.20476013, 443058876.8633751, 448181232.5093362, 452883835.6332396, 458056721.77926534, 461816531.22735566, 464363620.1970998, 465886343.5057493, 466928872.0651, 467180536.42647296, 468111848.70714295, 469138695.3071312, 470378429.6930793, 471517958.7132626, 472109050.4262365, 473087417.0177867, 473381322.04648733, 473220195.85483915, 472666071.8998819, 472124669.87879956, 471298571.411737, 471251033.2902761, 471672676.43128747, 472177147.2193172, 472572361.7711908, 472968783.7751127, 473156295.4164052, 473398034.82676554, 473897703.5203811, 474328271.33112127, 474452670.98002136, 474549003.99284613, 474252887.13567275, 473557462.909069, 473483385.85193115, 473609738.04855174, 473746944.82085115, 474016729.91696435, 474617321.94138587, 475045097.237122, 475125402.586558, 474664112.9824912, 474426247.5800283, 474104075.42796475, 473978219.7273978, 473773171.7798875, 473578534.69508696, 473102924.16904145, 472651240.5232615, 472374383.1810912, 472209479.6956096, 472202298.8921673, 472370090.76781124, 472220933.99374026, 471625467.37106377, 470994646.51883453, 470182428.9637543, 469348211.5939578, 468570387.4467277, 468540442.7225135, 468672018.90414184, 468994346.9533251, 469138757.58201426, 469553915.95710236, 470134523.38582784, 471082421.62055486, 471962316.51804745, 472939745.1708408, 474250621.5944825, 475773933.43199486, 477465399.71087736, 479218782.61382693, 481752299.7930922, 486608947.8984568, 496119403.2067917, 512730085.5704984, 539048915.2641417, 576285298.3548826, 621610270.2240586, 669308196.4436442, 710656993.5957186, 736344437.3725077, 745481288.0241544, 801121432.9925804}; -13730251.531853663, -12982852.199316509, -13673844.299583456,
-13089406.559646806, -12673095.524938712, -12823859.223276224,
-13590267.158903603, -14257618.467152044, -14374605.116185192,
-14490009.21822485, -14849827.158924166, -15354435.470563512,
-15834149.206532761, -16172971.985514281, -16348740.496746974,
-16423536.699409386, -16556246.263649225, -16744088.772748645,
-16916184.08510357, -17054034.840031497, -17165612.509455364,
-17255955.470915023, -17322572.527648456, -17408943.862033736,
-17521554.799865916, -17620623.254924215, -17699792.395918526,
-17723364.411134344, -17741483.4433254, -17747426.888704527,
-17733315.928209435, -17748780.160905756, -17808336.883775543,
-17895918.671983004, -18009812.59173023, -18098188.66548325,
-18195798.958462656, -18293617.62980999, -18397432.92077201,
-18505834.787318766, -18585451.8100908, -18652438.235649142,
-18700960.306275308, -18734944.58792185, -18737426.313365128,
-18735347.165987637, -18738813.444170244, -18737086.848890636,
-18731576.2474336, -18717405.44095871, -18703089.25545657,
-18691014.546456724, -18692460.568905357, -18702119.628629155,
-18727710.621126678, -18761582.72034647, -18806745.835547544,
-18850674.8692112, -18884431.510951452, -18919999.992506847,
-18939303.799078144, -18952946.273760635, -18980289.22996379,
-19011610.17803294, -19040948.61805145, -19061021.429847397,
-19112055.53768819, -19149667.414264943, -19201127.05091321,
-19270250.82564605, -19334606.883057203, -19390513.336589377,
-19444176.259208687, -19502755.000038862, -19544333.014549147,
-19612668.183176614, -19681902.19006569, -19771969.951249883,
-19873329.723376893, -19996752.59235844, -20110031.131400537,
-20231658.612529557, -20319378.894054495, -20378534.45718066,
-20413332.089584175, -20438147.844177883, -20443710.248040095,
-20465457.02238927, -20488610.969337028, -20516295.16424432,
-20541423.795738827, -20553192.874953747, -20573605.50701977,
-20577871.61936797, -20571807.008916274, -20556242.38912231,
-20542199.30819195, -20521239.063551214, -20519150.80004532,
-20527204.80248933, -20536933.769257784, -20543470.522332076,
-20549700.089992985, -20551525.24958494, -20554873.406493705,
-20564277.65794227, -20572211.740052115, -20574305.69550465,
-20575494.450104576, -20567092.577932164, -20549302.929608088,
-20545445.11878376, -20546625.326603737, -20549190.03499401,
-20554824.947828256, -20568341.378989458, -20577582.331383612,
-20577980.519402675, -20566603.03458152, -20560131.592262644,
-20552166.469060015, -20549063.06763577, -20544490.562339947,
-20539817.82346569, -20528747.715731595, -20518026.24576161,
-20510977.844974525, -20506874.36087992, -20506731.11977665,
-20510482.133420516, -20507760.92101862, -20494644.834457114,
-20480107.89304893, -20461312.091867123, -20442941.75080173,
-20426123.02834838, -20424607.675283, -20426810.369107097,
-20434024.50097819, -20437404.75544205, -20447688.63916367,
-20460893.335563846, -20482922.735127095, -20503610.119434915,
-20527062.76448319, -20557830.035128627, -20593274.72068722,
-20632528.452965066, -20673637.471334763, -20733106.97143075,
-20842921.0447562, -21054357.83621519, -21416569.534189366,
-21978460.272811692, -22753170.052172784, -23671344.10563395,
-24613499.293358143, -25406477.12230188, -25884377.82156489,
-26049040.62791664, -26996879.104431007};
std::vector<float> variance_{
213747175.10846674, 188395815.34302503, 212706429.10966414,
199109025.81461075, 189235901.23864496, 194901336.53253657,
217481594.29306737, 238689869.12327808, 243977501.24115244,
248479623.6431067, 259766741.47116545, 275516766.7790273,
291271202.3691234, 302693239.8220509, 308627358.3997694,
311143911.38788426, 315446105.07731867, 321705430.9341829,
327458907.4659941, 332245072.43223983, 336251717.5935284,
339694069.7639722, 342188204.4322228, 345587110.31313115,
349903086.2875232, 353660214.20643026, 356700344.5270885,
357665362.3529641, 358493352.05658793, 358857951.620328,
358375239.52774596, 358899733.6342954, 361051818.3511561,
364361716.05025816, 368750322.3771452, 372047800.6462831,
375655861.1349018, 379358519.1980013, 383327605.3935181,
387458599.282341, 390434692.3406868, 392994486.35057056,
394874418.04603153, 396230525.79763395, 396365592.0414835,
396334819.8242737, 396488353.19250053, 396438877.00744957,
396197980.4459586, 395590921.6672991, 395001107.62072515,
394528291.7318225, 394593110.424006, 395018405.59353715,
396110577.5415993, 397506704.0371068, 399400197.4657644,
401243568.2468382, 402687134.7805103, 404136047.2872507,
404883170.001883, 405522253.219517, 406660365.3626476,
407919346.0991902, 409045348.5384909, 409759588.7889818,
411974821.8564483, 413489718.78201455, 415535392.56684107,
418466481.97674364, 421104678.35678065, 423405392.5200779,
425550570.40798235, 427929423.9579701, 429585274.253478,
432368493.55181056, 435193587.13513297, 438886855.20476013,
443058876.8633751, 448181232.5093362, 452883835.6332396,
458056721.77926534, 461816531.22735566, 464363620.1970998,
465886343.5057493, 466928872.0651, 467180536.42647296,
468111848.70714295, 469138695.3071312, 470378429.6930793,
471517958.7132626, 472109050.4262365, 473087417.0177867,
473381322.04648733, 473220195.85483915, 472666071.8998819,
472124669.87879956, 471298571.411737, 471251033.2902761,
471672676.43128747, 472177147.2193172, 472572361.7711908,
472968783.7751127, 473156295.4164052, 473398034.82676554,
473897703.5203811, 474328271.33112127, 474452670.98002136,
474549003.99284613, 474252887.13567275, 473557462.909069,
473483385.85193115, 473609738.04855174, 473746944.82085115,
474016729.91696435, 474617321.94138587, 475045097.237122,
475125402.586558, 474664112.9824912, 474426247.5800283,
474104075.42796475, 473978219.7273978, 473773171.7798875,
473578534.69508696, 473102924.16904145, 472651240.5232615,
472374383.1810912, 472209479.6956096, 472202298.8921673,
472370090.76781124, 472220933.99374026, 471625467.37106377,
470994646.51883453, 470182428.9637543, 469348211.5939578,
468570387.4467277, 468540442.7225135, 468672018.90414184,
468994346.9533251, 469138757.58201426, 469553915.95710236,
470134523.38582784, 471082421.62055486, 471962316.51804745,
472939745.1708408, 474250621.5944825, 475773933.43199486,
477465399.71087736, 479218782.61382693, 481752299.7930922,
486608947.8984568, 496119403.2067917, 512730085.5704984,
539048915.2641417, 576285298.3548826, 621610270.2240586,
669308196.4436442, 710656993.5957186, 736344437.3725077,
745481288.0241544, 801121432.9925804};
int count_ = 912592; int count_ = 912592;
void WriteMatrix() { void WriteMatrix() {
kaldi::Matrix<double> cmvn_stats(2, mean_.size()+ 1); kaldi::Matrix<double> cmvn_stats(2, mean_.size() + 1);
for (size_t idx = 0; idx < mean_.size(); ++idx) { for (size_t idx = 0; idx < mean_.size(); ++idx) {
cmvn_stats(0, idx) = mean_[idx]; cmvn_stats(0, idx) = mean_[idx];
cmvn_stats(1, idx) = variance_[idx]; cmvn_stats(1, idx) = variance_[idx];
} }
cmvn_stats(0, mean_.size()) = count_; cmvn_stats(0, mean_.size()) = count_;
kaldi::WriteKaldiObject(cmvn_stats, FLAGS_cmvn_write_path, true); kaldi::WriteKaldiObject(cmvn_stats, FLAGS_cmvn_write_path, true);
} }
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false); gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(FLAGS_wav_rspecifier);
kaldi::BaseFloatMatrixWriter feat_writer(FLAGS_feature_wspecifier);
kaldi::BaseFloatMatrixWriter feat_cmvn_check_writer(FLAGS_feature_check_wspecifier);
WriteMatrix();
// test feature linear_spectorgram: wave --> decibel_normalizer --> hanning window -->linear_spectrogram --> cmvn
int32 num_done = 0, num_err = 0;
ppspeech::LinearSpectrogramOptions opt;
opt.frame_opts.frame_length_ms = 20;
opt.frame_opts.frame_shift_ms = 10;
ppspeech::DecibelNormalizerOptions db_norm_opt;
std::unique_ptr<ppspeech::FeatureExtractorInterface> base_feature_extractor(
new ppspeech::DecibelNormalizer(db_norm_opt));
ppspeech::LinearSpectrogram linear_spectrogram(opt, std::move(base_feature_extractor));
ppspeech::CMVN cmvn(FLAGS_cmvn_write_path);
float streaming_chunk = 0.36;
int sample_rate = 16000;
int chunk_sample_size = streaming_chunk * sample_rate;
LOG(INFO) << mean_.size();
for (size_t i = 0; i < mean_.size(); i++) {
mean_[i] /= count_;
variance_[i] = variance_[i] / count_ - mean_[i] * mean_[i];
if (variance_[i] < 1.0e-20) {
variance_[i] = 1.0e-20;
}
variance_[i] = 1.0 / std::sqrt(variance_[i]);
}
for (; !wav_reader.Done(); wav_reader.Next()) {
std::string utt = wav_reader.Key();
const kaldi::WaveData &wave_data = wav_reader.Value();
int32 this_channel = 0;
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(), this_channel);
int tot_samples = waveform.Dim();
int sample_offset = 0;
std::vector<kaldi::Matrix<BaseFloat>> feats;
int feature_rows = 0;
while (sample_offset < tot_samples) {
int cur_chunk_size = std::min(chunk_sample_size, tot_samples - sample_offset);
kaldi::Vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk(i) = waveform(sample_offset + i);
}
kaldi::Matrix<BaseFloat> features;
linear_spectrogram.AcceptWaveform(wav_chunk);
linear_spectrogram.ReadFeats(&features);
feats.push_back(features);
sample_offset += cur_chunk_size;
feature_rows += features.NumRows();
}
int cur_idx = 0; kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
kaldi::Matrix<kaldi::BaseFloat> features(feature_rows, feats[0].NumCols()); FLAGS_wav_rspecifier);
for (auto feat : feats) { kaldi::BaseFloatMatrixWriter feat_writer(FLAGS_feature_wspecifier);
for (int row_idx = 0; row_idx < feat.NumRows(); ++row_idx) { kaldi::BaseFloatMatrixWriter feat_cmvn_check_writer(
for (int col_idx = 0; col_idx < feat.NumCols(); ++col_idx) { FLAGS_feature_check_wspecifier);
features(cur_idx, col_idx) = (feat(row_idx, col_idx) - mean_[col_idx])*variance_[col_idx]; WriteMatrix();
}
++cur_idx; // test feature linear_spectorgram: wave --> decibel_normalizer --> hanning
} // window -->linear_spectrogram --> cmvn
} int32 num_done = 0, num_err = 0;
feat_writer.Write(utt, features); ppspeech::LinearSpectrogramOptions opt;
opt.frame_opts.frame_length_ms = 20;
cur_idx = 0; opt.frame_opts.frame_shift_ms = 10;
kaldi::Matrix<kaldi::BaseFloat> features_check(feature_rows, feats[0].NumCols()); ppspeech::DecibelNormalizerOptions db_norm_opt;
for (auto feat : feats) { std::unique_ptr<ppspeech::FeatureExtractorInterface> base_feature_extractor(
for (int row_idx = 0; row_idx < feat.NumRows(); ++row_idx) { new ppspeech::DecibelNormalizer(db_norm_opt));
for (int col_idx = 0; col_idx < feat.NumCols(); ++col_idx) { ppspeech::LinearSpectrogram linear_spectrogram(
features_check(cur_idx, col_idx) = feat(row_idx, col_idx); opt, std::move(base_feature_extractor));
}
kaldi::SubVector<BaseFloat> row_feat(features_check, cur_idx); ppspeech::CMVN cmvn(FLAGS_cmvn_write_path);
cmvn.ApplyCMVN(true, &row_feat);
++cur_idx; float streaming_chunk = 0.36;
} int sample_rate = 16000;
int chunk_sample_size = streaming_chunk * sample_rate;
LOG(INFO) << mean_.size();
for (size_t i = 0; i < mean_.size(); i++) {
mean_[i] /= count_;
variance_[i] = variance_[i] / count_ - mean_[i] * mean_[i];
if (variance_[i] < 1.0e-20) {
variance_[i] = 1.0e-20;
}
variance_[i] = 1.0 / std::sqrt(variance_[i]);
} }
feat_cmvn_check_writer.Write(utt, features_check);
if (num_done % 50 == 0 && num_done != 0) for (; !wav_reader.Done(); wav_reader.Next()) {
KALDI_VLOG(2) << "Processed " << num_done << " utterances"; std::string utt = wav_reader.Key();
num_done++; const kaldi::WaveData& wave_data = wav_reader.Value();
}
int32 this_channel = 0;
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
this_channel);
int tot_samples = waveform.Dim();
int sample_offset = 0;
std::vector<kaldi::Matrix<BaseFloat>> feats;
int feature_rows = 0;
while (sample_offset < tot_samples) {
int cur_chunk_size =
std::min(chunk_sample_size, tot_samples - sample_offset);
kaldi::Vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk(i) = waveform(sample_offset + i);
}
kaldi::Matrix<BaseFloat> features;
linear_spectrogram.AcceptWaveform(wav_chunk);
linear_spectrogram.ReadFeats(&features);
feats.push_back(features);
sample_offset += cur_chunk_size;
feature_rows += features.NumRows();
}
int cur_idx = 0;
kaldi::Matrix<kaldi::BaseFloat> features(feature_rows,
feats[0].NumCols());
for (auto feat : feats) {
for (int row_idx = 0; row_idx < feat.NumRows(); ++row_idx) {
for (int col_idx = 0; col_idx < feat.NumCols(); ++col_idx) {
features(cur_idx, col_idx) =
(feat(row_idx, col_idx) - mean_[col_idx]) *
variance_[col_idx];
}
++cur_idx;
}
}
feat_writer.Write(utt, features);
cur_idx = 0;
kaldi::Matrix<kaldi::BaseFloat> features_check(feature_rows,
feats[0].NumCols());
for (auto feat : feats) {
for (int row_idx = 0; row_idx < feat.NumRows(); ++row_idx) {
for (int col_idx = 0; col_idx < feat.NumCols(); ++col_idx) {
features_check(cur_idx, col_idx) = feat(row_idx, col_idx);
}
kaldi::SubVector<BaseFloat> row_feat(features_check, cur_idx);
cmvn.ApplyCMVN(true, &row_feat);
++cur_idx;
}
}
feat_cmvn_check_writer.Write(utt, features_check);
if (num_done % 50 == 0 && num_done != 0)
KALDI_VLOG(2) << "Processed " << num_done << " utterances";
num_done++;
}
KALDI_LOG << "Done " << num_done << " utterances, " << num_err KALDI_LOG << "Done " << num_done << " utterances, " << num_err
<< " with errors."; << " with errors.";
return (num_done != 0 ? 0 : 1); return (num_done != 0 ? 0 : 1);

@ -1,12 +1,26 @@
#include "paddle_inference_api.h" // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h> #include <gflags/gflags.h>
#include <iostream> #include <algorithm>
#include <thread>
#include <fstream> #include <fstream>
#include <functional>
#include <iostream>
#include <iterator> #include <iterator>
#include <algorithm>
#include <numeric> #include <numeric>
#include <functional> #include <thread>
#include "paddle_inference_api.h"
using std::cout; using std::cout;
using std::endl; using std::endl;
@ -19,19 +33,19 @@ void produce_data(std::vector<std::vector<float>>* data);
void model_forward_test(); void model_forward_test();
void produce_data(std::vector<std::vector<float>>* data) { void produce_data(std::vector<std::vector<float>>* data) {
int chunk_size = 35; // chunk_size in frame int chunk_size = 35; // chunk_size in frame
int col_size = 161; // feat dim int col_size = 161; // feat dim
cout << "chunk size: " << chunk_size << endl; cout << "chunk size: " << chunk_size << endl;
cout << "feat dim: " << col_size << endl; cout << "feat dim: " << col_size << endl;
data->reserve(chunk_size); data->reserve(chunk_size);
data->back().reserve(col_size); data->back().reserve(col_size);
for (int row = 0; row < chunk_size; ++row) { for (int row = 0; row < chunk_size; ++row) {
data->push_back(std::vector<float>()); data->push_back(std::vector<float>());
for (int col_idx = 0; col_idx < col_size; ++col_idx) { for (int col_idx = 0; col_idx < col_size; ++col_idx) {
data->back().push_back(0.201); data->back().push_back(0.201);
}
} }
}
} }
void model_forward_test() { void model_forward_test() {
@ -39,7 +53,8 @@ void model_forward_test() {
std::vector<std::vector<float>> feats; std::vector<std::vector<float>> feats;
produce_data(&feats); produce_data(&feats);
std::cout << "2. load the model" << std::endl;; std::cout << "2. load the model" << std::endl;
;
std::string model_graph = FLAGS_model_path; std::string model_graph = FLAGS_model_path;
std::string model_params = FLAGS_param_path; std::string model_params = FLAGS_param_path;
cout << "model path: " << model_graph << endl; cout << "model path: " << model_graph << endl;
@ -53,9 +68,10 @@ void model_forward_test() {
cout << "DisableFCPadding: " << endl; cout << "DisableFCPadding: " << endl;
auto predictor = paddle_infer::CreatePredictor(config); auto predictor = paddle_infer::CreatePredictor(config);
std::cout << "3. feat shape, row=" << feats.size() << ",col=" << feats[0].size() << std::endl; std::cout << "3. feat shape, row=" << feats.size()
<< ",col=" << feats[0].size() << std::endl;
std::vector<float> pp_input_mat; std::vector<float> pp_input_mat;
for(const auto& item : feats) { for (const auto& item : feats) {
pp_input_mat.insert(pp_input_mat.end(), item.begin(), item.end()); pp_input_mat.insert(pp_input_mat.end(), item.begin(), item.end());
} }
@ -64,11 +80,11 @@ void model_forward_test() {
int col = feats[0].size(); int col = feats[0].size();
std::vector<std::string> input_names = predictor->GetInputNames(); std::vector<std::string> input_names = predictor->GetInputNames();
std::vector<std::string> output_names = predictor->GetOutputNames(); std::vector<std::string> output_names = predictor->GetOutputNames();
for (auto name : input_names){ for (auto name : input_names) {
cout << "model input names: " << name << endl; cout << "model input names: " << name << endl;
} }
for (auto name : output_names){ for (auto name : output_names) {
cout << "model output names: " << name << endl; cout << "model output names: " << name << endl;
} }
// input // input
@ -79,7 +95,8 @@ void model_forward_test() {
input_tensor->CopyFromCpu(pp_input_mat.data()); input_tensor->CopyFromCpu(pp_input_mat.data());
// input length // input length
std::unique_ptr<paddle_infer::Tensor> input_len = predictor->GetInputHandle(input_names[1]); std::unique_ptr<paddle_infer::Tensor> input_len =
predictor->GetInputHandle(input_names[1]);
std::vector<int> input_len_size = {1}; std::vector<int> input_len_size = {1};
input_len->Reshape(input_len_size); input_len->Reshape(input_len_size);
std::vector<int64_t> audio_len; std::vector<int64_t> audio_len;
@ -87,20 +104,28 @@ void model_forward_test() {
input_len->CopyFromCpu(audio_len.data()); input_len->CopyFromCpu(audio_len.data());
// state_h // state_h
std::unique_ptr<paddle_infer::Tensor> chunk_state_h_box = predictor->GetInputHandle(input_names[2]); std::unique_ptr<paddle_infer::Tensor> chunk_state_h_box =
predictor->GetInputHandle(input_names[2]);
std::vector<int> chunk_state_h_box_shape = {3, 1, 1024}; std::vector<int> chunk_state_h_box_shape = {3, 1, 1024};
chunk_state_h_box->Reshape(chunk_state_h_box_shape); chunk_state_h_box->Reshape(chunk_state_h_box_shape);
int chunk_state_h_box_size = std::accumulate(chunk_state_h_box_shape.begin(), chunk_state_h_box_shape.end(), int chunk_state_h_box_size =
1, std::multiplies<int>()); std::accumulate(chunk_state_h_box_shape.begin(),
chunk_state_h_box_shape.end(),
1,
std::multiplies<int>());
std::vector<float> chunk_state_h_box_data(chunk_state_h_box_size, 0.0f); std::vector<float> chunk_state_h_box_data(chunk_state_h_box_size, 0.0f);
chunk_state_h_box->CopyFromCpu(chunk_state_h_box_data.data()); chunk_state_h_box->CopyFromCpu(chunk_state_h_box_data.data());
// state_c // state_c
std::unique_ptr<paddle_infer::Tensor> chunk_state_c_box = predictor->GetInputHandle(input_names[3]); std::unique_ptr<paddle_infer::Tensor> chunk_state_c_box =
predictor->GetInputHandle(input_names[3]);
std::vector<int> chunk_state_c_box_shape = {3, 1, 1024}; std::vector<int> chunk_state_c_box_shape = {3, 1, 1024};
chunk_state_c_box->Reshape(chunk_state_c_box_shape); chunk_state_c_box->Reshape(chunk_state_c_box_shape);
int chunk_state_c_box_size = std::accumulate(chunk_state_c_box_shape.begin(), chunk_state_c_box_shape.end(), int chunk_state_c_box_size =
1, std::multiplies<int>()); std::accumulate(chunk_state_c_box_shape.begin(),
chunk_state_c_box_shape.end(),
1,
std::multiplies<int>());
std::vector<float> chunk_state_c_box_data(chunk_state_c_box_size, 0.0f); std::vector<float> chunk_state_c_box_data(chunk_state_c_box_size, 0.0f);
chunk_state_c_box->CopyFromCpu(chunk_state_c_box_data.data()); chunk_state_c_box->CopyFromCpu(chunk_state_c_box_data.data());
@ -108,18 +133,20 @@ void model_forward_test() {
bool success = predictor->Run(); bool success = predictor->Run();
// state_h out // state_h out
std::unique_ptr<paddle_infer::Tensor> h_out = predictor->GetOutputHandle(output_names[2]); std::unique_ptr<paddle_infer::Tensor> h_out =
predictor->GetOutputHandle(output_names[2]);
std::vector<int> h_out_shape = h_out->shape(); std::vector<int> h_out_shape = h_out->shape();
int h_out_size = std::accumulate(h_out_shape.begin(), h_out_shape.end(), int h_out_size = std::accumulate(
1, std::multiplies<int>()); h_out_shape.begin(), h_out_shape.end(), 1, std::multiplies<int>());
std::vector<float> h_out_data(h_out_size); std::vector<float> h_out_data(h_out_size);
h_out->CopyToCpu(h_out_data.data()); h_out->CopyToCpu(h_out_data.data());
// stage_c out // stage_c out
std::unique_ptr<paddle_infer::Tensor> c_out = predictor->GetOutputHandle(output_names[3]); std::unique_ptr<paddle_infer::Tensor> c_out =
predictor->GetOutputHandle(output_names[3]);
std::vector<int> c_out_shape = c_out->shape(); std::vector<int> c_out_shape = c_out->shape();
int c_out_size = std::accumulate(c_out_shape.begin(), c_out_shape.end(), int c_out_size = std::accumulate(
1, std::multiplies<int>()); c_out_shape.begin(), c_out_shape.end(), 1, std::multiplies<int>());
std::vector<float> c_out_data(c_out_size); std::vector<float> c_out_data(c_out_size);
c_out->CopyToCpu(c_out_data.data()); c_out->CopyToCpu(c_out_data.data());
@ -128,8 +155,8 @@ void model_forward_test() {
predictor->GetOutputHandle(output_names[0]); predictor->GetOutputHandle(output_names[0]);
std::vector<int> output_shape = output_tensor->shape(); std::vector<int> output_shape = output_tensor->shape();
std::vector<float> output_probs; std::vector<float> output_probs;
int output_size = std::accumulate(output_shape.begin(), output_shape.end(), int output_size = std::accumulate(
1, std::multiplies<int>()); output_shape.begin(), output_shape.end(), 1, std::multiplies<int>());
output_probs.resize(output_size); output_probs.resize(output_size);
output_tensor->CopyToCpu(output_probs.data()); output_tensor->CopyToCpu(output_probs.data());
row = output_shape[1]; row = output_shape[1];
@ -148,12 +175,14 @@ void model_forward_test() {
} }
std::vector<std::vector<float>> log_feat = probs; std::vector<std::vector<float>> log_feat = probs;
std::cout << "probs, row: " << log_feat.size() << " col: " << log_feat[0].size() << std::endl; std::cout << "probs, row: " << log_feat.size()
<< " col: " << log_feat[0].size() << std::endl;
for (size_t row_idx = 0; row_idx < log_feat.size(); ++row_idx) { for (size_t row_idx = 0; row_idx < log_feat.size(); ++row_idx) {
for (size_t col_idx = 0; col_idx < log_feat[row_idx].size(); ++col_idx) { for (size_t col_idx = 0; col_idx < log_feat[row_idx].size();
std::cout << log_feat[row_idx][col_idx] << " "; ++col_idx) {
} std::cout << log_feat[row_idx][col_idx] << " ";
std::cout << std::endl; }
std::cout << std::endl;
} }
} }

@ -18,22 +18,22 @@
#include <limits> #include <limits>
typedef float BaseFloat; typedef float BaseFloat;
typedef double double64; typedef double double64;
typedef signed char int8; typedef signed char int8;
typedef short int16; typedef short int16;
typedef int int32; typedef int int32;
#if defined(__LP64__) && !defined(OS_MACOSX) && !defined(OS_OPENBSD) #if defined(__LP64__) && !defined(OS_MACOSX) && !defined(OS_OPENBSD)
typedef long int64; typedef long int64;
#else #else
typedef long long int64; typedef long long int64;
#endif #endif
typedef unsigned char uint8; typedef unsigned char uint8;
typedef unsigned short uint16; typedef unsigned short uint16;
typedef unsigned int uint32; typedef unsigned int uint32;
#if defined(__LP64__) && !defined(OS_MACOSX) && !defined(OS_OPENBSD) #if defined(__LP64__) && !defined(OS_MACOSX) && !defined(OS_OPENBSD)
typedef unsigned long uint64; typedef unsigned long uint64;
@ -41,20 +41,20 @@ typedef unsigned long uint64;
typedef unsigned long long uint64; typedef unsigned long long uint64;
#endif #endif
typedef signed int char32; typedef signed int char32;
const uint8 kuint8max = (( uint8) 0xFF); const uint8 kuint8max = ((uint8)0xFF);
const uint16 kuint16max = ((uint16) 0xFFFF); const uint16 kuint16max = ((uint16)0xFFFF);
const uint32 kuint32max = ((uint32) 0xFFFFFFFF); const uint32 kuint32max = ((uint32)0xFFFFFFFF);
const uint64 kuint64max = ((uint64) (0xFFFFFFFFFFFFFFFFLL)); const uint64 kuint64max = ((uint64)(0xFFFFFFFFFFFFFFFFLL));
const int8 kint8min = (( int8) 0x80); const int8 kint8min = ((int8)0x80);
const int8 kint8max = (( int8) 0x7F); const int8 kint8max = ((int8)0x7F);
const int16 kint16min = (( int16) 0x8000); const int16 kint16min = ((int16)0x8000);
const int16 kint16max = (( int16) 0x7FFF); const int16 kint16max = ((int16)0x7FFF);
const int32 kint32min = (( int32) 0x80000000); const int32 kint32min = ((int32)0x80000000);
const int32 kint32max = (( int32) 0x7FFFFFFF); const int32 kint32max = ((int32)0x7FFFFFFF);
const int64 kint64min = (( int64) (0x8000000000000000LL)); const int64 kint64min = ((int64)(0x8000000000000000LL));
const int64 kint64max = (( int64) (0x7FFFFFFFFFFFFFFFLL)); const int64 kint64max = ((int64)(0x7FFFFFFFFFFFFFFFLL));
const BaseFloat kBaseFloatMax = std::numeric_limits<BaseFloat>::max(); const BaseFloat kBaseFloatMax = std::numeric_limits<BaseFloat>::max();
const BaseFloat kBaseFloatMin = std::numeric_limits<BaseFloat>::min(); const BaseFloat kBaseFloatMin = std::numeric_limits<BaseFloat>::min();

@ -15,22 +15,22 @@
#pragma once #pragma once
#include <deque> #include <deque>
#include <fstream>
#include <iostream> #include <iostream>
#include <istream> #include <istream>
#include <fstream>
#include <map> #include <map>
#include <memory> #include <memory>
#include <mutex>
#include <ostream> #include <ostream>
#include <set> #include <set>
#include <sstream> #include <sstream>
#include <stack> #include <stack>
#include <string> #include <string>
#include <vector>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <mutex> #include <vector>
#include "base/log.h"
#include "base/flags.h"
#include "base/basic_types.h" #include "base/basic_types.h"
#include "base/flags.h"
#include "base/log.h"
#include "base/macros.h" #include "base/macros.h"

@ -18,8 +18,8 @@ namespace ppspeech {
#ifndef DISALLOW_COPY_AND_ASSIGN #ifndef DISALLOW_COPY_AND_ASSIGN
#define DISALLOW_COPY_AND_ASSIGN(TypeName) \ #define DISALLOW_COPY_AND_ASSIGN(TypeName) \
TypeName(const TypeName&) = delete; \ TypeName(const TypeName&) = delete; \
void operator=(const TypeName&) = delete void operator=(const TypeName&) = delete
#endif #endif
} // namespace pp_speech } // namespace pp_speech

@ -23,28 +23,29 @@
#ifndef BASE_THREAD_POOL_H #ifndef BASE_THREAD_POOL_H
#define BASE_THREAD_POOL_H #define BASE_THREAD_POOL_H
#include <vector>
#include <queue>
#include <memory>
#include <thread>
#include <mutex>
#include <condition_variable> #include <condition_variable>
#include <future>
#include <functional> #include <functional>
#include <future>
#include <memory>
#include <mutex>
#include <queue>
#include <stdexcept> #include <stdexcept>
#include <thread>
#include <vector>
class ThreadPool { class ThreadPool {
public: public:
ThreadPool(size_t); ThreadPool(size_t);
template<class F, class... Args> template <class F, class... Args>
auto enqueue(F&& f, Args&&... args) auto enqueue(F&& f, Args&&... args)
-> std::future<typename std::result_of<F(Args...)>::type>; -> std::future<typename std::result_of<F(Args...)>::type>;
~ThreadPool(); ~ThreadPool();
private:
private:
// need to keep track of threads so we can join them // need to keep track of threads so we can join them
std::vector< std::thread > workers; std::vector<std::thread> workers;
// the task queue // the task queue
std::queue< std::function<void()> > tasks; std::queue<std::function<void()>> tasks;
// synchronization // synchronization
std::mutex queue_mutex; std::mutex queue_mutex;
@ -53,68 +54,57 @@ private:
}; };
// the constructor just launches some amount of workers // the constructor just launches some amount of workers
inline ThreadPool::ThreadPool(size_t threads) inline ThreadPool::ThreadPool(size_t threads) : stop(false) {
: stop(false) for (size_t i = 0; i < threads; ++i)
{ workers.emplace_back([this] {
for(size_t i = 0;i<threads;++i) for (;;) {
workers.emplace_back( std::function<void()> task;
[this]
{
for(;;)
{ {
std::function<void()> task; std::unique_lock<std::mutex> lock(this->queue_mutex);
this->condition.wait(lock, [this] {
{ return this->stop || !this->tasks.empty();
std::unique_lock<std::mutex> lock(this->queue_mutex); });
this->condition.wait(lock, if (this->stop && this->tasks.empty()) return;
[this]{ return this->stop || !this->tasks.empty(); }); task = std::move(this->tasks.front());
if(this->stop && this->tasks.empty()) this->tasks.pop();
return;
task = std::move(this->tasks.front());
this->tasks.pop();
}
task();
} }
task();
} }
); });
} }
// add new work item to the pool // add new work item to the pool
template<class F, class... Args> template <class F, class... Args>
auto ThreadPool::enqueue(F&& f, Args&&... args) auto ThreadPool::enqueue(F&& f, Args&&... args)
-> std::future<typename std::result_of<F(Args...)>::type> -> std::future<typename std::result_of<F(Args...)>::type> {
{
using return_type = typename std::result_of<F(Args...)>::type; using return_type = typename std::result_of<F(Args...)>::type;
auto task = std::make_shared< std::packaged_task<return_type()> >( auto task = std::make_shared<std::packaged_task<return_type()>>(
std::bind(std::forward<F>(f), std::forward<Args>(args)...) std::bind(std::forward<F>(f), std::forward<Args>(args)...));
);
std::future<return_type> res = task->get_future(); std::future<return_type> res = task->get_future();
{ {
std::unique_lock<std::mutex> lock(queue_mutex); std::unique_lock<std::mutex> lock(queue_mutex);
// don't allow enqueueing after stopping the pool // don't allow enqueueing after stopping the pool
if(stop) if (stop) throw std::runtime_error("enqueue on stopped ThreadPool");
throw std::runtime_error("enqueue on stopped ThreadPool");
tasks.emplace([task](){ (*task)(); }); tasks.emplace([task]() { (*task)(); });
} }
condition.notify_one(); condition.notify_one();
return res; return res;
} }
// the destructor joins all threads // the destructor joins all threads
inline ThreadPool::~ThreadPool() inline ThreadPool::~ThreadPool() {
{
{ {
std::unique_lock<std::mutex> lock(queue_mutex); std::unique_lock<std::mutex> lock(queue_mutex);
stop = true; stop = true;
} }
condition.notify_all(); condition.notify_all();
for(std::thread &worker: workers) for (std::thread& worker : workers) worker.join();
worker.join();
} }
#endif #endif

@ -1,7 +1,21 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "base/basic_types.h" #include "base/basic_types.h"
struct DecoderResult { struct DecoderResult {
BaseFloat acoustic_score; BaseFloat acoustic_score;
std::vector<int32> words_idx; std::vector<int32> words_idx;
std::vector<pair<int32, int32>> time_stamp; std::vector<pair<int32, int32>> time_stamp;
}; };

@ -1,3 +1,17 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder/ctc_beam_search_decoder.h" #include "decoder/ctc_beam_search_decoder.h"
#include "base/basic_types.h" #include "base/basic_types.h"
@ -9,34 +23,31 @@ namespace ppspeech {
using std::vector; using std::vector;
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>; using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts) : CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts)
opts_(opts), : opts_(opts),
init_ext_scorer_(nullptr), init_ext_scorer_(nullptr),
blank_id(-1), blank_id(-1),
space_id(-1), space_id(-1),
num_frame_decoded_(0), num_frame_decoded_(0),
root(nullptr) { root(nullptr) {
LOG(INFO) << "dict path: " << opts_.dict_file; LOG(INFO) << "dict path: " << opts_.dict_file;
if (!ReadFileToVector(opts_.dict_file, &vocabulary_)) { if (!ReadFileToVector(opts_.dict_file, &vocabulary_)) {
LOG(INFO) << "load the dict failed"; LOG(INFO) << "load the dict failed";
} }
LOG(INFO) << "read the vocabulary success, dict size: " << vocabulary_.size(); LOG(INFO) << "read the vocabulary success, dict size: "
<< vocabulary_.size();
LOG(INFO) << "language model path: " << opts_.lm_path; LOG(INFO) << "language model path: " << opts_.lm_path;
init_ext_scorer_ = std::make_shared<Scorer>(opts_.alpha, init_ext_scorer_ = std::make_shared<Scorer>(
opts_.beta, opts_.alpha, opts_.beta, opts_.lm_path, vocabulary_);
opts_.lm_path,
vocabulary_);
} }
void CTCBeamSearch::Reset() { void CTCBeamSearch::Reset() {
num_frame_decoded_ = 0; num_frame_decoded_ = 0;
ResetPrefixes(); ResetPrefixes();
} }
void CTCBeamSearch::InitDecoder() { void CTCBeamSearch::InitDecoder() {
blank_id = 0; blank_id = 0;
auto it = std::find(vocabulary_.begin(), vocabulary_.end(), " "); auto it = std::find(vocabulary_.begin(), vocabulary_.end(), " ");
@ -51,10 +62,11 @@ void CTCBeamSearch::InitDecoder() {
root = std::make_shared<PathTrie>(); root = std::make_shared<PathTrie>();
root->score = root->log_prob_b_prev = 0.0; root->score = root->log_prob_b_prev = 0.0;
prefixes.push_back(root.get()); prefixes.push_back(root.get());
if (init_ext_scorer_ != nullptr && !init_ext_scorer_->is_character_based()) { if (init_ext_scorer_ != nullptr &&
!init_ext_scorer_->is_character_based()) {
auto fst_dict = auto fst_dict =
static_cast<fst::StdVectorFst *>(init_ext_scorer_->dictionary); static_cast<fst::StdVectorFst*>(init_ext_scorer_->dictionary);
fst::StdVectorFst *dict_ptr = fst_dict->Copy(true); fst::StdVectorFst* dict_ptr = fst_dict->Copy(true);
root->set_dictionary(dict_ptr); root->set_dictionary(dict_ptr);
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT); auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
@ -62,239 +74,239 @@ void CTCBeamSearch::InitDecoder() {
} }
} }
void CTCBeamSearch::Decode(std::shared_ptr<kaldi::DecodableInterface> decodable) { void CTCBeamSearch::Decode(
return; std::shared_ptr<kaldi::DecodableInterface> decodable) {
return;
} }
int32 CTCBeamSearch::NumFrameDecoded() { int32 CTCBeamSearch::NumFrameDecoded() { return num_frame_decoded_; }
return num_frame_decoded_;
}
// todo rename, refactor // todo rename, refactor
void CTCBeamSearch::AdvanceDecode(const std::shared_ptr<kaldi::DecodableInterface>& decodable, void CTCBeamSearch::AdvanceDecode(
int max_frames) { const std::shared_ptr<kaldi::DecodableInterface>& decodable,
int max_frames) {
while (max_frames > 0) { while (max_frames > 0) {
vector<vector<BaseFloat>> likelihood; vector<vector<BaseFloat>> likelihood;
if (decodable->IsLastFrame(NumFrameDecoded() + 1)) { if (decodable->IsLastFrame(NumFrameDecoded() + 1)) {
break; break;
} }
likelihood.push_back(decodable->FrameLogLikelihood(NumFrameDecoded() + 1)); likelihood.push_back(
AdvanceDecoding(likelihood); decodable->FrameLogLikelihood(NumFrameDecoded() + 1));
max_frames--; AdvanceDecoding(likelihood);
max_frames--;
} }
} }
void CTCBeamSearch::ResetPrefixes() { void CTCBeamSearch::ResetPrefixes() {
for (size_t i = 0; i < prefixes.size(); i++) { for (size_t i = 0; i < prefixes.size(); i++) {
if (prefixes[i] != nullptr) { if (prefixes[i] != nullptr) {
delete prefixes[i]; delete prefixes[i];
prefixes[i] = nullptr; prefixes[i] = nullptr;
}
} }
}
} }
int CTCBeamSearch::DecodeLikelihoods(const vector<vector<float>>&probs, int CTCBeamSearch::DecodeLikelihoods(const vector<vector<float>>& probs,
vector<string>& nbest_words) { vector<string>& nbest_words) {
kaldi::Timer timer; kaldi::Timer timer;
timer.Reset(); timer.Reset();
AdvanceDecoding(probs); AdvanceDecoding(probs);
LOG(INFO) <<"ctc decoding elapsed time(s) " << static_cast<float>(timer.Elapsed()) / 1000.0f; LOG(INFO) << "ctc decoding elapsed time(s) "
return 0; << static_cast<float>(timer.Elapsed()) / 1000.0f;
return 0;
} }
vector<std::pair<double, string>> CTCBeamSearch::GetNBestPath() { vector<std::pair<double, string>> CTCBeamSearch::GetNBestPath() {
return get_beam_search_result(prefixes, vocabulary_, opts_.beam_size); return get_beam_search_result(prefixes, vocabulary_, opts_.beam_size);
} }
string CTCBeamSearch::GetBestPath() { string CTCBeamSearch::GetBestPath() {
std::vector<std::pair<double, std::string>> result; std::vector<std::pair<double, std::string>> result;
result = get_beam_search_result(prefixes, vocabulary_, opts_.beam_size); result = get_beam_search_result(prefixes, vocabulary_, opts_.beam_size);
return result[0].second; return result[0].second;
} }
string CTCBeamSearch::GetFinalBestPath() { string CTCBeamSearch::GetFinalBestPath() {
CalculateApproxScore(); CalculateApproxScore();
LMRescore(); LMRescore();
return GetBestPath(); return GetBestPath();
} }
void CTCBeamSearch::AdvanceDecoding(const vector<vector<BaseFloat>>& probs) { void CTCBeamSearch::AdvanceDecoding(const vector<vector<BaseFloat>>& probs) {
size_t num_time_steps = probs.size(); size_t num_time_steps = probs.size();
size_t beam_size = opts_.beam_size; size_t beam_size = opts_.beam_size;
double cutoff_prob = opts_.cutoff_prob; double cutoff_prob = opts_.cutoff_prob;
size_t cutoff_top_n = opts_.cutoff_top_n; size_t cutoff_top_n = opts_.cutoff_top_n;
vector<vector<double>> probs_seq(probs.size(), vector<double>(probs[0].size(), 0)); vector<vector<double>> probs_seq(probs.size(),
vector<double>(probs[0].size(), 0));
int row = probs.size();
int col = probs[0].size(); int row = probs.size();
for(int i = 0; i < row; i++) { int col = probs[0].size();
for (int j = 0; j < col; j++){ for (int i = 0; i < row; i++) {
probs_seq[i][j] = static_cast<double>(probs[i][j]); for (int j = 0; j < col; j++) {
} probs_seq[i][j] = static_cast<double>(probs[i][j]);
} }
for (size_t time_step = 0; time_step < num_time_steps; time_step++) {
const auto& prob = probs_seq[time_step];
float min_cutoff = -NUM_FLT_INF;
bool full_beam = false;
if (init_ext_scorer_ != nullptr) {
size_t num_prefixes = std::min(prefixes.size(), beam_size);
std::sort(prefixes.begin(), prefixes.begin() + num_prefixes,
prefix_compare);
if (num_prefixes == 0) {
continue;
}
min_cutoff = prefixes[num_prefixes - 1]->score +
std::log(prob[blank_id]) -
std::max(0.0, init_ext_scorer_->beta);
full_beam = (num_prefixes == beam_size);
} }
vector<std::pair<size_t, float>> log_prob_idx = for (size_t time_step = 0; time_step < num_time_steps; time_step++) {
get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n); const auto& prob = probs_seq[time_step];
float min_cutoff = -NUM_FLT_INF;
bool full_beam = false;
if (init_ext_scorer_ != nullptr) {
size_t num_prefixes = std::min(prefixes.size(), beam_size);
std::sort(prefixes.begin(),
prefixes.begin() + num_prefixes,
prefix_compare);
if (num_prefixes == 0) {
continue;
}
min_cutoff = prefixes[num_prefixes - 1]->score +
std::log(prob[blank_id]) -
std::max(0.0, init_ext_scorer_->beta);
full_beam = (num_prefixes == beam_size);
}
// loop over chars vector<std::pair<size_t, float>> log_prob_idx =
size_t log_prob_idx_len = log_prob_idx.size(); get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n);
for (size_t index = 0; index < log_prob_idx_len; index++) {
SearchOneChar(full_beam, log_prob_idx[index], min_cutoff);
}
prefixes.clear(); // loop over chars
size_t log_prob_idx_len = log_prob_idx.size();
// update log probs for (size_t index = 0; index < log_prob_idx_len; index++) {
root->iterate_to_vec(prefixes); SearchOneChar(full_beam, log_prob_idx[index], min_cutoff);
// only preserve top beam_size prefixes }
if (prefixes.size() >= beam_size) {
std::nth_element(prefixes.begin(),
prefixes.begin() + beam_size,
prefixes.end(),
prefix_compare);
for (size_t i = beam_size; i < prefixes.size(); ++i) {
prefixes[i]->remove();
}
} // if
num_frame_decoded_++;
} // for probs_seq
}
int32 CTCBeamSearch::SearchOneChar(const bool& full_beam, prefixes.clear();
const std::pair<size_t, BaseFloat>& log_prob_idx,
const BaseFloat& min_cutoff) { // update log probs
size_t beam_size = opts_.beam_size; root->iterate_to_vec(prefixes);
const auto& c = log_prob_idx.first; // only preserve top beam_size prefixes
const auto& log_prob_c = log_prob_idx.second; if (prefixes.size() >= beam_size) {
size_t prefixes_len = std::min(prefixes.size(), beam_size); std::nth_element(prefixes.begin(),
prefixes.begin() + beam_size,
for (size_t i = 0; i < prefixes_len; ++i) { prefixes.end(),
auto prefix = prefixes[i]; prefix_compare);
if (full_beam && log_prob_c + prefix->score < min_cutoff) { for (size_t i = beam_size; i < prefixes.size(); ++i) {
break; prefixes[i]->remove();
} }
} // if
num_frame_decoded_++;
} // for probs_seq
}
if (c == blank_id) { int32 CTCBeamSearch::SearchOneChar(
prefix->log_prob_b_cur = log_sum_exp( const bool& full_beam,
prefix->log_prob_b_cur, const std::pair<size_t, BaseFloat>& log_prob_idx,
log_prob_c + const BaseFloat& min_cutoff) {
prefix->score); size_t beam_size = opts_.beam_size;
continue; const auto& c = log_prob_idx.first;
} const auto& log_prob_c = log_prob_idx.second;
size_t prefixes_len = std::min(prefixes.size(), beam_size);
for (size_t i = 0; i < prefixes_len; ++i) {
auto prefix = prefixes[i];
if (full_beam && log_prob_c + prefix->score < min_cutoff) {
break;
}
// repeated character if (c == blank_id) {
if (c == prefix->character) { prefix->log_prob_b_cur =
// p_{nb}(l;x_{1:t}) = p(c;x_{t})p(l;x_{1:t-1}) log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score);
prefix->log_prob_nb_cur = log_sum_exp( continue;
prefix->log_prob_nb_cur, }
log_prob_c +
prefix->log_prob_nb_prev);
}
// get new prefix // repeated character
auto prefix_new = prefix->get_path_trie(c); if (c == prefix->character) {
if (prefix_new != nullptr) { // p_{nb}(l;x_{1:t}) = p(c;x_{t})p(l;x_{1:t-1})
float log_p = -NUM_FLT_INF; prefix->log_prob_nb_cur = log_sum_exp(
if (c == prefix->character && prefix->log_prob_nb_cur, log_prob_c + prefix->log_prob_nb_prev);
prefix->log_prob_b_prev > -NUM_FLT_INF) {
// p_{nb}(l^{+};x_{1:t}) = p(c;x_{t})p_{b}(l;x_{1:t-1})
log_p = log_prob_c + prefix->log_prob_b_prev;
} else if (c != prefix->character) {
// p_{nb}(l^{+};x_{1:t}) = p(c;x_{t}) p(l;x_{1:t-1})
log_p = log_prob_c + prefix->score;
}
// language model scoring
if (init_ext_scorer_ != nullptr &&
(c == space_id || init_ext_scorer_->is_character_based())) {
PathTrie *prefix_to_score = nullptr;
// skip scoring the space
if (init_ext_scorer_->is_character_based()) {
prefix_to_score = prefix_new;
} else {
prefix_to_score = prefix;
} }
float score = 0.0; // get new prefix
vector<string> ngram; auto prefix_new = prefix->get_path_trie(c);
ngram = init_ext_scorer_->make_ngram(prefix_to_score); if (prefix_new != nullptr) {
// lm score: p_{lm}(W)^{\alpha} + \beta float log_p = -NUM_FLT_INF;
score = init_ext_scorer_->get_log_cond_prob(ngram) * if (c == prefix->character &&
init_ext_scorer_->alpha; prefix->log_prob_b_prev > -NUM_FLT_INF) {
log_p += score; // p_{nb}(l^{+};x_{1:t}) = p(c;x_{t})p_{b}(l;x_{1:t-1})
log_p += init_ext_scorer_->beta; log_p = log_prob_c + prefix->log_prob_b_prev;
} } else if (c != prefix->character) {
// p_{nb}(l;x_{1:t}) // p_{nb}(l^{+};x_{1:t}) = p(c;x_{t}) p(l;x_{1:t-1})
prefix_new->log_prob_nb_cur = log_p = log_prob_c + prefix->score;
log_sum_exp(prefix_new->log_prob_nb_cur, }
log_p);
} // language model scoring
} // end of loop over prefix if (init_ext_scorer_ != nullptr &&
return 0; (c == space_id || init_ext_scorer_->is_character_based())) {
PathTrie* prefix_to_score = nullptr;
// skip scoring the space
if (init_ext_scorer_->is_character_based()) {
prefix_to_score = prefix_new;
} else {
prefix_to_score = prefix;
}
float score = 0.0;
vector<string> ngram;
ngram = init_ext_scorer_->make_ngram(prefix_to_score);
// lm score: p_{lm}(W)^{\alpha} + \beta
score = init_ext_scorer_->get_log_cond_prob(ngram) *
init_ext_scorer_->alpha;
log_p += score;
log_p += init_ext_scorer_->beta;
}
// p_{nb}(l;x_{1:t})
prefix_new->log_prob_nb_cur =
log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
}
} // end of loop over prefix
return 0;
} }
void CTCBeamSearch::CalculateApproxScore() { void CTCBeamSearch::CalculateApproxScore() {
size_t beam_size = opts_.beam_size; size_t beam_size = opts_.beam_size;
size_t num_prefixes = std::min(prefixes.size(), beam_size); size_t num_prefixes = std::min(prefixes.size(), beam_size);
std::sort( std::sort(
prefixes.begin(), prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare);
prefixes.begin() + num_prefixes,
prefix_compare); // compute aproximate ctc score as the return score, without affecting the
// return order of decoding result. To delete when decoder gets stable.
// compute aproximate ctc score as the return score, without affecting the for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
// return order of decoding result. To delete when decoder gets stable. double approx_ctc = prefixes[i]->score;
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { if (init_ext_scorer_ != nullptr) {
double approx_ctc = prefixes[i]->score; vector<int> output;
if (init_ext_scorer_ != nullptr) { prefixes[i]->get_path_vec(output);
vector<int> output; auto prefix_length = output.size();
prefixes[i]->get_path_vec(output); auto words = init_ext_scorer_->split_labels(output);
auto prefix_length = output.size(); // remove word insert
auto words = init_ext_scorer_->split_labels(output); approx_ctc = approx_ctc - prefix_length * init_ext_scorer_->beta;
// remove word insert // remove language model weight:
approx_ctc = approx_ctc - prefix_length * init_ext_scorer_->beta; approx_ctc -= (init_ext_scorer_->get_sent_log_prob(words)) *
// remove language model weight: init_ext_scorer_->alpha;
approx_ctc -= }
(init_ext_scorer_->get_sent_log_prob(words)) * init_ext_scorer_->alpha; prefixes[i]->approx_ctc = approx_ctc;
} }
prefixes[i]->approx_ctc = approx_ctc;
}
} }
void CTCBeamSearch::LMRescore() { void CTCBeamSearch::LMRescore() {
size_t beam_size = opts_.beam_size; size_t beam_size = opts_.beam_size;
if (init_ext_scorer_ != nullptr && !init_ext_scorer_->is_character_based()) { if (init_ext_scorer_ != nullptr &&
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { !init_ext_scorer_->is_character_based()) {
auto prefix = prefixes[i]; for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
if (!prefix->is_empty() && prefix->character != space_id) { auto prefix = prefixes[i];
float score = 0.0; if (!prefix->is_empty() && prefix->character != space_id) {
vector<string> ngram = init_ext_scorer_->make_ngram(prefix); float score = 0.0;
score = init_ext_scorer_->get_log_cond_prob(ngram) * init_ext_scorer_->alpha; vector<string> ngram = init_ext_scorer_->make_ngram(prefix);
score += init_ext_scorer_->beta; score = init_ext_scorer_->get_log_cond_prob(ngram) *
prefix->score += score; init_ext_scorer_->alpha;
} score += init_ext_scorer_->beta;
prefix->score += score;
}
}
} }
}
} }
} // namespace ppspeech } // namespace ppspeech

@ -1,8 +1,22 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "base/common.h" #include "base/common.h"
#include "decoder/ctc_decoders/path_trie.h"
#include "decoder/ctc_decoders/scorer.h"
#include "nnet/decodable-itf.h" #include "nnet/decodable-itf.h"
#include "util/parse-options.h" #include "util/parse-options.h"
#include "decoder/ctc_decoders/scorer.h"
#include "decoder/ctc_decoders/path_trie.h"
#pragma once #pragma once
@ -17,31 +31,32 @@ struct CTCBeamSearchOptions {
int beam_size; int beam_size;
int cutoff_top_n; int cutoff_top_n;
int num_proc_bsearch; int num_proc_bsearch;
CTCBeamSearchOptions() : CTCBeamSearchOptions()
dict_file("./model/words.txt"), : dict_file("./model/words.txt"),
lm_path("./model/lm.arpa"), lm_path("./model/lm.arpa"),
alpha(1.9f), alpha(1.9f),
beta(5.0), beta(5.0),
beam_size(300), beam_size(300),
cutoff_prob(0.99f), cutoff_prob(0.99f),
cutoff_top_n(40), cutoff_top_n(40),
num_proc_bsearch(0) { num_proc_bsearch(0) {}
}
void Register(kaldi::OptionsItf* opts) { void Register(kaldi::OptionsItf* opts) {
opts->Register("dict", &dict_file, "dict file "); opts->Register("dict", &dict_file, "dict file ");
opts->Register("lm-path", &lm_path, "language model file"); opts->Register("lm-path", &lm_path, "language model file");
opts->Register("alpha", &alpha, "alpha"); opts->Register("alpha", &alpha, "alpha");
opts->Register("beta", &beta, "beta"); opts->Register("beta", &beta, "beta");
opts->Register("beam-size", &beam_size, "beam size for beam search method"); opts->Register(
"beam-size", &beam_size, "beam size for beam search method");
opts->Register("cutoff-prob", &cutoff_prob, "cutoff probs"); opts->Register("cutoff-prob", &cutoff_prob, "cutoff probs");
opts->Register("cutoff-top-n", &cutoff_top_n, "cutoff top n"); opts->Register("cutoff-top-n", &cutoff_top_n, "cutoff top n");
opts->Register("num-proc-bsearch", &num_proc_bsearch, "num proc bsearch"); opts->Register(
"num-proc-bsearch", &num_proc_bsearch, "num proc bsearch");
} }
}; };
class CTCBeamSearch { class CTCBeamSearch {
public: public:
explicit CTCBeamSearch(const CTCBeamSearchOptions& opts); explicit CTCBeamSearch(const CTCBeamSearchOptions& opts);
~CTCBeamSearch() {} ~CTCBeamSearch() {}
void InitDecoder(); void InitDecoder();
@ -50,30 +65,32 @@ class CTCBeamSearch {
std::vector<std::pair<double, std::string>> GetNBestPath(); std::vector<std::pair<double, std::string>> GetNBestPath();
std::string GetFinalBestPath(); std::string GetFinalBestPath();
int NumFrameDecoded(); int NumFrameDecoded();
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>&probs, int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs,
std::vector<std::string>& nbest_words); std::vector<std::string>& nbest_words);
void AdvanceDecode(const std::shared_ptr<kaldi::DecodableInterface>& decodable, void AdvanceDecode(
int max_frames); const std::shared_ptr<kaldi::DecodableInterface>& decodable,
int max_frames);
void Reset(); void Reset();
private:
void ResetPrefixes();
int32 SearchOneChar(const bool& full_beam,
const std::pair<size_t, BaseFloat>& log_prob_idx,
const BaseFloat& min_cutoff);
void CalculateApproxScore();
void LMRescore();
void AdvanceDecoding(const std::vector<std::vector<BaseFloat>>& probs);
CTCBeamSearchOptions opts_; private:
std::shared_ptr<Scorer> init_ext_scorer_; // todo separate later void ResetPrefixes();
//std::vector<DecodeResult> decoder_results_; int32 SearchOneChar(const bool& full_beam,
std::vector<std::string> vocabulary_; // todo remove later const std::pair<size_t, BaseFloat>& log_prob_idx,
size_t blank_id; const BaseFloat& min_cutoff);
int space_id; void CalculateApproxScore();
std::shared_ptr<PathTrie> root; void LMRescore();
std::vector<PathTrie*> prefixes; void AdvanceDecoding(const std::vector<std::vector<BaseFloat>>& probs);
int num_frame_decoded_;
DISALLOW_COPY_AND_ASSIGN(CTCBeamSearch); CTCBeamSearchOptions opts_;
std::shared_ptr<Scorer> init_ext_scorer_; // todo separate later
// std::vector<DecodeResult> decoder_results_;
std::vector<std::string> vocabulary_; // todo remove later
size_t blank_id;
int space_id;
std::shared_ptr<PathTrie> root;
std::vector<PathTrie*> prefixes;
int num_frame_decoded_;
DISALLOW_COPY_AND_ASSIGN(CTCBeamSearch);
}; };
} // namespace basr } // namespace basr

@ -24,13 +24,14 @@ class FbankExtractor : FeatureExtractorInterface {
public: public:
explicit FbankExtractor(const FbankOptions& opts, explicit FbankExtractor(const FbankOptions& opts,
share_ptr<FeatureExtractorInterface> pre_extractor); share_ptr<FeatureExtractorInterface> pre_extractor);
virtual void AcceptWaveform(const kaldi::Vector<kaldi::BaseFloat>& input) = 0; virtual void AcceptWaveform(
const kaldi::Vector<kaldi::BaseFloat>& input) = 0;
virtual void Read(kaldi::Vector<kaldi::BaseFloat>* feat) = 0; virtual void Read(kaldi::Vector<kaldi::BaseFloat>* feat) = 0;
virtual size_t Dim() const = 0; virtual size_t Dim() const = 0;
private: private:
bool Compute(const kaldi::Vector<kaldi::BaseFloat>& wave, bool Compute(const kaldi::Vector<kaldi::BaseFloat>& wave,
kaldi::Vector<kaldi::BaseFloat>* feat) const; kaldi::Vector<kaldi::BaseFloat>* feat) const;
}; };
} // namespace ppspeech } // namespace ppspeech

@ -0,0 +1,14 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

@ -0,0 +1,14 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

@ -21,7 +21,8 @@ namespace ppspeech {
class FeatureExtractorInterface { class FeatureExtractorInterface {
public: public:
virtual void AcceptWaveform(const kaldi::VectorBase<kaldi::BaseFloat>& input) = 0; virtual void AcceptWaveform(
const kaldi::VectorBase<kaldi::BaseFloat>& input) = 0;
virtual void Read(kaldi::VectorBase<kaldi::BaseFloat>* feat) = 0; virtual void Read(kaldi::VectorBase<kaldi::BaseFloat>* feat) = 0;
virtual size_t Dim() const = 0; virtual size_t Dim() const = 0;
}; };

@ -25,97 +25,97 @@ using kaldi::VectorBase;
using kaldi::Matrix; using kaldi::Matrix;
using std::vector; using std::vector;
//todo remove later // todo remove later
void CopyVector2StdVector_(const VectorBase<BaseFloat>& input, void CopyVector2StdVector_(const VectorBase<BaseFloat>& input,
vector<BaseFloat>* output) { vector<BaseFloat>* output) {
if (input.Dim() == 0) return; if (input.Dim() == 0) return;
output->resize(input.Dim()); output->resize(input.Dim());
for (size_t idx = 0; idx < input.Dim(); ++idx) { for (size_t idx = 0; idx < input.Dim(); ++idx) {
(*output)[idx] = input(idx); (*output)[idx] = input(idx);
} }
} }
void CopyStdVector2Vector_(const vector<BaseFloat>& input, void CopyStdVector2Vector_(const vector<BaseFloat>& input,
Vector<BaseFloat>* output) { Vector<BaseFloat>* output) {
if (input.empty()) return; if (input.empty()) return;
output->Resize(input.size()); output->Resize(input.size());
for (size_t idx = 0; idx < input.size(); ++idx) { for (size_t idx = 0; idx < input.size(); ++idx) {
(*output)(idx) = input[idx]; (*output)(idx) = input[idx];
} }
} }
LinearSpectrogram::LinearSpectrogram( LinearSpectrogram::LinearSpectrogram(
const LinearSpectrogramOptions& opts, const LinearSpectrogramOptions& opts,
std::unique_ptr<FeatureExtractorInterface> base_extractor) { std::unique_ptr<FeatureExtractorInterface> base_extractor) {
opts_ = opts; opts_ = opts;
base_extractor_ = std::move(base_extractor); base_extractor_ = std::move(base_extractor);
int32 window_size = opts.frame_opts.WindowSize(); int32 window_size = opts.frame_opts.WindowSize();
int32 window_shift = opts.frame_opts.WindowShift(); int32 window_shift = opts.frame_opts.WindowShift();
fft_points_ = window_size; fft_points_ = window_size;
hanning_window_.resize(window_size); hanning_window_.resize(window_size);
double a = M_2PI / (window_size - 1); double a = M_2PI / (window_size - 1);
hanning_window_energy_ = 0; hanning_window_energy_ = 0;
for (int i = 0; i < window_size; ++i) { for (int i = 0; i < window_size; ++i) {
hanning_window_[i] = 0.5 - 0.5 * cos(a * i); hanning_window_[i] = 0.5 - 0.5 * cos(a * i);
hanning_window_energy_ += hanning_window_[i] * hanning_window_[i]; hanning_window_energy_ += hanning_window_[i] * hanning_window_[i];
} }
dim_ = fft_points_ / 2 + 1; // the dimension is Fs/2 Hz dim_ = fft_points_ / 2 + 1; // the dimension is Fs/2 Hz
} }
void LinearSpectrogram::AcceptWaveform(const VectorBase<BaseFloat>& input) { void LinearSpectrogram::AcceptWaveform(const VectorBase<BaseFloat>& input) {
base_extractor_->AcceptWaveform(input); base_extractor_->AcceptWaveform(input);
} }
void LinearSpectrogram::Hanning(vector<float>* data) const { void LinearSpectrogram::Hanning(vector<float>* data) const {
CHECK_GE(data->size(), hanning_window_.size()); CHECK_GE(data->size(), hanning_window_.size());
for (size_t i = 0; i < hanning_window_.size(); ++i) { for (size_t i = 0; i < hanning_window_.size(); ++i) {
data->at(i) *= hanning_window_[i]; data->at(i) *= hanning_window_[i];
} }
} }
bool LinearSpectrogram::NumpyFft(vector<BaseFloat>* v, bool LinearSpectrogram::NumpyFft(vector<BaseFloat>* v,
vector<BaseFloat>* real, vector<BaseFloat>* real,
vector<BaseFloat>* img) const { vector<BaseFloat>* img) const {
Vector<BaseFloat> v_tmp; Vector<BaseFloat> v_tmp;
CopyStdVector2Vector_(*v, &v_tmp); CopyStdVector2Vector_(*v, &v_tmp);
RealFft(&v_tmp, true); RealFft(&v_tmp, true);
CopyVector2StdVector_(v_tmp, v); CopyVector2StdVector_(v_tmp, v);
real->push_back(v->at(0)); real->push_back(v->at(0));
img->push_back(0); img->push_back(0);
for (int i = 1; i < v->size() / 2; i++) { for (int i = 1; i < v->size() / 2; i++) {
real->push_back(v->at(2 * i)); real->push_back(v->at(2 * i));
img->push_back(v->at(2 * i + 1)); img->push_back(v->at(2 * i + 1));
} }
real->push_back(v->at(1)); real->push_back(v->at(1));
img->push_back(0); img->push_back(0);
return true; return true;
} }
// todo remove later // todo remove later
void LinearSpectrogram::ReadFeats(Matrix<BaseFloat>* feats) { void LinearSpectrogram::ReadFeats(Matrix<BaseFloat>* feats) {
Vector<BaseFloat> tmp; Vector<BaseFloat> tmp;
waveform_.Resize(base_extractor_->Dim()); waveform_.Resize(base_extractor_->Dim());
Compute(tmp, &waveform_); Compute(tmp, &waveform_);
vector<vector<BaseFloat>> result; vector<vector<BaseFloat>> result;
vector<BaseFloat> feats_vec; vector<BaseFloat> feats_vec;
CopyVector2StdVector_(waveform_, &feats_vec); CopyVector2StdVector_(waveform_, &feats_vec);
Compute(feats_vec, result); Compute(feats_vec, result);
feats->Resize(result.size(), result[0].size()); feats->Resize(result.size(), result[0].size());
for (int row_idx = 0; row_idx < result.size(); ++row_idx) { for (int row_idx = 0; row_idx < result.size(); ++row_idx) {
for (int col_idx = 0; col_idx < result[0].size(); ++col_idx) { for (int col_idx = 0; col_idx < result[0].size(); ++col_idx) {
(*feats)(row_idx, col_idx) = result[row_idx][col_idx]; (*feats)(row_idx, col_idx) = result[row_idx][col_idx];
}
} }
} waveform_.Resize(0);
waveform_.Resize(0);
} }
void LinearSpectrogram::Read(VectorBase<BaseFloat>* feat) { void LinearSpectrogram::Read(VectorBase<BaseFloat>* feat) {
// todo // todo
return; return;
} }
// only for test, remove later // only for test, remove later
@ -129,49 +129,49 @@ void LinearSpectrogram::Compute(const VectorBase<kaldi::BaseFloat>& input,
// todo: refactor later (SmileGoat) // todo: refactor later (SmileGoat)
bool LinearSpectrogram::Compute(const vector<float>& wave, bool LinearSpectrogram::Compute(const vector<float>& wave,
vector<vector<float>>& feat) { vector<vector<float>>& feat) {
int num_samples = wave.size(); int num_samples = wave.size();
const int& frame_length = opts_.frame_opts.WindowSize(); const int& frame_length = opts_.frame_opts.WindowSize();
const int& sample_rate = opts_.frame_opts.samp_freq; const int& sample_rate = opts_.frame_opts.samp_freq;
const int& frame_shift = opts_.frame_opts.WindowShift(); const int& frame_shift = opts_.frame_opts.WindowShift();
const int& fft_points = fft_points_; const int& fft_points = fft_points_;
const float scale = hanning_window_energy_ * sample_rate; const float scale = hanning_window_energy_ * sample_rate;
if (num_samples < frame_length) { if (num_samples < frame_length) {
return true; return true;
} }
int num_frames = 1 + ((num_samples - frame_length) / frame_shift); int num_frames = 1 + ((num_samples - frame_length) / frame_shift);
feat.resize(num_frames); feat.resize(num_frames);
vector<float> fft_real((fft_points_ / 2 + 1), 0); vector<float> fft_real((fft_points_ / 2 + 1), 0);
vector<float> fft_img((fft_points_ / 2 + 1), 0); vector<float> fft_img((fft_points_ / 2 + 1), 0);
vector<float> v(frame_length, 0); vector<float> v(frame_length, 0);
vector<float> power((fft_points / 2 + 1)); vector<float> power((fft_points / 2 + 1));
for (int i = 0; i < num_frames; ++i) { for (int i = 0; i < num_frames; ++i) {
vector<float> data(wave.data() + i * frame_shift, vector<float> data(wave.data() + i * frame_shift,
wave.data() + i * frame_shift + frame_length); wave.data() + i * frame_shift + frame_length);
Hanning(&data); Hanning(&data);
fft_img.clear(); fft_img.clear();
fft_real.clear(); fft_real.clear();
v.assign(data.begin(), data.end()); v.assign(data.begin(), data.end());
NumpyFft(&v, &fft_real, &fft_img); NumpyFft(&v, &fft_real, &fft_img);
feat[i].resize(fft_points / 2 + 1); // the last dimension is Fs/2 Hz feat[i].resize(fft_points / 2 + 1); // the last dimension is Fs/2 Hz
for (int j = 0; j < (fft_points / 2 + 1); ++j) { for (int j = 0; j < (fft_points / 2 + 1); ++j) {
power[j] = fft_real[j] * fft_real[j] + fft_img[j] * fft_img[j]; power[j] = fft_real[j] * fft_real[j] + fft_img[j] * fft_img[j];
feat[i][j] = power[j]; feat[i][j] = power[j];
if (j == 0 || j == feat[0].size() - 1) { if (j == 0 || j == feat[0].size() - 1) {
feat[i][j] /= scale; feat[i][j] /= scale;
} else { } else {
feat[i][j] *= (2.0 / scale); feat[i][j] *= (2.0 / scale);
} }
// log added eps=1e-14 // log added eps=1e-14
feat[i][j] = std::log(feat[i][j] + 1e-14); feat[i][j] = std::log(feat[i][j] + 1e-14);
}
} }
} return true;
return true;
} }
} // namespace ppspeech } // namespace ppspeech

@ -1,27 +1,40 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once #pragma once
#include "base/common.h"
#include "frontend/feature_extractor_interface.h" #include "frontend/feature_extractor_interface.h"
#include "kaldi/feat/feature-window.h" #include "kaldi/feat/feature-window.h"
#include "base/common.h"
namespace ppspeech { namespace ppspeech {
struct LinearSpectrogramOptions { struct LinearSpectrogramOptions {
kaldi::FrameExtractionOptions frame_opts; kaldi::FrameExtractionOptions frame_opts;
LinearSpectrogramOptions(): LinearSpectrogramOptions() : frame_opts() {}
frame_opts() {}
void Register(kaldi::OptionsItf* opts) { void Register(kaldi::OptionsItf* opts) { frame_opts.Register(opts); }
frame_opts.Register(opts);
}
}; };
class LinearSpectrogram : public FeatureExtractorInterface { class LinearSpectrogram : public FeatureExtractorInterface {
public: public:
explicit LinearSpectrogram(const LinearSpectrogramOptions& opts, explicit LinearSpectrogram(
std::unique_ptr<FeatureExtractorInterface> base_extractor); const LinearSpectrogramOptions& opts,
virtual void AcceptWaveform(const kaldi::VectorBase<kaldi::BaseFloat>& input); std::unique_ptr<FeatureExtractorInterface> base_extractor);
virtual void AcceptWaveform(
const kaldi::VectorBase<kaldi::BaseFloat>& input);
virtual void Read(kaldi::VectorBase<kaldi::BaseFloat>* feat); virtual void Read(kaldi::VectorBase<kaldi::BaseFloat>* feat);
virtual size_t Dim() const { return dim_; } virtual size_t Dim() const { return dim_; }
void ReadFeats(kaldi::Matrix<kaldi::BaseFloat>* feats); void ReadFeats(kaldi::Matrix<kaldi::BaseFloat>* feats);
@ -41,7 +54,7 @@ class LinearSpectrogram : public FeatureExtractorInterface {
std::vector<kaldi::BaseFloat> hanning_window_; std::vector<kaldi::BaseFloat> hanning_window_;
kaldi::BaseFloat hanning_window_energy_; kaldi::BaseFloat hanning_window_energy_;
LinearSpectrogramOptions opts_; LinearSpectrogramOptions opts_;
kaldi::Vector<kaldi::BaseFloat> waveform_; // remove later, todo(SmileGoat) kaldi::Vector<kaldi::BaseFloat> waveform_; // remove later, todo(SmileGoat)
std::unique_ptr<FeatureExtractorInterface> base_extractor_; std::unique_ptr<FeatureExtractorInterface> base_extractor_;
DISALLOW_COPY_AND_ASSIGN(LinearSpectrogram); DISALLOW_COPY_AND_ASSIGN(LinearSpectrogram);
}; };

@ -1,3 +1,17 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "frontend/normalizer.h" #include "frontend/normalizer.h"
#include "kaldi/feat/cmvn.h" #include "kaldi/feat/cmvn.h"
@ -12,169 +26,173 @@ using std::vector;
using kaldi::SubVector; using kaldi::SubVector;
DecibelNormalizer::DecibelNormalizer(const DecibelNormalizerOptions& opts) { DecibelNormalizer::DecibelNormalizer(const DecibelNormalizerOptions& opts) {
opts_ = opts; opts_ = opts;
dim_ = 0; dim_ = 0;
} }
void DecibelNormalizer::AcceptWaveform(const kaldi::VectorBase<BaseFloat>& input) { void DecibelNormalizer::AcceptWaveform(
dim_ = input.Dim(); const kaldi::VectorBase<BaseFloat>& input) {
waveform_.Resize(input.Dim()); dim_ = input.Dim();
waveform_.CopyFromVec(input); waveform_.Resize(input.Dim());
waveform_.CopyFromVec(input);
} }
void DecibelNormalizer::Read(kaldi::VectorBase<BaseFloat>* feat) { void DecibelNormalizer::Read(kaldi::VectorBase<BaseFloat>* feat) {
if (waveform_.Dim() == 0) return; if (waveform_.Dim() == 0) return;
Compute(waveform_, feat); Compute(waveform_, feat);
} }
//todo remove later // todo remove later
void CopyVector2StdVector(const kaldi::VectorBase<BaseFloat>& input, void CopyVector2StdVector(const kaldi::VectorBase<BaseFloat>& input,
vector<BaseFloat>* output) { vector<BaseFloat>* output) {
if (input.Dim() == 0) return; if (input.Dim() == 0) return;
output->resize(input.Dim()); output->resize(input.Dim());
for (size_t idx = 0; idx < input.Dim(); ++idx) { for (size_t idx = 0; idx < input.Dim(); ++idx) {
(*output)[idx] = input(idx); (*output)[idx] = input(idx);
} }
} }
void CopyStdVector2Vector(const vector<BaseFloat>& input, void CopyStdVector2Vector(const vector<BaseFloat>& input,
VectorBase<BaseFloat>* output) { VectorBase<BaseFloat>* output) {
if (input.empty()) return; if (input.empty()) return;
assert(input.size() == output->Dim()); assert(input.size() == output->Dim());
for (size_t idx = 0; idx < input.size(); ++idx) { for (size_t idx = 0; idx < input.size(); ++idx) {
(*output)(idx) = input[idx]; (*output)(idx) = input[idx];
} }
} }
bool DecibelNormalizer::Compute(const VectorBase<BaseFloat>& input, bool DecibelNormalizer::Compute(const VectorBase<BaseFloat>& input,
VectorBase<BaseFloat>* feat) const { VectorBase<BaseFloat>* feat) const {
// calculate db rms // calculate db rms
BaseFloat rms_db = 0.0; BaseFloat rms_db = 0.0;
BaseFloat mean_square = 0.0; BaseFloat mean_square = 0.0;
BaseFloat gain = 0.0; BaseFloat gain = 0.0;
BaseFloat wave_float_normlization = 1.0f / (std::pow(2, 16 - 1)); BaseFloat wave_float_normlization = 1.0f / (std::pow(2, 16 - 1));
vector<BaseFloat> samples; vector<BaseFloat> samples;
samples.resize(input.Dim()); samples.resize(input.Dim());
for (int32 i = 0; i < samples.size(); ++i) { for (int32 i = 0; i < samples.size(); ++i) {
samples[i] = input(i); samples[i] = input(i);
}
// square
for (auto &d : samples) {
if (opts_.convert_int_float) {
d = d * wave_float_normlization;
} }
mean_square += d * d;
}
// mean
mean_square /= samples.size();
rms_db = 10 * std::log10(mean_square);
gain = opts_.target_db - rms_db;
if (gain > opts_.max_gain_db) {
LOG(ERROR) << "Unable to normalize segment to " << opts_.target_db << "dB,"
<< "because the the probable gain have exceeds opts_.max_gain_db"
<< opts_.max_gain_db << "dB.";
return false;
}
// Note that this is an in-place transformation. // square
for (auto &item : samples) { for (auto& d : samples) {
// python item *= 10.0 ** (gain / 20.0) if (opts_.convert_int_float) {
item *= std::pow(10.0, gain / 20.0); d = d * wave_float_normlization;
} }
mean_square += d * d;
}
// mean
mean_square /= samples.size();
rms_db = 10 * std::log10(mean_square);
gain = opts_.target_db - rms_db;
if (gain > opts_.max_gain_db) {
LOG(ERROR)
<< "Unable to normalize segment to " << opts_.target_db << "dB,"
<< "because the the probable gain have exceeds opts_.max_gain_db"
<< opts_.max_gain_db << "dB.";
return false;
}
CopyStdVector2Vector(samples, feat); // Note that this is an in-place transformation.
return true; for (auto& item : samples) {
// python item *= 10.0 ** (gain / 20.0)
item *= std::pow(10.0, gain / 20.0);
}
CopyStdVector2Vector(samples, feat);
return true;
} }
CMVN::CMVN(std::string cmvn_file) : var_norm_(true) { CMVN::CMVN(std::string cmvn_file) : var_norm_(true) {
bool binary; bool binary;
kaldi::Input ki(cmvn_file, &binary); kaldi::Input ki(cmvn_file, &binary);
stats_.Read(ki.Stream(), binary); stats_.Read(ki.Stream(), binary);
} }
void CMVN::AcceptWaveform(const kaldi::VectorBase<kaldi::BaseFloat>& input) { void CMVN::AcceptWaveform(const kaldi::VectorBase<kaldi::BaseFloat>& input) {
return; return;
} }
void CMVN::Read(kaldi::VectorBase<BaseFloat>* feat) { void CMVN::Read(kaldi::VectorBase<BaseFloat>* feat) { return; }
return;
}
// feats contain num_frames feature. // feats contain num_frames feature.
void CMVN::ApplyCMVN(bool var_norm, VectorBase<BaseFloat>* feats) { void CMVN::ApplyCMVN(bool var_norm, VectorBase<BaseFloat>* feats) {
KALDI_ASSERT(feats != NULL); KALDI_ASSERT(feats != NULL);
int32 dim = stats_.NumCols() - 1; int32 dim = stats_.NumCols() - 1;
if (stats_.NumRows() > 2 || stats_.NumRows() < 1 || feats->Dim() % dim != 0) { if (stats_.NumRows() > 2 || stats_.NumRows() < 1 ||
KALDI_ERR << "Dim mismatch: cmvn " feats->Dim() % dim != 0) {
<< stats_.NumRows() << 'x' << stats_.NumCols() KALDI_ERR << "Dim mismatch: cmvn " << stats_.NumRows() << 'x'
<< ", feats " << feats->Dim() << 'x'; << stats_.NumCols() << ", feats " << feats->Dim() << 'x';
}
if (stats_.NumRows() == 1 && var_norm) {
KALDI_ERR << "You requested variance normalization but no variance stats_ "
<< "are supplied.";
}
double count = stats_(0, dim);
// Do not change the threshold of 1.0 here: in the balanced-cmvn code, when
// computing an offset and representing it as stats_, we use a count of one.
if (count < 1.0)
KALDI_ERR << "Insufficient stats_ for cepstral mean and variance normalization: "
<< "count = " << count;
if (!var_norm) {
Vector<BaseFloat> offset(feats->Dim());
SubVector<double> mean_stats(stats_.RowData(0), dim);
Vector<double> mean_stats_apply(feats->Dim());
//fill the datat of mean_stats in mean_stats_appy whose dim is equal with the dim of feature.
//the dim of feats = dim * num_frames;
for (int32 idx = 0; idx < feats->Dim() / dim; ++idx) {
SubVector<double> stats_tmp(mean_stats_apply.Data() + dim*idx, dim);
stats_tmp.CopyFromVec(mean_stats);
} }
offset.AddVec(-1.0 / count, mean_stats_apply); if (stats_.NumRows() == 1 && var_norm) {
feats->AddVec(1.0, offset); KALDI_ERR
return; << "You requested variance normalization but no variance stats_ "
} << "are supplied.";
// norm(0, d) = mean offset;
// norm(1, d) = scale, e.g. x(d) <-- x(d)*norm(1, d) + norm(0, d).
kaldi::Matrix<BaseFloat> norm(2, feats->Dim());
for (int32 d = 0; d < dim; d++) {
double mean, offset, scale;
mean = stats_(0, d)/count;
double var = (stats_(1, d)/count) - mean*mean,
floor = 1.0e-20;
if (var < floor) {
KALDI_WARN << "Flooring cepstral variance from " << var << " to "
<< floor;
var = floor;
} }
scale = 1.0 / sqrt(var);
if (scale != scale || 1/scale == 0.0) double count = stats_(0, dim);
KALDI_ERR << "NaN or infinity in cepstral mean/variance computation"; // Do not change the threshold of 1.0 here: in the balanced-cmvn code, when
offset = -(mean*scale); // computing an offset and representing it as stats_, we use a count of one.
for (int32 d_skip = d; d_skip < feats->Dim();) { if (count < 1.0)
norm(0, d_skip) = offset; KALDI_ERR << "Insufficient stats_ for cepstral mean and variance "
norm(1, d_skip) = scale; "normalization: "
d_skip = d_skip + dim; << "count = " << count;
if (!var_norm) {
Vector<BaseFloat> offset(feats->Dim());
SubVector<double> mean_stats(stats_.RowData(0), dim);
Vector<double> mean_stats_apply(feats->Dim());
// fill the datat of mean_stats in mean_stats_appy whose dim is equal
// with the dim of feature.
// the dim of feats = dim * num_frames;
for (int32 idx = 0; idx < feats->Dim() / dim; ++idx) {
SubVector<double> stats_tmp(mean_stats_apply.Data() + dim * idx,
dim);
stats_tmp.CopyFromVec(mean_stats);
}
offset.AddVec(-1.0 / count, mean_stats_apply);
feats->AddVec(1.0, offset);
return;
}
// norm(0, d) = mean offset;
// norm(1, d) = scale, e.g. x(d) <-- x(d)*norm(1, d) + norm(0, d).
kaldi::Matrix<BaseFloat> norm(2, feats->Dim());
for (int32 d = 0; d < dim; d++) {
double mean, offset, scale;
mean = stats_(0, d) / count;
double var = (stats_(1, d) / count) - mean * mean, floor = 1.0e-20;
if (var < floor) {
KALDI_WARN << "Flooring cepstral variance from " << var << " to "
<< floor;
var = floor;
}
scale = 1.0 / sqrt(var);
if (scale != scale || 1 / scale == 0.0)
KALDI_ERR
<< "NaN or infinity in cepstral mean/variance computation";
offset = -(mean * scale);
for (int32 d_skip = d; d_skip < feats->Dim();) {
norm(0, d_skip) = offset;
norm(1, d_skip) = scale;
d_skip = d_skip + dim;
}
} }
} // Apply the normalization.
// Apply the normalization. feats->MulElements(norm.Row(1));
feats->MulElements(norm.Row(1)); feats->AddVec(1.0, norm.Row(0));
feats->AddVec(1.0, norm.Row(0));
} }
void CMVN::ApplyCMVNMatrix(bool var_norm, kaldi::MatrixBase<BaseFloat>* feats) { void CMVN::ApplyCMVNMatrix(bool var_norm, kaldi::MatrixBase<BaseFloat>* feats) {
ApplyCmvn(stats_, var_norm, feats); ApplyCmvn(stats_, var_norm, feats);
} }
bool CMVN::Compute(const VectorBase<BaseFloat>& input, bool CMVN::Compute(const VectorBase<BaseFloat>& input,
VectorBase<BaseFloat>* feat) const { VectorBase<BaseFloat>* feat) const {
return false; return false;
} }
} // namespace ppspeech } // namespace ppspeech

@ -1,37 +1,55 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once #pragma once
#include "base/common.h" #include "base/common.h"
#include "frontend/feature_extractor_interface.h" #include "frontend/feature_extractor_interface.h"
#include "kaldi/util/options-itf.h"
#include "kaldi/matrix/kaldi-matrix.h" #include "kaldi/matrix/kaldi-matrix.h"
#include "kaldi/util/options-itf.h"
namespace ppspeech { namespace ppspeech {
struct DecibelNormalizerOptions { struct DecibelNormalizerOptions {
float target_db; float target_db;
float max_gain_db; float max_gain_db;
bool convert_int_float; bool convert_int_float;
DecibelNormalizerOptions() : DecibelNormalizerOptions()
target_db(-20), : target_db(-20), max_gain_db(300.0), convert_int_float(false) {}
max_gain_db(300.0),
convert_int_float(false) {}
void Register(kaldi::OptionsItf* opts) { void Register(kaldi::OptionsItf* opts) {
opts->Register("target-db", &target_db, "target db for db normalization"); opts->Register(
opts->Register("max-gain-db", &max_gain_db, "max gain db for db normalization"); "target-db", &target_db, "target db for db normalization");
opts->Register("convert-int-float", &convert_int_float, "if convert int samples to float"); opts->Register(
"max-gain-db", &max_gain_db, "max gain db for db normalization");
opts->Register("convert-int-float",
&convert_int_float,
"if convert int samples to float");
} }
}; };
class DecibelNormalizer : public FeatureExtractorInterface { class DecibelNormalizer : public FeatureExtractorInterface {
public: public:
explicit DecibelNormalizer(const DecibelNormalizerOptions& opts); explicit DecibelNormalizer(const DecibelNormalizerOptions& opts);
virtual void AcceptWaveform(const kaldi::VectorBase<kaldi::BaseFloat>& input); virtual void AcceptWaveform(
const kaldi::VectorBase<kaldi::BaseFloat>& input);
virtual void Read(kaldi::VectorBase<kaldi::BaseFloat>* feat); virtual void Read(kaldi::VectorBase<kaldi::BaseFloat>* feat);
virtual size_t Dim() const { return dim_; } virtual size_t Dim() const { return dim_; }
bool Compute(const kaldi::VectorBase<kaldi::BaseFloat>& input, bool Compute(const kaldi::VectorBase<kaldi::BaseFloat>& input,
kaldi::VectorBase<kaldi::BaseFloat>* feat) const; kaldi::VectorBase<kaldi::BaseFloat>* feat) const;
private: private:
DecibelNormalizerOptions opts_; DecibelNormalizerOptions opts_;
size_t dim_; size_t dim_;
@ -43,7 +61,8 @@ class DecibelNormalizer : public FeatureExtractorInterface {
class CMVN : public FeatureExtractorInterface { class CMVN : public FeatureExtractorInterface {
public: public:
explicit CMVN(std::string cmvn_file); explicit CMVN(std::string cmvn_file);
virtual void AcceptWaveform(const kaldi::VectorBase<kaldi::BaseFloat>& input); virtual void AcceptWaveform(
const kaldi::VectorBase<kaldi::BaseFloat>& input);
virtual void Read(kaldi::VectorBase<kaldi::BaseFloat>* feat); virtual void Read(kaldi::VectorBase<kaldi::BaseFloat>* feat);
virtual size_t Dim() const { return stats_.NumCols() - 1; } virtual size_t Dim() const { return stats_.NumCols() - 1; }
bool Compute(const kaldi::VectorBase<kaldi::BaseFloat>& input, bool Compute(const kaldi::VectorBase<kaldi::BaseFloat>& input,
@ -51,6 +70,7 @@ class CMVN : public FeatureExtractorInterface {
// for test // for test
void ApplyCMVN(bool var_norm, kaldi::VectorBase<BaseFloat>* feats); void ApplyCMVN(bool var_norm, kaldi::VectorBase<BaseFloat>* feats);
void ApplyCMVNMatrix(bool var_norm, kaldi::MatrixBase<BaseFloat>* feats); void ApplyCMVNMatrix(bool var_norm, kaldi::MatrixBase<BaseFloat>* feats);
private: private:
kaldi::Matrix<double> stats_; kaldi::Matrix<double> stats_;
std::shared_ptr<FeatureExtractorInterface> base_extractor_; std::shared_ptr<FeatureExtractorInterface> base_extractor_;

@ -13,4 +13,3 @@
// limitations under the License. // limitations under the License.
// extract the window of kaldi feat. // extract the window of kaldi feat.

@ -1,3 +1,17 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// itf/decodable-itf.h // itf/decodable-itf.h
// Copyright 2009-2011 Microsoft Corporation; Saarland University; // Copyright 2009-2011 Microsoft Corporation; Saarland University;
@ -42,8 +56,10 @@ namespace kaldi {
For online decoding, where the features are coming in in real time, it is For online decoding, where the features are coming in in real time, it is
important to understand the IsLastFrame() and NumFramesReady() functions. important to understand the IsLastFrame() and NumFramesReady() functions.
There are two ways these are used: the old online-decoding code, in ../online/, There are two ways these are used: the old online-decoding code, in
and the new online-decoding code, in ../online2/. In the old online-decoding ../online/,
and the new online-decoding code, in ../online2/. In the old
online-decoding
code, the decoder would do: code, the decoder would do:
\code{.cc} \code{.cc}
for (int frame = 0; !decodable.IsLastFrame(frame); frame++) { for (int frame = 0; !decodable.IsLastFrame(frame); frame++) {
@ -52,13 +68,16 @@ namespace kaldi {
\endcode \endcode
and the call to IsLastFrame would block if the features had not arrived yet. and the call to IsLastFrame would block if the features had not arrived yet.
The decodable object would have to know when to terminate the decoding. This The decodable object would have to know when to terminate the decoding. This
online-decoding mode is still supported, it is what happens when you call, for online-decoding mode is still supported, it is what happens when you call,
for
example, LatticeFasterDecoder::Decode(). example, LatticeFasterDecoder::Decode().
We realized that this "blocking" mode of decoding is not very convenient We realized that this "blocking" mode of decoding is not very convenient
because it forces the program to be multi-threaded and makes it complex to because it forces the program to be multi-threaded and makes it complex to
control endpointing. In the "new" decoding code, you don't call (for example) control endpointing. In the "new" decoding code, you don't call (for
LatticeFasterDecoder::Decode(), you call LatticeFasterDecoder::InitDecoding(), example)
LatticeFasterDecoder::Decode(), you call
LatticeFasterDecoder::InitDecoding(),
and then each time you get more features, you provide them to the decodable and then each time you get more features, you provide them to the decodable
object, and you call LatticeFasterDecoder::AdvanceDecoding(), which does object, and you call LatticeFasterDecoder::AdvanceDecoding(), which does
something like this: something like this:
@ -68,7 +87,8 @@ namespace kaldi {
} }
\endcode \endcode
So the decodable object never has IsLastFrame() called. For decoding where So the decodable object never has IsLastFrame() called. For decoding where
you are starting with a matrix of features, the NumFramesReady() function will you are starting with a matrix of features, the NumFramesReady() function
will
always just return the number of frames in the file, and IsLastFrame() will always just return the number of frames in the file, and IsLastFrame() will
return true for the last frame. return true for the last frame.
@ -80,43 +100,52 @@ namespace kaldi {
frame of the file once we've decided to terminate decoding. frame of the file once we've decided to terminate decoding.
*/ */
class DecodableInterface { class DecodableInterface {
public: public:
/// Returns the log likelihood, which will be negated in the decoder. /// Returns the log likelihood, which will be negated in the decoder.
/// The "frame" starts from zero. You should verify that NumFramesReady() > frame /// The "frame" starts from zero. You should verify that NumFramesReady() >
/// before calling this. /// frame
virtual BaseFloat LogLikelihood(int32 frame, int32 index) = 0; /// before calling this.
virtual BaseFloat LogLikelihood(int32 frame, int32 index) = 0;
/// Returns true if this is the last frame. Frames are zero-based, so the
/// first frame is zero. IsLastFrame(-1) will return false, unless the file /// Returns true if this is the last frame. Frames are zero-based, so the
/// is empty (which is a case that I'm not sure all the code will handle, so /// first frame is zero. IsLastFrame(-1) will return false, unless the file
/// be careful). Caution: the behavior of this function in an online setting /// is empty (which is a case that I'm not sure all the code will handle, so
/// is being changed somewhat. In future it may return false in cases where /// be careful). Caution: the behavior of this function in an online
/// we haven't yet decided to terminate decoding, but later true if we decide /// setting
/// to terminate decoding. The plan in future is to rely more on /// is being changed somewhat. In future it may return false in cases where
/// NumFramesReady(), and in future, IsLastFrame() would always return false /// we haven't yet decided to terminate decoding, but later true if we
/// in an online-decoding setting, and would only return true in a /// decide
/// decoding-from-matrix setting where we want to allow the last delta or LDA /// to terminate decoding. The plan in future is to rely more on
/// features to be flushed out for compatibility with the baseline setup. /// NumFramesReady(), and in future, IsLastFrame() would always return false
virtual bool IsLastFrame(int32 frame) const = 0; /// in an online-decoding setting, and would only return true in a
/// decoding-from-matrix setting where we want to allow the last delta or
/// The call NumFramesReady() will return the number of frames currently available /// LDA
/// for this decodable object. This is for use in setups where you don't want the /// features to be flushed out for compatibility with the baseline setup.
/// decoder to block while waiting for input. This is newly added as of Jan 2014, virtual bool IsLastFrame(int32 frame) const = 0;
/// and I hope, going forward, to rely on this mechanism more than IsLastFrame to
/// know when to stop decoding. /// The call NumFramesReady() will return the number of frames currently
virtual int32 NumFramesReady() const { /// available
KALDI_ERR << "NumFramesReady() not implemented for this decodable type."; /// for this decodable object. This is for use in setups where you don't
return -1; /// want the
} /// decoder to block while waiting for input. This is newly added as of Jan
/// 2014,
/// Returns the number of states in the acoustic model /// and I hope, going forward, to rely on this mechanism more than
/// (they will be indexed one-based, i.e. from 1 to NumIndices(); /// IsLastFrame to
/// this is for compatibility with OpenFst). /// know when to stop decoding.
virtual int32 NumIndices() const = 0; virtual int32 NumFramesReady() const {
KALDI_ERR
virtual std::vector<BaseFloat> FrameLogLikelihood(int32 frame) = 0; << "NumFramesReady() not implemented for this decodable type.";
return -1;
virtual ~DecodableInterface() {} }
/// Returns the number of states in the acoustic model
/// (they will be indexed one-based, i.e. from 1 to NumIndices();
/// this is for compatibility with OpenFst).
virtual int32 NumIndices() const = 0;
virtual std::vector<BaseFloat> FrameLogLikelihood(int32 frame) = 0;
virtual ~DecodableInterface() {}
}; };
/// @} /// @}
} // namespace Kaldi } // namespace Kaldi

@ -1,3 +1,17 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "nnet/decodable.h" #include "nnet/decodable.h"
namespace ppspeech { namespace ppspeech {
@ -5,51 +19,43 @@ namespace ppspeech {
using kaldi::BaseFloat; using kaldi::BaseFloat;
using kaldi::Matrix; using kaldi::Matrix;
Decodable::Decodable(const std::shared_ptr<NnetInterface>& nnet): Decodable::Decodable(const std::shared_ptr<NnetInterface>& nnet)
frontend_(NULL), : frontend_(NULL), nnet_(nnet), finished_(false), frames_ready_(0) {}
nnet_(nnet),
finished_(false),
frames_ready_(0) {
}
void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) { void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) {
frames_ready_ += likelihood.NumRows(); frames_ready_ += likelihood.NumRows();
} }
//Decodable::Init(DecodableConfig config) { // Decodable::Init(DecodableConfig config) {
//} //}
bool Decodable::IsLastFrame(int32 frame) const { bool Decodable::IsLastFrame(int32 frame) const {
CHECK_LE(frame, frames_ready_); CHECK_LE(frame, frames_ready_);
return finished_ && (frame == frames_ready_ - 1); return finished_ && (frame == frames_ready_ - 1);
} }
int32 Decodable::NumIndices() const { int32 Decodable::NumIndices() const { return 0; }
return 0;
}
BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) { BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) { return 0; }
return 0;
}
void Decodable::FeedFeatures(const Matrix<kaldi::BaseFloat>& features) { void Decodable::FeedFeatures(const Matrix<kaldi::BaseFloat>& features) {
nnet_->FeedForward(features, &nnet_cache_); nnet_->FeedForward(features, &nnet_cache_);
frames_ready_ += nnet_cache_.NumRows(); frames_ready_ += nnet_cache_.NumRows();
return ; return;
} }
std::vector<BaseFloat> Decodable::FrameLogLikelihood(int32 frame) { std::vector<BaseFloat> Decodable::FrameLogLikelihood(int32 frame) {
std::vector<BaseFloat> result; std::vector<BaseFloat> result;
result.reserve(nnet_cache_.NumCols()); result.reserve(nnet_cache_.NumCols());
for (int32 idx = 0; idx < nnet_cache_.NumCols(); ++idx) { for (int32 idx = 0; idx < nnet_cache_.NumCols(); ++idx) {
result[idx] = nnet_cache_(frame, idx); result[idx] = nnet_cache_(frame, idx);
} }
return result; return result;
} }
void Decodable::Reset() { void Decodable::Reset() {
// frontend_.Reset(); // frontend_.Reset();
nnet_->Reset(); nnet_->Reset();
} }
} // namespace ppspeech } // namespace ppspeech

@ -1,7 +1,21 @@
#include "nnet/decodable-itf.h" // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "base/common.h" #include "base/common.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "frontend/feature_extractor_interface.h" #include "frontend/feature_extractor_interface.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "nnet/decodable-itf.h"
#include "nnet/nnet_interface.h" #include "nnet/nnet_interface.h"
namespace ppspeech { namespace ppspeech {
@ -11,15 +25,18 @@ struct DecodableOpts;
class Decodable : public kaldi::DecodableInterface { class Decodable : public kaldi::DecodableInterface {
public: public:
explicit Decodable(const std::shared_ptr<NnetInterface>& nnet); explicit Decodable(const std::shared_ptr<NnetInterface>& nnet);
//void Init(DecodableOpts config); // void Init(DecodableOpts config);
virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index); virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index);
virtual bool IsLastFrame(int32 frame) const; virtual bool IsLastFrame(int32 frame) const;
virtual int32 NumIndices() const; virtual int32 NumIndices() const;
virtual std::vector<BaseFloat> FrameLogLikelihood(int32 frame); virtual std::vector<BaseFloat> FrameLogLikelihood(int32 frame);
void Acceptlikelihood(const kaldi::Matrix<kaldi::BaseFloat>& likelihood); // remove later void Acceptlikelihood(
void FeedFeatures(const kaldi::Matrix<kaldi::BaseFloat>& feature); // only for test, todo remove later const kaldi::Matrix<kaldi::BaseFloat>& likelihood); // remove later
void FeedFeatures(const kaldi::Matrix<kaldi::BaseFloat>&
feature); // only for test, todo remove later
void Reset(); void Reset();
void InputFinished() { finished_ = true; } void InputFinished() { finished_ = true; }
private: private:
std::shared_ptr<FeatureExtractorInterface> frontend_; std::shared_ptr<FeatureExtractorInterface> frontend_;
std::shared_ptr<NnetInterface> nnet_; std::shared_ptr<NnetInterface> nnet_;

@ -1,3 +1,17 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once #pragma once
@ -10,10 +24,9 @@ namespace ppspeech {
class NnetInterface { class NnetInterface {
public: public:
virtual void FeedForward(const kaldi::Matrix<kaldi::BaseFloat>& features, virtual void FeedForward(const kaldi::Matrix<kaldi::BaseFloat>& features,
kaldi::Matrix<kaldi::BaseFloat>* inferences)= 0; kaldi::Matrix<kaldi::BaseFloat>* inferences) = 0;
virtual void Reset() = 0; virtual void Reset() = 0;
virtual ~NnetInterface() {} virtual ~NnetInterface() {}
}; };
} // namespace ppspeech } // namespace ppspeech

@ -1,3 +1,17 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "nnet/paddle_nnet.h" #include "nnet/paddle_nnet.h"
#include "absl/strings/str_split.h" #include "absl/strings/str_split.h"
@ -9,43 +23,44 @@ using std::shared_ptr;
using kaldi::Matrix; using kaldi::Matrix;
void PaddleNnet::InitCacheEncouts(const ModelOptions& opts) { void PaddleNnet::InitCacheEncouts(const ModelOptions& opts) {
std::vector<std::string> cache_names; std::vector<std::string> cache_names;
cache_names = absl::StrSplit(opts.cache_names, ","); cache_names = absl::StrSplit(opts.cache_names, ",");
std::vector<std::string> cache_shapes; std::vector<std::string> cache_shapes;
cache_shapes = absl::StrSplit(opts.cache_shape, ","); cache_shapes = absl::StrSplit(opts.cache_shape, ",");
assert(cache_shapes.size() == cache_names.size()); assert(cache_shapes.size() == cache_names.size());
cache_encouts_.clear(); cache_encouts_.clear();
cache_names_idx_.clear(); cache_names_idx_.clear();
for (size_t i = 0; i < cache_shapes.size(); i++) { for (size_t i = 0; i < cache_shapes.size(); i++) {
std::vector<std::string> tmp_shape; std::vector<std::string> tmp_shape;
tmp_shape = absl::StrSplit(cache_shapes[i], "-"); tmp_shape = absl::StrSplit(cache_shapes[i], "-");
std::vector<int> cur_shape; std::vector<int> cur_shape;
std::transform(tmp_shape.begin(), tmp_shape.end(), std::transform(tmp_shape.begin(),
std::back_inserter(cur_shape), tmp_shape.end(),
[](const std::string& s) { std::back_inserter(cur_shape),
return atoi(s.c_str()); [](const std::string& s) { return atoi(s.c_str()); });
}); cache_names_idx_[cache_names[i]] = i;
cache_names_idx_[cache_names[i]] = i; std::shared_ptr<Tensor<BaseFloat>> cache_eout =
std::shared_ptr<Tensor<BaseFloat>> cache_eout = std::make_shared<Tensor<BaseFloat>>(cur_shape); std::make_shared<Tensor<BaseFloat>>(cur_shape);
cache_encouts_.push_back(cache_eout); cache_encouts_.push_back(cache_eout);
} }
} }
PaddleNnet::PaddleNnet(const ModelOptions& opts):opts_(opts) { PaddleNnet::PaddleNnet(const ModelOptions& opts) : opts_(opts) {
paddle_infer::Config config; paddle_infer::Config config;
config.SetModel(opts.model_path, opts.params_path); config.SetModel(opts.model_path, opts.params_path);
if (opts.use_gpu) { if (opts.use_gpu) {
config.EnableUseGpu(500, 0); config.EnableUseGpu(500, 0);
} }
config.SwitchIrOptim(opts.switch_ir_optim); config.SwitchIrOptim(opts.switch_ir_optim);
if (opts.enable_fc_padding == false) { if (opts.enable_fc_padding == false) {
config.DisableFCPadding(); config.DisableFCPadding();
} }
if (opts.enable_profile) { if (opts.enable_profile) {
config.EnableProfile(); config.EnableProfile();
} }
pool.reset(new paddle_infer::services::PredictorPool(config, opts.thread_num)); pool.reset(
new paddle_infer::services::PredictorPool(config, opts.thread_num));
if (pool == nullptr) { if (pool == nullptr) {
LOG(ERROR) << "create the predictor pool failed"; LOG(ERROR) << "create the predictor pool failed";
} }
@ -68,16 +83,14 @@ PaddleNnet::PaddleNnet(const ModelOptions& opts):opts_(opts) {
std::vector<std::string> model_output_names = predictor->GetOutputNames(); std::vector<std::string> model_output_names = predictor->GetOutputNames();
assert(output_names_vec.size() == model_output_names.size()); assert(output_names_vec.size() == model_output_names.size());
for (size_t i = 0;i < output_names_vec.size(); i++) { for (size_t i = 0; i < output_names_vec.size(); i++) {
assert(output_names_vec[i] == model_output_names[i]); assert(output_names_vec[i] == model_output_names[i]);
} }
ReleasePredictor(predictor); ReleasePredictor(predictor);
InitCacheEncouts(opts); InitCacheEncouts(opts);
} }
void PaddleNnet::Reset() { void PaddleNnet::Reset() { InitCacheEncouts(opts_); }
InitCacheEncouts(opts_);
}
paddle_infer::Predictor* PaddleNnet::GetPredictor() { paddle_infer::Predictor* PaddleNnet::GetPredictor() {
LOG(INFO) << "attempt to get a new predictor instance " << std::endl; LOG(INFO) << "attempt to get a new predictor instance " << std::endl;
@ -122,80 +135,88 @@ int PaddleNnet::ReleasePredictor(paddle_infer::Predictor* predictor) {
} }
shared_ptr<Tensor<BaseFloat>> PaddleNnet::GetCacheEncoder(const string& name) { shared_ptr<Tensor<BaseFloat>> PaddleNnet::GetCacheEncoder(const string& name) {
auto iter = cache_names_idx_.find(name); auto iter = cache_names_idx_.find(name);
if (iter == cache_names_idx_.end()) { if (iter == cache_names_idx_.end()) {
return nullptr; return nullptr;
} }
assert(iter->second < cache_encouts_.size()); assert(iter->second < cache_encouts_.size());
return cache_encouts_[iter->second]; return cache_encouts_[iter->second];
} }
void PaddleNnet::FeedForward(const Matrix<BaseFloat>& features, Matrix<BaseFloat>* inferences) { void PaddleNnet::FeedForward(const Matrix<BaseFloat>& features,
paddle_infer::Predictor* predictor = GetPredictor(); Matrix<BaseFloat>* inferences) {
int row = features.NumRows(); paddle_infer::Predictor* predictor = GetPredictor();
int col = features.NumCols(); int row = features.NumRows();
std::vector<BaseFloat> feed_feature; int col = features.NumCols();
// todo refactor feed feature: SmileGoat std::vector<BaseFloat> feed_feature;
feed_feature.reserve(row*col); // todo refactor feed feature: SmileGoat
for (size_t row_idx = 0; row_idx < features.NumRows(); ++row_idx) { feed_feature.reserve(row * col);
for (size_t col_idx = 0; col_idx < features.NumCols(); ++col_idx) { for (size_t row_idx = 0; row_idx < features.NumRows(); ++row_idx) {
feed_feature.push_back(features(row_idx, col_idx)); for (size_t col_idx = 0; col_idx < features.NumCols(); ++col_idx) {
feed_feature.push_back(features(row_idx, col_idx));
}
} }
} std::vector<std::string> input_names = predictor->GetInputNames();
std::vector<std::string> input_names = predictor->GetInputNames(); std::vector<std::string> output_names = predictor->GetOutputNames();
std::vector<std::string> output_names = predictor->GetOutputNames(); LOG(INFO) << "feat info: row=" << row << ", col= " << col;
LOG(INFO) << "feat info: row=" << row << ", col= " << col;
std::unique_ptr<paddle_infer::Tensor> input_tensor =
std::unique_ptr<paddle_infer::Tensor> input_tensor = predictor->GetInputHandle(input_names[0]); predictor->GetInputHandle(input_names[0]);
std::vector<int> INPUT_SHAPE = {1, row, col}; std::vector<int> INPUT_SHAPE = {1, row, col};
input_tensor->Reshape(INPUT_SHAPE); input_tensor->Reshape(INPUT_SHAPE);
input_tensor->CopyFromCpu(feed_feature.data()); input_tensor->CopyFromCpu(feed_feature.data());
std::unique_ptr<paddle_infer::Tensor> input_len = predictor->GetInputHandle(input_names[1]); std::unique_ptr<paddle_infer::Tensor> input_len =
std::vector<int> input_len_size = {1}; predictor->GetInputHandle(input_names[1]);
input_len->Reshape(input_len_size); std::vector<int> input_len_size = {1};
std::vector<int64_t> audio_len; input_len->Reshape(input_len_size);
audio_len.push_back(row); std::vector<int64_t> audio_len;
input_len->CopyFromCpu(audio_len.data()); audio_len.push_back(row);
input_len->CopyFromCpu(audio_len.data());
std::unique_ptr<paddle_infer::Tensor> h_box = predictor->GetInputHandle(input_names[2]);
shared_ptr<Tensor<BaseFloat>> h_cache = GetCacheEncoder(input_names[2]); std::unique_ptr<paddle_infer::Tensor> h_box =
h_box->Reshape(h_cache->get_shape()); predictor->GetInputHandle(input_names[2]);
h_box->CopyFromCpu(h_cache->get_data().data()); shared_ptr<Tensor<BaseFloat>> h_cache = GetCacheEncoder(input_names[2]);
std::unique_ptr<paddle_infer::Tensor> c_box = predictor->GetInputHandle(input_names[3]); h_box->Reshape(h_cache->get_shape());
shared_ptr<Tensor<float>> c_cache = GetCacheEncoder(input_names[3]); h_box->CopyFromCpu(h_cache->get_data().data());
c_box->Reshape(c_cache->get_shape()); std::unique_ptr<paddle_infer::Tensor> c_box =
c_box->CopyFromCpu(c_cache->get_data().data()); predictor->GetInputHandle(input_names[3]);
bool success = predictor->Run(); shared_ptr<Tensor<float>> c_cache = GetCacheEncoder(input_names[3]);
c_box->Reshape(c_cache->get_shape());
if (success == false) { c_box->CopyFromCpu(c_cache->get_data().data());
LOG(INFO) << "predictor run occurs error"; bool success = predictor->Run();
}
if (success == false) {
LOG(INFO) << "get the model success"; LOG(INFO) << "predictor run occurs error";
std::unique_ptr<paddle_infer::Tensor> h_out = predictor->GetOutputHandle(output_names[2]); }
assert(h_cache->get_shape() == h_out->shape());
h_out->CopyToCpu(h_cache->get_data().data()); LOG(INFO) << "get the model success";
std::unique_ptr<paddle_infer::Tensor> c_out = predictor->GetOutputHandle(output_names[3]); std::unique_ptr<paddle_infer::Tensor> h_out =
assert(c_cache->get_shape() == c_out->shape()); predictor->GetOutputHandle(output_names[2]);
c_out->CopyToCpu(c_cache->get_data().data()); assert(h_cache->get_shape() == h_out->shape());
h_out->CopyToCpu(h_cache->get_data().data());
// get result std::unique_ptr<paddle_infer::Tensor> c_out =
std::unique_ptr<paddle_infer::Tensor> output_tensor = predictor->GetOutputHandle(output_names[3]);
predictor->GetOutputHandle(output_names[0]); assert(c_cache->get_shape() == c_out->shape());
std::vector<int> output_shape = output_tensor->shape(); c_out->CopyToCpu(c_cache->get_data().data());
row = output_shape[1];
col = output_shape[2]; // get result
vector<float> inferences_result; std::unique_ptr<paddle_infer::Tensor> output_tensor =
inferences->Resize(row, col); predictor->GetOutputHandle(output_names[0]);
inferences_result.resize(row*col); std::vector<int> output_shape = output_tensor->shape();
output_tensor->CopyToCpu(inferences_result.data()); row = output_shape[1];
ReleasePredictor(predictor); col = output_shape[2];
vector<float> inferences_result;
for (int row_idx = 0; row_idx < row; ++row_idx) { inferences->Resize(row, col);
for (int col_idx = 0; col_idx < col; ++col_idx) { inferences_result.resize(row * col);
(*inferences)(row_idx, col_idx) = inferences_result[col*row_idx + col_idx]; output_tensor->CopyToCpu(inferences_result.data());
ReleasePredictor(predictor);
for (int row_idx = 0; row_idx < row; ++row_idx) {
for (int col_idx = 0; col_idx < col; ++col_idx) {
(*inferences)(row_idx, col_idx) =
inferences_result[col * row_idx + col_idx];
}
} }
}
} }
} // namespace ppspeech } // namespace ppspeech

@ -1,8 +1,22 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once #pragma once
#include "nnet/nnet_interface.h"
#include "base/common.h" #include "base/common.h"
#include "nnet/nnet_interface.h"
#include "paddle_inference_api.h" #include "paddle_inference_api.h"
#include "kaldi/matrix/kaldi-matrix.h" #include "kaldi/matrix/kaldi-matrix.h"
@ -13,71 +27,79 @@
namespace ppspeech { namespace ppspeech {
struct ModelOptions { struct ModelOptions {
std::string model_path; std::string model_path;
std::string params_path; std::string params_path;
int thread_num; int thread_num;
bool use_gpu; bool use_gpu;
bool switch_ir_optim; bool switch_ir_optim;
std::string input_names; std::string input_names;
std::string output_names; std::string output_names;
std::string cache_names; std::string cache_names;
std::string cache_shape; std::string cache_shape;
bool enable_fc_padding; bool enable_fc_padding;
bool enable_profile; bool enable_profile;
ModelOptions() : ModelOptions()
model_path("../../../../model/paddle_online_deepspeech/model/avg_1.jit.pdmodel"), : model_path(
params_path("../../../../model/paddle_online_deepspeech/model/avg_1.jit.pdiparams"), "../../../../model/paddle_online_deepspeech/model/"
thread_num(2), "avg_1.jit.pdmodel"),
use_gpu(false), params_path(
input_names("audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box"), "../../../../model/paddle_online_deepspeech/model/"
output_names("save_infer_model/scale_0.tmp_1,save_infer_model/scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/scale_3.tmp_1"), "avg_1.jit.pdiparams"),
cache_names("chunk_state_h_box,chunk_state_c_box"), thread_num(2),
cache_shape("3-1-1024,3-1-1024"), use_gpu(false),
switch_ir_optim(false), input_names(
enable_fc_padding(false), "audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_"
enable_profile(false) { "box"),
} output_names(
"save_infer_model/scale_0.tmp_1,save_infer_model/"
"scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/"
"scale_3.tmp_1"),
cache_names("chunk_state_h_box,chunk_state_c_box"),
cache_shape("3-1-1024,3-1-1024"),
switch_ir_optim(false),
enable_fc_padding(false),
enable_profile(false) {}
void Register(kaldi::OptionsItf* opts) { void Register(kaldi::OptionsItf* opts) {
opts->Register("model-path", &model_path, "model file path"); opts->Register("model-path", &model_path, "model file path");
opts->Register("model-params", &params_path, "params model file path"); opts->Register("model-params", &params_path, "params model file path");
opts->Register("thread-num", &thread_num, "thread num"); opts->Register("thread-num", &thread_num, "thread num");
opts->Register("use-gpu", &use_gpu, "if use gpu"); opts->Register("use-gpu", &use_gpu, "if use gpu");
opts->Register("input-names", &input_names, "paddle input names"); opts->Register("input-names", &input_names, "paddle input names");
opts->Register("output-names", &output_names, "paddle output names"); opts->Register("output-names", &output_names, "paddle output names");
opts->Register("cache-names", &cache_names, "cache names"); opts->Register("cache-names", &cache_names, "cache names");
opts->Register("cache-shape", &cache_shape, "cache shape"); opts->Register("cache-shape", &cache_shape, "cache shape");
opts->Register("switch-ir-optiom", &switch_ir_optim, "paddle SwitchIrOptim option"); opts->Register("switch-ir-optiom",
opts->Register("enable-fc-padding", &enable_fc_padding, "paddle EnableFCPadding option"); &switch_ir_optim,
opts->Register("enable-profile", &enable_profile, "paddle EnableProfile option"); "paddle SwitchIrOptim option");
} opts->Register("enable-fc-padding",
&enable_fc_padding,
"paddle EnableFCPadding option");
opts->Register(
"enable-profile", &enable_profile, "paddle EnableProfile option");
}
}; };
template<typename T> template <typename T>
class Tensor { class Tensor {
public: public:
Tensor() { Tensor() {}
} Tensor(const std::vector<int>& shape) : _shape(shape) {
Tensor(const std::vector<int>& shape) : int data_size = std::accumulate(
_shape(shape) { _shape.begin(), _shape.end(), 1, std::multiplies<int>());
int data_size = std::accumulate(_shape.begin(), _shape.end(),
1, std::multiplies<int>());
LOG(INFO) << "data size: " << data_size; LOG(INFO) << "data size: " << data_size;
_data.resize(data_size, 0); _data.resize(data_size, 0);
} }
void reshape(const std::vector<int>& shape) { void reshape(const std::vector<int>& shape) {
_shape = shape; _shape = shape;
int data_size = std::accumulate(_shape.begin(), _shape.end(), int data_size = std::accumulate(
1, std::multiplies<int>()); _shape.begin(), _shape.end(), 1, std::multiplies<int>());
_data.resize(data_size, 0); _data.resize(data_size, 0);
} }
const std::vector<int>& get_shape() const { const std::vector<int>& get_shape() const { return _shape; }
return _shape; std::vector<T>& get_data() { return _data; }
}
std::vector<T>& get_data() { private:
return _data;
}
private:
std::vector<int> _shape; std::vector<int> _shape;
std::vector<T> _data; std::vector<T> _data;
}; };
@ -88,7 +110,8 @@ class PaddleNnet : public NnetInterface {
virtual void FeedForward(const kaldi::Matrix<kaldi::BaseFloat>& features, virtual void FeedForward(const kaldi::Matrix<kaldi::BaseFloat>& features,
kaldi::Matrix<kaldi::BaseFloat>* inferences); kaldi::Matrix<kaldi::BaseFloat>* inferences);
virtual void Reset(); virtual void Reset();
std::shared_ptr<Tensor<kaldi::BaseFloat>> GetCacheEncoder(const std::string& name); std::shared_ptr<Tensor<kaldi::BaseFloat>> GetCacheEncoder(
const std::string& name);
void InitCacheEncouts(const ModelOptions& opts); void InitCacheEncouts(const ModelOptions& opts);
private: private:
@ -107,4 +130,4 @@ class PaddleNnet : public NnetInterface {
DISALLOW_COPY_AND_ASSIGN(PaddleNnet); DISALLOW_COPY_AND_ASSIGN(PaddleNnet);
}; };
} // namespace ppspeech } // namespace ppspeech

@ -1,3 +1,17 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "utils/file_utils.h" #include "utils/file_utils.h"
namespace ppspeech { namespace ppspeech {
@ -17,5 +31,4 @@ bool ReadFileToVector(const std::string& filename,
return true; return true;
} }
} }

@ -1,8 +1,21 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "base/common.h" #include "base/common.h"
namespace ppspeech { namespace ppspeech {
bool ReadFileToVector(const std::string& filename, bool ReadFileToVector(const std::string& filename,
std::vector<std::string>* data); std::vector<std::string>* data);
} }

Loading…
Cancel
Save