commit
2f15a78707
@ -0,0 +1,143 @@
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include <cmath>
|
||||
#include "ctc_beam_search_decoder.h"
|
||||
|
||||
template <typename T1, typename T2>
|
||||
bool pair_comp_first_rev(const std::pair<T1, T2> a, const std::pair<T1, T2> b) {
|
||||
return a.first > b.first;
|
||||
}
|
||||
|
||||
template <typename T1, typename T2>
|
||||
bool pair_comp_second_rev(const std::pair<T1, T2> a, const std::pair<T1, T2> b) {
|
||||
return a.second > b.second;
|
||||
}
|
||||
|
||||
/* CTC beam search decoder in C++, the interface is consistent with the original
|
||||
decoder in Python version.
|
||||
*/
|
||||
std::vector<std::pair<double, std::string> >
|
||||
ctc_beam_search_decoder(std::vector<std::vector<double> > probs_seq,
|
||||
int beam_size,
|
||||
std::vector<std::string> vocabulary,
|
||||
int blank_id,
|
||||
double cutoff_prob,
|
||||
Scorer *ext_scorer,
|
||||
bool nproc
|
||||
)
|
||||
{
|
||||
int num_time_steps = probs_seq.size();
|
||||
|
||||
// assign space ID
|
||||
std::vector<std::string>::iterator it = std::find(vocabulary.begin(), vocabulary.end(), " ");
|
||||
int space_id = it-vocabulary.begin();
|
||||
if(space_id >= vocabulary.size()) {
|
||||
std::cout<<"The character space is not in the vocabulary!";
|
||||
exit(1);
|
||||
}
|
||||
|
||||
// initialize
|
||||
// two sets containing selected and candidate prefixes respectively
|
||||
std::map<std::string, double> prefix_set_prev, prefix_set_next;
|
||||
// probability of prefixes ending with blank and non-blank
|
||||
std::map<std::string, double> probs_b_prev, probs_nb_prev;
|
||||
std::map<std::string, double> probs_b_cur, probs_nb_cur;
|
||||
prefix_set_prev["\t"] = 1.0;
|
||||
probs_b_prev["\t"] = 1.0;
|
||||
probs_nb_prev["\t"] = 0.0;
|
||||
|
||||
for (int time_step=0; time_step<num_time_steps; time_step++) {
|
||||
prefix_set_next.clear();
|
||||
probs_b_cur.clear();
|
||||
probs_nb_cur.clear();
|
||||
std::vector<double> prob = probs_seq[time_step];
|
||||
|
||||
std::vector<std::pair<int, double> > prob_idx;
|
||||
for (int i=0; i<prob.size(); i++) {
|
||||
prob_idx.push_back(std::pair<int, double>(i, prob[i]));
|
||||
}
|
||||
// pruning of vacobulary
|
||||
if (cutoff_prob < 1.0) {
|
||||
std::sort(prob_idx.begin(), prob_idx.end(), pair_comp_second_rev<int, double>);
|
||||
float cum_prob = 0.0;
|
||||
int cutoff_len = 0;
|
||||
for (int i=0; i<prob_idx.size(); i++) {
|
||||
cum_prob += prob_idx[i].second;
|
||||
cutoff_len += 1;
|
||||
if (cum_prob >= cutoff_prob) break;
|
||||
}
|
||||
prob_idx = std::vector<std::pair<int, double> >(prob_idx.begin(), prob_idx.begin()+cutoff_len);
|
||||
}
|
||||
// extend prefix
|
||||
for (std::map<std::string, double>::iterator it = prefix_set_prev.begin();
|
||||
it != prefix_set_prev.end(); it++) {
|
||||
std::string l = it->first;
|
||||
if( prefix_set_next.find(l) == prefix_set_next.end()) {
|
||||
probs_b_cur[l] = probs_nb_cur[l] = 0.0;
|
||||
}
|
||||
|
||||
for (int index=0; index<prob_idx.size(); index++) {
|
||||
int c = prob_idx[index].first;
|
||||
double prob_c = prob_idx[index].second;
|
||||
if (c == blank_id) {
|
||||
probs_b_cur[l] += prob_c*(probs_b_prev[l]+probs_nb_prev[l]);
|
||||
} else {
|
||||
std::string last_char = l.substr(l.size()-1, 1);
|
||||
std::string new_char = vocabulary[c];
|
||||
std::string l_plus = l+new_char;
|
||||
|
||||
if( prefix_set_next.find(l_plus) == prefix_set_next.end()) {
|
||||
probs_b_cur[l_plus] = probs_nb_cur[l_plus] = 0.0;
|
||||
}
|
||||
if (last_char == new_char) {
|
||||
probs_nb_cur[l_plus] += prob_c * probs_b_prev[l];
|
||||
probs_nb_cur[l] += prob_c * probs_nb_prev[l];
|
||||
} else if (new_char == " ") {
|
||||
double score = 1.0;
|
||||
if (ext_scorer != NULL && l.size() > 1) {
|
||||
score = ext_scorer->get_score(l.substr(1));
|
||||
}
|
||||
probs_nb_cur[l_plus] += score * prob_c * (
|
||||
probs_b_prev[l] + probs_nb_prev[l]);
|
||||
} else {
|
||||
probs_nb_cur[l_plus] += prob_c * (
|
||||
probs_b_prev[l] + probs_nb_prev[l]);
|
||||
}
|
||||
prefix_set_next[l_plus] = probs_nb_cur[l_plus]+probs_b_cur[l_plus];
|
||||
}
|
||||
}
|
||||
|
||||
prefix_set_next[l] = probs_b_cur[l]+probs_nb_cur[l];
|
||||
}
|
||||
|
||||
probs_b_prev = probs_b_cur;
|
||||
probs_nb_prev = probs_nb_cur;
|
||||
std::vector<std::pair<std::string, double> >
|
||||
prefix_vec_next(prefix_set_next.begin(), prefix_set_next.end());
|
||||
std::sort(prefix_vec_next.begin(), prefix_vec_next.end(), pair_comp_second_rev<std::string, double>);
|
||||
int k = beam_size<prefix_vec_next.size() ? beam_size : prefix_vec_next.size();
|
||||
prefix_set_prev = std::map<std::string, double>
|
||||
(prefix_vec_next.begin(), prefix_vec_next.begin()+k);
|
||||
}
|
||||
|
||||
// post processing
|
||||
std::vector<std::pair<double, std::string> > beam_result;
|
||||
for (std::map<std::string, double>::iterator it = prefix_set_prev.begin();
|
||||
it != prefix_set_prev.end(); it++) {
|
||||
if (it->second > 0.0 && it->first.size() > 1) {
|
||||
double prob = it->second;
|
||||
std::string sentence = it->first.substr(1);
|
||||
// scoring the last word
|
||||
if (ext_scorer != NULL && sentence[sentence.size()-1] != ' ') {
|
||||
prob = prob * ext_scorer->get_score(sentence);
|
||||
}
|
||||
double log_prob = log(it->second);
|
||||
beam_result.push_back(std::pair<double, std::string>(log_prob, it->first));
|
||||
}
|
||||
}
|
||||
// sort the result and return
|
||||
std::sort(beam_result.begin(), beam_result.end(), pair_comp_first_rev<double, std::string>);
|
||||
return beam_result;
|
||||
}
|
@ -0,0 +1,19 @@
|
||||
#ifndef CTC_BEAM_SEARCH_DECODER_H_
|
||||
#define CTC_BEAM_SEARCH_DECODER_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include "scorer.h"
|
||||
|
||||
std::vector<std::pair<double, std::string> >
|
||||
ctc_beam_search_decoder(std::vector<std::vector<double> > probs_seq,
|
||||
int beam_size,
|
||||
std::vector<std::string> vocabulary,
|
||||
int blank_id=0,
|
||||
double cutoff_prob=1.0,
|
||||
Scorer *ext_scorer=NULL,
|
||||
bool nproc=false
|
||||
);
|
||||
|
||||
#endif // CTC_BEAM_SEARCH_DECODER_H_
|
@ -0,0 +1,22 @@
|
||||
%module swig_ctc_beam_search_decoder
|
||||
%{
|
||||
#include "ctc_beam_search_decoder.h"
|
||||
%}
|
||||
|
||||
%include "std_vector.i"
|
||||
%include "std_pair.i"
|
||||
%include "std_string.i"
|
||||
|
||||
namespace std{
|
||||
%template(DoubleVector) std::vector<double>;
|
||||
%template(IntVector) std::vector<int>;
|
||||
%template(StringVector) std::vector<std::string>;
|
||||
%template(VectorOfStructVector) std::vector<std::vector<double> >;
|
||||
%template(FloatVector) std::vector<float>;
|
||||
%template(Pair) std::pair<float, std::string>;
|
||||
%template(PairFloatStringVector) std::vector<std::pair<float, std::string> >;
|
||||
%template(PairDoubleStringVector) std::vector<std::pair<double, std::string> >;
|
||||
}
|
||||
|
||||
%import scorer.h
|
||||
%include "ctc_beam_search_decoder.h"
|
@ -0,0 +1,58 @@
|
||||
from setuptools import setup, Extension
|
||||
import glob
|
||||
import platform
|
||||
import os
|
||||
|
||||
|
||||
def compile_test(header, library):
|
||||
dummy_path = os.path.join(os.path.dirname(__file__), "dummy")
|
||||
command = "bash -c \"g++ -include " + header + " -l" + library + " -x c++ - <<<'int main() {}' -o " + dummy_path + " >/dev/null 2>/dev/null && rm " + dummy_path + " 2>/dev/null\""
|
||||
return os.system(command) == 0
|
||||
|
||||
|
||||
FILES = glob.glob('util/*.cc') + glob.glob('lm/*.cc') + glob.glob(
|
||||
'util/double-conversion/*.cc')
|
||||
FILES = [
|
||||
fn for fn in FILES if not (fn.endswith('main.cc') or fn.endswith('test.cc'))
|
||||
]
|
||||
|
||||
LIBS = ['stdc++']
|
||||
if platform.system() != 'Darwin':
|
||||
LIBS.append('rt')
|
||||
|
||||
ARGS = ['-O3', '-DNDEBUG', '-DKENLM_MAX_ORDER=6']
|
||||
|
||||
if compile_test('zlib.h', 'z'):
|
||||
ARGS.append('-DHAVE_ZLIB')
|
||||
LIBS.append('z')
|
||||
|
||||
if compile_test('bzlib.h', 'bz2'):
|
||||
ARGS.append('-DHAVE_BZLIB')
|
||||
LIBS.append('bz2')
|
||||
|
||||
if compile_test('lzma.h', 'lzma'):
|
||||
ARGS.append('-DHAVE_XZLIB')
|
||||
LIBS.append('lzma')
|
||||
|
||||
os.system('swig -python -c++ ./ctc_beam_search_decoder.i')
|
||||
|
||||
ctc_beam_search_decoder_module = [
|
||||
Extension(
|
||||
name='_swig_ctc_beam_search_decoder',
|
||||
sources=FILES + [
|
||||
'scorer.cpp', 'ctc_beam_search_decoder_wrap.cxx',
|
||||
'ctc_beam_search_decoder.cpp'
|
||||
],
|
||||
language='C++',
|
||||
include_dirs=['.'],
|
||||
libraries=LIBS,
|
||||
extra_compile_args=ARGS)
|
||||
]
|
||||
|
||||
setup(
|
||||
name='swig_ctc_beam_search_decoder',
|
||||
version='0.1',
|
||||
author='Yibing Liu',
|
||||
description="""CTC beam search decoder""",
|
||||
ext_modules=ctc_beam_search_decoder_module,
|
||||
py_modules=['swig_ctc_beam_search_decoder'], )
|
@ -0,0 +1,82 @@
|
||||
#include <iostream>
|
||||
|
||||
#include "scorer.h"
|
||||
#include "lm/model.hh"
|
||||
#include "util/tokenize_piece.hh"
|
||||
#include "util/string_piece.hh"
|
||||
|
||||
using namespace lm::ngram;
|
||||
|
||||
Scorer::Scorer(float alpha, float beta, std::string lm_model_path) {
|
||||
this->_alpha = alpha;
|
||||
this->_beta = beta;
|
||||
this->_language_model = new Model(lm_model_path.c_str());
|
||||
}
|
||||
|
||||
Scorer::~Scorer(){
|
||||
delete (Model *)this->_language_model;
|
||||
}
|
||||
|
||||
inline void strip(std::string &str, char ch=' ') {
|
||||
if (str.size() == 0) return;
|
||||
int start = 0;
|
||||
int end = str.size()-1;
|
||||
for (int i=0; i<str.size(); i++){
|
||||
if (str[i] == ch) {
|
||||
start ++;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
for (int i=str.size()-1; i>=0; i--) {
|
||||
if (str[i] == ch) {
|
||||
end --;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (start == 0 && end == str.size()-1) return;
|
||||
if (start > end) {
|
||||
std::string emp_str;
|
||||
str = emp_str;
|
||||
} else {
|
||||
str = str.substr(start, end-start+1);
|
||||
}
|
||||
}
|
||||
|
||||
int Scorer::word_count(std::string sentence) {
|
||||
strip(sentence);
|
||||
int cnt = 0;
|
||||
for (int i=0; i<sentence.size(); i++) {
|
||||
if (sentence[i] == ' ' && sentence[i-1] != ' ') {
|
||||
cnt ++;
|
||||
}
|
||||
}
|
||||
if (cnt > 0) cnt ++;
|
||||
return cnt;
|
||||
}
|
||||
|
||||
double Scorer::language_model_score(std::string sentence) {
|
||||
Model *model = (Model *)this->_language_model;
|
||||
State state, out_state;
|
||||
lm::FullScoreReturn ret;
|
||||
state = model->BeginSentenceState();
|
||||
|
||||
for (util::TokenIter<util::SingleCharacter, true> it(sentence, ' '); it; ++it){
|
||||
lm::WordIndex vocab = model->GetVocabulary().Index(*it);
|
||||
ret = model->FullScore(state, vocab, out_state);
|
||||
state = out_state;
|
||||
}
|
||||
double score = ret.prob;
|
||||
|
||||
return pow(10, score);
|
||||
}
|
||||
|
||||
double Scorer::get_score(std::string sentence) {
|
||||
double lm_score = language_model_score(sentence);
|
||||
int word_cnt = word_count(sentence);
|
||||
|
||||
double final_score = pow(lm_score, _alpha) * pow(word_cnt, _beta);
|
||||
return final_score;
|
||||
}
|
@ -0,0 +1,22 @@
|
||||
#ifndef SCORER_H_
|
||||
#define SCORER_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
|
||||
class Scorer{
|
||||
private:
|
||||
float _alpha;
|
||||
float _beta;
|
||||
void *_language_model;
|
||||
|
||||
public:
|
||||
Scorer(){}
|
||||
Scorer(float alpha, float beta, std::string lm_model_path);
|
||||
~Scorer();
|
||||
int word_count(std::string);
|
||||
double language_model_score(std::string);
|
||||
double get_score(std::string);
|
||||
};
|
||||
|
||||
#endif
|
@ -0,0 +1,8 @@
|
||||
%module swig_scorer
|
||||
%{
|
||||
#include "scorer.h"
|
||||
%}
|
||||
|
||||
%include "std_string.i"
|
||||
|
||||
%include "scorer.h"
|
@ -0,0 +1,54 @@
|
||||
from setuptools import setup, Extension
|
||||
import glob
|
||||
import platform
|
||||
import os
|
||||
|
||||
|
||||
def compile_test(header, library):
|
||||
dummy_path = os.path.join(os.path.dirname(__file__), "dummy")
|
||||
command = "bash -c \"g++ -include " + header + " -l" + library + " -x c++ - <<<'int main() {}' -o " + dummy_path + " >/dev/null 2>/dev/null && rm " + dummy_path + " 2>/dev/null\""
|
||||
return os.system(command) == 0
|
||||
|
||||
|
||||
FILES = glob.glob('util/*.cc') + glob.glob('lm/*.cc') + glob.glob(
|
||||
'util/double-conversion/*.cc')
|
||||
FILES = [
|
||||
fn for fn in FILES if not (fn.endswith('main.cc') or fn.endswith('test.cc'))
|
||||
]
|
||||
|
||||
LIBS = ['stdc++']
|
||||
if platform.system() != 'Darwin':
|
||||
LIBS.append('rt')
|
||||
|
||||
ARGS = ['-O3', '-DNDEBUG', '-DKENLM_MAX_ORDER=6']
|
||||
|
||||
if compile_test('zlib.h', 'z'):
|
||||
ARGS.append('-DHAVE_ZLIB')
|
||||
LIBS.append('z')
|
||||
|
||||
if compile_test('bzlib.h', 'bz2'):
|
||||
ARGS.append('-DHAVE_BZLIB')
|
||||
LIBS.append('bz2')
|
||||
|
||||
if compile_test('lzma.h', 'lzma'):
|
||||
ARGS.append('-DHAVE_XZLIB')
|
||||
LIBS.append('lzma')
|
||||
|
||||
os.system('swig -python -c++ ./scorer.i')
|
||||
|
||||
ext_modules = [
|
||||
Extension(
|
||||
name='_swig_scorer',
|
||||
sources=FILES + ['scorer_wrap.cxx', 'scorer.cpp'],
|
||||
language='C++',
|
||||
include_dirs=['.'],
|
||||
libraries=LIBS,
|
||||
extra_compile_args=ARGS)
|
||||
]
|
||||
|
||||
setup(
|
||||
name='swig_scorer',
|
||||
version='0.1',
|
||||
ext_modules=ext_modules,
|
||||
include_package_data=True,
|
||||
py_modules=['swig_scorer'], )
|
Loading…
Reference in new issue