parent
dedbfb2654
commit
51f35a5372
@ -1,192 +0,0 @@
|
|||||||
## This is a prototype of ctc beam search decoder
|
|
||||||
|
|
||||||
import copy
|
|
||||||
import random
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
# vocab = blank + space + English characters
|
|
||||||
#vocab = ['-', ' '] + [chr(i) for i in range(97, 123)]
|
|
||||||
|
|
||||||
vocab = ['-', '_', 'a']
|
|
||||||
|
|
||||||
|
|
||||||
def ids_list2str(ids_list):
|
|
||||||
ids_str = [str(elem) for elem in ids_list]
|
|
||||||
ids_str = ' '.join(ids_str)
|
|
||||||
return ids_str
|
|
||||||
|
|
||||||
|
|
||||||
def ids_id2token(ids_list):
|
|
||||||
ids_str = ''
|
|
||||||
for ids in ids_list:
|
|
||||||
ids_str += vocab[ids]
|
|
||||||
return ids_str
|
|
||||||
|
|
||||||
|
|
||||||
def language_model(ids_list, vocabulary):
|
|
||||||
# lookup ptb vocabulary
|
|
||||||
ptb_vocab_path = "./data/ptb_vocab.txt"
|
|
||||||
sentence = ''.join([vocabulary[ids] for ids in ids_list])
|
|
||||||
words = sentence.split(' ')
|
|
||||||
last_word = words[-1]
|
|
||||||
with open(ptb_vocab_path, 'r') as ptb_vocab:
|
|
||||||
f = ptb_vocab.readline()
|
|
||||||
while f:
|
|
||||||
if f == last_word:
|
|
||||||
return 1.0
|
|
||||||
f = ptb_vocab.readline()
|
|
||||||
return 0.0
|
|
||||||
|
|
||||||
|
|
||||||
def ctc_beam_search_decoder(input_probs_matrix,
|
|
||||||
beam_size,
|
|
||||||
vocabulary,
|
|
||||||
max_time_steps=None,
|
|
||||||
lang_model=language_model,
|
|
||||||
alpha=1.0,
|
|
||||||
beta=1.0,
|
|
||||||
blank_id=0,
|
|
||||||
space_id=1,
|
|
||||||
num_results_per_sample=None):
|
|
||||||
'''
|
|
||||||
Beam search decoder for CTC-trained network, adapted from Algorithm 1
|
|
||||||
in https://arxiv.org/abs/1408.2873.
|
|
||||||
|
|
||||||
:param input_probs_matrix: probs matrix for input sequence, row major
|
|
||||||
:type input_probs_matrix: 2D matrix.
|
|
||||||
:param beam_size: width for beam search
|
|
||||||
:type beam_size: int
|
|
||||||
:max_time_steps: maximum steps' number for input sequence,
|
|
||||||
<=len(input_probs_matrix)
|
|
||||||
:type max_time_steps: int
|
|
||||||
:lang_model: language model for scoring
|
|
||||||
:type lang_model: function
|
|
||||||
:param alpha: parameter associated with language model.
|
|
||||||
:type alpha: float
|
|
||||||
:param beta: parameter associated with word count
|
|
||||||
:type beta: float
|
|
||||||
:param blank_id: id of blank, default 0.
|
|
||||||
:type blank_id: int
|
|
||||||
:param space_id: id of space, default 1.
|
|
||||||
:type space_id: int
|
|
||||||
:param num_result_per_sample: the number of output decoding results
|
|
||||||
per given sample, <=beam_size.
|
|
||||||
:type num_result_per_sample: int
|
|
||||||
'''
|
|
||||||
|
|
||||||
# function to convert ids in string to list
|
|
||||||
def ids_str2list(ids_str):
|
|
||||||
ids_str = ids_str.split(' ')
|
|
||||||
ids_list = [int(elem) for elem in ids_str]
|
|
||||||
return ids_list
|
|
||||||
|
|
||||||
# counting words in a character list
|
|
||||||
def word_count(ids_list):
|
|
||||||
cnt = 0
|
|
||||||
for elem in ids_list:
|
|
||||||
if elem == space_id:
|
|
||||||
cnt += 1
|
|
||||||
return cnt
|
|
||||||
|
|
||||||
if num_results_per_sample is None:
|
|
||||||
num_results_per_sample = beam_size
|
|
||||||
assert num_results_per_sample <= beam_size
|
|
||||||
|
|
||||||
if max_time_steps is None:
|
|
||||||
max_time_steps = len(input_probs_matrix)
|
|
||||||
else:
|
|
||||||
max_time_steps = min(max_time_steps, len(input_probs_matrix))
|
|
||||||
assert max_time_steps > 0
|
|
||||||
|
|
||||||
vocab_dim = len(input_probs_matrix[0])
|
|
||||||
assert blank_id < vocab_dim
|
|
||||||
assert space_id < vocab_dim
|
|
||||||
|
|
||||||
## initialize
|
|
||||||
start_id = -1
|
|
||||||
# the set containing selected prefixes
|
|
||||||
prefix_set_prev = {str(start_id): 1.0}
|
|
||||||
probs_b, probs_nb = {str(start_id): 1.0}, {str(start_id): 0.0}
|
|
||||||
|
|
||||||
## extend prefix in loop
|
|
||||||
for time_step in range(max_time_steps):
|
|
||||||
# the set containing candidate prefixes
|
|
||||||
prefix_set_next = {}
|
|
||||||
probs_b_cur, probs_nb_cur = {}, {}
|
|
||||||
for l in prefix_set_prev:
|
|
||||||
prob = input_probs_matrix[time_step]
|
|
||||||
|
|
||||||
# convert ids in string to list
|
|
||||||
ids_list = ids_str2list(l)
|
|
||||||
end_id = ids_list[-1]
|
|
||||||
if not prefix_set_next.has_key(l):
|
|
||||||
probs_b_cur[l], probs_nb_cur[l] = 0.0, 0.0
|
|
||||||
|
|
||||||
# extend prefix by travering vocabulary
|
|
||||||
for c in range(0, vocab_dim):
|
|
||||||
if c == blank_id:
|
|
||||||
probs_b_cur[l] += prob[c] * (probs_b[l] + probs_nb[l])
|
|
||||||
else:
|
|
||||||
l_plus = l + ' ' + str(c)
|
|
||||||
if not prefix_set_next.has_key(l_plus):
|
|
||||||
probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0
|
|
||||||
|
|
||||||
if c == end_id:
|
|
||||||
probs_nb_cur[l_plus] += prob[c] * probs_b[l]
|
|
||||||
probs_nb_cur[l] += prob[c] * probs_nb[l]
|
|
||||||
elif c == space_id:
|
|
||||||
lm = 1.0 if lang_model is None \
|
|
||||||
else np.power(lang_model(ids_list, vocabulary), alpha)
|
|
||||||
probs_nb_cur[l_plus] += lm * prob[c] * (
|
|
||||||
probs_b[l] + probs_nb[l])
|
|
||||||
else:
|
|
||||||
probs_nb_cur[l_plus] += prob[c] * (
|
|
||||||
probs_b[l] + probs_nb[l])
|
|
||||||
# add l_plus into prefix_set_next
|
|
||||||
prefix_set_next[l_plus] = probs_nb_cur[
|
|
||||||
l_plus] + probs_b_cur[l_plus]
|
|
||||||
# add l into prefix_set_next
|
|
||||||
prefix_set_next[l] = probs_b_cur[l] + probs_nb_cur[l]
|
|
||||||
# update probs
|
|
||||||
probs_b, probs_nb = copy.deepcopy(probs_b_cur), copy.deepcopy(
|
|
||||||
probs_nb_cur)
|
|
||||||
|
|
||||||
## store top beam_size prefixes
|
|
||||||
prefix_set_prev = sorted(
|
|
||||||
prefix_set_next.iteritems(), key=lambda asd: asd[1], reverse=True)
|
|
||||||
if beam_size < len(prefix_set_prev):
|
|
||||||
prefix_set_prev = prefix_set_prev[:beam_size]
|
|
||||||
prefix_set_prev = dict(prefix_set_prev)
|
|
||||||
|
|
||||||
beam_result = []
|
|
||||||
for (seq, prob) in prefix_set_prev.items():
|
|
||||||
if prob > 0.0:
|
|
||||||
ids_list = ids_str2list(seq)[1:]
|
|
||||||
result = ''.join([vocabulary[ids] for ids in ids_list])
|
|
||||||
log_prob = np.log(prob)
|
|
||||||
beam_result.append([log_prob, result])
|
|
||||||
|
|
||||||
## output top beam_size decoding results
|
|
||||||
beam_result = sorted(beam_result, key=lambda asd: asd[0], reverse=True)
|
|
||||||
if num_results_per_sample < beam_size:
|
|
||||||
beam_result = beam_result[:num_results_per_sample]
|
|
||||||
return beam_result
|
|
||||||
|
|
||||||
|
|
||||||
def simple_test():
|
|
||||||
|
|
||||||
input_probs_matrix = [[0.1, 0.3, 0.6], [0.2, 0.1, 0.7], [0.5, 0.2, 0.3]]
|
|
||||||
|
|
||||||
beam_result = ctc_beam_search_decoder(
|
|
||||||
input_probs_matrix=input_probs_matrix,
|
|
||||||
beam_size=20,
|
|
||||||
blank_id=0,
|
|
||||||
space_id=1, )
|
|
||||||
|
|
||||||
print "\nbeam search output:"
|
|
||||||
for result in beam_result:
|
|
||||||
print("%6f\t%s" % (result[0], ids_id2token(result[1])))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
simple_test()
|
|
@ -1,69 +0,0 @@
|
|||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import tensorflow as tf
|
|
||||||
from tensorflow.python.framework import ops
|
|
||||||
from tensorflow.python.ops import array_ops
|
|
||||||
import ctc_beam_search_decoder as tested_decoder
|
|
||||||
|
|
||||||
|
|
||||||
def test_beam_search_decoder():
|
|
||||||
max_time_steps = 6
|
|
||||||
beam_size = 20
|
|
||||||
num_results_per_sample = 20
|
|
||||||
|
|
||||||
input_prob_matrix_0 = np.asarray(
|
|
||||||
[
|
|
||||||
[0.30999, 0.309938, 0.0679938, 0.0673362, 0.0708352, 0.173908],
|
|
||||||
[0.215136, 0.439699, 0.0370931, 0.0393967, 0.0381581, 0.230517],
|
|
||||||
[0.199959, 0.489485, 0.0233221, 0.0251417, 0.0233289, 0.238763],
|
|
||||||
[0.279611, 0.452966, 0.0204795, 0.0209126, 0.0194803, 0.20655],
|
|
||||||
[0.51286, 0.288951, 0.0243026, 0.0220788, 0.0219297, 0.129878],
|
|
||||||
# Random entry added in at time=5
|
|
||||||
[0.155251, 0.164444, 0.173517, 0.176138, 0.169979, 0.160671]
|
|
||||||
],
|
|
||||||
dtype=np.float32)
|
|
||||||
|
|
||||||
# Add arbitrary offset - this is fine
|
|
||||||
input_log_prob_matrix_0 = np.log(input_prob_matrix_0) #+ 2.0
|
|
||||||
|
|
||||||
# len max_time_steps array of batch_size x depth matrices
|
|
||||||
inputs = ([
|
|
||||||
input_log_prob_matrix_0[t, :][np.newaxis, :]
|
|
||||||
for t in range(max_time_steps)
|
|
||||||
])
|
|
||||||
|
|
||||||
inputs_t = [ops.convert_to_tensor(x) for x in inputs]
|
|
||||||
inputs_t = array_ops.stack(inputs_t)
|
|
||||||
|
|
||||||
# run CTC beam search decoder in tensorflow
|
|
||||||
with tf.Session() as sess:
|
|
||||||
decoded, log_probabilities = tf.nn.ctc_beam_search_decoder(
|
|
||||||
inputs_t, [max_time_steps],
|
|
||||||
beam_width=beam_size,
|
|
||||||
top_paths=num_results_per_sample,
|
|
||||||
merge_repeated=False)
|
|
||||||
tf_decoded = sess.run(decoded)
|
|
||||||
tf_log_probs = sess.run(log_probabilities)
|
|
||||||
|
|
||||||
# run tested CTC beam search decoder
|
|
||||||
beam_result = tested_decoder.ctc_beam_search_decoder(
|
|
||||||
input_probs_matrix=input_prob_matrix_0,
|
|
||||||
beam_size=beam_size,
|
|
||||||
blank_id=5, # default blank_id in tensorflow decoder is (num classes-1)
|
|
||||||
space_id=4, # doesn't matter
|
|
||||||
max_time_steps=max_time_steps,
|
|
||||||
num_results_per_sample=num_results_per_sample)
|
|
||||||
|
|
||||||
# compare decoding result
|
|
||||||
print(
|
|
||||||
"{tf_decoder log probs} \t {tested_decoder log probs}: {tf_decoder result} {tested_decoder result}"
|
|
||||||
)
|
|
||||||
for index in range(len(beam_result)):
|
|
||||||
print(('%6f\t%6f: ') % (tf_log_probs[0][index], beam_result[index][0]),
|
|
||||||
tf_decoded[index].values, ' ', beam_result[index][1])
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
test_beam_search_decoder()
|
|
Loading…
Reference in new issue