pull/1541/head
Hui Zhang 3 years ago
parent 11dc485d63
commit 41feecbd06

@ -50,13 +50,13 @@ repos:
entry: bash .pre-commit-hooks/clang-format.hook -i
language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$
exclude: (?=speechx/speechx/kaldi).*(\.cpp|\.cc|\.h|\.py)$
exclude: (?=speechx/speechx/kaldi|speechx/patch).*(\.cpp|\.cc|\.h|\.py)$
- id: copyright_checker
name: copyright_checker
entry: python .pre-commit-hooks/copyright-check.hook
language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$
exclude: (?=third_party|pypinyin|speechx/speechx/kaldi).*(\.cpp|\.cc|\.h|\.py)$
exclude: (?=third_party|pypinyin|speechx/speechx/kaldi|speechx/patch).*(\.cpp|\.cc|\.h|\.py)$
- repo: https://github.com/asottile/reorder_python_imports
rev: v2.4.0
hooks:

@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False):
"""
rng = np.random.RandomState(epoch)
shift_len = rng.randint(0, batch_size - 1)
batch_indices = list(zip(*[iter(indices[shift_len:])] * batch_size))
batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size))
rng.shuffle(batch_indices)
batch_indices = [item for batch in batch_indices for item in batch]
assert clipped is False

@ -36,4 +36,4 @@ def repeat(N, fn):
Returns:
MultiSequential: Repeated model instance.
"""
return MultiSequential(*[fn(n) for n in range(N)])
return MultiSequential(* [fn(n) for n in range(N)])

@ -3,4 +3,3 @@
* decoder - offline decoder
* feat - mfcc, linear
* nnet - ds2 nn

@ -1,11 +1,25 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// todo refactor, repalce with gtest
#include "base/flags.h"
#include "base/log.h"
#include "decoder/ctc_beam_search_decoder.h"
#include "kaldi/util/table-types.h"
#include "base/log.h"
#include "base/flags.h"
#include "nnet/paddle_nnet.h"
#include "nnet/decodable.h"
#include "nnet/paddle_nnet.h"
DEFINE_string(feature_respecifier, "", "test nnet prob");
@ -13,45 +27,48 @@ using kaldi::BaseFloat;
using kaldi::Matrix;
using std::vector;
//void SplitFeature(kaldi::Matrix<BaseFloat> feature,
// void SplitFeature(kaldi::Matrix<BaseFloat> feature,
// int32 chunk_size,
// std::vector<kaldi::Matrix<BaseFloat>* feature_chunks) {
//}
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
kaldi::SequentialBaseFloatMatrixReader feature_reader(FLAGS_feature_respecifier);
// test nnet_output --> decoder result
int32 num_done = 0, num_err = 0;
ppspeech::CTCBeamSearchOptions opts;
ppspeech::CTCBeamSearch decoder(opts);
ppspeech::ModelOptions model_opts;
std::shared_ptr<ppspeech::PaddleNnet> nnet(new ppspeech::PaddleNnet(model_opts));
std::shared_ptr<ppspeech::Decodable> decodable(new ppspeech::Decodable(nnet));
//int32 chunk_size = 35;
decoder.InitDecoder();
for (; !feature_reader.Done(); feature_reader.Next()) {
string utt = feature_reader.Key();
const kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
decodable->FeedFeatures(feature);
decoder.AdvanceDecode(decodable, 8);
decodable->InputFinished();
std::string result;
result = decoder.GetFinalBestPath();
KALDI_LOG << " the result of " << utt << " is " << result;
decodable->Reset();
++num_done;
}
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
return (num_done != 0 ? 0 : 1);
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_respecifier);
// test nnet_output --> decoder result
int32 num_done = 0, num_err = 0;
ppspeech::CTCBeamSearchOptions opts;
ppspeech::CTCBeamSearch decoder(opts);
ppspeech::ModelOptions model_opts;
std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts));
std::shared_ptr<ppspeech::Decodable> decodable(
new ppspeech::Decodable(nnet));
// int32 chunk_size = 35;
decoder.InitDecoder();
for (; !feature_reader.Done(); feature_reader.Next()) {
string utt = feature_reader.Key();
const kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
decodable->FeedFeatures(feature);
decoder.AdvanceDecode(decodable, 8);
decodable->InputFinished();
std::string result;
result = decoder.GetFinalBestPath();
KALDI_LOG << " the result of " << utt << " is " << result;
decodable->Reset();
++num_done;
}
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
return (num_done != 0 ? 0 : 1);
}

File diff suppressed because it is too large Load Diff

