You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/runtime/engine/asr/nnet/u2_nnet_main.cc

195 lines
7.6 KiB

// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
2 years ago
#include "nnet/u2_nnet.h"
#include "base/common.h"
2 years ago
#include "decoder/param.h"
#include "frontend/assembler.h"
#include "frontend/data_cache.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
2 years ago
DEFINE_string(feature_rspecifier, "", "test feature rspecifier");
DEFINE_string(nnet_prob_wspecifier, "", "nnet porb wspecifier");
DEFINE_string(nnet_encoder_outs_wspecifier, "", "nnet encoder outs wspecifier");
using kaldi::BaseFloat;
using kaldi::Matrix;
using std::vector;
int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
int32 num_done = 0, num_err = 0;
2 years ago
CHECK_GT(FLAGS_feature_rspecifier.size(), 0);
CHECK_GT(FLAGS_nnet_prob_wspecifier.size(), 0);
CHECK_GT(FLAGS_model_path.size(), 0);
LOG(INFO) << "input rspecifier: " << FLAGS_feature_rspecifier;
LOG(INFO) << "output wspecifier: " << FLAGS_nnet_prob_wspecifier;
LOG(INFO) << "model path: " << FLAGS_model_path;
2 years ago
kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_rspecifier);
kaldi::BaseFloatMatrixWriter nnet_out_writer(FLAGS_nnet_prob_wspecifier);
2 years ago
kaldi::BaseFloatMatrixWriter nnet_encoder_outs_writer(
FLAGS_nnet_encoder_outs_wspecifier);
2 years ago
ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags();
2 years ago
int32 chunk_size = (FLAGS_nnet_decoder_chunk - 1) * FLAGS_subsampling_rate +
FLAGS_receptive_field_length;
2 years ago
int32 chunk_stride = FLAGS_subsampling_rate * FLAGS_nnet_decoder_chunk;
int32 receptive_field_length = FLAGS_receptive_field_length;
LOG(INFO) << "chunk size (frame): " << chunk_size;
LOG(INFO) << "chunk stride (frame): " << chunk_stride;
LOG(INFO) << "receptive field (frame): " << receptive_field_length;
std::shared_ptr<ppspeech::U2Nnet> nnet(new ppspeech::U2Nnet(model_opts));
std::shared_ptr<ppspeech::DataCache> raw_data(new ppspeech::DataCache());
std::shared_ptr<ppspeech::Decodable> decodable(
new ppspeech::Decodable(nnet, raw_data, FLAGS_acoustic_scale));
kaldi::Timer timer;
for (; !feature_reader.Done(); feature_reader.Next()) {
string utt = feature_reader.Key();
kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
int nframes = feature.NumRows();
int feat_dim = feature.NumCols();
raw_data->SetDim(feat_dim);
LOG(INFO) << "utt: " << utt;
LOG(INFO) << "feat shape: " << nframes << ", " << feat_dim;
int32 frame_idx = 0;
int vocab_dim = 0;
std::vector<kaldi::Vector<kaldi::BaseFloat>> prob_vec;
std::vector<kaldi::Vector<kaldi::BaseFloat>> encoder_out_vec;
int32 ori_feature_len = feature.NumRows();
int32 num_chunks = feature.NumRows() / chunk_stride + 1;
LOG(INFO) << "num_chunks: " << num_chunks;
for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) {
int32 this_chunk_size = 0;
if (ori_feature_len > chunk_idx * chunk_stride) {
this_chunk_size = std::min(
ori_feature_len - chunk_idx * chunk_stride, chunk_size);
}
if (this_chunk_size < receptive_field_length) {
LOG(WARNING) << "utt: " << utt << " skip last "
<< this_chunk_size << " frames, expect is "
<< receptive_field_length;
break;
}
kaldi::Vector<kaldi::BaseFloat> feature_chunk(this_chunk_size *
feat_dim);
int32 start = chunk_idx * chunk_stride;
for (int row_id = 0; row_id < this_chunk_size; ++row_id) {
kaldi::SubVector<kaldi::BaseFloat> feat_row(feature, start);
kaldi::SubVector<kaldi::BaseFloat> feature_chunk_row(
feature_chunk.Data() + row_id * feat_dim, feat_dim);
feature_chunk_row.CopyFromVec(feat_row);
++start;
}
// feat to frontend pipeline cache
raw_data->Accept(feature_chunk);
// send data finish signal
if (chunk_idx == num_chunks - 1) {
raw_data->SetFinished();
}
// get nnet outputs
kaldi::Timer timer;
kaldi::Vector<kaldi::BaseFloat> logprobs;
bool isok = decodable->AdvanceChunk(&logprobs, &vocab_dim);
CHECK(isok == true);
2 years ago
for (int row_idx = 0; row_idx < logprobs.Dim() / vocab_dim;
row_idx++) {
kaldi::Vector<kaldi::BaseFloat> vec_tmp(vocab_dim);
2 years ago
std::memcpy(vec_tmp.Data(),
logprobs.Data() + row_idx * vocab_dim,
sizeof(kaldi::BaseFloat) * vocab_dim);
prob_vec.push_back(vec_tmp);
}
2 years ago
VLOG(2) << "frame_idx: " << frame_idx
<< " elapsed: " << timer.Elapsed() << " sec.";
}
// get encoder out
decodable->Nnet()->EncoderOuts(&encoder_out_vec);
// after process one utt, then reset decoder state.
decodable->Reset();
if (prob_vec.size() == 0 || encoder_out_vec.size() == 0) {
// the TokenWriter can not write empty string.
++num_err;
2 years ago
LOG(WARNING) << " the nnet prob/encoder_out of " << utt
<< " is empty";
continue;
}
{
// writer nnet output
kaldi::MatrixIndexT nrow = prob_vec.size();
kaldi::MatrixIndexT ncol = prob_vec[0].Dim();
LOG(INFO) << "nnet out shape: " << nrow << ", " << ncol;
kaldi::Matrix<kaldi::BaseFloat> nnet_out(nrow, ncol);
for (int32 row_idx = 0; row_idx < nrow; ++row_idx) {
for (int32 col_idx = 0; col_idx < ncol; ++col_idx) {
nnet_out(row_idx, col_idx) = prob_vec[row_idx](col_idx);
}
}
nnet_out_writer.Write(utt, nnet_out);
}
{
// writer nnet encoder outs
kaldi::MatrixIndexT nrow = encoder_out_vec.size();
kaldi::MatrixIndexT ncol = encoder_out_vec[0].Dim();
LOG(INFO) << "nnet encoder outs shape: " << nrow << ", " << ncol;
kaldi::Matrix<kaldi::BaseFloat> encoder_outs(nrow, ncol);
for (int32 row_idx = 0; row_idx < nrow; ++row_idx) {
for (int32 col_idx = 0; col_idx < ncol; ++col_idx) {
2 years ago
encoder_outs(row_idx, col_idx) =
encoder_out_vec[row_idx](col_idx);
}
}
nnet_encoder_outs_writer.Write(utt, encoder_outs);
}
++num_done;
}
double elapsed = timer.Elapsed();
LOG(INFO) << "Program cost:" << elapsed << " sec";
LOG(INFO) << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
return (num_done != 0 ? 0 : 1);
}