parent
d14ee80065
commit
e90438289d
@ -0,0 +1,58 @@
|
|||||||
|
// todo refactor, repalce with gtest
|
||||||
|
|
||||||
|
#include "decoder/ctc_beam_search_decoder.h"
|
||||||
|
#include "kaldi/util/table-types.h"
|
||||||
|
#include "base/log.h"
|
||||||
|
#include "base/flags.h"
|
||||||
|
|
||||||
|
DEFINE_string(feature_respecifier, "", "test nnet prob");
|
||||||
|
|
||||||
|
using kaldi::BaseFloat;
|
||||||
|
|
||||||
|
void SplitFeature(kaldi::Matrix<BaseFloat> feature,
|
||||||
|
int32 chunk_size,
|
||||||
|
std::vector<kaldi::Matrix<BaseFloat>> feature_chunks) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
int main(int argc, char* argv[]) {
|
||||||
|
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
||||||
|
google::InitGoogleLogging(argv[0]);
|
||||||
|
|
||||||
|
kaldi::SequentialBaseFloatMatrixReader feature_reader(FLAGS_feature_respecifier);
|
||||||
|
|
||||||
|
// test nnet_output --> decoder result
|
||||||
|
int32 num_done = 0, num_err = 0;
|
||||||
|
|
||||||
|
CTCBeamSearchOptions opts;
|
||||||
|
CTCBeamSearch decoder(opts);
|
||||||
|
|
||||||
|
ModelOptions model_opts;
|
||||||
|
std::shared_ptr<PaddleNnet> nnet(new PaddleNnet(model_opts));
|
||||||
|
|
||||||
|
Decodable decodable();
|
||||||
|
decodable.SetNnet(nnet);
|
||||||
|
|
||||||
|
int32 chunk_size = 0;
|
||||||
|
for (; !feature_reader.Done(); feature_reader.Next()) {
|
||||||
|
string utt = feature_reader.Key();
|
||||||
|
const kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
|
||||||
|
vector<Matrix<BaseFloat>> feature_chunks;
|
||||||
|
SplitFeature(feature, chunk_size, &feature_chunks);
|
||||||
|
for (auto feature_chunk : feature_chunks) {
|
||||||
|
decodable.FeedFeatures(feature_chunk);
|
||||||
|
decoder.InitDecoder();
|
||||||
|
decoder.AdvanceDecode(decodable, chunk_size);
|
||||||
|
}
|
||||||
|
decodable.InputFinished();
|
||||||
|
std::string result;
|
||||||
|
result = decoder.GetFinalBestPath();
|
||||||
|
KALDI_LOG << " the result of " << utt << " is " << result;
|
||||||
|
decodable.Reset();
|
||||||
|
++num_done;
|
||||||
|
}
|
||||||
|
|
||||||
|
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
|
||||||
|
<< " with errors.";
|
||||||
|
return (num_done != 0 ? 0 : 1);
|
||||||
|
}
|
@ -0,0 +1,38 @@
|
|||||||
|
#include "nnet/decodable.h"
|
||||||
|
|
||||||
|
namespace ppspeech {
|
||||||
|
|
||||||
|
Decodable::Acceptlikelihood(const kaldi::Matrix<BaseFloat>& likelihood) {
|
||||||
|
frames_ready_ += likelihood.NumRows();
|
||||||
|
}
|
||||||
|
|
||||||
|
Decodable::Init(DecodableConfig config) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
Decodable::IsLastFrame(int32 frame) const {
|
||||||
|
CHECK_LE(frame, frames_ready_);
|
||||||
|
return finished_ && (frame == frames_ready_ - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
int32 Decodable::NumIndices() const {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Decodable::LogLikelihood(int32 frame, int32 index) {
|
||||||
|
return ;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Decodable::FeedFeatures(const kaldi::Matrix<kaldi::BaseFloat>& features) {
|
||||||
|
// skip frame ???
|
||||||
|
nnet_->FeedForward(features, &nnet_cache_);
|
||||||
|
frames_ready_ += nnet_cache_.NumRows();
|
||||||
|
return ;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Decodable::Reset() {
|
||||||
|
// frontend_.Reset();
|
||||||
|
nnet_->Reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace ppspeech
|
Loading…
Reference in new issue