refactor raw ctc decoder into ctcdecoder

pull/882/head
Hui Zhang 3 years ago
parent 4f99dca88e
commit 69bd17dcb2

@ -1,3 +0,0 @@
# Reference
* [Sequence Modeling With CTC](https://distill.pub/2017/ctc/)
* [First-Pass Large Vocabulary Continuous Speech Recognition using Bi-Directional Recurrent DNNs](https://arxiv.org/pdf/1408.2873.pdf)

@ -1,13 +1 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .ctcdecoder import swig_wrapper

@ -0,0 +1,13 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,243 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ctc_beam_search_decoder.h"
#include <algorithm>
#include <cmath>
#include <iostream>
#include <limits>
#include <map>
#include <utility>
#include "ThreadPool.h"
#include "fst/fstlib.h"
#include "decoder_utils.h"
#include "path_trie.h"
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
const std::vector<std::vector<double>> &probs_seq,
const std::vector<std::string> &vocabulary,
size_t beam_size,
double cutoff_prob,
size_t cutoff_top_n,
Scorer *ext_scorer,
size_t blank_id) {
// dimension check
size_t num_time_steps = probs_seq.size();
for (size_t i = 0; i < num_time_steps; ++i) {
VALID_CHECK_EQ(probs_seq[i].size(),
// vocabulary.size() + 1,
vocabulary.size(),
"The shape of probs_seq does not match with "
"the shape of the vocabulary");
}
// assign space id
auto it = std::find(vocabulary.begin(), vocabulary.end(), kSPACE);
int space_id = it - vocabulary.begin();
// if no space in vocabulary
if ((size_t)space_id >= vocabulary.size()) {
space_id = -2;
}
// init prefixes' root
PathTrie root;
root.score = root.log_prob_b_prev = 0.0;
std::vector<PathTrie *> prefixes;
prefixes.push_back(&root);
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
auto fst_dict =
static_cast<fst::StdVectorFst *>(ext_scorer->dictionary);
fst::StdVectorFst *dict_ptr = fst_dict->Copy(true);
root.set_dictionary(dict_ptr);
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
root.set_matcher(matcher);
}
// prefix search over time
for (size_t time_step = 0; time_step < num_time_steps; ++time_step) {
auto &prob = probs_seq[time_step];
float min_cutoff = -NUM_FLT_INF;
bool full_beam = false;
if (ext_scorer != nullptr) {
size_t num_prefixes = std::min(prefixes.size(), beam_size);
std::sort(prefixes.begin(),
prefixes.begin() + num_prefixes,
prefix_compare);
min_cutoff = prefixes[num_prefixes - 1]->score +
std::log(prob[blank_id]) -
std::max(0.0, ext_scorer->beta);
full_beam = (num_prefixes == beam_size);
}
std::vector<std::pair<size_t, float>> log_prob_idx =
get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n);
// loop over chars
for (size_t index = 0; index < log_prob_idx.size(); index++) {
auto c = log_prob_idx[index].first;
auto log_prob_c = log_prob_idx[index].second;
for (size_t i = 0; i < prefixes.size() && i < beam_size; ++i) {
auto prefix = prefixes[i];
if (full_beam && log_prob_c + prefix->score < min_cutoff) {
break;
}
// blank
if (c == blank_id) {
prefix->log_prob_b_cur = log_sum_exp(
prefix->log_prob_b_cur, log_prob_c + prefix->score);
continue;
}
// repeated character
if (c == prefix->character) {
prefix->log_prob_nb_cur =
log_sum_exp(prefix->log_prob_nb_cur,
log_prob_c + prefix->log_prob_nb_prev);
}
// get new prefix
auto prefix_new = prefix->get_path_trie(c);
if (prefix_new != nullptr) {
float log_p = -NUM_FLT_INF;
if (c == prefix->character &&
prefix->log_prob_b_prev > -NUM_FLT_INF) {
log_p = log_prob_c + prefix->log_prob_b_prev;
} else if (c != prefix->character) {
log_p = log_prob_c + prefix->score;
}
// language model scoring
if (ext_scorer != nullptr &&
(c == space_id || ext_scorer->is_character_based())) {
PathTrie *prefix_to_score = nullptr;
// skip scoring the space
if (ext_scorer->is_character_based()) {
prefix_to_score = prefix_new;
} else {
prefix_to_score = prefix;
}
float score = 0.0;
std::vector<std::string> ngram;
ngram = ext_scorer->make_ngram(prefix_to_score);
score = ext_scorer->get_log_cond_prob(ngram) *
ext_scorer->alpha;
log_p += score;
log_p += ext_scorer->beta;
}
prefix_new->log_prob_nb_cur =
log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
}
} // end of loop over prefix
} // end of loop over vocabulary
prefixes.clear();
// update log probs
root.iterate_to_vec(prefixes);
// only preserve top beam_size prefixes
if (prefixes.size() >= beam_size) {
std::nth_element(prefixes.begin(),
prefixes.begin() + beam_size,
prefixes.end(),
prefix_compare);
for (size_t i = beam_size; i < prefixes.size(); ++i) {
prefixes[i]->remove();
}
}
} // end of loop over time
// score the last word of each prefix that doesn't end with space
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
auto prefix = prefixes[i];
if (!prefix->is_empty() && prefix->character != space_id) {
float score = 0.0;
std::vector<std::string> ngram = ext_scorer->make_ngram(prefix);
score =
ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha;
score += ext_scorer->beta;
prefix->score += score;
}
}
}
size_t num_prefixes = std::min(prefixes.size(), beam_size);
std::sort(
prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare);
// compute aproximate ctc score as the return score, without affecting the
// return order of decoding result. To delete when decoder gets stable.
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
double approx_ctc = prefixes[i]->score;
if (ext_scorer != nullptr) {
std::vector<int> output;
prefixes[i]->get_path_vec(output);
auto prefix_length = output.size();
auto words = ext_scorer->split_labels(output);
// remove word insert
approx_ctc = approx_ctc - prefix_length * ext_scorer->beta;
// remove language model weight:
approx_ctc -=
(ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha;
}
prefixes[i]->approx_ctc = approx_ctc;
}
return get_beam_search_result(prefixes, vocabulary, beam_size);
}
std::vector<std::vector<std::pair<double, std::string>>>
ctc_beam_search_decoder_batch(
const std::vector<std::vector<std::vector<double>>> &probs_split,
const std::vector<std::string> &vocabulary,
size_t beam_size,
size_t num_processes,
double cutoff_prob,
size_t cutoff_top_n,
Scorer *ext_scorer,
size_t blank_id) {
VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!");
// thread pool
ThreadPool pool(num_processes);
// number of samples
size_t batch_size = probs_split.size();
// enqueue the tasks of decoding
std::vector<std::future<std::vector<std::pair<double, std::string>>>> res;
for (size_t i = 0; i < batch_size; ++i) {
res.emplace_back(pool.enqueue(ctc_beam_search_decoder,
probs_split[i],
vocabulary,
beam_size,
cutoff_prob,
cutoff_top_n,
ext_scorer,
blank_id));
}
// get decoding results
std::vector<std::vector<std::pair<double, std::string>>> batch_results;
for (size_t i = 0; i < batch_size; ++i) {
batch_results.emplace_back(res[i].get());
}
return batch_results;
}

