update annotations

pull/2/head
Yibing Liu 8 years ago
parent c943ca79ac
commit cfe9d22866

@ -10,12 +10,6 @@ import numpy as np
vocab = ['-', '_', 'a']
def ids_str2list(ids_str):
ids_str = ids_str.split(' ')
ids_list = [int(elem) for elem in ids_str]
return ids_list
def ids_list2str(ids_list):
ids_str = [str(elem) for elem in ids_list]
ids_str = ' '.join(ids_str)
@ -39,21 +33,45 @@ def ctc_beam_search_decoder(input_probs_matrix,
space_id=1,
num_results_per_sample=None):
'''
beam search decoder for CTC-trained network, called outside of the recurrent group.
adapted from Algorithm 1 in https://arxiv.org/abs/1408.2873.
Beam search decoder for CTC-trained network, adapted from Algorithm 1
in https://arxiv.org/abs/1408.2873.
:param input_probs_matrix: probs matrix for input sequence, row major
:type input_probs_matrix: 2D matrix.
:param beam_size: width for beam search
:type beam_size: int
:max_time_steps: maximum steps' number for input sequence,
<=len(input_probs_matrix)
:type max_time_steps: int
:lang_model: language model for scoring
:type lang_model: function
:param alpha: parameter associated with language model.
:type alpha: float
:param beta: parameter associated with word count
:type beta: float
:param blank_id: id of blank, default 0.
:type blank_id: int
:param space_id: id of space, default 1.
:type space_id: int
:param num_result_per_sample: the number of output decoding results
per given sample, <=beam_size.
:type num_result_per_sample: int
'''
param input_probs_matrix: probs matrix for input sequence, row major
type input_probs_matrix: 2D matrix.
param beam_size: width for beam search
type beam_size: int
max_time_steps: maximum steps' number for input sequence, <=len(input_probs_matrix)
type max_time_steps: int
lang_model: language model for scoring
type lang_model: function
# function to convert ids in string to list
def ids_str2list(ids_str):
ids_str = ids_str.split(' ')
ids_list = [int(elem) for elem in ids_str]
return ids_list
......
# counting words in a character list
def word_count(ids_list):
cnt = 0
for elem in ids_list:
if elem == space_id:
cnt += 1
return cnt
'''
if num_results_per_sample is None:
num_results_per_sample = beam_size
assert num_results_per_sample <= beam_size

Loading…
Cancel
Save