You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
249 lines
10 KiB
249 lines
10 KiB
4 years ago
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
8 years ago
|
"""Contains various CTC decoders."""
|
||
4 years ago
|
import multiprocessing
|
||
8 years ago
|
from itertools import groupby
|
||
7 years ago
|
from math import log
|
||
4 years ago
|
|
||
|
import numpy as np
|
||
8 years ago
|
|
||
|
|
||
7 years ago
|
def ctc_greedy_decoder(probs_seq, vocabulary):
|
||
|
"""CTC greedy (best path) decoder.
|
||
|
|
||
8 years ago
|
Path consisting of the most probable tokens are further post-processed to
|
||
|
remove consecutive repetitions and all blanks.
|
||
|
|
||
|
:param probs_seq: 2-D list of probabilities over the vocabulary for each
|
||
|
character. Each element is a list of float probabilities
|
||
|
for one character.
|
||
|
:type probs_seq: list
|
||
|
:param vocabulary: Vocabulary list.
|
||
|
:type vocabulary: list
|
||
|
:return: Decoding result string.
|
||
|
:rtype: baseline
|
||
|
"""
|
||
|
# dimension verification
|
||
|
for probs in probs_seq:
|
||
|
if not len(probs) == len(vocabulary) + 1:
|
||
|
raise ValueError("probs_seq dimension mismatchedd with vocabulary")
|
||
|
# argmax to get the best index for each time step
|
||
|
max_index_list = list(np.array(probs_seq).argmax(axis=1))
|
||
|
# remove consecutive duplicate indexes
|
||
|
index_list = [index_group[0] for index_group in groupby(max_index_list)]
|
||
|
# remove blank indexes
|
||
|
blank_index = len(vocabulary)
|
||
|
index_list = [index for index in index_list if index != blank_index]
|
||
|
# convert index list to string
|
||
|
return ''.join([vocabulary[index] for index in index_list])
|
||
|
|
||
|
|
||
8 years ago
|
def ctc_beam_search_decoder(probs_seq,
|
||
|
beam_size,
|
||
|
vocabulary,
|
||
8 years ago
|
cutoff_prob=1.0,
|
||
7 years ago
|
cutoff_top_n=40,
|
||
8 years ago
|
ext_scoring_func=None,
|
||
8 years ago
|
nproc=False):
|
||
7 years ago
|
"""CTC Beam search decoder.
|
||
|
|
||
|
It utilizes beam search to approximately select top best decoding
|
||
|
labels and returning results in the descending order.
|
||
|
The implementation is based on Prefix Beam Search
|
||
|
(https://arxiv.org/abs/1408.2873), and the unclear part is
|
||
8 years ago
|
redesigned. Two important modifications: 1) in the iterative computation
|
||
|
of probabilities, the assignment operation is changed to accumulation for
|
||
|
one prefix may comes from different paths; 2) the if condition "if l^+ not
|
||
|
in A_prev then" after probabilities' computation is deprecated for it is
|
||
|
hard to understand and seems unnecessary.
|
||
|
|
||
|
:param probs_seq: 2-D list of probability distributions over each time
|
||
|
step, with each element being a list of normalized
|
||
|
probabilities over vocabulary and blank.
|
||
8 years ago
|
:type probs_seq: 2-D list
|
||
|
:param beam_size: Width for beam search.
|
||
|
:type beam_size: int
|
||
8 years ago
|
:param vocabulary: Vocabulary list.
|
||
|
:type vocabulary: list
|
||
8 years ago
|
:param cutoff_prob: Cutoff probability in pruning,
|
||
|
default 1.0, no pruning.
|
||
|
:type cutoff_prob: float
|
||
8 years ago
|
:param ext_scoring_func: External scoring function for
|
||
8 years ago
|
partially decoded sentence, e.g. word count
|
||
8 years ago
|
or language model.
|
||
|
:type external_scoring_func: callable
|
||
8 years ago
|
:param nproc: Whether the decoder used in multiprocesses.
|
||
|
:type nproc: bool
|
||
8 years ago
|
:return: List of tuples of log probability and sentence as decoding
|
||
|
results, in descending order of the probability.
|
||
8 years ago
|
:rtype: list
|
||
8 years ago
|
"""
|
||
8 years ago
|
# dimension check
|
||
8 years ago
|
for prob_list in probs_seq:
|
||
|
if not len(prob_list) == len(vocabulary) + 1:
|
||
8 years ago
|
raise ValueError("The shape of prob_seq does not match with the "
|
||
|
"shape of the vocabulary.")
|
||
8 years ago
|
|
||
7 years ago
|
# blank_id assign
|
||
|
blank_id = len(vocabulary)
|
||
8 years ago
|
|
||
8 years ago
|
# If the decoder called in the multiprocesses, then use the global scorer
|
||
8 years ago
|
# instantiated in ctc_beam_search_decoder_batch().
|
||
8 years ago
|
if nproc is True:
|
||
|
global ext_nproc_scorer
|
||
|
ext_scoring_func = ext_nproc_scorer
|
||
|
|
||
4 years ago
|
# initialize
|
||
8 years ago
|
# prefix_set_prev: the set containing selected prefixes
|
||
|
# probs_b_prev: prefixes' probability ending with blank in previous step
|
||
|
# probs_nb_prev: prefixes' probability ending with non-blank in previous step
|
||
7 years ago
|
prefix_set_prev = {'\t': 1.0}
|
||
|
probs_b_prev, probs_nb_prev = {'\t': 1.0}, {'\t': 0.0}
|
||
8 years ago
|
|
||
4 years ago
|
# extend prefix in loop
|
||
5 years ago
|
for time_step in range(len(probs_seq)):
|
||
8 years ago
|
# prefix_set_next: the set containing candidate prefixes
|
||
|
# probs_b_cur: prefixes' probability ending with blank in current step
|
||
|
# probs_nb_cur: prefixes' probability ending with non-blank in current step
|
||
|
prefix_set_next, probs_b_cur, probs_nb_cur = {}, {}, {}
|
||
|
|
||
|
prob_idx = list(enumerate(probs_seq[time_step]))
|
||
8 years ago
|
cutoff_len = len(prob_idx)
|
||
4 years ago
|
# If pruning is enabled
|
||
7 years ago
|
if cutoff_prob < 1.0 or cutoff_top_n < cutoff_len:
|
||
8 years ago
|
prob_idx = sorted(prob_idx, key=lambda asd: asd[1], reverse=True)
|
||
8 years ago
|
cutoff_len, cum_prob = 0, 0.0
|
||
5 years ago
|
for i in range(len(prob_idx)):
|
||
8 years ago
|
cum_prob += prob_idx[i][1]
|
||
|
cutoff_len += 1
|
||
|
if cum_prob >= cutoff_prob:
|
||
|
break
|
||
7 years ago
|
cutoff_len = min(cutoff_len, cutoff_top_n)
|
||
8 years ago
|
prob_idx = prob_idx[0:cutoff_len]
|
||
|
|
||
8 years ago
|
for l in prefix_set_prev:
|
||
4 years ago
|
if l not in prefix_set_next:
|
||
8 years ago
|
probs_b_cur[l], probs_nb_cur[l] = 0.0, 0.0
|
||
|
|
||
8 years ago
|
# extend prefix by travering prob_idx
|
||
5 years ago
|
for index in range(cutoff_len):
|
||
8 years ago
|
c, prob_c = prob_idx[index][0], prob_idx[index][1]
|
||
|
|
||
8 years ago
|
if c == blank_id:
|
||
8 years ago
|
probs_b_cur[l] += prob_c * (
|
||
8 years ago
|
probs_b_prev[l] + probs_nb_prev[l])
|
||
8 years ago
|
else:
|
||
8 years ago
|
last_char = l[-1]
|
||
|
new_char = vocabulary[c]
|
||
|
l_plus = l + new_char
|
||
4 years ago
|
if l_plus not in prefix_set_next:
|
||
8 years ago
|
probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0
|
||
|
|
||
8 years ago
|
if new_char == last_char:
|
||
8 years ago
|
probs_nb_cur[l_plus] += prob_c * probs_b_prev[l]
|
||
|
probs_nb_cur[l] += prob_c * probs_nb_prev[l]
|
||
8 years ago
|
elif new_char == ' ':
|
||
|
if (ext_scoring_func is None) or (len(l) == 1):
|
||
8 years ago
|
score = 1.0
|
||
|
else:
|
||
8 years ago
|
prefix = l[1:]
|
||
8 years ago
|
score = ext_scoring_func(prefix)
|
||
8 years ago
|
probs_nb_cur[l_plus] += score * prob_c * (
|
||
8 years ago
|
probs_b_prev[l] + probs_nb_prev[l])
|
||
8 years ago
|
else:
|
||
8 years ago
|
probs_nb_cur[l_plus] += prob_c * (
|
||
8 years ago
|
probs_b_prev[l] + probs_nb_prev[l])
|
||
8 years ago
|
# add l_plus into prefix_set_next
|
||
|
prefix_set_next[l_plus] = probs_nb_cur[
|
||
|
l_plus] + probs_b_cur[l_plus]
|
||
|
# add l into prefix_set_next
|
||
|
prefix_set_next[l] = probs_b_cur[l] + probs_nb_cur[l]
|
||
|
# update probs
|
||
8 years ago
|
probs_b_prev, probs_nb_prev = probs_b_cur, probs_nb_cur
|
||
8 years ago
|
|
||
4 years ago
|
# store top beam_size prefixes
|
||
8 years ago
|
prefix_set_prev = sorted(
|
||
4 years ago
|
prefix_set_next.items(), key=lambda asd: asd[1], reverse=True)
|
||
8 years ago
|
if beam_size < len(prefix_set_prev):
|
||
|
prefix_set_prev = prefix_set_prev[:beam_size]
|
||
|
prefix_set_prev = dict(prefix_set_prev)
|
||
|
|
||
|
beam_result = []
|
||
8 years ago
|
for seq, prob in prefix_set_prev.items():
|
||
8 years ago
|
if prob > 0.0 and len(seq) > 1:
|
||
8 years ago
|
result = seq[1:]
|
||
8 years ago
|
# score last word by external scorer
|
||
|
if (ext_scoring_func is not None) and (result[-1] != ' '):
|
||
|
prob = prob * ext_scoring_func(result)
|
||
7 years ago
|
log_prob = log(prob)
|
||
8 years ago
|
beam_result.append((log_prob, result))
|
||
7 years ago
|
else:
|
||
|
beam_result.append((float('-inf'), ''))
|
||
8 years ago
|
|
||
4 years ago
|
# output top beam_size decoding results
|
||
8 years ago
|
beam_result = sorted(beam_result, key=lambda asd: asd[0], reverse=True)
|
||
|
return beam_result
|
||
8 years ago
|
|
||
|
|
||
8 years ago
|
def ctc_beam_search_decoder_batch(probs_split,
|
||
8 years ago
|
beam_size,
|
||
|
vocabulary,
|
||
8 years ago
|
num_processes,
|
||
8 years ago
|
cutoff_prob=1.0,
|
||
7 years ago
|
cutoff_top_n=40,
|
||
8 years ago
|
ext_scoring_func=None):
|
||
|
"""CTC beam search decoder using multiple processes.
|
||
8 years ago
|
|
||
8 years ago
|
:param probs_seq: 3-D list with each element as an instance of 2-D list
|
||
|
of probabilities used by ctc_beam_search_decoder().
|
||
8 years ago
|
:type probs_seq: 3-D list
|
||
|
:param beam_size: Width for beam search.
|
||
|
:type beam_size: int
|
||
|
:param vocabulary: Vocabulary list.
|
||
|
:type vocabulary: list
|
||
8 years ago
|
:param num_processes: Number of parallel processes.
|
||
|
:type num_processes: int
|
||
8 years ago
|
:param cutoff_prob: Cutoff probability in pruning,
|
||
8 years ago
|
default 1.0, no pruning.
|
||
7 years ago
|
:type cutoff_prob: float
|
||
8 years ago
|
:param num_processes: Number of parallel processes.
|
||
|
:type num_processes: int
|
||
|
:param ext_scoring_func: External scoring function for
|
||
8 years ago
|
partially decoded sentence, e.g. word count
|
||
8 years ago
|
or language model.
|
||
|
:type external_scoring_function: callable
|
||
|
:return: List of tuples of log probability and sentence as decoding
|
||
|
results, in descending order of the probability.
|
||
8 years ago
|
:rtype: list
|
||
8 years ago
|
"""
|
||
8 years ago
|
if not num_processes > 0:
|
||
|
raise ValueError("Number of processes must be positive!")
|
||
|
|
||
8 years ago
|
# use global variable to pass the externnal scorer to beam search decoder
|
||
|
global ext_nproc_scorer
|
||
|
ext_nproc_scorer = ext_scoring_func
|
||
|
nproc = True
|
||
|
|
||
8 years ago
|
pool = multiprocessing.Pool(processes=num_processes)
|
||
|
results = []
|
||
|
for i, probs_list in enumerate(probs_split):
|
||
7 years ago
|
args = (probs_list, beam_size, vocabulary, cutoff_prob, cutoff_top_n,
|
||
|
None, nproc)
|
||
8 years ago
|
results.append(pool.apply_async(ctc_beam_search_decoder, args))
|
||
|
|
||
|
pool.close()
|
||
|
pool.join()
|
||
8 years ago
|
beam_search_results = [result.get() for result in results]
|
||
8 years ago
|
return beam_search_results
|