diff --git a/alphabet.py b/alphabet.py new file mode 100644 index 000000000..475be152f --- /dev/null +++ b/alphabet.py @@ -0,0 +1,42 @@ +import codecs + + +class Alphabet(object): + def __init__(self, config_file): + self._config_file = config_file + self._label_to_str = [] + self._str_to_label = {} + self._size = 0 + self.blank_token = 1 + with codecs.open(config_file, 'r', 'utf-8') as fin: + for line in fin: + if line[0:2] == '\\#': + line = '#\n' + elif line[0] == '#': + continue + self._label_to_str += line[:-1] # remove the line ending + self._str_to_label[line[:-1]] = self._size + self._size += 1 + + def string_from_label(self, label): + return self._label_to_str[label] + + def label_from_string(self, string): + try: + return self._str_to_label[string] + except KeyError as e: + raise KeyError( + '''ERROR: Your transcripts contain characters which do not occur in data/alphabet.txt! Use util/check_characters.py to see what characters are in your {train,dev,test}.csv transcripts, and then add all these to data/alphabet.txt.''' + ).with_traceback(e.__traceback__) + + def decode(self, labels): + res = '' + for label in labels: + res += self.string_from_label(label) + return res + + def size(self): + return self._size + + def config_file(self): + return self._config_file diff --git a/model_utils/model.py b/model_utils/model.py index 4b3764bf2..543e246e3 100644 --- a/model_utils/model.py +++ b/model_utils/model.py @@ -1,4 +1,4 @@ -"""Contains DeepSpeech2 model.""" + from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -12,9 +12,10 @@ import copy import inspect from distutils.dir_util import mkpath import paddle.v2 as paddle -from decoders.swig_wrapper import Scorer -from decoders.swig_wrapper import ctc_greedy_decoder -from decoders.swig_wrapper import ctc_beam_search_decoder_batch + +from ds_ctcdecoder import Scorer, swigwrapper +from alphabet import Alphabet +from words import Words from model_utils.network import deep_speech_v2_network logging.basicConfig( @@ -205,26 +206,8 @@ class DeepSpeech2Model(object): ] return probs_split - def decode_batch_greedy(self, probs_split, vocab_list): - """Decode by best path for a batch of probs matrix input. - - :param probs_split: List of 2-D probability matrix, and each consists - of prob vectors for one speech utterancce. - :param probs_split: List of matrix - :param vocab_list: List of tokens in the vocabulary, for decoding. - :type vocab_list: list - :return: List of transcription texts. - :rtype: List of basestring - """ - results = [] - for i, probs in enumerate(probs_split): - output_transcription = ctc_greedy_decoder( - probs_seq=probs, vocabulary=vocab_list) - results.append(output_transcription) - return results - - def init_ext_scorer(self, beam_alpha, beam_beta, language_model_path, - vocab_list): + def init_ext_scorer(self, beam_alpha, beam_beta, + language_model_path, trie_path, alphabet): """Initialize the external scorer. :param beam_alpha: Parameter associated with language model. @@ -236,39 +219,31 @@ class DeepSpeech2Model(object): None, and the decoding method will be pure beam search without scorer. :type language_model_path: basestring|None - :param vocab_list: List of tokens in the vocabulary, for decoding. - :type vocab_list: list + :param trie_path: Filepath for trie + :type trie_path: basestring\None + :param alphabet: Alphabet class for decoding. + :type alphabet: Alphabet """ if language_model_path != '': self.logger.info("begin to initialize the external scorer " "for decoding") self._ext_scorer = Scorer(beam_alpha, beam_beta, - language_model_path, vocab_list) - lm_char_based = self._ext_scorer.is_character_based() - lm_max_order = self._ext_scorer.get_max_order() - lm_dict_size = self._ext_scorer.get_dict_size() - self.logger.info("language model: " - "is_character_based = %d," % lm_char_based + - " max_order = %d," % lm_max_order + - " dict_size = %d" % lm_dict_size) + language_model_path, trie_path, + alphabet) self.logger.info("end initializing scorer") else: self._ext_scorer = None self.logger.info("no language model provided, " "decoding by pure beam search without scorer.") - def decode_batch_beam_search(self, probs_split, beam_alpha, beam_beta, - beam_size, cutoff_prob, cutoff_top_n, - vocab_list, num_processes): + def decode_beam_search(self, probs_split, beam_size, + cutoff_prob, cutoff_top_n, + alphabet): """Decode by beam search for a batch of probs matrix input. :param probs_split: List of 2-D probability matrix, and each consists of prob vectors for one speech utterancce. :param probs_split: List of matrix - :param beam_alpha: Parameter associated with language model. - :type beam_alpha: float - :param beam_beta: Parameter associated with word count. - :type beam_beta: float :param beam_size: Width for Beam search. :type beam_size: int :param cutoff_prob: Cutoff probability in pruning, @@ -278,28 +253,24 @@ class DeepSpeech2Model(object): characters with highest probs in vocabulary will be used in beam search, default 40. :type cutoff_top_n: int - :param vocab_list: List of tokens in the vocabulary, for decoding. - :type vocab_list: list - :param num_processes: Number of processes (CPU) for decoder. - :type num_processes: int + :param alphabet: alphabet used for decoding. + :type alphabet: Alphabet :return: List of transcription texts. :rtype: List of basestring """ - if self._ext_scorer != None: - self._ext_scorer.reset_params(beam_alpha, beam_beta) # beam search decode - num_processes = min(num_processes, len(probs_split)) - beam_search_results = ctc_beam_search_decoder_batch( - probs_split=probs_split, - vocabulary=vocab_list, - beam_size=beam_size, - num_processes=num_processes, - ext_scoring_func=self._ext_scorer, - cutoff_prob=cutoff_prob, - cutoff_top_n=cutoff_top_n) - - results = [result[0][1] for result in beam_search_results] - return results + + metadata = swigwrapper.ctc_beam_search_decoder( + probs_split[0], + alphabet.config_file(), + beam_size, + cutoff_prob, + cutoff_top_n, + self._ext_scorer) + + words = Words(probs_split[0], metadata[0], alphabet) + + return words def _adapt_feeding_dict(self, feeding_dict): """Adapt feeding dict according to network struct. diff --git a/words.py b/words.py new file mode 100644 index 000000000..4778a6b6a --- /dev/null +++ b/words.py @@ -0,0 +1,53 @@ +import json + + +class Words(object): + + def __init__(self, prob_split, metadata, alphabet, frame_to_sec=.03): + + self.raw_output = '' + self.extended_output = [] + + word = '' + start_step, confidence, num_char = 0, 1.0, 0 + metadata_size = metadata.tokens.size() + + for i in range(metadata_size): + token = metadata.tokens[i] + letter = alphabet.string_from_label(token) + time_step = metadata.timesteps[i] + + # prepare raw output + self.raw_output += letter + + # prepare extended output + if token != alphabet.blank_token: + word.append(letter) + confidence *= prob_split[time_step][token] + num_char += 1 + if len(word) == 1: + start_step = time_step + + if token == alphabet.blank_token or i == metadata_size-1: + duration_step = time_step - start_step + + if duration_step < 0: + duration_step = 0 + + self.extended_output.append({"word": word, + "start_time": frame_to_sec * start_step, + "duration": frame_to_sec * duration_step, + "confidence": confidence**(1.0/num_char)}) + # reset + word = '' + start_step, confidence, num_char = 0, 1.0, 0 + + def to_json(self): + return json.dumps({"raw_output": self.raw_output, + "extended_output": self.extended_output}) + + def save_json(self, file_path): + with open(file_path, 'w') as outfile: + json.dump({"raw_output": self.raw_output, + "extended_output": self.extended_output}, + outfile)