@ -1,13 +1,27 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// todo refactor, repalce with gtest
#include "base/flags.h"
#include "base/log.h"
#include "frontend/feature_extractor_interface.h"
#include "frontend/linear_spectrogram.h"
#include "frontend/normalizer.h"
#include "frontend/feature_extractor_interface.h"
#include "kaldi/util/table-types.h"
#include "base/log.h"
#include "base/flags.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/kaldi-io.h"
#include "kaldi/util/table-types.h"
DEFINE_string(wav_rspecifier, "", "test wav path");
DEFINE_string(feature_wspecifier, "", "test wav ark");
@ -15,110 +29,228 @@ DEFINE_string(feature_check_wspecifier, "", "test wav ark");
DEFINE_string(cmvn_write_path, "./cmvn.ark", "test wav ark");
std::vector<float> mean_{-13730251.531853663, -12982852.199316509, -13673844.299583456, -13089406.559646806, -12673095.524938712, -12823859.223276224, -13590267.158903603, -14257618.467152044, -14374605.116185192, -14490009.21822485, -14849827.158924166, -15354435.470563512, -15834149.206532761, -16172971.985514281, -16348740.496746974, -16423536.699409386, -16556246.263649225, -16744088.772748645, -16916184.08510357, -17054034.840031497, -17165612.509455364, -17255955.470915023, -17322572.527648456, -17408943.862033736, -17521554.799865916, -17620623.254924215, -17699792.395918526, -17723364.411134344, -17741483.4433254, -17747426.888704527, -17733315.928209435, -17748780.160905756, -17808336.883775543, -17895918.671983004, -18009812.59173023, -18098188.66548325, -18195798.958462656, -18293617.62980999, -18397432.92077201, -18505834.787318766, -18585451.8100908, -18652438.235649142, -18700960.306275308, -18734944.58792185, -18737426.313365128, -18735347.165987637, -18738813.444170244, -18737086.848890636, -18731576.2474336, -18717405.44095871, -18703089.25545657, -18691014.546456724, -18692460.568905357, -18702119.628629155, -18727710.621126678, -18761582.72034647, -18806745.835547544, -18850674.8692112, -18884431.510951452, -18919999.992506847, -18939303.799078144, -18952946.273760635, -18980289.22996379, -19011610.17803294, -19040948.61805145, -19061021.429847397, -19112055.53768819, -19149667.414264943, -19201127.05091321, -19270250.82564605, -19334606.883057203, -19390513.336589377, -19444176.259208687, -19502755.000038862, -19544333.014549147, -19612668.183176614, -19681902.19006569, -19771969.951249883, -19873329.723376893, -19996752.59235844, -20110031.131400537, -20231658.612529557, -20319378.894054495, -20378534.45718066, -20413332.089584175, -20438147.844177883, -20443710.248040095, -20465457.02238927, -20488610.969337028, -20516295.16424432, -20541423.795738827, -20553192.874953747, -20573605.50701977, -20577871.61936797, -20571807.008916274, -20556242.38912231, -20542199.30819195, -20521239.063551214, -20519150.80004532, -20527204.80248933, -20536933.769257784, -20543470.522332076, -20549700.089992985, -20551525.24958494, -20554873.406493705, -20564277.65794227, -20572211.740052115, -20574305.69550465, -20575494.450104576, -20567092.577932164, -20549302.929608088, -20545445.11878376, -20546625.326603737, -20549190.03499401, -20554824.947828256, -20568341.378989458, -20577582.331383612, -20577980.519402675, -20566603.03458152, -20560131.592262644, -20552166.469060015, -20549063.06763577, -20544490.562339947, -20539817.82346569, -20528747.715731595, -20518026.24576161, -20510977.844974525, -20506874.36087992, -20506731.11977665, -20510482.133420516, -20507760.92101862, -20494644.834457114, -20480107.89304893, -20461312.091867123, -20442941.75080173, -20426123.02834838, -20424607.675283, -20426810.369107097, -20434024.50097819, -20437404.75544205, -20447688.63916367, -20460893.335563846, -20482922.735127095, -20503610.119434915, -20527062.76448319, -20557830.035128627, -20593274.72068722, -20632528.452965066, -20673637.471334763, -20733106.97143075, -20842921.0447562, -21054357.83621519, -21416569.534189366, -21978460.272811692, -22753170.052172784, -23671344.10563395, -24613499.293358143, -25406477.12230188, -25884377.82156489, -26049040.62791664, -26996879.104431007};
std::vector<float> variance_{213747175.10846674, 188395815.34302503, 212706429.10966414, 199109025.81461075, 189235901.23864496, 194901336.53253657, 217481594.29306737, 238689869.12327808, 243977501.24115244, 248479623.6431067, 259766741.47116545, 275516766.7790273, 291271202.3691234, 302693239.8220509, 308627358.3997694, 311143911.38788426, 315446105.07731867, 321705430.9341829, 327458907.4659941, 332245072.43223983, 336251717.5935284, 339694069.7639722, 342188204.4322228, 345587110.31313115, 349903086.2875232, 353660214.20643026, 356700344.5270885, 357665362.3529641, 358493352.05658793, 358857951.620328, 358375239.52774596, 358899733.6342954, 361051818.3511561, 364361716.05025816, 368750322.3771452, 372047800.6462831, 375655861.1349018, 379358519.1980013, 383327605.3935181, 387458599.282341, 390434692.3406868, 392994486.35057056, 394874418.04603153, 396230525.79763395, 396365592.0414835, 396334819.8242737, 396488353.19250053, 396438877.00744957, 396197980.4459586, 395590921.6672991, 395001107.62072515, 394528291.7318225, 394593110.424006, 395018405.59353715, 396110577.5415993, 397506704.0371068, 399400197.4657644, 401243568.2468382, 402687134.7805103, 404136047.2872507, 404883170.001883, 405522253.219517, 406660365.3626476, 407919346.0991902, 409045348.5384909, 409759588.7889818, 411974821.8564483, 413489718.78201455, 415535392.56684107, 418466481.97674364, 421104678.35678065, 423405392.5200779, 425550570.40798235, 427929423.9579701, 429585274.253478, 432368493.55181056, 435193587.13513297, 438886855.20476013, 443058876.8633751, 448181232.5093362, 452883835.6332396, 458056721.77926534, 461816531.22735566, 464363620.1970998, 465886343.5057493, 466928872.0651, 467180536.42647296, 468111848.70714295, 469138695.3071312, 470378429.6930793, 471517958.7132626, 472109050.4262365, 473087417.0177867, 473381322.04648733, 473220195.85483915, 472666071.8998819, 472124669.87879956, 471298571.411737, 471251033.2902761, 471672676.43128747, 472177147.2193172, 472572361.7711908, 472968783.7751127, 473156295.4164052, 473398034.82676554, 473897703.5203811, 474328271.33112127, 474452670.98002136, 474549003.99284613, 474252887.13567275, 473557462.909069, 473483385.85193115, 473609738.04855174, 473746944.82085115, 474016729.91696435, 474617321.94138587, 475045097.237122, 475125402.586558, 474664112.9824912, 474426247.5800283, 474104075.42796475, 473978219.7273978, 473773171.7798875, 473578534.69508696, 473102924.16904145, 472651240.5232615, 472374383.1810912, 472209479.6956096, 472202298.8921673, 472370090.76781124, 472220933.99374026, 471625467.37106377, 470994646.51883453, 470182428.9637543, 469348211.5939578, 468570387.4467277, 468540442.7225135, 468672018.90414184, 468994346.9533251, 469138757.58201426, 469553915.95710236, 470134523.38582784, 471082421.62055486, 471962316.51804745, 472939745.1708408, 474250621.5944825, 475773933.43199486, 477465399.71087736, 479218782.61382693, 481752299.7930922, 486608947.8984568, 496119403.2067917, 512730085.5704984, 539048915.2641417, 576285298.3548826, 621610270.2240586, 669308196.4436442, 710656993.5957186, 736344437.3725077, 745481288.0241544, 801121432.9925804};
std::vector<float> mean_{
-13730251.531853663, -12982852.199316509, -13673844.299583456,
-13089406.559646806, -12673095.524938712, -12823859.223276224,
-13590267.158903603, -14257618.467152044, -14374605.116185192,
-14490009.21822485, -14849827.158924166, -15354435.470563512,
-15834149.206532761, -16172971.985514281, -16348740.496746974,
-16423536.699409386, -16556246.263649225, -16744088.772748645,
-16916184.08510357, -17054034.840031497, -17165612.509455364,
-17255955.470915023, -17322572.527648456, -17408943.862033736,
-17521554.799865916, -17620623.254924215, -17699792.395918526,
-17723364.411134344, -17741483.4433254, -17747426.888704527,
-17733315.928209435, -17748780.160905756, -17808336.883775543,
-17895918.671983004, -18009812.59173023, -18098188.66548325,
-18195798.958462656, -18293617.62980999, -18397432.92077201,
-18505834.787318766, -18585451.8100908, -18652438.235649142,
-18700960.306275308, -18734944.58792185, -18737426.313365128,
-18735347.165987637, -18738813.444170244, -18737086.848890636,
-18731576.2474336, -18717405.44095871, -18703089.25545657,
-18691014.546456724, -18692460.568905357, -18702119.628629155,
-18727710.621126678, -18761582.72034647, -18806745.835547544,
-18850674.8692112, -18884431.510951452, -18919999.992506847,
-18939303.799078144, -18952946.273760635, -18980289.22996379,
-19011610.17803294, -19040948.61805145, -19061021.429847397,
-19112055.53768819, -19149667.414264943, -19201127.05091321,
-19270250.82564605, -19334606.883057203, -19390513.336589377,
-19444176.259208687, -19502755.000038862, -19544333.014549147,
-19612668.183176614, -19681902.19006569, -19771969.951249883,
-19873329.723376893, -19996752.59235844, -20110031.131400537,
-20231658.612529557, -20319378.894054495, -20378534.45718066,
-20413332.089584175, -20438147.844177883, -20443710.248040095,
-20465457.02238927, -20488610.969337028, -20516295.16424432,
-20541423.795738827, -20553192.874953747, -20573605.50701977,
-20577871.61936797, -20571807.008916274, -20556242.38912231,
-20542199.30819195, -20521239.063551214, -20519150.80004532,
-20527204.80248933, -20536933.769257784, -20543470.522332076,
-20549700.089992985, -20551525.24958494, -20554873.406493705,
-20564277.65794227, -20572211.740052115, -20574305.69550465,
-20575494.450104576, -20567092.577932164, -20549302.929608088,
-20545445.11878376, -20546625.326603737, -20549190.03499401,
-20554824.947828256, -20568341.378989458, -20577582.331383612,
-20577980.519402675, -20566603.03458152, -20560131.592262644,
-20552166.469060015, -20549063.06763577, -20544490.562339947,
-20539817.82346569, -20528747.715731595, -20518026.24576161,
-20510977.844974525, -20506874.36087992, -20506731.11977665,
-20510482.133420516, -20507760.92101862, -20494644.834457114,
-20480107.89304893, -20461312.091867123, -20442941.75080173,
-20426123.02834838, -20424607.675283, -20426810.369107097,
-20434024.50097819, -20437404.75544205, -20447688.63916367,
-20460893.335563846, -20482922.735127095, -20503610.119434915,
-20527062.76448319, -20557830.035128627, -20593274.72068722,
-20632528.452965066, -20673637.471334763, -20733106.97143075,
-20842921.0447562, -21054357.83621519, -21416569.534189366,
-21978460.272811692, -22753170.052172784, -23671344.10563395,
-24613499.293358143, -25406477.12230188, -25884377.82156489,
-26049040.62791664, -26996879.104431007};
std::vector<float> variance_{
213747175.10846674, 188395815.34302503, 212706429.10966414,
199109025.81461075, 189235901.23864496, 194901336.53253657,
217481594.29306737, 238689869.12327808, 243977501.24115244,
248479623.6431067, 259766741.47116545, 275516766.7790273,
291271202.3691234, 302693239.8220509, 308627358.3997694,
311143911.38788426, 315446105.07731867, 321705430.9341829,
327458907.4659941, 332245072.43223983, 336251717.5935284,
339694069.7639722, 342188204.4322228, 345587110.31313115,
349903086.2875232, 353660214.20643026, 356700344.5270885,
357665362.3529641, 358493352.05658793, 358857951.620328,
358375239.52774596, 358899733.6342954, 361051818.3511561,
364361716.05025816, 368750322.3771452, 372047800.6462831,
375655861.1349018, 379358519.1980013, 383327605.3935181,
387458599.282341, 390434692.3406868, 392994486.35057056,
394874418.04603153, 396230525.79763395, 396365592.0414835,
396334819.8242737, 396488353.19250053, 396438877.00744957,
396197980.4459586, 395590921.6672991, 395001107.62072515,
394528291.7318225, 394593110.424006, 395018405.59353715,
396110577.5415993, 397506704.0371068, 399400197.4657644,
401243568.2468382, 402687134.7805103, 404136047.2872507,
404883170.001883, 405522253.219517, 406660365.3626476,
407919346.0991902, 409045348.5384909, 409759588.7889818,
411974821.8564483, 413489718.78201455, 415535392.56684107,
418466481.97674364, 421104678.35678065, 423405392.5200779,
425550570.40798235, 427929423.9579701, 429585274.253478,
432368493.55181056, 435193587.13513297, 438886855.20476013,
443058876.8633751, 448181232.5093362, 452883835.6332396,
458056721.77926534, 461816531.22735566, 464363620.1970998,
465886343.5057493, 466928872.0651, 467180536.42647296,
468111848.70714295, 469138695.3071312, 470378429.6930793,
471517958.7132626, 472109050.4262365, 473087417.0177867,
473381322.04648733, 473220195.85483915, 472666071.8998819,
472124669.87879956, 471298571.411737, 471251033.2902761,
471672676.43128747, 472177147.2193172, 472572361.7711908,
472968783.7751127, 473156295.4164052, 473398034.82676554,
473897703.5203811, 474328271.33112127, 474452670.98002136,
474549003.99284613, 474252887.13567275, 473557462.909069,
473483385.85193115, 473609738.04855174, 473746944.82085115,
474016729.91696435, 474617321.94138587, 475045097.237122,
475125402.586558, 474664112.9824912, 474426247.5800283,
474104075.42796475, 473978219.7273978, 473773171.7798875,
473578534.69508696, 473102924.16904145, 472651240.5232615,
472374383.1810912, 472209479.6956096, 472202298.8921673,
472370090.76781124, 472220933.99374026, 471625467.37106377,
470994646.51883453, 470182428.9637543, 469348211.5939578,
468570387.4467277, 468540442.7225135, 468672018.90414184,
468994346.9533251, 469138757.58201426, 469553915.95710236,
470134523.38582784, 471082421.62055486, 471962316.51804745,
472939745.1708408, 474250621.5944825, 475773933.43199486,
477465399.71087736, 479218782.61382693, 481752299.7930922,
486608947.8984568, 496119403.2067917, 512730085.5704984,
539048915.2641417, 576285298.3548826, 621610270.2240586,
669308196.4436442, 710656993.5957186, 736344437.3725077,
745481288.0241544, 801121432.9925804};
int count_ = 912592;
void WriteMatrix() {
kaldi::Matrix<double> cmvn_stats(2, mean_.size()+ 1);
for (size_t idx = 0; idx < mean_.size(); ++idx) {
cmvn_stats(0, idx) = mean_[idx];
cmvn_stats(1, idx) = variance_[idx];
}
cmvn_stats(0, mean_.size()) = count_;
kaldi::WriteKaldiObject(cmvn_stats, FLAGS_cmvn_write_path, true);
kaldi::Matrix<double> cmvn_stats(2, mean_.size() + 1);
for (size_t idx = 0; idx < mean_.size(); ++idx) {
cmvn_stats(0, idx) = mean_[idx];
cmvn_stats(1, idx) = variance_[idx];
}
cmvn_stats(0, mean_.size()) = count_;
kaldi::WriteKaldiObject(cmvn_stats, FLAGS_cmvn_write_path, true);
}
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(FLAGS_wav_rspecifier);
kaldi::BaseFloatMatrixWriter feat_writer(FLAGS_feature_wspecifier);
kaldi::BaseFloatMatrixWriter feat_cmvn_check_writer(FLAGS_feature_check_wspecifier);
WriteMatrix();
// test feature linear_spectorgram: wave --> decibel_normalizer --> hanning window -->linear_spectrogram --> cmvn
int32 num_done = 0, num_err = 0;
ppspeech::LinearSpectrogramOptions opt;
opt.frame_opts.frame_length_ms = 20;
opt.frame_opts.frame_shift_ms = 10;
ppspeech::DecibelNormalizerOptions db_norm_opt;
std::unique_ptr<ppspeech::FeatureExtractorInterface> base_feature_extractor(
new ppspeech::DecibelNormalizer(db_norm_opt));
ppspeech::LinearSpectrogram linear_spectrogram(opt, std::move(base_feature_extractor));
ppspeech::CMVN cmvn(FLAGS_cmvn_write_path);
float streaming_chunk = 0.36;
int sample_rate = 16000;
int chunk_sample_size = streaming_chunk * sample_rate;
LOG(INFO) << mean_.size();
for (size_t i = 0; i < mean_.size(); i++) {
mean_[i] /= count_;
variance_[i] = variance_[i] / count_ - mean_[i] * mean_[i];
if (variance_[i] < 1.0e-20) {
variance_[i] = 1.0e-20;
}
variance_[i] = 1.0 / std::sqrt(variance_[i]);
}
for (; !wav_reader.Done(); wav_reader.Next()) {
std::string utt = wav_reader.Key();
const kaldi::WaveData &wave_data = wav_reader.Value();
int32 this_channel = 0;
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(), this_channel);
int tot_samples = waveform.Dim();
int sample_offset = 0;
std::vector<kaldi::Matrix<BaseFloat>> feats;
int feature_rows = 0;
while (sample_offset < tot_samples) {
int cur_chunk_size = std::min(chunk_sample_size, tot_samples - sample_offset);
kaldi::Vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk(i) = waveform(sample_offset + i);
}
kaldi::Matrix<BaseFloat> features;
linear_spectrogram.AcceptWaveform(wav_chunk);
linear_spectrogram.ReadFeats(&features);
feats.push_back(features);
sample_offset += cur_chunk_size;
feature_rows += features.NumRows();
}
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
int cur_idx = 0;
kaldi::Matrix<kaldi::BaseFloat> features(feature_rows, feats[0].NumCols());
for (auto feat : feats) {
for (int row_idx = 0; row_idx < feat.NumRows(); ++row_idx) {
for (int col_idx = 0; col_idx < feat.NumCols(); ++col_idx) {
features(cur_idx, col_idx) = (feat(row_idx, col_idx) - mean_[col_idx])*variance_[col_idx];
}
++cur_idx;
}
}
feat_writer.Write(utt, features);
cur_idx = 0;
kaldi::Matrix<kaldi::BaseFloat> features_check(feature_rows, feats[0].NumCols());
for (auto feat : feats) {
for (int row_idx = 0; row_idx < feat.NumRows(); ++row_idx) {
for (int col_idx = 0; col_idx < feat.NumCols(); ++col_idx) {
features_check(cur_idx, col_idx) = feat(row_idx, col_idx);
}
kaldi::SubVector<BaseFloat> row_feat(features_check, cur_idx);
cmvn.ApplyCMVN(true, &row_feat);
++cur_idx;
}
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
FLAGS_wav_rspecifier);
kaldi::BaseFloatMatrixWriter feat_writer(FLAGS_feature_wspecifier);
kaldi::BaseFloatMatrixWriter feat_cmvn_check_writer(
FLAGS_feature_check_wspecifier);
WriteMatrix();
// test feature linear_spectorgram: wave --> decibel_normalizer --> hanning
// window -->linear_spectrogram --> cmvn
int32 num_done = 0, num_err = 0;
ppspeech::LinearSpectrogramOptions opt;
opt.frame_opts.frame_length_ms = 20;
opt.frame_opts.frame_shift_ms = 10;
ppspeech::DecibelNormalizerOptions db_norm_opt;
std::unique_ptr<ppspeech::FeatureExtractorInterface> base_feature_extractor(
new ppspeech::DecibelNormalizer(db_norm_opt));
ppspeech::LinearSpectrogram linear_spectrogram(
opt, std::move(base_feature_extractor));
ppspeech::CMVN cmvn(FLAGS_cmvn_write_path);
float streaming_chunk = 0.36;
int sample_rate = 16000;
int chunk_sample_size = streaming_chunk * sample_rate;
LOG(INFO) << mean_.size();
for (size_t i = 0; i < mean_.size(); i++) {
mean_[i] /= count_;
variance_[i] = variance_[i] / count_ - mean_[i] * mean_[i];
if (variance_[i] < 1.0e-20) {
variance_[i] = 1.0e-20;
}
variance_[i] = 1.0 / std::sqrt(variance_[i]);
}
feat_cmvn_check_writer.Write(utt, features_check);
if (num_done % 50 == 0 && num_done != 0)
KALDI_VLOG(2) << "Processed " << num_done << " utterances";
num_done++;
}
for (; !wav_reader.Done(); wav_reader.Next()) {
std::string utt = wav_reader.Key();
const kaldi::WaveData& wave_data = wav_reader.Value();
int32 this_channel = 0;
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
this_channel);
int tot_samples = waveform.Dim();
int sample_offset = 0;
std::vector<kaldi::Matrix<BaseFloat>> feats;
int feature_rows = 0;
while (sample_offset < tot_samples) {
int cur_chunk_size =
std::min(chunk_sample_size, tot_samples - sample_offset);
kaldi::Vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk(i) = waveform(sample_offset + i);
}
kaldi::Matrix<BaseFloat> features;
linear_spectrogram.AcceptWaveform(wav_chunk);
linear_spectrogram.ReadFeats(&features);
feats.push_back(features);
sample_offset += cur_chunk_size;
feature_rows += features.NumRows();
}
int cur_idx = 0;
kaldi::Matrix<kaldi::BaseFloat> features(feature_rows,
feats[0].NumCols());
for (auto feat : feats) {
for (int row_idx = 0; row_idx < feat.NumRows(); ++row_idx) {
for (int col_idx = 0; col_idx < feat.NumCols(); ++col_idx) {
features(cur_idx, col_idx) =
(feat(row_idx, col_idx) - mean_[col_idx]) *
variance_[col_idx];
}
++cur_idx;
}
}
feat_writer.Write(utt, features);
cur_idx = 0;
kaldi::Matrix<kaldi::BaseFloat> features_check(feature_rows,
feats[0].NumCols());
for (auto feat : feats) {
for (int row_idx = 0; row_idx < feat.NumRows(); ++row_idx) {
for (int col_idx = 0; col_idx < feat.NumCols(); ++col_idx) {
features_check(cur_idx, col_idx) = feat(row_idx, col_idx);
}
kaldi::SubVector<BaseFloat> row_feat(features_check, cur_idx);
cmvn.ApplyCMVN(true, &row_feat);
++cur_idx;
}
}
feat_cmvn_check_writer.Write(utt, features_check);
if (num_done % 50 == 0 && num_done != 0)
KALDI_VLOG(2) << "Processed " << num_done << " utterances";
num_done++;
}
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
return (num_done != 0 ? 0 : 1);

