|
|
|
@ -4,7 +4,7 @@ from __future__ import division
|
|
|
|
|
from __future__ import print_function
|
|
|
|
|
|
|
|
|
|
import unittest
|
|
|
|
|
from model_utils import decoder
|
|
|
|
|
from decoders import decoders_deprecated as decoder
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestDecoders(unittest.TestCase):
|
|
|
|
@ -66,16 +66,14 @@ class TestDecoders(unittest.TestCase):
|
|
|
|
|
beam_result = decoder.ctc_beam_search_decoder(
|
|
|
|
|
probs_seq=self.probs_seq1,
|
|
|
|
|
beam_size=self.beam_size,
|
|
|
|
|
vocabulary=self.vocab_list,
|
|
|
|
|
blank_id=len(self.vocab_list))
|
|
|
|
|
vocabulary=self.vocab_list)
|
|
|
|
|
self.assertEqual(beam_result[0][1], self.beam_search_result[0])
|
|
|
|
|
|
|
|
|
|
def test_beam_search_decoder_2(self):
|
|
|
|
|
beam_result = decoder.ctc_beam_search_decoder(
|
|
|
|
|
probs_seq=self.probs_seq2,
|
|
|
|
|
beam_size=self.beam_size,
|
|
|
|
|
vocabulary=self.vocab_list,
|
|
|
|
|
blank_id=len(self.vocab_list))
|
|
|
|
|
vocabulary=self.vocab_list)
|
|
|
|
|
self.assertEqual(beam_result[0][1], self.beam_search_result[1])
|
|
|
|
|
|
|
|
|
|
def test_beam_search_decoder_batch(self):
|
|
|
|
@ -83,7 +81,6 @@ class TestDecoders(unittest.TestCase):
|
|
|
|
|
probs_split=[self.probs_seq1, self.probs_seq2],
|
|
|
|
|
beam_size=self.beam_size,
|
|
|
|
|
vocabulary=self.vocab_list,
|
|
|
|
|
blank_id=len(self.vocab_list),
|
|
|
|
|
num_processes=24)
|
|
|
|
|
self.assertEqual(beam_results[0][0][1], self.beam_search_result[0])
|
|
|
|
|
self.assertEqual(beam_results[1][0][1], self.beam_search_result[1])
|