From c943ca79acefc85e605a6e414e90239ee56f98be Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Sun, 4 Jun 2017 19:15:09 +0800 Subject: [PATCH 01/28] mv ctc_beam_search_decoder into deep_speech_2/ --- ctc_beam_search_decoder.py | 162 ++++++++++++++++++++++++++++++++ test_ctc_beam_search_decoder.py | 69 ++++++++++++++ 2 files changed, 231 insertions(+) create mode 100644 ctc_beam_search_decoder.py create mode 100644 test_ctc_beam_search_decoder.py diff --git a/ctc_beam_search_decoder.py b/ctc_beam_search_decoder.py new file mode 100644 index 00000000..873121b1 --- /dev/null +++ b/ctc_beam_search_decoder.py @@ -0,0 +1,162 @@ +## This is a prototype of ctc beam search decoder + +import copy +import random +import numpy as np + +# vocab = blank + space + English characters +#vocab = ['-', ' '] + [chr(i) for i in range(97, 123)] + +vocab = ['-', '_', 'a'] + + +def ids_str2list(ids_str): + ids_str = ids_str.split(' ') + ids_list = [int(elem) for elem in ids_str] + return ids_list + + +def ids_list2str(ids_list): + ids_str = [str(elem) for elem in ids_list] + ids_str = ' '.join(ids_str) + return ids_str + + +def ids_id2token(ids_list): + ids_str = '' + for ids in ids_list: + ids_str += vocab[ids] + return ids_str + + +def ctc_beam_search_decoder(input_probs_matrix, + beam_size, + max_time_steps=None, + lang_model=None, + alpha=1.0, + beta=1.0, + blank_id=0, + space_id=1, + num_results_per_sample=None): + ''' + beam search decoder for CTC-trained network, called outside of the recurrent group. + adapted from Algorithm 1 in https://arxiv.org/abs/1408.2873. + + param input_probs_matrix: probs matrix for input sequence, row major + type input_probs_matrix: 2D matrix. + param beam_size: width for beam search + type beam_size: int + max_time_steps: maximum steps' number for input sequence, <=len(input_probs_matrix) + type max_time_steps: int + lang_model: language model for scoring + type lang_model: function + + ...... + + ''' + if num_results_per_sample is None: + num_results_per_sample = beam_size + assert num_results_per_sample <= beam_size + + if max_time_steps is None: + max_time_steps = len(input_probs_matrix) + else: + max_time_steps = min(max_time_steps, len(input_probs_matrix)) + assert max_time_steps > 0 + + vocab_dim = len(input_probs_matrix[0]) + assert blank_id < vocab_dim + assert space_id < vocab_dim + + ## initialize + start_id = -1 + # the set containing selected prefixes + prefix_set_prev = {str(start_id): 1.0} + probs_b, probs_nb = {str(start_id): 1.0}, {str(start_id): 0.0} + + ## extend prefix in loop + for time_step in range(max_time_steps): + # the set containing candidate prefixes + prefix_set_next = {} + probs_b_cur, probs_nb_cur = {}, {} + for l in prefix_set_prev: + prob = input_probs_matrix[time_step] + + # convert ids in string to list + ids_list = ids_str2list(l) + end_id = ids_list[-1] + if not prefix_set_next.has_key(l): + probs_b_cur[l], probs_nb_cur[l] = 0.0, 0.0 + + # extend prefix by travering vocabulary + for c in range(0, vocab_dim): + if c == blank_id: + probs_b_cur[l] += prob[c] * (probs_b[l] + probs_nb[l]) + else: + l_plus = l + ' ' + str(c) + if not prefix_set_next.has_key(l_plus): + probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0 + + if c == end_id: + probs_nb_cur[l_plus] += prob[c] * probs_b[l] + probs_nb_cur[l] += prob[c] * probs_nb[l] + elif c == space_id: + lm = 1.0 if lang_model is None \ + else np.power(lang_model(ids_list), alpha) + probs_nb_cur[l_plus] += lm * prob[c] * ( + probs_b[l] + probs_nb[l]) + else: + probs_nb_cur[l_plus] += prob[c] * ( + probs_b[l] + probs_nb[l]) + # add l_plus into prefix_set_next + prefix_set_next[l_plus] = probs_nb_cur[ + l_plus] + probs_b_cur[l_plus] + # add l into prefix_set_next + prefix_set_next[l] = probs_b_cur[l] + probs_nb_cur[l] + # update probs + probs_b, probs_nb = copy.deepcopy(probs_b_cur), copy.deepcopy( + probs_nb_cur) + + ## store top beam_size prefixes + prefix_set_prev = sorted( + prefix_set_next.iteritems(), key=lambda asd: asd[1], reverse=True) + if beam_size < len(prefix_set_prev): + prefix_set_prev = prefix_set_prev[:beam_size] + prefix_set_prev = dict(prefix_set_prev) + + beam_result = [] + for (seq, prob) in prefix_set_prev.items(): + if prob > 0.0: + ids_list = ids_str2list(seq) + log_prob = np.log(prob) + beam_result.append([log_prob, ids_list[1:]]) + + ## output top beam_size decoding results + beam_result = sorted(beam_result, key=lambda asd: asd[0], reverse=True) + if num_results_per_sample < beam_size: + beam_result = beam_result[:num_results_per_sample] + return beam_result + + +def language_model(input): + # TODO + return random.uniform(0, 1) + + +def simple_test(): + + input_probs_matrix = [[0.1, 0.3, 0.6], [0.2, 0.1, 0.7], [0.5, 0.2, 0.3]] + + beam_result = ctc_beam_search_decoder( + input_probs_matrix=input_probs_matrix, + beam_size=20, + blank_id=0, + space_id=1, ) + + print "\nbeam search output:" + for result in beam_result: + print("%6f\t%s" % (result[0], ids_id2token(result[1]))) + + +if __name__ == '__main__': + simple_test() diff --git a/test_ctc_beam_search_decoder.py b/test_ctc_beam_search_decoder.py new file mode 100644 index 00000000..f7970444 --- /dev/null +++ b/test_ctc_beam_search_decoder.py @@ -0,0 +1,69 @@ +from __future__ import absolute_import +from __future__ import print_function + +import numpy as np +import tensorflow as tf +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +import ctc_beam_search_decoder as tested_decoder + + +def test_beam_search_decoder(): + max_time_steps = 6 + beam_size = 20 + num_results_per_sample = 20 + + input_prob_matrix_0 = np.asarray( + [ + [0.30999, 0.309938, 0.0679938, 0.0673362, 0.0708352, 0.173908], + [0.215136, 0.439699, 0.0370931, 0.0393967, 0.0381581, 0.230517], + [0.199959, 0.489485, 0.0233221, 0.0251417, 0.0233289, 0.238763], + [0.279611, 0.452966, 0.0204795, 0.0209126, 0.0194803, 0.20655], + [0.51286, 0.288951, 0.0243026, 0.0220788, 0.0219297, 0.129878], + # Random entry added in at time=5 + [0.155251, 0.164444, 0.173517, 0.176138, 0.169979, 0.160671] + ], + dtype=np.float32) + + # Add arbitrary offset - this is fine + input_log_prob_matrix_0 = np.log(input_prob_matrix_0) #+ 2.0 + + # len max_time_steps array of batch_size x depth matrices + inputs = ([ + input_log_prob_matrix_0[t, :][np.newaxis, :] + for t in range(max_time_steps) + ]) + + inputs_t = [ops.convert_to_tensor(x) for x in inputs] + inputs_t = array_ops.stack(inputs_t) + + # run CTC beam search decoder in tensorflow + with tf.Session() as sess: + decoded, log_probabilities = tf.nn.ctc_beam_search_decoder( + inputs_t, [max_time_steps], + beam_width=beam_size, + top_paths=num_results_per_sample, + merge_repeated=False) + tf_decoded = sess.run(decoded) + tf_log_probs = sess.run(log_probabilities) + + # run tested CTC beam search decoder + beam_result = tested_decoder.ctc_beam_search_decoder( + input_probs_matrix=input_prob_matrix_0, + beam_size=beam_size, + blank_id=5, # default blank_id in tensorflow decoder is (num classes-1) + space_id=4, # doesn't matter + max_time_steps=max_time_steps, + num_results_per_sample=num_results_per_sample) + + # compare decoding result + print( + "{tf_decoder log probs} \t {tested_decoder log probs}: {tf_decoder result} {tested_decoder result}" + ) + for index in range(len(beam_result)): + print(('%6f\t%6f: ') % (tf_log_probs[0][index], beam_result[index][0]), + tf_decoded[index].values, ' ', beam_result[index][1]) + + +if __name__ == '__main__': + test_beam_search_decoder() From cfe9d22866e4e94802f25033c6217dee8f509c6a Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Sun, 4 Jun 2017 19:19:36 +0800 Subject: [PATCH 02/28] update annotations --- ctc_beam_search_decoder.py | 54 +++++++++++++++++++++++++------------- 1 file changed, 36 insertions(+), 18 deletions(-) diff --git a/ctc_beam_search_decoder.py b/ctc_beam_search_decoder.py index 873121b1..3223c0c2 100644 --- a/ctc_beam_search_decoder.py +++ b/ctc_beam_search_decoder.py @@ -10,12 +10,6 @@ import numpy as np vocab = ['-', '_', 'a'] -def ids_str2list(ids_str): - ids_str = ids_str.split(' ') - ids_list = [int(elem) for elem in ids_str] - return ids_list - - def ids_list2str(ids_list): ids_str = [str(elem) for elem in ids_list] ids_str = ' '.join(ids_str) @@ -39,21 +33,45 @@ def ctc_beam_search_decoder(input_probs_matrix, space_id=1, num_results_per_sample=None): ''' - beam search decoder for CTC-trained network, called outside of the recurrent group. - adapted from Algorithm 1 in https://arxiv.org/abs/1408.2873. + Beam search decoder for CTC-trained network, adapted from Algorithm 1 + in https://arxiv.org/abs/1408.2873. + + :param input_probs_matrix: probs matrix for input sequence, row major + :type input_probs_matrix: 2D matrix. + :param beam_size: width for beam search + :type beam_size: int + :max_time_steps: maximum steps' number for input sequence, + <=len(input_probs_matrix) + :type max_time_steps: int + :lang_model: language model for scoring + :type lang_model: function + :param alpha: parameter associated with language model. + :type alpha: float + :param beta: parameter associated with word count + :type beta: float + :param blank_id: id of blank, default 0. + :type blank_id: int + :param space_id: id of space, default 1. + :type space_id: int + :param num_result_per_sample: the number of output decoding results + per given sample, <=beam_size. + :type num_result_per_sample: int + ''' - param input_probs_matrix: probs matrix for input sequence, row major - type input_probs_matrix: 2D matrix. - param beam_size: width for beam search - type beam_size: int - max_time_steps: maximum steps' number for input sequence, <=len(input_probs_matrix) - type max_time_steps: int - lang_model: language model for scoring - type lang_model: function + # function to convert ids in string to list + def ids_str2list(ids_str): + ids_str = ids_str.split(' ') + ids_list = [int(elem) for elem in ids_str] + return ids_list - ...... + # counting words in a character list + def word_count(ids_list): + cnt = 0 + for elem in ids_list: + if elem == space_id: + cnt += 1 + return cnt - ''' if num_results_per_sample is None: num_results_per_sample = beam_size assert num_results_per_sample <= beam_size From dedbfb2654254e6c45b32221c6c6a09c2de09f9a Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Mon, 5 Jun 2017 08:48:54 +0800 Subject: [PATCH 03/28] enable ctc beam search decoder --- ctc_beam_search_decoder.py | 30 +++++++++++++++++++++--------- decoder.py | 16 ++++++++++++++-- infer.py | 33 ++++++++++++++++++++++++++++++--- 3 files changed, 65 insertions(+), 14 deletions(-) diff --git a/ctc_beam_search_decoder.py b/ctc_beam_search_decoder.py index 3223c0c2..f66d545a 100644 --- a/ctc_beam_search_decoder.py +++ b/ctc_beam_search_decoder.py @@ -23,10 +23,26 @@ def ids_id2token(ids_list): return ids_str +def language_model(ids_list, vocabulary): + # lookup ptb vocabulary + ptb_vocab_path = "./data/ptb_vocab.txt" + sentence = ''.join([vocabulary[ids] for ids in ids_list]) + words = sentence.split(' ') + last_word = words[-1] + with open(ptb_vocab_path, 'r') as ptb_vocab: + f = ptb_vocab.readline() + while f: + if f == last_word: + return 1.0 + f = ptb_vocab.readline() + return 0.0 + + def ctc_beam_search_decoder(input_probs_matrix, beam_size, + vocabulary, max_time_steps=None, - lang_model=None, + lang_model=language_model, alpha=1.0, beta=1.0, blank_id=0, @@ -120,7 +136,7 @@ def ctc_beam_search_decoder(input_probs_matrix, probs_nb_cur[l] += prob[c] * probs_nb[l] elif c == space_id: lm = 1.0 if lang_model is None \ - else np.power(lang_model(ids_list), alpha) + else np.power(lang_model(ids_list, vocabulary), alpha) probs_nb_cur[l_plus] += lm * prob[c] * ( probs_b[l] + probs_nb[l]) else: @@ -145,9 +161,10 @@ def ctc_beam_search_decoder(input_probs_matrix, beam_result = [] for (seq, prob) in prefix_set_prev.items(): if prob > 0.0: - ids_list = ids_str2list(seq) + ids_list = ids_str2list(seq)[1:] + result = ''.join([vocabulary[ids] for ids in ids_list]) log_prob = np.log(prob) - beam_result.append([log_prob, ids_list[1:]]) + beam_result.append([log_prob, result]) ## output top beam_size decoding results beam_result = sorted(beam_result, key=lambda asd: asd[0], reverse=True) @@ -156,11 +173,6 @@ def ctc_beam_search_decoder(input_probs_matrix, return beam_result -def language_model(input): - # TODO - return random.uniform(0, 1) - - def simple_test(): input_probs_matrix = [[0.1, 0.3, 0.6], [0.2, 0.1, 0.7], [0.5, 0.2, 0.3]] diff --git a/decoder.py b/decoder.py index 7c4b9526..34e1715c 100755 --- a/decoder.py +++ b/decoder.py @@ -4,6 +4,7 @@ from itertools import groupby import numpy as np +from ctc_beam_search_decoder import * def ctc_best_path_decode(probs_seq, vocabulary): @@ -36,7 +37,11 @@ def ctc_best_path_decode(probs_seq, vocabulary): return ''.join([vocabulary[index] for index in index_list]) -def ctc_decode(probs_seq, vocabulary, method): +def ctc_decode(probs_seq, + vocabulary, + method, + beam_size=None, + num_results_per_sample=None): """ CTC-like sequence decoding from a sequence of likelihood probablilites. @@ -56,5 +61,12 @@ def ctc_decode(probs_seq, vocabulary, method): raise ValueError("probs dimension mismatchedd with vocabulary") if method == "best_path": return ctc_best_path_decode(probs_seq, vocabulary) + elif method == "beam_search": + return ctc_beam_search_decoder( + input_probs_matrix=probs_seq, + vocabulary=vocabulary, + beam_size=beam_size, + blank_id=len(vocabulary), + num_results_per_sample=num_results_per_sample) else: - raise ValueError("Decoding method [%s] is not supported.") + raise ValueError("Decoding method [%s] is not supported." % method) diff --git a/infer.py b/infer.py index 598c348b..e5ecf6f3 100644 --- a/infer.py +++ b/infer.py @@ -57,6 +57,23 @@ parser.add_argument( default='data/eng_vocab.txt', type=str, help="Vocabulary filepath. (default: %(default)s)") +parser.add_argument( + "--decode_method", + default='best_path', + type=str, + help="Method for ctc decoding, best_path or beam_search. (default: %(default)s)" +) +parser.add_argument( + "--beam_size", + default=50, + type=int, + help="Width for beam search decoding. (default: %(default)d)") +parser.add_argument( + "--num_result_per_sample", + default=2, + type=int, + help="Number of results per given sample in beam search. (default: %(default)d)" +) args = parser.parse_args() @@ -120,12 +137,22 @@ def infer(): # decode and print for i, probs in enumerate(probs_split): - output_transcription = ctc_decode( + best_path_transcription = ctc_decode( probs_seq=probs, vocabulary=vocab_list, method="best_path") target_transcription = ''.join( [vocab_list[index] for index in infer_data[i][1]]) - print("Target Transcription: %s \nOutput Transcription: %s \n" % - (target_transcription, output_transcription)) + print("\nTarget Transcription: %s \nBst_path Transcription: %s" % + (target_transcription, best_path_transcription)) + beam_search_transcription = ctc_decode( + probs_seq=probs, + vocabulary=vocab_list, + method="beam_search", + beam_size=args.beam_size, + num_results_per_sample=args.num_result_per_sample) + for index in range(len(beam_search_transcription)): + print("LM No, %d - %4f: %s " % + (index, beam_search_transcription[index][0], + beam_search_transcription[index][1])) def main(): From 51f35a53723779f042498e14786abb791d278c50 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Wed, 7 Jun 2017 02:21:35 +0800 Subject: [PATCH 04/28] code clean & add external scorer --- ctc_beam_search_decoder.py | 192 -------------------------------- decoder.py | 184 +++++++++++++++++++++++++----- infer.py | 72 ++++++++---- test_ctc_beam_search_decoder.py | 69 ------------ 4 files changed, 205 insertions(+), 312 deletions(-) delete mode 100644 ctc_beam_search_decoder.py delete mode 100644 test_ctc_beam_search_decoder.py diff --git a/ctc_beam_search_decoder.py b/ctc_beam_search_decoder.py deleted file mode 100644 index f66d545a..00000000 --- a/ctc_beam_search_decoder.py +++ /dev/null @@ -1,192 +0,0 @@ -## This is a prototype of ctc beam search decoder - -import copy -import random -import numpy as np - -# vocab = blank + space + English characters -#vocab = ['-', ' '] + [chr(i) for i in range(97, 123)] - -vocab = ['-', '_', 'a'] - - -def ids_list2str(ids_list): - ids_str = [str(elem) for elem in ids_list] - ids_str = ' '.join(ids_str) - return ids_str - - -def ids_id2token(ids_list): - ids_str = '' - for ids in ids_list: - ids_str += vocab[ids] - return ids_str - - -def language_model(ids_list, vocabulary): - # lookup ptb vocabulary - ptb_vocab_path = "./data/ptb_vocab.txt" - sentence = ''.join([vocabulary[ids] for ids in ids_list]) - words = sentence.split(' ') - last_word = words[-1] - with open(ptb_vocab_path, 'r') as ptb_vocab: - f = ptb_vocab.readline() - while f: - if f == last_word: - return 1.0 - f = ptb_vocab.readline() - return 0.0 - - -def ctc_beam_search_decoder(input_probs_matrix, - beam_size, - vocabulary, - max_time_steps=None, - lang_model=language_model, - alpha=1.0, - beta=1.0, - blank_id=0, - space_id=1, - num_results_per_sample=None): - ''' - Beam search decoder for CTC-trained network, adapted from Algorithm 1 - in https://arxiv.org/abs/1408.2873. - - :param input_probs_matrix: probs matrix for input sequence, row major - :type input_probs_matrix: 2D matrix. - :param beam_size: width for beam search - :type beam_size: int - :max_time_steps: maximum steps' number for input sequence, - <=len(input_probs_matrix) - :type max_time_steps: int - :lang_model: language model for scoring - :type lang_model: function - :param alpha: parameter associated with language model. - :type alpha: float - :param beta: parameter associated with word count - :type beta: float - :param blank_id: id of blank, default 0. - :type blank_id: int - :param space_id: id of space, default 1. - :type space_id: int - :param num_result_per_sample: the number of output decoding results - per given sample, <=beam_size. - :type num_result_per_sample: int - ''' - - # function to convert ids in string to list - def ids_str2list(ids_str): - ids_str = ids_str.split(' ') - ids_list = [int(elem) for elem in ids_str] - return ids_list - - # counting words in a character list - def word_count(ids_list): - cnt = 0 - for elem in ids_list: - if elem == space_id: - cnt += 1 - return cnt - - if num_results_per_sample is None: - num_results_per_sample = beam_size - assert num_results_per_sample <= beam_size - - if max_time_steps is None: - max_time_steps = len(input_probs_matrix) - else: - max_time_steps = min(max_time_steps, len(input_probs_matrix)) - assert max_time_steps > 0 - - vocab_dim = len(input_probs_matrix[0]) - assert blank_id < vocab_dim - assert space_id < vocab_dim - - ## initialize - start_id = -1 - # the set containing selected prefixes - prefix_set_prev = {str(start_id): 1.0} - probs_b, probs_nb = {str(start_id): 1.0}, {str(start_id): 0.0} - - ## extend prefix in loop - for time_step in range(max_time_steps): - # the set containing candidate prefixes - prefix_set_next = {} - probs_b_cur, probs_nb_cur = {}, {} - for l in prefix_set_prev: - prob = input_probs_matrix[time_step] - - # convert ids in string to list - ids_list = ids_str2list(l) - end_id = ids_list[-1] - if not prefix_set_next.has_key(l): - probs_b_cur[l], probs_nb_cur[l] = 0.0, 0.0 - - # extend prefix by travering vocabulary - for c in range(0, vocab_dim): - if c == blank_id: - probs_b_cur[l] += prob[c] * (probs_b[l] + probs_nb[l]) - else: - l_plus = l + ' ' + str(c) - if not prefix_set_next.has_key(l_plus): - probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0 - - if c == end_id: - probs_nb_cur[l_plus] += prob[c] * probs_b[l] - probs_nb_cur[l] += prob[c] * probs_nb[l] - elif c == space_id: - lm = 1.0 if lang_model is None \ - else np.power(lang_model(ids_list, vocabulary), alpha) - probs_nb_cur[l_plus] += lm * prob[c] * ( - probs_b[l] + probs_nb[l]) - else: - probs_nb_cur[l_plus] += prob[c] * ( - probs_b[l] + probs_nb[l]) - # add l_plus into prefix_set_next - prefix_set_next[l_plus] = probs_nb_cur[ - l_plus] + probs_b_cur[l_plus] - # add l into prefix_set_next - prefix_set_next[l] = probs_b_cur[l] + probs_nb_cur[l] - # update probs - probs_b, probs_nb = copy.deepcopy(probs_b_cur), copy.deepcopy( - probs_nb_cur) - - ## store top beam_size prefixes - prefix_set_prev = sorted( - prefix_set_next.iteritems(), key=lambda asd: asd[1], reverse=True) - if beam_size < len(prefix_set_prev): - prefix_set_prev = prefix_set_prev[:beam_size] - prefix_set_prev = dict(prefix_set_prev) - - beam_result = [] - for (seq, prob) in prefix_set_prev.items(): - if prob > 0.0: - ids_list = ids_str2list(seq)[1:] - result = ''.join([vocabulary[ids] for ids in ids_list]) - log_prob = np.log(prob) - beam_result.append([log_prob, result]) - - ## output top beam_size decoding results - beam_result = sorted(beam_result, key=lambda asd: asd[0], reverse=True) - if num_results_per_sample < beam_size: - beam_result = beam_result[:num_results_per_sample] - return beam_result - - -def simple_test(): - - input_probs_matrix = [[0.1, 0.3, 0.6], [0.2, 0.1, 0.7], [0.5, 0.2, 0.3]] - - beam_result = ctc_beam_search_decoder( - input_probs_matrix=input_probs_matrix, - beam_size=20, - blank_id=0, - space_id=1, ) - - print "\nbeam search output:" - for result in beam_result: - print("%6f\t%s" % (result[0], ids_id2token(result[1]))) - - -if __name__ == '__main__': - simple_test() diff --git a/decoder.py b/decoder.py index 34e1715c..91dbfc34 100755 --- a/decoder.py +++ b/decoder.py @@ -4,7 +4,8 @@ from itertools import groupby import numpy as np -from ctc_beam_search_decoder import * +import copy +import kenlm def ctc_best_path_decode(probs_seq, vocabulary): @@ -37,36 +38,165 @@ def ctc_best_path_decode(probs_seq, vocabulary): return ''.join([vocabulary[index] for index in index_list]) -def ctc_decode(probs_seq, - vocabulary, - method, - beam_size=None, - num_results_per_sample=None): +class Scorer(object): """ - CTC-like sequence decoding from a sequence of likelihood probablilites. + External defined scorer to evaluate a sentence in beam search + decoding, consisting of language model and word count. - :param probs_seq: 2-D list of probabilities over the vocabulary for each - character. Each element is a list of float probabilities - for one character. - :type probs_seq: list + :param alpha: Parameter associated with language model. + :type alpha: float + :param beta: Parameter associated with word count. + :type beta: float + :model_path: Path to load language model. + :type model_path: basestring + """ + + def __init__(self, alpha, beta, model_path): + + self._alpha = alpha + self._beta = beta + self._language_model = kenlm.LanguageModel(model_path) + + def language_model_score(self, sentence, bos=True, eos=False): + log_prob = self._language_model.score(sentence, bos, eos) + return np.power(10, log_prob) + + def word_count(self, sentence): + words = sentence.strip().split(' ') + return len(words) + + # execute evaluation + def evaluate(self, sentence, bos=True, eos=False): + lm = self.language_model_score(sentence, bos, eos) + word_count = self.word_count(sentence) + score = np.power(lm, self._alpha) \ + * np.power(word_count, self._beta) + return score + + +def ctc_beam_search_decoder(probs_seq, + beam_size, + vocabulary, + ext_scoring_func=None, + blank_id=0): + ''' + Beam search decoder for CTC-trained network, using beam search with width + beam_size to find many paths to one label, return beam_size labels in + the order of probabilities. The implementation is based on Prefix Beam + Search(https://arxiv.org/abs/1408.2873), and the unclear part is + redesigned, need to be verified. + + :param probs_seq: 2-D list with length max_time_steps, each element + is a list of normalized probabilities over vocabulary + and blank for one time step. + :type probs_seq: 2-D list + :param beam_size: Width for beam search. + :type beam_size: int :param vocabulary: Vocabulary list. :type vocabulary: list - :param method: Decoding method name, with options: "best_path". - :type method: basestring - :return: Decoding result string. - :rtype: baseline - """ + :param ext_scoring_func: External defined scoring function for + partially decoded sentence, e.g. word count + and language model. + :type external_scoring_function: function + :param blank_id: id of blank, default 0. + :type blank_id: int + :return: Decoding log probability and result string. + :rtype: list + + ''' + for prob_list in probs_seq: if not len(prob_list) == len(vocabulary) + 1: raise ValueError("probs dimension mismatchedd with vocabulary") - if method == "best_path": - return ctc_best_path_decode(probs_seq, vocabulary) - elif method == "beam_search": - return ctc_beam_search_decoder( - input_probs_matrix=probs_seq, - vocabulary=vocabulary, - beam_size=beam_size, - blank_id=len(vocabulary), - num_results_per_sample=num_results_per_sample) - else: - raise ValueError("Decoding method [%s] is not supported." % method) + + max_time_steps = len(probs_seq) + if not max_time_steps > 0: + raise ValueError("probs_seq shouldn't be empty") + + probs_dim = len(probs_seq[0]) + if not blank_id < probs_dim: + raise ValueError("blank_id shouldn't be greater than probs dimension") + + if ' ' not in vocabulary: + raise ValueError("space doesn't exist in vocabulary") + space_id = vocabulary.index(' ') + + # function to convert ids in string to list + def ids_str2list(ids_str): + ids_str = ids_str.split(' ') + ids_list = [int(elem) for elem in ids_str] + return ids_list + + # function to convert ids list to sentence + def ids2sentence(ids_list, vocab): + return ''.join([vocab[ids] for ids in ids_list]) + + ## initialize + # the set containing selected prefixes + prefix_set_prev = {'-1': 1.0} + probs_b, probs_nb = {'-1': 1.0}, {'-1': 0.0} + + ## extend prefix in loop + for time_step in range(max_time_steps): + # the set containing candidate prefixes + prefix_set_next = {} + probs_b_cur, probs_nb_cur = {}, {} + for l in prefix_set_prev: + prob = probs_seq[time_step] + + # convert ids in string to list + ids_list = ids_str2list(l) + end_id = ids_list[-1] + if not prefix_set_next.has_key(l): + probs_b_cur[l], probs_nb_cur[l] = 0.0, 0.0 + + # extend prefix by travering vocabulary + for c in range(0, probs_dim): + if c == blank_id: + probs_b_cur[l] += prob[c] * (probs_b[l] + probs_nb[l]) + else: + l_plus = l + ' ' + str(c) + if not prefix_set_next.has_key(l_plus): + probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0 + + if c == end_id: + probs_nb_cur[l_plus] += prob[c] * probs_b[l] + probs_nb_cur[l] += prob[c] * probs_nb[l] + elif c == space_id: + if ext_scoring_func is None: + score = 1.0 + else: + prefix_sent = ids2sentence(ids_list, vocabulary) + score = ext_scoring_func(prefix_sent) + probs_nb_cur[l_plus] += score * prob[c] * ( + probs_b[l] + probs_nb[l]) + else: + probs_nb_cur[l_plus] += prob[c] * ( + probs_b[l] + probs_nb[l]) + # add l_plus into prefix_set_next + prefix_set_next[l_plus] = probs_nb_cur[ + l_plus] + probs_b_cur[l_plus] + # add l into prefix_set_next + prefix_set_next[l] = probs_b_cur[l] + probs_nb_cur[l] + # update probs + probs_b, probs_nb = copy.deepcopy(probs_b_cur), copy.deepcopy( + probs_nb_cur) + + ## store top beam_size prefixes + prefix_set_prev = sorted( + prefix_set_next.iteritems(), key=lambda asd: asd[1], reverse=True) + if beam_size < len(prefix_set_prev): + prefix_set_prev = prefix_set_prev[:beam_size] + prefix_set_prev = dict(prefix_set_prev) + + beam_result = [] + for (seq, prob) in prefix_set_prev.items(): + if prob > 0.0: + ids_list = ids_str2list(seq)[1:] + result = ids2sentence(ids_list, vocabulary) + log_prob = np.log(prob) + beam_result.append([log_prob, result]) + + ## output top beam_size decoding results + beam_result = sorted(beam_result, key=lambda asd: asd[0], reverse=True) + return beam_result diff --git a/infer.py b/infer.py index e5ecf6f3..dc46b83e 100644 --- a/infer.py +++ b/infer.py @@ -8,7 +8,7 @@ import argparse import gzip from audio_data_utils import DataGenerator from model import deep_speech2 -from decoder import ctc_decode +from decoder import * parser = argparse.ArgumentParser( description='Simplified version of DeepSpeech2 inference.') @@ -59,7 +59,7 @@ parser.add_argument( help="Vocabulary filepath. (default: %(default)s)") parser.add_argument( "--decode_method", - default='best_path', + default='beam_search', type=str, help="Method for ctc decoding, best_path or beam_search. (default: %(default)s)" ) @@ -69,11 +69,25 @@ parser.add_argument( type=int, help="Width for beam search decoding. (default: %(default)d)") parser.add_argument( - "--num_result_per_sample", - default=2, + "--num_results_per_sample", + default=1, type=int, - help="Number of results per given sample in beam search. (default: %(default)d)" -) + help="Number of output per sample in beam search. (default: %(default)d)") +parser.add_argument( + "--language_model_path", + default="./data/1Billion.klm", + type=str, + help="Path for language model. (default: %(default)d)") +parser.add_argument( + "--alpha", + default=0.0, + type=float, + help="Parameter associated with language model. (default: %(default)f)") +parser.add_argument( + "--beta", + default=0.0, + type=float, + help="Parameter associated with word count. (default: %(default)f)") args = parser.parse_args() @@ -135,24 +149,34 @@ def infer(): for i in xrange(0, len(infer_data)) ] - # decode and print - for i, probs in enumerate(probs_split): - best_path_transcription = ctc_decode( - probs_seq=probs, vocabulary=vocab_list, method="best_path") - target_transcription = ''.join( - [vocab_list[index] for index in infer_data[i][1]]) - print("\nTarget Transcription: %s \nBst_path Transcription: %s" % - (target_transcription, best_path_transcription)) - beam_search_transcription = ctc_decode( - probs_seq=probs, - vocabulary=vocab_list, - method="beam_search", - beam_size=args.beam_size, - num_results_per_sample=args.num_result_per_sample) - for index in range(len(beam_search_transcription)): - print("LM No, %d - %4f: %s " % - (index, beam_search_transcription[index][0], - beam_search_transcription[index][1])) + ## decode and print + # best path decode + if args.decode_method == "best_path": + for i, probs in enumerate(probs_split): + target_transcription = ''.join( + [vocab_list[index] for index in infer_data[i][1]]) + best_path_transcription = ctc_best_path_decode( + probs_seq=probs, vocabulary=vocab_list) + print("\nTarget Transcription: %s\nOutput Transcription: %s" % + (target_transcription, best_path_transcription)) + # beam search decode + elif args.decode_method == "beam_search": + for i, probs in enumerate(probs_split): + target_transcription = ''.join( + [vocab_list[index] for index in infer_data[i][1]]) + ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path) + beam_search_result = ctc_beam_search_decoder( + probs_seq=probs, + vocabulary=vocab_list, + beam_size=args.beam_size, + ext_scoring_func=ext_scorer.evaluate, + blank_id=len(vocab_list)) + print("\nTarget Transcription:\t%s" % target_transcription) + for index in range(args.num_results_per_sample): + result = beam_search_result[index] + print("Beam %d: %f \t%s" % (index, result[0], result[1])) + else: + raise ValueError("Decoding method [%s] is not supported." % method) def main(): diff --git a/test_ctc_beam_search_decoder.py b/test_ctc_beam_search_decoder.py deleted file mode 100644 index f7970444..00000000 --- a/test_ctc_beam_search_decoder.py +++ /dev/null @@ -1,69 +0,0 @@ -from __future__ import absolute_import -from __future__ import print_function - -import numpy as np -import tensorflow as tf -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -import ctc_beam_search_decoder as tested_decoder - - -def test_beam_search_decoder(): - max_time_steps = 6 - beam_size = 20 - num_results_per_sample = 20 - - input_prob_matrix_0 = np.asarray( - [ - [0.30999, 0.309938, 0.0679938, 0.0673362, 0.0708352, 0.173908], - [0.215136, 0.439699, 0.0370931, 0.0393967, 0.0381581, 0.230517], - [0.199959, 0.489485, 0.0233221, 0.0251417, 0.0233289, 0.238763], - [0.279611, 0.452966, 0.0204795, 0.0209126, 0.0194803, 0.20655], - [0.51286, 0.288951, 0.0243026, 0.0220788, 0.0219297, 0.129878], - # Random entry added in at time=5 - [0.155251, 0.164444, 0.173517, 0.176138, 0.169979, 0.160671] - ], - dtype=np.float32) - - # Add arbitrary offset - this is fine - input_log_prob_matrix_0 = np.log(input_prob_matrix_0) #+ 2.0 - - # len max_time_steps array of batch_size x depth matrices - inputs = ([ - input_log_prob_matrix_0[t, :][np.newaxis, :] - for t in range(max_time_steps) - ]) - - inputs_t = [ops.convert_to_tensor(x) for x in inputs] - inputs_t = array_ops.stack(inputs_t) - - # run CTC beam search decoder in tensorflow - with tf.Session() as sess: - decoded, log_probabilities = tf.nn.ctc_beam_search_decoder( - inputs_t, [max_time_steps], - beam_width=beam_size, - top_paths=num_results_per_sample, - merge_repeated=False) - tf_decoded = sess.run(decoded) - tf_log_probs = sess.run(log_probabilities) - - # run tested CTC beam search decoder - beam_result = tested_decoder.ctc_beam_search_decoder( - input_probs_matrix=input_prob_matrix_0, - beam_size=beam_size, - blank_id=5, # default blank_id in tensorflow decoder is (num classes-1) - space_id=4, # doesn't matter - max_time_steps=max_time_steps, - num_results_per_sample=num_results_per_sample) - - # compare decoding result - print( - "{tf_decoder log probs} \t {tested_decoder log probs}: {tf_decoder result} {tested_decoder result}" - ) - for index in range(len(beam_result)): - print(('%6f\t%6f: ') % (tf_log_probs[0][index], beam_result[index][0]), - tf_decoded[index].values, ' ', beam_result[index][1]) - - -if __name__ == '__main__': - test_beam_search_decoder() From ac370eca850825cc3cd075f47903722e2805fc5a Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Wed, 7 Jun 2017 09:06:58 +0800 Subject: [PATCH 05/28] add annotations --- decoder.py | 11 +++++------ infer.py | 5 +++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/decoder.py b/decoder.py index 91dbfc34..e16d1054 100755 --- a/decoder.py +++ b/decoder.py @@ -68,9 +68,9 @@ class Scorer(object): # execute evaluation def evaluate(self, sentence, bos=True, eos=False): lm = self.language_model_score(sentence, bos, eos) - word_count = self.word_count(sentence) + word_cnt = self.word_count(sentence) score = np.power(lm, self._alpha) \ - * np.power(word_count, self._beta) + * np.power(word_cnt, self._beta) return score @@ -104,19 +104,18 @@ def ctc_beam_search_decoder(probs_seq, :rtype: list ''' - + # dimension check for prob_list in probs_seq: if not len(prob_list) == len(vocabulary) + 1: raise ValueError("probs dimension mismatchedd with vocabulary") - max_time_steps = len(probs_seq) - if not max_time_steps > 0: - raise ValueError("probs_seq shouldn't be empty") + # blank_id check probs_dim = len(probs_seq[0]) if not blank_id < probs_dim: raise ValueError("blank_id shouldn't be greater than probs dimension") + # assign space_id if ' ' not in vocabulary: raise ValueError("space doesn't exist in vocabulary") space_id = vocabulary.index(' ') diff --git a/infer.py b/infer.py index dc46b83e..be7ecad9 100644 --- a/infer.py +++ b/infer.py @@ -77,7 +77,7 @@ parser.add_argument( "--language_model_path", default="./data/1Billion.klm", type=str, - help="Path for language model. (default: %(default)d)") + help="Path for language model. (default: %(default)s)") parser.add_argument( "--alpha", default=0.0, @@ -93,7 +93,7 @@ args = parser.parse_args() def infer(): """ - Max-ctc-decoding for DeepSpeech2. + Inference for DeepSpeech2. """ # initialize data generator data_generator = DataGenerator( @@ -174,6 +174,7 @@ def infer(): print("\nTarget Transcription:\t%s" % target_transcription) for index in range(args.num_results_per_sample): result = beam_search_result[index] + #output: index, log prob, beam result print("Beam %d: %f \t%s" % (index, result[0], result[1])) else: raise ValueError("Decoding method [%s] is not supported." % method) From 21ff590e6d905c9b8d0bba5159d996b8ba23e599 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Wed, 7 Jun 2017 14:57:04 +0800 Subject: [PATCH 06/28] modify language model scoring --- decoder.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/decoder.py b/decoder.py index e16d1054..458cd9ad 100755 --- a/decoder.py +++ b/decoder.py @@ -52,13 +52,19 @@ class Scorer(object): """ def __init__(self, alpha, beta, model_path): - self._alpha = alpha self._beta = beta self._language_model = kenlm.LanguageModel(model_path) def language_model_score(self, sentence, bos=True, eos=False): - log_prob = self._language_model.score(sentence, bos, eos) + words = sentence.strip().split(' ') + length = len(words) + if length == 1: + log_prob = self._language_model.score(sentence, bos, eos) + else: + prefix_sent = ' '.join(words[0:length - 1]) + log_prob = self._language_model.score(sentence, bos, eos) \ + - self._language_model.score(prefix_sent, bos, eos) return np.power(10, log_prob) def word_count(self, sentence): From 44efbed798966f1d57276e5fde3d8541e8fddc48 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Wed, 7 Jun 2017 16:59:11 +0800 Subject: [PATCH 07/28] rename variables in decoder --- decoder.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/decoder.py b/decoder.py index 458cd9ad..d5bd72f6 100755 --- a/decoder.py +++ b/decoder.py @@ -92,7 +92,7 @@ def ctc_beam_search_decoder(probs_seq, Search(https://arxiv.org/abs/1408.2873), and the unclear part is redesigned, need to be verified. - :param probs_seq: 2-D list with length max_time_steps, each element + :param probs_seq: 2-D list with length num_time_steps, each element is a list of normalized probabilities over vocabulary and blank for one time step. :type probs_seq: 2-D list @@ -114,7 +114,7 @@ def ctc_beam_search_decoder(probs_seq, for prob_list in probs_seq: if not len(prob_list) == len(vocabulary) + 1: raise ValueError("probs dimension mismatchedd with vocabulary") - max_time_steps = len(probs_seq) + num_time_steps = len(probs_seq) # blank_id check probs_dim = len(probs_seq[0]) @@ -139,10 +139,10 @@ def ctc_beam_search_decoder(probs_seq, ## initialize # the set containing selected prefixes prefix_set_prev = {'-1': 1.0} - probs_b, probs_nb = {'-1': 1.0}, {'-1': 0.0} + probs_b_prev, probs_nb_prev = {'-1': 1.0}, {'-1': 0.0} ## extend prefix in loop - for time_step in range(max_time_steps): + for time_step in range(num_time_steps): # the set containing candidate prefixes prefix_set_next = {} probs_b_cur, probs_nb_cur = {}, {} @@ -158,33 +158,34 @@ def ctc_beam_search_decoder(probs_seq, # extend prefix by travering vocabulary for c in range(0, probs_dim): if c == blank_id: - probs_b_cur[l] += prob[c] * (probs_b[l] + probs_nb[l]) + probs_b_cur[l] += prob[c] * ( + probs_b_prev[l] + probs_nb_prev[l]) else: l_plus = l + ' ' + str(c) if not prefix_set_next.has_key(l_plus): probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0 if c == end_id: - probs_nb_cur[l_plus] += prob[c] * probs_b[l] - probs_nb_cur[l] += prob[c] * probs_nb[l] + probs_nb_cur[l_plus] += prob[c] * probs_b_prev[l] + probs_nb_cur[l] += prob[c] * probs_nb_prev[l] elif c == space_id: if ext_scoring_func is None: score = 1.0 else: - prefix_sent = ids2sentence(ids_list, vocabulary) - score = ext_scoring_func(prefix_sent) + prefix = ids2sentence(ids_list, vocabulary) + score = ext_scoring_func(prefix) probs_nb_cur[l_plus] += score * prob[c] * ( - probs_b[l] + probs_nb[l]) + probs_b_prev[l] + probs_nb_prev[l]) else: probs_nb_cur[l_plus] += prob[c] * ( - probs_b[l] + probs_nb[l]) + probs_b_prev[l] + probs_nb_prev[l]) # add l_plus into prefix_set_next prefix_set_next[l_plus] = probs_nb_cur[ l_plus] + probs_b_cur[l_plus] # add l into prefix_set_next prefix_set_next[l] = probs_b_cur[l] + probs_nb_cur[l] # update probs - probs_b, probs_nb = copy.deepcopy(probs_b_cur), copy.deepcopy( + probs_b_prev, probs_nb_prev = copy.deepcopy(probs_b_cur), copy.deepcopy( probs_nb_cur) ## store top beam_size prefixes From b046e651e7d41a8332fc49096383d5777f2dc2c2 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Wed, 7 Jun 2017 17:43:12 +0800 Subject: [PATCH 08/28] tiny modify to pass CI --- decoder.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/decoder.py b/decoder.py index d5bd72f6..b7ed4045 100755 --- a/decoder.py +++ b/decoder.py @@ -56,6 +56,7 @@ class Scorer(object): self._beta = beta self._language_model = kenlm.LanguageModel(model_path) + # language model scoring def language_model_score(self, sentence, bos=True, eos=False): words = sentence.strip().split(' ') length = len(words) @@ -67,6 +68,7 @@ class Scorer(object): - self._language_model.score(prefix_sent, bos, eos) return np.power(10, log_prob) + # word insertion term def word_count(self, sentence): words = sentence.strip().split(' ') return len(words) From 9fda521ee3e067291560c7f4816d0540d808fb22 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Wed, 7 Jun 2017 19:24:04 +0800 Subject: [PATCH 09/28] improve external scorer --- decoder.py | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/decoder.py b/decoder.py index b7ed4045..05400d1b 100755 --- a/decoder.py +++ b/decoder.py @@ -6,6 +6,7 @@ from itertools import groupby import numpy as np import copy import kenlm +import os def ctc_best_path_decode(probs_seq, vocabulary): @@ -54,19 +55,16 @@ class Scorer(object): def __init__(self, alpha, beta, model_path): self._alpha = alpha self._beta = beta + if not os.path.isfile(model_path): + raise IOError("Invaid language model path: %s" % model_path) self._language_model = kenlm.LanguageModel(model_path) - # language model scoring - def language_model_score(self, sentence, bos=True, eos=False): - words = sentence.strip().split(' ') - length = len(words) - if length == 1: - log_prob = self._language_model.score(sentence, bos, eos) - else: - prefix_sent = ' '.join(words[0:length - 1]) - log_prob = self._language_model.score(sentence, bos, eos) \ - - self._language_model.score(prefix_sent, bos, eos) - return np.power(10, log_prob) + # n-gram language model scoring + def language_model_score(self, sentence): + #log prob of last word + log_cond_prob = list( + self._language_model.full_scores(sentence, eos=False))[-1][0] + return np.power(10, log_cond_prob) # word insertion term def word_count(self, sentence): @@ -74,8 +72,8 @@ class Scorer(object): return len(words) # execute evaluation - def evaluate(self, sentence, bos=True, eos=False): - lm = self.language_model_score(sentence, bos, eos) + def evaluate(self, sentence): + lm = self.language_model_score(sentence) word_cnt = self.word_count(sentence) score = np.power(lm, self._alpha) \ * np.power(word_cnt, self._beta) From 453f038df91fc56ea24ff09e85def14194f32ee7 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Thu, 8 Jun 2017 16:05:40 +0800 Subject: [PATCH 10/28] optimize the efficiency of beam search --- decoder.py | 41 +++++++++++------------------------------ 1 file changed, 11 insertions(+), 30 deletions(-) diff --git a/decoder.py b/decoder.py index 05400d1b..0eab3651 100755 --- a/decoder.py +++ b/decoder.py @@ -121,25 +121,10 @@ def ctc_beam_search_decoder(probs_seq, if not blank_id < probs_dim: raise ValueError("blank_id shouldn't be greater than probs dimension") - # assign space_id - if ' ' not in vocabulary: - raise ValueError("space doesn't exist in vocabulary") - space_id = vocabulary.index(' ') - - # function to convert ids in string to list - def ids_str2list(ids_str): - ids_str = ids_str.split(' ') - ids_list = [int(elem) for elem in ids_str] - return ids_list - - # function to convert ids list to sentence - def ids2sentence(ids_list, vocab): - return ''.join([vocab[ids] for ids in ids_list]) - ## initialize # the set containing selected prefixes - prefix_set_prev = {'-1': 1.0} - probs_b_prev, probs_nb_prev = {'-1': 1.0}, {'-1': 0.0} + prefix_set_prev = {'\t': 1.0} + probs_b_prev, probs_nb_prev = {'\t': 1.0}, {'\t': 0.0} ## extend prefix in loop for time_step in range(num_time_steps): @@ -148,10 +133,6 @@ def ctc_beam_search_decoder(probs_seq, probs_b_cur, probs_nb_cur = {}, {} for l in prefix_set_prev: prob = probs_seq[time_step] - - # convert ids in string to list - ids_list = ids_str2list(l) - end_id = ids_list[-1] if not prefix_set_next.has_key(l): probs_b_cur[l], probs_nb_cur[l] = 0.0, 0.0 @@ -161,18 +142,20 @@ def ctc_beam_search_decoder(probs_seq, probs_b_cur[l] += prob[c] * ( probs_b_prev[l] + probs_nb_prev[l]) else: - l_plus = l + ' ' + str(c) + last_char = l[-1] + new_char = vocabulary[c] + l_plus = l + new_char if not prefix_set_next.has_key(l_plus): probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0 - if c == end_id: + if new_char == last_char: probs_nb_cur[l_plus] += prob[c] * probs_b_prev[l] probs_nb_cur[l] += prob[c] * probs_nb_prev[l] - elif c == space_id: - if ext_scoring_func is None: + elif new_char == ' ': + if (ext_scoring_func is None) or (len(l) == 1): score = 1.0 else: - prefix = ids2sentence(ids_list, vocabulary) + prefix = l[1:] score = ext_scoring_func(prefix) probs_nb_cur[l_plus] += score * prob[c] * ( probs_b_prev[l] + probs_nb_prev[l]) @@ -185,8 +168,7 @@ def ctc_beam_search_decoder(probs_seq, # add l into prefix_set_next prefix_set_next[l] = probs_b_cur[l] + probs_nb_cur[l] # update probs - probs_b_prev, probs_nb_prev = copy.deepcopy(probs_b_cur), copy.deepcopy( - probs_nb_cur) + probs_b_prev, probs_nb_prev = probs_b_cur, probs_nb_cur ## store top beam_size prefixes prefix_set_prev = sorted( @@ -198,8 +180,7 @@ def ctc_beam_search_decoder(probs_seq, beam_result = [] for (seq, prob) in prefix_set_prev.items(): if prob > 0.0: - ids_list = ids_str2list(seq)[1:] - result = ids2sentence(ids_list, vocabulary) + result = seq[1:] log_prob = np.log(prob) beam_result.append([log_prob, result]) From ae83a25affafda71f004538b72309c5043f6667b Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Mon, 12 Jun 2017 17:13:48 +0800 Subject: [PATCH 11/28] add beam search decoder using multiprocesses --- decoder.py | 54 +++++++++++++++++++++++++++++++++++++++++++++++++++++- infer.py | 43 ++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 93 insertions(+), 4 deletions(-) diff --git a/decoder.py b/decoder.py index 0eab3651..fc746c70 100755 --- a/decoder.py +++ b/decoder.py @@ -2,11 +2,12 @@ CTC-like decoder utilitis. """ +import os from itertools import groupby import numpy as np import copy import kenlm -import os +import multiprocessing def ctc_best_path_decode(probs_seq, vocabulary): @@ -187,3 +188,54 @@ def ctc_beam_search_decoder(probs_seq, ## output top beam_size decoding results beam_result = sorted(beam_result, key=lambda asd: asd[0], reverse=True) return beam_result + + +def ctc_beam_search_decoder_nproc(probs_split, + beam_size, + vocabulary, + ext_scoring_func=None, + blank_id=0, + num_processes=None): + ''' + Beam search decoder using multiple processes. + + :param probs_seq: 3-D list with length num_time_steps, each element + is a 2-D list of probabilities can be used by + ctc_beam_search_decoder. + + :type probs_seq: 3-D list + :param beam_size: Width for beam search. + :type beam_size: int + :param vocabulary: Vocabulary list. + :type vocabulary: list + :param ext_scoring_func: External defined scoring function for + partially decoded sentence, e.g. word count + and language model. + :type external_scoring_function: function + :param blank_id: id of blank, default 0. + :type blank_id: int + :param num_processes: Number of processes, default None, equal to the + number of CPUs. + :type num_processes: int + :return: Decoding log probability and result string. + :rtype: list + + ''' + + if num_processes is None: + num_processes = multiprocessing.cpu_count() + if not num_processes > 0: + raise ValueError("Number of processes must be positive!") + + pool = multiprocessing.Pool(processes=num_processes) + results = [] + for i, probs_list in enumerate(probs_split): + args = (probs_list, beam_size, vocabulary, ext_scoring_func, blank_id) + results.append(pool.apply_async(ctc_beam_search_decoder, args)) + + pool.close() + pool.join() + beam_search_results = [] + for result in results: + beam_search_results.append(result.get()) + return beam_search_results diff --git a/infer.py b/infer.py index be7ecad9..377aeb73 100644 --- a/infer.py +++ b/infer.py @@ -9,6 +9,7 @@ import gzip from audio_data_utils import DataGenerator from model import deep_speech2 from decoder import * +from error_rate import wer parser = argparse.ArgumentParser( description='Simplified version of DeepSpeech2 inference.') @@ -59,9 +60,9 @@ parser.add_argument( help="Vocabulary filepath. (default: %(default)s)") parser.add_argument( "--decode_method", - default='beam_search', + default='beam_search_nproc', type=str, - help="Method for ctc decoding, best_path or beam_search. (default: %(default)s)" + help="Method for ctc decoding, best_path, beam_search or beam_search_nproc. (default: %(default)s)" ) parser.add_argument( "--beam_size", @@ -151,6 +152,7 @@ def infer(): ## decode and print # best path decode + wer_sum, wer_counter = 0, 0 if args.decode_method == "best_path": for i, probs in enumerate(probs_split): target_transcription = ''.join( @@ -159,12 +161,17 @@ def infer(): probs_seq=probs, vocabulary=vocab_list) print("\nTarget Transcription: %s\nOutput Transcription: %s" % (target_transcription, best_path_transcription)) + wer_cur = wer(target_transcription, best_path_transcription) + wer_sum += wer_cur + wer_counter += 1 + print("cur wer = %f, average wer = %f" % + (wer_cur, wer_sum / wer_counter)) # beam search decode elif args.decode_method == "beam_search": + ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path) for i, probs in enumerate(probs_split): target_transcription = ''.join( [vocab_list[index] for index in infer_data[i][1]]) - ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path) beam_search_result = ctc_beam_search_decoder( probs_seq=probs, vocabulary=vocab_list, @@ -172,10 +179,40 @@ def infer(): ext_scoring_func=ext_scorer.evaluate, blank_id=len(vocab_list)) print("\nTarget Transcription:\t%s" % target_transcription) + + for index in range(args.num_results_per_sample): + result = beam_search_result[index] + #output: index, log prob, beam result + print("Beam %d: %f \t%s" % (index, result[0], result[1])) + wer_cur = wer(target_transcription, beam_search_result[0][1]) + wer_sum += wer_cur + wer_counter += 1 + print("cur wer = %f , average wer = %f" % + (wer_cur, wer_sum / wer_counter)) + # beam search in multiple processes + elif args.decode_method == "beam_search_nproc": + ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path) + beam_search_nproc_results = ctc_beam_search_decoder_nproc( + probs_split=probs_split, + vocabulary=vocab_list, + beam_size=args.beam_size, + #ext_scoring_func=ext_scorer.evaluate, + ext_scoring_func=None, + blank_id=len(vocab_list)) + for i, beam_search_result in enumerate(beam_search_nproc_results): + target_transcription = ''.join( + [vocab_list[index] for index in infer_data[i][1]]) + print("\nTarget Transcription:\t%s" % target_transcription) + for index in range(args.num_results_per_sample): result = beam_search_result[index] #output: index, log prob, beam result print("Beam %d: %f \t%s" % (index, result[0], result[1])) + wer_cur = wer(target_transcription, beam_search_result[0][1]) + wer_sum += wer_cur + wer_counter += 1 + print("cur wer = %f , average wer = %f" % + (wer_cur, wer_sum / wer_counter)) else: raise ValueError("Decoding method [%s] is not supported." % method) From bb34e90398b71fca0c1e9ff88ab21e069db001ba Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Mon, 12 Jun 2017 17:20:22 +0800 Subject: [PATCH 12/28] correct typos in annotations --- decoder.py | 2 +- infer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/decoder.py b/decoder.py index fc746c70..96e91181 100755 --- a/decoder.py +++ b/decoder.py @@ -199,7 +199,7 @@ def ctc_beam_search_decoder_nproc(probs_split, ''' Beam search decoder using multiple processes. - :param probs_seq: 3-D list with length num_time_steps, each element + :param probs_seq: 3-D list with length batch_size, each element is a 2-D list of probabilities can be used by ctc_beam_search_decoder. diff --git a/infer.py b/infer.py index 377aeb73..0be89e61 100644 --- a/infer.py +++ b/infer.py @@ -189,7 +189,7 @@ def infer(): wer_counter += 1 print("cur wer = %f , average wer = %f" % (wer_cur, wer_sum / wer_counter)) - # beam search in multiple processes + # beam search using multiple processes elif args.decode_method == "beam_search_nproc": ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path) beam_search_nproc_results = ctc_beam_search_decoder_nproc( From 7db13ca9dbec998d5fff8f69c5fb5ec3d546352f Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Tue, 13 Jun 2017 14:16:54 +0800 Subject: [PATCH 13/28] enable lm in multiprocessing decoder & add script for params tuning --- decoder.py | 23 ++++-- infer.py | 9 ++- tune.py | 234 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 257 insertions(+), 9 deletions(-) create mode 100644 tune.py diff --git a/decoder.py b/decoder.py index 96e91181..824ac970 100755 --- a/decoder.py +++ b/decoder.py @@ -73,7 +73,7 @@ class Scorer(object): return len(words) # execute evaluation - def evaluate(self, sentence): + def __call__(self, sentence): lm = self.language_model_score(sentence) word_cnt = self.word_count(sentence) score = np.power(lm, self._alpha) \ @@ -84,8 +84,9 @@ class Scorer(object): def ctc_beam_search_decoder(probs_seq, beam_size, vocabulary, + blank_id=0, ext_scoring_func=None, - blank_id=0): + nproc=False): ''' Beam search decoder for CTC-trained network, using beam search with width beam_size to find many paths to one label, return beam_size labels in @@ -107,6 +108,8 @@ def ctc_beam_search_decoder(probs_seq, :type external_scoring_function: function :param blank_id: id of blank, default 0. :type blank_id: int + :param nproc: Whether the decoder used in multiprocesses. + :type nproc: bool :return: Decoding log probability and result string. :rtype: list @@ -122,6 +125,12 @@ def ctc_beam_search_decoder(probs_seq, if not blank_id < probs_dim: raise ValueError("blank_id shouldn't be greater than probs dimension") + # If the decoder called in the multiprocesses, then use the global scorer + # instantiated in ctc_beam_search_decoder_nproc(). + if nproc is True: + global ext_nproc_scorer + ext_scoring_func = ext_nproc_scorer + ## initialize # the set containing selected prefixes prefix_set_prev = {'\t': 1.0} @@ -193,8 +202,8 @@ def ctc_beam_search_decoder(probs_seq, def ctc_beam_search_decoder_nproc(probs_split, beam_size, vocabulary, - ext_scoring_func=None, blank_id=0, + ext_scoring_func=None, num_processes=None): ''' Beam search decoder using multiple processes. @@ -202,7 +211,6 @@ def ctc_beam_search_decoder_nproc(probs_split, :param probs_seq: 3-D list with length batch_size, each element is a 2-D list of probabilities can be used by ctc_beam_search_decoder. - :type probs_seq: 3-D list :param beam_size: Width for beam search. :type beam_size: int @@ -227,10 +235,15 @@ def ctc_beam_search_decoder_nproc(probs_split, if not num_processes > 0: raise ValueError("Number of processes must be positive!") + # use global variable to pass the externnal scorer to beam search decoder + global ext_nproc_scorer + ext_nproc_scorer = ext_scoring_func + nproc = True + pool = multiprocessing.Pool(processes=num_processes) results = [] for i, probs_list in enumerate(probs_split): - args = (probs_list, beam_size, vocabulary, ext_scoring_func, blank_id) + args = (probs_list, beam_size, vocabulary, blank_id, None, nproc) results.append(pool.apply_async(ctc_beam_search_decoder, args)) pool.close() diff --git a/infer.py b/infer.py index 0be89e61..0bae1312 100644 --- a/infer.py +++ b/infer.py @@ -9,6 +9,7 @@ import gzip from audio_data_utils import DataGenerator from model import deep_speech2 from decoder import * +import kenlm from error_rate import wer parser = argparse.ArgumentParser( @@ -176,7 +177,7 @@ def infer(): probs_seq=probs, vocabulary=vocab_list, beam_size=args.beam_size, - ext_scoring_func=ext_scorer.evaluate, + ext_scoring_func=ext_scorer, blank_id=len(vocab_list)) print("\nTarget Transcription:\t%s" % target_transcription) @@ -196,9 +197,9 @@ def infer(): probs_split=probs_split, vocabulary=vocab_list, beam_size=args.beam_size, - #ext_scoring_func=ext_scorer.evaluate, - ext_scoring_func=None, - blank_id=len(vocab_list)) + ext_scoring_func=ext_scorer, + blank_id=len(vocab_list), + num_processes=1) for i, beam_search_result in enumerate(beam_search_nproc_results): target_transcription = ''.join( [vocab_list[index] for index in infer_data[i][1]]) diff --git a/tune.py b/tune.py new file mode 100644 index 00000000..3eb82648 --- /dev/null +++ b/tune.py @@ -0,0 +1,234 @@ +""" + Tune parameters for beam search decoder in Deep Speech 2. +""" + +import paddle.v2 as paddle +import distutils.util +import argparse +import gzip +from audio_data_utils import DataGenerator +from model import deep_speech2 +from decoder import * +from error_rate import wer + +parser = argparse.ArgumentParser( + description='Parameters tuning script for ctc beam search decoder in Deep Speech 2.' +) +parser.add_argument( + "--num_samples", + default=100, + type=int, + help="Number of samples for parameters tuning. (default: %(default)s)") +parser.add_argument( + "--num_conv_layers", + default=2, + type=int, + help="Convolution layer number. (default: %(default)s)") +parser.add_argument( + "--num_rnn_layers", + default=3, + type=int, + help="RNN layer number. (default: %(default)s)") +parser.add_argument( + "--rnn_layer_size", + default=512, + type=int, + help="RNN layer cell number. (default: %(default)s)") +parser.add_argument( + "--use_gpu", + default=True, + type=distutils.util.strtobool, + help="Use gpu or not. (default: %(default)s)") +parser.add_argument( + "--normalizer_manifest_path", + default='data/manifest.libri.train-clean-100', + type=str, + help="Manifest path for normalizer. (default: %(default)s)") +parser.add_argument( + "--decode_manifest_path", + default='data/manifest.libri.test-100sample', + type=str, + help="Manifest path for decoding. (default: %(default)s)") +parser.add_argument( + "--model_filepath", + default='./params.tar.gz', + type=str, + help="Model filepath. (default: %(default)s)") +parser.add_argument( + "--vocab_filepath", + default='data/eng_vocab.txt', + type=str, + help="Vocabulary filepath. (default: %(default)s)") +parser.add_argument( + "--decode_method", + default='beam_search_nproc', + type=str, + help="Method for decoding, beam_search or beam_search_nproc. (default: %(default)s)" +) +parser.add_argument( + "--beam_size", + default=500, + type=int, + help="Width for beam search decoding. (default: %(default)d)") +parser.add_argument( + "--num_results_per_sample", + default=1, + type=int, + help="Number of outputs per sample in beam search. (default: %(default)d)") +parser.add_argument( + "--language_model_path", + default="./data/1Billion.klm", + type=str, + help="Path for language model. (default: %(default)s)") +parser.add_argument( + "--alpha_from", + default=0.0, + type=float, + help="Where alpha starts from, <= alpha_to. (default: %(default)f)") +parser.add_argument( + "--alpha_stride", + default=0.001, + type=float, + help="Step length for varying alpha. (default: %(default)f)") +parser.add_argument( + "--alpha_to", + default=0.01, + type=float, + help="Where alpha ends with, >= alpha_from. (default: %(default)f)") +parser.add_argument( + "--beta_from", + default=0.0, + type=float, + help="Where beta starts from, <= beta_to. (default: %(default)f)") +parser.add_argument( + "--beta_stride", + default=0.01, + type=float, + help="Step length for varying beta. (default: %(default)f)") +parser.add_argument( + "--beta_to", + default=0.0, + type=float, + help="Where beta ends with, >= beta_from. (default: %(default)f)") +args = parser.parse_args() + + +def tune(): + """ + Tune parameters alpha and beta on one minibatch. + """ + + if not args.alpha_from <= args.alpha_to: + raise ValueError("alpha_from <= alpha_to doesn't satisfy!") + if not args.alpha_stride > 0: + raise ValueError("alpha_stride shouldn't be negative!") + + if not args.beta_from <= args.beta_to: + raise ValueError("beta_from <= beta_to doesn't satisfy!") + if not args.beta_stride > 0: + raise ValueError("beta_stride shouldn't be negative!") + + # initialize data generator + data_generator = DataGenerator( + vocab_filepath=args.vocab_filepath, + normalizer_manifest_path=args.normalizer_manifest_path, + normalizer_num_samples=200, + max_duration=20.0, + min_duration=0.0, + stride_ms=10, + window_ms=20) + + # create network config + dict_size = data_generator.vocabulary_size() + vocab_list = data_generator.vocabulary_list() + audio_data = paddle.layer.data( + name="audio_spectrogram", + height=161, + width=2000, + type=paddle.data_type.dense_vector(322000)) + text_data = paddle.layer.data( + name="transcript_text", + type=paddle.data_type.integer_value_sequence(dict_size)) + output_probs = deep_speech2( + audio_data=audio_data, + text_data=text_data, + dict_size=dict_size, + num_conv_layers=args.num_conv_layers, + num_rnn_layers=args.num_rnn_layers, + rnn_size=args.rnn_layer_size, + is_inference=True) + + # load parameters + parameters = paddle.parameters.Parameters.from_tar( + gzip.open(args.model_filepath)) + + # prepare infer data + feeding = data_generator.data_name_feeding() + test_batch_reader = data_generator.batch_reader_creator( + manifest_path=args.decode_manifest_path, + batch_size=args.num_samples, + padding_to=2000, + flatten=True, + sort_by_duration=False, + shuffle=False) + infer_data = test_batch_reader().next() + + # run inference + infer_results = paddle.infer( + output_layer=output_probs, parameters=parameters, input=infer_data) + num_steps = len(infer_results) / len(infer_data) + probs_split = [ + infer_results[i * num_steps:(i + 1) * num_steps] + for i in xrange(0, len(infer_data)) + ] + + cand_alpha = np.arange(args.alpha_from, args.alpha_to + args.alpha_stride, + args.alpha_stride) + cand_beta = np.arange(args.beta_from, args.beta_to + args.beta_stride, + args.beta_stride) + params_grid = [(alpha, beta) for alpha in cand_alpha for beta in cand_beta] + ## tune parameters in loop + for (alpha, beta) in params_grid: + wer_sum, wer_counter = 0, 0 + ext_scorer = Scorer(alpha, beta, args.language_model_path) + # beam search decode + if args.decode_method == "beam_search": + for i, probs in enumerate(probs_split): + target_transcription = ''.join( + [vocab_list[index] for index in infer_data[i][1]]) + beam_search_result = ctc_beam_search_decoder( + probs_seq=probs, + vocabulary=vocab_list, + beam_size=args.beam_size, + ext_scoring_func=ext_scorer, + blank_id=len(vocab_list)) + wer_sum += wer(target_transcription, beam_search_result[0][1]) + wer_counter += 1 + # beam search using multiple processes + elif args.decode_method == "beam_search_nproc": + beam_search_nproc_results = ctc_beam_search_decoder_nproc( + probs_split=probs_split, + vocabulary=vocab_list, + beam_size=args.beam_size, + ext_scoring_func=ext_scorer, + blank_id=len(vocab_list), + num_processes=1) + for i, beam_search_result in enumerate(beam_search_nproc_results): + target_transcription = ''.join( + [vocab_list[index] for index in infer_data[i][1]]) + wer_sum += wer(target_transcription, beam_search_result[0][1]) + wer_counter += 1 + else: + raise ValueError("Decoding method [%s] is not supported." % method) + + print("alpha = %f\tbeta = %f\tWER = %f" % + (alpha, beta, wer_sum / wer_counter)) + + +def main(): + paddle.init(use_gpu=args.use_gpu, trainer_count=1) + tune() + + +if __name__ == '__main__': + main() From a633eb9cc6d81ad9e1d9615be281b5678e256faa Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Tue, 13 Jun 2017 15:28:43 +0800 Subject: [PATCH 14/28] change two arguments --- infer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/infer.py b/infer.py index 0bae1312..bb9dfa0a 100644 --- a/infer.py +++ b/infer.py @@ -198,8 +198,7 @@ def infer(): vocabulary=vocab_list, beam_size=args.beam_size, ext_scoring_func=ext_scorer, - blank_id=len(vocab_list), - num_processes=1) + blank_id=len(vocab_list)) for i, beam_search_result in enumerate(beam_search_nproc_results): target_transcription = ''.join( [vocab_list[index] for index in infer_data[i][1]]) From ff01d048d39854abf075a81320bddddcbc62f1f0 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Sun, 18 Jun 2017 16:36:52 +0800 Subject: [PATCH 15/28] final refining on old data provider: enable pruning & add evaluation & code cleanup --- decoder.py | 84 +++++++++++++++------ evaluate.py | 214 ++++++++++++++++++++++++++++++++++++++++++++++++++++ infer.py | 40 ++++++---- tune.py | 73 +++++++++--------- 4 files changed, 339 insertions(+), 72 deletions(-) create mode 100644 evaluate.py diff --git a/decoder.py b/decoder.py index 824ac970..2ee89cbd 100755 --- a/decoder.py +++ b/decoder.py @@ -5,7 +5,6 @@ import os from itertools import groupby import numpy as np -import copy import kenlm import multiprocessing @@ -73,11 +72,25 @@ class Scorer(object): return len(words) # execute evaluation - def __call__(self, sentence): + def __call__(self, sentence, log=False): + """ + Evaluation function + + :param sentence: The input sentence for evalutation + :type sentence: basestring + :param log: Whether return the score in log representation. + :type log: bool + :return: Evaluation score, in the decimal or log. + :rtype: float + """ lm = self.language_model_score(sentence) word_cnt = self.word_count(sentence) - score = np.power(lm, self._alpha) \ - * np.power(word_cnt, self._beta) + if log == False: + score = np.power(lm, self._alpha) \ + * np.power(word_cnt, self._beta) + else: + score = self._alpha * np.log(lm) \ + + self._beta * np.log(word_cnt) return score @@ -85,13 +98,14 @@ def ctc_beam_search_decoder(probs_seq, beam_size, vocabulary, blank_id=0, + cutoff_prob=1.0, ext_scoring_func=None, nproc=False): ''' Beam search decoder for CTC-trained network, using beam search with width beam_size to find many paths to one label, return beam_size labels in - the order of probabilities. The implementation is based on Prefix Beam - Search(https://arxiv.org/abs/1408.2873), and the unclear part is + the descending order of probabilities. The implementation is based on Prefix + Beam Search(https://arxiv.org/abs/1408.2873), and the unclear part is redesigned, need to be verified. :param probs_seq: 2-D list with length num_time_steps, each element @@ -102,22 +116,25 @@ def ctc_beam_search_decoder(probs_seq, :type beam_size: int :param vocabulary: Vocabulary list. :type vocabulary: list + :param blank_id: ID of blank, default 0. + :type blank_id: int + :param cutoff_prob: Cutoff probability in pruning, + default 1.0, no pruning. + :type cutoff_prob: float :param ext_scoring_func: External defined scoring function for partially decoded sentence, e.g. word count and language model. :type external_scoring_function: function - :param blank_id: id of blank, default 0. - :type blank_id: int :param nproc: Whether the decoder used in multiprocesses. :type nproc: bool - :return: Decoding log probability and result string. + :return: Decoding log probabilities and result sentences in descending order. :rtype: list ''' # dimension check for prob_list in probs_seq: if not len(prob_list) == len(vocabulary) + 1: - raise ValueError("probs dimension mismatchedd with vocabulary") + raise ValueError("probs dimension mismatched with vocabulary") num_time_steps = len(probs_seq) # blank_id check @@ -137,19 +154,35 @@ def ctc_beam_search_decoder(probs_seq, probs_b_prev, probs_nb_prev = {'\t': 1.0}, {'\t': 0.0} ## extend prefix in loop - for time_step in range(num_time_steps): + for time_step in xrange(num_time_steps): # the set containing candidate prefixes prefix_set_next = {} probs_b_cur, probs_nb_cur = {}, {} + prob = probs_seq[time_step] + prob_idx = [[i, prob[i]] for i in xrange(len(prob))] + cutoff_len = len(prob_idx) + #If pruning is enabled + if (cutoff_prob < 1.0): + prob_idx = sorted(prob_idx, key=lambda asd: asd[1], reverse=True) + cutoff_len = 0 + cum_prob = 0.0 + for i in xrange(len(prob_idx)): + cum_prob += prob_idx[i][1] + cutoff_len += 1 + if cum_prob >= cutoff_prob: + break + prob_idx = prob_idx[0:cutoff_len] + for l in prefix_set_prev: - prob = probs_seq[time_step] if not prefix_set_next.has_key(l): probs_b_cur[l], probs_nb_cur[l] = 0.0, 0.0 - # extend prefix by travering vocabulary - for c in range(0, probs_dim): + # extend prefix by travering prob_idx + for index in xrange(cutoff_len): + c, prob_c = prob_idx[index][0], prob_idx[index][1] + if c == blank_id: - probs_b_cur[l] += prob[c] * ( + probs_b_cur[l] += prob_c * ( probs_b_prev[l] + probs_nb_prev[l]) else: last_char = l[-1] @@ -159,18 +192,18 @@ def ctc_beam_search_decoder(probs_seq, probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0 if new_char == last_char: - probs_nb_cur[l_plus] += prob[c] * probs_b_prev[l] - probs_nb_cur[l] += prob[c] * probs_nb_prev[l] + probs_nb_cur[l_plus] += prob_c * probs_b_prev[l] + probs_nb_cur[l] += prob_c * probs_nb_prev[l] elif new_char == ' ': if (ext_scoring_func is None) or (len(l) == 1): score = 1.0 else: prefix = l[1:] score = ext_scoring_func(prefix) - probs_nb_cur[l_plus] += score * prob[c] * ( + probs_nb_cur[l_plus] += score * prob_c * ( probs_b_prev[l] + probs_nb_prev[l]) else: - probs_nb_cur[l_plus] += prob[c] * ( + probs_nb_cur[l_plus] += prob_c * ( probs_b_prev[l] + probs_nb_prev[l]) # add l_plus into prefix_set_next prefix_set_next[l_plus] = probs_nb_cur[ @@ -203,6 +236,7 @@ def ctc_beam_search_decoder_nproc(probs_split, beam_size, vocabulary, blank_id=0, + cutoff_prob=1.0, ext_scoring_func=None, num_processes=None): ''' @@ -216,16 +250,19 @@ def ctc_beam_search_decoder_nproc(probs_split, :type beam_size: int :param vocabulary: Vocabulary list. :type vocabulary: list + :param blank_id: ID of blank, default 0. + :type blank_id: int + :param cutoff_prob: Cutoff probability in pruning, + default 0, no pruning. + :type cutoff_prob: float :param ext_scoring_func: External defined scoring function for partially decoded sentence, e.g. word count and language model. :type external_scoring_function: function - :param blank_id: id of blank, default 0. - :type blank_id: int :param num_processes: Number of processes, default None, equal to the number of CPUs. :type num_processes: int - :return: Decoding log probability and result string. + :return: Decoding log probabilities and result sentences in descending order. :rtype: list ''' @@ -243,7 +280,8 @@ def ctc_beam_search_decoder_nproc(probs_split, pool = multiprocessing.Pool(processes=num_processes) results = [] for i, probs_list in enumerate(probs_split): - args = (probs_list, beam_size, vocabulary, blank_id, None, nproc) + args = (probs_list, beam_size, vocabulary, blank_id, cutoff_prob, None, + nproc) results.append(pool.apply_async(ctc_beam_search_decoder, args)) pool.close() diff --git a/evaluate.py b/evaluate.py new file mode 100644 index 00000000..7c05a309 --- /dev/null +++ b/evaluate.py @@ -0,0 +1,214 @@ +""" + Evaluation for a simplifed version of Baidu DeepSpeech2 model. +""" + +import paddle.v2 as paddle +import distutils.util +import argparse +import gzip +from audio_data_utils import DataGenerator +from model import deep_speech2 +from decoder import * +from error_rate import wer + +parser = argparse.ArgumentParser( + description='Simplified version of DeepSpeech2 evaluation.') +parser.add_argument( + "--num_samples", + default=100, + type=int, + help="Number of samples for evaluation. (default: %(default)s)") +parser.add_argument( + "--num_conv_layers", + default=2, + type=int, + help="Convolution layer number. (default: %(default)s)") +parser.add_argument( + "--num_rnn_layers", + default=3, + type=int, + help="RNN layer number. (default: %(default)s)") +parser.add_argument( + "--rnn_layer_size", + default=512, + type=int, + help="RNN layer cell number. (default: %(default)s)") +parser.add_argument( + "--use_gpu", + default=True, + type=distutils.util.strtobool, + help="Use gpu or not. (default: %(default)s)") +parser.add_argument( + "--decode_method", + default='beam_search_nproc', + type=str, + help="Method for ctc decoding, best_path, " + "beam_search or beam_search_nproc. (default: %(default)s)") +parser.add_argument( + "--language_model_path", + default="./data/1Billion.klm", + type=str, + help="Path for language model. (default: %(default)s)") +parser.add_argument( + "--alpha", + default=0.26, + type=float, + help="Parameter associated with language model. (default: %(default)f)") +parser.add_argument( + "--beta", + default=0.1, + type=float, + help="Parameter associated with word count. (default: %(default)f)") +parser.add_argument( + "--cutoff_prob", + default=0.99, + type=float, + help="The cutoff probability of pruning" + "in beam search. (default: %(default)f)") +parser.add_argument( + "--beam_size", + default=500, + type=int, + help="Width for beam search decoding. (default: %(default)d)") +parser.add_argument( + "--normalizer_manifest_path", + default='data/manifest.libri.train-clean-100', + type=str, + help="Manifest path for normalizer. (default: %(default)s)") +parser.add_argument( + "--decode_manifest_path", + default='data/manifest.libri.test-clean', + type=str, + help="Manifest path for decoding. (default: %(default)s)") +parser.add_argument( + "--model_filepath", + default='./params.tar.gz', + type=str, + help="Model filepath. (default: %(default)s)") +parser.add_argument( + "--vocab_filepath", + default='data/eng_vocab.txt', + type=str, + help="Vocabulary filepath. (default: %(default)s)") +args = parser.parse_args() + + +def evaluate(): + """ + Evaluate on whole test data for DeepSpeech2. + """ + # initialize data generator + data_generator = DataGenerator( + vocab_filepath=args.vocab_filepath, + normalizer_manifest_path=args.normalizer_manifest_path, + normalizer_num_samples=200, + max_duration=20.0, + min_duration=0.0, + stride_ms=10, + window_ms=20) + + # create network config + dict_size = data_generator.vocabulary_size() + vocab_list = data_generator.vocabulary_list() + audio_data = paddle.layer.data( + name="audio_spectrogram", + height=161, + width=2000, + type=paddle.data_type.dense_vector(322000)) + text_data = paddle.layer.data( + name="transcript_text", + type=paddle.data_type.integer_value_sequence(dict_size)) + output_probs = deep_speech2( + audio_data=audio_data, + text_data=text_data, + dict_size=dict_size, + num_conv_layers=args.num_conv_layers, + num_rnn_layers=args.num_rnn_layers, + rnn_size=args.rnn_layer_size, + is_inference=True) + + # load parameters + parameters = paddle.parameters.Parameters.from_tar( + gzip.open(args.model_filepath)) + + # prepare infer data + feeding = data_generator.data_name_feeding() + test_batch_reader = data_generator.batch_reader_creator( + manifest_path=args.decode_manifest_path, + batch_size=args.num_samples, + padding_to=2000, + flatten=True, + sort_by_duration=False, + shuffle=False) + + # define inferer + inferer = paddle.inference.Inference( + output_layer=output_probs, parameters=parameters) + + # initialize external scorer for beam search decoding + if args.decode_method == 'beam_search' or \ + args.decode_method == 'beam_search_nproc': + ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path) + + wer_counter, wer_sum = 0, 0.0 + for infer_data in test_batch_reader(): + # run inference + infer_results = inferer.infer(input=infer_data) + num_steps = len(infer_results) / len(infer_data) + probs_split = [ + infer_results[i * num_steps:(i + 1) * num_steps] + for i in xrange(0, len(infer_data)) + ] + + # decode and print + # best path decode + if args.decode_method == "best_path": + for i, probs in enumerate(probs_split): + output_transcription = ctc_decode( + probs_seq=probs, vocabulary=vocab_list, method="best_path") + target_transcription = ''.join( + [vocab_list[index] for index in infer_data[i][1]]) + wer_sum += wer(target_transcription, output_transcription) + wer_counter += 1 + # beam search decode in single process + elif args.decode_method == "beam_search": + for i, probs in enumerate(probs_split): + target_transcription = ''.join( + [vocab_list[index] for index in infer_data[i][1]]) + beam_search_result = ctc_beam_search_decoder( + probs_seq=probs, + vocabulary=vocab_list, + beam_size=args.beam_size, + blank_id=len(vocab_list), + ext_scoring_func=ext_scorer, + cutoff_prob=args.cutoff_prob, ) + wer_sum += wer(target_transcription, beam_search_result[0][1]) + wer_counter += 1 + # beam search using multiple processes + elif args.decode_method == "beam_search_nproc": + beam_search_nproc_results = ctc_beam_search_decoder_nproc( + probs_split=probs_split, + vocabulary=vocab_list, + beam_size=args.beam_size, + blank_id=len(vocab_list), + ext_scoring_func=ext_scorer, + cutoff_prob=args.cutoff_prob, ) + for i, beam_search_result in enumerate(beam_search_nproc_results): + target_transcription = ''.join( + [vocab_list[index] for index in infer_data[i][1]]) + wer_sum += wer(target_transcription, beam_search_result[0][1]) + wer_counter += 1 + else: + raise ValueError("Decoding method [%s] is not supported." % method) + + print("Cur WER = %f" % (wer_sum / wer_counter)) + print("Final WER = %f" % (wer_sum / wer_counter)) + + +def main(): + paddle.init(use_gpu=args.use_gpu, trainer_count=1) + evaluate() + + +if __name__ == '__main__': + main() diff --git a/infer.py b/infer.py index bb9dfa0a..64fe1524 100644 --- a/infer.py +++ b/infer.py @@ -9,14 +9,14 @@ import gzip from audio_data_utils import DataGenerator from model import deep_speech2 from decoder import * -import kenlm from error_rate import wer +import time parser = argparse.ArgumentParser( description='Simplified version of DeepSpeech2 inference.') parser.add_argument( "--num_samples", - default=10, + default=100, type=int, help="Number of samples for inference. (default: %(default)s)") parser.add_argument( @@ -46,7 +46,7 @@ parser.add_argument( help="Manifest path for normalizer. (default: %(default)s)") parser.add_argument( "--decode_manifest_path", - default='data/manifest.libri.test-clean', + default='data/manifest.libri.test-100sample', type=str, help="Manifest path for decoding. (default: %(default)s)") parser.add_argument( @@ -63,11 +63,13 @@ parser.add_argument( "--decode_method", default='beam_search_nproc', type=str, - help="Method for ctc decoding, best_path, beam_search or beam_search_nproc. (default: %(default)s)" -) + help="Method for ctc decoding:" + " best_path," + " beam_search, " + " or beam_search_nproc. (default: %(default)s)") parser.add_argument( "--beam_size", - default=50, + default=500, type=int, help="Width for beam search decoding. (default: %(default)d)") parser.add_argument( @@ -82,14 +84,20 @@ parser.add_argument( help="Path for language model. (default: %(default)s)") parser.add_argument( "--alpha", - default=0.0, + default=0.26, type=float, help="Parameter associated with language model. (default: %(default)f)") parser.add_argument( "--beta", - default=0.0, + default=0.1, type=float, help="Parameter associated with word count. (default: %(default)f)") +parser.add_argument( + "--cutoff_prob", + default=0.99, + type=float, + help="The cutoff probability of pruning" + "in beam search. (default: %(default)f)") args = parser.parse_args() @@ -154,6 +162,7 @@ def infer(): ## decode and print # best path decode wer_sum, wer_counter = 0, 0 + total_time = 0.0 if args.decode_method == "best_path": for i, probs in enumerate(probs_split): target_transcription = ''.join( @@ -177,11 +186,12 @@ def infer(): probs_seq=probs, vocabulary=vocab_list, beam_size=args.beam_size, - ext_scoring_func=ext_scorer, - blank_id=len(vocab_list)) + blank_id=len(vocab_list), + cutoff_prob=args.cutoff_prob, + ext_scoring_func=ext_scorer, ) print("\nTarget Transcription:\t%s" % target_transcription) - for index in range(args.num_results_per_sample): + for index in xrange(args.num_results_per_sample): result = beam_search_result[index] #output: index, log prob, beam result print("Beam %d: %f \t%s" % (index, result[0], result[1])) @@ -190,21 +200,21 @@ def infer(): wer_counter += 1 print("cur wer = %f , average wer = %f" % (wer_cur, wer_sum / wer_counter)) - # beam search using multiple processes elif args.decode_method == "beam_search_nproc": ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path) beam_search_nproc_results = ctc_beam_search_decoder_nproc( probs_split=probs_split, vocabulary=vocab_list, beam_size=args.beam_size, - ext_scoring_func=ext_scorer, - blank_id=len(vocab_list)) + blank_id=len(vocab_list), + cutoff_prob=args.cutoff_prob, + ext_scoring_func=ext_scorer, ) for i, beam_search_result in enumerate(beam_search_nproc_results): target_transcription = ''.join( [vocab_list[index] for index in infer_data[i][1]]) print("\nTarget Transcription:\t%s" % target_transcription) - for index in range(args.num_results_per_sample): + for index in xrange(args.num_results_per_sample): result = beam_search_result[index] #output: index, log prob, beam result print("Beam %d: %f \t%s" % (index, result[0], result[1])) diff --git a/tune.py b/tune.py index 3eb82648..58a8a0d1 100644 --- a/tune.py +++ b/tune.py @@ -1,5 +1,5 @@ """ - Tune parameters for beam search decoder in Deep Speech 2. + Parameters tuning for beam search decoder in Deep Speech 2. """ import paddle.v2 as paddle @@ -12,7 +12,7 @@ from decoder import * from error_rate import wer parser = argparse.ArgumentParser( - description='Parameters tuning script for ctc beam search decoder in Deep Speech 2.' + description='Parameters tuning for ctc beam search decoder in Deep Speech 2.' ) parser.add_argument( "--num_samples", @@ -82,34 +82,40 @@ parser.add_argument( help="Path for language model. (default: %(default)s)") parser.add_argument( "--alpha_from", - default=0.0, + default=0.1, type=float, - help="Where alpha starts from, <= alpha_to. (default: %(default)f)") + help="Where alpha starts from. (default: %(default)f)") parser.add_argument( - "--alpha_stride", - default=0.001, - type=float, - help="Step length for varying alpha. (default: %(default)f)") + "--num_alphas", + default=14, + type=int, + help="Number of candidate alphas. (default: %(default)d)") parser.add_argument( "--alpha_to", - default=0.01, + default=0.36, type=float, - help="Where alpha ends with, >= alpha_from. (default: %(default)f)") + help="Where alpha ends with. (default: %(default)f)") parser.add_argument( "--beta_from", - default=0.0, + default=0.05, type=float, - help="Where beta starts from, <= beta_to. (default: %(default)f)") + help="Where beta starts from. (default: %(default)f)") parser.add_argument( - "--beta_stride", - default=0.01, + "--num_betas", + default=20, type=float, - help="Step length for varying beta. (default: %(default)f)") + help="Number of candidate betas. (default: %(default)d)") parser.add_argument( "--beta_to", - default=0.0, + default=1.0, type=float, - help="Where beta ends with, >= beta_from. (default: %(default)f)") + help="Where beta ends with. (default: %(default)f)") +parser.add_argument( + "--cutoff_prob", + default=0.99, + type=float, + help="The cutoff probability of pruning" + "in beam search. (default: %(default)f)") args = parser.parse_args() @@ -118,15 +124,11 @@ def tune(): Tune parameters alpha and beta on one minibatch. """ - if not args.alpha_from <= args.alpha_to: - raise ValueError("alpha_from <= alpha_to doesn't satisfy!") - if not args.alpha_stride > 0: - raise ValueError("alpha_stride shouldn't be negative!") + if not args.num_alphas >= 0: + raise ValueError("num_alphas must be non-negative!") - if not args.beta_from <= args.beta_to: - raise ValueError("beta_from <= beta_to doesn't satisfy!") - if not args.beta_stride > 0: - raise ValueError("beta_stride shouldn't be negative!") + if not args.num_betas >= 0: + raise ValueError("num_betas must be non-negative!") # initialize data generator data_generator = DataGenerator( @@ -171,6 +173,7 @@ def tune(): flatten=True, sort_by_duration=False, shuffle=False) + # get one batch data for tuning infer_data = test_batch_reader().next() # run inference @@ -182,11 +185,12 @@ def tune(): for i in xrange(0, len(infer_data)) ] - cand_alpha = np.arange(args.alpha_from, args.alpha_to + args.alpha_stride, - args.alpha_stride) - cand_beta = np.arange(args.beta_from, args.beta_to + args.beta_stride, - args.beta_stride) - params_grid = [(alpha, beta) for alpha in cand_alpha for beta in cand_beta] + # create grid for search + cand_alphas = np.linspace(args.alpha_from, args.alpha_to, args.num_alphas) + cand_betas = np.linspace(args.beta_from, args.beta_to, args.num_betas) + params_grid = [(alpha, beta) for alpha in cand_alphas + for beta in cand_betas] + ## tune parameters in loop for (alpha, beta) in params_grid: wer_sum, wer_counter = 0, 0 @@ -200,8 +204,9 @@ def tune(): probs_seq=probs, vocabulary=vocab_list, beam_size=args.beam_size, - ext_scoring_func=ext_scorer, - blank_id=len(vocab_list)) + blank_id=len(vocab_list), + cutoff_prob=args.cutoff_prob, + ext_scoring_func=ext_scorer, ) wer_sum += wer(target_transcription, beam_search_result[0][1]) wer_counter += 1 # beam search using multiple processes @@ -210,9 +215,9 @@ def tune(): probs_split=probs_split, vocabulary=vocab_list, beam_size=args.beam_size, - ext_scoring_func=ext_scorer, + cutoff_prob=args.cutoff_prob, blank_id=len(vocab_list), - num_processes=1) + ext_scoring_func=ext_scorer, ) for i, beam_search_result in enumerate(beam_search_nproc_results): target_transcription = ''.join( [vocab_list[index] for index in infer_data[i][1]]) From 36743d36897082289ab678a744d236699fd69ae3 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Sun, 18 Jun 2017 18:11:01 +0800 Subject: [PATCH 16/28] add scoring last word in beam search --- decoder.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/decoder.py b/decoder.py index 2ee89cbd..37640aff 100755 --- a/decoder.py +++ b/decoder.py @@ -222,8 +222,11 @@ def ctc_beam_search_decoder(probs_seq, beam_result = [] for (seq, prob) in prefix_set_prev.items(): - if prob > 0.0: + if prob > 0.0 and len(seq) > 1: result = seq[1:] + # score last word by external scorer + if (ext_scoring_func is not None) and (result[-1] != ' '): + prob = prob * ext_scoring_func(result) log_prob = np.log(prob) beam_result.append([log_prob, result]) From 0729abe02e787762acc0f0b30e4890b554f20d06 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Tue, 20 Jun 2017 12:14:24 +0800 Subject: [PATCH 17/28] tiny adjust --- decoder.py | 6 ++---- infer.py | 1 - 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/decoder.py b/decoder.py index 445831aa..a23fa132 100644 --- a/decoder.py +++ b/decoder.py @@ -1,4 +1,4 @@ -"""Contains various CTC decoder.""" +"""Contains various CTC decoders.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -103,7 +103,7 @@ def ctc_beam_search_decoder(probs_seq, beam_size to find many paths to one label, return beam_size labels in the descending order of probabilities. The implementation is based on Prefix Beam Search(https://arxiv.org/abs/1408.2873), and the unclear part is - redesigned, need to be verified. + redesigned. :param probs_seq: 2-D list with length num_time_steps, each element is a list of normalized probabilities over vocabulary @@ -262,9 +262,7 @@ def ctc_beam_search_decoder_nproc(probs_split, :type num_processes: int :return: Decoding log probabilities and result sentences in descending order. :rtype: list - ''' - if num_processes is None: num_processes = multiprocessing.cpu_count() if not num_processes > 0: diff --git a/infer.py b/infer.py index 9f6d91ca..4545f3da 100644 --- a/infer.py +++ b/infer.py @@ -151,7 +151,6 @@ def infer(): ## decode and print # best path decode wer_sum, wer_counter = 0, 0 - total_time = 0.0 if args.decode_method == "best_path": for i, probs in enumerate(probs_split): target_transcription = ''.join([ From 803384561501299d01464f847e5ef9d5a6b38685 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Wed, 21 Jun 2017 14:52:06 +0800 Subject: [PATCH 18/28] add unit test for decoders --- decoder.py | 55 -------------------------- evaluate.py | 3 +- infer.py | 5 ++- scorer.py | 62 +++++++++++++++++++++++++++++ tests/test_decoders.py | 90 ++++++++++++++++++++++++++++++++++++++++++ tune.py | 3 +- 6 files changed, 159 insertions(+), 59 deletions(-) create mode 100644 scorer.py create mode 100644 tests/test_decoders.py diff --git a/decoder.py b/decoder.py index a23fa132..00659367 100644 --- a/decoder.py +++ b/decoder.py @@ -3,10 +3,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os from itertools import groupby import numpy as np -import kenlm import multiprocessing @@ -39,59 +37,6 @@ def ctc_best_path_decode(probs_seq, vocabulary): return ''.join([vocabulary[index] for index in index_list]) -class Scorer(object): - """External defined scorer to evaluate a sentence in beam search - decoding, consisting of language model and word count. - - :param alpha: Parameter associated with language model. - :type alpha: float - :param beta: Parameter associated with word count. - :type beta: float - :model_path: Path to load language model. - :type model_path: basestring - """ - - def __init__(self, alpha, beta, model_path): - self._alpha = alpha - self._beta = beta - if not os.path.isfile(model_path): - raise IOError("Invaid language model path: %s" % model_path) - self._language_model = kenlm.LanguageModel(model_path) - - # n-gram language model scoring - def language_model_score(self, sentence): - #log prob of last word - log_cond_prob = list( - self._language_model.full_scores(sentence, eos=False))[-1][0] - return np.power(10, log_cond_prob) - - # word insertion term - def word_count(self, sentence): - words = sentence.strip().split(' ') - return len(words) - - # execute evaluation - def __call__(self, sentence, log=False): - """Evaluation function, gathering all the scores. - - :param sentence: The input sentence for evalutation - :type sentence: basestring - :param log: Whether return the score in log representation. - :type log: bool - :return: Evaluation score, in the decimal or log. - :rtype: float - """ - lm = self.language_model_score(sentence) - word_cnt = self.word_count(sentence) - if log == False: - score = np.power(lm, self._alpha) \ - * np.power(word_cnt, self._beta) - else: - score = self._alpha * np.log(lm) \ - + self._beta * np.log(word_cnt) - return score - - def ctc_beam_search_decoder(probs_seq, beam_size, vocabulary, diff --git a/evaluate.py b/evaluate.py index dee85cbd..a7b8e221 100644 --- a/evaluate.py +++ b/evaluate.py @@ -10,6 +10,7 @@ import gzip from data_utils.data import DataGenerator from model import deep_speech2 from decoder import * +from scorer import Scorer from error_rate import wer parser = argparse.ArgumentParser(description=__doc__) @@ -51,7 +52,7 @@ parser.add_argument( "beam_search or beam_search_nproc. (default: %(default)s)") parser.add_argument( "--language_model_path", - default="data/1Billion.klm", + default="data/en.00.UNKNOWN.klm", type=str, help="Path for language model. (default: %(default)s)") parser.add_argument( diff --git a/infer.py b/infer.py index b4de2b60..ca18569d 100644 --- a/infer.py +++ b/infer.py @@ -11,6 +11,7 @@ import paddle.v2 as paddle from data_utils.data import DataGenerator from model import deep_speech2 from decoder import * +from scorer import Scorer from error_rate import wer import utils @@ -67,7 +68,7 @@ parser.add_argument( help="Vocabulary filepath. (default: %(default)s)") parser.add_argument( "--decode_method", - default='best_path', + default='beam_search_nproc', type=str, help="Method for ctc decoding:" " best_path," @@ -85,7 +86,7 @@ parser.add_argument( help="Number of output per sample in beam search. (default: %(default)d)") parser.add_argument( "--language_model_path", - default="data/1Billion.klm", + default="data/en.00.UNKNOWN.klm", type=str, help="Path for language model. (default: %(default)s)") parser.add_argument( diff --git a/scorer.py b/scorer.py new file mode 100644 index 00000000..4f468481 --- /dev/null +++ b/scorer.py @@ -0,0 +1,62 @@ +"""External Scorer for Beam Search Decoder.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import kenlm +import numpy as np + + +class Scorer(object): + """External defined scorer to evaluate a sentence in beam search + decoding, consisting of language model and word count. + + :param alpha: Parameter associated with language model. + :type alpha: float + :param beta: Parameter associated with word count. + :type beta: float + :model_path: Path to load language model. + :type model_path: basestring + """ + + def __init__(self, alpha, beta, model_path): + self._alpha = alpha + self._beta = beta + if not os.path.isfile(model_path): + raise IOError("Invaid language model path: %s" % model_path) + self._language_model = kenlm.LanguageModel(model_path) + + # n-gram language model scoring + def language_model_score(self, sentence): + #log10 prob of last word + log_cond_prob = list( + self._language_model.full_scores(sentence, eos=False))[-1][0] + return np.power(10, log_cond_prob) + + # word insertion term + def word_count(self, sentence): + words = sentence.strip().split(' ') + return len(words) + + # execute evaluation + def __call__(self, sentence, log=False): + """Evaluation function, gathering all the different scores + and return the final one. + + :param sentence: The input sentence for evalutation + :type sentence: basestring + :param log: Whether return the score in log representation. + :type log: bool + :return: Evaluation score, in the decimal or log. + :rtype: float + """ + lm = self.language_model_score(sentence) + word_cnt = self.word_count(sentence) + if log == False: + score = np.power(lm, self._alpha) \ + * np.power(word_cnt, self._beta) + else: + score = self._alpha * np.log(lm) \ + + self._beta * np.log(word_cnt) + return score diff --git a/tests/test_decoders.py b/tests/test_decoders.py new file mode 100644 index 00000000..7fa89c5f --- /dev/null +++ b/tests/test_decoders.py @@ -0,0 +1,90 @@ +"""Test decoders.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import unittest +from decoder import * + + +class TestDecoders(unittest.TestCase): + def setUp(self): + self.vocab_list = ["\'", ' ', 'a', 'b', 'c', 'd'] + self.beam_size = 20 + self.probs_seq1 = [[ + 0.06390443, 0.21124858, 0.27323887, 0.06870235, 0.0361254, + 0.18184413, 0.16493624 + ], [ + 0.03309247, 0.22866108, 0.24390638, 0.09699597, 0.31895462, + 0.0094893, 0.06890021 + ], [ + 0.218104, 0.19992557, 0.18245131, 0.08503348, 0.14903535, + 0.08424043, 0.08120984 + ], [ + 0.12094152, 0.19162472, 0.01473646, 0.28045061, 0.24246305, + 0.05206269, 0.09772094 + ], [ + 0.1333387, 0.00550838, 0.00301669, 0.21745861, 0.20803985, + 0.41317442, 0.01946335 + ], [ + 0.16468227, 0.1980699, 0.1906545, 0.18963251, 0.19860937, + 0.04377724, 0.01457421 + ]] + self.probs_seq2 = [[ + 0.08034842, 0.22671944, 0.05799633, 0.36814645, 0.11307441, + 0.04468023, 0.10903471 + ], [ + 0.09742457, 0.12959763, 0.09435383, 0.21889204, 0.15113123, + 0.10219457, 0.20640612 + ], [ + 0.45033529, 0.09091417, 0.15333208, 0.07939558, 0.08649316, + 0.12298585, 0.01654384 + ], [ + 0.02512238, 0.22079203, 0.19664364, 0.11906379, 0.07816055, + 0.22538587, 0.13483174 + ], [ + 0.17928453, 0.06065261, 0.41153005, 0.1172041, 0.11880313, + 0.07113197, 0.04139363 + ], [ + 0.15882358, 0.1235788, 0.23376776, 0.20510435, 0.00279306, + 0.05294827, 0.22298418 + ]] + self.best_path_result = ["ac'bdc", "b'da"] + self.beam_search_result = ['acdc', "b'a"] + + def test_best_path_decoder_1(self): + bst_result = ctc_best_path_decode(self.probs_seq1, self.vocab_list) + self.assertEqual(bst_result, self.best_path_result[0]) + + def test_best_path_decoder_2(self): + bst_result = ctc_best_path_decode(self.probs_seq2, self.vocab_list) + self.assertEqual(bst_result, self.best_path_result[1]) + + def test_beam_search_decoder_1(self): + beam_result = ctc_beam_search_decoder( + probs_seq=self.probs_seq1, + beam_size=self.beam_size, + vocabulary=self.vocab_list, + blank_id=len(self.vocab_list)) + self.assertEqual(beam_result[0][1], self.beam_search_result[0]) + + def test_beam_search_decoder_2(self): + beam_result = ctc_beam_search_decoder( + probs_seq=self.probs_seq2, + beam_size=self.beam_size, + vocabulary=self.vocab_list, + blank_id=len(self.vocab_list)) + self.assertEqual(beam_result[0][1], self.beam_search_result[1]) + + def test_beam_search_nproc_decoder(self): + beam_results = ctc_beam_search_decoder_nproc( + probs_split=[self.probs_seq1, self.probs_seq2], + beam_size=self.beam_size, + vocabulary=self.vocab_list, + blank_id=len(self.vocab_list)) + self.assertEqual(beam_results[0][0][1], self.beam_search_result[0]) + self.assertEqual(beam_results[1][0][1], self.beam_search_result[1]) + + +if __name__ == '__main__': + unittest.main() diff --git a/tune.py b/tune.py index 7dae1490..02076349 100644 --- a/tune.py +++ b/tune.py @@ -10,6 +10,7 @@ import gzip from data_utils.data import DataGenerator from model import deep_speech2 from decoder import * +from scorer import Scorer from error_rate import wer parser = argparse.ArgumentParser(description=__doc__) @@ -81,7 +82,7 @@ parser.add_argument( help="Number of outputs per sample in beam search. (default: %(default)d)") parser.add_argument( "--language_model_path", - default="data/1Billion.klm", + default="data/en.00.UNKNOWN.klm", type=str, help="Path for language model. (default: %(default)s)") parser.add_argument( From 26510f74a63307786f83db3f9faa2f579292e1f4 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Tue, 27 Jun 2017 17:42:44 +0800 Subject: [PATCH 19/28] refine ctc_beam_search_decoder --- decoder.py | 128 +++++++++++++++++++---------------- evaluate.py | 89 +++++++++++------------- infer.py | 79 ++++++++------------- lm/__init__.py | 0 scorer.py => lm/lm_scorer.py | 21 +++--- lm/run.sh | 3 + requirements.txt | 1 + tests/test_decoders.py | 6 +- tune.py | 89 +++++++++--------------- 9 files changed, 187 insertions(+), 229 deletions(-) create mode 100644 lm/__init__.py rename scorer.py => lm/lm_scorer.py (73%) create mode 100644 lm/run.sh diff --git a/decoder.py b/decoder.py index 00659367..4676b02b 100644 --- a/decoder.py +++ b/decoder.py @@ -8,8 +8,8 @@ import numpy as np import multiprocessing -def ctc_best_path_decode(probs_seq, vocabulary): - """Best path decoding, also called argmax decoding or greedy decoding. +def ctc_best_path_decoder(probs_seq, vocabulary): + """Best path decoder, also called argmax decoder or greedy decoder. Path consisting of the most probable tokens are further post-processed to remove consecutive repetitions and all blanks. @@ -40,73 +40,84 @@ def ctc_best_path_decode(probs_seq, vocabulary): def ctc_beam_search_decoder(probs_seq, beam_size, vocabulary, - blank_id=0, + blank_id, cutoff_prob=1.0, ext_scoring_func=None, nproc=False): - '''Beam search decoder for CTC-trained network, using beam search with width - beam_size to find many paths to one label, return beam_size labels in - the descending order of probabilities. The implementation is based on Prefix - Beam Search(https://arxiv.org/abs/1408.2873), and the unclear part is - redesigned. - - :param probs_seq: 2-D list with length num_time_steps, each element - is a list of normalized probabilities over vocabulary - and blank for one time step. + """Beam search decoder for CTC-trained network. It utilizes beam search + to approximately select top best decoding labels and returning results + in the descending order. The implementation is based on Prefix + Beam Search (https://arxiv.org/abs/1408.2873), and the unclear part is + redesigned. Two important modifications: 1) in the iterative computation + of probabilities, the assignment operation is changed to accumulation for + one prefix may comes from different paths; 2) the if condition "if l^+ not + in A_prev then" after probabilities' computation is deprecated for it is + hard to understand and seems unnecessary. + + :param probs_seq: 2-D list of probability distributions over each time + step, with each element being a list of normalized + probabilities over vocabulary and blank. :type probs_seq: 2-D list :param beam_size: Width for beam search. :type beam_size: int :param vocabulary: Vocabulary list. :type vocabulary: list - :param blank_id: ID of blank, default 0. + :param blank_id: ID of blank. :type blank_id: int :param cutoff_prob: Cutoff probability in pruning, default 1.0, no pruning. :type cutoff_prob: float - :param ext_scoring_func: External defined scoring function for + :param ext_scoring_func: External scoring function for partially decoded sentence, e.g. word count - and language model. - :type external_scoring_function: function + or language model. + :type external_scoring_func: callable :param nproc: Whether the decoder used in multiprocesses. :type nproc: bool - :return: Decoding log probabilities and result sentences in descending order. + :return: List of tuples of log probability and sentence as decoding + results, in descending order of the probability. :rtype: list - ''' + """ # dimension check for prob_list in probs_seq: if not len(prob_list) == len(vocabulary) + 1: - raise ValueError("probs dimension mismatched with vocabulary") - num_time_steps = len(probs_seq) + raise ValueError("The shape of prob_seq does not match with the " + "shape of the vocabulary.") # blank_id check - probs_dim = len(probs_seq[0]) - if not blank_id < probs_dim: + if not blank_id < len(probs_seq[0]): raise ValueError("blank_id shouldn't be greater than probs dimension") # If the decoder called in the multiprocesses, then use the global scorer - # instantiated in ctc_beam_search_decoder_nproc(). + # instantiated in ctc_beam_search_decoder_batch(). if nproc is True: global ext_nproc_scorer ext_scoring_func = ext_nproc_scorer ## initialize - # the set containing selected prefixes - prefix_set_prev = {'\t': 1.0} - probs_b_prev, probs_nb_prev = {'\t': 1.0}, {'\t': 0.0} + # prefix_set_prev: the set containing selected prefixes + # probs_b_prev: prefixes' probability ending with blank in previous step + # probs_nb_prev: prefixes' probability ending with non-blank in previous step + prefix_set_prev, probs_b_prev, probs_nb_prev = { + '\t': 1.0 + }, { + '\t': 1.0 + }, { + '\t': 0.0 + } ## extend prefix in loop - for time_step in xrange(num_time_steps): - # the set containing candidate prefixes - prefix_set_next = {} - probs_b_cur, probs_nb_cur = {}, {} - prob = probs_seq[time_step] - prob_idx = [[i, prob[i]] for i in xrange(len(prob))] + for time_step in xrange(len(probs_seq)): + # prefix_set_next: the set containing candidate prefixes + # probs_b_cur: prefixes' probability ending with blank in current step + # probs_nb_cur: prefixes' probability ending with non-blank in current step + prefix_set_next, probs_b_cur, probs_nb_cur = {}, {}, {} + + prob_idx = list(enumerate(probs_seq[time_step])) cutoff_len = len(prob_idx) #If pruning is enabled - if (cutoff_prob < 1.0): + if cutoff_prob < 1.0: prob_idx = sorted(prob_idx, key=lambda asd: asd[1], reverse=True) - cutoff_len = 0 - cum_prob = 0.0 + cutoff_len, cum_prob = 0, 0.0 for i in xrange(len(prob_idx)): cum_prob += prob_idx[i][1] cutoff_len += 1 @@ -162,54 +173,53 @@ def ctc_beam_search_decoder(probs_seq, prefix_set_prev = dict(prefix_set_prev) beam_result = [] - for (seq, prob) in prefix_set_prev.items(): + for seq, prob in prefix_set_prev.items(): if prob > 0.0 and len(seq) > 1: result = seq[1:] # score last word by external scorer if (ext_scoring_func is not None) and (result[-1] != ' '): prob = prob * ext_scoring_func(result) log_prob = np.log(prob) - beam_result.append([log_prob, result]) + beam_result.append((log_prob, result)) ## output top beam_size decoding results beam_result = sorted(beam_result, key=lambda asd: asd[0], reverse=True) return beam_result -def ctc_beam_search_decoder_nproc(probs_split, +def ctc_beam_search_decoder_batch(probs_split, beam_size, vocabulary, - blank_id=0, + blank_id, + num_processes, cutoff_prob=1.0, - ext_scoring_func=None, - num_processes=None): - '''Beam search decoder using multiple processes. + ext_scoring_func=None): + """CTC beam search decoder using multiple processes. - :param probs_seq: 3-D list with length batch_size, each element - is a 2-D list of probabilities can be used by - ctc_beam_search_decoder. + :param probs_seq: 3-D list with each element as an instance of 2-D list + of probabilities used by ctc_beam_search_decoder(). :type probs_seq: 3-D list :param beam_size: Width for beam search. :type beam_size: int :param vocabulary: Vocabulary list. :type vocabulary: list - :param blank_id: ID of blank, default 0. + :param blank_id: ID of blank. :type blank_id: int + :param num_processes: Number of parallel processes. + :type num_processes: int :param cutoff_prob: Cutoff probability in pruning, - default 0, no pruning. + default 1.0, no pruning. + :param num_processes: Number of parallel processes. + :type num_processes: int :type cutoff_prob: float - :param ext_scoring_func: External defined scoring function for + :param ext_scoring_func: External scoring function for partially decoded sentence, e.g. word count - and language model. - :type external_scoring_function: function - :param num_processes: Number of processes, default None, equal to the - number of CPUs. - :type num_processes: int - :return: Decoding log probabilities and result sentences in descending order. + or language model. + :type external_scoring_function: callable + :return: List of tuples of log probability and sentence as decoding + results, in descending order of the probability. :rtype: list - ''' - if num_processes is None: - num_processes = multiprocessing.cpu_count() + """ if not num_processes > 0: raise ValueError("Number of processes must be positive!") @@ -227,7 +237,5 @@ def ctc_beam_search_decoder_nproc(probs_split, pool.close() pool.join() - beam_search_results = [] - for result in results: - beam_search_results.append(result.get()) + beam_search_results = [result.get() for result in results] return beam_search_results diff --git a/evaluate.py b/evaluate.py index a7b8e221..7ef32ad1 100644 --- a/evaluate.py +++ b/evaluate.py @@ -3,22 +3,22 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import paddle.v2 as paddle import distutils.util import argparse import gzip +import paddle.v2 as paddle from data_utils.data import DataGenerator from model import deep_speech2 from decoder import * -from scorer import Scorer +from lm.lm_scorer import LmScorer from error_rate import wer parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( - "--num_samples", + "--batch_size", default=100, type=int, - help="Number of samples for evaluation. (default: %(default)s)") + help="Minibatch size for evaluation. (default: %(default)s)") parser.add_argument( "--num_conv_layers", default=2, @@ -39,6 +39,16 @@ parser.add_argument( default=True, type=distutils.util.strtobool, help="Use gpu or not. (default: %(default)s)") +parser.add_argument( + "--num_threads_data", + default=multiprocessing.cpu_count(), + type=int, + help="Number of cpu threads for preprocessing data. (default: %(default)s)") +parser.add_argument( + "--num_processes_beam_search", + default=multiprocessing.cpu_count(), + type=int, + help="Number of cpu processes for beam search. (default: %(default)s)") parser.add_argument( "--mean_std_filepath", default='mean_std.npz', @@ -46,10 +56,10 @@ parser.add_argument( help="Manifest path for normalizer. (default: %(default)s)") parser.add_argument( "--decode_method", - default='beam_search_nproc', + default='beam_search', type=str, - help="Method for ctc decoding, best_path, " - "beam_search or beam_search_nproc. (default: %(default)s)") + help="Method for ctc decoding, best_path or beam_search. (default: %(default)s)" +) parser.add_argument( "--language_model_path", default="data/en.00.UNKNOWN.klm", @@ -76,11 +86,6 @@ parser.add_argument( default=500, type=int, help="Width for beam search decoding. (default: %(default)d)") -parser.add_argument( - "--normalizer_manifest_path", - default='data/manifest.libri.train-clean-100', - type=str, - help="Manifest path for normalizer. (default: %(default)s)") parser.add_argument( "--decode_manifest_path", default='data/manifest.libri.test-clean', @@ -88,7 +93,7 @@ parser.add_argument( help="Manifest path for decoding. (default: %(default)s)") parser.add_argument( "--model_filepath", - default='./params.tar.gz', + default='checkpoints/params.latest.tar.gz', type=str, help="Model filepath. (default: %(default)s)") parser.add_argument( @@ -101,12 +106,12 @@ args = parser.parse_args() def evaluate(): """Evaluate on whole test data for DeepSpeech2.""" - # initialize data generator data_generator = DataGenerator( vocab_filepath=args.vocab_filepath, mean_std_filepath=args.mean_std_filepath, - augmentation_config='{}') + augmentation_config='{}', + num_threads=args.num_threads_data) # create network config # paddle.data_type.dense_array is used for variable batch input. @@ -133,7 +138,7 @@ def evaluate(): # prepare infer data batch_reader = data_generator.batch_reader_creator( manifest_path=args.decode_manifest_path, - batch_size=args.num_samples, + batch_size=args.batch_size, sortagrad=False, shuffle_method=None) @@ -142,9 +147,8 @@ def evaluate(): output_layer=output_probs, parameters=parameters) # initialize external scorer for beam search decoding - if args.decode_method == 'beam_search' or \ - args.decode_method == 'beam_search_nproc': - ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path) + if args.decode_method == 'beam_search': + ext_scorer = LmScorer(args.alpha, args.beta, args.language_model_path) wer_counter, wer_sum = 0, 0.0 for infer_data in batch_reader(): @@ -155,56 +159,39 @@ def evaluate(): infer_results[i * num_steps:(i + 1) * num_steps] for i in xrange(0, len(infer_data)) ] - + # target transcription + target_transcription = [ + ''.join([ + data_generator.vocab_list[index] for index in infer_data[i][1] + ]) for i, probs in enumerate(probs_split) + ] # decode and print # best path decode if args.decode_method == "best_path": for i, probs in enumerate(probs_split): - output_transcription = ctc_best_path_decode( + output_transcription = ctc_best_path_decoder( probs_seq=probs, vocabulary=data_generator.vocab_list) - target_transcription = ''.join([ - data_generator.vocab_list[index] - for index in infer_data[i][1] - ]) - wer_sum += wer(target_transcription, output_transcription) + wer_sum += wer(target_transcription[i], output_transcription) wer_counter += 1 - # beam search decode in single process + # beam search decode elif args.decode_method == "beam_search": - for i, probs in enumerate(probs_split): - target_transcription = ''.join([ - data_generator.vocab_list[index] - for index in infer_data[i][1] - ]) - beam_search_result = ctc_beam_search_decoder( - probs_seq=probs, - vocabulary=data_generator.vocab_list, - beam_size=args.beam_size, - blank_id=len(data_generator.vocab_list), - ext_scoring_func=ext_scorer, - cutoff_prob=args.cutoff_prob, ) - wer_sum += wer(target_transcription, beam_search_result[0][1]) - wer_counter += 1 - # beam search using multiple processes - elif args.decode_method == "beam_search_nproc": - beam_search_nproc_results = ctc_beam_search_decoder_nproc( + # beam search using multiple processes + beam_search_results = ctc_beam_search_decoder_batch( probs_split=probs_split, vocabulary=data_generator.vocab_list, beam_size=args.beam_size, blank_id=len(data_generator.vocab_list), + num_processes=args.num_processes_beam_search, ext_scoring_func=ext_scorer, cutoff_prob=args.cutoff_prob, ) - for i, beam_search_result in enumerate(beam_search_nproc_results): - target_transcription = ''.join([ - data_generator.vocab_list[index] - for index in infer_data[i][1] - ]) - wer_sum += wer(target_transcription, beam_search_result[0][1]) + for i, beam_search_result in enumerate(beam_search_results): + wer_sum += wer(target_transcription[i], + beam_search_result[0][1]) wer_counter += 1 else: raise ValueError("Decoding method [%s] is not supported." % decode_method) - print("Cur WER = %f" % (wer_sum / wer_counter)) print("Final WER = %f" % (wer_sum / wer_counter)) diff --git a/infer.py b/infer.py index 069b9e3e..5f0f268a 100644 --- a/infer.py +++ b/infer.py @@ -11,14 +11,14 @@ import paddle.v2 as paddle from data_utils.data import DataGenerator from model import deep_speech2 from decoder import * -from scorer import Scorer +from lm.lm_scorer import LmScorer from error_rate import wer import utils parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( "--num_samples", - default=100, + default=10, type=int, help="Number of samples for inference. (default: %(default)s)") parser.add_argument( @@ -46,6 +46,11 @@ parser.add_argument( default=multiprocessing.cpu_count(), type=int, help="Number of cpu threads for preprocessing data. (default: %(default)s)") +parser.add_argument( + "--num_processes_beam_search", + default=multiprocessing.cpu_count(), + type=int, + help="Number of cpu processes for beam search. (default: %(default)s)") parser.add_argument( "--mean_std_filepath", default='mean_std.npz', @@ -53,12 +58,12 @@ parser.add_argument( help="Manifest path for normalizer. (default: %(default)s)") parser.add_argument( "--decode_manifest_path", - default='data/manifest.libri.test-100sample', + default='datasets/manifest.test', type=str, help="Manifest path for decoding. (default: %(default)s)") parser.add_argument( "--model_filepath", - default='checkpoints/params.latest.tar.gz', + default='checkpoints/params.tar.gz.41', type=str, help="Model filepath. (default: %(default)s)") parser.add_argument( @@ -68,12 +73,10 @@ parser.add_argument( help="Vocabulary filepath. (default: %(default)s)") parser.add_argument( "--decode_method", - default='beam_search_nproc', + default='beam_search', type=str, - help="Method for ctc decoding:" - " best_path," - " beam_search, " - " or beam_search_nproc. (default: %(default)s)") + help="Method for ctc decoding: best_path or beam_search. (default: %(default)s)" +) parser.add_argument( "--beam_size", default=500, @@ -86,7 +89,7 @@ parser.add_argument( help="Number of output per sample in beam search. (default: %(default)d)") parser.add_argument( "--language_model_path", - default="data/en.00.UNKNOWN.klm", + default="lm/data/en.00.UNKNOWN.klm", type=str, help="Path for language model. (default: %(default)s)") parser.add_argument( @@ -143,6 +146,7 @@ def infer(): batch_reader = data_generator.batch_reader_creator( manifest_path=args.decode_manifest_path, batch_size=args.num_samples, + min_batch_size=1, sortagrad=False, shuffle_method=None) infer_data = batch_reader().next() @@ -156,68 +160,45 @@ def infer(): for i in xrange(len(infer_data)) ] + # targe transcription + target_transcription = [ + ''.join( + [data_generator.vocab_list[index] for index in infer_data[i][1]]) + for i, probs in enumerate(probs_split) + ] + ## decode and print # best path decode wer_sum, wer_counter = 0, 0 if args.decode_method == "best_path": for i, probs in enumerate(probs_split): - target_transcription = ''.join([ - data_generator.vocab_list[index] for index in infer_data[i][1] - ]) - best_path_transcription = ctc_best_path_decode( + best_path_transcription = ctc_best_path_decoder( probs_seq=probs, vocabulary=data_generator.vocab_list) print("\nTarget Transcription: %s\nOutput Transcription: %s" % - (target_transcription, best_path_transcription)) - wer_cur = wer(target_transcription, best_path_transcription) + (target_transcription[i], best_path_transcription)) + wer_cur = wer(target_transcription[i], best_path_transcription) wer_sum += wer_cur wer_counter += 1 print("cur wer = %f, average wer = %f" % (wer_cur, wer_sum / wer_counter)) # beam search decode elif args.decode_method == "beam_search": - ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path) - for i, probs in enumerate(probs_split): - target_transcription = ''.join([ - data_generator.vocab_list[index] for index in infer_data[i][1] - ]) - beam_search_result = ctc_beam_search_decoder( - probs_seq=probs, - vocabulary=data_generator.vocab_list, - beam_size=args.beam_size, - blank_id=len(data_generator.vocab_list), - cutoff_prob=args.cutoff_prob, - ext_scoring_func=ext_scorer, ) - print("\nTarget Transcription:\t%s" % target_transcription) - - for index in xrange(args.num_results_per_sample): - result = beam_search_result[index] - #output: index, log prob, beam result - print("Beam %d: %f \t%s" % (index, result[0], result[1])) - wer_cur = wer(target_transcription, beam_search_result[0][1]) - wer_sum += wer_cur - wer_counter += 1 - print("cur wer = %f , average wer = %f" % - (wer_cur, wer_sum / wer_counter)) - elif args.decode_method == "beam_search_nproc": - ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path) - beam_search_nproc_results = ctc_beam_search_decoder_nproc( + ext_scorer = LmScorer(args.alpha, args.beta, args.language_model_path) + beam_search_batch_results = ctc_beam_search_decoder_batch( probs_split=probs_split, vocabulary=data_generator.vocab_list, beam_size=args.beam_size, blank_id=len(data_generator.vocab_list), + num_processes=args.num_processes_beam_search, cutoff_prob=args.cutoff_prob, ext_scoring_func=ext_scorer, ) - for i, beam_search_result in enumerate(beam_search_nproc_results): - target_transcription = ''.join([ - data_generator.vocab_list[index] for index in infer_data[i][1] - ]) - print("\nTarget Transcription:\t%s" % target_transcription) - + for i, beam_search_result in enumerate(beam_search_batch_results): + print("\nTarget Transcription:\t%s" % target_transcription[i]) for index in xrange(args.num_results_per_sample): result = beam_search_result[index] #output: index, log prob, beam result print("Beam %d: %f \t%s" % (index, result[0], result[1])) - wer_cur = wer(target_transcription, beam_search_result[0][1]) + wer_cur = wer(target_transcription[i], beam_search_result[0][1]) wer_sum += wer_cur wer_counter += 1 print("cur wer = %f , average wer = %f" % diff --git a/lm/__init__.py b/lm/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/scorer.py b/lm/lm_scorer.py similarity index 73% rename from scorer.py rename to lm/lm_scorer.py index 4f468481..1c029e97 100644 --- a/scorer.py +++ b/lm/lm_scorer.py @@ -8,13 +8,16 @@ import kenlm import numpy as np -class Scorer(object): - """External defined scorer to evaluate a sentence in beam search - decoding, consisting of language model and word count. +class LmScorer(object): + """External scorer to evaluate a prefix or whole sentence in + beam search decoding, including the score from n-gram language + model and word count. - :param alpha: Parameter associated with language model. + :param alpha: Parameter associated with language model. Don't use + language model when alpha = 0. :type alpha: float - :param beta: Parameter associated with word count. + :param beta: Parameter associated with word count. Don't use word + count when beta = 0. :type beta: float :model_path: Path to load language model. :type model_path: basestring @@ -28,14 +31,14 @@ class Scorer(object): self._language_model = kenlm.LanguageModel(model_path) # n-gram language model scoring - def language_model_score(self, sentence): + def _language_model_score(self, sentence): #log10 prob of last word log_cond_prob = list( self._language_model.full_scores(sentence, eos=False))[-1][0] return np.power(10, log_cond_prob) # word insertion term - def word_count(self, sentence): + def _word_count(self, sentence): words = sentence.strip().split(' ') return len(words) @@ -51,8 +54,8 @@ class Scorer(object): :return: Evaluation score, in the decimal or log. :rtype: float """ - lm = self.language_model_score(sentence) - word_cnt = self.word_count(sentence) + lm = self._language_model_score(sentence) + word_cnt = self._word_count(sentence) if log == False: score = np.power(lm, self._alpha) \ * np.power(word_cnt, self._beta) diff --git a/lm/run.sh b/lm/run.sh new file mode 100644 index 00000000..bf523740 --- /dev/null +++ b/lm/run.sh @@ -0,0 +1,3 @@ +echo "Downloading language model." + +wget -c ftp://xxx/xxx/en.00.UNKNOWN.klm -P ./data diff --git a/requirements.txt b/requirements.txt index 0183ecf0..ce024591 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ SoundFile==0.9.0.post1 wget==3.2 scipy==0.13.1 +https://github.com/kpu/kenlm/archive/master.zip diff --git a/tests/test_decoders.py b/tests/test_decoders.py index 7fa89c5f..4435355c 100644 --- a/tests/test_decoders.py +++ b/tests/test_decoders.py @@ -53,11 +53,11 @@ class TestDecoders(unittest.TestCase): self.beam_search_result = ['acdc', "b'a"] def test_best_path_decoder_1(self): - bst_result = ctc_best_path_decode(self.probs_seq1, self.vocab_list) + bst_result = ctc_best_path_decoder(self.probs_seq1, self.vocab_list) self.assertEqual(bst_result, self.best_path_result[0]) def test_best_path_decoder_2(self): - bst_result = ctc_best_path_decode(self.probs_seq2, self.vocab_list) + bst_result = ctc_best_path_decoder(self.probs_seq2, self.vocab_list) self.assertEqual(bst_result, self.best_path_result[1]) def test_beam_search_decoder_1(self): @@ -77,7 +77,7 @@ class TestDecoders(unittest.TestCase): self.assertEqual(beam_result[0][1], self.beam_search_result[1]) def test_beam_search_nproc_decoder(self): - beam_results = ctc_beam_search_decoder_nproc( + beam_results = ctc_beam_search_decoder_batch( probs_split=[self.probs_seq1, self.probs_seq2], beam_size=self.beam_size, vocabulary=self.vocab_list, diff --git a/tune.py b/tune.py index 02076349..9cea66b9 100644 --- a/tune.py +++ b/tune.py @@ -3,14 +3,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import paddle.v2 as paddle import distutils.util import argparse import gzip +import paddle.v2 as paddle from data_utils.data import DataGenerator from model import deep_speech2 from decoder import * -from scorer import Scorer +from lm.lm_scorer import LmScorer from error_rate import wer parser = argparse.ArgumentParser(description=__doc__) @@ -39,24 +39,29 @@ parser.add_argument( default=True, type=distutils.util.strtobool, help="Use gpu or not. (default: %(default)s)") +parser.add_argument( + "--num_threads_data", + default=multiprocessing.cpu_count(), + type=int, + help="Number of cpu threads for preprocessing data. (default: %(default)s)") +parser.add_argument( + "--num_processes_beam_search", + default=multiprocessing.cpu_count(), + type=int, + help="Number of cpu processes for beam search. (default: %(default)s)") parser.add_argument( "--mean_std_filepath", default='mean_std.npz', type=str, help="Manifest path for normalizer. (default: %(default)s)") -parser.add_argument( - "--normalizer_manifest_path", - default='data/manifest.libri.train-clean-100', - type=str, - help="Manifest path for normalizer. (default: %(default)s)") parser.add_argument( "--decode_manifest_path", - default='data/manifest.libri.test-100sample', + default='datasets/manifest.test', type=str, help="Manifest path for decoding. (default: %(default)s)") parser.add_argument( "--model_filepath", - default='./params.tar.gz', + default='checkpoints/params.latest.tar.gz', type=str, help="Model filepath. (default: %(default)s)") parser.add_argument( @@ -64,25 +69,14 @@ parser.add_argument( default='datasets/vocab/eng_vocab.txt', type=str, help="Vocabulary filepath. (default: %(default)s)") -parser.add_argument( - "--decode_method", - default='beam_search_nproc', - type=str, - help="Method for decoding, beam_search or beam_search_nproc. (default: %(default)s)" -) parser.add_argument( "--beam_size", default=500, type=int, help="Width for beam search decoding. (default: %(default)d)") -parser.add_argument( - "--num_results_per_sample", - default=1, - type=int, - help="Number of outputs per sample in beam search. (default: %(default)d)") parser.add_argument( "--language_model_path", - default="data/en.00.UNKNOWN.klm", + default="lm/data/en.00.UNKNOWN.klm", type=str, help="Path for language model. (default: %(default)s)") parser.add_argument( @@ -137,7 +131,8 @@ def tune(): data_generator = DataGenerator( vocab_filepath=args.vocab_filepath, mean_std_filepath=args.mean_std_filepath, - augmentation_config='{}') + augmentation_config='{}', + num_threads=args.num_threads_data) # create network config # paddle.data_type.dense_array is used for variable batch input. @@ -188,42 +183,22 @@ def tune(): ## tune parameters in loop for (alpha, beta) in params_grid: wer_sum, wer_counter = 0, 0 - ext_scorer = Scorer(alpha, beta, args.language_model_path) - # beam search decode - if args.decode_method == "beam_search": - for i, probs in enumerate(probs_split): - target_transcription = ''.join([ - data_generator.vocab_list[index] - for index in infer_data[i][1] - ]) - beam_search_result = ctc_beam_search_decoder( - probs_seq=probs, - vocabulary=data_generator.vocab_list, - beam_size=args.beam_size, - blank_id=len(data_generator.vocab_list), - cutoff_prob=args.cutoff_prob, - ext_scoring_func=ext_scorer, ) - wer_sum += wer(target_transcription, beam_search_result[0][1]) - wer_counter += 1 + ext_scorer = LmScorer(alpha, beta, args.language_model_path) # beam search using multiple processes - elif args.decode_method == "beam_search_nproc": - beam_search_nproc_results = ctc_beam_search_decoder_nproc( - probs_split=probs_split, - vocabulary=data_generator.vocab_list, - beam_size=args.beam_size, - cutoff_prob=args.cutoff_prob, - blank_id=len(data_generator.vocab_list), - ext_scoring_func=ext_scorer, ) - for i, beam_search_result in enumerate(beam_search_nproc_results): - target_transcription = ''.join([ - data_generator.vocab_list[index] - for index in infer_data[i][1] - ]) - wer_sum += wer(target_transcription, beam_search_result[0][1]) - wer_counter += 1 - else: - raise ValueError("Decoding method [%s] is not supported." % - decode_method) + beam_search_results = ctc_beam_search_decoder_batch( + probs_split=probs_split, + vocabulary=data_generator.vocab_list, + beam_size=args.beam_size, + cutoff_prob=args.cutoff_prob, + blank_id=len(data_generator.vocab_list), + num_processes=args.num_processes_beam_search, + ext_scoring_func=ext_scorer, ) + for i, beam_search_result in enumerate(beam_search_results): + target_transcription = ''.join([ + data_generator.vocab_list[index] for index in infer_data[i][1] + ]) + wer_sum += wer(target_transcription, beam_search_result[0][1]) + wer_counter += 1 print("alpha = %f\tbeta = %f\tWER = %f" % (alpha, beta, wer_sum / wer_counter)) From 8ba98cb518d494a2f7a63a748cf7f8a82759c3bc Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Tue, 27 Jun 2017 18:35:49 +0800 Subject: [PATCH 20/28] fix decoders' unittest --- infer.py | 2 +- tests/test_decoders.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/infer.py b/infer.py index 5f0f268a..686f2822 100644 --- a/infer.py +++ b/infer.py @@ -63,7 +63,7 @@ parser.add_argument( help="Manifest path for decoding. (default: %(default)s)") parser.add_argument( "--model_filepath", - default='checkpoints/params.tar.gz.41', + default='checkpoints/params.latest.tar.gz', type=str, help="Model filepath. (default: %(default)s)") parser.add_argument( diff --git a/tests/test_decoders.py b/tests/test_decoders.py index 4435355c..a5e19b08 100644 --- a/tests/test_decoders.py +++ b/tests/test_decoders.py @@ -81,7 +81,8 @@ class TestDecoders(unittest.TestCase): probs_split=[self.probs_seq1, self.probs_seq2], beam_size=self.beam_size, vocabulary=self.vocab_list, - blank_id=len(self.vocab_list)) + blank_id=len(self.vocab_list), + num_processes=24) self.assertEqual(beam_results[0][0][1], self.beam_search_result[0]) self.assertEqual(beam_results[1][0][1], self.beam_search_result[1]) From aeccd9851b4d7137bcfa32ebf437298bfb9e478f Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Tue, 27 Jun 2017 20:22:47 +0800 Subject: [PATCH 21/28] append README.md --- README.md | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/README.md b/README.md index 2912ff31..41acf102 100644 --- a/README.md +++ b/README.md @@ -77,3 +77,29 @@ More help for arguments: ``` python infer.py --help ``` + +### Evaluating + +``` +CUDA_VISIBLE_DEVICES=0 python evaluate.py +``` + +More help for arguments: + +``` +python evaluate.py --help +``` + +### Parameters tuning + +Parameters tuning for the CTC beam search decoder + +``` +CUDA_VISIBLE_DEVICES=0 python tune.py +``` + +More help for arguments: + +``` +python tune.py --help +``` From 37e98df74df04bd266913ec1b2665f696a8ba1ca Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Tue, 4 Jul 2017 19:48:56 +0800 Subject: [PATCH 22/28] enable resetting params in scorer --- lm/lm_scorer.py | 5 +++++ tests/test_decoders.py | 2 +- tune.py | 8 ++++++-- 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/lm/lm_scorer.py b/lm/lm_scorer.py index 1c029e97..de41754f 100644 --- a/lm/lm_scorer.py +++ b/lm/lm_scorer.py @@ -42,6 +42,11 @@ class LmScorer(object): words = sentence.strip().split(' ') return len(words) + # reset alpha and beta + def reset_params(self, alpha, beta): + self._alpha = alpha + self._beta = beta + # execute evaluation def __call__(self, sentence, log=False): """Evaluation function, gathering all the different scores diff --git a/tests/test_decoders.py b/tests/test_decoders.py index a5e19b08..99d8a828 100644 --- a/tests/test_decoders.py +++ b/tests/test_decoders.py @@ -76,7 +76,7 @@ class TestDecoders(unittest.TestCase): blank_id=len(self.vocab_list)) self.assertEqual(beam_result[0][1], self.beam_search_result[1]) - def test_beam_search_nproc_decoder(self): + def test_beam_search_decoder_batch(self): beam_results = ctc_beam_search_decoder_batch( probs_split=[self.probs_seq1, self.probs_seq2], beam_size=self.beam_size, diff --git a/tune.py b/tune.py index 9cea66b9..e26bc45c 100644 --- a/tune.py +++ b/tune.py @@ -12,6 +12,7 @@ from model import deep_speech2 from decoder import * from lm.lm_scorer import LmScorer from error_rate import wer +import utils parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( @@ -180,10 +181,13 @@ def tune(): params_grid = [(alpha, beta) for alpha in cand_alphas for beta in cand_betas] + ext_scorer = LmScorer(args.alpha_from, args.beta_from, + args.language_model_path) ## tune parameters in loop - for (alpha, beta) in params_grid: + for alpha, beta in params_grid: wer_sum, wer_counter = 0, 0 - ext_scorer = LmScorer(alpha, beta, args.language_model_path) + # reset scorer + ext_scorer.reset_params(alpha, beta) # beam search using multiple processes beam_search_results = ctc_beam_search_decoder_batch( probs_split=probs_split, From 0cadc56a8417c9f1f0d3e21d3fec4363cf0b00a2 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Wed, 5 Jul 2017 11:05:26 +0800 Subject: [PATCH 23/28] follow comments in code format --- decoder.py | 12 ++++-------- evaluate.py | 4 ++-- infer.py | 2 +- lm/lm_scorer.py | 6 ++---- tune.py | 2 +- 5 files changed, 10 insertions(+), 16 deletions(-) diff --git a/decoder.py b/decoder.py index 4676b02b..a1fadc2c 100644 --- a/decoder.py +++ b/decoder.py @@ -5,6 +5,7 @@ from __future__ import print_function from itertools import groupby import numpy as np +from math import log import multiprocessing @@ -97,13 +98,8 @@ def ctc_beam_search_decoder(probs_seq, # prefix_set_prev: the set containing selected prefixes # probs_b_prev: prefixes' probability ending with blank in previous step # probs_nb_prev: prefixes' probability ending with non-blank in previous step - prefix_set_prev, probs_b_prev, probs_nb_prev = { - '\t': 1.0 - }, { - '\t': 1.0 - }, { - '\t': 0.0 - } + prefix_set_prev = {'\t': 1.0} + probs_b_prev, probs_nb_prev = {'\t': 1.0}, {'\t': 0.0} ## extend prefix in loop for time_step in xrange(len(probs_seq)): @@ -179,7 +175,7 @@ def ctc_beam_search_decoder(probs_seq, # score last word by external scorer if (ext_scoring_func is not None) and (result[-1] != ' '): prob = prob * ext_scoring_func(result) - log_prob = np.log(prob) + log_prob = log(prob) beam_result.append((log_prob, result)) ## output top beam_size decoding results diff --git a/evaluate.py b/evaluate.py index 7ef32ad1..a4f2a690 100644 --- a/evaluate.py +++ b/evaluate.py @@ -62,7 +62,7 @@ parser.add_argument( ) parser.add_argument( "--language_model_path", - default="data/en.00.UNKNOWN.klm", + default="lm/data/1Billion.klm", type=str, help="Path for language model. (default: %(default)s)") parser.add_argument( @@ -88,7 +88,7 @@ parser.add_argument( help="Width for beam search decoding. (default: %(default)d)") parser.add_argument( "--decode_manifest_path", - default='data/manifest.libri.test-clean', + default='datasets/manifest.test', type=str, help="Manifest path for decoding. (default: %(default)s)") parser.add_argument( diff --git a/infer.py b/infer.py index 686f2822..dc143080 100644 --- a/infer.py +++ b/infer.py @@ -89,7 +89,7 @@ parser.add_argument( help="Number of output per sample in beam search. (default: %(default)d)") parser.add_argument( "--language_model_path", - default="lm/data/en.00.UNKNOWN.klm", + default="lm/data/1Billion.klm", type=str, help="Path for language model. (default: %(default)s)") parser.add_argument( diff --git a/lm/lm_scorer.py b/lm/lm_scorer.py index de41754f..463e96d6 100644 --- a/lm/lm_scorer.py +++ b/lm/lm_scorer.py @@ -62,9 +62,7 @@ class LmScorer(object): lm = self._language_model_score(sentence) word_cnt = self._word_count(sentence) if log == False: - score = np.power(lm, self._alpha) \ - * np.power(word_cnt, self._beta) + score = np.power(lm, self._alpha) * np.power(word_cnt, self._beta) else: - score = self._alpha * np.log(lm) \ - + self._beta * np.log(word_cnt) + score = self._alpha * np.log(lm) + self._beta * np.log(word_cnt) return score diff --git a/tune.py b/tune.py index e26bc45c..4e9e268f 100644 --- a/tune.py +++ b/tune.py @@ -77,7 +77,7 @@ parser.add_argument( help="Width for beam search decoding. (default: %(default)d)") parser.add_argument( "--language_model_path", - default="lm/data/en.00.UNKNOWN.klm", + default="lm/data/1Billion.klm", type=str, help="Path for language model. (default: %(default)s)") parser.add_argument( From d15c48d616b299f7a91ad71a49673cc2eeb7ab56 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Tue, 11 Jul 2017 13:32:35 +0800 Subject: [PATCH 24/28] upload the language model --- README.md | 38 ++++++++++++++++++++++++++++++++++++-- evaluate.py | 3 ++- infer.py | 2 +- lm/run.sh | 20 ++++++++++++++++++-- tune.py | 2 +- 5 files changed, 58 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 41acf102..48f4b0db 100644 --- a/README.md +++ b/README.md @@ -66,12 +66,36 @@ More help for arguments: python train.py --help ``` -### Inferencing +### Preparing language model + +The following steps, inference, parameters tuning and evaluating, will require a language model during decoding. +A compressed language model is provided and can be accessed by + +``` +cd ./lm +sh run.sh +``` + +After the downloading is completed, then + +``` +cd .. +``` + +### Inference + +For GPU inference ``` CUDA_VISIBLE_DEVICES=0 python infer.py ``` +For CPU inference + +``` +python infer.py --use_gpu=False +``` + More help for arguments: ``` @@ -92,14 +116,24 @@ python evaluate.py --help ### Parameters tuning -Parameters tuning for the CTC beam search decoder +Usually, the parameters $\alpha$ and $\beta$ for the CTC [prefix beam search](https://arxiv.org/abs/1408.2873) decoder need to be tuned after retraining the acoustic model. + +For GPU tuning ``` CUDA_VISIBLE_DEVICES=0 python tune.py ``` +For CPU tuning + +``` +python tune.py --use_gpu=False +``` + More help for arguments: ``` python tune.py --help ``` + +Then reset parameters with the tuning result before inference or evaluating. diff --git a/evaluate.py b/evaluate.py index a4f2a690..00516dcb 100644 --- a/evaluate.py +++ b/evaluate.py @@ -62,7 +62,7 @@ parser.add_argument( ) parser.add_argument( "--language_model_path", - default="lm/data/1Billion.klm", + default="lm/data/common_crawl_00.prune01111.trie.klm", type=str, help="Path for language model. (default: %(default)s)") parser.add_argument( @@ -139,6 +139,7 @@ def evaluate(): batch_reader = data_generator.batch_reader_creator( manifest_path=args.decode_manifest_path, batch_size=args.batch_size, + min_batch_size=1, sortagrad=False, shuffle_method=None) diff --git a/infer.py b/infer.py index dc143080..bb81feac 100644 --- a/infer.py +++ b/infer.py @@ -89,7 +89,7 @@ parser.add_argument( help="Number of output per sample in beam search. (default: %(default)d)") parser.add_argument( "--language_model_path", - default="lm/data/1Billion.klm", + default="lm/data/common_crawl_00.prune01111.trie.klm", type=str, help="Path for language model. (default: %(default)s)") parser.add_argument( diff --git a/lm/run.sh b/lm/run.sh index bf523740..2108ea55 100644 --- a/lm/run.sh +++ b/lm/run.sh @@ -1,3 +1,19 @@ -echo "Downloading language model." +echo "Downloading language model ..." + +mkdir data + +LM=common_crawl_00.prune01111.trie.klm +MD5="099a601759d467cd0a8523ff939819c5" + +wget -c http://paddlepaddle.bj.bcebos.com/model_zoo/speech/$LM -P ./data + +echo "Checking md5sum ..." +md5_tmp=`md5sum ./data/$LM | awk -F[' '] '{print $1}'` + +if [ $MD5 != $md5_tmp ]; then + echo "Fail to download the language model!" + exit 1 +fi + + -wget -c ftp://xxx/xxx/en.00.UNKNOWN.klm -P ./data diff --git a/tune.py b/tune.py index 4e9e268f..19a2d559 100644 --- a/tune.py +++ b/tune.py @@ -77,7 +77,7 @@ parser.add_argument( help="Width for beam search decoding. (default: %(default)d)") parser.add_argument( "--language_model_path", - default="lm/data/1Billion.klm", + default="lm/data/common_crawl_00.prune01111.trie.klm", type=str, help="Path for language model. (default: %(default)s)") parser.add_argument( From 8ce954671084c55a83e0008aee54395ab76c9670 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Tue, 11 Jul 2017 14:56:37 +0800 Subject: [PATCH 25/28] modify README.md --- README.md | 5 ----- 1 file changed, 5 deletions(-) diff --git a/README.md b/README.md index 48f4b0db..3b20bf49 100644 --- a/README.md +++ b/README.md @@ -74,11 +74,6 @@ A compressed language model is provided and can be accessed by ``` cd ./lm sh run.sh -``` - -After the downloading is completed, then - -``` cd .. ``` From ee5abbe37d5a3e1fd8629a55d4d149ab5612c740 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Tue, 18 Jul 2017 10:14:37 +0800 Subject: [PATCH 26/28] add mfcc feature for DS2 --- README.md | 6 ++- compute_mean_std.py | 8 +++- data_utils/featurizer/audio_featurizer.py | 48 ++++++++++++++++++++-- data_utils/featurizer/speech_featurizer.py | 15 +++---- data_utils/normalizer.py | 2 +- requirements.txt | 1 + train.py | 7 ++++ 7 files changed, 74 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 3b20bf49..a92b671c 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,11 @@ python datasets/librispeech/librispeech.py --help python compute_mean_std.py ``` -`python compute_mean_std.py` computes mean and stdandard deviation for audio features, and save them to a file with a default name `./mean_std.npz`. This file will be used in both training and inferencing. +`python compute_mean_std.py` computes mean and stdandard deviation for audio features, and save them to a file with a default name `./mean_std.npz`. This file will be used in both training and inferencing. The default feature of audio data is power spectrum, currently the mfcc feature is also supported. To train and infer based on mfcc feature, you can regenerate this file by + +``` +python compute_mean_std.py --specgram_type mfcc +``` More help for arguments: diff --git a/compute_mean_std.py b/compute_mean_std.py index 9c301c93..0cc84e73 100644 --- a/compute_mean_std.py +++ b/compute_mean_std.py @@ -10,6 +10,12 @@ from data_utils.featurizer.audio_featurizer import AudioFeaturizer parser = argparse.ArgumentParser( description='Computing mean and stddev for feature normalizer.') +parser.add_argument( + "--specgram_type", + default='linear', + type=str, + help="Feature type of audio data: 'linear' (power spectrum)" + " or 'mfcc'. (default: %(default)s)") parser.add_argument( "--manifest_path", default='datasets/manifest.train', @@ -39,7 +45,7 @@ args = parser.parse_args() def main(): augmentation_pipeline = AugmentationPipeline(args.augmentation_config) - audio_featurizer = AudioFeaturizer() + audio_featurizer = AudioFeaturizer(specgram_type=args.specgram_type) def augment_and_featurize(audio_segment): augmentation_pipeline.transform_audio(audio_segment) diff --git a/data_utils/featurizer/audio_featurizer.py b/data_utils/featurizer/audio_featurizer.py index 4b4d02c6..271e535b 100644 --- a/data_utils/featurizer/audio_featurizer.py +++ b/data_utils/featurizer/audio_featurizer.py @@ -6,13 +6,15 @@ from __future__ import print_function import numpy as np from data_utils import utils from data_utils.audio import AudioSegment +from python_speech_features import mfcc +from python_speech_features import delta class AudioFeaturizer(object): """Audio featurizer, for extracting features from audio contents of AudioSegment or SpeechSegment. - Currently, it only supports feature type of linear spectrogram. + Currently, it supports feature types of linear spectrogram and mfcc. :param specgram_type: Specgram feature type. Options: 'linear'. :type specgram_type: str @@ -20,9 +22,10 @@ class AudioFeaturizer(object): :type stride_ms: float :param window_ms: Window size (in milliseconds) for generating frames. :type window_ms: float - :param max_freq: Used when specgram_type is 'linear', only FFT bins + :param max_freq: When specgram_type is 'linear', only FFT bins corresponding to frequencies between [0, max_freq] are - returned. + returned; when specgram_type is 'mfcc', max_feq is the + highest band edge of mel filters. :types max_freq: None|float :param target_sample_rate: Audio are resampled (if upsampling or downsampling is allowed) to this before @@ -91,6 +94,9 @@ class AudioFeaturizer(object): return self._compute_linear_specgram( samples, sample_rate, self._stride_ms, self._window_ms, self._max_freq) + elif self._specgram_type == 'mfcc': + return self._compute_mfcc(samples, sample_rate, self._stride_ms, + self._window_ms, self._max_freq) else: raise ValueError("Unknown specgram_type %s. " "Supported values: linear." % self._specgram_type) @@ -142,3 +148,39 @@ class AudioFeaturizer(object): # prepare fft frequency list freqs = float(sample_rate) / window_size * np.arange(fft.shape[0]) return fft, freqs + + def _compute_mfcc(self, + samples, + sample_rate, + stride_ms=10.0, + window_ms=20.0, + max_freq=None): + """Compute mfcc from samples.""" + if max_freq is None: + max_freq = sample_rate / 2 + if max_freq > sample_rate / 2: + raise ValueError("max_freq must be greater than half of " + "sample rate.") + if stride_ms > window_ms: + raise ValueError("Stride size must not be greater than " + "window size.") + # compute 13 cepstral coefficients, and the first one is replaced + # by log(frame energy) + mfcc_feat = mfcc( + signal=samples, + samplerate=sample_rate, + winlen=0.001 * window_ms, + winstep=0.001 * stride_ms, + highfreq=max_freq) + # Deltas + d_mfcc_feat = delta(mfcc_feat, 2) + # Deltas-Deltas + dd_mfcc_feat = delta(d_mfcc_feat, 2) + # concat above three features + concat_mfcc_feat = [ + np.concatenate((mfcc_feat[i], d_mfcc_feat[i], dd_mfcc_feat[i])) + for i in xrange(len(mfcc_feat)) + ] + # transpose to be consistent with the linear specgram situation + concat_mfcc_feat = np.transpose(concat_mfcc_feat) + return concat_mfcc_feat diff --git a/data_utils/featurizer/speech_featurizer.py b/data_utils/featurizer/speech_featurizer.py index 26283892..a947588d 100644 --- a/data_utils/featurizer/speech_featurizer.py +++ b/data_utils/featurizer/speech_featurizer.py @@ -11,23 +11,24 @@ class SpeechFeaturizer(object): """Speech featurizer, for extracting features from both audio and transcript contents of SpeechSegment. - Currently, for audio parts, it only supports feature type of linear - spectrogram; for transcript parts, it only supports char-level tokenizing - and conversion into a list of token indices. Note that the token indexing - order follows the given vocabulary file. + Currently, for audio parts, it supports feature types of linear + spectrogram and mfcc; for transcript parts, it only supports char-level + tokenizing and conversion into a list of token indices. Note that the + token indexing order follows the given vocabulary file. :param vocab_filepath: Filepath to load vocabulary for token indices conversion. :type specgram_type: basestring - :param specgram_type: Specgram feature type. Options: 'linear'. + :param specgram_type: Specgram feature type. Options: 'linear', 'mfcc'. :type specgram_type: str :param stride_ms: Striding size (in milliseconds) for generating frames. :type stride_ms: float :param window_ms: Window size (in milliseconds) for generating frames. :type window_ms: float - :param max_freq: Used when specgram_type is 'linear', only FFT bins + :param max_freq: When specgram_type is 'linear', only FFT bins corresponding to frequencies between [0, max_freq] are - returned. + returned; when specgram_type is 'mfcc', max_freq is the + highest band edge of mel filters. :types max_freq: None|float :param target_sample_rate: Speech are resampled (if upsampling or downsampling is allowed) to this before diff --git a/data_utils/normalizer.py b/data_utils/normalizer.py index c123d25d..1f4aae9a 100644 --- a/data_utils/normalizer.py +++ b/data_utils/normalizer.py @@ -16,7 +16,7 @@ class FeatureNormalizer(object): if mean_std_filepath is provided (not None), the normalizer will directly initilize from the file. Otherwise, both manifest_path and featurize_func should be given for on-the-fly mean and stddev computing. - + :param mean_std_filepath: File containing the pre-computed mean and stddev. :type mean_std_filepath: None|basestring :param manifest_path: Manifest of instances for computing mean and stddev. diff --git a/requirements.txt b/requirements.txt index 2ae7d089..721fa281 100755 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,4 @@ wget==3.2 scipy==0.13.1 resampy==0.1.5 https://github.com/kpu/kenlm/archive/master.zip +python_speech_features diff --git a/train.py b/train.py index 3a2d0cad..6481074c 100644 --- a/train.py +++ b/train.py @@ -53,6 +53,12 @@ parser.add_argument( default=True, type=distutils.util.strtobool, help="Use sortagrad or not. (default: %(default)s)") +parser.add_argument( + "--specgram_type", + default='linear', + type=str, + help="Feature type of audio data: 'linear' (power spectrum)" + " or 'mfcc'. (default: %(default)s)") parser.add_argument( "--max_duration", default=27.0, @@ -130,6 +136,7 @@ def train(): augmentation_config=args.augmentation_config, max_duration=args.max_duration, min_duration=args.min_duration, + specgram_type=args.specgram_type, num_threads=args.num_threads_data) train_generator = data_generator() From 724ef185966a379ceca0caa1d0b2200e42bf32f3 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Wed, 19 Jul 2017 22:40:01 +0800 Subject: [PATCH 27/28] update several scripts to support mfcc --- README.md | 2 ++ evaluate.py | 7 +++++++ infer.py | 7 +++++++ tune.py | 7 +++++++ 4 files changed, 23 insertions(+) diff --git a/README.md b/README.md index a92b671c..24f0b3c3 100644 --- a/README.md +++ b/README.md @@ -44,6 +44,8 @@ python compute_mean_std.py python compute_mean_std.py --specgram_type mfcc ``` +and specify the ```specgram_type``` to ```mfcc``` in each step, including training, inference etc. + More help for arguments: ``` diff --git a/evaluate.py b/evaluate.py index 00516dcb..19eabf4e 100644 --- a/evaluate.py +++ b/evaluate.py @@ -86,6 +86,12 @@ parser.add_argument( default=500, type=int, help="Width for beam search decoding. (default: %(default)d)") +parser.add_argument( + "--specgram_type", + default='linear', + type=str, + help="Feature type of audio data: 'linear' (power spectrum)" + " or 'mfcc'. (default: %(default)s)") parser.add_argument( "--decode_manifest_path", default='datasets/manifest.test', @@ -111,6 +117,7 @@ def evaluate(): vocab_filepath=args.vocab_filepath, mean_std_filepath=args.mean_std_filepath, augmentation_config='{}', + specgram_type=args.specgram_type, num_threads=args.num_threads_data) # create network config diff --git a/infer.py b/infer.py index bb81feac..81752630 100644 --- a/infer.py +++ b/infer.py @@ -51,6 +51,12 @@ parser.add_argument( default=multiprocessing.cpu_count(), type=int, help="Number of cpu processes for beam search. (default: %(default)s)") +parser.add_argument( + "--specgram_type", + default='linear', + type=str, + help="Feature type of audio data: 'linear' (power spectrum)" + " or 'mfcc'. (default: %(default)s)") parser.add_argument( "--mean_std_filepath", default='mean_std.npz', @@ -118,6 +124,7 @@ def infer(): vocab_filepath=args.vocab_filepath, mean_std_filepath=args.mean_std_filepath, augmentation_config='{}', + specgram_type=args.specgram_type, num_threads=args.num_threads_data) # create network config diff --git a/tune.py b/tune.py index 19a2d559..2fcca486 100644 --- a/tune.py +++ b/tune.py @@ -50,6 +50,12 @@ parser.add_argument( default=multiprocessing.cpu_count(), type=int, help="Number of cpu processes for beam search. (default: %(default)s)") +parser.add_argument( + "--specgram_type", + default='linear', + type=str, + help="Feature type of audio data: 'linear' (power spectrum)" + " or 'mfcc'. (default: %(default)s)") parser.add_argument( "--mean_std_filepath", default='mean_std.npz', @@ -133,6 +139,7 @@ def tune(): vocab_filepath=args.vocab_filepath, mean_std_filepath=args.mean_std_filepath, augmentation_config='{}', + specgram_type=args.specgram_type, num_threads=args.num_threads_data) # create network config From cb0680e8c49ffa23d2fb7857d1a3fd39d6e48ac1 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Thu, 20 Jul 2017 11:47:46 +0800 Subject: [PATCH 28/28] follow comments to modify README.md --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 24f0b3c3..3010c0e5 100644 --- a/README.md +++ b/README.md @@ -38,13 +38,13 @@ python datasets/librispeech/librispeech.py --help python compute_mean_std.py ``` -`python compute_mean_std.py` computes mean and stdandard deviation for audio features, and save them to a file with a default name `./mean_std.npz`. This file will be used in both training and inferencing. The default feature of audio data is power spectrum, currently the mfcc feature is also supported. To train and infer based on mfcc feature, you can regenerate this file by +It will compute mean and stdandard deviation for audio features, and save them to a file with a default name `./mean_std.npz`. This file will be used in both training and inferencing. The default feature of audio data is power spectrum, and the mfcc feature is also supported. To train and infer based on mfcc feature, please generate this file by ``` python compute_mean_std.py --specgram_type mfcc ``` -and specify the ```specgram_type``` to ```mfcc``` in each step, including training, inference etc. +and specify ```--specgram_type mfcc``` when running train.py, infer.py, evaluator.py or tune.py. More help for arguments: