"""Contains various CTC decoders.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from itertools import groupby import numpy as np import multiprocessing def ctc_best_path_decode(probs_seq, vocabulary): """Best path decoding, also called argmax decoding or greedy decoding. Path consisting of the most probable tokens are further post-processed to remove consecutive repetitions and all blanks. :param probs_seq: 2-D list of probabilities over the vocabulary for each character. Each element is a list of float probabilities for one character. :type probs_seq: list :param vocabulary: Vocabulary list. :type vocabulary: list :return: Decoding result string. :rtype: baseline """ # dimension verification for probs in probs_seq: if not len(probs) == len(vocabulary) + 1: raise ValueError("probs_seq dimension mismatchedd with vocabulary") # argmax to get the best index for each time step max_index_list = list(np.array(probs_seq).argmax(axis=1)) # remove consecutive duplicate indexes index_list = [index_group[0] for index_group in groupby(max_index_list)] # remove blank indexes blank_index = len(vocabulary) index_list = [index for index in index_list if index != blank_index] # convert index list to string return ''.join([vocabulary[index] for index in index_list]) def ctc_beam_search_decoder(probs_seq, beam_size, vocabulary, blank_id=0, cutoff_prob=1.0, ext_scoring_func=None, nproc=False): '''Beam search decoder for CTC-trained network, using beam search with width beam_size to find many paths to one label, return beam_size labels in the descending order of probabilities. The implementation is based on Prefix Beam Search(https://arxiv.org/abs/1408.2873), and the unclear part is redesigned. :param probs_seq: 2-D list with length num_time_steps, each element is a list of normalized probabilities over vocabulary and blank for one time step. :type probs_seq: 2-D list :param beam_size: Width for beam search. :type beam_size: int :param vocabulary: Vocabulary list. :type vocabulary: list :param blank_id: ID of blank, default 0. :type blank_id: int :param cutoff_prob: Cutoff probability in pruning, default 1.0, no pruning. :type cutoff_prob: float :param ext_scoring_func: External defined scoring function for partially decoded sentence, e.g. word count and language model. :type external_scoring_function: function :param nproc: Whether the decoder used in multiprocesses. :type nproc: bool :return: Decoding log probabilities and result sentences in descending order. :rtype: list ''' # dimension check for prob_list in probs_seq: if not len(prob_list) == len(vocabulary) + 1: raise ValueError("probs dimension mismatched with vocabulary") num_time_steps = len(probs_seq) # blank_id check probs_dim = len(probs_seq[0]) if not blank_id < probs_dim: raise ValueError("blank_id shouldn't be greater than probs dimension") # If the decoder called in the multiprocesses, then use the global scorer # instantiated in ctc_beam_search_decoder_nproc(). if nproc is True: global ext_nproc_scorer ext_scoring_func = ext_nproc_scorer ## initialize # the set containing selected prefixes prefix_set_prev = {'\t': 1.0} probs_b_prev, probs_nb_prev = {'\t': 1.0}, {'\t': 0.0} ## extend prefix in loop for time_step in xrange(num_time_steps): # the set containing candidate prefixes prefix_set_next = {} probs_b_cur, probs_nb_cur = {}, {} prob = probs_seq[time_step] prob_idx = [[i, prob[i]] for i in xrange(len(prob))] cutoff_len = len(prob_idx) #If pruning is enabled if (cutoff_prob < 1.0): prob_idx = sorted(prob_idx, key=lambda asd: asd[1], reverse=True) cutoff_len = 0 cum_prob = 0.0 for i in xrange(len(prob_idx)): cum_prob += prob_idx[i][1] cutoff_len += 1 if cum_prob >= cutoff_prob: break prob_idx = prob_idx[0:cutoff_len] for l in prefix_set_prev: if not prefix_set_next.has_key(l): probs_b_cur[l], probs_nb_cur[l] = 0.0, 0.0 # extend prefix by travering prob_idx for index in xrange(cutoff_len): c, prob_c = prob_idx[index][0], prob_idx[index][1] if c == blank_id: probs_b_cur[l] += prob_c * ( probs_b_prev[l] + probs_nb_prev[l]) else: last_char = l[-1] new_char = vocabulary[c] l_plus = l + new_char if not prefix_set_next.has_key(l_plus): probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0 if new_char == last_char: probs_nb_cur[l_plus] += prob_c * probs_b_prev[l] probs_nb_cur[l] += prob_c * probs_nb_prev[l] elif new_char == ' ': if (ext_scoring_func is None) or (len(l) == 1): score = 1.0 else: prefix = l[1:] score = ext_scoring_func(prefix) probs_nb_cur[l_plus] += score * prob_c * ( probs_b_prev[l] + probs_nb_prev[l]) else: probs_nb_cur[l_plus] += prob_c * ( probs_b_prev[l] + probs_nb_prev[l]) # add l_plus into prefix_set_next prefix_set_next[l_plus] = probs_nb_cur[ l_plus] + probs_b_cur[l_plus] # add l into prefix_set_next prefix_set_next[l] = probs_b_cur[l] + probs_nb_cur[l] # update probs probs_b_prev, probs_nb_prev = probs_b_cur, probs_nb_cur ## store top beam_size prefixes prefix_set_prev = sorted( prefix_set_next.iteritems(), key=lambda asd: asd[1], reverse=True) if beam_size < len(prefix_set_prev): prefix_set_prev = prefix_set_prev[:beam_size] prefix_set_prev = dict(prefix_set_prev) beam_result = [] for (seq, prob) in prefix_set_prev.items(): if prob > 0.0 and len(seq) > 1: result = seq[1:] # score last word by external scorer if (ext_scoring_func is not None) and (result[-1] != ' '): prob = prob * ext_scoring_func(result) log_prob = np.log(prob) beam_result.append([log_prob, result]) ## output top beam_size decoding results beam_result = sorted(beam_result, key=lambda asd: asd[0], reverse=True) return beam_result def ctc_beam_search_decoder_nproc(probs_split, beam_size, vocabulary, blank_id=0, cutoff_prob=1.0, ext_scoring_func=None, num_processes=None): '''Beam search decoder using multiple processes. :param probs_seq: 3-D list with length batch_size, each element is a 2-D list of probabilities can be used by ctc_beam_search_decoder. :type probs_seq: 3-D list :param beam_size: Width for beam search. :type beam_size: int :param vocabulary: Vocabulary list. :type vocabulary: list :param blank_id: ID of blank, default 0. :type blank_id: int :param cutoff_prob: Cutoff probability in pruning, default 0, no pruning. :type cutoff_prob: float :param ext_scoring_func: External defined scoring function for partially decoded sentence, e.g. word count and language model. :type external_scoring_function: function :param num_processes: Number of processes, default None, equal to the number of CPUs. :type num_processes: int :return: Decoding log probabilities and result sentences in descending order. :rtype: list ''' if num_processes is None: num_processes = multiprocessing.cpu_count() if not num_processes > 0: raise ValueError("Number of processes must be positive!") # use global variable to pass the externnal scorer to beam search decoder global ext_nproc_scorer ext_nproc_scorer = ext_scoring_func nproc = True pool = multiprocessing.Pool(processes=num_processes) results = [] for i, probs_list in enumerate(probs_split): args = (probs_list, beam_size, vocabulary, blank_id, cutoff_prob, None, nproc) results.append(pool.apply_async(ctc_beam_search_decoder, args)) pool.close() pool.join() beam_search_results = [] for result in results: beam_search_results.append(result.get()) return beam_search_results