enable ctc beam search decoder

pull/2/head
Yibing Liu 8 years ago
parent cfe9d22866
commit dedbfb2654

@ -23,10 +23,26 @@ def ids_id2token(ids_list):
return ids_str
def language_model(ids_list, vocabulary):
# lookup ptb vocabulary
ptb_vocab_path = "./data/ptb_vocab.txt"
sentence = ''.join([vocabulary[ids] for ids in ids_list])
words = sentence.split(' ')
last_word = words[-1]
with open(ptb_vocab_path, 'r') as ptb_vocab:
f = ptb_vocab.readline()
while f:
if f == last_word:
return 1.0
f = ptb_vocab.readline()
return 0.0
def ctc_beam_search_decoder(input_probs_matrix,
beam_size,
vocabulary,
max_time_steps=None,
lang_model=None,
lang_model=language_model,
alpha=1.0,
beta=1.0,
blank_id=0,
@ -120,7 +136,7 @@ def ctc_beam_search_decoder(input_probs_matrix,
probs_nb_cur[l] += prob[c] * probs_nb[l]
elif c == space_id:
lm = 1.0 if lang_model is None \
else np.power(lang_model(ids_list), alpha)
else np.power(lang_model(ids_list, vocabulary), alpha)
probs_nb_cur[l_plus] += lm * prob[c] * (
probs_b[l] + probs_nb[l])
else:
@ -145,9 +161,10 @@ def ctc_beam_search_decoder(input_probs_matrix,
beam_result = []
for (seq, prob) in prefix_set_prev.items():
if prob > 0.0:
ids_list = ids_str2list(seq)
ids_list = ids_str2list(seq)[1:]
result = ''.join([vocabulary[ids] for ids in ids_list])
log_prob = np.log(prob)
beam_result.append([log_prob, ids_list[1:]])
beam_result.append([log_prob, result])
## output top beam_size decoding results
beam_result = sorted(beam_result, key=lambda asd: asd[0], reverse=True)
@ -156,11 +173,6 @@ def ctc_beam_search_decoder(input_probs_matrix,
return beam_result
def language_model(input):
# TODO
return random.uniform(0, 1)
def simple_test():
input_probs_matrix = [[0.1, 0.3, 0.6], [0.2, 0.1, 0.7], [0.5, 0.2, 0.3]]

@ -4,6 +4,7 @@
from itertools import groupby
import numpy as np
from ctc_beam_search_decoder import *
def ctc_best_path_decode(probs_seq, vocabulary):
@ -36,7 +37,11 @@ def ctc_best_path_decode(probs_seq, vocabulary):
return ''.join([vocabulary[index] for index in index_list])
def ctc_decode(probs_seq, vocabulary, method):
def ctc_decode(probs_seq,
vocabulary,
method,
beam_size=None,
num_results_per_sample=None):
"""
CTC-like sequence decoding from a sequence of likelihood probablilites.
@ -56,5 +61,12 @@ def ctc_decode(probs_seq, vocabulary, method):
raise ValueError("probs dimension mismatchedd with vocabulary")
if method == "best_path":
return ctc_best_path_decode(probs_seq, vocabulary)
elif method == "beam_search":
return ctc_beam_search_decoder(
input_probs_matrix=probs_seq,
vocabulary=vocabulary,
beam_size=beam_size,
blank_id=len(vocabulary),
num_results_per_sample=num_results_per_sample)
else:
raise ValueError("Decoding method [%s] is not supported.")
raise ValueError("Decoding method [%s] is not supported." % method)

@ -57,6 +57,23 @@ parser.add_argument(
default='data/eng_vocab.txt',
type=str,
help="Vocabulary filepath. (default: %(default)s)")
parser.add_argument(
"--decode_method",
default='best_path',
type=str,
help="Method for ctc decoding, best_path or beam_search. (default: %(default)s)"
)
parser.add_argument(
"--beam_size",
default=50,
type=int,
help="Width for beam search decoding. (default: %(default)d)")
parser.add_argument(
"--num_result_per_sample",
default=2,
type=int,
help="Number of results per given sample in beam search. (default: %(default)d)"
)
args = parser.parse_args()
@ -120,12 +137,22 @@ def infer():
# decode and print
for i, probs in enumerate(probs_split):
output_transcription = ctc_decode(
best_path_transcription = ctc_decode(
probs_seq=probs, vocabulary=vocab_list, method="best_path")
target_transcription = ''.join(
[vocab_list[index] for index in infer_data[i][1]])
print("Target Transcription: %s \nOutput Transcription: %s \n" %
(target_transcription, output_transcription))
print("\nTarget Transcription: %s \nBst_path Transcription: %s" %
(target_transcription, best_path_transcription))
beam_search_transcription = ctc_decode(
probs_seq=probs,
vocabulary=vocab_list,
method="beam_search",
beam_size=args.beam_size,
num_results_per_sample=args.num_result_per_sample)
for index in range(len(beam_search_transcription)):
print("LM No, %d - %4f: %s " %
(index, beam_search_transcription[index][0],
beam_search_transcription[index][1]))
def main():

Loading…
Cancel
Save