#include #include #include #include #include #include "path_trie.h" #include "decoder_utils.h" PathTrie::PathTrie() { float lowest = -1.0*std::numeric_limits::max(); _log_prob_b_prev = lowest; _log_prob_nb_prev = lowest; _log_prob_b_cur = lowest; _log_prob_nb_cur = lowest; _score = lowest; _ROOT = -1; _character = _ROOT; _exists = true; _parent = nullptr; _dictionary = nullptr; _dictionary_state = 0; _has_dictionary = false; _matcher = nullptr; // finds arcs in FST } PathTrie::~PathTrie() { for (auto child : _children) { delete child.second; } } PathTrie* PathTrie::get_path_trie(int new_char, bool reset) { auto child = _children.begin(); for (child = _children.begin(); child != _children.end(); ++child) { if (child->first == new_char) { break; } } if ( child != _children.end() ) { if (!child->second->_exists) { child->second->_exists = true; float lowest = -1.0*std::numeric_limits::max(); child->second->_log_prob_b_prev = lowest; child->second->_log_prob_nb_prev = lowest; child->second->_log_prob_b_cur = lowest; child->second->_log_prob_nb_cur = lowest; } return (child->second); } else { if (_has_dictionary) { _matcher->SetState(_dictionary_state); bool found = _matcher->Find(new_char); if (!found) { // Adding this character causes word outside dictionary auto FSTZERO = fst::TropicalWeight::Zero(); auto final_weight = _dictionary->Final(_dictionary_state); bool is_final = (final_weight != FSTZERO); if (is_final && reset) { _dictionary_state = _dictionary->Start(); } return nullptr; } else { PathTrie* new_path = new PathTrie; new_path->_character = new_char; new_path->_parent = this; new_path->_dictionary = _dictionary; new_path->_dictionary_state = _matcher->Value().nextstate; new_path->_has_dictionary = true; new_path->_matcher = _matcher; _children.push_back(std::make_pair(new_char, new_path)); return new_path; } } else { PathTrie* new_path = new PathTrie; new_path->_character = new_char; new_path->_parent = this; _children.push_back(std::make_pair(new_char, new_path)); return new_path; } } } PathTrie* PathTrie::get_path_vec(std::vector& output) { return get_path_vec(output, _ROOT); } PathTrie* PathTrie::get_path_vec(std::vector& output, int stop, size_t max_steps /*= std::numeric_limits::max() */) { if (_character == stop || _character == _ROOT || output.size() == max_steps) { std::reverse(output.begin(), output.end()); return this; } else { output.push_back(_character); return _parent->get_path_vec(output, stop, max_steps); } } void PathTrie::iterate_to_vec( std::vector& output) { if (_exists) { _log_prob_b_prev = _log_prob_b_cur; _log_prob_nb_prev = _log_prob_nb_cur; _log_prob_b_cur = -1.0 * std::numeric_limits::max(); _log_prob_nb_cur = -1.0 * std::numeric_limits::max(); _score = log_sum_exp(_log_prob_b_prev, _log_prob_nb_prev); output.push_back(this); } for (auto child : _children) { child.second->iterate_to_vec(output); } } //------------------------------------------------------- // Effectively removes node //------------------------------------------------------- void PathTrie::remove() { _exists = false; if (_children.size() == 0) { auto child = _parent->_children.begin(); for (child = _parent->_children.begin(); child != _parent->_children.end(); ++child) { if (child->first == _character) { _parent->_children.erase(child); break; } } if ( _parent->_children.size() == 0 && !_parent->_exists ) { _parent->remove(); } delete this; } } void PathTrie::set_dictionary(fst::StdVectorFst* dictionary) { _dictionary = dictionary; _dictionary_state = dictionary->Start(); _has_dictionary = true; } using FSTMATCH = fst::SortedMatcher; void PathTrie::set_matcher(std::shared_ptr matcher) { _matcher = matcher; }