make kaldi-native-fbank work

pull/2794/head
YangZhou 3 years ago
parent d1551b1ed3
commit 84f624f304

@ -19,21 +19,12 @@ aishell_wav_scp=aishell_test.scp
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
cmvn_json2kaldi_main \
--json_file $model_dir/mean_std.json \
--cmvn_write_path $exp/cmvn.ark \
--binary=false
echo "convert json cmvn to kaldi ark."
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj ./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/feat.log \ utils/run.pl JOB=1:$nj $data/split${nj}/JOB/feat.log \
compute_fbank_main \ compute_fbank_main \
--num_bins 80 \ --num_bins 80 \
--cmvn_file=$exp/cmvn.ark \ --cmvn_file=$model_dir/mean_std.json \
--streaming_chunk=36 \ --streaming_chunk=36 \
--wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \ --wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \
--feature_wspecifier=ark,scp:$data/split${nj}/JOB/fbank.ark,$data/split${nj}/JOB/fbank.scp --feature_wspecifier=ark,scp:$data/split${nj}/JOB/fbank.ark,$data/split${nj}/JOB/fbank.scp

@ -19,7 +19,7 @@ utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recognizer.log \
u2_recognizer_main \ u2_recognizer_main \
--use_fbank=true \ --use_fbank=true \
--num_bins=80 \ --num_bins=80 \
--cmvn_file=$exp/cmvn.ark \ --cmvn_file=$model_dir/mean_std.json \
--model_path=$model_dir/export.jit \ --model_path=$model_dir/export.jit \
--vocab_path=$model_dir/unit.txt \ --vocab_path=$model_dir/unit.txt \
--nnet_decoder_chunk=16 \ --nnet_decoder_chunk=16 \

@ -19,7 +19,7 @@ utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recognizer.quant.log \
u2_recognizer_main \ u2_recognizer_main \
--use_fbank=true \ --use_fbank=true \
--num_bins=80 \ --num_bins=80 \
--cmvn_file=$exp/cmvn.ark \ --cmvn_file=$model_dir/mean_std.json \
--model_path=$model_dir/export \ --model_path=$model_dir/export \
--vocab_path=$model_dir/unit.txt \ --vocab_path=$model_dir/unit.txt \
--nnet_decoder_chunk=16 \ --nnet_decoder_chunk=16 \

@ -22,7 +22,6 @@ if [ ! -d ${SPEECHX_BUILD} ]; then
popd popd
fi fi
ckpt_dir=$data/model ckpt_dir=$data/model
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ];then if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ];then
@ -72,7 +71,7 @@ fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# process cmvn and compute fbank feat # process compute fbank feat
./local/feat.sh ./local/feat.sh
fi fi

@ -16,12 +16,6 @@ add_library(frontend STATIC
) )
target_link_libraries(frontend PUBLIC kaldi-native-fbank-core utils) target_link_libraries(frontend PUBLIC kaldi-native-fbank-core utils)
#set(bin_name cmvn_json2kaldi_main)
#add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
#target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
#target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog)
set(BINS set(BINS
compute_fbank_main compute_fbank_main
) )

@ -1,98 +0,0 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Note: Do not print/log ondemand object.
#include "base/common.h"
#include "base/flags.h"
#include "base/log.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "kaldi/util/kaldi-io.h"
#include "utils/file_utils.h"
#include "utils/picojson.h"
DEFINE_string(json_file, "", "cmvn json file");
DEFINE_string(cmvn_write_path, "./cmvn.ark", "write cmvn");
DEFINE_bool(binary, true, "write cmvn in binary (true) or text(false)");
int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
LOG(INFO) << "cmvn josn path: " << FLAGS_json_file;
auto ifs = std::ifstream(FLAGS_json_file);
std::string json_str = ppspeech::ReadFile2String(FLAGS_json_file);
picojson::value value;
std::string err;
const char* json_end = picojson::parse(
value, json_str.c_str(), json_str.c_str() + json_str.size(), &err);
if (!value.is<picojson::object>()) {
LOG(ERROR) << "Input json file format error.";
}
const picojson::value::object& obj = value.get<picojson::object>();
for (picojson::value::object::const_iterator elem = obj.begin();
elem != obj.end();
++elem) {
if (elem->first == "mean_stat") {
VLOG(2) << "mean_stat:" << elem->second;
// const picojson::value tmp =
// elem->second.get(0);//<picojson::array>();
double tmp =
elem->second.get(0).get<double>(); //<picojson::array>();
VLOG(2) << "tmp: " << tmp;
}
if (elem->first == "var_stat") {
VLOG(2) << "var_stat: " << elem->second;
}
if (elem->first == "frame_num") {
VLOG(2) << "frame_num: " << elem->second;
}
}
const picojson::value::array& mean_stat =
value.get("mean_stat").get<picojson::array>();
std::vector<kaldi::BaseFloat> mean_stat_vec;
for (auto it = mean_stat.begin(); it != mean_stat.end(); it++) {
mean_stat_vec.push_back((*it).get<double>());
}
const picojson::value::array& var_stat =
value.get("var_stat").get<picojson::array>();
std::vector<kaldi::BaseFloat> var_stat_vec;
for (auto it = var_stat.begin(); it != var_stat.end(); it++) {
var_stat_vec.push_back((*it).get<double>());
}
kaldi::int32 frame_num = value.get("frame_num").get<int64_t>();
LOG(INFO) << "nframe: " << frame_num;
size_t mean_size = mean_stat_vec.size();
kaldi::Matrix<double> cmvn_stats(2, mean_size + 1);
for (size_t idx = 0; idx < mean_size; ++idx) {
cmvn_stats(0, idx) = mean_stat_vec[idx];
cmvn_stats(1, idx) = var_stat_vec[idx];
}
cmvn_stats(0, mean_size) = frame_num;
VLOG(2) << cmvn_stats;
kaldi::WriteKaldiObject(cmvn_stats, FLAGS_cmvn_write_path, FLAGS_binary);
LOG(INFO) << "cmvn stats have write into: " << FLAGS_cmvn_write_path;
LOG(INFO) << "Binary: " << FLAGS_binary;
return 0;
}

