|
|
|
@ -120,7 +120,7 @@ def ctc_beam_search_decoder(probs_seq,
|
|
|
|
|
prob_idx = prob_idx[0:cutoff_len]
|
|
|
|
|
|
|
|
|
|
for l in prefix_set_prev:
|
|
|
|
|
if not prefix_set_next.has_key(l):
|
|
|
|
|
if l not in prefix_set_next:
|
|
|
|
|
probs_b_cur[l], probs_nb_cur[l] = 0.0, 0.0
|
|
|
|
|
|
|
|
|
|
# extend prefix by travering prob_idx
|
|
|
|
@ -134,7 +134,7 @@ def ctc_beam_search_decoder(probs_seq,
|
|
|
|
|
last_char = l[-1]
|
|
|
|
|
new_char = vocabulary[c]
|
|
|
|
|
l_plus = l + new_char
|
|
|
|
|
if not prefix_set_next.has_key(l_plus):
|
|
|
|
|
if l_plus not in prefix_set_next:
|
|
|
|
|
probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0
|
|
|
|
|
|
|
|
|
|
if new_char == last_char:
|
|
|
|
@ -161,7 +161,7 @@ def ctc_beam_search_decoder(probs_seq,
|
|
|
|
|
|
|
|
|
|
## store top beam_size prefixes
|
|
|
|
|
prefix_set_prev = sorted(
|
|
|
|
|
prefix_set_next.iteritems(), key=lambda asd: asd[1], reverse=True)
|
|
|
|
|
prefix_set_next.items(), key=lambda asd: asd[1], reverse=True)
|
|
|
|
|
if beam_size < len(prefix_set_prev):
|
|
|
|
|
prefix_set_prev = prefix_set_prev[:beam_size]
|
|
|
|
|
prefix_set_prev = dict(prefix_set_prev)
|
|
|
|
|