add annotations

pull/2/head
Yibing Liu 8 years ago
parent 51f35a5372
commit ac370eca85

@ -68,9 +68,9 @@ class Scorer(object):
# execute evaluation # execute evaluation
def evaluate(self, sentence, bos=True, eos=False): def evaluate(self, sentence, bos=True, eos=False):
lm = self.language_model_score(sentence, bos, eos) lm = self.language_model_score(sentence, bos, eos)
word_count = self.word_count(sentence) word_cnt = self.word_count(sentence)
score = np.power(lm, self._alpha) \ score = np.power(lm, self._alpha) \
* np.power(word_count, self._beta) * np.power(word_cnt, self._beta)
return score return score
@ -104,19 +104,18 @@ def ctc_beam_search_decoder(probs_seq,
:rtype: list :rtype: list
''' '''
# dimension check
for prob_list in probs_seq: for prob_list in probs_seq:
if not len(prob_list) == len(vocabulary) + 1: if not len(prob_list) == len(vocabulary) + 1:
raise ValueError("probs dimension mismatchedd with vocabulary") raise ValueError("probs dimension mismatchedd with vocabulary")
max_time_steps = len(probs_seq) max_time_steps = len(probs_seq)
if not max_time_steps > 0:
raise ValueError("probs_seq shouldn't be empty")
# blank_id check
probs_dim = len(probs_seq[0]) probs_dim = len(probs_seq[0])
if not blank_id < probs_dim: if not blank_id < probs_dim:
raise ValueError("blank_id shouldn't be greater than probs dimension") raise ValueError("blank_id shouldn't be greater than probs dimension")
# assign space_id
if ' ' not in vocabulary: if ' ' not in vocabulary:
raise ValueError("space doesn't exist in vocabulary") raise ValueError("space doesn't exist in vocabulary")
space_id = vocabulary.index(' ') space_id = vocabulary.index(' ')

@ -77,7 +77,7 @@ parser.add_argument(
"--language_model_path", "--language_model_path",
default="./data/1Billion.klm", default="./data/1Billion.klm",
type=str, type=str,
help="Path for language model. (default: %(default)d)") help="Path for language model. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--alpha", "--alpha",
default=0.0, default=0.0,
@ -93,7 +93,7 @@ args = parser.parse_args()
def infer(): def infer():
""" """
Max-ctc-decoding for DeepSpeech2. Inference for DeepSpeech2.
""" """
# initialize data generator # initialize data generator
data_generator = DataGenerator( data_generator = DataGenerator(
@ -174,6 +174,7 @@ def infer():
print("\nTarget Transcription:\t%s" % target_transcription) print("\nTarget Transcription:\t%s" % target_transcription)
for index in range(args.num_results_per_sample): for index in range(args.num_results_per_sample):
result = beam_search_result[index] result = beam_search_result[index]
#output: index, log prob, beam result
print("Beam %d: %f \t%s" % (index, result[0], result[1])) print("Beam %d: %f \t%s" % (index, result[0], result[1]))
else: else:
raise ValueError("Decoding method [%s] is not supported." % method) raise ValueError("Decoding method [%s] is not supported." % method)

Loading…
Cancel
Save