#ifndef SCORER_H_ #define SCORER_H_ #include #include #include #include #include "lm/enumerate_vocab.hh" #include "lm/word_index.hh" #include "lm/virtual_interface.hh" #include "util/string_piece.hh" #include "path_trie.h" const double OOV_SCOER = -1000.0; const std::string START_TOKEN = ""; const std::string UNK_TOKEN = ""; const std::string END_TOKEN = ""; // Implement a callback to retrive string vocabulary. class RetriveStrEnumerateVocab : public lm::EnumerateVocab { public: RetriveStrEnumerateVocab() {} void Add(lm::WordIndex index, const StringPiece& str) { vocabulary.push_back(std::string(str.data(), str.length())); } std::vector vocabulary; }; // External scorer to query languange score for n-gram or sentence. // Example: // Scorer scorer(alpha, beta, "path_of_language_model"); // scorer.get_log_cond_prob({ "WORD1", "WORD2", "WORD3" }); // scorer.get_log_cond_prob("this a sentence"); // scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" }); class Scorer{ public: Scorer(double alpha, double beta, const std::string& lm_path); ~Scorer(); double get_log_cond_prob(const std::vector& words); double get_sent_log_prob(const std::vector& words); size_t get_max_order() { return _max_order; } bool is_character_based() { return _is_character_based; } std::vector get_vocab() { return _vocabulary; } // word insertion term int word_count(std::string); // get the log cond prob of the last word double get_log_cond_prob(std::string); // reset params alpha & beta void reset_params(float alpha, float beta); // get the final score double get_score(std::string, bool log=false); // make ngram std::vector make_ngram(PathTrie* prefix); // fill dictionary for fst void fill_dictionary(bool add_space); // set char map void set_char_map(std::vector char_list); // expose to decoder double alpha; double beta; // fst dictionary void* _dictionary; protected: void load_LM(const char* filename); double get_log_prob(const std::vector& words); std::string vec2str(const std::vector &input); std::vector split_labels(const std::vector &labels); std::vector split_str(const std::string &s, const std::string &delim); private: void _init_char_list(); void _init_char_map(); void* _language_model; bool _is_character_based; size_t _max_order; unsigned int _SPACE; std::vector _char_list; std::unordered_map _char_map; std::vector _vocabulary; }; #endif // SCORER_H_