enable loading language model in multiple format

pull/2/head
Yibing Liu 8 years ago
parent 5bfa066920
commit ccea7c0150

@ -14,6 +14,7 @@ from swig_ctc_beam_search_decoder import *
from swig_scorer import Scorer from swig_scorer import Scorer
from error_rate import wer from error_rate import wer
import utils import utils
import time
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument( parser.add_argument(
@ -74,7 +75,7 @@ parser.add_argument(
) )
parser.add_argument( parser.add_argument(
"--beam_size", "--beam_size",
default=500, default=200,
type=int, type=int,
help="Width for beam search decoding. (default: %(default)d)") help="Width for beam search decoding. (default: %(default)d)")
parser.add_argument( parser.add_argument(
@ -166,6 +167,7 @@ def infer():
ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path) ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path)
## decode and print ## decode and print
time_begin = time.time()
wer_sum, wer_counter = 0, 0 wer_sum, wer_counter = 0, 0
for i, probs in enumerate(probs_split): for i, probs in enumerate(probs_split):
beam_result = ctc_beam_search_decoder( beam_result = ctc_beam_search_decoder(
@ -183,6 +185,8 @@ def infer():
wer_counter += 1 wer_counter += 1
print("cur wer = %f , average wer = %f" % print("cur wer = %f , average wer = %f" %
(wer_cur, wer_sum / wer_counter)) (wer_cur, wer_sum / wer_counter))
time_end = time.time()
print("total time = %f" % (time_end - time_begin))
def main(): def main():

@ -1,4 +1,5 @@
#include <iostream> #include <iostream>
#include <unistd.h>
#include "scorer.h" #include "scorer.h"
#include "lm/model.hh" #include "lm/model.hh"
#include "util/tokenize_piece.hh" #include "util/tokenize_piece.hh"
@ -9,11 +10,16 @@ using namespace lm::ngram;
Scorer::Scorer(float alpha, float beta, std::string lm_model_path) { Scorer::Scorer(float alpha, float beta, std::string lm_model_path) {
this->_alpha = alpha; this->_alpha = alpha;
this->_beta = beta; this->_beta = beta;
this->_language_model = new Model(lm_model_path.c_str());
if (access(lm_model_path.c_str(), F_OK) != 0) {
std::cout<<"Invalid language model path!"<<std::endl;
exit(1);
}
this->_language_model = LoadVirtual(lm_model_path.c_str());
} }
Scorer::~Scorer(){ Scorer::~Scorer(){
delete (Model *)this->_language_model; delete (lm::base::Model *)this->_language_model;
} }
/* Strip a input sentence /* Strip a input sentence
@ -63,14 +69,14 @@ int Scorer::word_count(std::string sentence) {
} }
double Scorer::language_model_score(std::string sentence) { double Scorer::language_model_score(std::string sentence) {
Model *model = (Model *)this->_language_model; lm::base::Model *model = (lm::base::Model *)this->_language_model;
State state, out_state; State state, out_state;
lm::FullScoreReturn ret; lm::FullScoreReturn ret;
state = model->BeginSentenceState(); model->BeginSentenceWrite(&state);
for (util::TokenIter<util::SingleCharacter, true> it(sentence, ' '); it; ++it){ for (util::TokenIter<util::SingleCharacter, true> it(sentence, ' '); it; ++it){
lm::WordIndex vocab = model->GetVocabulary().Index(*it); lm::WordIndex wid = model->BaseVocabulary().Index(*it);
ret = model->FullScore(state, vocab, out_state); ret = model->BaseFullScore(&state, wid, &out_state);
state = out_state; state = out_state;
} }
//log10 prob //log10 prob

@ -3,9 +3,9 @@ echo "Run decoder setup ..."
python decoder_setup.py install python decoder_setup.py install
rm -r ./build rm -r ./build
echo "\nRun scorer setup ..." echo "Run scorer setup ..."
python scorer_setup.py install python scorer_setup.py install
rm -r ./build rm -r ./build
echo "\nFinish the installation of decoder and scorer." echo "Finish the installation of decoder and scorer."

Loading…
Cancel
Save