diff --git a/decoders/swig/path_trie.cpp b/decoders/swig/path_trie.cpp index 40d90970..152efa82 100644 --- a/decoders/swig/path_trie.cpp +++ b/decoders/swig/path_trie.cpp @@ -52,7 +52,7 @@ PathTrie* PathTrie::get_path_trie(int new_char, bool reset) { } else { if (has_dictionary_) { matcher_->SetState(dictionary_state_); - bool found = matcher_->Find(new_char); + bool found = matcher_->Find(new_char + 1); if (!found) { // Adding this character causes word outside dictionary auto FSTZERO = fst::TropicalWeight::Zero(); diff --git a/decoders/swig/scorer.cpp b/decoders/swig/scorer.cpp index 686c67c7..27b61cd0 100644 --- a/decoders/swig/scorer.cpp +++ b/decoders/swig/scorer.cpp @@ -149,13 +149,15 @@ void Scorer::set_char_map(const std::vector& char_list) { char_list_ = char_list; char_map_.clear(); + // Set the char map for the FST for spelling correction for (size_t i = 0; i < char_list_.size(); i++) { if (char_list_[i] == " ") { SPACE_ID_ = i; - char_map_[' '] = i; - } else if (char_list_[i].size() == 1) { - char_map_[char_list_[i][0]] = i; } + // The initial state of FST is state 0, hence the index of chars in + // the FST should start from 1 to avoid the conflict with the initial + // state, otherwise wrong decoding results would be given. + char_map_[char_list_[i]] = i + 1; } } @@ -193,17 +195,11 @@ std::vector Scorer::make_ngram(PathTrie* prefix) { void Scorer::fill_dictionary(bool add_space) { fst::StdVectorFst dictionary; - // First reverse char_list so ints can be accessed by chars - std::unordered_map char_map; - for (size_t i = 0; i < char_list_.size(); i++) { - char_map[char_list_[i]] = i; - } - // For each unigram convert to ints and put in trie int dict_size = 0; for (const auto& word : vocabulary_) { bool added = add_word_to_dictionary( - word, char_map, add_space, SPACE_ID_, &dictionary); + word, char_map_, add_space, SPACE_ID_ + 1, &dictionary); dict_size += added ? 1 : 0; } diff --git a/decoders/swig/scorer.h b/decoders/swig/scorer.h index 61836463..5ebc719c 100644 --- a/decoders/swig/scorer.h +++ b/decoders/swig/scorer.h @@ -104,7 +104,7 @@ private: int SPACE_ID_; std::vector char_list_; - std::unordered_map char_map_; + std::unordered_map char_map_; std::vector vocabulary_; }; diff --git a/decoders/swig/setup.py b/decoders/swig/setup.py index b6bc0ca0..a4bb2e9d 100644 --- a/decoders/swig/setup.py +++ b/decoders/swig/setup.py @@ -113,7 +113,7 @@ decoders_module = [ setup( name='swig_decoders', - version='1.0', + version='1.1', description="""CTC decoders""", ext_modules=decoders_module, py_modules=['swig_decoders'], ) diff --git a/setup.sh b/setup.sh index 7c40415d..ec5e47ec 100644 --- a/setup.sh +++ b/setup.sh @@ -27,7 +27,7 @@ if [ $? != 0 ]; then fi # install decoders -python -c "import swig_decoders" +python -c "import pkg_resources; pkg_resources.require(\"swig_decoders==1.1\")" if [ $? != 0 ]; then cd decoders/swig > /dev/null sh setup.sh