format varabiables' name & add more comments

pull/2/head
Yibing Liu 7 years ago
parent a24d0138d9
commit 3018dcb4d9

@ -18,8 +18,8 @@ using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
std::vector<std::pair<double, std::string>> ctc_beam_search_decoder( std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
const std::vector<std::vector<double>> &probs_seq, const std::vector<std::vector<double>> &probs_seq,
const std::vector<std::string> &vocabulary,
size_t beam_size, size_t beam_size,
std::vector<std::string> vocabulary,
double cutoff_prob, double cutoff_prob,
size_t cutoff_top_n, size_t cutoff_top_n,
Scorer *ext_scorer) { Scorer *ext_scorer) {
@ -36,8 +36,7 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
size_t blank_id = vocabulary.size(); size_t blank_id = vocabulary.size();
// assign space id // assign space id
std::vector<std::string>::iterator it = auto it = std::find(vocabulary.begin(), vocabulary.end(), " ");
std::find(vocabulary.begin(), vocabulary.end(), " ");
int space_id = it - vocabulary.begin(); int space_id = it - vocabulary.begin();
// if no space in vocabulary // if no space in vocabulary
if ((size_t)space_id >= vocabulary.size()) { if ((size_t)space_id >= vocabulary.size()) {
@ -173,11 +172,11 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
std::vector<std::vector<std::pair<double, std::string>>> std::vector<std::vector<std::pair<double, std::string>>>
ctc_beam_search_decoder_batch( ctc_beam_search_decoder_batch(
const std::vector<std::vector<std::vector<double>>> &probs_split, const std::vector<std::vector<std::vector<double>>> &probs_split,
const size_t beam_size,
const std::vector<std::string> &vocabulary, const std::vector<std::string> &vocabulary,
const size_t num_processes, size_t beam_size,
const double cutoff_prob, size_t num_processes,
const size_t cutoff_top_n, double cutoff_prob,
size_t cutoff_top_n,
Scorer *ext_scorer) { Scorer *ext_scorer) {
VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!"); VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!");
// thread pool // thread pool
@ -190,8 +189,8 @@ ctc_beam_search_decoder_batch(
for (size_t i = 0; i < batch_size; ++i) { for (size_t i = 0; i < batch_size; ++i) {
res.emplace_back(pool.enqueue(ctc_beam_search_decoder, res.emplace_back(pool.enqueue(ctc_beam_search_decoder,
probs_split[i], probs_split[i],
beam_size,
vocabulary, vocabulary,
beam_size,
cutoff_prob, cutoff_prob,
cutoff_top_n, cutoff_top_n,
ext_scorer)); ext_scorer));

@ -12,8 +12,8 @@
* Parameters: * Parameters:
* probs_seq: 2-D vector that each element is a vector of probabilities * probs_seq: 2-D vector that each element is a vector of probabilities
* over vocabulary of one time step. * over vocabulary of one time step.
* beam_size: The width of beam search.
* vocabulary: A vector of vocabulary. * vocabulary: A vector of vocabulary.
* beam_size: The width of beam search.
* cutoff_prob: Cutoff probability for pruning. * cutoff_prob: Cutoff probability for pruning.
* cutoff_top_n: Cutoff number for pruning. * cutoff_top_n: Cutoff number for pruning.
* ext_scorer: External scorer to evaluate a prefix, which consists of * ext_scorer: External scorer to evaluate a prefix, which consists of
@ -25,8 +25,8 @@
*/ */
std::vector<std::pair<double, std::string>> ctc_beam_search_decoder( std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
const std::vector<std::vector<double>> &probs_seq, const std::vector<std::vector<double>> &probs_seq,
const std::vector<std::string> &vocabulary,
size_t beam_size, size_t beam_size,
std::vector<std::string> vocabulary,
double cutoff_prob = 1.0, double cutoff_prob = 1.0,
size_t cutoff_top_n = 40, size_t cutoff_top_n = 40,
Scorer *ext_scorer = nullptr); Scorer *ext_scorer = nullptr);
@ -36,9 +36,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
* Parameters: * Parameters:
* probs_seq: 3-D vector that each element is a 2-D vector that can be used * probs_seq: 3-D vector that each element is a 2-D vector that can be used
* by ctc_beam_search_decoder(). * by ctc_beam_search_decoder().
* .
* beam_size: The width of beam search.
* vocabulary: A vector of vocabulary. * vocabulary: A vector of vocabulary.
* beam_size: The width of beam search.
* num_processes: Number of threads for beam search. * num_processes: Number of threads for beam search.
* cutoff_prob: Cutoff probability for pruning. * cutoff_prob: Cutoff probability for pruning.
* cutoff_top_n: Cutoff number for pruning. * cutoff_top_n: Cutoff number for pruning.
@ -52,8 +51,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
std::vector<std::vector<std::pair<double, std::string>>> std::vector<std::vector<std::pair<double, std::string>>>
ctc_beam_search_decoder_batch( ctc_beam_search_decoder_batch(
const std::vector<std::vector<std::vector<double>>> &probs_split, const std::vector<std::vector<std::vector<double>>> &probs_split,
size_t beam_size,
const std::vector<std::string> &vocabulary, const std::vector<std::string> &vocabulary,
size_t beam_size,
size_t num_processes, size_t num_processes,
double cutoff_prob = 1.0, double cutoff_prob = 1.0,
size_t cutoff_top_n = 40, size_t cutoff_top_n = 40,

@ -15,32 +15,32 @@ PathTrie::PathTrie() {
log_prob_nb_cur = -NUM_FLT_INF; log_prob_nb_cur = -NUM_FLT_INF;
score = -NUM_FLT_INF; score = -NUM_FLT_INF;
_ROOT = -1; ROOT_ = -1;
character = _ROOT; character = ROOT_;
_exists = true; exists_ = true;
parent = nullptr; parent = nullptr;
_dictionary = nullptr; dictionary_ = nullptr;
_dictionary_state = 0; dictionary_state_ = 0;
_has_dictionary = false; has_dictionary_ = false;
_matcher = nullptr; matcher_ = nullptr;
} }
PathTrie::~PathTrie() { PathTrie::~PathTrie() {
for (auto child : _children) { for (auto child : children_) {
delete child.second; delete child.second;
} }
} }
PathTrie* PathTrie::get_path_trie(int new_char, bool reset) { PathTrie* PathTrie::get_path_trie(int new_char, bool reset) {
auto child = _children.begin(); auto child = children_.begin();
for (child = _children.begin(); child != _children.end(); ++child) { for (child = children_.begin(); child != children_.end(); ++child) {
if (child->first == new_char) { if (child->first == new_char) {
break; break;
} }
} }
if (child != _children.end()) { if (child != children_.end()) {
if (!child->second->_exists) { if (!child->second->exists_) {
child->second->_exists = true; child->second->exists_ = true;
child->second->log_prob_b_prev = -NUM_FLT_INF; child->second->log_prob_b_prev = -NUM_FLT_INF;
child->second->log_prob_nb_prev = -NUM_FLT_INF; child->second->log_prob_nb_prev = -NUM_FLT_INF;
child->second->log_prob_b_cur = -NUM_FLT_INF; child->second->log_prob_b_cur = -NUM_FLT_INF;
@ -48,47 +48,47 @@ PathTrie* PathTrie::get_path_trie(int new_char, bool reset) {
} }
return (child->second); return (child->second);
} else { } else {
if (_has_dictionary) { if (has_dictionary_) {
_matcher->SetState(_dictionary_state); matcher_->SetState(dictionary_state_);
bool found = _matcher->Find(new_char); bool found = matcher_->Find(new_char);
if (!found) { if (!found) {
// Adding this character causes word outside dictionary // Adding this character causes word outside dictionary
auto FSTZERO = fst::TropicalWeight::Zero(); auto FSTZERO = fst::TropicalWeight::Zero();
auto final_weight = _dictionary->Final(_dictionary_state); auto final_weight = dictionary_->Final(dictionary_state_);
bool is_final = (final_weight != FSTZERO); bool is_final = (final_weight != FSTZERO);
if (is_final && reset) { if (is_final && reset) {
_dictionary_state = _dictionary->Start(); dictionary_state_ = dictionary_->Start();
} }
return nullptr; return nullptr;
} else { } else {
PathTrie* new_path = new PathTrie; PathTrie* new_path = new PathTrie;
new_path->character = new_char; new_path->character = new_char;
new_path->parent = this; new_path->parent = this;
new_path->_dictionary = _dictionary; new_path->dictionary_ = dictionary_;
new_path->_dictionary_state = _matcher->Value().nextstate; new_path->dictionary_state_ = matcher_->Value().nextstate;
new_path->_has_dictionary = true; new_path->has_dictionary_ = true;
new_path->_matcher = _matcher; new_path->matcher_ = matcher_;
_children.push_back(std::make_pair(new_char, new_path)); children_.push_back(std::make_pair(new_char, new_path));
return new_path; return new_path;
} }
} else { } else {
PathTrie* new_path = new PathTrie; PathTrie* new_path = new PathTrie;
new_path->character = new_char; new_path->character = new_char;
new_path->parent = this; new_path->parent = this;
_children.push_back(std::make_pair(new_char, new_path)); children_.push_back(std::make_pair(new_char, new_path));
return new_path; return new_path;
} }
} }
} }
PathTrie* PathTrie::get_path_vec(std::vector<int>& output) { PathTrie* PathTrie::get_path_vec(std::vector<int>& output) {
return get_path_vec(output, _ROOT); return get_path_vec(output, ROOT_);
} }
PathTrie* PathTrie::get_path_vec(std::vector<int>& output, PathTrie* PathTrie::get_path_vec(std::vector<int>& output,
int stop, int stop,
size_t max_steps) { size_t max_steps) {
if (character == stop || character == _ROOT || output.size() == max_steps) { if (character == stop || character == ROOT_ || output.size() == max_steps) {
std::reverse(output.begin(), output.end()); std::reverse(output.begin(), output.end());
return this; return this;
} else { } else {
@ -98,7 +98,7 @@ PathTrie* PathTrie::get_path_vec(std::vector<int>& output,
} }
void PathTrie::iterate_to_vec(std::vector<PathTrie*>& output) { void PathTrie::iterate_to_vec(std::vector<PathTrie*>& output) {
if (_exists) { if (exists_) {
log_prob_b_prev = log_prob_b_cur; log_prob_b_prev = log_prob_b_cur;
log_prob_nb_prev = log_prob_nb_cur; log_prob_nb_prev = log_prob_nb_cur;
@ -108,25 +108,25 @@ void PathTrie::iterate_to_vec(std::vector<PathTrie*>& output) {
score = log_sum_exp(log_prob_b_prev, log_prob_nb_prev); score = log_sum_exp(log_prob_b_prev, log_prob_nb_prev);
output.push_back(this); output.push_back(this);
} }
for (auto child : _children) { for (auto child : children_) {
child.second->iterate_to_vec(output); child.second->iterate_to_vec(output);
} }
} }
void PathTrie::remove() { void PathTrie::remove() {
_exists = false; exists_ = false;
if (_children.size() == 0) { if (children_.size() == 0) {
auto child = parent->_children.begin(); auto child = parent->children_.begin();
for (child = parent->_children.begin(); child != parent->_children.end(); for (child = parent->children_.begin(); child != parent->children_.end();
++child) { ++child) {
if (child->first == character) { if (child->first == character) {
parent->_children.erase(child); parent->children_.erase(child);
break; break;
} }
} }
if (parent->_children.size() == 0 && !parent->_exists) { if (parent->children_.size() == 0 && !parent->exists_) {
parent->remove(); parent->remove();
} }
@ -135,12 +135,12 @@ void PathTrie::remove() {
} }
void PathTrie::set_dictionary(fst::StdVectorFst* dictionary) { void PathTrie::set_dictionary(fst::StdVectorFst* dictionary) {
_dictionary = dictionary; dictionary_ = dictionary;
_dictionary_state = dictionary->Start(); dictionary_state_ = dictionary->Start();
_has_dictionary = true; has_dictionary_ = true;
} }
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>; using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
void PathTrie::set_matcher(std::shared_ptr<FSTMATCH> matcher) { void PathTrie::set_matcher(std::shared_ptr<FSTMATCH> matcher) {
_matcher = matcher; matcher_ = matcher;
} }

@ -36,7 +36,7 @@ public:
void set_matcher(std::shared_ptr<fst::SortedMatcher<fst::StdVectorFst>>); void set_matcher(std::shared_ptr<fst::SortedMatcher<fst::StdVectorFst>>);
bool is_empty() { return _ROOT == character; } bool is_empty() { return ROOT_ == character; }
// remove current path from root // remove current path from root
void remove(); void remove();
@ -51,17 +51,17 @@ public:
PathTrie* parent; PathTrie* parent;
private: private:
int _ROOT; int ROOT_;
bool _exists; bool exists_;
bool _has_dictionary; bool has_dictionary_;
std::vector<std::pair<int, PathTrie*>> _children; std::vector<std::pair<int, PathTrie*>> children_;
// pointer to dictionary of FST // pointer to dictionary of FST
fst::StdVectorFst* _dictionary; fst::StdVectorFst* dictionary_;
fst::StdVectorFst::StateId _dictionary_state; fst::StdVectorFst::StateId dictionary_state_;
// true if finding ars in FST // true if finding ars in FST
std::shared_ptr<fst::SortedMatcher<fst::StdVectorFst>> _matcher; std::shared_ptr<fst::SortedMatcher<fst::StdVectorFst>> matcher_;
}; };
#endif // PATH_TRIE_H #endif // PATH_TRIE_H

@ -19,19 +19,19 @@ Scorer::Scorer(double alpha,
const std::vector<std::string>& vocab_list) { const std::vector<std::string>& vocab_list) {
this->alpha = alpha; this->alpha = alpha;
this->beta = beta; this->beta = beta;
_is_character_based = true; is_character_based_ = true;
_language_model = nullptr; language_model_ = nullptr;
dictionary = nullptr; dictionary = nullptr;
_max_order = 0; max_order_ = 0;
_dict_size = 0; dict_size_ = 0;
_SPACE_ID = -1; SPACE_ID_ = -1;
setup(lm_path, vocab_list); setup(lm_path, vocab_list);
} }
Scorer::~Scorer() { Scorer::~Scorer() {
if (_language_model != nullptr) { if (language_model_ != nullptr) {
delete static_cast<lm::base::Model*>(_language_model); delete static_cast<lm::base::Model*>(language_model_);
} }
if (dictionary != nullptr) { if (dictionary != nullptr) {
delete static_cast<fst::StdVectorFst*>(dictionary); delete static_cast<fst::StdVectorFst*>(dictionary);
@ -57,20 +57,20 @@ void Scorer::load_lm(const std::string& lm_path) {
RetriveStrEnumerateVocab enumerate; RetriveStrEnumerateVocab enumerate;
lm::ngram::Config config; lm::ngram::Config config;
config.enumerate_vocab = &enumerate; config.enumerate_vocab = &enumerate;
_language_model = lm::ngram::LoadVirtual(filename, config); language_model_ = lm::ngram::LoadVirtual(filename, config);
_max_order = static_cast<lm::base::Model*>(_language_model)->Order(); max_order_ = static_cast<lm::base::Model*>(language_model_)->Order();
_vocabulary = enumerate.vocabulary; vocabulary_ = enumerate.vocabulary;
for (size_t i = 0; i < _vocabulary.size(); ++i) { for (size_t i = 0; i < vocabulary_.size(); ++i) {
if (_is_character_based && _vocabulary[i] != UNK_TOKEN && if (is_character_based_ && vocabulary_[i] != UNK_TOKEN &&
_vocabulary[i] != START_TOKEN && _vocabulary[i] != END_TOKEN && vocabulary_[i] != START_TOKEN && vocabulary_[i] != END_TOKEN &&
get_utf8_str_len(enumerate.vocabulary[i]) > 1) { get_utf8_str_len(enumerate.vocabulary[i]) > 1) {
_is_character_based = false; is_character_based_ = false;
} }
} }
} }
double Scorer::get_log_cond_prob(const std::vector<std::string>& words) { double Scorer::get_log_cond_prob(const std::vector<std::string>& words) {
lm::base::Model* model = static_cast<lm::base::Model*>(_language_model); lm::base::Model* model = static_cast<lm::base::Model*>(language_model_);
double cond_prob; double cond_prob;
lm::ngram::State state, tmp_state, out_state; lm::ngram::State state, tmp_state, out_state;
// avoid to inserting <s> in begin // avoid to inserting <s> in begin
@ -93,11 +93,11 @@ double Scorer::get_log_cond_prob(const std::vector<std::string>& words) {
double Scorer::get_sent_log_prob(const std::vector<std::string>& words) { double Scorer::get_sent_log_prob(const std::vector<std::string>& words) {
std::vector<std::string> sentence; std::vector<std::string> sentence;
if (words.size() == 0) { if (words.size() == 0) {
for (size_t i = 0; i < _max_order; ++i) { for (size_t i = 0; i < max_order_; ++i) {
sentence.push_back(START_TOKEN); sentence.push_back(START_TOKEN);
} }
} else { } else {
for (size_t i = 0; i < _max_order - 1; ++i) { for (size_t i = 0; i < max_order_ - 1; ++i) {
sentence.push_back(START_TOKEN); sentence.push_back(START_TOKEN);
} }
sentence.insert(sentence.end(), words.begin(), words.end()); sentence.insert(sentence.end(), words.begin(), words.end());
@ -107,11 +107,11 @@ double Scorer::get_sent_log_prob(const std::vector<std::string>& words) {
} }
double Scorer::get_log_prob(const std::vector<std::string>& words) { double Scorer::get_log_prob(const std::vector<std::string>& words) {
assert(words.size() > _max_order); assert(words.size() > max_order_);
double score = 0.0; double score = 0.0;
for (size_t i = 0; i < words.size() - _max_order + 1; ++i) { for (size_t i = 0; i < words.size() - max_order_ + 1; ++i) {
std::vector<std::string> ngram(words.begin() + i, std::vector<std::string> ngram(words.begin() + i,
words.begin() + i + _max_order); words.begin() + i + max_order_);
score += get_log_cond_prob(ngram); score += get_log_cond_prob(ngram);
} }
return score; return score;
@ -125,7 +125,7 @@ void Scorer::reset_params(float alpha, float beta) {
std::string Scorer::vec2str(const std::vector<int>& input) { std::string Scorer::vec2str(const std::vector<int>& input) {
std::string word; std::string word;
for (auto ind : input) { for (auto ind : input) {
word += _char_list[ind]; word += char_list_[ind];
} }
return word; return word;
} }
@ -135,7 +135,7 @@ std::vector<std::string> Scorer::split_labels(const std::vector<int>& labels) {
std::string s = vec2str(labels); std::string s = vec2str(labels);
std::vector<std::string> words; std::vector<std::string> words;
if (_is_character_based) { if (is_character_based_) {
words = split_utf8_str(s); words = split_utf8_str(s);
} else { } else {
words = split_str(s, " "); words = split_str(s, " ");
@ -144,15 +144,15 @@ std::vector<std::string> Scorer::split_labels(const std::vector<int>& labels) {
} }
void Scorer::set_char_map(const std::vector<std::string>& char_list) { void Scorer::set_char_map(const std::vector<std::string>& char_list) {
_char_list = char_list; char_list_ = char_list;
_char_map.clear(); char_map_.clear();
for (unsigned int i = 0; i < _char_list.size(); i++) { for (size_t i = 0; i < char_list_.size(); i++) {
if (_char_list[i] == " ") { if (char_list_[i] == " ") {
_SPACE_ID = i; SPACE_ID_ = i;
_char_map[' '] = i; char_map_[' '] = i;
} else if (_char_list[i].size() == 1) { } else if (char_list_[i].size() == 1) {
_char_map[_char_list[i][0]] = i; char_map_[char_list_[i][0]] = i;
} }
} }
} }
@ -162,14 +162,14 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix) {
PathTrie* current_node = prefix; PathTrie* current_node = prefix;
PathTrie* new_node = nullptr; PathTrie* new_node = nullptr;
for (int order = 0; order < _max_order; order++) { for (int order = 0; order < max_order_; order++) {
std::vector<int> prefix_vec; std::vector<int> prefix_vec;
if (_is_character_based) { if (is_character_based_) {
new_node = current_node->get_path_vec(prefix_vec, _SPACE_ID, 1); new_node = current_node->get_path_vec(prefix_vec, SPACE_ID_, 1);
current_node = new_node; current_node = new_node;
} else { } else {
new_node = current_node->get_path_vec(prefix_vec, _SPACE_ID); new_node = current_node->get_path_vec(prefix_vec, SPACE_ID_);
current_node = new_node->parent; // Skipping spaces current_node = new_node->parent; // Skipping spaces
} }
@ -179,7 +179,7 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix) {
if (new_node->character == -1) { if (new_node->character == -1) {
// No more spaces, but still need order // No more spaces, but still need order
for (int i = 0; i < _max_order - order - 1; i++) { for (int i = 0; i < max_order_ - order - 1; i++) {
ngram.push_back(START_TOKEN); ngram.push_back(START_TOKEN);
} }
break; break;
@ -193,19 +193,19 @@ void Scorer::fill_dictionary(bool add_space) {
fst::StdVectorFst dictionary; fst::StdVectorFst dictionary;
// First reverse char_list so ints can be accessed by chars // First reverse char_list so ints can be accessed by chars
std::unordered_map<std::string, int> char_map; std::unordered_map<std::string, int> char_map;
for (unsigned int i = 0; i < _char_list.size(); i++) { for (size_t i = 0; i < char_list_.size(); i++) {
char_map[_char_list[i]] = i; char_map[char_list_[i]] = i;
} }
// For each unigram convert to ints and put in trie // For each unigram convert to ints and put in trie
int dict_size = 0; int dict_size = 0;
for (const auto& word : _vocabulary) { for (const auto& word : vocabulary_) {
bool added = add_word_to_dictionary( bool added = add_word_to_dictionary(
word, char_map, add_space, _SPACE_ID, &dictionary); word, char_map, add_space, SPACE_ID_, &dictionary);
dict_size += added ? 1 : 0; dict_size += added ? 1 : 0;
} }
_dict_size = dict_size; dict_size_ = dict_size;
/* Simplify FST /* Simplify FST

@ -18,7 +18,7 @@ const std::string START_TOKEN = "<s>";
const std::string UNK_TOKEN = "<unk>"; const std::string UNK_TOKEN = "<unk>";
const std::string END_TOKEN = "</s>"; const std::string END_TOKEN = "</s>";
// Implement a callback to retrive string vocabulary. // Implement a callback to retrive the dictionary of language model.
class RetriveStrEnumerateVocab : public lm::EnumerateVocab { class RetriveStrEnumerateVocab : public lm::EnumerateVocab {
public: public:
RetriveStrEnumerateVocab() {} RetriveStrEnumerateVocab() {}
@ -50,13 +50,14 @@ public:
double get_sent_log_prob(const std::vector<std::string> &words); double get_sent_log_prob(const std::vector<std::string> &words);
size_t get_max_order() const { return _max_order; } // return the max order
size_t get_max_order() const { return max_order_; }
size_t get_dict_size() const { return _dict_size; } // return the dictionary size of language model
size_t get_dict_size() const { return dict_size_; }
bool is_char_map_empty() const { return _char_map.size() == 0; } // retrun true if the language model is character based
bool is_character_based() const { return is_character_based_; }
bool is_character_based() const { return _is_character_based; }
// reset params alpha & beta // reset params alpha & beta
void reset_params(float alpha, float beta); void reset_params(float alpha, float beta);
@ -68,20 +69,23 @@ public:
// the vector of characters (character based lm) // the vector of characters (character based lm)
std::vector<std::string> split_labels(const std::vector<int> &labels); std::vector<std::string> split_labels(const std::vector<int> &labels);
// expose to decoder // language model weight
double alpha; double alpha;
// word insertion weight
double beta; double beta;
// fst dictionary // pointer to the dictionary of FST
void *dictionary; void *dictionary;
protected: protected:
// necessary setup: load language model, set char map, fill FST's dictionary
void setup(const std::string &lm_path, void setup(const std::string &lm_path,
const std::vector<std::string> &vocab_list); const std::vector<std::string> &vocab_list);
// load language model from given path
void load_lm(const std::string &lm_path); void load_lm(const std::string &lm_path);
// fill dictionary for fst // fill dictionary for FST
void fill_dictionary(bool add_space); void fill_dictionary(bool add_space);
// set char map // set char map
@ -89,19 +93,20 @@ protected:
double get_log_prob(const std::vector<std::string> &words); double get_log_prob(const std::vector<std::string> &words);
// translate the vector in index to string
std::string vec2str(const std::vector<int> &input); std::string vec2str(const std::vector<int> &input);
private: private:
void *_language_model; void *language_model_;
bool _is_character_based; bool is_character_based_;
size_t _max_order; size_t max_order_;
size_t _dict_size; size_t dict_size_;
int _SPACE_ID; int SPACE_ID_;
std::vector<std::string> _char_list; std::vector<std::string> char_list_;
std::unordered_map<char, int> _char_map; std::unordered_map<char, int> char_map_;
std::vector<std::string> _vocabulary; std::vector<std::string> vocabulary_;
}; };
#endif // SCORER_H_ #endif // SCORER_H_

@ -39,8 +39,8 @@ def ctc_greedy_decoder(probs_seq, vocabulary):
def ctc_beam_search_decoder(probs_seq, def ctc_beam_search_decoder(probs_seq,
beam_size,
vocabulary, vocabulary,
beam_size,
cutoff_prob=1.0, cutoff_prob=1.0,
cutoff_top_n=40, cutoff_top_n=40,
ext_scoring_func=None): ext_scoring_func=None):
@ -50,10 +50,10 @@ def ctc_beam_search_decoder(probs_seq,
step, with each element being a list of normalized step, with each element being a list of normalized
probabilities over vocabulary and blank. probabilities over vocabulary and blank.
:type probs_seq: 2-D list :type probs_seq: 2-D list
:param beam_size: Width for beam search.
:type beam_size: int
:param vocabulary: Vocabulary list. :param vocabulary: Vocabulary list.
:type vocabulary: list :type vocabulary: list
:param beam_size: Width for beam search.
:type beam_size: int
:param cutoff_prob: Cutoff probability in pruning, :param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning. default 1.0, no pruning.
:type cutoff_prob: float :type cutoff_prob: float
@ -69,14 +69,14 @@ def ctc_beam_search_decoder(probs_seq,
results, in descending order of the probability. results, in descending order of the probability.
:rtype: list :rtype: list
""" """
return swig_decoders.ctc_beam_search_decoder(probs_seq.tolist(), beam_size, return swig_decoders.ctc_beam_search_decoder(probs_seq.tolist(), vocabulary,
vocabulary, cutoff_prob, beam_size, cutoff_prob,
cutoff_top_n, ext_scoring_func) cutoff_top_n, ext_scoring_func)
def ctc_beam_search_decoder_batch(probs_split, def ctc_beam_search_decoder_batch(probs_split,
beam_size,
vocabulary, vocabulary,
beam_size,
num_processes, num_processes,
cutoff_prob=1.0, cutoff_prob=1.0,
cutoff_top_n=40, cutoff_top_n=40,
@ -86,10 +86,10 @@ def ctc_beam_search_decoder_batch(probs_split,
:param probs_seq: 3-D list with each element as an instance of 2-D list :param probs_seq: 3-D list with each element as an instance of 2-D list
of probabilities used by ctc_beam_search_decoder(). of probabilities used by ctc_beam_search_decoder().
:type probs_seq: 3-D list :type probs_seq: 3-D list
:param beam_size: Width for beam search.
:type beam_size: int
:param vocabulary: Vocabulary list. :param vocabulary: Vocabulary list.
:type vocabulary: list :type vocabulary: list
:param beam_size: Width for beam search.
:type beam_size: int
:param num_processes: Number of parallel processes. :param num_processes: Number of parallel processes.
:type num_processes: int :type num_processes: int
:param cutoff_prob: Cutoff probability in vocabulary pruning, :param cutoff_prob: Cutoff probability in vocabulary pruning,
@ -112,5 +112,5 @@ def ctc_beam_search_decoder_batch(probs_split,
probs_split = [probs_seq.tolist() for probs_seq in probs_split] probs_split = [probs_seq.tolist() for probs_seq in probs_split]
return swig_decoders.ctc_beam_search_decoder_batch( return swig_decoders.ctc_beam_search_decoder_batch(
probs_split, beam_size, vocabulary, num_processes, cutoff_prob, probs_split, vocabulary, beam_size, num_processes, cutoff_prob,
cutoff_top_n, ext_scoring_func) cutoff_top_n, ext_scoring_func)

Loading…
Cancel
Save