|
|
|
@ -81,7 +81,8 @@ class TestDecoders(unittest.TestCase):
|
|
|
|
|
probs_split=[self.probs_seq1, self.probs_seq2],
|
|
|
|
|
beam_size=self.beam_size,
|
|
|
|
|
vocabulary=self.vocab_list,
|
|
|
|
|
blank_id=len(self.vocab_list))
|
|
|
|
|
blank_id=len(self.vocab_list),
|
|
|
|
|
num_processes=24)
|
|
|
|
|
self.assertEqual(beam_results[0][0][1], self.beam_search_result[0])
|
|
|
|
|
self.assertEqual(beam_results[1][0][1], self.beam_search_result[1])
|
|
|
|
|
|
|
|
|
|