|
|
|
@ -92,7 +92,7 @@ def ctc_beam_search_decoder(probs_seq,
|
|
|
|
|
Search(https://arxiv.org/abs/1408.2873), and the unclear part is
|
|
|
|
|
redesigned, need to be verified.
|
|
|
|
|
|
|
|
|
|
:param probs_seq: 2-D list with length max_time_steps, each element
|
|
|
|
|
:param probs_seq: 2-D list with length num_time_steps, each element
|
|
|
|
|
is a list of normalized probabilities over vocabulary
|
|
|
|
|
and blank for one time step.
|
|
|
|
|
:type probs_seq: 2-D list
|
|
|
|
@ -114,7 +114,7 @@ def ctc_beam_search_decoder(probs_seq,
|
|
|
|
|
for prob_list in probs_seq:
|
|
|
|
|
if not len(prob_list) == len(vocabulary) + 1:
|
|
|
|
|
raise ValueError("probs dimension mismatchedd with vocabulary")
|
|
|
|
|
max_time_steps = len(probs_seq)
|
|
|
|
|
num_time_steps = len(probs_seq)
|
|
|
|
|
|
|
|
|
|
# blank_id check
|
|
|
|
|
probs_dim = len(probs_seq[0])
|
|
|
|
@ -139,10 +139,10 @@ def ctc_beam_search_decoder(probs_seq,
|
|
|
|
|
## initialize
|
|
|
|
|
# the set containing selected prefixes
|
|
|
|
|
prefix_set_prev = {'-1': 1.0}
|
|
|
|
|
probs_b, probs_nb = {'-1': 1.0}, {'-1': 0.0}
|
|
|
|
|
probs_b_prev, probs_nb_prev = {'-1': 1.0}, {'-1': 0.0}
|
|
|
|
|
|
|
|
|
|
## extend prefix in loop
|
|
|
|
|
for time_step in range(max_time_steps):
|
|
|
|
|
for time_step in range(num_time_steps):
|
|
|
|
|
# the set containing candidate prefixes
|
|
|
|
|
prefix_set_next = {}
|
|
|
|
|
probs_b_cur, probs_nb_cur = {}, {}
|
|
|
|
@ -158,33 +158,34 @@ def ctc_beam_search_decoder(probs_seq,
|
|
|
|
|
# extend prefix by travering vocabulary
|
|
|
|
|
for c in range(0, probs_dim):
|
|
|
|
|
if c == blank_id:
|
|
|
|
|
probs_b_cur[l] += prob[c] * (probs_b[l] + probs_nb[l])
|
|
|
|
|
probs_b_cur[l] += prob[c] * (
|
|
|
|
|
probs_b_prev[l] + probs_nb_prev[l])
|
|
|
|
|
else:
|
|
|
|
|
l_plus = l + ' ' + str(c)
|
|
|
|
|
if not prefix_set_next.has_key(l_plus):
|
|
|
|
|
probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0
|
|
|
|
|
|
|
|
|
|
if c == end_id:
|
|
|
|
|
probs_nb_cur[l_plus] += prob[c] * probs_b[l]
|
|
|
|
|
probs_nb_cur[l] += prob[c] * probs_nb[l]
|
|
|
|
|
probs_nb_cur[l_plus] += prob[c] * probs_b_prev[l]
|
|
|
|
|
probs_nb_cur[l] += prob[c] * probs_nb_prev[l]
|
|
|
|
|
elif c == space_id:
|
|
|
|
|
if ext_scoring_func is None:
|
|
|
|
|
score = 1.0
|
|
|
|
|
else:
|
|
|
|
|
prefix_sent = ids2sentence(ids_list, vocabulary)
|
|
|
|
|
score = ext_scoring_func(prefix_sent)
|
|
|
|
|
prefix = ids2sentence(ids_list, vocabulary)
|
|
|
|
|
score = ext_scoring_func(prefix)
|
|
|
|
|
probs_nb_cur[l_plus] += score * prob[c] * (
|
|
|
|
|
probs_b[l] + probs_nb[l])
|
|
|
|
|
probs_b_prev[l] + probs_nb_prev[l])
|
|
|
|
|
else:
|
|
|
|
|
probs_nb_cur[l_plus] += prob[c] * (
|
|
|
|
|
probs_b[l] + probs_nb[l])
|
|
|
|
|
probs_b_prev[l] + probs_nb_prev[l])
|
|
|
|
|
# add l_plus into prefix_set_next
|
|
|
|
|
prefix_set_next[l_plus] = probs_nb_cur[
|
|
|
|
|
l_plus] + probs_b_cur[l_plus]
|
|
|
|
|
# add l into prefix_set_next
|
|
|
|
|
prefix_set_next[l] = probs_b_cur[l] + probs_nb_cur[l]
|
|
|
|
|
# update probs
|
|
|
|
|
probs_b, probs_nb = copy.deepcopy(probs_b_cur), copy.deepcopy(
|
|
|
|
|
probs_b_prev, probs_nb_prev = copy.deepcopy(probs_b_cur), copy.deepcopy(
|
|
|
|
|
probs_nb_cur)
|
|
|
|
|
|
|
|
|
|
## store top beam_size prefixes
|
|
|
|
|