fix nnet model inputs and output name

pull/1740/head
Hui Zhang 3 years ago
parent a655fb69ef
commit caf7225892

@ -34,10 +34,12 @@ DEFINE_int32(receptive_field_length,
DEFINE_int32(downsampling_rate, DEFINE_int32(downsampling_rate,
4, 4,
"two CNN(kernel=5) module downsampling rate."); "two CNN(kernel=5) module downsampling rate.");
DEFINE_string(
model_input_names,
"audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box",
"model input names");
DEFINE_string(model_output_names, DEFINE_string(model_output_names,
"save_infer_model/scale_0.tmp_1,save_infer_model/" "softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0",
"scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/"
"scale_3.tmp_1",
"model output names"); "model output names");
DEFINE_string(model_cache_names, "5-1-1024,5-1-1024", "model cache names"); DEFINE_string(model_cache_names, "5-1-1024,5-1-1024", "model cache names");
@ -76,6 +78,7 @@ int main(int argc, char* argv[]) {
model_opts.model_path = model_path; model_opts.model_path = model_path;
model_opts.params_path = model_params; model_opts.params_path = model_params;
model_opts.cache_shape = FLAGS_model_cache_names; model_opts.cache_shape = FLAGS_model_cache_names;
model_opts.input_names = FLAGS_model_input_names;
model_opts.output_names = FLAGS_model_output_names; model_opts.output_names = FLAGS_model_output_names;
std::shared_ptr<ppspeech::PaddleNnet> nnet( std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts)); new ppspeech::PaddleNnet(model_opts));

@ -48,7 +48,6 @@ if [ ! -f $lm ]; then
popd popd
fi fi
feat_wspecifier=$exp_dir/feats.ark feat_wspecifier=$exp_dir/feats.ark
cmvn=$exp_dir/cmvn.ark cmvn=$exp_dir/cmvn.ark
@ -57,7 +56,7 @@ export GLOG_logtostderr=1
# dump json cmvn to kaldi # dump json cmvn to kaldi
cmvn-json2kaldi \ cmvn-json2kaldi \
--json_file $ckpt_dir/data/mean_std.json \ --json_file $ckpt_dir/data/mean_std.json \
--cmvn_write_path $exp_dir/cmvn.ark \ --cmvn_write_path $cmvn \
--binary=false --binary=false
echo "convert json cmvn to kaldi ark." echo "convert json cmvn to kaldi ark."
@ -66,7 +65,7 @@ echo "convert json cmvn to kaldi ark."
linear-spectrogram-wo-db-norm-ol \ linear-spectrogram-wo-db-norm-ol \
--wav_rspecifier=scp:$data/wav.scp \ --wav_rspecifier=scp:$data/wav.scp \
--feature_wspecifier=ark,t:$feat_wspecifier \ --feature_wspecifier=ark,t:$feat_wspecifier \
--cmvn_file=$exp_dir/cmvn.ark --cmvn_file=$cmvn
echo "compute linear spectrogram feature." echo "compute linear spectrogram feature."
# run ctc beam search decoder as streaming # run ctc beam search decoder as streaming

@ -37,10 +37,12 @@ DEFINE_int32(receptive_field_length,
DEFINE_int32(downsampling_rate, DEFINE_int32(downsampling_rate,
4, 4,
"two CNN(kernel=5) module downsampling rate."); "two CNN(kernel=5) module downsampling rate.");
DEFINE_string(
model_input_names,
"audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box",
"model input names");
DEFINE_string(model_output_names, DEFINE_string(model_output_names,
"save_infer_model/scale_0.tmp_1,save_infer_model/" "softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0",
"scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/"
"scale_3.tmp_1",
"model output names"); "model output names");
DEFINE_string(model_cache_names, "5-1-1024,5-1-1024", "model cache names"); DEFINE_string(model_cache_names, "5-1-1024,5-1-1024", "model cache names");
@ -79,6 +81,7 @@ int main(int argc, char* argv[]) {
model_opts.model_path = model_graph; model_opts.model_path = model_graph;
model_opts.params_path = model_params; model_opts.params_path = model_params;
model_opts.cache_shape = FLAGS_model_cache_names; model_opts.cache_shape = FLAGS_model_cache_names;
model_opts.input_names = FLAGS_model_input_names;
model_opts.output_names = FLAGS_model_output_names; model_opts.output_names = FLAGS_model_output_names;
std::shared_ptr<ppspeech::PaddleNnet> nnet( std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts)); new ppspeech::PaddleNnet(model_opts));

@ -9,4 +9,4 @@ target_link_libraries(${bin_name} frontend kaldi-util kaldi-feat-common gflags g
set(bin_name cmvn-json2kaldi) set(bin_name cmvn-json2kaldi)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog) target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog ${DEPS})

