|
|
@ -10,12 +10,6 @@ import numpy as np
|
|
|
|
vocab = ['-', '_', 'a']
|
|
|
|
vocab = ['-', '_', 'a']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ids_str2list(ids_str):
|
|
|
|
|
|
|
|
ids_str = ids_str.split(' ')
|
|
|
|
|
|
|
|
ids_list = [int(elem) for elem in ids_str]
|
|
|
|
|
|
|
|
return ids_list
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ids_list2str(ids_list):
|
|
|
|
def ids_list2str(ids_list):
|
|
|
|
ids_str = [str(elem) for elem in ids_list]
|
|
|
|
ids_str = [str(elem) for elem in ids_list]
|
|
|
|
ids_str = ' '.join(ids_str)
|
|
|
|
ids_str = ' '.join(ids_str)
|
|
|
@ -39,21 +33,45 @@ def ctc_beam_search_decoder(input_probs_matrix,
|
|
|
|
space_id=1,
|
|
|
|
space_id=1,
|
|
|
|
num_results_per_sample=None):
|
|
|
|
num_results_per_sample=None):
|
|
|
|
'''
|
|
|
|
'''
|
|
|
|
beam search decoder for CTC-trained network, called outside of the recurrent group.
|
|
|
|
Beam search decoder for CTC-trained network, adapted from Algorithm 1
|
|
|
|
adapted from Algorithm 1 in https://arxiv.org/abs/1408.2873.
|
|
|
|
in https://arxiv.org/abs/1408.2873.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
:param input_probs_matrix: probs matrix for input sequence, row major
|
|
|
|
|
|
|
|
:type input_probs_matrix: 2D matrix.
|
|
|
|
|
|
|
|
:param beam_size: width for beam search
|
|
|
|
|
|
|
|
:type beam_size: int
|
|
|
|
|
|
|
|
:max_time_steps: maximum steps' number for input sequence,
|
|
|
|
|
|
|
|
<=len(input_probs_matrix)
|
|
|
|
|
|
|
|
:type max_time_steps: int
|
|
|
|
|
|
|
|
:lang_model: language model for scoring
|
|
|
|
|
|
|
|
:type lang_model: function
|
|
|
|
|
|
|
|
:param alpha: parameter associated with language model.
|
|
|
|
|
|
|
|
:type alpha: float
|
|
|
|
|
|
|
|
:param beta: parameter associated with word count
|
|
|
|
|
|
|
|
:type beta: float
|
|
|
|
|
|
|
|
:param blank_id: id of blank, default 0.
|
|
|
|
|
|
|
|
:type blank_id: int
|
|
|
|
|
|
|
|
:param space_id: id of space, default 1.
|
|
|
|
|
|
|
|
:type space_id: int
|
|
|
|
|
|
|
|
:param num_result_per_sample: the number of output decoding results
|
|
|
|
|
|
|
|
per given sample, <=beam_size.
|
|
|
|
|
|
|
|
:type num_result_per_sample: int
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
|
|
param input_probs_matrix: probs matrix for input sequence, row major
|
|
|
|
# function to convert ids in string to list
|
|
|
|
type input_probs_matrix: 2D matrix.
|
|
|
|
def ids_str2list(ids_str):
|
|
|
|
param beam_size: width for beam search
|
|
|
|
ids_str = ids_str.split(' ')
|
|
|
|
type beam_size: int
|
|
|
|
ids_list = [int(elem) for elem in ids_str]
|
|
|
|
max_time_steps: maximum steps' number for input sequence, <=len(input_probs_matrix)
|
|
|
|
return ids_list
|
|
|
|
type max_time_steps: int
|
|
|
|
|
|
|
|
lang_model: language model for scoring
|
|
|
|
|
|
|
|
type lang_model: function
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
......
|
|
|
|
# counting words in a character list
|
|
|
|
|
|
|
|
def word_count(ids_list):
|
|
|
|
|
|
|
|
cnt = 0
|
|
|
|
|
|
|
|
for elem in ids_list:
|
|
|
|
|
|
|
|
if elem == space_id:
|
|
|
|
|
|
|
|
cnt += 1
|
|
|
|
|
|
|
|
return cnt
|
|
|
|
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
if num_results_per_sample is None:
|
|
|
|
if num_results_per_sample is None:
|
|
|
|
num_results_per_sample = beam_size
|
|
|
|
num_results_per_sample = beam_size
|
|
|
|
assert num_results_per_sample <= beam_size
|
|
|
|
assert num_results_per_sample <= beam_size
|
|
|
|