format varabiables' name & add more comments

pull/2/head
Yibing Liu 7 years ago
parent a24d0138d9
commit 3018dcb4d9

@ -18,8 +18,8 @@ using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
const std::vector<std::vector<double>> &probs_seq,
const std::vector<std::string> &vocabulary,
size_t beam_size,
std::vector<std::string> vocabulary,
double cutoff_prob,
size_t cutoff_top_n,
Scorer *ext_scorer) {
@ -36,8 +36,7 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
size_t blank_id = vocabulary.size();
// assign space id
std::vector<std::string>::iterator it =
std::find(vocabulary.begin(), vocabulary.end(), " ");
auto it = std::find(vocabulary.begin(), vocabulary.end(), " ");
int space_id = it - vocabulary.begin();
// if no space in vocabulary
if ((size_t)space_id >= vocabulary.size()) {
@ -173,11 +172,11 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
std::vector<std::vector<std::pair<double, std::string>>>
ctc_beam_search_decoder_batch(
const std::vector<std::vector<std::vector<double>>> &probs_split,
const size_t beam_size,
const std::vector<std::string> &vocabulary,
const size_t num_processes,
const double cutoff_prob,
const size_t cutoff_top_n,
size_t beam_size,
size_t num_processes,
double cutoff_prob,
size_t cutoff_top_n,
Scorer *ext_scorer) {
VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!");
// thread pool
@ -190,8 +189,8 @@ ctc_beam_search_decoder_batch(
for (size_t i = 0; i < batch_size; ++i) {
res.emplace_back(pool.enqueue(ctc_beam_search_decoder,
probs_split[i],
beam_size,
vocabulary,
beam_size,
cutoff_prob,
cutoff_top_n,
ext_scorer));

@ -12,8 +12,8 @@
* Parameters:
* probs_seq: 2-D vector that each element is a vector of probabilities
* over vocabulary of one time step.
* beam_size: The width of beam search.
* vocabulary: A vector of vocabulary.
* beam_size: The width of beam search.
* cutoff_prob: Cutoff probability for pruning.
* cutoff_top_n: Cutoff number for pruning.
* ext_scorer: External scorer to evaluate a prefix, which consists of
@ -25,8 +25,8 @@
*/
std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
const std::vector<std::vector<double>> &probs_seq,
const std::vector<std::string> &vocabulary,
size_t beam_size,
std::vector<std::string> vocabulary,
double cutoff_prob = 1.0,
size_t cutoff_top_n = 40,
Scorer *ext_scorer = nullptr);
@ -36,9 +36,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
* Parameters:
* probs_seq: 3-D vector that each element is a 2-D vector that can be used
* by ctc_beam_search_decoder().
* .
* beam_size: The width of beam search.
* vocabulary: A vector of vocabulary.
* beam_size: The width of beam search.
* num_processes: Number of threads for beam search.
* cutoff_prob: Cutoff probability for pruning.
* cutoff_top_n: Cutoff number for pruning.
@ -52,8 +51,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
std::vector<std::vector<std::pair<double, std::string>>>
ctc_beam_search_decoder_batch(
const std::vector<std::vector<std::vector<double>>> &probs_split,
size_t beam_size,
const std::vector<std::string> &vocabulary,
size_t beam_size,
size_t num_processes,
double cutoff_prob = 1.0,
size_t cutoff_top_n = 40,

@ -15,32 +15,32 @@ PathTrie::PathTrie() {
log_prob_nb_cur = -NUM_FLT_INF;
score = -NUM_FLT_INF;
_ROOT = -1;
character = _ROOT;
_exists = true;
ROOT_ = -1;
character = ROOT_;
exists_ = true;
parent = nullptr;
_dictionary = nullptr;
_dictionary_state = 0;
_has_dictionary = false;
_matcher = nullptr;
dictionary_ = nullptr;
dictionary_state_ = 0;
has_dictionary_ = false;
matcher_ = nullptr;
}
PathTrie::~PathTrie() {
for (auto child : _children) {
for (auto child : children_) {
delete child.second;
}
}
PathTrie* PathTrie::get_path_trie(int new_char, bool reset) {
auto child = _children.begin();
for (child = _children.begin(); child != _children.end(); ++child) {
auto child = children_.begin();
for (child = children_.begin(); child != children_.end(); ++child) {
if (child->first == new_char) {
break;
}
}
if (child != _children.end()) {
if (!child->second->_exists) {
child->second->_exists = true;
if (child != children_.end()) {
if (!child->second->exists_) {
child->second->exists_ = true;
child->second->log_prob_b_prev = -NUM_FLT_INF;
child->second->log_prob_nb_prev = -NUM_FLT_INF;
child->second->log_prob_b_cur = -NUM_FLT_INF;
@ -48,47 +48,47 @@ PathTrie* PathTrie::get_path_trie(int new_char, bool reset) {
}
return (child->second);
} else {
if (_has_dictionary) {
_matcher->SetState(_dictionary_state);
bool found = _matcher->Find(new_char);
if (has_dictionary_) {
matcher_->SetState(dictionary_state_);
bool found = matcher_->Find(new_char);
if (!found) {
// Adding this character causes word outside dictionary
auto FSTZERO = fst::TropicalWeight::Zero();
auto final_weight = _dictionary->Final(_dictionary_state);
auto final_weight = dictionary_->Final(dictionary_state_);
bool is_final = (final_weight != FSTZERO);
if (is_final && reset) {
_dictionary_state = _dictionary->Start();
dictionary_state_ = dictionary_->Start();
}
return nullptr;
} else {
PathTrie* new_path = new PathTrie;
new_path->character = new_char;
new_path->parent = this;
new_path->_dictionary = _dictionary;
new_path->_dictionary_state = _matcher->Value().nextstate;
new_path->_has_dictionary = true;
new_path->_matcher = _matcher;
_children.push_back(std::make_pair(new_char, new_path));
new_path->dictionary_ = dictionary_;
new_path->dictionary_state_ = matcher_->Value().nextstate;
new_path->has_dictionary_ = true;
new_path->matcher_ = matcher_;
children_.push_back(std::make_pair(new_char, new_path));
return new_path;
}
} else {
PathTrie* new_path = new PathTrie;
new_path->character = new_char;
new_path->parent = this;
_children.push_back(std::make_pair(new_char, new_path));
children_.push_back(std::make_pair(new_char, new_path));
return new_path;
}
}
}
PathTrie* PathTrie::get_path_vec(std::vector<int>& output) {
return get_path_vec(output, _ROOT);
return get_path_vec(output, ROOT_);
}
PathTrie* PathTrie::get_path_vec(std::vector<int>& output,
int stop,
size_t max_steps) {
if (character == stop || character == _ROOT || output.size() == max_steps) {
if (character == stop || character == ROOT_ || output.size() == max_steps) {
std::reverse(output.begin(), output.end());
return this;
} else {
@ -98,7 +98,7 @@ PathTrie* PathTrie::get_path_vec(std::vector<int>& output,
}
void PathTrie::iterate_to_vec(std::vector<PathTrie*>& output) {
if (_exists) {
if (exists_) {
log_prob_b_prev = log_prob_b_cur;
log_prob_nb_prev = log_prob_nb_cur;
@ -108,25 +108,25 @@ void PathTrie::iterate_to_vec(std::vector<PathTrie*>& output) {
score = log_sum_exp(log_prob_b_prev, log_prob_nb_prev);
output.push_back(this);
}
for (auto child : _children) {
for (auto child : children_) {
child.second->iterate_to_vec(output);
}
}
void PathTrie::remove() {
_exists = false;
exists_ = false;
if (_children.size() == 0) {
auto child = parent->_children.begin();
for (child = parent->_children.begin(); child != parent->_children.end();
if (children_.size() == 0) {
auto child = parent->children_.begin();
for (child = parent->children_.begin(); child != parent->children_.end();
++child) {
if (child->first == character) {
parent->_children.erase(child);
parent->children_.erase(child);
break;
}
}
if (parent->_children.size() == 0 && !parent->_exists) {
if (parent->children_.size() == 0 && !parent->exists_) {
parent->remove();
}
@ -135,12 +135,12 @@ void PathTrie::remove() {
}
void PathTrie::set_dictionary(fst::StdVectorFst* dictionary) {
_dictionary = dictionary;
_dictionary_state = dictionary->Start();
_has_dictionary = true;
dictionary_ = dictionary;
dictionary_state_ = dictionary->Start();
has_dictionary_ = true;
}
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
void PathTrie::set_matcher(std::shared_ptr<FSTMATCH> matcher) {
_matcher = matcher;
matcher_ = matcher;
}

@ -36,7 +36,7 @@ public:
void set_matcher(std::shared_ptr<fst::SortedMatcher<fst::StdVectorFst>>);
bool is_empty() { return _ROOT == character; }
bool is_empty() { return ROOT_ == character; }
// remove current path from root
void remove();
@ -51,17 +51,17 @@ public:
PathTrie* parent;
private:
int _ROOT;
bool _exists;
bool _has_dictionary;
int ROOT_;
bool exists_;
bool has_dictionary_;
std::vector<std::pair<int, PathTrie*>> _children;
std::vector<std::pair<int, PathTrie*>> children_;
// pointer to dictionary of FST
fst::StdVectorFst* _dictionary;
fst::StdVectorFst::StateId _dictionary_state;
fst::StdVectorFst* dictionary_;
fst::StdVectorFst::StateId dictionary_state_;
// true if finding ars in FST
std::shared_ptr<fst::SortedMatcher<fst::StdVectorFst>> _matcher;
std::shared_ptr<fst::SortedMatcher<fst::StdVectorFst>> matcher_;
};
#endif // PATH_TRIE_H

@ -19,19 +19,19 @@ Scorer::Scorer(double alpha,
const std::vector<std::string>& vocab_list) {
this->alpha = alpha;
this->beta = beta;
_is_character_based = true;
_language_model = nullptr;
is_character_based_ = true;
language_model_ = nullptr;
dictionary = nullptr;
_max_order = 0;
_dict_size = 0;
_SPACE_ID = -1;
max_order_ = 0;
dict_size_ = 0;
SPACE_ID_ = -1;
setup(lm_path, vocab_list);
}
Scorer::~Scorer() {
if (_language_model != nullptr) {
delete static_cast<lm::base::Model*>(_language_model);
if (language_model_ != nullptr) {
delete static_cast<lm::base::Model*>(language_model_);
}
if (dictionary != nullptr) {
delete static_cast<fst::StdVectorFst*>(dictionary);
@ -57,20 +57,20 @@ void Scorer::load_lm(const std::string& lm_path) {
RetriveStrEnumerateVocab enumerate;
lm::ngram::Config config;
config.enumerate_vocab = &enumerate;
_language_model = lm::ngram::LoadVirtual(filename, config);
_max_order = static_cast<lm::base::Model*>(_language_model)->Order();
_vocabulary = enumerate.vocabulary;
for (size_t i = 0; i < _vocabulary.size(); ++i) {
if (_is_character_based && _vocabulary[i] != UNK_TOKEN &&
_vocabulary[i] != START_TOKEN && _vocabulary[i] != END_TOKEN &&
language_model_ = lm::ngram::LoadVirtual(filename, config);
max_order_ = static_cast<lm::base::Model*>(language_model_)->Order();
vocabulary_ = enumerate.vocabulary;
for (size_t i = 0; i < vocabulary_.size(); ++i) {
if (is_character_based_ && vocabulary_[i] != UNK_TOKEN &&
vocabulary_[i] != START_TOKEN && vocabulary_[i] != END_TOKEN &&
get_utf8_str_len(enumerate.vocabulary[i]) > 1) {
_is_character_based = false;
is_character_based_ = false;
}
}
}
double Scorer::get_log_cond_prob(const std::vector<std::string>& words) {
lm::base::Model* model = static_cast<lm::base::Model*>(_language_model);
lm::base::Model* model = static_cast<lm::base::Model*>(language_model_);
double cond_prob;
lm::ngram::State state, tmp_state, out_state;
// avoid to inserting <s> in begin
@ -93,11 +93,11 @@ double Scorer::get_log_cond_prob(const std::vector<std::string>& words) {
double Scorer::get_sent_log_prob(const std::vector<std::string>& words) {
std::vector<std::string> sentence;
if (words.size() == 0) {
for (size_t i = 0; i < _max_order; ++i) {
for (size_t i = 0; i < max_order_; ++i) {
sentence.push_back(START_TOKEN);
}
} else {
for (size_t i = 0; i < _max_order - 1; ++i) {
for (size_t i = 0; i < max_order_ - 1; ++i) {
sentence.push_back(START_TOKEN);
}
sentence.insert(sentence.end(), words.begin(), words.end());
@ -107,11 +107,11 @@ double Scorer::get_sent_log_prob(const std::vector<std::string>& words) {
}
double Scorer::get_log_prob(const std::vector<std::string>& words) {
assert(words.size() > _max_order);
assert(words.size() > max_order_);
double score = 0.0;
for (size_t i = 0; i < words.size() - _max_order + 1; ++i) {
for (size_t i = 0; i < words.size() - max_order_ + 1; ++i) {
std::vector<std::string> ngram(words.begin() + i,
words.begin() + i + _max_order);
words.begin() + i + max_order_);
score += get_log_cond_prob(ngram);
}
return score;
@ -125,7 +125,7 @@ void Scorer::reset_params(float alpha, float beta) {
std::string Scorer::vec2str(const std::vector<int>& input) {
std::string word;
for (auto ind : input) {
word += _char_list[ind];
word += char_list_[ind];
}
return word;
}
@ -135,7 +135,7 @@ std::vector<std::string> Scorer::split_labels(const std::vector<int>& labels) {
std::string s = vec2str(labels);
std::vector<std::string> words;
if (_is_character_based) {
if (is_character_based_) {
words = split_utf8_str(s);
} else {
words = split_str(s, " ");
@ -144,15 +144,15 @@ std::vector<std::string> Scorer::split_labels(const std::vector<int>& labels) {
}
void Scorer::set_char_map(const std::vector<std::string>& char_list) {
_char_list = char_list;
_char_map.clear();
for (unsigned int i = 0; i < _char_list.size(); i++) {
if (_char_list[i] == " ") {
_SPACE_ID = i;
_char_map[' '] = i;
} else if (_char_list[i].size() == 1) {
_char_map[_char_list[i][0]] = i;
char_list_ = char_list;
char_map_.clear();
for (size_t i = 0; i < char_list_.size(); i++) {
if (char_list_[i] == " ") {
SPACE_ID_ = i;
char_map_[' '] = i;
} else if (char_list_[i].size() == 1) {
char_map_[char_list_[i][0]] = i;
}
}
}
@ -162,14 +162,14 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix) {
PathTrie* current_node = prefix;
PathTrie* new_node = nullptr;
for (int order = 0; order < _max_order; order++) {
for (int order = 0; order < max_order_; order++) {
std::vector<int> prefix_vec;
if (_is_character_based) {
new_node = current_node->get_path_vec(prefix_vec, _SPACE_ID, 1);
if (is_character_based_) {
new_node = current_node->get_path_vec(prefix_vec, SPACE_ID_, 1);
current_node = new_node;
} else {
new_node = current_node->get_path_vec(prefix_vec, _SPACE_ID);
new_node = current_node->get_path_vec(prefix_vec, SPACE_ID_);
current_node = new_node->parent; // Skipping spaces
}
@ -179,7 +179,7 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix) {
if (new_node->character == -1) {
// No more spaces, but still need order
for (int i = 0; i < _max_order - order - 1; i++) {
for (int i = 0; i < max_order_ - order - 1; i++) {
ngram.push_back(START_TOKEN);
}
break;
@ -193,19 +193,19 @@ void Scorer::fill_dictionary(bool add_space) {
fst::StdVectorFst dictionary;
// First reverse char_list so ints can be accessed by chars
std::unordered_map<std::string, int> char_map;
for (unsigned int i = 0; i < _char_list.size(); i++) {
char_map[_char_list[i]] = i;
for (size_t i = 0; i < char_list_.size(); i++) {
char_map[char_list_[i]] = i;
}
// For each unigram convert to ints and put in trie
int dict_size = 0;
for (const auto& word : _vocabulary) {
for (const auto& word : vocabulary_) {
bool added = add_word_to_dictionary(
word, char_map, add_space, _SPACE_ID, &dictionary);
word, char_map, add_space, SPACE_ID_, &dictionary);
dict_size += added ? 1 : 0;
}
_dict_size = dict_size;
dict_size_ = dict_size;
/* Simplify FST

@ -18,7 +18,7 @@ const std::string START_TOKEN = "<s>";
const std::string UNK_TOKEN = "<unk>";
const std::string END_TOKEN = "</s>";
// Implement a callback to retrive string vocabulary.
// Implement a callback to retrive the dictionary of language model.
class RetriveStrEnumerateVocab : public lm::EnumerateVocab {
public:
RetriveStrEnumerateVocab() {}
@ -50,13 +50,14 @@ public:
double get_sent_log_prob(const std::vector<std::string> &words);
size_t get_max_order() const { return _max_order; }
// return the max order
size_t get_max_order() const { return max_order_; }
size_t get_dict_size() const { return _dict_size; }
// return the dictionary size of language model
size_t get_dict_size() const { return dict_size_; }
bool is_char_map_empty() const { return _char_map.size() == 0; }
bool is_character_based() const { return _is_character_based; }
// retrun true if the language model is character based
bool is_character_based() const { return is_character_based_; }
// reset params alpha & beta
void reset_params(float alpha, float beta);
@ -68,20 +69,23 @@ public:
// the vector of characters (character based lm)
std::vector<std::string> split_labels(const std::vector<int> &labels);
// expose to decoder
// language model weight
double alpha;
// word insertion weight
double beta;
// fst dictionary
// pointer to the dictionary of FST
void *dictionary;
protected:
// necessary setup: load language model, set char map, fill FST's dictionary
void setup(const std::string &lm_path,
const std::vector<std::string> &vocab_list);
// load language model from given path
void load_lm(const std::string &lm_path);
// fill dictionary for fst
// fill dictionary for FST
void fill_dictionary(bool add_space);
// set char map
@ -89,19 +93,20 @@ protected:
double get_log_prob(const std::vector<std::string> &words);
// translate the vector in index to string
std::string vec2str(const std::vector<int> &input);
private:
void *_language_model;
bool _is_character_based;
size_t _max_order;
size_t _dict_size;
void *language_model_;
bool is_character_based_;
size_t max_order_;
size_t dict_size_;
int _SPACE_ID;
std::vector<std::string> _char_list;
std::unordered_map<char, int> _char_map;
int SPACE_ID_;
std::vector<std::string> char_list_;
std::unordered_map<char, int> char_map_;
std::vector<std::string> _vocabulary;
std::vector<std::string> vocabulary_;
};
#endif // SCORER_H_

@ -39,8 +39,8 @@ def ctc_greedy_decoder(probs_seq, vocabulary):
def ctc_beam_search_decoder(probs_seq,
beam_size,
vocabulary,
beam_size,
cutoff_prob=1.0,
cutoff_top_n=40,
ext_scoring_func=None):
@ -50,10 +50,10 @@ def ctc_beam_search_decoder(probs_seq,
step, with each element being a list of normalized
probabilities over vocabulary and blank.
:type probs_seq: 2-D list
:param beam_size: Width for beam search.
:type beam_size: int
:param vocabulary: Vocabulary list.
:type vocabulary: list
:param beam_size: Width for beam search.
:type beam_size: int
:param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning.
:type cutoff_prob: float
@ -69,14 +69,14 @@ def ctc_beam_search_decoder(probs_seq,
results, in descending order of the probability.
:rtype: list
"""
return swig_decoders.ctc_beam_search_decoder(probs_seq.tolist(), beam_size,
vocabulary, cutoff_prob,
return swig_decoders.ctc_beam_search_decoder(probs_seq.tolist(), vocabulary,
beam_size, cutoff_prob,
cutoff_top_n, ext_scoring_func)
def ctc_beam_search_decoder_batch(probs_split,
beam_size,
vocabulary,
beam_size,
num_processes,
cutoff_prob=1.0,
cutoff_top_n=40,
@ -86,10 +86,10 @@ def ctc_beam_search_decoder_batch(probs_split,
:param probs_seq: 3-D list with each element as an instance of 2-D list
of probabilities used by ctc_beam_search_decoder().
:type probs_seq: 3-D list
:param beam_size: Width for beam search.
:type beam_size: int
:param vocabulary: Vocabulary list.
:type vocabulary: list
:param beam_size: Width for beam search.
:type beam_size: int
:param num_processes: Number of parallel processes.
:type num_processes: int
:param cutoff_prob: Cutoff probability in vocabulary pruning,
@ -112,5 +112,5 @@ def ctc_beam_search_decoder_batch(probs_split,
probs_split = [probs_seq.tolist() for probs_seq in probs_split]
return swig_decoders.ctc_beam_search_decoder_batch(
probs_split, beam_size, vocabulary, num_processes, cutoff_prob,
probs_split, vocabulary, beam_size, num_processes, cutoff_prob,
cutoff_top_n, ext_scoring_func)

Loading…
Cancel
Save