@ -0,0 +1,110 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef DECODER_UTILS_H_
#define DECODER_UTILS_H_
#include <string>
#include <utility>
#include "fst/log.h"
#include "path_trie.h"
const std::string kSPACE = "<space>";
const float NUM_FLT_INF = std::numeric_limits<float>::max();
const float NUM_FLT_MIN = std::numeric_limits<float>::min();
// inline function for validation check
inline void check(
bool x, const char *expr, const char *file, int line, const char *err) {
if (!x) {
std::cout << "[" << file << ":" << line << "] ";
LOG(FATAL) << "\"" << expr << "\" check failed. " << err;
}
}
#define VALID_CHECK(x, info) \
check(static_cast<bool>(x), #x, __FILE__, __LINE__, info)
#define VALID_CHECK_EQ(x, y, info) VALID_CHECK((x) == (y), info)
#define VALID_CHECK_GT(x, y, info) VALID_CHECK((x) > (y), info)
#define VALID_CHECK_LT(x, y, info) VALID_CHECK((x) < (y), info)
// Function template for comparing two pairs
template <typename T1, typename T2>
bool pair_comp_first_rev(const std::pair<T1, T2> &a,
const std::pair<T1, T2> &b) {
return a.first > b.first;
}
// Function template for comparing two pairs
template <typename T1, typename T2>
bool pair_comp_second_rev(const std::pair<T1, T2> &a,
const std::pair<T1, T2> &b) {
return a.second > b.second;
}
// Return the sum of two probabilities in log scale
template <typename T>
T log_sum_exp(const T &x, const T &y) {
static T num_min = -std::numeric_limits<T>::max();
if (x <= num_min) return y;
if (y <= num_min) return x;
T xmax = std::max(x, y);
return std::log(std::exp(x - xmax) + std::exp(y - xmax)) + xmax;
}
// Get pruned probability vector for each time step's beam search
std::vector<std::pair<size_t, float>> get_pruned_log_probs(
const std::vector<double> &prob_step,
double cutoff_prob,
size_t cutoff_top_n);
// Get beam search result from prefixes in trie tree
std::vector<std::pair<double, std::string>> get_beam_search_result(
const std::vector<PathTrie *> &prefixes,
const std::vector<std::string> &vocabulary,
size_t beam_size);
// Functor for prefix comparsion
bool prefix_compare(const PathTrie *x, const PathTrie *y);
/* Get length of utf8 encoding string
* See: http://stackoverflow.com/a/4063229
*/
size_t get_utf8_str_len(const std::string &str);
/* Split a string into a list of strings on a given string
* delimiter. NB: delimiters on beginning / end of string are
* trimmed. Eg, "FooBarFoo" split on "Foo" returns ["Bar"].
*/
std::vector<std::string> split_str(const std::string &s,
const std::string &delim);
/* Splits string into vector of strings representing
* UTF-8 characters (not same as chars)
*/
std::vector<std::string> split_utf8_str(const std::string &str);
// Add a word in index to the dicionary of fst
void add_word_to_fst(const std::vector<int> &word,
fst::StdVectorFst *dictionary);
// Add a word in string to dictionary
bool add_word_to_dictionary(
const std::string &word,
const std::unordered_map<std::string, int> &char_map,
bool add_space,
int SPACE_ID,
fst::StdVectorFst *dictionary);
#endif // DECODER_UTILS_H

@ -0,0 +1,244 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "scorer.h"
#include <unistd.h>
#include <iostream>
#include "lm/config.hh"
#include "lm/model.hh"
#include "lm/state.hh"
#include "util/string_piece.hh"
#include "util/tokenize_piece.hh"
#include "decoder_utils.h"
using namespace lm::ngram;
Scorer::Scorer(double alpha,
double beta,
const std::string& lm_path,
const std::vector<std::string>& vocab_list) {
this->alpha = alpha;
this->beta = beta;
dictionary = nullptr;
is_character_based_ = true;
language_model_ = nullptr;
max_order_ = 0;
dict_size_ = 0;
SPACE_ID_ = -1;
setup(lm_path, vocab_list);
}
Scorer::~Scorer() {
if (language_model_ != nullptr) {
delete static_cast<lm::base::Model*>(language_model_);
}
if (dictionary != nullptr) {
delete static_cast<fst::StdVectorFst*>(dictionary);
}
}
void Scorer::setup(const std::string& lm_path,
const std::vector<std::string>& vocab_list) {
// load language model
load_lm(lm_path);
// set char map for scorer
set_char_map(vocab_list);
// fill the dictionary for FST
if (!is_character_based()) {
fill_dictionary(true);
}
}
void Scorer::load_lm(const std::string& lm_path) {
const char* filename = lm_path.c_str();
VALID_CHECK_EQ(access(filename, F_OK), 0, "Invalid language model path");
RetriveStrEnumerateVocab enumerate;
lm::ngram::Config config;
config.enumerate_vocab = &enumerate;
language_model_ = lm::ngram::LoadVirtual(filename, config);
max_order_ = static_cast<lm::base::Model*>(language_model_)->Order();
vocabulary_ = enumerate.vocabulary;
for (size_t i = 0; i < vocabulary_.size(); ++i) {
if (is_character_based_ && vocabulary_[i] != UNK_TOKEN &&
vocabulary_[i] != START_TOKEN && vocabulary_[i] != END_TOKEN &&
get_utf8_str_len(enumerate.vocabulary[i]) > 1) {
is_character_based_ = false;
}
}
}
double Scorer::get_log_cond_prob(const std::vector<std::string>& words) {
lm::base::Model* model = static_cast<lm::base::Model*>(language_model_);
double cond_prob;
lm::ngram::State state, tmp_state, out_state;
// avoid to inserting <s> in begin
model->NullContextWrite(&state);
for (size_t i = 0; i < words.size(); ++i) {
lm::WordIndex word_index = model->BaseVocabulary().Index(words[i]);
// encounter OOV
if (word_index == 0) {
return OOV_SCORE;
}
cond_prob = model->BaseScore(&state, word_index, &out_state);
tmp_state = state;
state = out_state;
out_state = tmp_state;
}
// return log10 prob
return cond_prob;
}
double Scorer::get_sent_log_prob(const std::vector<std::string>& words) {
std::vector<std::string> sentence;
if (words.size() == 0) {
for (size_t i = 0; i < max_order_; ++i) {
sentence.push_back(START_TOKEN);
}
} else {
for (size_t i = 0; i < max_order_ - 1; ++i) {
sentence.push_back(START_TOKEN);
}
sentence.insert(sentence.end(), words.begin(), words.end());
}
sentence.push_back(END_TOKEN);
return get_log_prob(sentence);
}
double Scorer::get_log_prob(const std::vector<std::string>& words) {
assert(words.size() > max_order_);
double score = 0.0;
for (size_t i = 0; i < words.size() - max_order_ + 1; ++i) {
std::vector<std::string> ngram(words.begin() + i,
words.begin() + i + max_order_);
score += get_log_cond_prob(ngram);
}
return score;
}
void Scorer::reset_params(float alpha, float beta) {
this->alpha = alpha;
this->beta = beta;
}
std::string Scorer::vec2str(const std::vector<int>& input) {
std::string word;
for (auto ind : input) {
word += char_list_[ind];
}
return word;
}
std::vector<std::string> Scorer::split_labels(const std::vector<int>& labels) {
if (labels.empty()) return {};
std::string s = vec2str(labels);
std::vector<std::string> words;
if (is_character_based_) {
words = split_utf8_str(s);
} else {
words = split_str(s, " ");
}
return words;
}
void Scorer::set_char_map(const std::vector<std::string>& char_list) {
char_list_ = char_list;
char_map_.clear();
// Set the char map for the FST for spelling correction
for (size_t i = 0; i < char_list_.size(); i++) {
if (char_list_[i] == kSPACE) {
SPACE_ID_ = i;
}
// The initial state of FST is state 0, hence the index of chars in
// the FST should start from 1 to avoid the conflict with the initial
// state, otherwise wrong decoding results would be given.
char_map_[char_list_[i]] = i + 1;
}
}
std::vector<std::string> Scorer::make_ngram(PathTrie* prefix) {
std::vector<std::string> ngram;
PathTrie* current_node = prefix;
PathTrie* new_node = nullptr;
for (int order = 0; order < max_order_; order++) {
std::vector<int> prefix_vec;
if (is_character_based_) {
new_node = current_node->get_path_vec(prefix_vec, SPACE_ID_, 1);
current_node = new_node;
} else {
new_node = current_node->get_path_vec(prefix_vec, SPACE_ID_);
current_node = new_node->parent; // Skipping spaces
}
// reconstruct word
std::string word = vec2str(prefix_vec);
ngram.push_back(word);
if (new_node->character == -1) {
// No more spaces, but still need order
for (int i = 0; i < max_order_ - order - 1; i++) {
ngram.push_back(START_TOKEN);
}
break;
}
}
std::reverse(ngram.begin(), ngram.end());
return ngram;
}
void Scorer::fill_dictionary(bool add_space) {
fst::StdVectorFst dictionary;
// For each unigram convert to ints and put in trie
int dict_size = 0;
for (const auto& word : vocabulary_) {
bool added = add_word_to_dictionary(
word, char_map_, add_space, SPACE_ID_ + 1, &dictionary);
dict_size += added ? 1 : 0;
}
dict_size_ = dict_size;
/* Simplify FST
* This gets rid of "epsilon" transitions in the FST.
* These are transitions that don't require a string input to be taken.
* Getting rid of them is necessary to make the FST determinisitc, but
* can greatly increase the size of the FST
*/
fst::RmEpsilon(&dictionary);
fst::StdVectorFst* new_dict = new fst::StdVectorFst;
/* This makes the FST deterministic, meaning for any string input there's
* only one possible state the FST could be in. It is assumed our
* dictionary is deterministic when using it.
* (lest we'd have to check for multiple transitions at each state)
*/
fst::Determinize(dictionary, new_dict);
/* Finds the simplest equivalent fst. This is unnecessary but decreases
* memory usage of the dictionary
*/
fst::Minimize(new_dict);
this->dictionary = new_dict;
}

@ -0,0 +1,24 @@
#!/usr/bin/env bash
if [ ! -d kenlm ]; then
git clone https://github.com/kpu/kenlm.git
cd kenlm/
git checkout df2d717e95183f79a90b2fa6e4307083a351ca6a
cd ..
echo -e "\n"
fi
if [ ! -d openfst-1.6.3 ]; then
echo "Download and extract openfst ..."
wget http://www.openfst.org/twiki/pub/FST/FstDownload/openfst-1.6.3.tar.gz
tar -xzvf openfst-1.6.3.tar.gz
echo -e "\n"
fi
if [ ! -d ThreadPool ]; then
git clone https://github.com/progschj/ThreadPool.git
echo -e "\n"
fi
echo "Install decoders ..."
python3 setup.py install --num_processes 4

@ -32,7 +32,7 @@ except Exception as e:
__all__ = ['CTCDecoder']
class CTCDecoder(nn.Layer):
class CTCDecoderBase(nn.Layer):
def __init__(self,
odim,
enc_n_units,
@ -65,9 +65,6 @@ class CTCDecoder(nn.Layer):
batch_average=batch_average,
grad_norm_type=grad_norm_type)
# CTCDecoder LM Score handle
self._ext_scorer = None
def forward(self, hs_pad, hlens, ys_pad, ys_lens):
"""Calculate CTC loss.
@ -126,6 +123,13 @@ class CTCDecoder(nn.Layer):
"""
return ctc_utils.forced_align(ctc_probs, y, blank_id)
class CTCDecoder(CTCDecoderBase):
def __init__(self,*args, **kwargs):
super().__init__(*args, **kwargs)
# CTCDecoder LM Score handle
self._ext_scorer = None
def _decode_batch_greedy(self, probs_split, vocab_list):
"""Decode by best path for a batch of probs matrix input.
:param probs_split: List of 2-D probability matrix, and each consists

@ -69,8 +69,8 @@ fi
# install decoders
python3 -c "import pkg_resources; pkg_resources.require(\"swig_decoders==1.1\")"
if [ $? != 0 ]; then
cd deepspeech/decoders/swig > /dev/null
sh setup.sh
cd deepspeech/decoders/ctcdecoder/swig > /dev/null
bash setup.sh
cd - > /dev/null
fi
python3 -c "import pkg_resources; pkg_resources.require(\"swig_decoders==1.1\")"

Loading…
Cancel
Save