""" CTC-like decoder utilitis. """ from itertools import groupby import numpy as np import copy import kenlm def ctc_best_path_decode(probs_seq, vocabulary): """ Best path decoding, also called argmax decoding or greedy decoding. Path consisting of the most probable tokens are further post-processed to remove consecutive repetitions and all blanks. :param probs_seq: 2-D list of probabilities over the vocabulary for each character. Each element is a list of float probabilities for one character. :type probs_seq: list :param vocabulary: Vocabulary list. :type vocabulary: list :return: Decoding result string. :rtype: baseline """ # dimension verification for probs in probs_seq: if not len(probs) == len(vocabulary) + 1: raise ValueError("probs_seq dimension mismatchedd with vocabulary") # argmax to get the best index for each time step max_index_list = list(np.array(probs_seq).argmax(axis=1)) # remove consecutive duplicate indexes index_list = [index_group[0] for index_group in groupby(max_index_list)] # remove blank indexes blank_index = len(vocabulary) index_list = [index for index in index_list if index != blank_index] # convert index list to string return ''.join([vocabulary[index] for index in index_list]) class Scorer(object): """ External defined scorer to evaluate a sentence in beam search decoding, consisting of language model and word count. :param alpha: Parameter associated with language model. :type alpha: float :param beta: Parameter associated with word count. :type beta: float :model_path: Path to load language model. :type model_path: basestring """ def __init__(self, alpha, beta, model_path): self._alpha = alpha self._beta = beta self._language_model = kenlm.LanguageModel(model_path) def language_model_score(self, sentence, bos=True, eos=False): log_prob = self._language_model.score(sentence, bos, eos) return np.power(10, log_prob) def word_count(self, sentence): words = sentence.strip().split(' ') return len(words) # execute evaluation def evaluate(self, sentence, bos=True, eos=False): lm = self.language_model_score(sentence, bos, eos) word_count = self.word_count(sentence) score = np.power(lm, self._alpha) \ * np.power(word_count, self._beta) return score def ctc_beam_search_decoder(probs_seq, beam_size, vocabulary, ext_scoring_func=None, blank_id=0): ''' Beam search decoder for CTC-trained network, using beam search with width beam_size to find many paths to one label, return beam_size labels in the order of probabilities. The implementation is based on Prefix Beam Search(https://arxiv.org/abs/1408.2873), and the unclear part is redesigned, need to be verified. :param probs_seq: 2-D list with length max_time_steps, each element is a list of normalized probabilities over vocabulary and blank for one time step. :type probs_seq: 2-D list :param beam_size: Width for beam search. :type beam_size: int :param vocabulary: Vocabulary list. :type vocabulary: list :param ext_scoring_func: External defined scoring function for partially decoded sentence, e.g. word count and language model. :type external_scoring_function: function :param blank_id: id of blank, default 0. :type blank_id: int :return: Decoding log probability and result string. :rtype: list ''' for prob_list in probs_seq: if not len(prob_list) == len(vocabulary) + 1: raise ValueError("probs dimension mismatchedd with vocabulary") max_time_steps = len(probs_seq) if not max_time_steps > 0: raise ValueError("probs_seq shouldn't be empty") probs_dim = len(probs_seq[0]) if not blank_id < probs_dim: raise ValueError("blank_id shouldn't be greater than probs dimension") if ' ' not in vocabulary: raise ValueError("space doesn't exist in vocabulary") space_id = vocabulary.index(' ') # function to convert ids in string to list def ids_str2list(ids_str): ids_str = ids_str.split(' ') ids_list = [int(elem) for elem in ids_str] return ids_list # function to convert ids list to sentence def ids2sentence(ids_list, vocab): return ''.join([vocab[ids] for ids in ids_list]) ## initialize # the set containing selected prefixes prefix_set_prev = {'-1': 1.0} probs_b, probs_nb = {'-1': 1.0}, {'-1': 0.0} ## extend prefix in loop for time_step in range(max_time_steps): # the set containing candidate prefixes prefix_set_next = {} probs_b_cur, probs_nb_cur = {}, {} for l in prefix_set_prev: prob = probs_seq[time_step] # convert ids in string to list ids_list = ids_str2list(l) end_id = ids_list[-1] if not prefix_set_next.has_key(l): probs_b_cur[l], probs_nb_cur[l] = 0.0, 0.0 # extend prefix by travering vocabulary for c in range(0, probs_dim): if c == blank_id: probs_b_cur[l] += prob[c] * (probs_b[l] + probs_nb[l]) else: l_plus = l + ' ' + str(c) if not prefix_set_next.has_key(l_plus): probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0 if c == end_id: probs_nb_cur[l_plus] += prob[c] * probs_b[l] probs_nb_cur[l] += prob[c] * probs_nb[l] elif c == space_id: if ext_scoring_func is None: score = 1.0 else: prefix_sent = ids2sentence(ids_list, vocabulary) score = ext_scoring_func(prefix_sent) probs_nb_cur[l_plus] += score * prob[c] * ( probs_b[l] + probs_nb[l]) else: probs_nb_cur[l_plus] += prob[c] * ( probs_b[l] + probs_nb[l]) # add l_plus into prefix_set_next prefix_set_next[l_plus] = probs_nb_cur[ l_plus] + probs_b_cur[l_plus] # add l into prefix_set_next prefix_set_next[l] = probs_b_cur[l] + probs_nb_cur[l] # update probs probs_b, probs_nb = copy.deepcopy(probs_b_cur), copy.deepcopy( probs_nb_cur) ## store top beam_size prefixes prefix_set_prev = sorted( prefix_set_next.iteritems(), key=lambda asd: asd[1], reverse=True) if beam_size < len(prefix_set_prev): prefix_set_prev = prefix_set_prev[:beam_size] prefix_set_prev = dict(prefix_set_prev) beam_result = [] for (seq, prob) in prefix_set_prev.items(): if prob > 0.0: ids_list = ids_str2list(seq)[1:] result = ids2sentence(ids_list, vocabulary) log_prob = np.log(prob) beam_result.append([log_prob, result]) ## output top beam_size decoding results beam_result = sorted(beam_result, key=lambda asd: asd[0], reverse=True) return beam_result