|
|
|
@ -8,8 +8,8 @@ import numpy as np
|
|
|
|
|
import multiprocessing
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ctc_best_path_decode(probs_seq, vocabulary):
|
|
|
|
|
"""Best path decoding, also called argmax decoding or greedy decoding.
|
|
|
|
|
def ctc_best_path_decoder(probs_seq, vocabulary):
|
|
|
|
|
"""Best path decoder, also called argmax decoder or greedy decoder.
|
|
|
|
|
Path consisting of the most probable tokens are further post-processed to
|
|
|
|
|
remove consecutive repetitions and all blanks.
|
|
|
|
|
|
|
|
|
@ -40,73 +40,84 @@ def ctc_best_path_decode(probs_seq, vocabulary):
|
|
|
|
|
def ctc_beam_search_decoder(probs_seq,
|
|
|
|
|
beam_size,
|
|
|
|
|
vocabulary,
|
|
|
|
|
blank_id=0,
|
|
|
|
|
blank_id,
|
|
|
|
|
cutoff_prob=1.0,
|
|
|
|
|
ext_scoring_func=None,
|
|
|
|
|
nproc=False):
|
|
|
|
|
'''Beam search decoder for CTC-trained network, using beam search with width
|
|
|
|
|
beam_size to find many paths to one label, return beam_size labels in
|
|
|
|
|
the descending order of probabilities. The implementation is based on Prefix
|
|
|
|
|
Beam Search(https://arxiv.org/abs/1408.2873), and the unclear part is
|
|
|
|
|
redesigned.
|
|
|
|
|
|
|
|
|
|
:param probs_seq: 2-D list with length num_time_steps, each element
|
|
|
|
|
is a list of normalized probabilities over vocabulary
|
|
|
|
|
and blank for one time step.
|
|
|
|
|
"""Beam search decoder for CTC-trained network. It utilizes beam search
|
|
|
|
|
to approximately select top best decoding labels and returning results
|
|
|
|
|
in the descending order. The implementation is based on Prefix
|
|
|
|
|
Beam Search (https://arxiv.org/abs/1408.2873), and the unclear part is
|
|
|
|
|
redesigned. Two important modifications: 1) in the iterative computation
|
|
|
|
|
of probabilities, the assignment operation is changed to accumulation for
|
|
|
|
|
one prefix may comes from different paths; 2) the if condition "if l^+ not
|
|
|
|
|
in A_prev then" after probabilities' computation is deprecated for it is
|
|
|
|
|
hard to understand and seems unnecessary.
|
|
|
|
|
|
|
|
|
|
:param probs_seq: 2-D list of probability distributions over each time
|
|
|
|
|
step, with each element being a list of normalized
|
|
|
|
|
probabilities over vocabulary and blank.
|
|
|
|
|
:type probs_seq: 2-D list
|
|
|
|
|
:param beam_size: Width for beam search.
|
|
|
|
|
:type beam_size: int
|
|
|
|
|
:param vocabulary: Vocabulary list.
|
|
|
|
|
:type vocabulary: list
|
|
|
|
|
:param blank_id: ID of blank, default 0.
|
|
|
|
|
:param blank_id: ID of blank.
|
|
|
|
|
:type blank_id: int
|
|
|
|
|
:param cutoff_prob: Cutoff probability in pruning,
|
|
|
|
|
default 1.0, no pruning.
|
|
|
|
|
:type cutoff_prob: float
|
|
|
|
|
:param ext_scoring_func: External defined scoring function for
|
|
|
|
|
:param ext_scoring_func: External scoring function for
|
|
|
|
|
partially decoded sentence, e.g. word count
|
|
|
|
|
and language model.
|
|
|
|
|
:type external_scoring_function: function
|
|
|
|
|
or language model.
|
|
|
|
|
:type external_scoring_func: callable
|
|
|
|
|
:param nproc: Whether the decoder used in multiprocesses.
|
|
|
|
|
:type nproc: bool
|
|
|
|
|
:return: Decoding log probabilities and result sentences in descending order.
|
|
|
|
|
:return: List of tuples of log probability and sentence as decoding
|
|
|
|
|
results, in descending order of the probability.
|
|
|
|
|
:rtype: list
|
|
|
|
|
'''
|
|
|
|
|
"""
|
|
|
|
|
# dimension check
|
|
|
|
|
for prob_list in probs_seq:
|
|
|
|
|
if not len(prob_list) == len(vocabulary) + 1:
|
|
|
|
|
raise ValueError("probs dimension mismatched with vocabulary")
|
|
|
|
|
num_time_steps = len(probs_seq)
|
|
|
|
|
raise ValueError("The shape of prob_seq does not match with the "
|
|
|
|
|
"shape of the vocabulary.")
|
|
|
|
|
|
|
|
|
|
# blank_id check
|
|
|
|
|
probs_dim = len(probs_seq[0])
|
|
|
|
|
if not blank_id < probs_dim:
|
|
|
|
|
if not blank_id < len(probs_seq[0]):
|
|
|
|
|
raise ValueError("blank_id shouldn't be greater than probs dimension")
|
|
|
|
|
|
|
|
|
|
# If the decoder called in the multiprocesses, then use the global scorer
|
|
|
|
|
# instantiated in ctc_beam_search_decoder_nproc().
|
|
|
|
|
# instantiated in ctc_beam_search_decoder_batch().
|
|
|
|
|
if nproc is True:
|
|
|
|
|
global ext_nproc_scorer
|
|
|
|
|
ext_scoring_func = ext_nproc_scorer
|
|
|
|
|
|
|
|
|
|
## initialize
|
|
|
|
|
# the set containing selected prefixes
|
|
|
|
|
prefix_set_prev = {'\t': 1.0}
|
|
|
|
|
probs_b_prev, probs_nb_prev = {'\t': 1.0}, {'\t': 0.0}
|
|
|
|
|
# prefix_set_prev: the set containing selected prefixes
|
|
|
|
|
# probs_b_prev: prefixes' probability ending with blank in previous step
|
|
|
|
|
# probs_nb_prev: prefixes' probability ending with non-blank in previous step
|
|
|
|
|
prefix_set_prev, probs_b_prev, probs_nb_prev = {
|
|
|
|
|
'\t': 1.0
|
|
|
|
|
}, {
|
|
|
|
|
'\t': 1.0
|
|
|
|
|
}, {
|
|
|
|
|
'\t': 0.0
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
## extend prefix in loop
|
|
|
|
|
for time_step in xrange(num_time_steps):
|
|
|
|
|
# the set containing candidate prefixes
|
|
|
|
|
prefix_set_next = {}
|
|
|
|
|
probs_b_cur, probs_nb_cur = {}, {}
|
|
|
|
|
prob = probs_seq[time_step]
|
|
|
|
|
prob_idx = [[i, prob[i]] for i in xrange(len(prob))]
|
|
|
|
|
for time_step in xrange(len(probs_seq)):
|
|
|
|
|
# prefix_set_next: the set containing candidate prefixes
|
|
|
|
|
# probs_b_cur: prefixes' probability ending with blank in current step
|
|
|
|
|
# probs_nb_cur: prefixes' probability ending with non-blank in current step
|
|
|
|
|
prefix_set_next, probs_b_cur, probs_nb_cur = {}, {}, {}
|
|
|
|
|
|
|
|
|
|
prob_idx = list(enumerate(probs_seq[time_step]))
|
|
|
|
|
cutoff_len = len(prob_idx)
|
|
|
|
|
#If pruning is enabled
|
|
|
|
|
if (cutoff_prob < 1.0):
|
|
|
|
|
if cutoff_prob < 1.0:
|
|
|
|
|
prob_idx = sorted(prob_idx, key=lambda asd: asd[1], reverse=True)
|
|
|
|
|
cutoff_len = 0
|
|
|
|
|
cum_prob = 0.0
|
|
|
|
|
cutoff_len, cum_prob = 0, 0.0
|
|
|
|
|
for i in xrange(len(prob_idx)):
|
|
|
|
|
cum_prob += prob_idx[i][1]
|
|
|
|
|
cutoff_len += 1
|
|
|
|
@ -162,54 +173,53 @@ def ctc_beam_search_decoder(probs_seq,
|
|
|
|
|
prefix_set_prev = dict(prefix_set_prev)
|
|
|
|
|
|
|
|
|
|
beam_result = []
|
|
|
|
|
for (seq, prob) in prefix_set_prev.items():
|
|
|
|
|
for seq, prob in prefix_set_prev.items():
|
|
|
|
|
if prob > 0.0 and len(seq) > 1:
|
|
|
|
|
result = seq[1:]
|
|
|
|
|
# score last word by external scorer
|
|
|
|
|
if (ext_scoring_func is not None) and (result[-1] != ' '):
|
|
|
|
|
prob = prob * ext_scoring_func(result)
|
|
|
|
|
log_prob = np.log(prob)
|
|
|
|
|
beam_result.append([log_prob, result])
|
|
|
|
|
beam_result.append((log_prob, result))
|
|
|
|
|
|
|
|
|
|
## output top beam_size decoding results
|
|
|
|
|
beam_result = sorted(beam_result, key=lambda asd: asd[0], reverse=True)
|
|
|
|
|
return beam_result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ctc_beam_search_decoder_nproc(probs_split,
|
|
|
|
|
def ctc_beam_search_decoder_batch(probs_split,
|
|
|
|
|
beam_size,
|
|
|
|
|
vocabulary,
|
|
|
|
|
blank_id=0,
|
|
|
|
|
blank_id,
|
|
|
|
|
num_processes,
|
|
|
|
|
cutoff_prob=1.0,
|
|
|
|
|
ext_scoring_func=None,
|
|
|
|
|
num_processes=None):
|
|
|
|
|
'''Beam search decoder using multiple processes.
|
|
|
|
|
ext_scoring_func=None):
|
|
|
|
|
"""CTC beam search decoder using multiple processes.
|
|
|
|
|
|
|
|
|
|
:param probs_seq: 3-D list with length batch_size, each element
|
|
|
|
|
is a 2-D list of probabilities can be used by
|
|
|
|
|
ctc_beam_search_decoder.
|
|
|
|
|
:param probs_seq: 3-D list with each element as an instance of 2-D list
|
|
|
|
|
of probabilities used by ctc_beam_search_decoder().
|
|
|
|
|
:type probs_seq: 3-D list
|
|
|
|
|
:param beam_size: Width for beam search.
|
|
|
|
|
:type beam_size: int
|
|
|
|
|
:param vocabulary: Vocabulary list.
|
|
|
|
|
:type vocabulary: list
|
|
|
|
|
:param blank_id: ID of blank, default 0.
|
|
|
|
|
:param blank_id: ID of blank.
|
|
|
|
|
:type blank_id: int
|
|
|
|
|
:param num_processes: Number of parallel processes.
|
|
|
|
|
:type num_processes: int
|
|
|
|
|
:param cutoff_prob: Cutoff probability in pruning,
|
|
|
|
|
default 0, no pruning.
|
|
|
|
|
default 1.0, no pruning.
|
|
|
|
|
:param num_processes: Number of parallel processes.
|
|
|
|
|
:type num_processes: int
|
|
|
|
|
:type cutoff_prob: float
|
|
|
|
|
:param ext_scoring_func: External defined scoring function for
|
|
|
|
|
:param ext_scoring_func: External scoring function for
|
|
|
|
|
partially decoded sentence, e.g. word count
|
|
|
|
|
and language model.
|
|
|
|
|
:type external_scoring_function: function
|
|
|
|
|
:param num_processes: Number of processes, default None, equal to the
|
|
|
|
|
number of CPUs.
|
|
|
|
|
:type num_processes: int
|
|
|
|
|
:return: Decoding log probabilities and result sentences in descending order.
|
|
|
|
|
or language model.
|
|
|
|
|
:type external_scoring_function: callable
|
|
|
|
|
:return: List of tuples of log probability and sentence as decoding
|
|
|
|
|
results, in descending order of the probability.
|
|
|
|
|
:rtype: list
|
|
|
|
|
'''
|
|
|
|
|
if num_processes is None:
|
|
|
|
|
num_processes = multiprocessing.cpu_count()
|
|
|
|
|
"""
|
|
|
|
|
if not num_processes > 0:
|
|
|
|
|
raise ValueError("Number of processes must be positive!")
|
|
|
|
|
|
|
|
|
@ -227,7 +237,5 @@ def ctc_beam_search_decoder_nproc(probs_split,
|
|
|
|
|
|
|
|
|
|
pool.close()
|
|
|
|
|
pool.join()
|
|
|
|
|
beam_search_results = []
|
|
|
|
|
for result in results:
|
|
|
|
|
beam_search_results.append(result.get())
|
|
|
|
|
beam_search_results = [result.get() for result in results]
|
|
|
|
|
return beam_search_results
|
|
|
|
|