|
|
|
@ -121,25 +121,10 @@ def ctc_beam_search_decoder(probs_seq,
|
|
|
|
|
if not blank_id < probs_dim:
|
|
|
|
|
raise ValueError("blank_id shouldn't be greater than probs dimension")
|
|
|
|
|
|
|
|
|
|
# assign space_id
|
|
|
|
|
if ' ' not in vocabulary:
|
|
|
|
|
raise ValueError("space doesn't exist in vocabulary")
|
|
|
|
|
space_id = vocabulary.index(' ')
|
|
|
|
|
|
|
|
|
|
# function to convert ids in string to list
|
|
|
|
|
def ids_str2list(ids_str):
|
|
|
|
|
ids_str = ids_str.split(' ')
|
|
|
|
|
ids_list = [int(elem) for elem in ids_str]
|
|
|
|
|
return ids_list
|
|
|
|
|
|
|
|
|
|
# function to convert ids list to sentence
|
|
|
|
|
def ids2sentence(ids_list, vocab):
|
|
|
|
|
return ''.join([vocab[ids] for ids in ids_list])
|
|
|
|
|
|
|
|
|
|
## initialize
|
|
|
|
|
# the set containing selected prefixes
|
|
|
|
|
prefix_set_prev = {'-1': 1.0}
|
|
|
|
|
probs_b_prev, probs_nb_prev = {'-1': 1.0}, {'-1': 0.0}
|
|
|
|
|
prefix_set_prev = {'\t': 1.0}
|
|
|
|
|
probs_b_prev, probs_nb_prev = {'\t': 1.0}, {'\t': 0.0}
|
|
|
|
|
|
|
|
|
|
## extend prefix in loop
|
|
|
|
|
for time_step in range(num_time_steps):
|
|
|
|
@ -148,10 +133,6 @@ def ctc_beam_search_decoder(probs_seq,
|
|
|
|
|
probs_b_cur, probs_nb_cur = {}, {}
|
|
|
|
|
for l in prefix_set_prev:
|
|
|
|
|
prob = probs_seq[time_step]
|
|
|
|
|
|
|
|
|
|
# convert ids in string to list
|
|
|
|
|
ids_list = ids_str2list(l)
|
|
|
|
|
end_id = ids_list[-1]
|
|
|
|
|
if not prefix_set_next.has_key(l):
|
|
|
|
|
probs_b_cur[l], probs_nb_cur[l] = 0.0, 0.0
|
|
|
|
|
|
|
|
|
@ -161,18 +142,20 @@ def ctc_beam_search_decoder(probs_seq,
|
|
|
|
|
probs_b_cur[l] += prob[c] * (
|
|
|
|
|
probs_b_prev[l] + probs_nb_prev[l])
|
|
|
|
|
else:
|
|
|
|
|
l_plus = l + ' ' + str(c)
|
|
|
|
|
last_char = l[-1]
|
|
|
|
|
new_char = vocabulary[c]
|
|
|
|
|
l_plus = l + new_char
|
|
|
|
|
if not prefix_set_next.has_key(l_plus):
|
|
|
|
|
probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0
|
|
|
|
|
|
|
|
|
|
if c == end_id:
|
|
|
|
|
if new_char == last_char:
|
|
|
|
|
probs_nb_cur[l_plus] += prob[c] * probs_b_prev[l]
|
|
|
|
|
probs_nb_cur[l] += prob[c] * probs_nb_prev[l]
|
|
|
|
|
elif c == space_id:
|
|
|
|
|
if ext_scoring_func is None:
|
|
|
|
|
elif new_char == ' ':
|
|
|
|
|
if (ext_scoring_func is None) or (len(l) == 1):
|
|
|
|
|
score = 1.0
|
|
|
|
|
else:
|
|
|
|
|
prefix = ids2sentence(ids_list, vocabulary)
|
|
|
|
|
prefix = l[1:]
|
|
|
|
|
score = ext_scoring_func(prefix)
|
|
|
|
|
probs_nb_cur[l_plus] += score * prob[c] * (
|
|
|
|
|
probs_b_prev[l] + probs_nb_prev[l])
|
|
|
|
@ -185,8 +168,7 @@ def ctc_beam_search_decoder(probs_seq,
|
|
|
|
|
# add l into prefix_set_next
|
|
|
|
|
prefix_set_next[l] = probs_b_cur[l] + probs_nb_cur[l]
|
|
|
|
|
# update probs
|
|
|
|
|
probs_b_prev, probs_nb_prev = copy.deepcopy(probs_b_cur), copy.deepcopy(
|
|
|
|
|
probs_nb_cur)
|
|
|
|
|
probs_b_prev, probs_nb_prev = probs_b_cur, probs_nb_cur
|
|
|
|
|
|
|
|
|
|
## store top beam_size prefixes
|
|
|
|
|
prefix_set_prev = sorted(
|
|
|
|
@ -198,8 +180,7 @@ def ctc_beam_search_decoder(probs_seq,
|
|
|
|
|
beam_result = []
|
|
|
|
|
for (seq, prob) in prefix_set_prev.items():
|
|
|
|
|
if prob > 0.0:
|
|
|
|
|
ids_list = ids_str2list(seq)[1:]
|
|
|
|
|
result = ids2sentence(ids_list, vocabulary)
|
|
|
|
|
result = seq[1:]
|
|
|
|
|
log_prob = np.log(prob)
|
|
|
|
|
beam_result.append([log_prob, result])
|
|
|
|
|
|
|
|
|
|