@ -1,12 +1,26 @@
#include "paddle_inference_api.h"
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <iostream>
#include <thread>
#include <algorithm>
#include <fstream>
#include <functional>
#include <iostream>
#include <iterator>
#include <algorithm>
#include <numeric>
#include <functional>
#include <thread>
#include "paddle_inference_api.h"
using std::cout;
using std::endl;
@ -19,27 +33,28 @@ void produce_data(std::vector<std::vector<float>>* data);
void model_forward_test();
void produce_data(std::vector<std::vector<float>>* data) {
int chunk_size = 35; // chunk_size in frame
int col_size = 161; // feat dim
cout << "chunk size: " << chunk_size << endl;
cout << "feat dim: " << col_size << endl;
data->reserve(chunk_size);
data->back().reserve(col_size);
for (int row = 0; row < chunk_size; ++row) {
data->push_back(std::vector<float>());
for (int col_idx = 0; col_idx < col_size; ++col_idx) {
data->back().push_back(0.201);
int chunk_size = 35; // chunk_size in frame
int col_size = 161; // feat dim
cout << "chunk size: " << chunk_size << endl;
cout << "feat dim: " << col_size << endl;
data->reserve(chunk_size);
data->back().reserve(col_size);
for (int row = 0; row < chunk_size; ++row) {
data->push_back(std::vector<float>());
for (int col_idx = 0; col_idx < col_size; ++col_idx) {
data->back().push_back(0.201);
}
}
}
}
void model_forward_test() {
std::cout << "1. read the data" << std::endl;
std::vector<std::vector<float>> feats;
produce_data(&feats);
std::cout << "2. load the model" << std::endl;;
std::cout << "2. load the model" << std::endl;
;
std::string model_graph = FLAGS_model_path;
std::string model_params = FLAGS_param_path;
cout << "model path: " << model_graph << endl;
@ -48,14 +63,15 @@ void model_forward_test() {
paddle_infer::Config config;
config.SetModel(model_graph, model_params);
config.SwitchIrOptim(false);
cout << "SwitchIrOptim: " << false << endl;
cout << "SwitchIrOptim: " << false << endl;
config.DisableFCPadding();
cout << "DisableFCPadding: " << endl;
auto predictor = paddle_infer::CreatePredictor(config);
std::cout << "3. feat shape, row=" << feats.size() << ",col=" << feats[0].size() << std::endl;
std::cout << "3. feat shape, row=" << feats.size()
<< ",col=" << feats[0].size() << std::endl;
std::vector<float> pp_input_mat;
for(const auto& item : feats) {
for (const auto& item : feats) {
pp_input_mat.insert(pp_input_mat.end(), item.begin(), item.end());
}
@ -64,22 +80,23 @@ void model_forward_test() {
int col = feats[0].size();
std::vector<std::string> input_names = predictor->GetInputNames();
std::vector<std::string> output_names = predictor->GetOutputNames();
for (auto name : input_names){
cout << "model input names: " << name << endl;
for (auto name : input_names) {
cout << "model input names: " << name << endl;
}
for (auto name : output_names){
cout << "model output names: " << name << endl;
for (auto name : output_names) {
cout << "model output names: " << name << endl;
}
// input
std::unique_ptr<paddle_infer::Tensor> input_tensor =
std::unique_ptr<paddle_infer::Tensor> input_tensor =
predictor->GetInputHandle(input_names[0]);
std::vector<int> INPUT_SHAPE = {1, row, col};
input_tensor->Reshape(INPUT_SHAPE);
input_tensor->CopyFromCpu(pp_input_mat.data());
// input length
std::unique_ptr<paddle_infer::Tensor> input_len = predictor->GetInputHandle(input_names[1]);
std::unique_ptr<paddle_infer::Tensor> input_len =
predictor->GetInputHandle(input_names[1]);
std::vector<int> input_len_size = {1};
input_len->Reshape(input_len_size);
std::vector<int64_t> audio_len;
@ -87,20 +104,28 @@ void model_forward_test() {
input_len->CopyFromCpu(audio_len.data());
// state_h
std::unique_ptr<paddle_infer::Tensor> chunk_state_h_box = predictor->GetInputHandle(input_names[2]);
std::unique_ptr<paddle_infer::Tensor> chunk_state_h_box =
predictor->GetInputHandle(input_names[2]);
std::vector<int> chunk_state_h_box_shape = {3, 1, 1024};
chunk_state_h_box->Reshape(chunk_state_h_box_shape);
int chunk_state_h_box_size = std::accumulate(chunk_state_h_box_shape.begin(), chunk_state_h_box_shape.end(),
1, std::multiplies<int>());
int chunk_state_h_box_size =
std::accumulate(chunk_state_h_box_shape.begin(),
chunk_state_h_box_shape.end(),
1,
std::multiplies<int>());
std::vector<float> chunk_state_h_box_data(chunk_state_h_box_size, 0.0f);
chunk_state_h_box->CopyFromCpu(chunk_state_h_box_data.data());
// state_c
std::unique_ptr<paddle_infer::Tensor> chunk_state_c_box = predictor->GetInputHandle(input_names[3]);
std::unique_ptr<paddle_infer::Tensor> chunk_state_c_box =
predictor->GetInputHandle(input_names[3]);
std::vector<int> chunk_state_c_box_shape = {3, 1, 1024};
chunk_state_c_box->Reshape(chunk_state_c_box_shape);
int chunk_state_c_box_size = std::accumulate(chunk_state_c_box_shape.begin(), chunk_state_c_box_shape.end(),
1, std::multiplies<int>());
int chunk_state_c_box_size =
std::accumulate(chunk_state_c_box_shape.begin(),
chunk_state_c_box_shape.end(),
1,
std::multiplies<int>());
std::vector<float> chunk_state_c_box_data(chunk_state_c_box_size, 0.0f);
chunk_state_c_box->CopyFromCpu(chunk_state_c_box_data.data());
@ -108,18 +133,20 @@ void model_forward_test() {
bool success = predictor->Run();
// state_h out
std::unique_ptr<paddle_infer::Tensor> h_out = predictor->GetOutputHandle(output_names[2]);
std::unique_ptr<paddle_infer::Tensor> h_out =
predictor->GetOutputHandle(output_names[2]);
std::vector<int> h_out_shape = h_out->shape();
int h_out_size = std::accumulate(h_out_shape.begin(), h_out_shape.end(),
1, std::multiplies<int>());
int h_out_size = std::accumulate(
h_out_shape.begin(), h_out_shape.end(), 1, std::multiplies<int>());
std::vector<float> h_out_data(h_out_size);
h_out->CopyToCpu(h_out_data.data());
// stage_c out
std::unique_ptr<paddle_infer::Tensor> c_out = predictor->GetOutputHandle(output_names[3]);
std::unique_ptr<paddle_infer::Tensor> c_out =
predictor->GetOutputHandle(output_names[3]);
std::vector<int> c_out_shape = c_out->shape();
int c_out_size = std::accumulate(c_out_shape.begin(), c_out_shape.end(),
1, std::multiplies<int>());
int c_out_size = std::accumulate(
c_out_shape.begin(), c_out_shape.end(), 1, std::multiplies<int>());
std::vector<float> c_out_data(c_out_size);
c_out->CopyToCpu(c_out_data.data());
@ -128,8 +155,8 @@ void model_forward_test() {
predictor->GetOutputHandle(output_names[0]);
std::vector<int> output_shape = output_tensor->shape();
std::vector<float> output_probs;
int output_size = std::accumulate(output_shape.begin(), output_shape.end(),
1, std::multiplies<int>());
int output_size = std::accumulate(
output_shape.begin(), output_shape.end(), 1, std::multiplies<int>());
output_probs.resize(output_size);
output_tensor->CopyToCpu(output_probs.data());
row = output_shape[1];
@ -148,12 +175,14 @@ void model_forward_test() {
}
std::vector<std::vector<float>> log_feat = probs;
std::cout << "probs, row: " << log_feat.size() << " col: " << log_feat[0].size() << std::endl;
std::cout << "probs, row: " << log_feat.size()
<< " col: " << log_feat[0].size() << std::endl;
for (size_t row_idx = 0; row_idx < log_feat.size(); ++row_idx) {
for (size_t col_idx = 0; col_idx < log_feat[row_idx].size(); ++col_idx) {
std::cout << log_feat[row_idx][col_idx] << " ";
}
std::cout << std::endl;
for (size_t col_idx = 0; col_idx < log_feat[row_idx].size();
++col_idx) {
std::cout << log_feat[row_idx][col_idx] << " ";
}
std::cout << std::endl;
}
}

@ -18,22 +18,22 @@
#include <limits>
typedef float BaseFloat;
typedef double double64;
typedef float BaseFloat;
typedef double double64;
typedef signed char int8;
typedef short int16;
typedef int int32;
typedef signed char int8;
typedef short int16;
typedef int int32;
#if defined(__LP64__) && !defined(OS_MACOSX) && !defined(OS_OPENBSD)
typedef long int64;
typedef long int64;
#else
typedef long long int64;
typedef long long int64;
#endif
typedef unsigned char uint8;
typedef unsigned short uint16;
typedef unsigned int uint32;
typedef unsigned char uint8;
typedef unsigned short uint16;
typedef unsigned int uint32;
#if defined(__LP64__) && !defined(OS_MACOSX) && !defined(OS_OPENBSD)
typedef unsigned long uint64;
@ -41,20 +41,20 @@ typedef unsigned long uint64;
typedef unsigned long long uint64;
#endif
typedef signed int char32;
const uint8 kuint8max = (( uint8) 0xFF);
const uint16 kuint16max = ((uint16) 0xFFFF);
const uint32 kuint32max = ((uint32) 0xFFFFFFFF);
const uint64 kuint64max = ((uint64) (0xFFFFFFFFFFFFFFFFLL));
const int8 kint8min = (( int8) 0x80);
const int8 kint8max = (( int8) 0x7F);
const int16 kint16min = (( int16) 0x8000);
const int16 kint16max = (( int16) 0x7FFF);
const int32 kint32min = (( int32) 0x80000000);
const int32 kint32max = (( int32) 0x7FFFFFFF);
const int64 kint64min = (( int64) (0x8000000000000000LL));
const int64 kint64max = (( int64) (0x7FFFFFFFFFFFFFFFLL));
const BaseFloat kBaseFloatMax = std::numeric_limits<BaseFloat>::max();
const BaseFloat kBaseFloatMin = std::numeric_limits<BaseFloat>::min();
typedef signed int char32;
const uint8 kuint8max = ((uint8)0xFF);
const uint16 kuint16max = ((uint16)0xFFFF);
const uint32 kuint32max = ((uint32)0xFFFFFFFF);
const uint64 kuint64max = ((uint64)(0xFFFFFFFFFFFFFFFFLL));
const int8 kint8min = ((int8)0x80);
const int8 kint8max = ((int8)0x7F);
const int16 kint16min = ((int16)0x8000);
const int16 kint16max = ((int16)0x7FFF);
const int32 kint32min = ((int32)0x80000000);
const int32 kint32max = ((int32)0x7FFFFFFF);
const int64 kint64min = ((int64)(0x8000000000000000LL));
const int64 kint64max = ((int64)(0x7FFFFFFFFFFFFFFFLL));
const BaseFloat kBaseFloatMax = std::numeric_limits<BaseFloat>::max();
const BaseFloat kBaseFloatMin = std::numeric_limits<BaseFloat>::min();

@ -15,22 +15,22 @@
#pragma once
#include <deque>
#include <fstream>
#include <iostream>
#include <istream>
#include <fstream>
#include <map>
#include <memory>
#include <mutex>
#include <ostream>
#include <set>
#include <sstream>
#include <stack>
#include <string>
#include <vector>
#include <unordered_map>
#include <unordered_set>
#include <mutex>
#include <vector>
#include "base/log.h"
#include "base/flags.h"
#include "base/basic_types.h"
#include "base/flags.h"
#include "base/log.h"
#include "base/macros.h"

@ -18,8 +18,8 @@ namespace ppspeech {
#ifndef DISALLOW_COPY_AND_ASSIGN
#define DISALLOW_COPY_AND_ASSIGN(TypeName) \
TypeName(const TypeName&) = delete; \
void operator=(const TypeName&) = delete
TypeName(const TypeName&) = delete; \
void operator=(const TypeName&) = delete
#endif
} // namespace pp_speech

@ -23,98 +23,88 @@
#ifndef BASE_THREAD_POOL_H
#define BASE_THREAD_POOL_H
#include <vector>
#include <queue>
#include <memory>
#include <thread>
#include <mutex>
#include <condition_variable>
#include <future>
#include <functional>
#include <future>
#include <memory>
#include <mutex>
#include <queue>
#include <stdexcept>
#include <thread>
#include <vector>
class ThreadPool {
public:
public:
ThreadPool(size_t);
template<class F, class... Args>
auto enqueue(F&& f, Args&&... args)
template <class F, class... Args>
auto enqueue(F&& f, Args&&... args)
-> std::future<typename std::result_of<F(Args...)>::type>;
~ThreadPool();
private:
private:
// need to keep track of threads so we can join them
std::vector< std::thread > workers;
std::vector<std::thread> workers;
// the task queue
std::queue< std::function<void()> > tasks;
std::queue<std::function<void()>> tasks;
// synchronization
std::mutex queue_mutex;
std::condition_variable condition;
bool stop;
};
// the constructor just launches some amount of workers
inline ThreadPool::ThreadPool(size_t threads)
: stop(false)
{
for(size_t i = 0;i<threads;++i)
workers.emplace_back(
[this]
{
for(;;)
inline ThreadPool::ThreadPool(size_t threads) : stop(false) {
for (size_t i = 0; i < threads; ++i)
workers.emplace_back([this] {
for (;;) {
std::function<void()> task;
{
std::function<void()> task;
{
std::unique_lock<std::mutex> lock(this->queue_mutex);
this->condition.wait(lock,
[this]{ return this->stop || !this->tasks.empty(); });
if(this->stop && this->tasks.empty())
return;
task = std::move(this->tasks.front());
this->tasks.pop();
}
task();
std::unique_lock<std::mutex> lock(this->queue_mutex);
this->condition.wait(lock, [this] {
return this->stop || !this->tasks.empty();
});
if (this->stop && this->tasks.empty()) return;
task = std::move(this->tasks.front());
this->tasks.pop();
}
task();
}
);
});
}
// add new work item to the pool
template<class F, class... Args>
auto ThreadPool::enqueue(F&& f, Args&&... args)
-> std::future<typename std::result_of<F(Args...)>::type>
{
template <class F, class... Args>
auto ThreadPool::enqueue(F&& f, Args&&... args)
-> std::future<typename std::result_of<F(Args...)>::type> {
using return_type = typename std::result_of<F(Args...)>::type;
auto task = std::make_shared< std::packaged_task<return_type()> >(
std::bind(std::forward<F>(f), std::forward<Args>(args)...)
);
auto task = std::make_shared<std::packaged_task<return_type()>>(
std::bind(std::forward<F>(f), std::forward<Args>(args)...));
std::future<return_type> res = task->get_future();
{
std::unique_lock<std::mutex> lock(queue_mutex);
// don't allow enqueueing after stopping the pool
if(stop)
throw std::runtime_error("enqueue on stopped ThreadPool");
if (stop) throw std::runtime_error("enqueue on stopped ThreadPool");
tasks.emplace([task](){ (*task)(); });
tasks.emplace([task]() { (*task)(); });
}
condition.notify_one();
return res;
}
// the destructor joins all threads
inline ThreadPool::~ThreadPool()
{
inline ThreadPool::~ThreadPool() {
{
std::unique_lock<std::mutex> lock(queue_mutex);
stop = true;
}
condition.notify_all();
for(std::thread &worker: workers)
worker.join();
for (std::thread& worker : workers) worker.join();
}
#endif

@ -1,7 +1,21 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "base/basic_types.h"
struct DecoderResult {
BaseFloat acoustic_score;
std::vector<int32> words_idx;
std::vector<pair<int32, int32>> time_stamp;
BaseFloat acoustic_score;
std::vector<int32> words_idx;
std::vector<pair<int32, int32>> time_stamp;
};

@ -1,3 +1,17 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder/ctc_beam_search_decoder.h"
#include "base/basic_types.h"
@ -9,292 +23,290 @@ namespace ppspeech {
using std::vector;
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts) :
opts_(opts),
init_ext_scorer_(nullptr),
blank_id(-1),
space_id(-1),
num_frame_decoded_(0),
root(nullptr) {
CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts)
: opts_(opts),
init_ext_scorer_(nullptr),
blank_id(-1),
space_id(-1),
num_frame_decoded_(0),
root(nullptr) {
LOG(INFO) << "dict path: " << opts_.dict_file;
if (!ReadFileToVector(opts_.dict_file, &vocabulary_)) {
LOG(INFO) << "load the dict failed";
}
LOG(INFO) << "read the vocabulary success, dict size: " << vocabulary_.size();
LOG(INFO) << "read the vocabulary success, dict size: "
<< vocabulary_.size();
LOG(INFO) << "language model path: " << opts_.lm_path;
init_ext_scorer_ = std::make_shared<Scorer>(opts_.alpha,
opts_.beta,
opts_.lm_path,
vocabulary_);
init_ext_scorer_ = std::make_shared<Scorer>(
opts_.alpha, opts_.beta, opts_.lm_path, vocabulary_);
}
void CTCBeamSearch::Reset() {
num_frame_decoded_ = 0;
ResetPrefixes();
num_frame_decoded_ = 0;
ResetPrefixes();
}
void CTCBeamSearch::InitDecoder() {
blank_id = 0;
auto it = std::find(vocabulary_.begin(), vocabulary_.end(), " ");
space_id = it - vocabulary_.begin();
// if no space in vocabulary
if ((size_t)space_id >= vocabulary_.size()) {
space_id = -2;
}
}
ResetPrefixes();
root = std::make_shared<PathTrie>();
root->score = root->log_prob_b_prev = 0.0;
prefixes.push_back(root.get());
if (init_ext_scorer_ != nullptr && !init_ext_scorer_->is_character_based()) {
if (init_ext_scorer_ != nullptr &&
!init_ext_scorer_->is_character_based()) {
auto fst_dict =
static_cast<fst::StdVectorFst *>(init_ext_scorer_->dictionary);
fst::StdVectorFst *dict_ptr = fst_dict->Copy(true);
static_cast<fst::StdVectorFst*>(init_ext_scorer_->dictionary);
fst::StdVectorFst* dict_ptr = fst_dict->Copy(true);
root->set_dictionary(dict_ptr);
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
root->set_matcher(matcher);
}
}
void CTCBeamSearch::Decode(std::shared_ptr<kaldi::DecodableInterface> decodable) {
return;
void CTCBeamSearch::Decode(
std::shared_ptr<kaldi::DecodableInterface> decodable) {
return;
}
int32 CTCBeamSearch::NumFrameDecoded() {
return num_frame_decoded_;
}
int32 CTCBeamSearch::NumFrameDecoded() { return num_frame_decoded_; }
// todo rename, refactor
void CTCBeamSearch::AdvanceDecode(const std::shared_ptr<kaldi::DecodableInterface>& decodable,
int max_frames) {
void CTCBeamSearch::AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable,
int max_frames) {
while (max_frames > 0) {
vector<vector<BaseFloat>> likelihood;
if (decodable->IsLastFrame(NumFrameDecoded() + 1)) {
break;
}
likelihood.push_back(decodable->FrameLogLikelihood(NumFrameDecoded() + 1));
AdvanceDecoding(likelihood);
max_frames--;
vector<vector<BaseFloat>> likelihood;
if (decodable->IsLastFrame(NumFrameDecoded() + 1)) {
break;
}
likelihood.push_back(
decodable->FrameLogLikelihood(NumFrameDecoded() + 1));
AdvanceDecoding(likelihood);
max_frames--;
}
}
void CTCBeamSearch::ResetPrefixes() {
for (size_t i = 0; i < prefixes.size(); i++) {
if (prefixes[i] != nullptr) {
delete prefixes[i];
prefixes[i] = nullptr;
for (size_t i = 0; i < prefixes.size(); i++) {
if (prefixes[i] != nullptr) {
delete prefixes[i];
prefixes[i] = nullptr;
}
}
}
}
int CTCBeamSearch::DecodeLikelihoods(const vector<vector<float>>&probs,
vector<string>& nbest_words) {
kaldi::Timer timer;
timer.Reset();
AdvanceDecoding(probs);
LOG(INFO) <<"ctc decoding elapsed time(s) " << static_cast<float>(timer.Elapsed()) / 1000.0f;
return 0;
}
int CTCBeamSearch::DecodeLikelihoods(const vector<vector<float>>& probs,
vector<string>& nbest_words) {
kaldi::Timer timer;
timer.Reset();
AdvanceDecoding(probs);
LOG(INFO) << "ctc decoding elapsed time(s) "
<< static_cast<float>(timer.Elapsed()) / 1000.0f;
return 0;
}
vector<std::pair<double, string>> CTCBeamSearch::GetNBestPath() {
return get_beam_search_result(prefixes, vocabulary_, opts_.beam_size);
return get_beam_search_result(prefixes, vocabulary_, opts_.beam_size);
}
string CTCBeamSearch::GetBestPath() {
std::vector<std::pair<double, std::string>> result;
result = get_beam_search_result(prefixes, vocabulary_, opts_.beam_size);
return result[0].second;
std::vector<std::pair<double, std::string>> result;
result = get_beam_search_result(prefixes, vocabulary_, opts_.beam_size);
return result[0].second;
}
string CTCBeamSearch::GetFinalBestPath() {
CalculateApproxScore();
LMRescore();
return GetBestPath();
CalculateApproxScore();
LMRescore();
return GetBestPath();
}
void CTCBeamSearch::AdvanceDecoding(const vector<vector<BaseFloat>>& probs) {
size_t num_time_steps = probs.size();
size_t beam_size = opts_.beam_size;
double cutoff_prob = opts_.cutoff_prob;
size_t cutoff_top_n = opts_.cutoff_top_n;
vector<vector<double>> probs_seq(probs.size(), vector<double>(probs[0].size(), 0));
int row = probs.size();
int col = probs[0].size();
for(int i = 0; i < row; i++) {
for (int j = 0; j < col; j++){
probs_seq[i][j] = static_cast<double>(probs[i][j]);
}
}
for (size_t time_step = 0; time_step < num_time_steps; time_step++) {
const auto& prob = probs_seq[time_step];
float min_cutoff = -NUM_FLT_INF;
bool full_beam = false;
if (init_ext_scorer_ != nullptr) {
size_t num_prefixes = std::min(prefixes.size(), beam_size);
std::sort(prefixes.begin(), prefixes.begin() + num_prefixes,
prefix_compare);
if (num_prefixes == 0) {
continue;
}
min_cutoff = prefixes[num_prefixes - 1]->score +
std::log(prob[blank_id]) -
std::max(0.0, init_ext_scorer_->beta);
full_beam = (num_prefixes == beam_size);
size_t num_time_steps = probs.size();
size_t beam_size = opts_.beam_size;
double cutoff_prob = opts_.cutoff_prob;
size_t cutoff_top_n = opts_.cutoff_top_n;
vector<vector<double>> probs_seq(probs.size(),
vector<double>(probs[0].size(), 0));
int row = probs.size();
int col = probs[0].size();
for (int i = 0; i < row; i++) {
for (int j = 0; j < col; j++) {
probs_seq[i][j] = static_cast<double>(probs[i][j]);
}
}
vector<std::pair<size_t, float>> log_prob_idx =
get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n);
// loop over chars
size_t log_prob_idx_len = log_prob_idx.size();
for (size_t index = 0; index < log_prob_idx_len; index++) {
SearchOneChar(full_beam, log_prob_idx[index], min_cutoff);
}
prefixes.clear();
// update log probs
root->iterate_to_vec(prefixes);
// only preserve top beam_size prefixes
if (prefixes.size() >= beam_size) {
std::nth_element(prefixes.begin(),
prefixes.begin() + beam_size,
prefixes.end(),
prefix_compare);
for (size_t i = beam_size; i < prefixes.size(); ++i) {
prefixes[i]->remove();
}
} // if
num_frame_decoded_++;
} // for probs_seq
}
int32 CTCBeamSearch::SearchOneChar(const bool& full_beam,
const std::pair<size_t, BaseFloat>& log_prob_idx,
const BaseFloat& min_cutoff) {
size_t beam_size = opts_.beam_size;
const auto& c = log_prob_idx.first;
const auto& log_prob_c = log_prob_idx.second;
size_t prefixes_len = std::min(prefixes.size(), beam_size);
for (size_t i = 0; i < prefixes_len; ++i) {
auto prefix = prefixes[i];
if (full_beam && log_prob_c + prefix->score < min_cutoff) {
break;
}
for (size_t time_step = 0; time_step < num_time_steps; time_step++) {
const auto& prob = probs_seq[time_step];
float min_cutoff = -NUM_FLT_INF;
bool full_beam = false;
if (init_ext_scorer_ != nullptr) {
size_t num_prefixes = std::min(prefixes.size(), beam_size);
std::sort(prefixes.begin(),
prefixes.begin() + num_prefixes,
prefix_compare);
if (num_prefixes == 0) {
continue;
}
min_cutoff = prefixes[num_prefixes - 1]->score +
std::log(prob[blank_id]) -
std::max(0.0, init_ext_scorer_->beta);
full_beam = (num_prefixes == beam_size);
}
if (c == blank_id) {
prefix->log_prob_b_cur = log_sum_exp(
prefix->log_prob_b_cur,
log_prob_c +
prefix->score);
continue;
}
vector<std::pair<size_t, float>> log_prob_idx =
get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n);
// repeated character
if (c == prefix->character) {
// p_{nb}(l;x_{1:t}) = p(c;x_{t})p(l;x_{1:t-1})
prefix->log_prob_nb_cur = log_sum_exp(
prefix->log_prob_nb_cur,
log_prob_c +
prefix->log_prob_nb_prev);
}
// loop over chars
size_t log_prob_idx_len = log_prob_idx.size();
for (size_t index = 0; index < log_prob_idx_len; index++) {
SearchOneChar(full_beam, log_prob_idx[index], min_cutoff);
}
prefixes.clear();
// update log probs
root->iterate_to_vec(prefixes);
// only preserve top beam_size prefixes
if (prefixes.size() >= beam_size) {
std::nth_element(prefixes.begin(),
prefixes.begin() + beam_size,
prefixes.end(),
prefix_compare);
for (size_t i = beam_size; i < prefixes.size(); ++i) {
prefixes[i]->remove();
}
} // if
num_frame_decoded_++;
} // for probs_seq
}
// get new prefix
auto prefix_new = prefix->get_path_trie(c);
if (prefix_new != nullptr) {
float log_p = -NUM_FLT_INF;
if (c == prefix->character &&
prefix->log_prob_b_prev > -NUM_FLT_INF) {
// p_{nb}(l^{+};x_{1:t}) = p(c;x_{t})p_{b}(l;x_{1:t-1})
log_p = log_prob_c + prefix->log_prob_b_prev;
} else if (c != prefix->character) {
// p_{nb}(l^{+};x_{1:t}) = p(c;x_{t}) p(l;x_{1:t-1})
log_p = log_prob_c + prefix->score;
}
// language model scoring
if (init_ext_scorer_ != nullptr &&
(c == space_id || init_ext_scorer_->is_character_based())) {
PathTrie *prefix_to_score = nullptr;
// skip scoring the space
if (init_ext_scorer_->is_character_based()) {
prefix_to_score = prefix_new;
} else {
prefix_to_score = prefix;
int32 CTCBeamSearch::SearchOneChar(
const bool& full_beam,
const std::pair<size_t, BaseFloat>& log_prob_idx,
const BaseFloat& min_cutoff) {
size_t beam_size = opts_.beam_size;
const auto& c = log_prob_idx.first;
const auto& log_prob_c = log_prob_idx.second;
size_t prefixes_len = std::min(prefixes.size(), beam_size);
for (size_t i = 0; i < prefixes_len; ++i) {
auto prefix = prefixes[i];
if (full_beam && log_prob_c + prefix->score < min_cutoff) {
break;
}
float score = 0.0;
vector<string> ngram;
ngram = init_ext_scorer_->make_ngram(prefix_to_score);
// lm score: p_{lm}(W)^{\alpha} + \beta
score = init_ext_scorer_->get_log_cond_prob(ngram) *
init_ext_scorer_->alpha;
log_p += score;
log_p += init_ext_scorer_->beta;
}
// p_{nb}(l;x_{1:t})
prefix_new->log_prob_nb_cur =
log_sum_exp(prefix_new->log_prob_nb_cur,
log_p);
}
} // end of loop over prefix
return 0;
if (c == blank_id) {
prefix->log_prob_b_cur =
log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score);
continue;
}
// repeated character
if (c == prefix->character) {
// p_{nb}(l;x_{1:t}) = p(c;x_{t})p(l;x_{1:t-1})
prefix->log_prob_nb_cur = log_sum_exp(
prefix->log_prob_nb_cur, log_prob_c + prefix->log_prob_nb_prev);
}
// get new prefix
auto prefix_new = prefix->get_path_trie(c);
if (prefix_new != nullptr) {
float log_p = -NUM_FLT_INF;
if (c == prefix->character &&
prefix->log_prob_b_prev > -NUM_FLT_INF) {
// p_{nb}(l^{+};x_{1:t}) = p(c;x_{t})p_{b}(l;x_{1:t-1})
log_p = log_prob_c + prefix->log_prob_b_prev;
} else if (c != prefix->character) {
// p_{nb}(l^{+};x_{1:t}) = p(c;x_{t}) p(l;x_{1:t-1})
log_p = log_prob_c + prefix->score;
}
// language model scoring
if (init_ext_scorer_ != nullptr &&
(c == space_id || init_ext_scorer_->is_character_based())) {
PathTrie* prefix_to_score = nullptr;
// skip scoring the space
if (init_ext_scorer_->is_character_based()) {
prefix_to_score = prefix_new;
} else {
prefix_to_score = prefix;
}
float score = 0.0;
vector<string> ngram;
ngram = init_ext_scorer_->make_ngram(prefix_to_score);
// lm score: p_{lm}(W)^{\alpha} + \beta
score = init_ext_scorer_->get_log_cond_prob(ngram) *
init_ext_scorer_->alpha;
log_p += score;
log_p += init_ext_scorer_->beta;
}
// p_{nb}(l;x_{1:t})
prefix_new->log_prob_nb_cur =
log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
}
} // end of loop over prefix
return 0;
}
void CTCBeamSearch::CalculateApproxScore() {
size_t beam_size = opts_.beam_size;
size_t num_prefixes = std::min(prefixes.size(), beam_size);
std::sort(
prefixes.begin(),
prefixes.begin() + num_prefixes,
prefix_compare);
// compute aproximate ctc score as the return score, without affecting the
// return order of decoding result. To delete when decoder gets stable.
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
double approx_ctc = prefixes[i]->score;
if (init_ext_scorer_ != nullptr) {
vector<int> output;
prefixes[i]->get_path_vec(output);
auto prefix_length = output.size();
auto words = init_ext_scorer_->split_labels(output);
// remove word insert
approx_ctc = approx_ctc - prefix_length * init_ext_scorer_->beta;
// remove language model weight:
approx_ctc -=
(init_ext_scorer_->get_sent_log_prob(words)) * init_ext_scorer_->alpha;
size_t beam_size = opts_.beam_size;
size_t num_prefixes = std::min(prefixes.size(), beam_size);
std::sort(
prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare);
// compute aproximate ctc score as the return score, without affecting the
// return order of decoding result. To delete when decoder gets stable.
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
double approx_ctc = prefixes[i]->score;
if (init_ext_scorer_ != nullptr) {
vector<int> output;
prefixes[i]->get_path_vec(output);
auto prefix_length = output.size();
auto words = init_ext_scorer_->split_labels(output);
// remove word insert
approx_ctc = approx_ctc - prefix_length * init_ext_scorer_->beta;
// remove language model weight:
approx_ctc -= (init_ext_scorer_->get_sent_log_prob(words)) *
init_ext_scorer_->alpha;
}
prefixes[i]->approx_ctc = approx_ctc;
}
prefixes[i]->approx_ctc = approx_ctc;
}
}
void CTCBeamSearch::LMRescore() {
size_t beam_size = opts_.beam_size;
if (init_ext_scorer_ != nullptr && !init_ext_scorer_->is_character_based()) {
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
auto prefix = prefixes[i];
if (!prefix->is_empty() && prefix->character != space_id) {
float score = 0.0;
vector<string> ngram = init_ext_scorer_->make_ngram(prefix);
score = init_ext_scorer_->get_log_cond_prob(ngram) * init_ext_scorer_->alpha;
score += init_ext_scorer_->beta;
prefix->score += score;
}
size_t beam_size = opts_.beam_size;
if (init_ext_scorer_ != nullptr &&
!init_ext_scorer_->is_character_based()) {
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
auto prefix = prefixes[i];
if (!prefix->is_empty() && prefix->character != space_id) {
float score = 0.0;
vector<string> ngram = init_ext_scorer_->make_ngram(prefix);
score = init_ext_scorer_->get_log_cond_prob(ngram) *
init_ext_scorer_->alpha;
score += init_ext_scorer_->beta;
prefix->score += score;
}
}
}
}
}
} // namespace ppspeech
} // namespace ppspeech

@ -1,8 +1,22 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "base/common.h"
#include "decoder/ctc_decoders/path_trie.h"
#include "decoder/ctc_decoders/scorer.h"
#include "nnet/decodable-itf.h"
#include "util/parse-options.h"
#include "decoder/ctc_decoders/scorer.h"
#include "decoder/ctc_decoders/path_trie.h"
#pragma once
@ -17,63 +31,66 @@ struct CTCBeamSearchOptions {
int beam_size;
int cutoff_top_n;
int num_proc_bsearch;
CTCBeamSearchOptions() :
dict_file("./model/words.txt"),
lm_path("./model/lm.arpa"),
alpha(1.9f),
beta(5.0),
beam_size(300),
cutoff_prob(0.99f),
cutoff_top_n(40),
num_proc_bsearch(0) {
}
CTCBeamSearchOptions()
: dict_file("./model/words.txt"),
lm_path("./model/lm.arpa"),
alpha(1.9f),
beta(5.0),
beam_size(300),
cutoff_prob(0.99f),
cutoff_top_n(40),
num_proc_bsearch(0) {}
void Register(kaldi::OptionsItf* opts) {
opts->Register("dict", &dict_file, "dict file ");
opts->Register("lm-path", &lm_path, "language model file");
opts->Register("alpha", &alpha, "alpha");
opts->Register("beta", &beta, "beta");
opts->Register("beam-size", &beam_size, "beam size for beam search method");
opts->Register(
"beam-size", &beam_size, "beam size for beam search method");
opts->Register("cutoff-prob", &cutoff_prob, "cutoff probs");
opts->Register("cutoff-top-n", &cutoff_top_n, "cutoff top n");
opts->Register("num-proc-bsearch", &num_proc_bsearch, "num proc bsearch");
opts->Register(
"num-proc-bsearch", &num_proc_bsearch, "num proc bsearch");
}
};
class CTCBeamSearch {
public:
public:
explicit CTCBeamSearch(const CTCBeamSearchOptions& opts);
~CTCBeamSearch() {}
void InitDecoder();
void Decode(std::shared_ptr<kaldi::DecodableInterface> decodable);
std::string GetBestPath();
std::vector<std::pair<double, std::string>> GetNBestPath();
std::string GetBestPath();
std::vector<std::pair<double, std::string>> GetNBestPath();
std::string GetFinalBestPath();
int NumFrameDecoded();
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>&probs,
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs,
std::vector<std::string>& nbest_words);
void AdvanceDecode(const std::shared_ptr<kaldi::DecodableInterface>& decodable,
int max_frames);
void AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable,
int max_frames);
void Reset();
private:
void ResetPrefixes();
int32 SearchOneChar(const bool& full_beam,
const std::pair<size_t, BaseFloat>& log_prob_idx,
const BaseFloat& min_cutoff);
void CalculateApproxScore();
void LMRescore();
void AdvanceDecoding(const std::vector<std::vector<BaseFloat>>& probs);
CTCBeamSearchOptions opts_;
std::shared_ptr<Scorer> init_ext_scorer_; // todo separate later
//std::vector<DecodeResult> decoder_results_;
std::vector<std::string> vocabulary_; // todo remove later
size_t blank_id;
int space_id;
std::shared_ptr<PathTrie> root;
std::vector<PathTrie*> prefixes;
int num_frame_decoded_;
DISALLOW_COPY_AND_ASSIGN(CTCBeamSearch);
private:
void ResetPrefixes();
int32 SearchOneChar(const bool& full_beam,
const std::pair<size_t, BaseFloat>& log_prob_idx,
const BaseFloat& min_cutoff);
void CalculateApproxScore();
void LMRescore();
void AdvanceDecoding(const std::vector<std::vector<BaseFloat>>& probs);
CTCBeamSearchOptions opts_;
std::shared_ptr<Scorer> init_ext_scorer_; // todo separate later
// std::vector<DecodeResult> decoder_results_;
std::vector<std::string> vocabulary_; // todo remove later
size_t blank_id;
int space_id;
std::shared_ptr<PathTrie> root;
std::vector<PathTrie*> prefixes;
int num_frame_decoded_;
DISALLOW_COPY_AND_ASSIGN(CTCBeamSearch);
};
} // namespace basr
} // namespace basr

@ -22,15 +22,16 @@ namespace ppspeech {
class FbankExtractor : FeatureExtractorInterface {
public:
explicit FbankExtractor(const FbankOptions& opts,
explicit FbankExtractor(const FbankOptions& opts,
share_ptr<FeatureExtractorInterface> pre_extractor);
virtual void AcceptWaveform(const kaldi::Vector<kaldi::BaseFloat>& input) = 0;
virtual void AcceptWaveform(
const kaldi::Vector<kaldi::BaseFloat>& input) = 0;
virtual void Read(kaldi::Vector<kaldi::BaseFloat>* feat) = 0;
virtual size_t Dim() const = 0;
private:
bool Compute(const kaldi::Vector<kaldi::BaseFloat>& wave,
kaldi::Vector<kaldi::BaseFloat>* feat) const;
bool Compute(const kaldi::Vector<kaldi::BaseFloat>& wave,
kaldi::Vector<kaldi::BaseFloat>* feat) const;
};
} // namespace ppspeech

@ -0,0 +1,14 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

@ -0,0 +1,14 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

@ -21,7 +21,8 @@ namespace ppspeech {
class FeatureExtractorInterface {
public:
virtual void AcceptWaveform(const kaldi::VectorBase<kaldi::BaseFloat>& input) = 0;
virtual void AcceptWaveform(
const kaldi::VectorBase<kaldi::BaseFloat>& input) = 0;
virtual void Read(kaldi::VectorBase<kaldi::BaseFloat>* feat) = 0;
virtual size_t Dim() const = 0;
};

@ -25,97 +25,97 @@ using kaldi::VectorBase;
using kaldi::Matrix;
using std::vector;
//todo remove later
// todo remove later
void CopyVector2StdVector_(const VectorBase<BaseFloat>& input,
vector<BaseFloat>* output) {
if (input.Dim() == 0) return;
output->resize(input.Dim());
for (size_t idx = 0; idx < input.Dim(); ++idx) {
(*output)[idx] = input(idx);
}
vector<BaseFloat>* output) {
if (input.Dim() == 0) return;
output->resize(input.Dim());
for (size_t idx = 0; idx < input.Dim(); ++idx) {
(*output)[idx] = input(idx);
}
}
void CopyStdVector2Vector_(const vector<BaseFloat>& input,
Vector<BaseFloat>* output) {
if (input.empty()) return;
output->Resize(input.size());
for (size_t idx = 0; idx < input.size(); ++idx) {
(*output)(idx) = input[idx];
}
Vector<BaseFloat>* output) {
if (input.empty()) return;
output->Resize(input.size());
for (size_t idx = 0; idx < input.size(); ++idx) {
(*output)(idx) = input[idx];
}
}
LinearSpectrogram::LinearSpectrogram(
const LinearSpectrogramOptions& opts,
std::unique_ptr<FeatureExtractorInterface> base_extractor) {
opts_ = opts;
base_extractor_ = std::move(base_extractor);
int32 window_size = opts.frame_opts.WindowSize();
int32 window_shift = opts.frame_opts.WindowShift();
fft_points_ = window_size;
hanning_window_.resize(window_size);
double a = M_2PI / (window_size - 1);
hanning_window_energy_ = 0;
for (int i = 0; i < window_size; ++i) {
hanning_window_[i] = 0.5 - 0.5 * cos(a * i);
hanning_window_energy_ += hanning_window_[i] * hanning_window_[i];
}
dim_ = fft_points_ / 2 + 1; // the dimension is Fs/2 Hz
opts_ = opts;
base_extractor_ = std::move(base_extractor);
int32 window_size = opts.frame_opts.WindowSize();
int32 window_shift = opts.frame_opts.WindowShift();
fft_points_ = window_size;
hanning_window_.resize(window_size);
double a = M_2PI / (window_size - 1);
hanning_window_energy_ = 0;
for (int i = 0; i < window_size; ++i) {
hanning_window_[i] = 0.5 - 0.5 * cos(a * i);
hanning_window_energy_ += hanning_window_[i] * hanning_window_[i];
}
dim_ = fft_points_ / 2 + 1; // the dimension is Fs/2 Hz
}
void LinearSpectrogram::AcceptWaveform(const VectorBase<BaseFloat>& input) {
base_extractor_->AcceptWaveform(input);
base_extractor_->AcceptWaveform(input);
}
void LinearSpectrogram::Hanning(vector<float>* data) const {
CHECK_GE(data->size(), hanning_window_.size());
CHECK_GE(data->size(), hanning_window_.size());
for (size_t i = 0; i < hanning_window_.size(); ++i) {
data->at(i) *= hanning_window_[i];
}
for (size_t i = 0; i < hanning_window_.size(); ++i) {
data->at(i) *= hanning_window_[i];
}
}
bool LinearSpectrogram::NumpyFft(vector<BaseFloat>* v,
vector<BaseFloat>* real,
vector<BaseFloat>* img) const {
Vector<BaseFloat> v_tmp;
CopyStdVector2Vector_(*v, &v_tmp);
RealFft(&v_tmp, true);
CopyVector2StdVector_(v_tmp, v);
real->push_back(v->at(0));
img->push_back(0);
for (int i = 1; i < v->size() / 2; i++) {
real->push_back(v->at(2 * i));
img->push_back(v->at(2 * i + 1));
}
real->push_back(v->at(1));
img->push_back(0);
return true;
Vector<BaseFloat> v_tmp;
CopyStdVector2Vector_(*v, &v_tmp);
RealFft(&v_tmp, true);
CopyVector2StdVector_(v_tmp, v);
real->push_back(v->at(0));
img->push_back(0);
for (int i = 1; i < v->size() / 2; i++) {
real->push_back(v->at(2 * i));
img->push_back(v->at(2 * i + 1));
}
real->push_back(v->at(1));
img->push_back(0);
return true;
}
// todo remove later
void LinearSpectrogram::ReadFeats(Matrix<BaseFloat>* feats) {
Vector<BaseFloat> tmp;
waveform_.Resize(base_extractor_->Dim());
Compute(tmp, &waveform_);
vector<vector<BaseFloat>> result;
vector<BaseFloat> feats_vec;
CopyVector2StdVector_(waveform_, &feats_vec);
Compute(feats_vec, result);
feats->Resize(result.size(), result[0].size());
for (int row_idx = 0; row_idx < result.size(); ++row_idx) {
for (int col_idx = 0; col_idx < result[0].size(); ++col_idx) {
(*feats)(row_idx, col_idx) = result[row_idx][col_idx];
Vector<BaseFloat> tmp;
waveform_.Resize(base_extractor_->Dim());
Compute(tmp, &waveform_);
vector<vector<BaseFloat>> result;
vector<BaseFloat> feats_vec;
CopyVector2StdVector_(waveform_, &feats_vec);
Compute(feats_vec, result);
feats->Resize(result.size(), result[0].size());
for (int row_idx = 0; row_idx < result.size(); ++row_idx) {
for (int col_idx = 0; col_idx < result[0].size(); ++col_idx) {
(*feats)(row_idx, col_idx) = result[row_idx][col_idx];
}
}
}
waveform_.Resize(0);
waveform_.Resize(0);
}
void LinearSpectrogram::Read(VectorBase<BaseFloat>* feat) {
// todo
return;
// todo
return;
}
// only for test, remove later
@ -129,49 +129,49 @@ void LinearSpectrogram::Compute(const VectorBase<kaldi::BaseFloat>& input,
// todo: refactor later (SmileGoat)
bool LinearSpectrogram::Compute(const vector<float>& wave,
vector<vector<float>>& feat) {
int num_samples = wave.size();
const int& frame_length = opts_.frame_opts.WindowSize();
const int& sample_rate = opts_.frame_opts.samp_freq;
const int& frame_shift = opts_.frame_opts.WindowShift();
const int& fft_points = fft_points_;
const float scale = hanning_window_energy_ * sample_rate;
if (num_samples < frame_length) {
return true;
}
int num_frames = 1 + ((num_samples - frame_length) / frame_shift);
feat.resize(num_frames);
vector<float> fft_real((fft_points_ / 2 + 1), 0);
vector<float> fft_img((fft_points_ / 2 + 1), 0);
vector<float> v(frame_length, 0);
vector<float> power((fft_points / 2 + 1));
for (int i = 0; i < num_frames; ++i) {
vector<float> data(wave.data() + i * frame_shift,
wave.data() + i * frame_shift + frame_length);
Hanning(&data);
fft_img.clear();
fft_real.clear();
v.assign(data.begin(), data.end());
NumpyFft(&v, &fft_real, &fft_img);
feat[i].resize(fft_points / 2 + 1); // the last dimension is Fs/2 Hz
for (int j = 0; j < (fft_points / 2 + 1); ++j) {
power[j] = fft_real[j] * fft_real[j] + fft_img[j] * fft_img[j];
feat[i][j] = power[j];
if (j == 0 || j == feat[0].size() - 1) {
feat[i][j] /= scale;
} else {
feat[i][j] *= (2.0 / scale);
}
// log added eps=1e-14
feat[i][j] = std::log(feat[i][j] + 1e-14);
int num_samples = wave.size();
const int& frame_length = opts_.frame_opts.WindowSize();
const int& sample_rate = opts_.frame_opts.samp_freq;
const int& frame_shift = opts_.frame_opts.WindowShift();
const int& fft_points = fft_points_;
const float scale = hanning_window_energy_ * sample_rate;
if (num_samples < frame_length) {
return true;
}
int num_frames = 1 + ((num_samples - frame_length) / frame_shift);
feat.resize(num_frames);
vector<float> fft_real((fft_points_ / 2 + 1), 0);
vector<float> fft_img((fft_points_ / 2 + 1), 0);
vector<float> v(frame_length, 0);
vector<float> power((fft_points / 2 + 1));
for (int i = 0; i < num_frames; ++i) {
vector<float> data(wave.data() + i * frame_shift,
wave.data() + i * frame_shift + frame_length);
Hanning(&data);
fft_img.clear();
fft_real.clear();
v.assign(data.begin(), data.end());
NumpyFft(&v, &fft_real, &fft_img);
feat[i].resize(fft_points / 2 + 1); // the last dimension is Fs/2 Hz
for (int j = 0; j < (fft_points / 2 + 1); ++j) {
power[j] = fft_real[j] * fft_real[j] + fft_img[j] * fft_img[j];
feat[i][j] = power[j];
if (j == 0 || j == feat[0].size() - 1) {
feat[i][j] /= scale;
} else {
feat[i][j] *= (2.0 / scale);
}
// log added eps=1e-14
feat[i][j] = std::log(feat[i][j] + 1e-14);
}
}
}
return true;
return true;
}
} // namespace ppspeech

@ -1,32 +1,45 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base/common.h"
#include "frontend/feature_extractor_interface.h"
#include "kaldi/feat/feature-window.h"
#include "base/common.h"
namespace ppspeech {
struct LinearSpectrogramOptions {
kaldi::FrameExtractionOptions frame_opts;
LinearSpectrogramOptions():
frame_opts() {}
LinearSpectrogramOptions() : frame_opts() {}
void Register(kaldi::OptionsItf* opts) {
frame_opts.Register(opts);
}
void Register(kaldi::OptionsItf* opts) { frame_opts.Register(opts); }
};
class LinearSpectrogram : public FeatureExtractorInterface {
public:
explicit LinearSpectrogram(const LinearSpectrogramOptions& opts,
std::unique_ptr<FeatureExtractorInterface> base_extractor);
virtual void AcceptWaveform(const kaldi::VectorBase<kaldi::BaseFloat>& input);
explicit LinearSpectrogram(
const LinearSpectrogramOptions& opts,
std::unique_ptr<FeatureExtractorInterface> base_extractor);
virtual void AcceptWaveform(
const kaldi::VectorBase<kaldi::BaseFloat>& input);
virtual void Read(kaldi::VectorBase<kaldi::BaseFloat>* feat);
virtual size_t Dim() const { return dim_; }
void ReadFeats(kaldi::Matrix<kaldi::BaseFloat>* feats);
private:
private:
void Hanning(std::vector<kaldi::BaseFloat>* data) const;
bool Compute(const std::vector<kaldi::BaseFloat>& wave,
std::vector<std::vector<kaldi::BaseFloat>>& feat);
@ -41,7 +54,7 @@ class LinearSpectrogram : public FeatureExtractorInterface {
std::vector<kaldi::BaseFloat> hanning_window_;
kaldi::BaseFloat hanning_window_energy_;
LinearSpectrogramOptions opts_;
kaldi::Vector<kaldi::BaseFloat> waveform_; // remove later, todo(SmileGoat)
kaldi::Vector<kaldi::BaseFloat> waveform_; // remove later, todo(SmileGoat)
std::unique_ptr<FeatureExtractorInterface> base_extractor_;
DISALLOW_COPY_AND_ASSIGN(LinearSpectrogram);
};

@ -1,3 +1,17 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "frontend/normalizer.h"
#include "kaldi/feat/cmvn.h"
@ -12,169 +26,173 @@ using std::vector;
using kaldi::SubVector;
DecibelNormalizer::DecibelNormalizer(const DecibelNormalizerOptions& opts) {
opts_ = opts;
dim_ = 0;
opts_ = opts;
dim_ = 0;
}
void DecibelNormalizer::AcceptWaveform(const kaldi::VectorBase<BaseFloat>& input) {
dim_ = input.Dim();
waveform_.Resize(input.Dim());
waveform_.CopyFromVec(input);
void DecibelNormalizer::AcceptWaveform(
const kaldi::VectorBase<BaseFloat>& input) {
dim_ = input.Dim();
waveform_.Resize(input.Dim());
waveform_.CopyFromVec(input);
}
void DecibelNormalizer::Read(kaldi::VectorBase<BaseFloat>* feat) {
if (waveform_.Dim() == 0) return;
Compute(waveform_, feat);
if (waveform_.Dim() == 0) return;
Compute(waveform_, feat);
}
//todo remove later
// todo remove later
void CopyVector2StdVector(const kaldi::VectorBase<BaseFloat>& input,
vector<BaseFloat>* output) {
if (input.Dim() == 0) return;
output->resize(input.Dim());
for (size_t idx = 0; idx < input.Dim(); ++idx) {
(*output)[idx] = input(idx);
}
if (input.Dim() == 0) return;
output->resize(input.Dim());
for (size_t idx = 0; idx < input.Dim(); ++idx) {
(*output)[idx] = input(idx);
}
}
void CopyStdVector2Vector(const vector<BaseFloat>& input,
VectorBase<BaseFloat>* output) {
if (input.empty()) return;
assert(input.size() == output->Dim());
for (size_t idx = 0; idx < input.size(); ++idx) {
(*output)(idx) = input[idx];
}
if (input.empty()) return;
assert(input.size() == output->Dim());
for (size_t idx = 0; idx < input.size(); ++idx) {
(*output)(idx) = input[idx];
}
}
bool DecibelNormalizer::Compute(const VectorBase<BaseFloat>& input,
VectorBase<BaseFloat>* feat) const {
// calculate db rms
BaseFloat rms_db = 0.0;
BaseFloat mean_square = 0.0;
BaseFloat gain = 0.0;
BaseFloat wave_float_normlization = 1.0f / (std::pow(2, 16 - 1));
vector<BaseFloat> samples;
samples.resize(input.Dim());
for (int32 i = 0; i < samples.size(); ++i) {
samples[i] = input(i);
}
// square
for (auto &d : samples) {
if (opts_.convert_int_float) {
d = d * wave_float_normlization;
// calculate db rms
BaseFloat rms_db = 0.0;
BaseFloat mean_square = 0.0;
BaseFloat gain = 0.0;
BaseFloat wave_float_normlization = 1.0f / (std::pow(2, 16 - 1));
vector<BaseFloat> samples;
samples.resize(input.Dim());
for (int32 i = 0; i < samples.size(); ++i) {
samples[i] = input(i);
}
mean_square += d * d;
}
// mean
mean_square /= samples.size();
rms_db = 10 * std::log10(mean_square);
gain = opts_.target_db - rms_db;
if (gain > opts_.max_gain_db) {
LOG(ERROR) << "Unable to normalize segment to " << opts_.target_db << "dB,"
<< "because the the probable gain have exceeds opts_.max_gain_db"
<< opts_.max_gain_db << "dB.";
return false;
}
// Note that this is an in-place transformation.
for (auto &item : samples) {
// python item *= 10.0 ** (gain / 20.0)
item *= std::pow(10.0, gain / 20.0);
}
CopyStdVector2Vector(samples, feat);
return true;
// square
for (auto& d : samples) {
if (opts_.convert_int_float) {
d = d * wave_float_normlization;
}
mean_square += d * d;
}
// mean
mean_square /= samples.size();
rms_db = 10 * std::log10(mean_square);
gain = opts_.target_db - rms_db;
if (gain > opts_.max_gain_db) {
LOG(ERROR)
<< "Unable to normalize segment to " << opts_.target_db << "dB,"
<< "because the the probable gain have exceeds opts_.max_gain_db"
<< opts_.max_gain_db << "dB.";
return false;
}
// Note that this is an in-place transformation.
for (auto& item : samples) {
// python item *= 10.0 ** (gain / 20.0)
item *= std::pow(10.0, gain / 20.0);
}
CopyStdVector2Vector(samples, feat);
return true;
}
CMVN::CMVN(std::string cmvn_file) : var_norm_(true) {
bool binary;
kaldi::Input ki(cmvn_file, &binary);
stats_.Read(ki.Stream(), binary);
bool binary;
kaldi::Input ki(cmvn_file, &binary);
stats_.Read(ki.Stream(), binary);
}
void CMVN::AcceptWaveform(const kaldi::VectorBase<kaldi::BaseFloat>& input) {
return;
return;
}
void CMVN::Read(kaldi::VectorBase<BaseFloat>* feat) {
return;
}
void CMVN::Read(kaldi::VectorBase<BaseFloat>* feat) { return; }
// feats contain num_frames feature.
void CMVN::ApplyCMVN(bool var_norm, VectorBase<BaseFloat>* feats) {
KALDI_ASSERT(feats != NULL);
int32 dim = stats_.NumCols() - 1;
if (stats_.NumRows() > 2 || stats_.NumRows() < 1 || feats->Dim() % dim != 0) {
KALDI_ERR << "Dim mismatch: cmvn "
<< stats_.NumRows() << 'x' << stats_.NumCols()
<< ", feats " << feats->Dim() << 'x';
}
if (stats_.NumRows() == 1 && var_norm) {
KALDI_ERR << "You requested variance normalization but no variance stats_ "
<< "are supplied.";
}
double count = stats_(0, dim);
// Do not change the threshold of 1.0 here: in the balanced-cmvn code, when
// computing an offset and representing it as stats_, we use a count of one.
if (count < 1.0)
KALDI_ERR << "Insufficient stats_ for cepstral mean and variance normalization: "
<< "count = " << count;
if (!var_norm) {
Vector<BaseFloat> offset(feats->Dim());
SubVector<double> mean_stats(stats_.RowData(0), dim);
Vector<double> mean_stats_apply(feats->Dim());
//fill the datat of mean_stats in mean_stats_appy whose dim is equal with the dim of feature.
//the dim of feats = dim * num_frames;
for (int32 idx = 0; idx < feats->Dim() / dim; ++idx) {
SubVector<double> stats_tmp(mean_stats_apply.Data() + dim*idx, dim);
stats_tmp.CopyFromVec(mean_stats);
KALDI_ASSERT(feats != NULL);
int32 dim = stats_.NumCols() - 1;
if (stats_.NumRows() > 2 || stats_.NumRows() < 1 ||
feats->Dim() % dim != 0) {
KALDI_ERR << "Dim mismatch: cmvn " << stats_.NumRows() << 'x'
<< stats_.NumCols() << ", feats " << feats->Dim() << 'x';
}
offset.AddVec(-1.0 / count, mean_stats_apply);
feats->AddVec(1.0, offset);
return;
}
// norm(0, d) = mean offset;
// norm(1, d) = scale, e.g. x(d) <-- x(d)*norm(1, d) + norm(0, d).
kaldi::Matrix<BaseFloat> norm(2, feats->Dim());
for (int32 d = 0; d < dim; d++) {
double mean, offset, scale;
mean = stats_(0, d)/count;
double var = (stats_(1, d)/count) - mean*mean,
floor = 1.0e-20;
if (var < floor) {
KALDI_WARN << "Flooring cepstral variance from " << var << " to "
<< floor;
var = floor;
if (stats_.NumRows() == 1 && var_norm) {
KALDI_ERR
<< "You requested variance normalization but no variance stats_ "
<< "are supplied.";
}
double count = stats_(0, dim);
// Do not change the threshold of 1.0 here: in the balanced-cmvn code, when
// computing an offset and representing it as stats_, we use a count of one.
if (count < 1.0)
KALDI_ERR << "Insufficient stats_ for cepstral mean and variance "
"normalization: "
<< "count = " << count;
if (!var_norm) {
Vector<BaseFloat> offset(feats->Dim());
SubVector<double> mean_stats(stats_.RowData(0), dim);
Vector<double> mean_stats_apply(feats->Dim());
// fill the datat of mean_stats in mean_stats_appy whose dim is equal
// with the dim of feature.
// the dim of feats = dim * num_frames;
for (int32 idx = 0; idx < feats->Dim() / dim; ++idx) {
SubVector<double> stats_tmp(mean_stats_apply.Data() + dim * idx,
dim);
stats_tmp.CopyFromVec(mean_stats);
}
offset.AddVec(-1.0 / count, mean_stats_apply);
feats->AddVec(1.0, offset);
return;
}
scale = 1.0 / sqrt(var);
if (scale != scale || 1/scale == 0.0)
KALDI_ERR << "NaN or infinity in cepstral mean/variance computation";
offset = -(mean*scale);
for (int32 d_skip = d; d_skip < feats->Dim();) {
norm(0, d_skip) = offset;
norm(1, d_skip) = scale;
d_skip = d_skip + dim;
// norm(0, d) = mean offset;
// norm(1, d) = scale, e.g. x(d) <-- x(d)*norm(1, d) + norm(0, d).
kaldi::Matrix<BaseFloat> norm(2, feats->Dim());
for (int32 d = 0; d < dim; d++) {
double mean, offset, scale;
mean = stats_(0, d) / count;
double var = (stats_(1, d) / count) - mean * mean, floor = 1.0e-20;
if (var < floor) {
KALDI_WARN << "Flooring cepstral variance from " << var << " to "
<< floor;
var = floor;
}
scale = 1.0 / sqrt(var);
if (scale != scale || 1 / scale == 0.0)
KALDI_ERR
<< "NaN or infinity in cepstral mean/variance computation";
offset = -(mean * scale);
for (int32 d_skip = d; d_skip < feats->Dim();) {
norm(0, d_skip) = offset;
norm(1, d_skip) = scale;
d_skip = d_skip + dim;
}
}
}
// Apply the normalization.
feats->MulElements(norm.Row(1));
feats->AddVec(1.0, norm.Row(0));
// Apply the normalization.
feats->MulElements(norm.Row(1));
feats->AddVec(1.0, norm.Row(0));
}
void CMVN::ApplyCMVNMatrix(bool var_norm, kaldi::MatrixBase<BaseFloat>* feats) {
ApplyCmvn(stats_, var_norm, feats);
ApplyCmvn(stats_, var_norm, feats);
}
bool CMVN::Compute(const VectorBase<BaseFloat>& input,
VectorBase<BaseFloat>* feat) const {
return false;
return false;
}
} // namespace ppspeech
} // namespace ppspeech

@ -1,37 +1,55 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base/common.h"
#include "frontend/feature_extractor_interface.h"
#include "kaldi/util/options-itf.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "kaldi/util/options-itf.h"
namespace ppspeech {
struct DecibelNormalizerOptions {
float target_db;
float max_gain_db;
bool convert_int_float;
DecibelNormalizerOptions() :
target_db(-20),
max_gain_db(300.0),
convert_int_float(false) {}
float target_db;
float max_gain_db;
bool convert_int_float;
DecibelNormalizerOptions()
: target_db(-20), max_gain_db(300.0), convert_int_float(false) {}
void Register(kaldi::OptionsItf* opts) {
opts->Register("target-db", &target_db, "target db for db normalization");
opts->Register("max-gain-db", &max_gain_db, "max gain db for db normalization");
opts->Register("convert-int-float", &convert_int_float, "if convert int samples to float");
opts->Register(
"target-db", &target_db, "target db for db normalization");
opts->Register(
"max-gain-db", &max_gain_db, "max gain db for db normalization");
opts->Register("convert-int-float",
&convert_int_float,
"if convert int samples to float");
}
};
class DecibelNormalizer : public FeatureExtractorInterface {
public:
explicit DecibelNormalizer(const DecibelNormalizerOptions& opts);
virtual void AcceptWaveform(const kaldi::VectorBase<kaldi::BaseFloat>& input);
virtual void AcceptWaveform(
const kaldi::VectorBase<kaldi::BaseFloat>& input);
virtual void Read(kaldi::VectorBase<kaldi::BaseFloat>* feat);
virtual size_t Dim() const { return dim_; }
bool Compute(const kaldi::VectorBase<kaldi::BaseFloat>& input,
kaldi::VectorBase<kaldi::BaseFloat>* feat) const;
private:
DecibelNormalizerOptions opts_;
size_t dim_;
@ -43,7 +61,8 @@ class DecibelNormalizer : public FeatureExtractorInterface {
class CMVN : public FeatureExtractorInterface {
public:
explicit CMVN(std::string cmvn_file);
virtual void AcceptWaveform(const kaldi::VectorBase<kaldi::BaseFloat>& input);
virtual void AcceptWaveform(
const kaldi::VectorBase<kaldi::BaseFloat>& input);
virtual void Read(kaldi::VectorBase<kaldi::BaseFloat>* feat);
virtual size_t Dim() const { return stats_.NumCols() - 1; }
bool Compute(const kaldi::VectorBase<kaldi::BaseFloat>& input,
@ -51,6 +70,7 @@ class CMVN : public FeatureExtractorInterface {
// for test
void ApplyCMVN(bool var_norm, kaldi::VectorBase<BaseFloat>* feats);
void ApplyCMVNMatrix(bool var_norm, kaldi::MatrixBase<BaseFloat>* feats);
private:
kaldi::Matrix<double> stats_;
std::shared_ptr<FeatureExtractorInterface> base_extractor_;

@ -13,4 +13,3 @@
// limitations under the License.
// extract the window of kaldi feat.

@ -1,3 +1,17 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// itf/decodable-itf.h
// Copyright 2009-2011 Microsoft Corporation; Saarland University;
@ -42,8 +56,10 @@ namespace kaldi {
For online decoding, where the features are coming in in real time, it is
important to understand the IsLastFrame() and NumFramesReady() functions.
There are two ways these are used: the old online-decoding code, in ../online/,
and the new online-decoding code, in ../online2/. In the old online-decoding
There are two ways these are used: the old online-decoding code, in
../online/,
and the new online-decoding code, in ../online2/. In the old
online-decoding
code, the decoder would do:
\code{.cc}
for (int frame = 0; !decodable.IsLastFrame(frame); frame++) {
@ -52,13 +68,16 @@ namespace kaldi {
\endcode
and the call to IsLastFrame would block if the features had not arrived yet.
The decodable object would have to know when to terminate the decoding. This
online-decoding mode is still supported, it is what happens when you call, for
online-decoding mode is still supported, it is what happens when you call,
for
example, LatticeFasterDecoder::Decode().
We realized that this "blocking" mode of decoding is not very convenient
because it forces the program to be multi-threaded and makes it complex to
control endpointing. In the "new" decoding code, you don't call (for example)
LatticeFasterDecoder::Decode(), you call LatticeFasterDecoder::InitDecoding(),
control endpointing. In the "new" decoding code, you don't call (for
example)
LatticeFasterDecoder::Decode(), you call
LatticeFasterDecoder::InitDecoding(),
and then each time you get more features, you provide them to the decodable
object, and you call LatticeFasterDecoder::AdvanceDecoding(), which does
something like this:
@ -68,7 +87,8 @@ namespace kaldi {
}
\endcode
So the decodable object never has IsLastFrame() called. For decoding where
you are starting with a matrix of features, the NumFramesReady() function will
you are starting with a matrix of features, the NumFramesReady() function
will
always just return the number of frames in the file, and IsLastFrame() will
return true for the last frame.
@ -80,43 +100,52 @@ namespace kaldi {
frame of the file once we've decided to terminate decoding.
*/
class DecodableInterface {
public:
/// Returns the log likelihood, which will be negated in the decoder.
/// The "frame" starts from zero. You should verify that NumFramesReady() > frame
/// before calling this.
virtual BaseFloat LogLikelihood(int32 frame, int32 index) = 0;
/// Returns true if this is the last frame. Frames are zero-based, so the
/// first frame is zero. IsLastFrame(-1) will return false, unless the file
/// is empty (which is a case that I'm not sure all the code will handle, so
/// be careful). Caution: the behavior of this function in an online setting
/// is being changed somewhat. In future it may return false in cases where
/// we haven't yet decided to terminate decoding, but later true if we decide
/// to terminate decoding. The plan in future is to rely more on
/// NumFramesReady(), and in future, IsLastFrame() would always return false
/// in an online-decoding setting, and would only return true in a
/// decoding-from-matrix setting where we want to allow the last delta or LDA
/// features to be flushed out for compatibility with the baseline setup.
virtual bool IsLastFrame(int32 frame) const = 0;
/// The call NumFramesReady() will return the number of frames currently available
/// for this decodable object. This is for use in setups where you don't want the
/// decoder to block while waiting for input. This is newly added as of Jan 2014,
/// and I hope, going forward, to rely on this mechanism more than IsLastFrame to
/// know when to stop decoding.
virtual int32 NumFramesReady() const {
KALDI_ERR << "NumFramesReady() not implemented for this decodable type.";
return -1;
}
/// Returns the number of states in the acoustic model
/// (they will be indexed one-based, i.e. from 1 to NumIndices();
/// this is for compatibility with OpenFst).
virtual int32 NumIndices() const = 0;
virtual std::vector<BaseFloat> FrameLogLikelihood(int32 frame) = 0;
virtual ~DecodableInterface() {}
public:
/// Returns the log likelihood, which will be negated in the decoder.
/// The "frame" starts from zero. You should verify that NumFramesReady() >
/// frame
/// before calling this.
virtual BaseFloat LogLikelihood(int32 frame, int32 index) = 0;
/// Returns true if this is the last frame. Frames are zero-based, so the
/// first frame is zero. IsLastFrame(-1) will return false, unless the file
/// is empty (which is a case that I'm not sure all the code will handle, so
/// be careful). Caution: the behavior of this function in an online
/// setting
/// is being changed somewhat. In future it may return false in cases where
/// we haven't yet decided to terminate decoding, but later true if we
/// decide
/// to terminate decoding. The plan in future is to rely more on
/// NumFramesReady(), and in future, IsLastFrame() would always return false
/// in an online-decoding setting, and would only return true in a
/// decoding-from-matrix setting where we want to allow the last delta or
/// LDA
/// features to be flushed out for compatibility with the baseline setup.
virtual bool IsLastFrame(int32 frame) const = 0;
/// The call NumFramesReady() will return the number of frames currently
/// available
/// for this decodable object. This is for use in setups where you don't
/// want the
/// decoder to block while waiting for input. This is newly added as of Jan
/// 2014,
/// and I hope, going forward, to rely on this mechanism more than
/// IsLastFrame to
/// know when to stop decoding.
virtual int32 NumFramesReady() const {
KALDI_ERR
<< "NumFramesReady() not implemented for this decodable type.";
return -1;
}
/// Returns the number of states in the acoustic model
/// (they will be indexed one-based, i.e. from 1 to NumIndices();
/// this is for compatibility with OpenFst).
virtual int32 NumIndices() const = 0;
virtual std::vector<BaseFloat> FrameLogLikelihood(int32 frame) = 0;
virtual ~DecodableInterface() {}
};
/// @}
} // namespace Kaldi

@ -1,3 +1,17 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "nnet/decodable.h"
namespace ppspeech {
@ -5,51 +19,43 @@ namespace ppspeech {
using kaldi::BaseFloat;
using kaldi::Matrix;
Decodable::Decodable(const std::shared_ptr<NnetInterface>& nnet):
frontend_(NULL),
nnet_(nnet),
finished_(false),
frames_ready_(0) {
}
Decodable::Decodable(const std::shared_ptr<NnetInterface>& nnet)
: frontend_(NULL), nnet_(nnet), finished_(false), frames_ready_(0) {}
void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) {
frames_ready_ += likelihood.NumRows();
frames_ready_ += likelihood.NumRows();
}
//Decodable::Init(DecodableConfig config) {
// Decodable::Init(DecodableConfig config) {
//}
bool Decodable::IsLastFrame(int32 frame) const {
CHECK_LE(frame, frames_ready_);
return finished_ && (frame == frames_ready_ - 1);
CHECK_LE(frame, frames_ready_);
return finished_ && (frame == frames_ready_ - 1);
}
int32 Decodable::NumIndices() const {
return 0;
}
int32 Decodable::NumIndices() const { return 0; }
BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) {
return 0;
}
BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) { return 0; }
void Decodable::FeedFeatures(const Matrix<kaldi::BaseFloat>& features) {
nnet_->FeedForward(features, &nnet_cache_);
frames_ready_ += nnet_cache_.NumRows();
return ;
nnet_->FeedForward(features, &nnet_cache_);
frames_ready_ += nnet_cache_.NumRows();
return;
}
std::vector<BaseFloat> Decodable::FrameLogLikelihood(int32 frame) {
std::vector<BaseFloat> result;
result.reserve(nnet_cache_.NumCols());
for (int32 idx = 0; idx < nnet_cache_.NumCols(); ++idx) {
result[idx] = nnet_cache_(frame, idx);
}
return result;
std::vector<BaseFloat> result;
result.reserve(nnet_cache_.NumCols());
for (int32 idx = 0; idx < nnet_cache_.NumCols(); ++idx) {
result[idx] = nnet_cache_(frame, idx);
}
return result;
}
void Decodable::Reset() {
// frontend_.Reset();
nnet_->Reset();
// frontend_.Reset();
nnet_->Reset();
}
} // namespace ppspeech
} // namespace ppspeech

@ -1,7 +1,21 @@
#include "nnet/decodable-itf.h"
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "base/common.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "frontend/feature_extractor_interface.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "nnet/decodable-itf.h"
#include "nnet/nnet_interface.h"
namespace ppspeech {
@ -9,17 +23,20 @@ namespace ppspeech {
struct DecodableOpts;
class Decodable : public kaldi::DecodableInterface {
public:
public:
explicit Decodable(const std::shared_ptr<NnetInterface>& nnet);
//void Init(DecodableOpts config);
// void Init(DecodableOpts config);
virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index);
virtual bool IsLastFrame(int32 frame) const;
virtual int32 NumIndices() const;
virtual std::vector<BaseFloat> FrameLogLikelihood(int32 frame);
void Acceptlikelihood(const kaldi::Matrix<kaldi::BaseFloat>& likelihood); // remove later
void FeedFeatures(const kaldi::Matrix<kaldi::BaseFloat>& feature); // only for test, todo remove later
void Acceptlikelihood(
const kaldi::Matrix<kaldi::BaseFloat>& likelihood); // remove later
void FeedFeatures(const kaldi::Matrix<kaldi::BaseFloat>&
feature); // only for test, todo remove later
void Reset();
void InputFinished() { finished_ = true; }
private:
std::shared_ptr<FeatureExtractorInterface> frontend_;
std::shared_ptr<NnetInterface> nnet_;

@ -1,3 +1,17 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
@ -10,10 +24,9 @@ namespace ppspeech {
class NnetInterface {
public:
virtual void FeedForward(const kaldi::Matrix<kaldi::BaseFloat>& features,
kaldi::Matrix<kaldi::BaseFloat>* inferences)= 0;
kaldi::Matrix<kaldi::BaseFloat>* inferences) = 0;
virtual void Reset() = 0;
virtual ~NnetInterface() {}
};
} // namespace ppspeech

@ -1,3 +1,17 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "nnet/paddle_nnet.h"
#include "absl/strings/str_split.h"
@ -9,43 +23,44 @@ using std::shared_ptr;
using kaldi::Matrix;
void PaddleNnet::InitCacheEncouts(const ModelOptions& opts) {
std::vector<std::string> cache_names;
cache_names = absl::StrSplit(opts.cache_names, ",");
std::vector<std::string> cache_shapes;
cache_shapes = absl::StrSplit(opts.cache_shape, ",");
assert(cache_shapes.size() == cache_names.size());
cache_encouts_.clear();
cache_names_idx_.clear();
for (size_t i = 0; i < cache_shapes.size(); i++) {
std::vector<std::string> tmp_shape;
tmp_shape = absl::StrSplit(cache_shapes[i], "-");
std::vector<int> cur_shape;
std::transform(tmp_shape.begin(), tmp_shape.end(),
std::back_inserter(cur_shape),
[](const std::string& s) {
return atoi(s.c_str());
});
cache_names_idx_[cache_names[i]] = i;
std::shared_ptr<Tensor<BaseFloat>> cache_eout = std::make_shared<Tensor<BaseFloat>>(cur_shape);
cache_encouts_.push_back(cache_eout);
}
std::vector<std::string> cache_names;
cache_names = absl::StrSplit(opts.cache_names, ",");
std::vector<std::string> cache_shapes;
cache_shapes = absl::StrSplit(opts.cache_shape, ",");
assert(cache_shapes.size() == cache_names.size());
cache_encouts_.clear();
cache_names_idx_.clear();
for (size_t i = 0; i < cache_shapes.size(); i++) {
std::vector<std::string> tmp_shape;
tmp_shape = absl::StrSplit(cache_shapes[i], "-");
std::vector<int> cur_shape;
std::transform(tmp_shape.begin(),
tmp_shape.end(),
std::back_inserter(cur_shape),
[](const std::string& s) { return atoi(s.c_str()); });
cache_names_idx_[cache_names[i]] = i;
std::shared_ptr<Tensor<BaseFloat>> cache_eout =
std::make_shared<Tensor<BaseFloat>>(cur_shape);
cache_encouts_.push_back(cache_eout);
}
}
PaddleNnet::PaddleNnet(const ModelOptions& opts):opts_(opts) {
PaddleNnet::PaddleNnet(const ModelOptions& opts) : opts_(opts) {
paddle_infer::Config config;
config.SetModel(opts.model_path, opts.params_path);
if (opts.use_gpu) {
config.EnableUseGpu(500, 0);
config.EnableUseGpu(500, 0);
}
config.SwitchIrOptim(opts.switch_ir_optim);
if (opts.enable_fc_padding == false) {
config.DisableFCPadding();
config.DisableFCPadding();
}
if (opts.enable_profile) {
config.EnableProfile();
config.EnableProfile();
}
pool.reset(new paddle_infer::services::PredictorPool(config, opts.thread_num));
pool.reset(
new paddle_infer::services::PredictorPool(config, opts.thread_num));
if (pool == nullptr) {
LOG(ERROR) << "create the predictor pool failed";
}
@ -59,7 +74,7 @@ PaddleNnet::PaddleNnet(const ModelOptions& opts):opts_(opts) {
vector<string> input_names_vec = absl::StrSplit(opts.input_names, ",");
vector<string> output_names_vec = absl::StrSplit(opts.output_names, ",");
paddle_infer::Predictor* predictor = GetPredictor();
std::vector<std::string> model_input_names = predictor->GetInputNames();
assert(input_names_vec.size() == model_input_names.size());
for (size_t i = 0; i < model_input_names.size(); i++) {
@ -68,16 +83,14 @@ PaddleNnet::PaddleNnet(const ModelOptions& opts):opts_(opts) {
std::vector<std::string> model_output_names = predictor->GetOutputNames();
assert(output_names_vec.size() == model_output_names.size());
for (size_t i = 0;i < output_names_vec.size(); i++) {
for (size_t i = 0; i < output_names_vec.size(); i++) {
assert(output_names_vec[i] == model_output_names[i]);
}
ReleasePredictor(predictor);
InitCacheEncouts(opts);
}
void PaddleNnet::Reset() {
InitCacheEncouts(opts_);
}
void PaddleNnet::Reset() { InitCacheEncouts(opts_); }
paddle_infer::Predictor* PaddleNnet::GetPredictor() {
LOG(INFO) << "attempt to get a new predictor instance " << std::endl;
@ -122,80 +135,88 @@ int PaddleNnet::ReleasePredictor(paddle_infer::Predictor* predictor) {
}
shared_ptr<Tensor<BaseFloat>> PaddleNnet::GetCacheEncoder(const string& name) {
auto iter = cache_names_idx_.find(name);
if (iter == cache_names_idx_.end()) {
return nullptr;
}
assert(iter->second < cache_encouts_.size());
return cache_encouts_[iter->second];
auto iter = cache_names_idx_.find(name);
if (iter == cache_names_idx_.end()) {
return nullptr;
}
assert(iter->second < cache_encouts_.size());
return cache_encouts_[iter->second];
}
void PaddleNnet::FeedForward(const Matrix<BaseFloat>& features, Matrix<BaseFloat>* inferences) {
paddle_infer::Predictor* predictor = GetPredictor();
int row = features.NumRows();
int col = features.NumCols();
std::vector<BaseFloat> feed_feature;
// todo refactor feed feature: SmileGoat
feed_feature.reserve(row*col);
for (size_t row_idx = 0; row_idx < features.NumRows(); ++row_idx) {
for (size_t col_idx = 0; col_idx < features.NumCols(); ++col_idx) {
feed_feature.push_back(features(row_idx, col_idx));
void PaddleNnet::FeedForward(const Matrix<BaseFloat>& features,
Matrix<BaseFloat>* inferences) {
paddle_infer::Predictor* predictor = GetPredictor();
int row = features.NumRows();
int col = features.NumCols();
std::vector<BaseFloat> feed_feature;
// todo refactor feed feature: SmileGoat
feed_feature.reserve(row * col);
for (size_t row_idx = 0; row_idx < features.NumRows(); ++row_idx) {
for (size_t col_idx = 0; col_idx < features.NumCols(); ++col_idx) {
feed_feature.push_back(features(row_idx, col_idx));
}
}
std::vector<std::string> input_names = predictor->GetInputNames();
std::vector<std::string> output_names = predictor->GetOutputNames();
LOG(INFO) << "feat info: row=" << row << ", col= " << col;
std::unique_ptr<paddle_infer::Tensor> input_tensor =
predictor->GetInputHandle(input_names[0]);
std::vector<int> INPUT_SHAPE = {1, row, col};
input_tensor->Reshape(INPUT_SHAPE);
input_tensor->CopyFromCpu(feed_feature.data());
std::unique_ptr<paddle_infer::Tensor> input_len =
predictor->GetInputHandle(input_names[1]);
std::vector<int> input_len_size = {1};
input_len->Reshape(input_len_size);
std::vector<int64_t> audio_len;
audio_len.push_back(row);
input_len->CopyFromCpu(audio_len.data());
std::unique_ptr<paddle_infer::Tensor> h_box =
predictor->GetInputHandle(input_names[2]);
shared_ptr<Tensor<BaseFloat>> h_cache = GetCacheEncoder(input_names[2]);
h_box->Reshape(h_cache->get_shape());
h_box->CopyFromCpu(h_cache->get_data().data());
std::unique_ptr<paddle_infer::Tensor> c_box =
predictor->GetInputHandle(input_names[3]);
shared_ptr<Tensor<float>> c_cache = GetCacheEncoder(input_names[3]);
c_box->Reshape(c_cache->get_shape());
c_box->CopyFromCpu(c_cache->get_data().data());
bool success = predictor->Run();
if (success == false) {
LOG(INFO) << "predictor run occurs error";
}
}
std::vector<std::string> input_names = predictor->GetInputNames();
std::vector<std::string> output_names = predictor->GetOutputNames();
LOG(INFO) << "feat info: row=" << row << ", col= " << col;
std::unique_ptr<paddle_infer::Tensor> input_tensor = predictor->GetInputHandle(input_names[0]);
std::vector<int> INPUT_SHAPE = {1, row, col};
input_tensor->Reshape(INPUT_SHAPE);
input_tensor->CopyFromCpu(feed_feature.data());
std::unique_ptr<paddle_infer::Tensor> input_len = predictor->GetInputHandle(input_names[1]);
std::vector<int> input_len_size = {1};
input_len->Reshape(input_len_size);
std::vector<int64_t> audio_len;
audio_len.push_back(row);
input_len->CopyFromCpu(audio_len.data());
std::unique_ptr<paddle_infer::Tensor> h_box = predictor->GetInputHandle(input_names[2]);
shared_ptr<Tensor<BaseFloat>> h_cache = GetCacheEncoder(input_names[2]);
h_box->Reshape(h_cache->get_shape());
h_box->CopyFromCpu(h_cache->get_data().data());
std::unique_ptr<paddle_infer::Tensor> c_box = predictor->GetInputHandle(input_names[3]);
shared_ptr<Tensor<float>> c_cache = GetCacheEncoder(input_names[3]);
c_box->Reshape(c_cache->get_shape());
c_box->CopyFromCpu(c_cache->get_data().data());
bool success = predictor->Run();
if (success == false) {
LOG(INFO) << "predictor run occurs error";
}
LOG(INFO) << "get the model success";
std::unique_ptr<paddle_infer::Tensor> h_out = predictor->GetOutputHandle(output_names[2]);
assert(h_cache->get_shape() == h_out->shape());
h_out->CopyToCpu(h_cache->get_data().data());
std::unique_ptr<paddle_infer::Tensor> c_out = predictor->GetOutputHandle(output_names[3]);
assert(c_cache->get_shape() == c_out->shape());
c_out->CopyToCpu(c_cache->get_data().data());
// get result
std::unique_ptr<paddle_infer::Tensor> output_tensor =
predictor->GetOutputHandle(output_names[0]);
std::vector<int> output_shape = output_tensor->shape();
row = output_shape[1];
col = output_shape[2];
vector<float> inferences_result;
inferences->Resize(row, col);
inferences_result.resize(row*col);
output_tensor->CopyToCpu(inferences_result.data());
ReleasePredictor(predictor);
for (int row_idx = 0; row_idx < row; ++row_idx) {
for (int col_idx = 0; col_idx < col; ++col_idx) {
(*inferences)(row_idx, col_idx) = inferences_result[col*row_idx + col_idx];
LOG(INFO) << "get the model success";
std::unique_ptr<paddle_infer::Tensor> h_out =
predictor->GetOutputHandle(output_names[2]);
assert(h_cache->get_shape() == h_out->shape());
h_out->CopyToCpu(h_cache->get_data().data());
std::unique_ptr<paddle_infer::Tensor> c_out =
predictor->GetOutputHandle(output_names[3]);
assert(c_cache->get_shape() == c_out->shape());
c_out->CopyToCpu(c_cache->get_data().data());
// get result
std::unique_ptr<paddle_infer::Tensor> output_tensor =
predictor->GetOutputHandle(output_names[0]);
std::vector<int> output_shape = output_tensor->shape();
row = output_shape[1];
col = output_shape[2];
vector<float> inferences_result;
inferences->Resize(row, col);
inferences_result.resize(row * col);
output_tensor->CopyToCpu(inferences_result.data());
ReleasePredictor(predictor);
for (int row_idx = 0; row_idx < row; ++row_idx) {
for (int col_idx = 0; col_idx < col; ++col_idx) {
(*inferences)(row_idx, col_idx) =
inferences_result[col * row_idx + col_idx];
}
}
}
}
} // namespace ppspeech
} // namespace ppspeech

@ -1,8 +1,22 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "nnet/nnet_interface.h"
#include "base/common.h"
#include "nnet/nnet_interface.h"
#include "paddle_inference_api.h"
#include "kaldi/matrix/kaldi-matrix.h"
@ -13,71 +27,79 @@
namespace ppspeech {
struct ModelOptions {
std::string model_path;
std::string params_path;
int thread_num;
bool use_gpu;
bool switch_ir_optim;
std::string input_names;
std::string output_names;
std::string cache_names;
std::string cache_shape;
bool enable_fc_padding;
bool enable_profile;
ModelOptions() :
model_path("../../../../model/paddle_online_deepspeech/model/avg_1.jit.pdmodel"),
params_path("../../../../model/paddle_online_deepspeech/model/avg_1.jit.pdiparams"),
thread_num(2),
use_gpu(false),
input_names("audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box"),
output_names("save_infer_model/scale_0.tmp_1,save_infer_model/scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/scale_3.tmp_1"),
cache_names("chunk_state_h_box,chunk_state_c_box"),
cache_shape("3-1-1024,3-1-1024"),
switch_ir_optim(false),
enable_fc_padding(false),
enable_profile(false) {
}
std::string model_path;
std::string params_path;
int thread_num;
bool use_gpu;
bool switch_ir_optim;
std::string input_names;
std::string output_names;
std::string cache_names;
std::string cache_shape;
bool enable_fc_padding;
bool enable_profile;
ModelOptions()
: model_path(
"../../../../model/paddle_online_deepspeech/model/"
"avg_1.jit.pdmodel"),
params_path(
"../../../../model/paddle_online_deepspeech/model/"
"avg_1.jit.pdiparams"),
thread_num(2),
use_gpu(false),
input_names(
"audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_"
"box"),
output_names(
"save_infer_model/scale_0.tmp_1,save_infer_model/"
"scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/"
"scale_3.tmp_1"),
cache_names("chunk_state_h_box,chunk_state_c_box"),
cache_shape("3-1-1024,3-1-1024"),
switch_ir_optim(false),
enable_fc_padding(false),
enable_profile(false) {}
void Register(kaldi::OptionsItf* opts) {
opts->Register("model-path", &model_path, "model file path");
opts->Register("model-params", &params_path, "params model file path");
opts->Register("thread-num", &thread_num, "thread num");
opts->Register("use-gpu", &use_gpu, "if use gpu");
opts->Register("input-names", &input_names, "paddle input names");
opts->Register("output-names", &output_names, "paddle output names");
opts->Register("cache-names", &cache_names, "cache names");
opts->Register("cache-shape", &cache_shape, "cache shape");
opts->Register("switch-ir-optiom", &switch_ir_optim, "paddle SwitchIrOptim option");
opts->Register("enable-fc-padding", &enable_fc_padding, "paddle EnableFCPadding option");
opts->Register("enable-profile", &enable_profile, "paddle EnableProfile option");
}
void Register(kaldi::OptionsItf* opts) {
opts->Register("model-path", &model_path, "model file path");
opts->Register("model-params", &params_path, "params model file path");
opts->Register("thread-num", &thread_num, "thread num");
opts->Register("use-gpu", &use_gpu, "if use gpu");
opts->Register("input-names", &input_names, "paddle input names");
opts->Register("output-names", &output_names, "paddle output names");
opts->Register("cache-names", &cache_names, "cache names");
opts->Register("cache-shape", &cache_shape, "cache shape");
opts->Register("switch-ir-optiom",
&switch_ir_optim,
"paddle SwitchIrOptim option");
opts->Register("enable-fc-padding",
&enable_fc_padding,
"paddle EnableFCPadding option");
opts->Register(
"enable-profile", &enable_profile, "paddle EnableProfile option");
}
};
template<typename T>
template <typename T>
class Tensor {
public:
Tensor() {
}
Tensor(const std::vector<int>& shape) :
_shape(shape) {
int data_size = std::accumulate(_shape.begin(), _shape.end(),
1, std::multiplies<int>());
public:
Tensor() {}
Tensor(const std::vector<int>& shape) : _shape(shape) {
int data_size = std::accumulate(
_shape.begin(), _shape.end(), 1, std::multiplies<int>());
LOG(INFO) << "data size: " << data_size;
_data.resize(data_size, 0);
}
void reshape(const std::vector<int>& shape) {
_shape = shape;
int data_size = std::accumulate(_shape.begin(), _shape.end(),
1, std::multiplies<int>());
int data_size = std::accumulate(
_shape.begin(), _shape.end(), 1, std::multiplies<int>());
_data.resize(data_size, 0);
}
const std::vector<int>& get_shape() const {
return _shape;
}
std::vector<T>& get_data() {
return _data;
}
private:
const std::vector<int>& get_shape() const { return _shape; }
std::vector<T>& get_data() { return _data; }
private:
std::vector<int> _shape;
std::vector<T> _data;
};
@ -85,15 +107,16 @@ private:
class PaddleNnet : public NnetInterface {
public:
PaddleNnet(const ModelOptions& opts);
virtual void FeedForward(const kaldi::Matrix<kaldi::BaseFloat>& features,
virtual void FeedForward(const kaldi::Matrix<kaldi::BaseFloat>& features,
kaldi::Matrix<kaldi::BaseFloat>* inferences);
virtual void Reset();
std::shared_ptr<Tensor<kaldi::BaseFloat>> GetCacheEncoder(const std::string& name);
std::shared_ptr<Tensor<kaldi::BaseFloat>> GetCacheEncoder(
const std::string& name);
void InitCacheEncouts(const ModelOptions& opts);
private:
paddle_infer::Predictor* GetPredictor();
int ReleasePredictor(paddle_infer::Predictor* predictor);
int ReleasePredictor(paddle_infer::Predictor* predictor);
std::unique_ptr<paddle_infer::services::PredictorPool> pool;
std::vector<bool> pool_usages;
@ -107,4 +130,4 @@ class PaddleNnet : public NnetInterface {
DISALLOW_COPY_AND_ASSIGN(PaddleNnet);
};
} // namespace ppspeech
} // namespace ppspeech

@ -1,3 +1,17 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "utils/file_utils.h"
namespace ppspeech {
@ -17,5 +31,4 @@ bool ReadFileToVector(const std::string& filename,
return true;
}
}

@ -1,8 +1,21 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "base/common.h"
namespace ppspeech {
bool ReadFileToVector(const std::string& filename,
std::vector<std::string>* data);
bool ReadFileToVector(const std::string& filename,
std::vector<std::string>* data);
}

Loading…
Cancel
Save