u2 nnet get encoder out and align with py

pull/2524/head
Hui Zhang 3 years ago
parent a75abc1828
commit 5cc874e1c3

@ -40,6 +40,7 @@ cmvn_json2kaldi_main \
--json_file $model_dir/mean_std.json \
--cmvn_write_path $exp/cmvn.ark \
--binary=false
echo "convert json cmvn to kaldi ark."
compute_fbank_main \
@ -47,6 +48,7 @@ compute_fbank_main \
--wav_rspecifier=scp:$data/wav.scp \
--cmvn_file=$exp/cmvn.ark \
--feature_wspecifier=ark,t:$exp/fbank.ark
echo "compute fbank feature."
u2_nnet_main \
@ -56,4 +58,7 @@ u2_nnet_main \
--receptive_field_length=7 \
--downsampling_rate=4 \
--acoustic_scale=1.0 \
--nnet_prob_wspecifier=ark,t:$exp/probs.ark
--nnet_encoder_outs_wspecifier=ark,t:$exp/encoder_outs.ark \
--nnet_prob_wspecifier=ark,t:$exp/logprobs.ark
echo "u2 nnet decode."

@ -55,6 +55,8 @@ class Decodable : public kaldi::DecodableInterface {
int32 TokenId2NnetId(int32 token_id);
std::shared_ptr<NnetInterface> Nnet() { return nnet_; }
private:
bool AdvanceChunk();

@ -96,16 +96,22 @@ class PaddleNnet : public NnetInterface {
public:
PaddleNnet(const ModelOptions& opts);
virtual void FeedForward(const kaldi::Vector<kaldi::BaseFloat>& features,
const int32& feature_dim,
NnetOut* out);
void FeedForward(const kaldi::Vector<kaldi::BaseFloat>& features,
const int32& feature_dim,
NnetOut* out) override;
void Dim();
virtual void Reset();
void Reset() override;
std::shared_ptr<Tensor<kaldi::BaseFloat>> GetCacheEncoder(
const std::string& name);
void InitCacheEncouts(const ModelOptions& opts);
void EncoderOuts(std::vector<kaldi::Vector<kaldi::BaseFloat>>* encoder_out)
const override {}
private:
paddle_infer::Predictor* GetPredictor();
int ReleasePredictor(paddle_infer::Predictor* predictor);

@ -22,7 +22,7 @@
namespace ppspeech {
struct NnetOut {
// nnet out, maybe logprob or prob
// nnet out. maybe logprob or prob. Almost time this is logprob.
kaldi::Vector<kaldi::BaseFloat> logprobs;
int32 vocab_dim;
@ -35,11 +35,21 @@ struct NnetOut {
class NnetInterface {
public:
virtual ~NnetInterface() {}
// forward feat with nnet.
// nnet do not cache feats, feats cached by frontend.
// nnet cache model outputs, i.e. logprobs/encoder_outs.
virtual void FeedForward(const kaldi::Vector<kaldi::BaseFloat>& features,
const int32& feature_dim,
NnetOut* out) = 0;
// reset nnet state, e.g. nnet_logprob_cache_, offset_, encoder_outs_.
virtual void Reset() = 0;
virtual ~NnetInterface() {}
// using to get encoder outs. e.g. seq2seq with Attention model.
virtual void EncoderOuts(
std::vector<kaldi::Vector<kaldi::BaseFloat>>* encoder_out) const = 0;
};
} // namespace ppspeech

@ -705,4 +705,30 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
}
}
void U2Nnet::EncoderOuts(std::vector<kaldi::Vector<kaldi::BaseFloat>>* encoder_out) const {
// list of (B=1,T,D)
int size = encoder_outs_.size();
VLOG(1) << "encoder_outs_ size: " << size;
for (int i = 0; i < size; i++){
const paddle::Tensor& item = encoder_outs_[i];
const std::vector<int64_t> shape = item.shape();
CHECK(shape.size() == 3);
const int& B = shape[0];
const int& T = shape[1];
const int& D = shape[2];
CHECK(B == 1) << "Only support batch one.";
VLOG(1) << "encoder out " << i << " shape: (" << B << "," << T << "," << D << ")";
const float *this_tensor_ptr = item.data<float>();
for (int j = 0; j < T; j++){
const float* cur = this_tensor_ptr + j * D;
kaldi::Vector<kaldi::BaseFloat> out(D);
std::memcpy(out.Data(), cur, D * sizeof(kaldi::BaseFloat));
encoder_out->emplace_back(out);
}
}
}
} // namespace ppspeech

@ -137,9 +137,8 @@ class U2Nnet : public U2NnetBase {
// debug
void FeedEncoderOuts(paddle::Tensor& encoder_out);
const std::vector<paddle::Tensor>& EncoderOuts() const {
return encoder_outs_;
}
void EncoderOuts(
std::vector<kaldi::Vector<kaldi::BaseFloat>>* encoder_out) const;
private:
U2ModelOptions opts_;

@ -21,6 +21,7 @@
DEFINE_string(feature_rspecifier, "", "test feature rspecifier");
DEFINE_string(nnet_prob_wspecifier, "", "nnet porb wspecifier");
DEFINE_string(nnet_encoder_outs_wspecifier, "", "nnet encoder outs wspecifier");
DEFINE_string(model_path, "", "paddle nnet model");
@ -52,9 +53,10 @@ int main(int argc, char* argv[]) {
LOG(INFO) << "input rspecifier: " << FLAGS_feature_rspecifier;
LOG(INFO) << "output wspecifier: " << FLAGS_nnet_prob_wspecifier;
LOG(INFO) << "model path: " << FLAGS_model_path;
kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_rspecifier);
kaldi::SequentialBaseFloatMatrixReader feature_reader(FLAGS_feature_rspecifier);
kaldi::BaseFloatMatrixWriter nnet_out_writer(FLAGS_nnet_prob_wspecifier);
kaldi::BaseFloatMatrixWriter nnet_encoder_outs_writer(FLAGS_nnet_encoder_outs_wspecifier);
ppspeech::U2ModelOptions model_opts;
model_opts.model_path = FLAGS_model_path;
@ -97,6 +99,7 @@ int main(int argc, char* argv[]) {
int32 frame_idx = 0;
std::vector<kaldi::Vector<kaldi::BaseFloat>> prob_vec;
std::vector<kaldi::Vector<kaldi::BaseFloat>> encoder_out_vec;
int32 ori_feature_len = feature.NumRows();
int32 num_chunks = feature.NumRows() / chunk_stride + 1;
LOG(INFO) << "num_chunks: " << num_chunks;
@ -144,29 +147,51 @@ int main(int argc, char* argv[]) {
prob_vec.push_back(vec_tmp);
frame_idx++;
}
}
// get encoder out
decodable->Nnet()->EncoderOuts(&encoder_out_vec);
// after process one utt, then reset decoder state.
decodable->Reset();
if (prob_vec.size() == 0) {
if (prob_vec.size() == 0 || encoder_out_vec.size() == 0) {
// the TokenWriter can not write empty string.
++num_err;
LOG(WARNING) << " the nnet prob of " << utt << " is empty";
LOG(WARNING) << " the nnet prob/encoder_out of " << utt << " is empty";
continue;
}
// writer nnet output
kaldi::MatrixIndexT nrow = prob_vec.size();
kaldi::MatrixIndexT ncol = prob_vec[0].Dim();
LOG(INFO) << "nnet out shape: " << nrow << ", " << ncol;
kaldi::Matrix<kaldi::BaseFloat> result(nrow, ncol);
for (int32 row_idx = 0; row_idx < nrow; ++row_idx) {
for (int32 col_idx = 0; col_idx < ncol; ++col_idx) {
result(row_idx, col_idx) = prob_vec[row_idx](col_idx);
{
// writer nnet output
kaldi::MatrixIndexT nrow = prob_vec.size();
kaldi::MatrixIndexT ncol = prob_vec[0].Dim();
LOG(INFO) << "nnet out shape: " << nrow << ", " << ncol;
kaldi::Matrix<kaldi::BaseFloat> nnet_out(nrow, ncol);
for (int32 row_idx = 0; row_idx < nrow; ++row_idx) {
for (int32 col_idx = 0; col_idx < ncol; ++col_idx) {
nnet_out(row_idx, col_idx) = prob_vec[row_idx](col_idx);
}
}
nnet_out_writer.Write(utt, nnet_out);
}
{
// writer nnet encoder outs
kaldi::MatrixIndexT nrow = encoder_out_vec.size();
kaldi::MatrixIndexT ncol = encoder_out_vec[0].Dim();
LOG(INFO) << "nnet encoder outs shape: " << nrow << ", " << ncol;
kaldi::Matrix<kaldi::BaseFloat> encoder_outs(nrow, ncol);
for (int32 row_idx = 0; row_idx < nrow; ++row_idx) {
for (int32 col_idx = 0; col_idx < ncol; ++col_idx) {
encoder_outs(row_idx, col_idx) = encoder_out_vec[row_idx](col_idx);
}
}
nnet_encoder_outs_writer.Write(utt, encoder_outs);
}
nnet_out_writer.Write(utt, result);
++num_done;
}

Loading…
Cancel
Save