You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
101 lines
3.9 KiB
101 lines
3.9 KiB
4 years ago
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
8 years ago
|
"""Test decoders."""
|
||
|
import unittest
|
||
4 years ago
|
|
||
3 years ago
|
from paddlespeech.s2t.decoders import decoders_deprecated as decoder
|
||
8 years ago
|
|
||
|
|
||
|
class TestDecoders(unittest.TestCase):
|
||
|
def setUp(self):
|
||
|
self.vocab_list = ["\'", ' ', 'a', 'b', 'c', 'd']
|
||
|
self.beam_size = 20
|
||
|
self.probs_seq1 = [[
|
||
|
0.06390443, 0.21124858, 0.27323887, 0.06870235, 0.0361254,
|
||
|
0.18184413, 0.16493624
|
||
|
], [
|
||
|
0.03309247, 0.22866108, 0.24390638, 0.09699597, 0.31895462,
|
||
|
0.0094893, 0.06890021
|
||
|
], [
|
||
|
0.218104, 0.19992557, 0.18245131, 0.08503348, 0.14903535,
|
||
|
0.08424043, 0.08120984
|
||
|
], [
|
||
|
0.12094152, 0.19162472, 0.01473646, 0.28045061, 0.24246305,
|
||
|
0.05206269, 0.09772094
|
||
|
], [
|
||
|
0.1333387, 0.00550838, 0.00301669, 0.21745861, 0.20803985,
|
||
|
0.41317442, 0.01946335
|
||
|
], [
|
||
|
0.16468227, 0.1980699, 0.1906545, 0.18963251, 0.19860937,
|
||
|
0.04377724, 0.01457421
|
||
|
]]
|
||
|
self.probs_seq2 = [[
|
||
|
0.08034842, 0.22671944, 0.05799633, 0.36814645, 0.11307441,
|
||
|
0.04468023, 0.10903471
|
||
|
], [
|
||
|
0.09742457, 0.12959763, 0.09435383, 0.21889204, 0.15113123,
|
||
|
0.10219457, 0.20640612
|
||
|
], [
|
||
|
0.45033529, 0.09091417, 0.15333208, 0.07939558, 0.08649316,
|
||
|
0.12298585, 0.01654384
|
||
|
], [
|
||
|
0.02512238, 0.22079203, 0.19664364, 0.11906379, 0.07816055,
|
||
|
0.22538587, 0.13483174
|
||
|
], [
|
||
|
0.17928453, 0.06065261, 0.41153005, 0.1172041, 0.11880313,
|
||
|
0.07113197, 0.04139363
|
||
|
], [
|
||
|
0.15882358, 0.1235788, 0.23376776, 0.20510435, 0.00279306,
|
||
|
0.05294827, 0.22298418
|
||
|
]]
|
||
7 years ago
|
self.greedy_result = ["ac'bdc", "b'da"]
|
||
8 years ago
|
self.beam_search_result = ['acdc', "b'a"]
|
||
|
|
||
7 years ago
|
def test_greedy_decoder_1(self):
|
||
7 years ago
|
bst_result = decoder.ctc_greedy_decoder(self.probs_seq1,
|
||
|
self.vocab_list)
|
||
7 years ago
|
self.assertEqual(bst_result, self.greedy_result[0])
|
||
8 years ago
|
|
||
7 years ago
|
def test_greedy_decoder_2(self):
|
||
7 years ago
|
bst_result = decoder.ctc_greedy_decoder(self.probs_seq2,
|
||
|
self.vocab_list)
|
||
7 years ago
|
self.assertEqual(bst_result, self.greedy_result[1])
|
||
8 years ago
|
|
||
|
def test_beam_search_decoder_1(self):
|
||
7 years ago
|
beam_result = decoder.ctc_beam_search_decoder(
|
||
8 years ago
|
probs_seq=self.probs_seq1,
|
||
|
beam_size=self.beam_size,
|
||
7 years ago
|
vocabulary=self.vocab_list)
|
||
8 years ago
|
self.assertEqual(beam_result[0][1], self.beam_search_result[0])
|
||
|
|
||
|
def test_beam_search_decoder_2(self):
|
||
7 years ago
|
beam_result = decoder.ctc_beam_search_decoder(
|
||
8 years ago
|
probs_seq=self.probs_seq2,
|
||
|
beam_size=self.beam_size,
|
||
7 years ago
|
vocabulary=self.vocab_list)
|
||
8 years ago
|
self.assertEqual(beam_result[0][1], self.beam_search_result[1])
|
||
|
|
||
7 years ago
|
def test_beam_search_decoder_batch(self):
|
||
7 years ago
|
beam_results = decoder.ctc_beam_search_decoder_batch(
|
||
8 years ago
|
probs_split=[self.probs_seq1, self.probs_seq2],
|
||
|
beam_size=self.beam_size,
|
||
|
vocabulary=self.vocab_list,
|
||
8 years ago
|
num_processes=24)
|
||
8 years ago
|
self.assertEqual(beam_results[0][0][1], self.beam_search_result[0])
|
||
|
self.assertEqual(beam_results[1][0][1], self.beam_search_result[1])
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
unittest.main()
|