diff --git a/ctc_beam_search_decoder.py b/ctc_beam_search_decoder.py index 873121b1..3223c0c2 100644 --- a/ctc_beam_search_decoder.py +++ b/ctc_beam_search_decoder.py @@ -10,12 +10,6 @@ import numpy as np vocab = ['-', '_', 'a'] -def ids_str2list(ids_str): - ids_str = ids_str.split(' ') - ids_list = [int(elem) for elem in ids_str] - return ids_list - - def ids_list2str(ids_list): ids_str = [str(elem) for elem in ids_list] ids_str = ' '.join(ids_str) @@ -39,21 +33,45 @@ def ctc_beam_search_decoder(input_probs_matrix, space_id=1, num_results_per_sample=None): ''' - beam search decoder for CTC-trained network, called outside of the recurrent group. - adapted from Algorithm 1 in https://arxiv.org/abs/1408.2873. + Beam search decoder for CTC-trained network, adapted from Algorithm 1 + in https://arxiv.org/abs/1408.2873. + + :param input_probs_matrix: probs matrix for input sequence, row major + :type input_probs_matrix: 2D matrix. + :param beam_size: width for beam search + :type beam_size: int + :max_time_steps: maximum steps' number for input sequence, + <=len(input_probs_matrix) + :type max_time_steps: int + :lang_model: language model for scoring + :type lang_model: function + :param alpha: parameter associated with language model. + :type alpha: float + :param beta: parameter associated with word count + :type beta: float + :param blank_id: id of blank, default 0. + :type blank_id: int + :param space_id: id of space, default 1. + :type space_id: int + :param num_result_per_sample: the number of output decoding results + per given sample, <=beam_size. + :type num_result_per_sample: int + ''' - param input_probs_matrix: probs matrix for input sequence, row major - type input_probs_matrix: 2D matrix. - param beam_size: width for beam search - type beam_size: int - max_time_steps: maximum steps' number for input sequence, <=len(input_probs_matrix) - type max_time_steps: int - lang_model: language model for scoring - type lang_model: function + # function to convert ids in string to list + def ids_str2list(ids_str): + ids_str = ids_str.split(' ') + ids_list = [int(elem) for elem in ids_str] + return ids_list - ...... + # counting words in a character list + def word_count(ids_list): + cnt = 0 + for elem in ids_list: + if elem == space_id: + cnt += 1 + return cnt - ''' if num_results_per_sample is None: num_results_per_sample = beam_size assert num_results_per_sample <= beam_size