|
|
|
@ -42,8 +42,8 @@ def ctc_greedy_decoder(probs_seq, vocabulary):
|
|
|
|
|
def ctc_beam_search_decoder(probs_seq,
|
|
|
|
|
beam_size,
|
|
|
|
|
vocabulary,
|
|
|
|
|
blank_id,
|
|
|
|
|
cutoff_prob=1.0,
|
|
|
|
|
cutoff_top_n=40,
|
|
|
|
|
ext_scoring_func=None,
|
|
|
|
|
nproc=False):
|
|
|
|
|
"""CTC Beam search decoder.
|
|
|
|
@ -66,8 +66,6 @@ def ctc_beam_search_decoder(probs_seq,
|
|
|
|
|
:type beam_size: int
|
|
|
|
|
:param vocabulary: Vocabulary list.
|
|
|
|
|
:type vocabulary: list
|
|
|
|
|
:param blank_id: ID of blank.
|
|
|
|
|
:type blank_id: int
|
|
|
|
|
:param cutoff_prob: Cutoff probability in pruning,
|
|
|
|
|
default 1.0, no pruning.
|
|
|
|
|
:type cutoff_prob: float
|
|
|
|
@ -87,9 +85,8 @@ def ctc_beam_search_decoder(probs_seq,
|
|
|
|
|
raise ValueError("The shape of prob_seq does not match with the "
|
|
|
|
|
"shape of the vocabulary.")
|
|
|
|
|
|
|
|
|
|
# blank_id check
|
|
|
|
|
if not blank_id < len(probs_seq[0]):
|
|
|
|
|
raise ValueError("blank_id shouldn't be greater than probs dimension")
|
|
|
|
|
# blank_id assign
|
|
|
|
|
blank_id = len(vocabulary)
|
|
|
|
|
|
|
|
|
|
# If the decoder called in the multiprocesses, then use the global scorer
|
|
|
|
|
# instantiated in ctc_beam_search_decoder_batch().
|
|
|
|
@ -114,7 +111,7 @@ def ctc_beam_search_decoder(probs_seq,
|
|
|
|
|
prob_idx = list(enumerate(probs_seq[time_step]))
|
|
|
|
|
cutoff_len = len(prob_idx)
|
|
|
|
|
#If pruning is enabled
|
|
|
|
|
if cutoff_prob < 1.0:
|
|
|
|
|
if cutoff_prob < 1.0 or cutoff_top_n < cutoff_len:
|
|
|
|
|
prob_idx = sorted(prob_idx, key=lambda asd: asd[1], reverse=True)
|
|
|
|
|
cutoff_len, cum_prob = 0, 0.0
|
|
|
|
|
for i in xrange(len(prob_idx)):
|
|
|
|
@ -122,6 +119,7 @@ def ctc_beam_search_decoder(probs_seq,
|
|
|
|
|
cutoff_len += 1
|
|
|
|
|
if cum_prob >= cutoff_prob:
|
|
|
|
|
break
|
|
|
|
|
cutoff_len = min(cutoff_top_n, cutoff_top_n)
|
|
|
|
|
prob_idx = prob_idx[0:cutoff_len]
|
|
|
|
|
|
|
|
|
|
for l in prefix_set_prev:
|
|
|
|
@ -191,9 +189,9 @@ def ctc_beam_search_decoder(probs_seq,
|
|
|
|
|
def ctc_beam_search_decoder_batch(probs_split,
|
|
|
|
|
beam_size,
|
|
|
|
|
vocabulary,
|
|
|
|
|
blank_id,
|
|
|
|
|
num_processes,
|
|
|
|
|
cutoff_prob=1.0,
|
|
|
|
|
cutoff_top_n=40,
|
|
|
|
|
ext_scoring_func=None):
|
|
|
|
|
"""CTC beam search decoder using multiple processes.
|
|
|
|
|
|
|
|
|
@ -204,8 +202,6 @@ def ctc_beam_search_decoder_batch(probs_split,
|
|
|
|
|
:type beam_size: int
|
|
|
|
|
:param vocabulary: Vocabulary list.
|
|
|
|
|
:type vocabulary: list
|
|
|
|
|
:param blank_id: ID of blank.
|
|
|
|
|
:type blank_id: int
|
|
|
|
|
:param num_processes: Number of parallel processes.
|
|
|
|
|
:type num_processes: int
|
|
|
|
|
:param cutoff_prob: Cutoff probability in pruning,
|
|
|
|
@ -232,8 +228,8 @@ def ctc_beam_search_decoder_batch(probs_split,
|
|
|
|
|
pool = multiprocessing.Pool(processes=num_processes)
|
|
|
|
|
results = []
|
|
|
|
|
for i, probs_list in enumerate(probs_split):
|
|
|
|
|
args = (probs_list, beam_size, vocabulary, blank_id, cutoff_prob, None,
|
|
|
|
|
nproc)
|
|
|
|
|
args = (probs_list, beam_size, vocabulary, blank_id, cutoff_prob,
|
|
|
|
|
cutoff_top_n, None, nproc)
|
|
|
|
|
results.append(pool.apply_async(ctc_beam_search_decoder, args))
|
|
|
|
|
|
|
|
|
|
pool.close()
|
|
|
|
|