// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "COPYING.APACHE2.0");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef PATH_TRIE_H
#define PATH_TRIE_H

#include <algorithm>
#include <limits>
#include <memory>
#include <utility>
#include <vector>

#include "fst/fstlib.h"

/* Trie tree for prefix storing and manipulating, with a dictionary in
 * finite-state transducer for spelling correction.
 */
class PathTrie {
  public:
    PathTrie();
    ~PathTrie();

    // get new prefix after appending new char
    PathTrie* get_path_trie(int new_char, bool reset = true);

    // get the prefix in index from root to current node
    PathTrie* get_path_vec(std::vector<int>& output);

    // get the prefix in index from some stop node to current nodel
    PathTrie* get_path_vec(
        std::vector<int>& output,
        int stop,
        size_t max_steps = std::numeric_limits<size_t>::max());

    // update log probs
    void iterate_to_vec(std::vector<PathTrie*>& output);

    // set dictionary for FST
    void set_dictionary(fst::StdVectorFst* dictionary);

    void set_matcher(std::shared_ptr<fst::SortedMatcher<fst::StdVectorFst>>);

    bool is_empty() { return ROOT_ == character; }

    // remove current path from root
    void remove();

    float log_prob_b_prev;
    float log_prob_nb_prev;
    float log_prob_b_cur;
    float log_prob_nb_cur;
    float score;
    float approx_ctc;
    int character;
    PathTrie* parent;

  private:
    int ROOT_;
    bool exists_;
    bool has_dictionary_;

    std::vector<std::pair<int, PathTrie*>> children_;

    // pointer to dictionary of FST
    fst::StdVectorFst* dictionary_;
    fst::StdVectorFst::StateId dictionary_state_;
    // true if finding ars in FST
    std::shared_ptr<fst::SortedMatcher<fst::StdVectorFst>> matcher_;
};

#endif  // PATH_TRIE_H