@ -14,18 +14,20 @@
// Note: Do not print/log ondemand object. // Note: Do not print/log ondemand object.
#include "base/common.h"
#include "base/flags.h" #include "base/flags.h"
#include "base/log.h" #include "base/log.h"
#include "kaldi/matrix/kaldi-matrix.h" #include "kaldi/matrix/kaldi-matrix.h"
#include "kaldi/util/kaldi-io.h" #include "kaldi/util/kaldi-io.h"
#include "utils/file_utils.h" #include "utils/file_utils.h"
#include "utils/simdjson.h" // #include "boost/json.hpp"
#include <boost/json/src.hpp>
DEFINE_string(json_file, "", "cmvn json file"); DEFINE_string(json_file, "", "cmvn json file");
DEFINE_string(cmvn_write_path, "./cmvn.ark", "write cmvn"); DEFINE_string(cmvn_write_path, "./cmvn.ark", "write cmvn");
DEFINE_bool(binary, true, "write cmvn in binary (true) or text(false)"); DEFINE_bool(binary, true, "write cmvn in binary (true) or text(false)");
using namespace simdjson; using namespace boost::json; // from <boost/json.hpp>
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false); gflags::ParseCommandLineFlags(&argc, &argv, false);
@ -33,49 +35,51 @@ int main(int argc, char* argv[]) {
LOG(INFO) << "cmvn josn path: " << FLAGS_json_file; LOG(INFO) << "cmvn josn path: " << FLAGS_json_file;
try { auto ifs = std::ifstream(FLAGS_json_file);
padded_string json = padded_string::load(FLAGS_json_file); std::string json_str = ppspeech::ReadFile2String(FLAGS_json_file);
auto value = boost::json::parse(json_str);
ondemand::parser parser; if (!value.is_object()) {
ondemand::document doc = parser.iterate(json); LOG(ERROR) << "Input json file format error.";
ondemand::value val = doc; }
ondemand::array mean_stat = val["mean_stat"]; for (auto obj : value.as_object()) {
std::vector<kaldi::BaseFloat> mean_stat_vec; if (obj.key() == "mean_stat") {
for (double x : mean_stat) { LOG(INFO) << "mean_stat:" << obj.value();
mean_stat_vec.push_back(x);
} }
// LOG(INFO) << mean_stat; this line will casue if (obj.key() == "var_stat") {
// simdjson::simdjson_error("Objects and arrays can only be iterated LOG(INFO) << "var_stat: " << obj.value();
// when
// they are first encountered")
ondemand::array var_stat = val["var_stat"];
std::vector<kaldi::BaseFloat> var_stat_vec;
for (double x : var_stat) {
var_stat_vec.push_back(x);
} }
if (obj.key() == "frame_num") {
kaldi::int32 frame_num = uint64_t(val["frame_num"]); LOG(INFO) << "frame_num: " << obj.value();
LOG(INFO) << "nframe: " << frame_num;
size_t mean_size = mean_stat_vec.size();
kaldi::Matrix<double> cmvn_stats(2, mean_size + 1);
for (size_t idx = 0; idx < mean_size; ++idx) {
cmvn_stats(0, idx) = mean_stat_vec[idx];
cmvn_stats(1, idx) = var_stat_vec[idx];
} }
cmvn_stats(0, mean_size) = frame_num; }
LOG(INFO) << cmvn_stats;
boost::json::array mean_stat = value.at("mean_stat").as_array();
std::vector<kaldi::BaseFloat> mean_stat_vec;
for (auto it = mean_stat.begin(); it != mean_stat.end(); it++) {
mean_stat_vec.push_back(it->as_double());
}
kaldi::WriteKaldiObject( boost::json::array var_stat = value.at("var_stat").as_array();
cmvn_stats, FLAGS_cmvn_write_path, FLAGS_binary); std::vector<kaldi::BaseFloat> var_stat_vec;
LOG(INFO) << "cmvn stats have write into: " << FLAGS_cmvn_write_path; for (auto it = var_stat.begin(); it != var_stat.end(); it++) {
LOG(INFO) << "Binary: " << FLAGS_binary; var_stat_vec.push_back(it->as_double());
} catch (simdjson::simdjson_error& err) {
LOG(ERROR) << err.what();
} }
kaldi::int32 frame_num = uint64_t(value.at("frame_num").as_int64());
LOG(INFO) << "nframe: " << frame_num;
size_t mean_size = mean_stat_vec.size();
kaldi::Matrix<double> cmvn_stats(2, mean_size + 1);
for (size_t idx = 0; idx < mean_size; ++idx) {
cmvn_stats(0, idx) = mean_stat_vec[idx];
cmvn_stats(1, idx) = var_stat_vec[idx];
}
cmvn_stats(0, mean_size) = frame_num;
LOG(INFO) << cmvn_stats;
kaldi::WriteKaldiObject(cmvn_stats, FLAGS_cmvn_write_path, FLAGS_binary);
LOG(INFO) << "cmvn stats have write into: " << FLAGS_cmvn_write_path;
LOG(INFO) << "Binary: " << FLAGS_binary;
return 0; return 0;
} }

Loading…
Cancel
Save