|
|
|
@ -14,11 +14,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include "base/common.h"
|
|
|
|
|
#include "decoder/param.h"
|
|
|
|
|
#include "frontend/audio/assembler.h"
|
|
|
|
|
#include "frontend/audio/data_cache.h"
|
|
|
|
|
#include "kaldi/util/table-types.h"
|
|
|
|
|
#include "nnet/decodable.h"
|
|
|
|
|
#include "decoder/param.h"
|
|
|
|
|
#include "nnet/u2_nnet.h"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -46,14 +46,15 @@ int main(int argc, char* argv[]) {
|
|
|
|
|
LOG(INFO) << "output wspecifier: " << FLAGS_nnet_prob_wspecifier;
|
|
|
|
|
LOG(INFO) << "model path: " << FLAGS_model_path;
|
|
|
|
|
|
|
|
|
|
kaldi::SequentialBaseFloatMatrixReader feature_reader(FLAGS_feature_rspecifier);
|
|
|
|
|
kaldi::SequentialBaseFloatMatrixReader feature_reader(
|
|
|
|
|
FLAGS_feature_rspecifier);
|
|
|
|
|
kaldi::BaseFloatMatrixWriter nnet_out_writer(FLAGS_nnet_prob_wspecifier);
|
|
|
|
|
kaldi::BaseFloatMatrixWriter nnet_encoder_outs_writer(FLAGS_nnet_encoder_outs_wspecifier);
|
|
|
|
|
kaldi::BaseFloatMatrixWriter nnet_encoder_outs_writer(
|
|
|
|
|
FLAGS_nnet_encoder_outs_wspecifier);
|
|
|
|
|
|
|
|
|
|
ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags();
|
|
|
|
|
|
|
|
|
|
int32 chunk_size =
|
|
|
|
|
(FLAGS_nnet_decoder_chunk - 1) * FLAGS_subsampling_rate +
|
|
|
|
|
int32 chunk_size = (FLAGS_nnet_decoder_chunk - 1) * FLAGS_subsampling_rate +
|
|
|
|
|
FLAGS_receptive_field_length;
|
|
|
|
|
int32 chunk_stride = FLAGS_subsampling_rate * FLAGS_nnet_decoder_chunk;
|
|
|
|
|
int32 receptive_field_length = FLAGS_receptive_field_length;
|
|
|
|
@ -92,9 +93,9 @@ int main(int argc, char* argv[]) {
|
|
|
|
|
ori_feature_len - chunk_idx * chunk_stride, chunk_size);
|
|
|
|
|
}
|
|
|
|
|
if (this_chunk_size < receptive_field_length) {
|
|
|
|
|
LOG(WARNING) << "utt: " << utt << " skip last "
|
|
|
|
|
<< this_chunk_size << " frames, expect is "
|
|
|
|
|
<< receptive_field_length;
|
|
|
|
|
LOG(WARNING)
|
|
|
|
|
<< "utt: " << utt << " skip last " << this_chunk_size
|
|
|
|
|
<< " frames, expect is " << receptive_field_length;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -123,13 +124,17 @@ int main(int argc, char* argv[]) {
|
|
|
|
|
kaldi::Vector<kaldi::BaseFloat> logprobs;
|
|
|
|
|
bool isok = decodable->AdvanceChunk(&logprobs, &vocab_dim);
|
|
|
|
|
CHECK(isok == true);
|
|
|
|
|
for (int row_idx = 0; row_idx < logprobs.Dim() / vocab_dim; row_idx ++) {
|
|
|
|
|
for (int row_idx = 0; row_idx < logprobs.Dim() / vocab_dim;
|
|
|
|
|
row_idx++) {
|
|
|
|
|
kaldi::Vector<kaldi::BaseFloat> vec_tmp(vocab_dim);
|
|
|
|
|
std::memcpy(vec_tmp.Data(), logprobs.Data() + row_idx*vocab_dim, sizeof(kaldi::BaseFloat) * vocab_dim);
|
|
|
|
|
std::memcpy(vec_tmp.Data(),
|
|
|
|
|
logprobs.Data() + row_idx * vocab_dim,
|
|
|
|
|
sizeof(kaldi::BaseFloat) * vocab_dim);
|
|
|
|
|
prob_vec.push_back(vec_tmp);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
VLOG(2) << "frame_idx: " << frame_idx << " elapsed: " << timer.Elapsed() << " sec.";
|
|
|
|
|
VLOG(2) << "frame_idx: " << frame_idx
|
|
|
|
|
<< " elapsed: " << timer.Elapsed() << " sec.";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// get encoder out
|
|
|
|
@ -141,7 +146,8 @@ int main(int argc, char* argv[]) {
|
|
|
|
|
if (prob_vec.size() == 0 || encoder_out_vec.size() == 0) {
|
|
|
|
|
// the TokenWriter can not write empty string.
|
|
|
|
|
++num_err;
|
|
|
|
|
LOG(WARNING) << " the nnet prob/encoder_out of " << utt << " is empty";
|
|
|
|
|
LOG(WARNING) << " the nnet prob/encoder_out of " << utt
|
|
|
|
|
<< " is empty";
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -168,7 +174,8 @@ int main(int argc, char* argv[]) {
|
|
|
|
|
kaldi::Matrix<kaldi::BaseFloat> encoder_outs(nrow, ncol);
|
|
|
|
|
for (int32 row_idx = 0; row_idx < nrow; ++row_idx) {
|
|
|
|
|
for (int32 col_idx = 0; col_idx < ncol; ++col_idx) {
|
|
|
|
|
encoder_outs(row_idx, col_idx) = encoder_out_vec[row_idx](col_idx);
|
|
|
|
|
encoder_outs(row_idx, col_idx) =
|
|
|
|
|
encoder_out_vec[row_idx](col_idx);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
nnet_encoder_outs_writer.Write(utt, encoder_outs);
|
|
|
|
|