|
|
@ -25,65 +25,6 @@ using paddle::platform::TracerEventType;
|
|
|
|
|
|
|
|
|
|
|
|
namespace ppspeech {
|
|
|
|
namespace ppspeech {
|
|
|
|
|
|
|
|
|
|
|
|
int U2NnetBase::num_frames_for_chunk(bool start) const {
|
|
|
|
|
|
|
|
int num_needed_frames = 0; // num feat frames
|
|
|
|
|
|
|
|
bool first = !start; // start == false is first
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (chunk_size_ > 0) {
|
|
|
|
|
|
|
|
// streaming mode
|
|
|
|
|
|
|
|
if (first) {
|
|
|
|
|
|
|
|
// first chunk
|
|
|
|
|
|
|
|
// 1 decoder frame need `context` feat frames
|
|
|
|
|
|
|
|
int context = this->context();
|
|
|
|
|
|
|
|
num_needed_frames = (chunk_size_ - 1) * subsampling_rate_ + context;
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
// after first chunk, we need stride this num frames.
|
|
|
|
|
|
|
|
num_needed_frames = chunk_size_ * subsampling_rate_;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
// non-streaming mode. feed all feats once.
|
|
|
|
|
|
|
|
num_needed_frames = std::numeric_limits<int>::max();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return num_needed_frames;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// cache feats for next chunk
|
|
|
|
|
|
|
|
void U2NnetBase::CacheFeature(const std::vector<kaldi::BaseFloat>& chunk_feats,
|
|
|
|
|
|
|
|
int32 feat_dim) {
|
|
|
|
|
|
|
|
// chunk_feats is nframes*feat_dim
|
|
|
|
|
|
|
|
const int chunk_size = chunk_feats.size() / feat_dim;
|
|
|
|
|
|
|
|
const int cached_feat_size = this->context() - subsampling_rate_;
|
|
|
|
|
|
|
|
if (chunk_size >= cached_feat_size) {
|
|
|
|
|
|
|
|
cached_feats_.resize(cached_feat_size);
|
|
|
|
|
|
|
|
for (int i = 0; i < cached_feat_size; ++i) {
|
|
|
|
|
|
|
|
auto start =
|
|
|
|
|
|
|
|
chunk_feats.begin() + chunk_size - cached_feat_size + i;
|
|
|
|
|
|
|
|
auto end = start + feat_dim;
|
|
|
|
|
|
|
|
cached_feats_[i] = std::vector<float>(start, end);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void U2NnetBase::ForwardEncoderChunk(
|
|
|
|
|
|
|
|
const std::vector<kaldi::BaseFloat>& chunk_feats,
|
|
|
|
|
|
|
|
const int32& feat_dim,
|
|
|
|
|
|
|
|
std::vector<kaldi::BaseFloat>* ctc_probs,
|
|
|
|
|
|
|
|
int32* vocab_dim) {
|
|
|
|
|
|
|
|
ctc_probs->clear();
|
|
|
|
|
|
|
|
// int num_frames = cached_feats_.size() + chunk_feats.size();
|
|
|
|
|
|
|
|
int num_frames = chunk_feats.size() / feat_dim;
|
|
|
|
|
|
|
|
VLOG(3) << "foward encoder chunk: " << num_frames << " frames";
|
|
|
|
|
|
|
|
VLOG(3) << "context: " << this->context() << " frames";
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (num_frames >= this->context()) {
|
|
|
|
|
|
|
|
this->ForwardEncoderChunkImpl(
|
|
|
|
|
|
|
|
chunk_feats, feat_dim, ctc_probs, vocab_dim);
|
|
|
|
|
|
|
|
VLOG(3) << "after forward chunk";
|
|
|
|
|
|
|
|
this->CacheFeature(chunk_feats, feat_dim);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void U2Nnet::LoadModel(const std::string& model_path_w_prefix) {
|
|
|
|
void U2Nnet::LoadModel(const std::string& model_path_w_prefix) {
|
|
|
|
paddle::jit::utils::InitKernelSignatureMap();
|
|
|
|
paddle::jit::utils::InitKernelSignatureMap();
|
|
|
@ -188,7 +129,7 @@ U2Nnet::U2Nnet(const U2Nnet& other) {
|
|
|
|
forward_attention_decoder_ = other.forward_attention_decoder_;
|
|
|
|
forward_attention_decoder_ = other.forward_attention_decoder_;
|
|
|
|
ctc_activation_ = other.ctc_activation_;
|
|
|
|
ctc_activation_ = other.ctc_activation_;
|
|
|
|
|
|
|
|
|
|
|
|
// offset_ = other.offset_; // TODO: not used in nnets
|
|
|
|
offset_ = other.offset_;
|
|
|
|
|
|
|
|
|
|
|
|
// copy model ptr
|
|
|
|
// copy model ptr
|
|
|
|
model_ = other.model_;
|
|
|
|
model_ = other.model_;
|
|
|
@ -204,8 +145,7 @@ std::shared_ptr<NnetBase> U2Nnet::Copy() const {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void U2Nnet::Reset() {
|
|
|
|
void U2Nnet::Reset() {
|
|
|
|
// offset_ = 0;
|
|
|
|
offset_ = 0;
|
|
|
|
// cached_feats_.clear(); // TODO: not used in nnets
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
att_cache_ =
|
|
|
|
att_cache_ =
|
|
|
|
std::move(paddle::zeros({0, 0, 0, 0}, paddle::DataType::FLOAT32));
|
|
|
|
std::move(paddle::zeros({0, 0, 0, 0}, paddle::DataType::FLOAT32));
|
|
|
@ -263,16 +203,6 @@ void U2Nnet::ForwardEncoderChunkImpl(
|
|
|
|
paddle::zeros({1, num_frames, feat_dim}, paddle::DataType::FLOAT32);
|
|
|
|
paddle::zeros({1, num_frames, feat_dim}, paddle::DataType::FLOAT32);
|
|
|
|
float* feats_ptr = feats.mutable_data<float>();
|
|
|
|
float* feats_ptr = feats.mutable_data<float>();
|
|
|
|
|
|
|
|
|
|
|
|
// for (size_t i = 0; i < cached_feats_.size(); ++i) {
|
|
|
|
|
|
|
|
// float* row = feats_ptr + i * feat_dim;
|
|
|
|
|
|
|
|
// std::memcpy(row, cached_feats_[i].data(), feat_dim * sizeof(float));
|
|
|
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// for (size_t i = 0; i < chunk_feats.size(); ++i) {
|
|
|
|
|
|
|
|
// float* row = feats_ptr + (cached_feats_.size() + i) * feat_dim;
|
|
|
|
|
|
|
|
// std::memcpy(row, chunk_feats[i].data(), feat_dim * sizeof(float));
|
|
|
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// not cache feature in nnet
|
|
|
|
// not cache feature in nnet
|
|
|
|
CHECK(cached_feats_.size() == 0);
|
|
|
|
CHECK(cached_feats_.size() == 0);
|
|
|
|
// CHECK_EQ(std::is_same<float, kaldi::BaseFloat>::value, true);
|
|
|
|
// CHECK_EQ(std::is_same<float, kaldi::BaseFloat>::value, true);
|
|
|
@ -427,15 +357,6 @@ void U2Nnet::ForwardEncoderChunkImpl(
|
|
|
|
|
|
|
|
|
|
|
|
float* ctc_log_probs_ptr = ctc_log_probs.data<float>();
|
|
|
|
float* ctc_log_probs_ptr = ctc_log_probs.data<float>();
|
|
|
|
|
|
|
|
|
|
|
|
// // vector<vector<float>>
|
|
|
|
|
|
|
|
// out_prob->resize(T);
|
|
|
|
|
|
|
|
// for (int i = 0; i < T; i++) {
|
|
|
|
|
|
|
|
// (*out_prob)[i].resize(D);
|
|
|
|
|
|
|
|
// float* dst_ptr = (*out_prob)[i].data();
|
|
|
|
|
|
|
|
// float* src_ptr = ctc_log_probs_ptr + (i * D);
|
|
|
|
|
|
|
|
// std::memcpy(dst_ptr, src_ptr, D * sizeof(float));
|
|
|
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
// CHECK(std::is_same<float, kaldi::BaseFloat>::value);
|
|
|
|
|
|
|
|
out_prob->resize(T * D);
|
|
|
|
out_prob->resize(T * D);
|
|
|
|
std::memcpy(
|
|
|
|
std::memcpy(
|
|
|
|
out_prob->data(), ctc_log_probs_ptr, T * D * sizeof(kaldi::BaseFloat));
|
|
|
|
out_prob->data(), ctc_log_probs_ptr, T * D * sizeof(kaldi::BaseFloat));
|
|
|
|