diff --git a/speechx/examples/codelab/u2nnet/run.sh b/speechx/examples/codelab/u2nnet/run.sh index b309bc6f2..704653e7f 100755 --- a/speechx/examples/codelab/u2nnet/run.sh +++ b/speechx/examples/codelab/u2nnet/run.sh @@ -40,6 +40,7 @@ cmvn_json2kaldi_main \ --json_file $model_dir/mean_std.json \ --cmvn_write_path $exp/cmvn.ark \ --binary=false + echo "convert json cmvn to kaldi ark." compute_fbank_main \ @@ -47,6 +48,7 @@ compute_fbank_main \ --wav_rspecifier=scp:$data/wav.scp \ --cmvn_file=$exp/cmvn.ark \ --feature_wspecifier=ark,t:$exp/fbank.ark + echo "compute fbank feature." u2_nnet_main \ @@ -56,4 +58,7 @@ u2_nnet_main \ --receptive_field_length=7 \ --downsampling_rate=4 \ --acoustic_scale=1.0 \ - --nnet_prob_wspecifier=ark,t:$exp/probs.ark + --nnet_encoder_outs_wspecifier=ark,t:$exp/encoder_outs.ark \ + --nnet_prob_wspecifier=ark,t:$exp/logprobs.ark + +echo "u2 nnet decode." diff --git a/speechx/speechx/nnet/decodable.h b/speechx/speechx/nnet/decodable.h index 8786e4f20..39b38dc11 100644 --- a/speechx/speechx/nnet/decodable.h +++ b/speechx/speechx/nnet/decodable.h @@ -55,6 +55,8 @@ class Decodable : public kaldi::DecodableInterface { int32 TokenId2NnetId(int32 token_id); + std::shared_ptr Nnet() { return nnet_; } + private: bool AdvanceChunk(); diff --git a/speechx/speechx/nnet/ds2_nnet.h b/speechx/speechx/nnet/ds2_nnet.h index 717bdb721..80be69271 100644 --- a/speechx/speechx/nnet/ds2_nnet.h +++ b/speechx/speechx/nnet/ds2_nnet.h @@ -96,16 +96,22 @@ class PaddleNnet : public NnetInterface { public: PaddleNnet(const ModelOptions& opts); - virtual void FeedForward(const kaldi::Vector& features, - const int32& feature_dim, - NnetOut* out); + void FeedForward(const kaldi::Vector& features, + const int32& feature_dim, + NnetOut* out) override; void Dim(); - virtual void Reset(); + + void Reset() override; + std::shared_ptr> GetCacheEncoder( const std::string& name); + void InitCacheEncouts(const ModelOptions& opts); + void EncoderOuts(std::vector>* encoder_out) + const override {} + private: paddle_infer::Predictor* GetPredictor(); int ReleasePredictor(paddle_infer::Predictor* predictor); diff --git a/speechx/speechx/nnet/nnet_itf.h b/speechx/speechx/nnet/nnet_itf.h index b98f5ebd0..5dde72a81 100644 --- a/speechx/speechx/nnet/nnet_itf.h +++ b/speechx/speechx/nnet/nnet_itf.h @@ -22,7 +22,7 @@ namespace ppspeech { struct NnetOut { - // nnet out, maybe logprob or prob + // nnet out. maybe logprob or prob. Almost time this is logprob. kaldi::Vector logprobs; int32 vocab_dim; @@ -35,11 +35,21 @@ struct NnetOut { class NnetInterface { public: + virtual ~NnetInterface() {} + + // forward feat with nnet. + // nnet do not cache feats, feats cached by frontend. + // nnet cache model outputs, i.e. logprobs/encoder_outs. virtual void FeedForward(const kaldi::Vector& features, const int32& feature_dim, NnetOut* out) = 0; + + // reset nnet state, e.g. nnet_logprob_cache_, offset_, encoder_outs_. virtual void Reset() = 0; - virtual ~NnetInterface() {} + + // using to get encoder outs. e.g. seq2seq with Attention model. + virtual void EncoderOuts( + std::vector>* encoder_out) const = 0; }; } // namespace ppspeech diff --git a/speechx/speechx/nnet/u2_nnet.cc b/speechx/speechx/nnet/u2_nnet.cc index ddb815d20..74f8cf788 100644 --- a/speechx/speechx/nnet/u2_nnet.cc +++ b/speechx/speechx/nnet/u2_nnet.cc @@ -705,4 +705,30 @@ void U2Nnet::AttentionRescoring(const std::vector>& hyps, } } + +void U2Nnet::EncoderOuts(std::vector>* encoder_out) const { + // list of (B=1,T,D) + int size = encoder_outs_.size(); + VLOG(1) << "encoder_outs_ size: " << size; + + for (int i = 0; i < size; i++){ + const paddle::Tensor& item = encoder_outs_[i]; + const std::vector shape = item.shape(); + CHECK(shape.size() == 3); + const int& B = shape[0]; + const int& T = shape[1]; + const int& D = shape[2]; + CHECK(B == 1) << "Only support batch one."; + VLOG(1) << "encoder out " << i << " shape: (" << B << "," << T << "," << D << ")"; + + const float *this_tensor_ptr = item.data(); + for (int j = 0; j < T; j++){ + const float* cur = this_tensor_ptr + j * D; + kaldi::Vector out(D); + std::memcpy(out.Data(), cur, D * sizeof(kaldi::BaseFloat)); + encoder_out->emplace_back(out); + } + } + } + } // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/nnet/u2_nnet.h b/speechx/speechx/nnet/u2_nnet.h index 775a078a5..8ce45f43a 100644 --- a/speechx/speechx/nnet/u2_nnet.h +++ b/speechx/speechx/nnet/u2_nnet.h @@ -137,9 +137,8 @@ class U2Nnet : public U2NnetBase { // debug void FeedEncoderOuts(paddle::Tensor& encoder_out); - const std::vector& EncoderOuts() const { - return encoder_outs_; - } + void EncoderOuts( + std::vector>* encoder_out) const; private: U2ModelOptions opts_; diff --git a/speechx/speechx/nnet/u2_nnet_main.cc b/speechx/speechx/nnet/u2_nnet_main.cc index b602ac4db..fb9fec230 100644 --- a/speechx/speechx/nnet/u2_nnet_main.cc +++ b/speechx/speechx/nnet/u2_nnet_main.cc @@ -21,6 +21,7 @@ DEFINE_string(feature_rspecifier, "", "test feature rspecifier"); DEFINE_string(nnet_prob_wspecifier, "", "nnet porb wspecifier"); +DEFINE_string(nnet_encoder_outs_wspecifier, "", "nnet encoder outs wspecifier"); DEFINE_string(model_path, "", "paddle nnet model"); @@ -52,9 +53,10 @@ int main(int argc, char* argv[]) { LOG(INFO) << "input rspecifier: " << FLAGS_feature_rspecifier; LOG(INFO) << "output wspecifier: " << FLAGS_nnet_prob_wspecifier; LOG(INFO) << "model path: " << FLAGS_model_path; - kaldi::SequentialBaseFloatMatrixReader feature_reader( - FLAGS_feature_rspecifier); + + kaldi::SequentialBaseFloatMatrixReader feature_reader(FLAGS_feature_rspecifier); kaldi::BaseFloatMatrixWriter nnet_out_writer(FLAGS_nnet_prob_wspecifier); + kaldi::BaseFloatMatrixWriter nnet_encoder_outs_writer(FLAGS_nnet_encoder_outs_wspecifier); ppspeech::U2ModelOptions model_opts; model_opts.model_path = FLAGS_model_path; @@ -97,6 +99,7 @@ int main(int argc, char* argv[]) { int32 frame_idx = 0; std::vector> prob_vec; + std::vector> encoder_out_vec; int32 ori_feature_len = feature.NumRows(); int32 num_chunks = feature.NumRows() / chunk_stride + 1; LOG(INFO) << "num_chunks: " << num_chunks; @@ -144,29 +147,51 @@ int main(int argc, char* argv[]) { prob_vec.push_back(vec_tmp); frame_idx++; } + + } + // get encoder out + decodable->Nnet()->EncoderOuts(&encoder_out_vec); + // after process one utt, then reset decoder state. decodable->Reset(); - if (prob_vec.size() == 0) { + if (prob_vec.size() == 0 || encoder_out_vec.size() == 0) { // the TokenWriter can not write empty string. ++num_err; - LOG(WARNING) << " the nnet prob of " << utt << " is empty"; + LOG(WARNING) << " the nnet prob/encoder_out of " << utt << " is empty"; continue; } - // writer nnet output - kaldi::MatrixIndexT nrow = prob_vec.size(); - kaldi::MatrixIndexT ncol = prob_vec[0].Dim(); - LOG(INFO) << "nnet out shape: " << nrow << ", " << ncol; - kaldi::Matrix result(nrow, ncol); - for (int32 row_idx = 0; row_idx < nrow; ++row_idx) { - for (int32 col_idx = 0; col_idx < ncol; ++col_idx) { - result(row_idx, col_idx) = prob_vec[row_idx](col_idx); + { + // writer nnet output + kaldi::MatrixIndexT nrow = prob_vec.size(); + kaldi::MatrixIndexT ncol = prob_vec[0].Dim(); + LOG(INFO) << "nnet out shape: " << nrow << ", " << ncol; + kaldi::Matrix nnet_out(nrow, ncol); + for (int32 row_idx = 0; row_idx < nrow; ++row_idx) { + for (int32 col_idx = 0; col_idx < ncol; ++col_idx) { + nnet_out(row_idx, col_idx) = prob_vec[row_idx](col_idx); + } + } + nnet_out_writer.Write(utt, nnet_out); + } + + + { + // writer nnet encoder outs + kaldi::MatrixIndexT nrow = encoder_out_vec.size(); + kaldi::MatrixIndexT ncol = encoder_out_vec[0].Dim(); + LOG(INFO) << "nnet encoder outs shape: " << nrow << ", " << ncol; + kaldi::Matrix encoder_outs(nrow, ncol); + for (int32 row_idx = 0; row_idx < nrow; ++row_idx) { + for (int32 col_idx = 0; col_idx < ncol; ++col_idx) { + encoder_outs(row_idx, col_idx) = encoder_out_vec[row_idx](col_idx); + } } + nnet_encoder_outs_writer.Write(utt, encoder_outs); } - nnet_out_writer.Write(utt, result); ++num_done; }