|
|
|
@ -23,10 +23,26 @@ def ids_id2token(ids_list):
|
|
|
|
|
return ids_str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def language_model(ids_list, vocabulary):
|
|
|
|
|
# lookup ptb vocabulary
|
|
|
|
|
ptb_vocab_path = "./data/ptb_vocab.txt"
|
|
|
|
|
sentence = ''.join([vocabulary[ids] for ids in ids_list])
|
|
|
|
|
words = sentence.split(' ')
|
|
|
|
|
last_word = words[-1]
|
|
|
|
|
with open(ptb_vocab_path, 'r') as ptb_vocab:
|
|
|
|
|
f = ptb_vocab.readline()
|
|
|
|
|
while f:
|
|
|
|
|
if f == last_word:
|
|
|
|
|
return 1.0
|
|
|
|
|
f = ptb_vocab.readline()
|
|
|
|
|
return 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ctc_beam_search_decoder(input_probs_matrix,
|
|
|
|
|
beam_size,
|
|
|
|
|
vocabulary,
|
|
|
|
|
max_time_steps=None,
|
|
|
|
|
lang_model=None,
|
|
|
|
|
lang_model=language_model,
|
|
|
|
|
alpha=1.0,
|
|
|
|
|
beta=1.0,
|
|
|
|
|
blank_id=0,
|
|
|
|
@ -120,7 +136,7 @@ def ctc_beam_search_decoder(input_probs_matrix,
|
|
|
|
|
probs_nb_cur[l] += prob[c] * probs_nb[l]
|
|
|
|
|
elif c == space_id:
|
|
|
|
|
lm = 1.0 if lang_model is None \
|
|
|
|
|
else np.power(lang_model(ids_list), alpha)
|
|
|
|
|
else np.power(lang_model(ids_list, vocabulary), alpha)
|
|
|
|
|
probs_nb_cur[l_plus] += lm * prob[c] * (
|
|
|
|
|
probs_b[l] + probs_nb[l])
|
|
|
|
|
else:
|
|
|
|
@ -145,9 +161,10 @@ def ctc_beam_search_decoder(input_probs_matrix,
|
|
|
|
|
beam_result = []
|
|
|
|
|
for (seq, prob) in prefix_set_prev.items():
|
|
|
|
|
if prob > 0.0:
|
|
|
|
|
ids_list = ids_str2list(seq)
|
|
|
|
|
ids_list = ids_str2list(seq)[1:]
|
|
|
|
|
result = ''.join([vocabulary[ids] for ids in ids_list])
|
|
|
|
|
log_prob = np.log(prob)
|
|
|
|
|
beam_result.append([log_prob, ids_list[1:]])
|
|
|
|
|
beam_result.append([log_prob, result])
|
|
|
|
|
|
|
|
|
|
## output top beam_size decoding results
|
|
|
|
|
beam_result = sorted(beam_result, key=lambda asd: asd[0], reverse=True)
|
|
|
|
@ -156,11 +173,6 @@ def ctc_beam_search_decoder(input_probs_matrix,
|
|
|
|
|
return beam_result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def language_model(input):
|
|
|
|
|
# TODO
|
|
|
|
|
return random.uniform(0, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def simple_test():
|
|
|
|
|
|
|
|
|
|
input_probs_matrix = [[0.1, 0.3, 0.6], [0.2, 0.1, 0.7], [0.5, 0.2, 0.3]]
|
|
|
|
|