parent
d14ee80065
commit
e90438289d
@ -0,0 +1,58 @@
|
||||
// todo refactor, repalce with gtest
|
||||
|
||||
#include "decoder/ctc_beam_search_decoder.h"
|
||||
#include "kaldi/util/table-types.h"
|
||||
#include "base/log.h"
|
||||
#include "base/flags.h"
|
||||
|
||||
DEFINE_string(feature_respecifier, "", "test nnet prob");
|
||||
|
||||
using kaldi::BaseFloat;
|
||||
|
||||
void SplitFeature(kaldi::Matrix<BaseFloat> feature,
|
||||
int32 chunk_size,
|
||||
std::vector<kaldi::Matrix<BaseFloat>> feature_chunks) {
|
||||
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
|
||||
kaldi::SequentialBaseFloatMatrixReader feature_reader(FLAGS_feature_respecifier);
|
||||
|
||||
// test nnet_output --> decoder result
|
||||
int32 num_done = 0, num_err = 0;
|
||||
|
||||
CTCBeamSearchOptions opts;
|
||||
CTCBeamSearch decoder(opts);
|
||||
|
||||
ModelOptions model_opts;
|
||||
std::shared_ptr<PaddleNnet> nnet(new PaddleNnet(model_opts));
|
||||
|
||||
Decodable decodable();
|
||||
decodable.SetNnet(nnet);
|
||||
|
||||
int32 chunk_size = 0;
|
||||
for (; !feature_reader.Done(); feature_reader.Next()) {
|
||||
string utt = feature_reader.Key();
|
||||
const kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
|
||||
vector<Matrix<BaseFloat>> feature_chunks;
|
||||
SplitFeature(feature, chunk_size, &feature_chunks);
|
||||
for (auto feature_chunk : feature_chunks) {
|
||||
decodable.FeedFeatures(feature_chunk);
|
||||
decoder.InitDecoder();
|
||||
decoder.AdvanceDecode(decodable, chunk_size);
|
||||
}
|
||||
decodable.InputFinished();
|
||||
std::string result;
|
||||
result = decoder.GetFinalBestPath();
|
||||
KALDI_LOG << " the result of " << utt << " is " << result;
|
||||
decodable.Reset();
|
||||
++num_done;
|
||||
}
|
||||
|
||||
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
|
||||
<< " with errors.";
|
||||
return (num_done != 0 ? 0 : 1);
|
||||
}
|
||||
@ -0,0 +1,38 @@
|
||||
#include "nnet/decodable.h"
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
Decodable::Acceptlikelihood(const kaldi::Matrix<BaseFloat>& likelihood) {
|
||||
frames_ready_ += likelihood.NumRows();
|
||||
}
|
||||
|
||||
Decodable::Init(DecodableConfig config) {
|
||||
|
||||
}
|
||||
|
||||
Decodable::IsLastFrame(int32 frame) const {
|
||||
CHECK_LE(frame, frames_ready_);
|
||||
return finished_ && (frame == frames_ready_ - 1);
|
||||
}
|
||||
|
||||
int32 Decodable::NumIndices() const {
|
||||
return 0;
|
||||
}
|
||||
|
||||
void Decodable::LogLikelihood(int32 frame, int32 index) {
|
||||
return ;
|
||||
}
|
||||
|
||||
void Decodable::FeedFeatures(const kaldi::Matrix<kaldi::BaseFloat>& features) {
|
||||
// skip frame ???
|
||||
nnet_->FeedForward(features, &nnet_cache_);
|
||||
frames_ready_ += nnet_cache_.NumRows();
|
||||
return ;
|
||||
}
|
||||
|
||||
void Decodable::Reset() {
|
||||
// frontend_.Reset();
|
||||
nnet_->Reset();
|
||||
}
|
||||
|
||||
} // namespace ppspeech
|
||||
Loading…
Reference in new issue