commit
2f15a78707
@ -0,0 +1,143 @@
|
|||||||
|
#include <iostream>
|
||||||
|
#include <map>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <utility>
|
||||||
|
#include <cmath>
|
||||||
|
#include "ctc_beam_search_decoder.h"
|
||||||
|
|
||||||
|
template <typename T1, typename T2>
|
||||||
|
bool pair_comp_first_rev(const std::pair<T1, T2> a, const std::pair<T1, T2> b) {
|
||||||
|
return a.first > b.first;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T1, typename T2>
|
||||||
|
bool pair_comp_second_rev(const std::pair<T1, T2> a, const std::pair<T1, T2> b) {
|
||||||
|
return a.second > b.second;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* CTC beam search decoder in C++, the interface is consistent with the original
|
||||||
|
decoder in Python version.
|
||||||
|
*/
|
||||||
|
std::vector<std::pair<double, std::string> >
|
||||||
|
ctc_beam_search_decoder(std::vector<std::vector<double> > probs_seq,
|
||||||
|
int beam_size,
|
||||||
|
std::vector<std::string> vocabulary,
|
||||||
|
int blank_id,
|
||||||
|
double cutoff_prob,
|
||||||
|
Scorer *ext_scorer,
|
||||||
|
bool nproc
|
||||||
|
)
|
||||||
|
{
|
||||||
|
int num_time_steps = probs_seq.size();
|
||||||
|
|
||||||
|
// assign space ID
|
||||||
|
std::vector<std::string>::iterator it = std::find(vocabulary.begin(), vocabulary.end(), " ");
|
||||||
|
int space_id = it-vocabulary.begin();
|
||||||
|
if(space_id >= vocabulary.size()) {
|
||||||
|
std::cout<<"The character space is not in the vocabulary!";
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// initialize
|
||||||
|
// two sets containing selected and candidate prefixes respectively
|
||||||
|
std::map<std::string, double> prefix_set_prev, prefix_set_next;
|
||||||
|
// probability of prefixes ending with blank and non-blank
|
||||||
|
std::map<std::string, double> probs_b_prev, probs_nb_prev;
|
||||||
|
std::map<std::string, double> probs_b_cur, probs_nb_cur;
|
||||||
|
prefix_set_prev["\t"] = 1.0;
|
||||||
|
probs_b_prev["\t"] = 1.0;
|
||||||
|
probs_nb_prev["\t"] = 0.0;
|
||||||
|
|
||||||
|
for (int time_step=0; time_step<num_time_steps; time_step++) {
|
||||||
|
prefix_set_next.clear();
|
||||||
|
probs_b_cur.clear();
|
||||||
|
probs_nb_cur.clear();
|
||||||
|
std::vector<double> prob = probs_seq[time_step];
|
||||||
|
|
||||||
|
std::vector<std::pair<int, double> > prob_idx;
|
||||||
|
for (int i=0; i<prob.size(); i++) {
|
||||||
|
prob_idx.push_back(std::pair<int, double>(i, prob[i]));
|
||||||
|
}
|
||||||
|
// pruning of vacobulary
|
||||||
|
if (cutoff_prob < 1.0) {
|
||||||
|
std::sort(prob_idx.begin(), prob_idx.end(), pair_comp_second_rev<int, double>);
|
||||||
|
float cum_prob = 0.0;
|
||||||
|
int cutoff_len = 0;
|
||||||
|
for (int i=0; i<prob_idx.size(); i++) {
|
||||||
|
cum_prob += prob_idx[i].second;
|
||||||
|
cutoff_len += 1;
|
||||||
|
if (cum_prob >= cutoff_prob) break;
|
||||||
|
}
|
||||||
|
prob_idx = std::vector<std::pair<int, double> >(prob_idx.begin(), prob_idx.begin()+cutoff_len);
|
||||||
|
}
|
||||||
|
// extend prefix
|
||||||
|
for (std::map<std::string, double>::iterator it = prefix_set_prev.begin();
|
||||||
|
it != prefix_set_prev.end(); it++) {
|
||||||
|
std::string l = it->first;
|
||||||
|
if( prefix_set_next.find(l) == prefix_set_next.end()) {
|
||||||
|
probs_b_cur[l] = probs_nb_cur[l] = 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int index=0; index<prob_idx.size(); index++) {
|
||||||
|
int c = prob_idx[index].first;
|
||||||
|
double prob_c = prob_idx[index].second;
|
||||||
|
if (c == blank_id) {
|
||||||
|
probs_b_cur[l] += prob_c*(probs_b_prev[l]+probs_nb_prev[l]);
|
||||||
|
} else {
|
||||||
|
std::string last_char = l.substr(l.size()-1, 1);
|
||||||
|
std::string new_char = vocabulary[c];
|
||||||
|
std::string l_plus = l+new_char;
|
||||||
|
|
||||||
|
if( prefix_set_next.find(l_plus) == prefix_set_next.end()) {
|
||||||
|
probs_b_cur[l_plus] = probs_nb_cur[l_plus] = 0.0;
|
||||||
|
}
|
||||||
|
if (last_char == new_char) {
|
||||||
|
probs_nb_cur[l_plus] += prob_c * probs_b_prev[l];
|
||||||
|
probs_nb_cur[l] += prob_c * probs_nb_prev[l];
|
||||||
|
} else if (new_char == " ") {
|
||||||
|
double score = 1.0;
|
||||||
|
if (ext_scorer != NULL && l.size() > 1) {
|
||||||
|
score = ext_scorer->get_score(l.substr(1));
|
||||||
|
}
|
||||||
|
probs_nb_cur[l_plus] += score * prob_c * (
|
||||||
|
probs_b_prev[l] + probs_nb_prev[l]);
|
||||||
|
} else {
|
||||||
|
probs_nb_cur[l_plus] += prob_c * (
|
||||||
|
probs_b_prev[l] + probs_nb_prev[l]);
|
||||||
|
}
|
||||||
|
prefix_set_next[l_plus] = probs_nb_cur[l_plus]+probs_b_cur[l_plus];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
prefix_set_next[l] = probs_b_cur[l]+probs_nb_cur[l];
|
||||||
|
}
|
||||||
|
|
||||||
|
probs_b_prev = probs_b_cur;
|
||||||
|
probs_nb_prev = probs_nb_cur;
|
||||||
|
std::vector<std::pair<std::string, double> >
|
||||||
|
prefix_vec_next(prefix_set_next.begin(), prefix_set_next.end());
|
||||||
|
std::sort(prefix_vec_next.begin(), prefix_vec_next.end(), pair_comp_second_rev<std::string, double>);
|
||||||
|
int k = beam_size<prefix_vec_next.size() ? beam_size : prefix_vec_next.size();
|
||||||
|
prefix_set_prev = std::map<std::string, double>
|
||||||
|
(prefix_vec_next.begin(), prefix_vec_next.begin()+k);
|
||||||
|
}
|
||||||
|
|
||||||
|
// post processing
|
||||||
|
std::vector<std::pair<double, std::string> > beam_result;
|
||||||
|
for (std::map<std::string, double>::iterator it = prefix_set_prev.begin();
|
||||||
|
it != prefix_set_prev.end(); it++) {
|
||||||
|
if (it->second > 0.0 && it->first.size() > 1) {
|
||||||
|
double prob = it->second;
|
||||||
|
std::string sentence = it->first.substr(1);
|
||||||
|
// scoring the last word
|
||||||
|
if (ext_scorer != NULL && sentence[sentence.size()-1] != ' ') {
|
||||||
|
prob = prob * ext_scorer->get_score(sentence);
|
||||||
|
}
|
||||||
|
double log_prob = log(it->second);
|
||||||
|
beam_result.push_back(std::pair<double, std::string>(log_prob, it->first));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// sort the result and return
|
||||||
|
std::sort(beam_result.begin(), beam_result.end(), pair_comp_first_rev<double, std::string>);
|
||||||
|
return beam_result;
|
||||||
|
}
|
@ -0,0 +1,19 @@
|
|||||||
|
#ifndef CTC_BEAM_SEARCH_DECODER_H_
|
||||||
|
#define CTC_BEAM_SEARCH_DECODER_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include "scorer.h"
|
||||||
|
|
||||||
|
std::vector<std::pair<double, std::string> >
|
||||||
|
ctc_beam_search_decoder(std::vector<std::vector<double> > probs_seq,
|
||||||
|
int beam_size,
|
||||||
|
std::vector<std::string> vocabulary,
|
||||||
|
int blank_id=0,
|
||||||
|
double cutoff_prob=1.0,
|
||||||
|
Scorer *ext_scorer=NULL,
|
||||||
|
bool nproc=false
|
||||||
|
);
|
||||||
|
|
||||||
|
#endif // CTC_BEAM_SEARCH_DECODER_H_
|
@ -0,0 +1,22 @@
|
|||||||
|
%module swig_ctc_beam_search_decoder
|
||||||
|
%{
|
||||||
|
#include "ctc_beam_search_decoder.h"
|
||||||
|
%}
|
||||||
|
|
||||||
|
%include "std_vector.i"
|
||||||
|
%include "std_pair.i"
|
||||||
|
%include "std_string.i"
|
||||||
|
|
||||||
|
namespace std{
|
||||||
|
%template(DoubleVector) std::vector<double>;
|
||||||
|
%template(IntVector) std::vector<int>;
|
||||||
|
%template(StringVector) std::vector<std::string>;
|
||||||
|
%template(VectorOfStructVector) std::vector<std::vector<double> >;
|
||||||
|
%template(FloatVector) std::vector<float>;
|
||||||
|
%template(Pair) std::pair<float, std::string>;
|
||||||
|
%template(PairFloatStringVector) std::vector<std::pair<float, std::string> >;
|
||||||
|
%template(PairDoubleStringVector) std::vector<std::pair<double, std::string> >;
|
||||||
|
}
|
||||||
|
|
||||||
|
%import scorer.h
|
||||||
|
%include "ctc_beam_search_decoder.h"
|
@ -0,0 +1,58 @@
|
|||||||
|
from setuptools import setup, Extension
|
||||||
|
import glob
|
||||||
|
import platform
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def compile_test(header, library):
|
||||||
|
dummy_path = os.path.join(os.path.dirname(__file__), "dummy")
|
||||||
|
command = "bash -c \"g++ -include " + header + " -l" + library + " -x c++ - <<<'int main() {}' -o " + dummy_path + " >/dev/null 2>/dev/null && rm " + dummy_path + " 2>/dev/null\""
|
||||||
|
return os.system(command) == 0
|
||||||
|
|
||||||
|
|
||||||
|
FILES = glob.glob('util/*.cc') + glob.glob('lm/*.cc') + glob.glob(
|
||||||
|
'util/double-conversion/*.cc')
|
||||||
|
FILES = [
|
||||||
|
fn for fn in FILES if not (fn.endswith('main.cc') or fn.endswith('test.cc'))
|
||||||
|
]
|
||||||
|
|
||||||
|
LIBS = ['stdc++']
|
||||||
|
if platform.system() != 'Darwin':
|
||||||
|
LIBS.append('rt')
|
||||||
|
|
||||||
|
ARGS = ['-O3', '-DNDEBUG', '-DKENLM_MAX_ORDER=6']
|
||||||
|
|
||||||
|
if compile_test('zlib.h', 'z'):
|
||||||
|
ARGS.append('-DHAVE_ZLIB')
|
||||||
|
LIBS.append('z')
|
||||||
|
|
||||||
|
if compile_test('bzlib.h', 'bz2'):
|
||||||
|
ARGS.append('-DHAVE_BZLIB')
|
||||||
|
LIBS.append('bz2')
|
||||||
|
|
||||||
|
if compile_test('lzma.h', 'lzma'):
|
||||||
|
ARGS.append('-DHAVE_XZLIB')
|
||||||
|
LIBS.append('lzma')
|
||||||
|
|
||||||
|
os.system('swig -python -c++ ./ctc_beam_search_decoder.i')
|
||||||
|
|
||||||
|
ctc_beam_search_decoder_module = [
|
||||||
|
Extension(
|
||||||
|
name='_swig_ctc_beam_search_decoder',
|
||||||
|
sources=FILES + [
|
||||||
|
'scorer.cpp', 'ctc_beam_search_decoder_wrap.cxx',
|
||||||
|
'ctc_beam_search_decoder.cpp'
|
||||||
|
],
|
||||||
|
language='C++',
|
||||||
|
include_dirs=['.'],
|
||||||
|
libraries=LIBS,
|
||||||
|
extra_compile_args=ARGS)
|
||||||
|
]
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name='swig_ctc_beam_search_decoder',
|
||||||
|
version='0.1',
|
||||||
|
author='Yibing Liu',
|
||||||
|
description="""CTC beam search decoder""",
|
||||||
|
ext_modules=ctc_beam_search_decoder_module,
|
||||||
|
py_modules=['swig_ctc_beam_search_decoder'], )
|
@ -0,0 +1,82 @@
|
|||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
#include "scorer.h"
|
||||||
|
#include "lm/model.hh"
|
||||||
|
#include "util/tokenize_piece.hh"
|
||||||
|
#include "util/string_piece.hh"
|
||||||
|
|
||||||
|
using namespace lm::ngram;
|
||||||
|
|
||||||
|
Scorer::Scorer(float alpha, float beta, std::string lm_model_path) {
|
||||||
|
this->_alpha = alpha;
|
||||||
|
this->_beta = beta;
|
||||||
|
this->_language_model = new Model(lm_model_path.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
Scorer::~Scorer(){
|
||||||
|
delete (Model *)this->_language_model;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void strip(std::string &str, char ch=' ') {
|
||||||
|
if (str.size() == 0) return;
|
||||||
|
int start = 0;
|
||||||
|
int end = str.size()-1;
|
||||||
|
for (int i=0; i<str.size(); i++){
|
||||||
|
if (str[i] == ch) {
|
||||||
|
start ++;
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int i=str.size()-1; i>=0; i--) {
|
||||||
|
if (str[i] == ch) {
|
||||||
|
end --;
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (start == 0 && end == str.size()-1) return;
|
||||||
|
if (start > end) {
|
||||||
|
std::string emp_str;
|
||||||
|
str = emp_str;
|
||||||
|
} else {
|
||||||
|
str = str.substr(start, end-start+1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int Scorer::word_count(std::string sentence) {
|
||||||
|
strip(sentence);
|
||||||
|
int cnt = 0;
|
||||||
|
for (int i=0; i<sentence.size(); i++) {
|
||||||
|
if (sentence[i] == ' ' && sentence[i-1] != ' ') {
|
||||||
|
cnt ++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (cnt > 0) cnt ++;
|
||||||
|
return cnt;
|
||||||
|
}
|
||||||
|
|
||||||
|
double Scorer::language_model_score(std::string sentence) {
|
||||||
|
Model *model = (Model *)this->_language_model;
|
||||||
|
State state, out_state;
|
||||||
|
lm::FullScoreReturn ret;
|
||||||
|
state = model->BeginSentenceState();
|
||||||
|
|
||||||
|
for (util::TokenIter<util::SingleCharacter, true> it(sentence, ' '); it; ++it){
|
||||||
|
lm::WordIndex vocab = model->GetVocabulary().Index(*it);
|
||||||
|
ret = model->FullScore(state, vocab, out_state);
|
||||||
|
state = out_state;
|
||||||
|
}
|
||||||
|
double score = ret.prob;
|
||||||
|
|
||||||
|
return pow(10, score);
|
||||||
|
}
|
||||||
|
|
||||||
|
double Scorer::get_score(std::string sentence) {
|
||||||
|
double lm_score = language_model_score(sentence);
|
||||||
|
int word_cnt = word_count(sentence);
|
||||||
|
|
||||||
|
double final_score = pow(lm_score, _alpha) * pow(word_cnt, _beta);
|
||||||
|
return final_score;
|
||||||
|
}
|
@ -0,0 +1,22 @@
|
|||||||
|
#ifndef SCORER_H_
|
||||||
|
#define SCORER_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
|
||||||
|
class Scorer{
|
||||||
|
private:
|
||||||
|
float _alpha;
|
||||||
|
float _beta;
|
||||||
|
void *_language_model;
|
||||||
|
|
||||||
|
public:
|
||||||
|
Scorer(){}
|
||||||
|
Scorer(float alpha, float beta, std::string lm_model_path);
|
||||||
|
~Scorer();
|
||||||
|
int word_count(std::string);
|
||||||
|
double language_model_score(std::string);
|
||||||
|
double get_score(std::string);
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif
|
@ -0,0 +1,8 @@
|
|||||||
|
%module swig_scorer
|
||||||
|
%{
|
||||||
|
#include "scorer.h"
|
||||||
|
%}
|
||||||
|
|
||||||
|
%include "std_string.i"
|
||||||
|
|
||||||
|
%include "scorer.h"
|
@ -0,0 +1,54 @@
|
|||||||
|
from setuptools import setup, Extension
|
||||||
|
import glob
|
||||||
|
import platform
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def compile_test(header, library):
|
||||||
|
dummy_path = os.path.join(os.path.dirname(__file__), "dummy")
|
||||||
|
command = "bash -c \"g++ -include " + header + " -l" + library + " -x c++ - <<<'int main() {}' -o " + dummy_path + " >/dev/null 2>/dev/null && rm " + dummy_path + " 2>/dev/null\""
|
||||||
|
return os.system(command) == 0
|
||||||
|
|
||||||
|
|
||||||
|
FILES = glob.glob('util/*.cc') + glob.glob('lm/*.cc') + glob.glob(
|
||||||
|
'util/double-conversion/*.cc')
|
||||||
|
FILES = [
|
||||||
|
fn for fn in FILES if not (fn.endswith('main.cc') or fn.endswith('test.cc'))
|
||||||
|
]
|
||||||
|
|
||||||
|
LIBS = ['stdc++']
|
||||||
|
if platform.system() != 'Darwin':
|
||||||
|
LIBS.append('rt')
|
||||||
|
|
||||||
|
ARGS = ['-O3', '-DNDEBUG', '-DKENLM_MAX_ORDER=6']
|
||||||
|
|
||||||
|
if compile_test('zlib.h', 'z'):
|
||||||
|
ARGS.append('-DHAVE_ZLIB')
|
||||||
|
LIBS.append('z')
|
||||||
|
|
||||||
|
if compile_test('bzlib.h', 'bz2'):
|
||||||
|
ARGS.append('-DHAVE_BZLIB')
|
||||||
|
LIBS.append('bz2')
|
||||||
|
|
||||||
|
if compile_test('lzma.h', 'lzma'):
|
||||||
|
ARGS.append('-DHAVE_XZLIB')
|
||||||
|
LIBS.append('lzma')
|
||||||
|
|
||||||
|
os.system('swig -python -c++ ./scorer.i')
|
||||||
|
|
||||||
|
ext_modules = [
|
||||||
|
Extension(
|
||||||
|
name='_swig_scorer',
|
||||||
|
sources=FILES + ['scorer_wrap.cxx', 'scorer.cpp'],
|
||||||
|
language='C++',
|
||||||
|
include_dirs=['.'],
|
||||||
|
libraries=LIBS,
|
||||||
|
extra_compile_args=ARGS)
|
||||||
|
]
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name='swig_scorer',
|
||||||
|
version='0.1',
|
||||||
|
ext_modules=ext_modules,
|
||||||
|
include_package_data=True,
|
||||||
|
py_modules=['swig_scorer'], )
|
Loading…
Reference in new issue