From ac370eca850825cc3cd075f47903722e2805fc5a Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Wed, 7 Jun 2017 09:06:58 +0800 Subject: [PATCH] add annotations --- decoder.py | 11 +++++------ infer.py | 5 +++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/decoder.py b/decoder.py index 91dbfc347..e16d10544 100755 --- a/decoder.py +++ b/decoder.py @@ -68,9 +68,9 @@ class Scorer(object): # execute evaluation def evaluate(self, sentence, bos=True, eos=False): lm = self.language_model_score(sentence, bos, eos) - word_count = self.word_count(sentence) + word_cnt = self.word_count(sentence) score = np.power(lm, self._alpha) \ - * np.power(word_count, self._beta) + * np.power(word_cnt, self._beta) return score @@ -104,19 +104,18 @@ def ctc_beam_search_decoder(probs_seq, :rtype: list ''' - + # dimension check for prob_list in probs_seq: if not len(prob_list) == len(vocabulary) + 1: raise ValueError("probs dimension mismatchedd with vocabulary") - max_time_steps = len(probs_seq) - if not max_time_steps > 0: - raise ValueError("probs_seq shouldn't be empty") + # blank_id check probs_dim = len(probs_seq[0]) if not blank_id < probs_dim: raise ValueError("blank_id shouldn't be greater than probs dimension") + # assign space_id if ' ' not in vocabulary: raise ValueError("space doesn't exist in vocabulary") space_id = vocabulary.index(' ') diff --git a/infer.py b/infer.py index dc46b83e9..be7ecad9f 100644 --- a/infer.py +++ b/infer.py @@ -77,7 +77,7 @@ parser.add_argument( "--language_model_path", default="./data/1Billion.klm", type=str, - help="Path for language model. (default: %(default)d)") + help="Path for language model. (default: %(default)s)") parser.add_argument( "--alpha", default=0.0, @@ -93,7 +93,7 @@ args = parser.parse_args() def infer(): """ - Max-ctc-decoding for DeepSpeech2. + Inference for DeepSpeech2. """ # initialize data generator data_generator = DataGenerator( @@ -174,6 +174,7 @@ def infer(): print("\nTarget Transcription:\t%s" % target_transcription) for index in range(args.num_results_per_sample): result = beam_search_result[index] + #output: index, log prob, beam result print("Beam %d: %f \t%s" % (index, result[0], result[1])) else: raise ValueError("Decoding method [%s] is not supported." % method)