diff --git a/ctc_beam_search_decoder.py b/ctc_beam_search_decoder.py index 3223c0c25..f66d545aa 100644 --- a/ctc_beam_search_decoder.py +++ b/ctc_beam_search_decoder.py @@ -23,10 +23,26 @@ def ids_id2token(ids_list): return ids_str +def language_model(ids_list, vocabulary): + # lookup ptb vocabulary + ptb_vocab_path = "./data/ptb_vocab.txt" + sentence = ''.join([vocabulary[ids] for ids in ids_list]) + words = sentence.split(' ') + last_word = words[-1] + with open(ptb_vocab_path, 'r') as ptb_vocab: + f = ptb_vocab.readline() + while f: + if f == last_word: + return 1.0 + f = ptb_vocab.readline() + return 0.0 + + def ctc_beam_search_decoder(input_probs_matrix, beam_size, + vocabulary, max_time_steps=None, - lang_model=None, + lang_model=language_model, alpha=1.0, beta=1.0, blank_id=0, @@ -120,7 +136,7 @@ def ctc_beam_search_decoder(input_probs_matrix, probs_nb_cur[l] += prob[c] * probs_nb[l] elif c == space_id: lm = 1.0 if lang_model is None \ - else np.power(lang_model(ids_list), alpha) + else np.power(lang_model(ids_list, vocabulary), alpha) probs_nb_cur[l_plus] += lm * prob[c] * ( probs_b[l] + probs_nb[l]) else: @@ -145,9 +161,10 @@ def ctc_beam_search_decoder(input_probs_matrix, beam_result = [] for (seq, prob) in prefix_set_prev.items(): if prob > 0.0: - ids_list = ids_str2list(seq) + ids_list = ids_str2list(seq)[1:] + result = ''.join([vocabulary[ids] for ids in ids_list]) log_prob = np.log(prob) - beam_result.append([log_prob, ids_list[1:]]) + beam_result.append([log_prob, result]) ## output top beam_size decoding results beam_result = sorted(beam_result, key=lambda asd: asd[0], reverse=True) @@ -156,11 +173,6 @@ def ctc_beam_search_decoder(input_probs_matrix, return beam_result -def language_model(input): - # TODO - return random.uniform(0, 1) - - def simple_test(): input_probs_matrix = [[0.1, 0.3, 0.6], [0.2, 0.1, 0.7], [0.5, 0.2, 0.3]] diff --git a/decoder.py b/decoder.py index 7c4b95263..34e1715c3 100755 --- a/decoder.py +++ b/decoder.py @@ -4,6 +4,7 @@ from itertools import groupby import numpy as np +from ctc_beam_search_decoder import * def ctc_best_path_decode(probs_seq, vocabulary): @@ -36,7 +37,11 @@ def ctc_best_path_decode(probs_seq, vocabulary): return ''.join([vocabulary[index] for index in index_list]) -def ctc_decode(probs_seq, vocabulary, method): +def ctc_decode(probs_seq, + vocabulary, + method, + beam_size=None, + num_results_per_sample=None): """ CTC-like sequence decoding from a sequence of likelihood probablilites. @@ -56,5 +61,12 @@ def ctc_decode(probs_seq, vocabulary, method): raise ValueError("probs dimension mismatchedd with vocabulary") if method == "best_path": return ctc_best_path_decode(probs_seq, vocabulary) + elif method == "beam_search": + return ctc_beam_search_decoder( + input_probs_matrix=probs_seq, + vocabulary=vocabulary, + beam_size=beam_size, + blank_id=len(vocabulary), + num_results_per_sample=num_results_per_sample) else: - raise ValueError("Decoding method [%s] is not supported.") + raise ValueError("Decoding method [%s] is not supported." % method) diff --git a/infer.py b/infer.py index 598c348b0..e5ecf6f35 100644 --- a/infer.py +++ b/infer.py @@ -57,6 +57,23 @@ parser.add_argument( default='data/eng_vocab.txt', type=str, help="Vocabulary filepath. (default: %(default)s)") +parser.add_argument( + "--decode_method", + default='best_path', + type=str, + help="Method for ctc decoding, best_path or beam_search. (default: %(default)s)" +) +parser.add_argument( + "--beam_size", + default=50, + type=int, + help="Width for beam search decoding. (default: %(default)d)") +parser.add_argument( + "--num_result_per_sample", + default=2, + type=int, + help="Number of results per given sample in beam search. (default: %(default)d)" +) args = parser.parse_args() @@ -120,12 +137,22 @@ def infer(): # decode and print for i, probs in enumerate(probs_split): - output_transcription = ctc_decode( + best_path_transcription = ctc_decode( probs_seq=probs, vocabulary=vocab_list, method="best_path") target_transcription = ''.join( [vocab_list[index] for index in infer_data[i][1]]) - print("Target Transcription: %s \nOutput Transcription: %s \n" % - (target_transcription, output_transcription)) + print("\nTarget Transcription: %s \nBst_path Transcription: %s" % + (target_transcription, best_path_transcription)) + beam_search_transcription = ctc_decode( + probs_seq=probs, + vocabulary=vocab_list, + method="beam_search", + beam_size=args.beam_size, + num_results_per_sample=args.num_result_per_sample) + for index in range(len(beam_search_transcription)): + print("LM No, %d - %4f: %s " % + (index, beam_search_transcription[index][0], + beam_search_transcription[index][1])) def main():