|
|
|
@ -4,7 +4,7 @@ from __future__ import division
|
|
|
|
|
from __future__ import print_function
|
|
|
|
|
|
|
|
|
|
import unittest
|
|
|
|
|
from decoder import *
|
|
|
|
|
from models import decoder
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestDecoders(unittest.TestCase):
|
|
|
|
@ -53,15 +53,17 @@ class TestDecoders(unittest.TestCase):
|
|
|
|
|
self.beam_search_result = ['acdc', "b'a"]
|
|
|
|
|
|
|
|
|
|
def test_greedy_decoder_1(self):
|
|
|
|
|
bst_result = ctc_greedy_decoder(self.probs_seq1, self.vocab_list)
|
|
|
|
|
bst_result = decoder.ctc_greedy_decoder(self.probs_seq1,
|
|
|
|
|
self.vocab_list)
|
|
|
|
|
self.assertEqual(bst_result, self.greedy_result[0])
|
|
|
|
|
|
|
|
|
|
def test_greedy_decoder_2(self):
|
|
|
|
|
bst_result = ctc_greedy_decoder(self.probs_seq2, self.vocab_list)
|
|
|
|
|
bst_result = decoder.ctc_greedy_decoder(self.probs_seq2,
|
|
|
|
|
self.vocab_list)
|
|
|
|
|
self.assertEqual(bst_result, self.greedy_result[1])
|
|
|
|
|
|
|
|
|
|
def test_beam_search_decoder_1(self):
|
|
|
|
|
beam_result = ctc_beam_search_decoder(
|
|
|
|
|
beam_result = decoder.ctc_beam_search_decoder(
|
|
|
|
|
probs_seq=self.probs_seq1,
|
|
|
|
|
beam_size=self.beam_size,
|
|
|
|
|
vocabulary=self.vocab_list,
|
|
|
|
@ -69,7 +71,7 @@ class TestDecoders(unittest.TestCase):
|
|
|
|
|
self.assertEqual(beam_result[0][1], self.beam_search_result[0])
|
|
|
|
|
|
|
|
|
|
def test_beam_search_decoder_2(self):
|
|
|
|
|
beam_result = ctc_beam_search_decoder(
|
|
|
|
|
beam_result = decoder.ctc_beam_search_decoder(
|
|
|
|
|
probs_seq=self.probs_seq2,
|
|
|
|
|
beam_size=self.beam_size,
|
|
|
|
|
vocabulary=self.vocab_list,
|
|
|
|
@ -77,7 +79,7 @@ class TestDecoders(unittest.TestCase):
|
|
|
|
|
self.assertEqual(beam_result[0][1], self.beam_search_result[1])
|
|
|
|
|
|
|
|
|
|
def test_beam_search_decoder_batch(self):
|
|
|
|
|
beam_results = ctc_beam_search_decoder_batch(
|
|
|
|
|
beam_results = decoder.ctc_beam_search_decoder_batch(
|
|
|
|
|
probs_split=[self.probs_seq1, self.probs_seq2],
|
|
|
|
|
beam_size=self.beam_size,
|
|
|
|
|
vocabulary=self.vocab_list,
|
|
|
|
|