add annotations

pull/2/head
Yibing Liu 8 years ago
parent 51f35a5372
commit ac370eca85

@ -68,9 +68,9 @@ class Scorer(object):
# execute evaluation
def evaluate(self, sentence, bos=True, eos=False):
lm = self.language_model_score(sentence, bos, eos)
word_count = self.word_count(sentence)
word_cnt = self.word_count(sentence)
score = np.power(lm, self._alpha) \
* np.power(word_count, self._beta)
* np.power(word_cnt, self._beta)
return score
@ -104,19 +104,18 @@ def ctc_beam_search_decoder(probs_seq,
:rtype: list
'''
# dimension check
for prob_list in probs_seq:
if not len(prob_list) == len(vocabulary) + 1:
raise ValueError("probs dimension mismatchedd with vocabulary")
max_time_steps = len(probs_seq)
if not max_time_steps > 0:
raise ValueError("probs_seq shouldn't be empty")
# blank_id check
probs_dim = len(probs_seq[0])
if not blank_id < probs_dim:
raise ValueError("blank_id shouldn't be greater than probs dimension")
# assign space_id
if ' ' not in vocabulary:
raise ValueError("space doesn't exist in vocabulary")
space_id = vocabulary.index(' ')

@ -77,7 +77,7 @@ parser.add_argument(
"--language_model_path",
default="./data/1Billion.klm",
type=str,
help="Path for language model. (default: %(default)d)")
help="Path for language model. (default: %(default)s)")
parser.add_argument(
"--alpha",
default=0.0,
@ -93,7 +93,7 @@ args = parser.parse_args()
def infer():
"""
Max-ctc-decoding for DeepSpeech2.
Inference for DeepSpeech2.
"""
# initialize data generator
data_generator = DataGenerator(
@ -174,6 +174,7 @@ def infer():
print("\nTarget Transcription:\t%s" % target_transcription)
for index in range(args.num_results_per_sample):
result = beam_search_result[index]
#output: index, log prob, beam result
print("Beam %d: %f \t%s" % (index, result[0], result[1]))
else:
raise ValueError("Decoding method [%s] is not supported." % method)

Loading…
Cancel
Save