|
|
|
@ -101,10 +101,10 @@ void U2Nnet::Warmup() {
|
|
|
|
|
auto encoder_out = paddle::ones(
|
|
|
|
|
{1, 20, 512}, paddle::DataType::FLOAT32, phi::CPUPlace());
|
|
|
|
|
|
|
|
|
|
std::vector<paddle::experimental::Tensor> inputs{
|
|
|
|
|
std::vector<paddle::Tensor> inputs{
|
|
|
|
|
hyps, hyps_lens, encoder_out};
|
|
|
|
|
|
|
|
|
|
std::vector<paddle::experimental::Tensor> outputs =
|
|
|
|
|
std::vector<paddle::Tensor> outputs =
|
|
|
|
|
forward_attention_decoder_(inputs);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -523,7 +523,7 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
|
|
|
|
|
}
|
|
|
|
|
#endif // end TEST_DEBUG
|
|
|
|
|
|
|
|
|
|
std::vector<paddle::experimental::Tensor> inputs{
|
|
|
|
|
std::vector<paddle::Tensor> inputs{
|
|
|
|
|
hyps_tensor, hyps_lens, encoder_out};
|
|
|
|
|
std::vector<paddle::Tensor> outputs = forward_attention_decoder_(inputs);
|
|
|
|
|
CHECK_EQ(outputs.size(), 2);
|
|
|
|
@ -594,7 +594,7 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
|
|
|
|
|
} else {
|
|
|
|
|
// dump r_probs
|
|
|
|
|
CHECK_EQ(r_probs_shape.size(), 1);
|
|
|
|
|
CHECK_EQ(r_probs_shape[0], 1) << r_probs_shape[0];
|
|
|
|
|
//CHECK_EQ(r_probs_shape[0], 1) << r_probs_shape[0];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// compute rescoring score
|
|
|
|
@ -604,15 +604,15 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
|
|
|
|
|
VLOG(2) << "split prob: " << probs_v.size() << " "
|
|
|
|
|
<< probs_v[0].shape().size() << " 0: " << probs_v[0].shape()[0]
|
|
|
|
|
<< ", " << probs_v[0].shape()[1] << ", " << probs_v[0].shape()[2];
|
|
|
|
|
CHECK(static_cast<int>(probs_v.size()) == num_hyps)
|
|
|
|
|
<< ": is " << probs_v.size() << " expect: " << num_hyps;
|
|
|
|
|
//CHECK(static_cast<int>(probs_v.size()) == num_hyps)
|
|
|
|
|
// << ": is " << probs_v.size() << " expect: " << num_hyps;
|
|
|
|
|
|
|
|
|
|
std::vector<paddle::Tensor> r_probs_v;
|
|
|
|
|
if (is_bidecoder_ && reverse_weight > 0) {
|
|
|
|
|
r_probs_v = paddle::experimental::split_with_num(r_probs, num_hyps, 0);
|
|
|
|
|
CHECK(static_cast<int>(r_probs_v.size()) == num_hyps)
|
|
|
|
|
<< "r_probs_v size: is " << r_probs_v.size()
|
|
|
|
|
<< " expect: " << num_hyps;
|
|
|
|
|
//CHECK(static_cast<int>(r_probs_v.size()) == num_hyps)
|
|
|
|
|
// << "r_probs_v size: is " << r_probs_v.size()
|
|
|
|
|
// << " expect: " << num_hyps;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < num_hyps; ++i) {
|
|
|
|
@ -654,7 +654,7 @@ void U2Nnet::EncoderOuts(
|
|
|
|
|
const int& B = shape[0];
|
|
|
|
|
const int& T = shape[1];
|
|
|
|
|
const int& D = shape[2];
|
|
|
|
|
CHECK(B == 1) << "Only support batch one.";
|
|
|
|
|
//CHECK(B == 1) << "Only support batch one.";
|
|
|
|
|
VLOG(3) << "encoder out " << i << " shape: (" << B << "," << T << ","
|
|
|
|
|
<< D << ")";
|
|
|
|
|
|
|
|
|
@ -668,4 +668,4 @@ void U2Nnet::EncoderOuts(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace ppspeech
|
|
|
|
|
} // namespace ppspeech
|
|
|
|
|