Merge pull request #882 from PaddlePaddle/decoder
refactor raw ctc decoder into ctcdecoder, new join ctc/att decoderpull/888/head
commit
60e9790610
@ -1,3 +0,0 @@
|
|||||||
# Reference
|
|
||||||
* [Sequence Modeling With CTC](https://distill.pub/2017/ctc/)
|
|
||||||
* [First-Pass Large Vocabulary Continuous Speech Recognition using Bi-Directional Recurrent DNNs](https://arxiv.org/pdf/1408.2873.pdf)
|
|
@ -0,0 +1,13 @@
|
|||||||
|
# Decoders
|
||||||
|
|
||||||
|
## Reference
|
||||||
|
### CTC Prefix Beam Search
|
||||||
|
* [Sequence Modeling With CTC](https://distill.pub/2017/ctc/)
|
||||||
|
* [First-Pass Large Vocabulary Continuous Speech Recognition using Bi-Directional Recurrent DNNs](https://arxiv.org/pdf/1408.2873.pdf)
|
||||||
|
|
||||||
|
### CTC Prefix Score & Join CTC/ATT One-passing Decoding
|
||||||
|
* [Hybrid CTC/Attention Architecture for End-to-End Speech Recognition](http://www.ifp.illinois.edu/speech/speech_web_lg/slides/2019/watanabe_hybridCTCAttention_2017.pdf)
|
||||||
|
* [Vectorized Beam Search for CTC-Attention-based Speech Recognition](https://www.isca-speech.org/archive/pdfs/interspeech_2019/seki19b_interspeech.pdf)
|
||||||
|
|
||||||
|
### Streaming Join CTC/ATT Beam Search
|
||||||
|
* [STREAMING TRANSFORMER ASR WITH BLOCKWISE SYNCHRONOUS BEAM SEARCH](https://arxiv.org/abs/2006.14941)
|
@ -1,13 +1 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
from .ctcdecoder import swig_wrapper
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
@ -0,0 +1,13 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
@ -0,0 +1,243 @@
|
|||||||
|
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "ctc_beam_search_decoder.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cmath>
|
||||||
|
#include <iostream>
|
||||||
|
#include <limits>
|
||||||
|
#include <map>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "ThreadPool.h"
|
||||||
|
#include "fst/fstlib.h"
|
||||||
|
|
||||||
|
#include "decoder_utils.h"
|
||||||
|
#include "path_trie.h"
|
||||||
|
|
||||||
|
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
|
||||||
|
|
||||||
|
std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
|
||||||
|
const std::vector<std::vector<double>> &probs_seq,
|
||||||
|
const std::vector<std::string> &vocabulary,
|
||||||
|
size_t beam_size,
|
||||||
|
double cutoff_prob,
|
||||||
|
size_t cutoff_top_n,
|
||||||
|
Scorer *ext_scorer,
|
||||||
|
size_t blank_id) {
|
||||||
|
// dimension check
|
||||||
|
size_t num_time_steps = probs_seq.size();
|
||||||
|
for (size_t i = 0; i < num_time_steps; ++i) {
|
||||||
|
VALID_CHECK_EQ(probs_seq[i].size(),
|
||||||
|
// vocabulary.size() + 1,
|
||||||
|
vocabulary.size(),
|
||||||
|
"The shape of probs_seq does not match with "
|
||||||
|
"the shape of the vocabulary");
|
||||||
|
}
|
||||||
|
// assign space id
|
||||||
|
auto it = std::find(vocabulary.begin(), vocabulary.end(), kSPACE);
|
||||||
|
int space_id = it - vocabulary.begin();
|
||||||
|
// if no space in vocabulary
|
||||||
|
if ((size_t)space_id >= vocabulary.size()) {
|
||||||
|
space_id = -2;
|
||||||
|
}
|
||||||
|
// init prefixes' root
|
||||||
|
PathTrie root;
|
||||||
|
root.score = root.log_prob_b_prev = 0.0;
|
||||||
|
std::vector<PathTrie *> prefixes;
|
||||||
|
prefixes.push_back(&root);
|
||||||
|
|
||||||
|
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
|
||||||
|
auto fst_dict =
|
||||||
|
static_cast<fst::StdVectorFst *>(ext_scorer->dictionary);
|
||||||
|
fst::StdVectorFst *dict_ptr = fst_dict->Copy(true);
|
||||||
|
root.set_dictionary(dict_ptr);
|
||||||
|
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
|
||||||
|
root.set_matcher(matcher);
|
||||||
|
}
|
||||||
|
|
||||||
|
// prefix search over time
|
||||||
|
for (size_t time_step = 0; time_step < num_time_steps; ++time_step) {
|
||||||
|
auto &prob = probs_seq[time_step];
|
||||||
|
|
||||||
|
float min_cutoff = -NUM_FLT_INF;
|
||||||
|
bool full_beam = false;
|
||||||
|
if (ext_scorer != nullptr) {
|
||||||
|
size_t num_prefixes = std::min(prefixes.size(), beam_size);
|
||||||
|
std::sort(prefixes.begin(),
|
||||||
|
prefixes.begin() + num_prefixes,
|
||||||
|
prefix_compare);
|
||||||
|
min_cutoff = prefixes[num_prefixes - 1]->score +
|
||||||
|
std::log(prob[blank_id]) -
|
||||||
|
std::max(0.0, ext_scorer->beta);
|
||||||
|
full_beam = (num_prefixes == beam_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::pair<size_t, float>> log_prob_idx =
|
||||||
|
get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n);
|
||||||
|
// loop over chars
|
||||||
|
for (size_t index = 0; index < log_prob_idx.size(); index++) {
|
||||||
|
auto c = log_prob_idx[index].first;
|
||||||
|
auto log_prob_c = log_prob_idx[index].second;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < prefixes.size() && i < beam_size; ++i) {
|
||||||
|
auto prefix = prefixes[i];
|
||||||
|
if (full_beam && log_prob_c + prefix->score < min_cutoff) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
// blank
|
||||||
|
if (c == blank_id) {
|
||||||
|
prefix->log_prob_b_cur = log_sum_exp(
|
||||||
|
prefix->log_prob_b_cur, log_prob_c + prefix->score);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// repeated character
|
||||||
|
if (c == prefix->character) {
|
||||||
|
prefix->log_prob_nb_cur =
|
||||||
|
log_sum_exp(prefix->log_prob_nb_cur,
|
||||||
|
log_prob_c + prefix->log_prob_nb_prev);
|
||||||
|
}
|
||||||
|
// get new prefix
|
||||||
|
auto prefix_new = prefix->get_path_trie(c);
|
||||||
|
|
||||||
|
if (prefix_new != nullptr) {
|
||||||
|
float log_p = -NUM_FLT_INF;
|
||||||
|
|
||||||
|
if (c == prefix->character &&
|
||||||
|
prefix->log_prob_b_prev > -NUM_FLT_INF) {
|
||||||
|
log_p = log_prob_c + prefix->log_prob_b_prev;
|
||||||
|
} else if (c != prefix->character) {
|
||||||
|
log_p = log_prob_c + prefix->score;
|
||||||
|
}
|
||||||
|
|
||||||
|
// language model scoring
|
||||||
|
if (ext_scorer != nullptr &&
|
||||||
|
(c == space_id || ext_scorer->is_character_based())) {
|
||||||
|
PathTrie *prefix_to_score = nullptr;
|
||||||
|
// skip scoring the space
|
||||||
|
if (ext_scorer->is_character_based()) {
|
||||||
|
prefix_to_score = prefix_new;
|
||||||
|
} else {
|
||||||
|
prefix_to_score = prefix;
|
||||||
|
}
|
||||||
|
|
||||||
|
float score = 0.0;
|
||||||
|
std::vector<std::string> ngram;
|
||||||
|
ngram = ext_scorer->make_ngram(prefix_to_score);
|
||||||
|
score = ext_scorer->get_log_cond_prob(ngram) *
|
||||||
|
ext_scorer->alpha;
|
||||||
|
log_p += score;
|
||||||
|
log_p += ext_scorer->beta;
|
||||||
|
}
|
||||||
|
prefix_new->log_prob_nb_cur =
|
||||||
|
log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
|
||||||
|
}
|
||||||
|
} // end of loop over prefix
|
||||||
|
} // end of loop over vocabulary
|
||||||
|
|
||||||
|
|
||||||
|
prefixes.clear();
|
||||||
|
// update log probs
|
||||||
|
root.iterate_to_vec(prefixes);
|
||||||
|
|
||||||
|
// only preserve top beam_size prefixes
|
||||||
|
if (prefixes.size() >= beam_size) {
|
||||||
|
std::nth_element(prefixes.begin(),
|
||||||
|
prefixes.begin() + beam_size,
|
||||||
|
prefixes.end(),
|
||||||
|
prefix_compare);
|
||||||
|
for (size_t i = beam_size; i < prefixes.size(); ++i) {
|
||||||
|
prefixes[i]->remove();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // end of loop over time
|
||||||
|
|
||||||
|
// score the last word of each prefix that doesn't end with space
|
||||||
|
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
|
||||||
|
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
|
||||||
|
auto prefix = prefixes[i];
|
||||||
|
if (!prefix->is_empty() && prefix->character != space_id) {
|
||||||
|
float score = 0.0;
|
||||||
|
std::vector<std::string> ngram = ext_scorer->make_ngram(prefix);
|
||||||
|
score =
|
||||||
|
ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha;
|
||||||
|
score += ext_scorer->beta;
|
||||||
|
prefix->score += score;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t num_prefixes = std::min(prefixes.size(), beam_size);
|
||||||
|
std::sort(
|
||||||
|
prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare);
|
||||||
|
|
||||||
|
// compute aproximate ctc score as the return score, without affecting the
|
||||||
|
// return order of decoding result. To delete when decoder gets stable.
|
||||||
|
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
|
||||||
|
double approx_ctc = prefixes[i]->score;
|
||||||
|
if (ext_scorer != nullptr) {
|
||||||
|
std::vector<int> output;
|
||||||
|
prefixes[i]->get_path_vec(output);
|
||||||
|
auto prefix_length = output.size();
|
||||||
|
auto words = ext_scorer->split_labels(output);
|
||||||
|
// remove word insert
|
||||||
|
approx_ctc = approx_ctc - prefix_length * ext_scorer->beta;
|
||||||
|
// remove language model weight:
|
||||||
|
approx_ctc -=
|
||||||
|
(ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha;
|
||||||
|
}
|
||||||
|
prefixes[i]->approx_ctc = approx_ctc;
|
||||||
|
}
|
||||||
|
|
||||||
|
return get_beam_search_result(prefixes, vocabulary, beam_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
std::vector<std::vector<std::pair<double, std::string>>>
|
||||||
|
ctc_beam_search_decoder_batch(
|
||||||
|
const std::vector<std::vector<std::vector<double>>> &probs_split,
|
||||||
|
const std::vector<std::string> &vocabulary,
|
||||||
|
size_t beam_size,
|
||||||
|
size_t num_processes,
|
||||||
|
double cutoff_prob,
|
||||||
|
size_t cutoff_top_n,
|
||||||
|
Scorer *ext_scorer,
|
||||||
|
size_t blank_id) {
|
||||||
|
VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!");
|
||||||
|
// thread pool
|
||||||
|
ThreadPool pool(num_processes);
|
||||||
|
// number of samples
|
||||||
|
size_t batch_size = probs_split.size();
|
||||||
|
|
||||||
|
// enqueue the tasks of decoding
|
||||||
|
std::vector<std::future<std::vector<std::pair<double, std::string>>>> res;
|
||||||
|
for (size_t i = 0; i < batch_size; ++i) {
|
||||||
|
res.emplace_back(pool.enqueue(ctc_beam_search_decoder,
|
||||||
|
probs_split[i],
|
||||||
|
vocabulary,
|
||||||
|
beam_size,
|
||||||
|
cutoff_prob,
|
||||||
|
cutoff_top_n,
|
||||||
|
ext_scorer,
|
||||||
|
blank_id));
|
||||||
|
}
|
||||||
|
|
||||||
|
// get decoding results
|
||||||
|
std::vector<std::vector<std::pair<double, std::string>>> batch_results;
|
||||||
|
for (size_t i = 0; i < batch_size; ++i) {
|
||||||
|
batch_results.emplace_back(res[i].get());
|
||||||
|
}
|
||||||
|
return batch_results;
|
||||||
|
}
|
@ -0,0 +1,110 @@
|
|||||||
|
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#ifndef DECODER_UTILS_H_
|
||||||
|
#define DECODER_UTILS_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include "fst/log.h"
|
||||||
|
#include "path_trie.h"
|
||||||
|
|
||||||
|
const std::string kSPACE = "<space>";
|
||||||
|
const float NUM_FLT_INF = std::numeric_limits<float>::max();
|
||||||
|
const float NUM_FLT_MIN = std::numeric_limits<float>::min();
|
||||||
|
|
||||||
|
// inline function for validation check
|
||||||
|
inline void check(
|
||||||
|
bool x, const char *expr, const char *file, int line, const char *err) {
|
||||||
|
if (!x) {
|
||||||
|
std::cout << "[" << file << ":" << line << "] ";
|
||||||
|
LOG(FATAL) << "\"" << expr << "\" check failed. " << err;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#define VALID_CHECK(x, info) \
|
||||||
|
check(static_cast<bool>(x), #x, __FILE__, __LINE__, info)
|
||||||
|
#define VALID_CHECK_EQ(x, y, info) VALID_CHECK((x) == (y), info)
|
||||||
|
#define VALID_CHECK_GT(x, y, info) VALID_CHECK((x) > (y), info)
|
||||||
|
#define VALID_CHECK_LT(x, y, info) VALID_CHECK((x) < (y), info)
|
||||||
|
|
||||||
|
|
||||||
|
// Function template for comparing two pairs
|
||||||
|
template <typename T1, typename T2>
|
||||||
|
bool pair_comp_first_rev(const std::pair<T1, T2> &a,
|
||||||
|
const std::pair<T1, T2> &b) {
|
||||||
|
return a.first > b.first;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Function template for comparing two pairs
|
||||||
|
template <typename T1, typename T2>
|
||||||
|
bool pair_comp_second_rev(const std::pair<T1, T2> &a,
|
||||||
|
const std::pair<T1, T2> &b) {
|
||||||
|
return a.second > b.second;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the sum of two probabilities in log scale
|
||||||
|
template <typename T>
|
||||||
|
T log_sum_exp(const T &x, const T &y) {
|
||||||
|
static T num_min = -std::numeric_limits<T>::max();
|
||||||
|
if (x <= num_min) return y;
|
||||||
|
if (y <= num_min) return x;
|
||||||
|
T xmax = std::max(x, y);
|
||||||
|
return std::log(std::exp(x - xmax) + std::exp(y - xmax)) + xmax;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get pruned probability vector for each time step's beam search
|
||||||
|
std::vector<std::pair<size_t, float>> get_pruned_log_probs(
|
||||||
|
const std::vector<double> &prob_step,
|
||||||
|
double cutoff_prob,
|
||||||
|
size_t cutoff_top_n);
|
||||||
|
|
||||||
|
// Get beam search result from prefixes in trie tree
|
||||||
|
std::vector<std::pair<double, std::string>> get_beam_search_result(
|
||||||
|
const std::vector<PathTrie *> &prefixes,
|
||||||
|
const std::vector<std::string> &vocabulary,
|
||||||
|
size_t beam_size);
|
||||||
|
|
||||||
|
// Functor for prefix comparsion
|
||||||
|
bool prefix_compare(const PathTrie *x, const PathTrie *y);
|
||||||
|
|
||||||
|
/* Get length of utf8 encoding string
|
||||||
|
* See: http://stackoverflow.com/a/4063229
|
||||||
|
*/
|
||||||
|
size_t get_utf8_str_len(const std::string &str);
|
||||||
|
|
||||||
|
/* Split a string into a list of strings on a given string
|
||||||
|
* delimiter. NB: delimiters on beginning / end of string are
|
||||||
|
* trimmed. Eg, "FooBarFoo" split on "Foo" returns ["Bar"].
|
||||||
|
*/
|
||||||
|
std::vector<std::string> split_str(const std::string &s,
|
||||||
|
const std::string &delim);
|
||||||
|
|
||||||
|
/* Splits string into vector of strings representing
|
||||||
|
* UTF-8 characters (not same as chars)
|
||||||
|
*/
|
||||||
|
std::vector<std::string> split_utf8_str(const std::string &str);
|
||||||
|
|
||||||
|
// Add a word in index to the dicionary of fst
|
||||||
|
void add_word_to_fst(const std::vector<int> &word,
|
||||||
|
fst::StdVectorFst *dictionary);
|
||||||
|
|
||||||
|
// Add a word in string to dictionary
|
||||||
|
bool add_word_to_dictionary(
|
||||||
|
const std::string &word,
|
||||||
|
const std::unordered_map<std::string, int> &char_map,
|
||||||
|
bool add_space,
|
||||||
|
int SPACE_ID,
|
||||||
|
fst::StdVectorFst *dictionary);
|
||||||
|
#endif // DECODER_UTILS_H
|
@ -0,0 +1,244 @@
|
|||||||
|
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "scorer.h"
|
||||||
|
|
||||||
|
#include <unistd.h>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
#include "lm/config.hh"
|
||||||
|
#include "lm/model.hh"
|
||||||
|
#include "lm/state.hh"
|
||||||
|
#include "util/string_piece.hh"
|
||||||
|
#include "util/tokenize_piece.hh"
|
||||||
|
|
||||||
|
#include "decoder_utils.h"
|
||||||
|
|
||||||
|
using namespace lm::ngram;
|
||||||
|
|
||||||
|
Scorer::Scorer(double alpha,
|
||||||
|
double beta,
|
||||||
|
const std::string& lm_path,
|
||||||
|
const std::vector<std::string>& vocab_list) {
|
||||||
|
this->alpha = alpha;
|
||||||
|
this->beta = beta;
|
||||||
|
|
||||||
|
dictionary = nullptr;
|
||||||
|
is_character_based_ = true;
|
||||||
|
language_model_ = nullptr;
|
||||||
|
|
||||||
|
max_order_ = 0;
|
||||||
|
dict_size_ = 0;
|
||||||
|
SPACE_ID_ = -1;
|
||||||
|
|
||||||
|
setup(lm_path, vocab_list);
|
||||||
|
}
|
||||||
|
|
||||||
|
Scorer::~Scorer() {
|
||||||
|
if (language_model_ != nullptr) {
|
||||||
|
delete static_cast<lm::base::Model*>(language_model_);
|
||||||
|
}
|
||||||
|
if (dictionary != nullptr) {
|
||||||
|
delete static_cast<fst::StdVectorFst*>(dictionary);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Scorer::setup(const std::string& lm_path,
|
||||||
|
const std::vector<std::string>& vocab_list) {
|
||||||
|
// load language model
|
||||||
|
load_lm(lm_path);
|
||||||
|
// set char map for scorer
|
||||||
|
set_char_map(vocab_list);
|
||||||
|
// fill the dictionary for FST
|
||||||
|
if (!is_character_based()) {
|
||||||
|
fill_dictionary(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Scorer::load_lm(const std::string& lm_path) {
|
||||||
|
const char* filename = lm_path.c_str();
|
||||||
|
VALID_CHECK_EQ(access(filename, F_OK), 0, "Invalid language model path");
|
||||||
|
|
||||||
|
RetriveStrEnumerateVocab enumerate;
|
||||||
|
lm::ngram::Config config;
|
||||||
|
config.enumerate_vocab = &enumerate;
|
||||||
|
language_model_ = lm::ngram::LoadVirtual(filename, config);
|
||||||
|
max_order_ = static_cast<lm::base::Model*>(language_model_)->Order();
|
||||||
|
vocabulary_ = enumerate.vocabulary;
|
||||||
|
for (size_t i = 0; i < vocabulary_.size(); ++i) {
|
||||||
|
if (is_character_based_ && vocabulary_[i] != UNK_TOKEN &&
|
||||||
|
vocabulary_[i] != START_TOKEN && vocabulary_[i] != END_TOKEN &&
|
||||||
|
get_utf8_str_len(enumerate.vocabulary[i]) > 1) {
|
||||||
|
is_character_based_ = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
double Scorer::get_log_cond_prob(const std::vector<std::string>& words) {
|
||||||
|
lm::base::Model* model = static_cast<lm::base::Model*>(language_model_);
|
||||||
|
double cond_prob;
|
||||||
|
lm::ngram::State state, tmp_state, out_state;
|
||||||
|
// avoid to inserting <s> in begin
|
||||||
|
model->NullContextWrite(&state);
|
||||||
|
for (size_t i = 0; i < words.size(); ++i) {
|
||||||
|
lm::WordIndex word_index = model->BaseVocabulary().Index(words[i]);
|
||||||
|
// encounter OOV
|
||||||
|
if (word_index == 0) {
|
||||||
|
return OOV_SCORE;
|
||||||
|
}
|
||||||
|
cond_prob = model->BaseScore(&state, word_index, &out_state);
|
||||||
|
tmp_state = state;
|
||||||
|
state = out_state;
|
||||||
|
out_state = tmp_state;
|
||||||
|
}
|
||||||
|
// return log10 prob
|
||||||
|
return cond_prob;
|
||||||
|
}
|
||||||
|
|
||||||
|
double Scorer::get_sent_log_prob(const std::vector<std::string>& words) {
|
||||||
|
std::vector<std::string> sentence;
|
||||||
|
if (words.size() == 0) {
|
||||||
|
for (size_t i = 0; i < max_order_; ++i) {
|
||||||
|
sentence.push_back(START_TOKEN);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (size_t i = 0; i < max_order_ - 1; ++i) {
|
||||||
|
sentence.push_back(START_TOKEN);
|
||||||
|
}
|
||||||
|
sentence.insert(sentence.end(), words.begin(), words.end());
|
||||||
|
}
|
||||||
|
sentence.push_back(END_TOKEN);
|
||||||
|
return get_log_prob(sentence);
|
||||||
|
}
|
||||||
|
|
||||||
|
double Scorer::get_log_prob(const std::vector<std::string>& words) {
|
||||||
|
assert(words.size() > max_order_);
|
||||||
|
double score = 0.0;
|
||||||
|
for (size_t i = 0; i < words.size() - max_order_ + 1; ++i) {
|
||||||
|
std::vector<std::string> ngram(words.begin() + i,
|
||||||
|
words.begin() + i + max_order_);
|
||||||
|
score += get_log_cond_prob(ngram);
|
||||||
|
}
|
||||||
|
return score;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Scorer::reset_params(float alpha, float beta) {
|
||||||
|
this->alpha = alpha;
|
||||||
|
this->beta = beta;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string Scorer::vec2str(const std::vector<int>& input) {
|
||||||
|
std::string word;
|
||||||
|
for (auto ind : input) {
|
||||||
|
word += char_list_[ind];
|
||||||
|
}
|
||||||
|
return word;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> Scorer::split_labels(const std::vector<int>& labels) {
|
||||||
|
if (labels.empty()) return {};
|
||||||
|
|
||||||
|
std::string s = vec2str(labels);
|
||||||
|
std::vector<std::string> words;
|
||||||
|
if (is_character_based_) {
|
||||||
|
words = split_utf8_str(s);
|
||||||
|
} else {
|
||||||
|
words = split_str(s, " ");
|
||||||
|
}
|
||||||
|
return words;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Scorer::set_char_map(const std::vector<std::string>& char_list) {
|
||||||
|
char_list_ = char_list;
|
||||||
|
char_map_.clear();
|
||||||
|
|
||||||
|
// Set the char map for the FST for spelling correction
|
||||||
|
for (size_t i = 0; i < char_list_.size(); i++) {
|
||||||
|
if (char_list_[i] == kSPACE) {
|
||||||
|
SPACE_ID_ = i;
|
||||||
|
}
|
||||||
|
// The initial state of FST is state 0, hence the index of chars in
|
||||||
|
// the FST should start from 1 to avoid the conflict with the initial
|
||||||
|
// state, otherwise wrong decoding results would be given.
|
||||||
|
char_map_[char_list_[i]] = i + 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> Scorer::make_ngram(PathTrie* prefix) {
|
||||||
|
std::vector<std::string> ngram;
|
||||||
|
PathTrie* current_node = prefix;
|
||||||
|
PathTrie* new_node = nullptr;
|
||||||
|
|
||||||
|
for (int order = 0; order < max_order_; order++) {
|
||||||
|
std::vector<int> prefix_vec;
|
||||||
|
|
||||||
|
if (is_character_based_) {
|
||||||
|
new_node = current_node->get_path_vec(prefix_vec, SPACE_ID_, 1);
|
||||||
|
current_node = new_node;
|
||||||
|
} else {
|
||||||
|
new_node = current_node->get_path_vec(prefix_vec, SPACE_ID_);
|
||||||
|
current_node = new_node->parent; // Skipping spaces
|
||||||
|
}
|
||||||
|
|
||||||
|
// reconstruct word
|
||||||
|
std::string word = vec2str(prefix_vec);
|
||||||
|
ngram.push_back(word);
|
||||||
|
|
||||||
|
if (new_node->character == -1) {
|
||||||
|
// No more spaces, but still need order
|
||||||
|
for (int i = 0; i < max_order_ - order - 1; i++) {
|
||||||
|
ngram.push_back(START_TOKEN);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::reverse(ngram.begin(), ngram.end());
|
||||||
|
return ngram;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Scorer::fill_dictionary(bool add_space) {
|
||||||
|
fst::StdVectorFst dictionary;
|
||||||
|
// For each unigram convert to ints and put in trie
|
||||||
|
int dict_size = 0;
|
||||||
|
for (const auto& word : vocabulary_) {
|
||||||
|
bool added = add_word_to_dictionary(
|
||||||
|
word, char_map_, add_space, SPACE_ID_ + 1, &dictionary);
|
||||||
|
dict_size += added ? 1 : 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
dict_size_ = dict_size;
|
||||||
|
|
||||||
|
/* Simplify FST
|
||||||
|
|
||||||
|
* This gets rid of "epsilon" transitions in the FST.
|
||||||
|
* These are transitions that don't require a string input to be taken.
|
||||||
|
* Getting rid of them is necessary to make the FST determinisitc, but
|
||||||
|
* can greatly increase the size of the FST
|
||||||
|
*/
|
||||||
|
fst::RmEpsilon(&dictionary);
|
||||||
|
fst::StdVectorFst* new_dict = new fst::StdVectorFst;
|
||||||
|
|
||||||
|
/* This makes the FST deterministic, meaning for any string input there's
|
||||||
|
* only one possible state the FST could be in. It is assumed our
|
||||||
|
* dictionary is deterministic when using it.
|
||||||
|
* (lest we'd have to check for multiple transitions at each state)
|
||||||
|
*/
|
||||||
|
fst::Determinize(dictionary, new_dict);
|
||||||
|
|
||||||
|
/* Finds the simplest equivalent fst. This is unnecessary but decreases
|
||||||
|
* memory usage of the dictionary
|
||||||
|
*/
|
||||||
|
fst::Minimize(new_dict);
|
||||||
|
this->dictionary = new_dict;
|
||||||
|
}
|
@ -0,0 +1,24 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
if [ ! -d kenlm ]; then
|
||||||
|
git clone https://github.com/kpu/kenlm.git
|
||||||
|
cd kenlm/
|
||||||
|
git checkout df2d717e95183f79a90b2fa6e4307083a351ca6a
|
||||||
|
cd ..
|
||||||
|
echo -e "\n"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -d openfst-1.6.3 ]; then
|
||||||
|
echo "Download and extract openfst ..."
|
||||||
|
wget http://www.openfst.org/twiki/pub/FST/FstDownload/openfst-1.6.3.tar.gz
|
||||||
|
tar -xzvf openfst-1.6.3.tar.gz
|
||||||
|
echo -e "\n"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -d ThreadPool ]; then
|
||||||
|
git clone https://github.com/progschj/ThreadPool.git
|
||||||
|
echo -e "\n"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Install decoders ..."
|
||||||
|
python3 setup.py install --num_processes 4
|
@ -0,0 +1,158 @@
|
|||||||
|
"""ScorerInterface implementation for CTC."""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
from .ctc_prefix_score import CTCPrefixScore
|
||||||
|
from .ctc_prefix_score import CTCPrefixScorePD
|
||||||
|
from .scorer_interface import BatchPartialScorerInterface
|
||||||
|
|
||||||
|
|
||||||
|
class CTCPrefixScorer(BatchPartialScorerInterface):
|
||||||
|
"""Decoder interface wrapper for CTCPrefixScore."""
|
||||||
|
|
||||||
|
def __init__(self, ctc: paddle.nn.Layer, eos: int):
|
||||||
|
"""Initialize class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctc (paddle.nn.Layer): The CTC implementation.
|
||||||
|
For example, :class:`deepspeech.modules.ctc.CTC`
|
||||||
|
eos (int): The end-of-sequence id.
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.ctc = ctc
|
||||||
|
self.eos = eos
|
||||||
|
self.impl = None
|
||||||
|
|
||||||
|
def init_state(self, x: paddle.Tensor):
|
||||||
|
"""Get an initial state for decoding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (paddle.Tensor): The encoded feature tensor
|
||||||
|
|
||||||
|
Returns: initial state
|
||||||
|
|
||||||
|
"""
|
||||||
|
logp = self.ctc.log_softmax(x.unsqueeze(0)).squeeze(0).numpy()
|
||||||
|
# TODO(karita): use CTCPrefixScorePD
|
||||||
|
self.impl = CTCPrefixScore(logp, 0, self.eos, np)
|
||||||
|
return 0, self.impl.initial_state()
|
||||||
|
|
||||||
|
def select_state(self, state, i, new_id=None):
|
||||||
|
"""Select state with relative ids in the main beam search.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: Decoder state for prefix tokens
|
||||||
|
i (int): Index to select a state in the main beam search
|
||||||
|
new_id (int): New label id to select a state if necessary
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
state: pruned state
|
||||||
|
|
||||||
|
"""
|
||||||
|
if type(state) == tuple:
|
||||||
|
if len(state) == 2: # for CTCPrefixScore
|
||||||
|
sc, st = state
|
||||||
|
return sc[i], st[i]
|
||||||
|
else: # for CTCPrefixScorePD (need new_id > 0)
|
||||||
|
r, log_psi, f_min, f_max, scoring_idmap = state
|
||||||
|
s = log_psi[i, new_id].expand(log_psi.size(1))
|
||||||
|
if scoring_idmap is not None:
|
||||||
|
return r[:, :, i, scoring_idmap[i, new_id]], s, f_min, f_max
|
||||||
|
else:
|
||||||
|
return r[:, :, i, new_id], s, f_min, f_max
|
||||||
|
return None if state is None else state[i]
|
||||||
|
|
||||||
|
def score_partial(self, y, ids, state, x):
|
||||||
|
"""Score new token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y (paddle.Tensor): 1D prefix token
|
||||||
|
next_tokens (paddle.Tensor): paddle.int64 next token to score
|
||||||
|
state: decoder state for prefix tokens
|
||||||
|
x (paddle.Tensor): 2D encoder feature that generates ys
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[paddle.Tensor, Any]:
|
||||||
|
Tuple of a score tensor for y that has a shape `(len(next_tokens),)`
|
||||||
|
and next state for ys
|
||||||
|
|
||||||
|
"""
|
||||||
|
prev_score, state = state
|
||||||
|
presub_score, new_st = self.impl(y.cpu(), ids.cpu(), state)
|
||||||
|
tscore = paddle.to_tensor(
|
||||||
|
presub_score - prev_score, place=x.place, dtype=x.dtype
|
||||||
|
)
|
||||||
|
return tscore, (presub_score, new_st)
|
||||||
|
|
||||||
|
def batch_init_state(self, x: paddle.Tensor):
|
||||||
|
"""Get an initial state for decoding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (paddle.Tensor): The encoded feature tensor
|
||||||
|
|
||||||
|
Returns: initial state
|
||||||
|
|
||||||
|
"""
|
||||||
|
logp = self.ctc.log_softmax(x.unsqueeze(0)) # assuming batch_size = 1
|
||||||
|
xlen = paddle.to_tensor([logp.size(1)])
|
||||||
|
self.impl = CTCPrefixScorePD(logp, xlen, 0, self.eos)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def batch_score_partial(self, y, ids, state, x):
|
||||||
|
"""Score new token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y (paddle.Tensor): 1D prefix token
|
||||||
|
ids (paddle.Tensor): paddle.int64 next token to score
|
||||||
|
state: decoder state for prefix tokens
|
||||||
|
x (paddle.Tensor): 2D encoder feature that generates ys
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[paddle.Tensor, Any]:
|
||||||
|
Tuple of a score tensor for y that has a shape `(len(next_tokens),)`
|
||||||
|
and next state for ys
|
||||||
|
|
||||||
|
"""
|
||||||
|
batch_state = (
|
||||||
|
(
|
||||||
|
paddle.stack([s[0] for s in state], axis=2),
|
||||||
|
paddle.stack([s[1] for s in state]),
|
||||||
|
state[0][2],
|
||||||
|
state[0][3],
|
||||||
|
)
|
||||||
|
if state[0] is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
return self.impl(y, batch_state, ids)
|
||||||
|
|
||||||
|
def extend_prob(self, x: paddle.Tensor):
|
||||||
|
"""Extend probs for decoding.
|
||||||
|
|
||||||
|
This extension is for streaming decoding
|
||||||
|
as in Eq (14) in https://arxiv.org/abs/2006.14941
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (paddle.Tensor): The encoded feature tensor
|
||||||
|
|
||||||
|
"""
|
||||||
|
logp = self.ctc.log_softmax(x.unsqueeze(0))
|
||||||
|
self.impl.extend_prob(logp)
|
||||||
|
|
||||||
|
def extend_state(self, state):
|
||||||
|
"""Extend state for decoding.
|
||||||
|
|
||||||
|
This extension is for streaming decoding
|
||||||
|
as in Eq (14) in https://arxiv.org/abs/2006.14941
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: The states of hyps
|
||||||
|
|
||||||
|
Returns: exteded state
|
||||||
|
|
||||||
|
"""
|
||||||
|
new_state = []
|
||||||
|
for s in state:
|
||||||
|
new_state.append(self.impl.extend_state(s))
|
||||||
|
|
||||||
|
return new_state
|
@ -0,0 +1,356 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
# Copyright 2018 Mitsubishi Electric Research Labs (Takaaki Hori)
|
||||||
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import six
|
||||||
|
|
||||||
|
|
||||||
|
class CTCPrefixScorePD():
|
||||||
|
"""Batch processing of CTCPrefixScore
|
||||||
|
|
||||||
|
which is based on Algorithm 2 in WATANABE et al.
|
||||||
|
"HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION,"
|
||||||
|
but extended to efficiently compute the label probablities for multiple
|
||||||
|
hypotheses simultaneously
|
||||||
|
See also Seki et al. "Vectorized Beam Search for CTC-Attention-Based
|
||||||
|
Speech Recognition," In INTERSPEECH (pp. 3825-3829), 2019.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, x, xlens, blank, eos, margin=0):
|
||||||
|
"""Construct CTC prefix scorer
|
||||||
|
|
||||||
|
`margin` is M in eq.(22,23)
|
||||||
|
|
||||||
|
:param paddle.Tensor x: input label posterior sequences (B, T, O)
|
||||||
|
:param paddle.Tensor xlens: input lengths (B,)
|
||||||
|
:param int blank: blank label id
|
||||||
|
:param int eos: end-of-sequence id
|
||||||
|
:param int margin: margin parameter for windowing (0 means no windowing)
|
||||||
|
"""
|
||||||
|
# In the comment lines,
|
||||||
|
# we assume T: input_length, B: batch size, W: beam width, O: output dim.
|
||||||
|
self.logzero = -10000000000.0
|
||||||
|
self.blank = blank
|
||||||
|
self.eos = eos
|
||||||
|
self.batch = x.size(0)
|
||||||
|
self.input_length = x.size(1)
|
||||||
|
self.odim = x.size(2)
|
||||||
|
self.dtype = x.dtype
|
||||||
|
|
||||||
|
# Pad the rest of posteriors in the batch
|
||||||
|
# TODO(takaaki-hori): need a better way without for-loops
|
||||||
|
for i, l in enumerate(xlens):
|
||||||
|
if l < self.input_length:
|
||||||
|
x[i, l:, :] = self.logzero
|
||||||
|
x[i, l:, blank] = 0
|
||||||
|
# Reshape input x
|
||||||
|
xn = x.transpose([1, 0, 2]) # (B, T, O) -> (T, B, O)
|
||||||
|
xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1, self.odim) # (T,B,O)
|
||||||
|
self.x = paddle.stack([xn, xb]) # (2, T, B, O)
|
||||||
|
self.end_frames = paddle.to_tensor(xlens) - 1 # (B,)
|
||||||
|
|
||||||
|
# Setup CTC windowing
|
||||||
|
self.margin = margin
|
||||||
|
if margin > 0:
|
||||||
|
self.frame_ids = paddle.arange(self.input_length, dtype=self.dtype)
|
||||||
|
# Base indices for index conversion
|
||||||
|
# B idx, hyp idx. shape (B*W, 1)
|
||||||
|
self.idx_bh = None
|
||||||
|
# B idx. shape (B,)
|
||||||
|
self.idx_b = paddle.arange(self.batch)
|
||||||
|
# B idx, O idx. shape (B, 1)
|
||||||
|
self.idx_bo = (self.idx_b * self.odim).unsqueeze(1)
|
||||||
|
|
||||||
|
def __call__(self, y, state, scoring_ids=None, att_w=None):
|
||||||
|
"""Compute CTC prefix scores for next labels
|
||||||
|
|
||||||
|
:param list y: prefix label sequences
|
||||||
|
:param tuple state: previous CTC state
|
||||||
|
:param paddle.Tensor scoring_ids: selected next ids to score (BW, O'), O' <= O
|
||||||
|
:param paddle.Tensor att_w: attention weights to decide CTC window
|
||||||
|
:return new_state, ctc_local_scores (BW, O)
|
||||||
|
"""
|
||||||
|
output_length = len(y[0]) - 1 # ignore sos
|
||||||
|
last_ids = [yi[-1] for yi in y] # last output label ids
|
||||||
|
n_bh = len(last_ids) # batch * hyps
|
||||||
|
n_hyps = n_bh // self.batch # assuming each utterance has the same # of hyps
|
||||||
|
self.scoring_num = scoring_ids.size(-1) if scoring_ids is not None else 0
|
||||||
|
# prepare state info
|
||||||
|
if state is None:
|
||||||
|
r_prev = paddle.full(
|
||||||
|
(self.input_length, 2, self.batch, n_hyps),
|
||||||
|
self.logzero,
|
||||||
|
dtype=self.dtype,
|
||||||
|
) # (T, 2, B, W)
|
||||||
|
r_prev[:, 1] = paddle.cumsum(self.x[0, :, :, self.blank], 0).unsqueeze(2)
|
||||||
|
r_prev = r_prev.view(-1, 2, n_bh) # (T, 2, BW)
|
||||||
|
s_prev = 0.0 # score
|
||||||
|
f_min_prev = 0 # eq. 22-23
|
||||||
|
f_max_prev = 1 # eq. 22-23
|
||||||
|
else:
|
||||||
|
r_prev, s_prev, f_min_prev, f_max_prev = state
|
||||||
|
|
||||||
|
# select input dimensions for scoring
|
||||||
|
if self.scoring_num > 0:
|
||||||
|
# (BW, O)
|
||||||
|
scoring_idmap = paddle.full((n_bh, self.odim), -1, dtype=paddle.long)
|
||||||
|
snum = self.scoring_num
|
||||||
|
if self.idx_bh is None or n_bh > len(self.idx_bh):
|
||||||
|
self.idx_bh = paddle.arange(n_bh).view(-1, 1) # (BW, 1)
|
||||||
|
scoring_idmap[self.idx_bh[:n_bh], scoring_ids] = paddle.arange(snum)
|
||||||
|
scoring_idx = (
|
||||||
|
scoring_ids + self.idx_bo.repeat(1, n_hyps).view(-1, 1) # (BW,1)
|
||||||
|
).view(-1) # (BWO)
|
||||||
|
# x_ shape (2, T, B*W, O)
|
||||||
|
x_ = paddle.index_select(
|
||||||
|
self.x.view(2, -1, self.batch * self.odim), scoring_idx, 2
|
||||||
|
).view(2, -1, n_bh, snum)
|
||||||
|
else:
|
||||||
|
scoring_ids = None
|
||||||
|
scoring_idmap = None
|
||||||
|
snum = self.odim
|
||||||
|
# x_ shape (2, T, B*W, O)
|
||||||
|
x_ = self.x.unsqueeze(3).repeat(1, 1, 1, n_hyps, 1).view(2, -1, n_bh, snum)
|
||||||
|
|
||||||
|
# new CTC forward probs are prepared as a (T x 2 x BW x S) tensor
|
||||||
|
# that corresponds to r_t^n(h) and r_t^b(h) in a batch.
|
||||||
|
r = paddle.full(
|
||||||
|
(self.input_length, 2, n_bh, snum),
|
||||||
|
self.logzero,
|
||||||
|
dtype=self.dtype,
|
||||||
|
)
|
||||||
|
if output_length == 0:
|
||||||
|
r[0, 0] = x_[0, 0]
|
||||||
|
|
||||||
|
r_sum = paddle.logsumexp(r_prev, 1) #(T,BW)
|
||||||
|
log_phi = r_sum.unsqueeze(2).repeat(1, 1, snum) # (T, BW, O)
|
||||||
|
if scoring_ids is not None:
|
||||||
|
for idx in range(n_bh):
|
||||||
|
pos = scoring_idmap[idx, last_ids[idx]]
|
||||||
|
if pos >= 0:
|
||||||
|
log_phi[:, idx, pos] = r_prev[:, 1, idx]
|
||||||
|
else:
|
||||||
|
for idx in range(n_bh):
|
||||||
|
log_phi[:, idx, last_ids[idx]] = r_prev[:, 1, idx]
|
||||||
|
|
||||||
|
# decide start and end frames based on attention weights
|
||||||
|
if att_w is not None and self.margin > 0:
|
||||||
|
f_arg = paddle.matmul(att_w, self.frame_ids)
|
||||||
|
f_min = max(int(f_arg.min().cpu()), f_min_prev)
|
||||||
|
f_max = max(int(f_arg.max().cpu()), f_max_prev)
|
||||||
|
start = min(f_max_prev, max(f_min - self.margin, output_length, 1))
|
||||||
|
end = min(f_max + self.margin, self.input_length)
|
||||||
|
else:
|
||||||
|
f_min = f_max = 0
|
||||||
|
# if one frame one out, the output_length is the eating frame num now.
|
||||||
|
start = max(output_length, 1)
|
||||||
|
end = self.input_length
|
||||||
|
|
||||||
|
# compute forward probabilities log(r_t^n(h)) and log(r_t^b(h))
|
||||||
|
for t in range(start, end):
|
||||||
|
rp = r[t - 1] # (2 x BW x O')
|
||||||
|
rr = paddle.stack([rp[0], log_phi[t - 1], rp[0], rp[1]]).view(
|
||||||
|
2, 2, n_bh, snum
|
||||||
|
) # (2,2,BW,O')
|
||||||
|
r[t] = paddle.logsumexp(rr, 1) + x_[:, t]
|
||||||
|
|
||||||
|
# compute log prefix probabilities log(psi)
|
||||||
|
log_phi_x = paddle.concat((log_phi[0].unsqueeze(0), log_phi[:-1]), axis=0) + x_[0]
|
||||||
|
if scoring_ids is not None:
|
||||||
|
log_psi = paddle.full((n_bh, self.odim), self.logzero, dtype=self.dtype)
|
||||||
|
log_psi_ = paddle.logsumexp(
|
||||||
|
paddle.concat((log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)), axis=0),
|
||||||
|
axis=0,
|
||||||
|
)
|
||||||
|
for si in range(n_bh):
|
||||||
|
log_psi[si, scoring_ids[si]] = log_psi_[si]
|
||||||
|
else:
|
||||||
|
log_psi = paddle.logsumexp(
|
||||||
|
paddle.concat((log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)), axis=0),
|
||||||
|
axis=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
for si in range(n_bh):
|
||||||
|
log_psi[si, self.eos] = r_sum[self.end_frames[si // n_hyps], si]
|
||||||
|
|
||||||
|
# exclude blank probs
|
||||||
|
log_psi[:, self.blank] = self.logzero
|
||||||
|
|
||||||
|
return (log_psi - s_prev), (r, log_psi, f_min, f_max, scoring_idmap)
|
||||||
|
|
||||||
|
def index_select_state(self, state, best_ids):
|
||||||
|
"""Select CTC states according to best ids
|
||||||
|
|
||||||
|
:param state : CTC state
|
||||||
|
:param best_ids : index numbers selected by beam pruning (B, W)
|
||||||
|
:return selected_state
|
||||||
|
"""
|
||||||
|
r, s, f_min, f_max, scoring_idmap = state
|
||||||
|
# convert ids to BHO space
|
||||||
|
n_bh = len(s)
|
||||||
|
n_hyps = n_bh // self.batch
|
||||||
|
vidx = (best_ids + (self.idx_b * (n_hyps * self.odim)).view(-1, 1)).view(-1)
|
||||||
|
# select hypothesis scores
|
||||||
|
s_new = paddle.index_select(s.view(-1), vidx, 0)
|
||||||
|
s_new = s_new.view(-1, 1).repeat(1, self.odim).view(n_bh, self.odim)
|
||||||
|
# convert ids to BHS space (S: scoring_num)
|
||||||
|
if scoring_idmap is not None:
|
||||||
|
snum = self.scoring_num
|
||||||
|
hyp_idx = (best_ids // self.odim + (self.idx_b * n_hyps).view(-1, 1)).view(
|
||||||
|
-1
|
||||||
|
)
|
||||||
|
label_ids = paddle.fmod(best_ids, self.odim).view(-1)
|
||||||
|
score_idx = scoring_idmap[hyp_idx, label_ids]
|
||||||
|
score_idx[score_idx == -1] = 0
|
||||||
|
vidx = score_idx + hyp_idx * snum
|
||||||
|
else:
|
||||||
|
snum = self.odim
|
||||||
|
# select forward probabilities
|
||||||
|
r_new = paddle.index_select(r.view(-1, 2, n_bh * snum), vidx, 2).view(
|
||||||
|
-1, 2, n_bh
|
||||||
|
)
|
||||||
|
return r_new, s_new, f_min, f_max
|
||||||
|
|
||||||
|
def extend_prob(self, x):
|
||||||
|
"""Extend CTC prob.
|
||||||
|
|
||||||
|
:param paddle.Tensor x: input label posterior sequences (B, T, O)
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.x.shape[1] < x.shape[1]: # self.x (2,T,B,O); x (B,T,O)
|
||||||
|
# Pad the rest of posteriors in the batch
|
||||||
|
# TODO(takaaki-hori): need a better way without for-loops
|
||||||
|
xlens = [x.size(1)]
|
||||||
|
for i, l in enumerate(xlens):
|
||||||
|
if l < self.input_length:
|
||||||
|
x[i, l:, :] = self.logzero
|
||||||
|
x[i, l:, self.blank] = 0
|
||||||
|
tmp_x = self.x
|
||||||
|
xn = x.transpose([1, 0, 2]) # (B, T, O) -> (T, B, O)
|
||||||
|
xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1, self.odim)
|
||||||
|
self.x = paddle.stack([xn, xb]) # (2, T, B, O)
|
||||||
|
self.x[:, : tmp_x.shape[1], :, :] = tmp_x
|
||||||
|
self.input_length = x.size(1)
|
||||||
|
self.end_frames = paddle.to_tensor(xlens) - 1
|
||||||
|
|
||||||
|
def extend_state(self, state):
|
||||||
|
"""Compute CTC prefix state.
|
||||||
|
|
||||||
|
|
||||||
|
:param state : CTC state
|
||||||
|
:return ctc_state
|
||||||
|
"""
|
||||||
|
|
||||||
|
if state is None:
|
||||||
|
# nothing to do
|
||||||
|
return state
|
||||||
|
else:
|
||||||
|
r_prev, s_prev, f_min_prev, f_max_prev = state
|
||||||
|
|
||||||
|
r_prev_new = paddle.full(
|
||||||
|
(self.input_length, 2),
|
||||||
|
self.logzero,
|
||||||
|
dtype=self.dtype,
|
||||||
|
)
|
||||||
|
start = max(r_prev.shape[0], 1)
|
||||||
|
r_prev_new[0:start] = r_prev
|
||||||
|
for t in range(start, self.input_length):
|
||||||
|
r_prev_new[t, 1] = r_prev_new[t - 1, 1] + self.x[0, t, :, self.blank]
|
||||||
|
|
||||||
|
return (r_prev_new, s_prev, f_min_prev, f_max_prev)
|
||||||
|
|
||||||
|
|
||||||
|
class CTCPrefixScore():
|
||||||
|
"""Compute CTC label sequence scores
|
||||||
|
|
||||||
|
which is based on Algorithm 2 in WATANABE et al.
|
||||||
|
"HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION,"
|
||||||
|
but extended to efficiently compute the probablities of multiple labels
|
||||||
|
simultaneously
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, x, blank, eos, xp):
|
||||||
|
self.xp = xp
|
||||||
|
self.logzero = -10000000000.0
|
||||||
|
self.blank = blank
|
||||||
|
self.eos = eos
|
||||||
|
self.input_length = len(x)
|
||||||
|
self.x = x # (T, O)
|
||||||
|
|
||||||
|
def initial_state(self):
|
||||||
|
"""Obtain an initial CTC state
|
||||||
|
|
||||||
|
:return: CTC state
|
||||||
|
"""
|
||||||
|
# initial CTC state is made of a frame x 2 tensor that corresponds to
|
||||||
|
# r_t^n(<sos>) and r_t^b(<sos>), where 0 and 1 of axis=1 represent
|
||||||
|
# superscripts n and b (non-blank and blank), respectively.
|
||||||
|
# r shape (T, 2)
|
||||||
|
r = self.xp.full((self.input_length, 2), self.logzero, dtype=np.float32)
|
||||||
|
r[0, 1] = self.x[0, self.blank]
|
||||||
|
for i in six.moves.range(1, self.input_length):
|
||||||
|
r[i, 1] = r[i - 1, 1] + self.x[i, self.blank]
|
||||||
|
return r
|
||||||
|
|
||||||
|
def __call__(self, y, cs, r_prev):
|
||||||
|
"""Compute CTC prefix scores for next labels
|
||||||
|
|
||||||
|
:param y : prefix label sequence
|
||||||
|
:param cs : array of next labels
|
||||||
|
:param r_prev: previous CTC state
|
||||||
|
:return ctc_scores, ctc_states
|
||||||
|
"""
|
||||||
|
# initialize CTC states
|
||||||
|
output_length = len(y) - 1 # ignore sos
|
||||||
|
# new CTC states are prepared as a frame x (n or b) x n_labels tensor
|
||||||
|
# that corresponds to r_t^n(h) and r_t^b(h).
|
||||||
|
# r shape (T, 2, n_labels)
|
||||||
|
r = self.xp.ndarray((self.input_length, 2, len(cs)), dtype=np.float32)
|
||||||
|
xs = self.x[:, cs]
|
||||||
|
if output_length == 0:
|
||||||
|
r[0, 0] = xs[0]
|
||||||
|
r[0, 1] = self.logzero
|
||||||
|
else:
|
||||||
|
r[output_length - 1] = self.logzero
|
||||||
|
|
||||||
|
# prepare forward probabilities for the last label
|
||||||
|
r_sum = self.xp.logaddexp(
|
||||||
|
r_prev[:, 0], r_prev[:, 1]
|
||||||
|
) # log(r_t^n(g) + r_t^b(g))
|
||||||
|
last = y[-1]
|
||||||
|
if output_length > 0 and last in cs:
|
||||||
|
log_phi = self.xp.ndarray((self.input_length, len(cs)), dtype=np.float32)
|
||||||
|
for i in six.moves.range(len(cs)):
|
||||||
|
log_phi[:, i] = r_sum if cs[i] != last else r_prev[:, 1]
|
||||||
|
else:
|
||||||
|
log_phi = r_sum
|
||||||
|
|
||||||
|
# compute forward probabilities log(r_t^n(h)), log(r_t^b(h)),
|
||||||
|
# and log prefix probabilities log(psi)
|
||||||
|
start = max(output_length, 1)
|
||||||
|
log_psi = r[start - 1, 0]
|
||||||
|
for t in six.moves.range(start, self.input_length):
|
||||||
|
r[t, 0] = self.xp.logaddexp(r[t - 1, 0], log_phi[t - 1]) + xs[t]
|
||||||
|
r[t, 1] = (
|
||||||
|
self.xp.logaddexp(r[t - 1, 0], r[t - 1, 1]) + self.x[t, self.blank]
|
||||||
|
)
|
||||||
|
log_psi = self.xp.logaddexp(log_psi, log_phi[t - 1] + xs[t])
|
||||||
|
|
||||||
|
# get P(...eos|X) that ends with the prefix itself
|
||||||
|
eos_pos = self.xp.where(cs == self.eos)[0]
|
||||||
|
if len(eos_pos) > 0:
|
||||||
|
log_psi[eos_pos] = r_sum[-1] # log(r_T^n(g) + r_T^b(g))
|
||||||
|
|
||||||
|
# exclude blank probs
|
||||||
|
blank_pos = self.xp.where(cs == self.blank)[0]
|
||||||
|
if len(blank_pos) > 0:
|
||||||
|
log_psi[blank_pos] = self.logzero
|
||||||
|
|
||||||
|
# return the log prefix probability and CTC states, where the label axis
|
||||||
|
# of the CTC states is moved to the first axis to slice it easily
|
||||||
|
# log_psi shape (n_labels,), state shape (n_labels, T, 2)
|
||||||
|
return log_psi, self.xp.rollaxis(r, 2)
|
@ -0,0 +1,61 @@
|
|||||||
|
"""Length bonus module."""
|
||||||
|
from typing import Any
|
||||||
|
from typing import List
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
from .scorer_interface import BatchScorerInterface
|
||||||
|
|
||||||
|
|
||||||
|
class LengthBonus(BatchScorerInterface):
|
||||||
|
"""Length bonus in beam search."""
|
||||||
|
|
||||||
|
def __init__(self, n_vocab: int):
|
||||||
|
"""Initialize class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_vocab (int): The number of tokens in vocabulary for beam search
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.n = n_vocab
|
||||||
|
|
||||||
|
def score(self, y, state, x):
|
||||||
|
"""Score new token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y (paddle.Tensor): 1D paddle.int64 prefix tokens.
|
||||||
|
state: Scorer state for prefix tokens
|
||||||
|
x (paddle.Tensor): 2D encoder feature that generates ys.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[paddle.Tensor, Any]: Tuple of
|
||||||
|
paddle.float32 scores for next token (n_vocab)
|
||||||
|
and None
|
||||||
|
|
||||||
|
"""
|
||||||
|
return paddle.to_tensor([1.0], place=x.place, dtype=x.dtype).expand(self.n), None
|
||||||
|
|
||||||
|
def batch_score(
|
||||||
|
self, ys: paddle.Tensor, states: List[Any], xs: paddle.Tensor
|
||||||
|
) -> Tuple[paddle.Tensor, List[Any]]:
|
||||||
|
"""Score new token batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ys (paddle.Tensor): paddle.int64 prefix tokens (n_batch, ylen).
|
||||||
|
states (List[Any]): Scorer states for prefix tokens.
|
||||||
|
xs (paddle.Tensor):
|
||||||
|
The encoder feature that generates ys (n_batch, xlen, n_feat).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[paddle.Tensor, List[Any]]: Tuple of
|
||||||
|
batchfied scores for next token with shape of `(n_batch, n_vocab)`
|
||||||
|
and next state list for ys.
|
||||||
|
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
paddle.to_tensor([1.0], place=xs.place, dtype=xs.dtype).expand(
|
||||||
|
ys.shape[0], self.n
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
@ -0,0 +1,102 @@
|
|||||||
|
"""Ngram lm implement."""
|
||||||
|
|
||||||
|
from abc import ABC
|
||||||
|
|
||||||
|
import kenlm
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
from .scorer_interface import BatchScorerInterface
|
||||||
|
from .scorer_interface import PartialScorerInterface
|
||||||
|
|
||||||
|
|
||||||
|
class Ngrambase(ABC):
|
||||||
|
"""Ngram base implemented through ScorerInterface."""
|
||||||
|
|
||||||
|
def __init__(self, ngram_model, token_list):
|
||||||
|
"""Initialize Ngrambase.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ngram_model: ngram model path
|
||||||
|
token_list: token list from dict or model.json
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.chardict = [x if x != "<eos>" else "</s>" for x in token_list]
|
||||||
|
self.charlen = len(self.chardict)
|
||||||
|
self.lm = kenlm.LanguageModel(ngram_model)
|
||||||
|
self.tmpkenlmstate = kenlm.State()
|
||||||
|
|
||||||
|
def init_state(self, x):
|
||||||
|
"""Initialize tmp state."""
|
||||||
|
state = kenlm.State()
|
||||||
|
self.lm.NullContextWrite(state)
|
||||||
|
return state
|
||||||
|
|
||||||
|
def score_partial_(self, y, next_token, state, x):
|
||||||
|
"""Score interface for both full and partial scorer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y: previous char
|
||||||
|
next_token: next token need to be score
|
||||||
|
state: previous state
|
||||||
|
x: encoded feature
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[paddle.Tensor, List[Any]]: Tuple of
|
||||||
|
batchfied scores for next token with shape of `(n_batch, n_vocab)`
|
||||||
|
and next state list for ys.
|
||||||
|
|
||||||
|
"""
|
||||||
|
out_state = kenlm.State()
|
||||||
|
ys = self.chardict[y[-1]] if y.shape[0] > 1 else "<s>"
|
||||||
|
self.lm.BaseScore(state, ys, out_state)
|
||||||
|
scores = paddle.empty_like(next_token, dtype=x.dtype)
|
||||||
|
for i, j in enumerate(next_token):
|
||||||
|
scores[i] = self.lm.BaseScore(
|
||||||
|
out_state, self.chardict[j], self.tmpkenlmstate
|
||||||
|
)
|
||||||
|
return scores, out_state
|
||||||
|
|
||||||
|
|
||||||
|
class NgramFullScorer(Ngrambase, BatchScorerInterface):
|
||||||
|
"""Fullscorer for ngram."""
|
||||||
|
|
||||||
|
def score(self, y, state, x):
|
||||||
|
"""Score interface for both full and partial scorer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y: previous char
|
||||||
|
state: previous state
|
||||||
|
x: encoded feature
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[paddle.Tensor, List[Any]]: Tuple of
|
||||||
|
batchfied scores for next token with shape of `(n_batch, n_vocab)`
|
||||||
|
and next state list for ys.
|
||||||
|
|
||||||
|
"""
|
||||||
|
return self.score_partial_(y, paddle.to_tensor(range(self.charlen)), state, x)
|
||||||
|
|
||||||
|
|
||||||
|
class NgramPartScorer(Ngrambase, PartialScorerInterface):
|
||||||
|
"""Partialscorer for ngram."""
|
||||||
|
|
||||||
|
def score_partial(self, y, next_token, state, x):
|
||||||
|
"""Score interface for both full and partial scorer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y: previous char
|
||||||
|
next_token: next token need to be score
|
||||||
|
state: previous state
|
||||||
|
x: encoded feature
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[paddle.Tensor, List[Any]]: Tuple of
|
||||||
|
batchfied scores for next token with shape of `(n_batch, n_vocab)`
|
||||||
|
and next state list for ys.
|
||||||
|
|
||||||
|
"""
|
||||||
|
return self.score_partial_(y, next_token, state, x)
|
||||||
|
|
||||||
|
def select_state(self, state, i):
|
||||||
|
"""Empty select state for scorer interface."""
|
||||||
|
return state
|
@ -0,0 +1,188 @@
|
|||||||
|
"""Scorer interface module."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
from typing import List
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
|
||||||
|
class ScorerInterface:
|
||||||
|
"""Scorer interface for beam search.
|
||||||
|
|
||||||
|
The scorer performs scoring of the all tokens in vocabulary.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
* Search heuristics
|
||||||
|
* :class:`scorers.length_bonus.LengthBonus`
|
||||||
|
* Decoder networks of the sequence-to-sequence models
|
||||||
|
* :class:`transformer.decoder.Decoder`
|
||||||
|
* :class:`rnn.decoders.Decoder`
|
||||||
|
* Neural language models
|
||||||
|
* :class:`lm.transformer.TransformerLM`
|
||||||
|
* :class:`lm.default.DefaultRNNLM`
|
||||||
|
* :class:`lm.seq_rnn.SequentialRNNLM`
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def init_state(self, x: paddle.Tensor) -> Any:
|
||||||
|
"""Get an initial state for decoding (optional).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (paddle.Tensor): The encoded feature tensor
|
||||||
|
|
||||||
|
Returns: initial state
|
||||||
|
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
def select_state(self, state: Any, i: int, new_id: int = None) -> Any:
|
||||||
|
"""Select state with relative ids in the main beam search.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: Decoder state for prefix tokens
|
||||||
|
i (int): Index to select a state in the main beam search
|
||||||
|
new_id (int): New label index to select a state if necessary
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
state: pruned state
|
||||||
|
|
||||||
|
"""
|
||||||
|
return None if state is None else state[i]
|
||||||
|
|
||||||
|
def score(
|
||||||
|
self, y: paddle.Tensor, state: Any, x: paddle.Tensor
|
||||||
|
) -> Tuple[paddle.Tensor, Any]:
|
||||||
|
"""Score new token (required).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y (paddle.Tensor): 1D paddle.int64 prefix tokens.
|
||||||
|
state: Scorer state for prefix tokens
|
||||||
|
x (paddle.Tensor): The encoder feature that generates ys.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[paddle.Tensor, Any]: Tuple of
|
||||||
|
scores for next token that has a shape of `(n_vocab)`
|
||||||
|
and next state for ys
|
||||||
|
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def final_score(self, state: Any) -> float:
|
||||||
|
"""Score eos (optional).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: Scorer state for prefix tokens
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: final score
|
||||||
|
|
||||||
|
"""
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class BatchScorerInterface(ScorerInterface):
|
||||||
|
"""Batch scorer interface."""
|
||||||
|
|
||||||
|
def batch_init_state(self, x: paddle.Tensor) -> Any:
|
||||||
|
"""Get an initial state for decoding (optional).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (paddle.Tensor): The encoded feature tensor
|
||||||
|
|
||||||
|
Returns: initial state
|
||||||
|
|
||||||
|
"""
|
||||||
|
return self.init_state(x)
|
||||||
|
|
||||||
|
def batch_score(
|
||||||
|
self, ys: paddle.Tensor, states: List[Any], xs: paddle.Tensor
|
||||||
|
) -> Tuple[paddle.Tensor, List[Any]]:
|
||||||
|
"""Score new token batch (required).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ys (paddle.Tensor): paddle.int64 prefix tokens (n_batch, ylen).
|
||||||
|
states (List[Any]): Scorer states for prefix tokens.
|
||||||
|
xs (paddle.Tensor):
|
||||||
|
The encoder feature that generates ys (n_batch, xlen, n_feat).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[paddle.Tensor, List[Any]]: Tuple of
|
||||||
|
batchfied scores for next token with shape of `(n_batch, n_vocab)`
|
||||||
|
and next state list for ys.
|
||||||
|
|
||||||
|
"""
|
||||||
|
warnings.warn(
|
||||||
|
"{} batch score is implemented through for loop not parallelized".format(
|
||||||
|
self.__class__.__name__
|
||||||
|
)
|
||||||
|
)
|
||||||
|
scores = list()
|
||||||
|
outstates = list()
|
||||||
|
for i, (y, state, x) in enumerate(zip(ys, states, xs)):
|
||||||
|
score, outstate = self.score(y, state, x)
|
||||||
|
outstates.append(outstate)
|
||||||
|
scores.append(score)
|
||||||
|
scores = paddle.cat(scores, 0).view(ys.shape[0], -1)
|
||||||
|
return scores, outstates
|
||||||
|
|
||||||
|
|
||||||
|
class PartialScorerInterface(ScorerInterface):
|
||||||
|
"""Partial scorer interface for beam search.
|
||||||
|
|
||||||
|
The partial scorer performs scoring when non-partial scorer finished scoring,
|
||||||
|
and receives pre-pruned next tokens to score because it is too heavy to score
|
||||||
|
all the tokens.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
* Prefix search for connectionist-temporal-classification models
|
||||||
|
* :class:`espnet.nets.scorers.ctc.CTCPrefixScorer`
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def score_partial(
|
||||||
|
self, y: paddle.Tensor, next_tokens: paddle.Tensor, state: Any, x: paddle.Tensor
|
||||||
|
) -> Tuple[paddle.Tensor, Any]:
|
||||||
|
"""Score new token (required).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y (paddle.Tensor): 1D prefix token
|
||||||
|
next_tokens (paddle.Tensor): paddle.int64 next token to score
|
||||||
|
state: decoder state for prefix tokens
|
||||||
|
x (paddle.Tensor): The encoder feature that generates ys
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[paddle.Tensor, Any]:
|
||||||
|
Tuple of a score tensor for y that has a shape `(len(next_tokens),)`
|
||||||
|
and next state for ys
|
||||||
|
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class BatchPartialScorerInterface(BatchScorerInterface, PartialScorerInterface):
|
||||||
|
"""Batch partial scorer interface for beam search."""
|
||||||
|
|
||||||
|
def batch_score_partial(
|
||||||
|
self,
|
||||||
|
ys: paddle.Tensor,
|
||||||
|
next_tokens: paddle.Tensor,
|
||||||
|
states: List[Any],
|
||||||
|
xs: paddle.Tensor,
|
||||||
|
) -> Tuple[paddle.Tensor, Any]:
|
||||||
|
"""Score new token (required).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ys (paddle.Tensor): paddle.int64 prefix tokens (n_batch, ylen).
|
||||||
|
next_tokens (paddle.Tensor): paddle.int64 tokens to score (n_batch, n_token).
|
||||||
|
states (List[Any]): Scorer states for prefix tokens.
|
||||||
|
xs (paddle.Tensor):
|
||||||
|
The encoder feature that generates ys (n_batch, xlen, n_feat).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[paddle.Tensor, Any]:
|
||||||
|
Tuple of a score tensor for ys that has a shape `(n_batch, n_vocab)`
|
||||||
|
and next states for ys
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
@ -0,0 +1,34 @@
|
|||||||
|
|
||||||
|
__all__ = ["end_detect"]
|
||||||
|
|
||||||
|
def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))):
|
||||||
|
"""End detection.
|
||||||
|
|
||||||
|
described in Eq. (50) of S. Watanabe et al
|
||||||
|
"Hybrid CTC/Attention Architecture for End-to-End Speech Recognition"
|
||||||
|
|
||||||
|
:param ended_hyps: dict
|
||||||
|
:param i: int
|
||||||
|
:param M: int
|
||||||
|
:param D_end: float
|
||||||
|
:return: bool
|
||||||
|
"""
|
||||||
|
if len(ended_hyps) == 0:
|
||||||
|
return False
|
||||||
|
count = 0
|
||||||
|
best_hyp = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[0]
|
||||||
|
for m in range(M):
|
||||||
|
# get ended_hyps with their length is i - m
|
||||||
|
hyp_length = i - m
|
||||||
|
hyps_same_length = [x for x in ended_hyps if len(x["yseq"]) == hyp_length]
|
||||||
|
if len(hyps_same_length) > 0:
|
||||||
|
best_hyp_same_length = sorted(
|
||||||
|
hyps_same_length, key=lambda x: x["score"], reverse=True
|
||||||
|
)[0]
|
||||||
|
if best_hyp_same_length["score"] - best_hyp["score"] < D_end:
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
if count == M:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
Loading…
Reference in new issue