|
|
|
@ -49,16 +49,16 @@ class TestDecoders(unittest.TestCase):
|
|
|
|
|
0.15882358, 0.1235788, 0.23376776, 0.20510435, 0.00279306,
|
|
|
|
|
0.05294827, 0.22298418
|
|
|
|
|
]]
|
|
|
|
|
self.best_path_result = ["ac'bdc", "b'da"]
|
|
|
|
|
self.greedy_result = ["ac'bdc", "b'da"]
|
|
|
|
|
self.beam_search_result = ['acdc', "b'a"]
|
|
|
|
|
|
|
|
|
|
def test_best_path_decoder_1(self):
|
|
|
|
|
bst_result = ctc_best_path_decoder(self.probs_seq1, self.vocab_list)
|
|
|
|
|
self.assertEqual(bst_result, self.best_path_result[0])
|
|
|
|
|
def test_greedy_decoder_1(self):
|
|
|
|
|
bst_result = ctc_greedy_decoder(self.probs_seq1, self.vocab_list)
|
|
|
|
|
self.assertEqual(bst_result, self.greedy_result[0])
|
|
|
|
|
|
|
|
|
|
def test_best_path_decoder_2(self):
|
|
|
|
|
bst_result = ctc_best_path_decoder(self.probs_seq2, self.vocab_list)
|
|
|
|
|
self.assertEqual(bst_result, self.best_path_result[1])
|
|
|
|
|
def test_greedy_decoder_2(self):
|
|
|
|
|
bst_result = ctc_greedy_decoder(self.probs_seq2, self.vocab_list)
|
|
|
|
|
self.assertEqual(bst_result, self.greedy_result[1])
|
|
|
|
|
|
|
|
|
|
def test_beam_search_decoder_1(self):
|
|
|
|
|
beam_result = ctc_beam_search_decoder(
|
|
|
|
|