// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "base/common.h" #include "decoder/param.h" #include "frontend/assembler.h" #include "frontend/data_cache.h" #include "kaldi/util/table-types.h" #include "nnet/decodable.h" #include "nnet/u2_nnet.h" DEFINE_string(feature_rspecifier, "", "test feature rspecifier"); DEFINE_string(nnet_prob_wspecifier, "", "nnet porb wspecifier"); DEFINE_string(nnet_encoder_outs_wspecifier, "", "nnet encoder outs wspecifier"); using kaldi::BaseFloat; using kaldi::Matrix; using std::vector; int main(int argc, char* argv[]) { gflags::SetUsageMessage("Usage:"); gflags::ParseCommandLineFlags(&argc, &argv, false); google::InitGoogleLogging(argv[0]); google::InstallFailureSignalHandler(); FLAGS_logtostderr = 1; int32 num_done = 0, num_err = 0; CHECK_GT(FLAGS_feature_rspecifier.size(), 0); CHECK_GT(FLAGS_nnet_prob_wspecifier.size(), 0); CHECK_GT(FLAGS_model_path.size(), 0); LOG(INFO) << "input rspecifier: " << FLAGS_feature_rspecifier; LOG(INFO) << "output wspecifier: " << FLAGS_nnet_prob_wspecifier; LOG(INFO) << "model path: " << FLAGS_model_path; kaldi::SequentialBaseFloatMatrixReader feature_reader( FLAGS_feature_rspecifier); kaldi::BaseFloatMatrixWriter nnet_out_writer(FLAGS_nnet_prob_wspecifier); kaldi::BaseFloatMatrixWriter nnet_encoder_outs_writer( FLAGS_nnet_encoder_outs_wspecifier); ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags(); int32 chunk_size = (FLAGS_nnet_decoder_chunk - 1) * FLAGS_subsampling_rate + FLAGS_receptive_field_length; int32 chunk_stride = FLAGS_subsampling_rate * FLAGS_nnet_decoder_chunk; int32 receptive_field_length = FLAGS_receptive_field_length; LOG(INFO) << "chunk size (frame): " << chunk_size; LOG(INFO) << "chunk stride (frame): " << chunk_stride; LOG(INFO) << "receptive field (frame): " << receptive_field_length; std::shared_ptr nnet(new ppspeech::U2Nnet(model_opts)); std::shared_ptr raw_data(new ppspeech::DataCache()); std::shared_ptr decodable( new ppspeech::Decodable(nnet, raw_data, FLAGS_acoustic_scale)); kaldi::Timer timer; for (; !feature_reader.Done(); feature_reader.Next()) { string utt = feature_reader.Key(); kaldi::Matrix feature = feature_reader.Value(); int nframes = feature.NumRows(); int feat_dim = feature.NumCols(); raw_data->SetDim(feat_dim); LOG(INFO) << "utt: " << utt; LOG(INFO) << "feat shape: " << nframes << ", " << feat_dim; int32 frame_idx = 0; int vocab_dim = 0; std::vector> prob_vec; std::vector> encoder_out_vec; int32 ori_feature_len = feature.NumRows(); int32 num_chunks = feature.NumRows() / chunk_stride + 1; LOG(INFO) << "num_chunks: " << num_chunks; for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) { int32 this_chunk_size = 0; if (ori_feature_len > chunk_idx * chunk_stride) { this_chunk_size = std::min( ori_feature_len - chunk_idx * chunk_stride, chunk_size); } if (this_chunk_size < receptive_field_length) { LOG(WARNING) << "utt: " << utt << " skip last " << this_chunk_size << " frames, expect is " << receptive_field_length; break; } kaldi::Vector feature_chunk(this_chunk_size * feat_dim); int32 start = chunk_idx * chunk_stride; for (int row_id = 0; row_id < this_chunk_size; ++row_id) { kaldi::SubVector feat_row(feature, start); kaldi::SubVector feature_chunk_row( feature_chunk.Data() + row_id * feat_dim, feat_dim); feature_chunk_row.CopyFromVec(feat_row); ++start; } // feat to frontend pipeline cache raw_data->Accept(feature_chunk); // send data finish signal if (chunk_idx == num_chunks - 1) { raw_data->SetFinished(); } // get nnet outputs kaldi::Timer timer; kaldi::Vector logprobs; bool isok = decodable->AdvanceChunk(&logprobs, &vocab_dim); CHECK(isok == true); for (int row_idx = 0; row_idx < logprobs.Dim() / vocab_dim; row_idx++) { kaldi::Vector vec_tmp(vocab_dim); std::memcpy(vec_tmp.Data(), logprobs.Data() + row_idx * vocab_dim, sizeof(kaldi::BaseFloat) * vocab_dim); prob_vec.push_back(vec_tmp); } VLOG(2) << "frame_idx: " << frame_idx << " elapsed: " << timer.Elapsed() << " sec."; } // get encoder out decodable->Nnet()->EncoderOuts(&encoder_out_vec); // after process one utt, then reset decoder state. decodable->Reset(); if (prob_vec.size() == 0 || encoder_out_vec.size() == 0) { // the TokenWriter can not write empty string. ++num_err; LOG(WARNING) << " the nnet prob/encoder_out of " << utt << " is empty"; continue; } { // writer nnet output kaldi::MatrixIndexT nrow = prob_vec.size(); kaldi::MatrixIndexT ncol = prob_vec[0].Dim(); LOG(INFO) << "nnet out shape: " << nrow << ", " << ncol; kaldi::Matrix nnet_out(nrow, ncol); for (int32 row_idx = 0; row_idx < nrow; ++row_idx) { for (int32 col_idx = 0; col_idx < ncol; ++col_idx) { nnet_out(row_idx, col_idx) = prob_vec[row_idx](col_idx); } } nnet_out_writer.Write(utt, nnet_out); } { // writer nnet encoder outs kaldi::MatrixIndexT nrow = encoder_out_vec.size(); kaldi::MatrixIndexT ncol = encoder_out_vec[0].Dim(); LOG(INFO) << "nnet encoder outs shape: " << nrow << ", " << ncol; kaldi::Matrix encoder_outs(nrow, ncol); for (int32 row_idx = 0; row_idx < nrow; ++row_idx) { for (int32 col_idx = 0; col_idx < ncol; ++col_idx) { encoder_outs(row_idx, col_idx) = encoder_out_vec[row_idx](col_idx); } } nnet_encoder_outs_writer.Write(utt, encoder_outs); } ++num_done; } double elapsed = timer.Elapsed(); LOG(INFO) << "Program cost:" << elapsed << " sec"; LOG(INFO) << "Done " << num_done << " utterances, " << num_err << " with errors."; return (num_done != 0 ? 0 : 1); }