@ -131,12 +131,14 @@ int main(int argc, char* argv[]) {
} }
// read feat // read feat
kaldi::Vector<BaseFloat> features; kaldi::Vector<BaseFloat> features(feature_cache.Dim());
bool flag = true; bool flag = true;
do { do {
std::vector<BaseFloat> tmp; std::vector<BaseFloat> tmp;
flag = feature_cache.Read(&tmp); flag = feature_cache.Read(&tmp);
std::memcpy(features.Data(), tmp.data(), tmp.size()*sizeof(BaseFloat)); std::memcpy(features.Data(),
tmp.data(),
tmp.size() * sizeof(BaseFloat));
if (flag && features.Dim() != 0) { if (flag && features.Dim() != 0) {
feats.push_back(features); feats.push_back(features);
feature_rows += features.Dim() / feature_cache.Dim(); feature_rows += features.Dim() / feature_cache.Dim();

@ -38,9 +38,12 @@ bool StreamingFeatureTpl<F>::Read(std::vector<kaldi::BaseFloat>* feats) {
int32 wav_len = wav.size(); int32 wav_len = wav.size();
int32 left_len = remained_wav_.size(); int32 left_len = remained_wav_.size();
std::vector<kaldi::BaseFloat> waves(left_len + wav_len); std::vector<kaldi::BaseFloat> waves(left_len + wav_len);
std::memcpy(waves.data(), remained_wav_.data(), left_len*sizeof(kaldi::BaseFloat)); std::memcpy(waves.data(),
remained_wav_.data(),
left_len * sizeof(kaldi::BaseFloat));
std::memcpy(waves.data() + left_len, std::memcpy(waves.data() + left_len,
wav.data(), wav_len*sizeof(kaldi::BaseFloat)); wav.data(),
wav_len * sizeof(kaldi::BaseFloat));
// compute speech feature // compute speech feature
Compute(waves, feats); Compute(waves, feats);
@ -52,18 +55,16 @@ bool StreamingFeatureTpl<F>::Read(std::vector<kaldi::BaseFloat>* feats) {
int32 left_samples = waves.size() - frame_shift * num_frames; int32 left_samples = waves.size() - frame_shift * num_frames;
remained_wav_.resize(left_samples); remained_wav_.resize(left_samples);
std::memcpy(remained_wav_.data(), std::memcpy(remained_wav_.data(),
waves.data() + frame_shift*num_frames, waves.data() + frame_shift * num_frames,
left_samples*sizeof(BaseFloat)); left_samples * sizeof(BaseFloat));
return true; return true;
} }
// Compute feat // Compute feat
template <class F> template <class F>
bool StreamingFeatureTpl<F>::Compute( bool StreamingFeatureTpl<F>::Compute(const std::vector<kaldi::BaseFloat>& waves,
const std::vector<kaldi::BaseFloat>& waves, std::vector<kaldi::BaseFloat>* feats) {
std::vector<kaldi::BaseFloat>* feats) { const knf::FrameExtractionOptions& frame_opts = computer_.GetFrameOptions();
const knf::FrameExtractionOptions& frame_opts =
computer_.GetFrameOptions();
int32 num_samples = waves.size(); int32 num_samples = waves.size();
int32 frame_length = frame_opts.WindowSize(); int32 frame_length = frame_opts.WindowSize();
int32 sample_rate = frame_opts.samp_freq; int32 sample_rate = frame_opts.samp_freq;
@ -77,18 +78,20 @@ bool StreamingFeatureTpl<F>::Compute(
std::vector<kaldi::BaseFloat> window; std::vector<kaldi::BaseFloat> window;
bool need_raw_log_energy = computer_.NeedRawLogEnergy(); bool need_raw_log_energy = computer_.NeedRawLogEnergy();
for (int32 frame = 0; frame < num_frames; frame++) { for (int32 frame = 0; frame < num_frames; frame++) {
std::fill(window.begin(), window.end(), 0);
kaldi::BaseFloat raw_log_energy = 0.0; kaldi::BaseFloat raw_log_energy = 0.0;
kaldi::BaseFloat vtln_warp = 1.0; kaldi::BaseFloat vtln_warp = 1.0;
knf::ExtractWindow(0, knf::ExtractWindow(0,
waves, waves,
frame, frame,
frame_opts, frame_opts,
window_function_, window_function_,
&window, &window,
need_raw_log_energy ? &raw_log_energy : NULL); need_raw_log_energy ? &raw_log_energy : NULL);
std::vector<kaldi::BaseFloat> this_feature(computer_.Dim()); std::vector<kaldi::BaseFloat> this_feature(computer_.Dim());
computer_.Compute(raw_log_energy, vtln_warp, &window, this_feature.data()); computer_.Compute(
raw_log_energy, vtln_warp, &window, this_feature.data());
std::memcpy(feats->data() + frame * Dim(), std::memcpy(feats->data() + frame * Dim(),
this_feature.data(), this_feature.data(),
sizeof(BaseFloat) * Dim()); sizeof(BaseFloat) * Dim());

Loading…
Cancel
Save