|
|
|
@ -1,16 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
import unittest
|
|
|
|
|
|
|
|
|
|
import paddle
|
|
|
|
|
import numpy as np
|
|
|
|
|
import unittest
|
|
|
|
|
|
|
|
|
|
# from paddlespeech.audio.utils.tensor_utils import reverse_pad_list
|
|
|
|
|
import paddlespeech.s2t
|
|
|
|
|
import paddlespeech.s2t # noqa: F401
|
|
|
|
|
from paddlespeech.audio.utils.tensor_utils import add_sos_eos
|
|
|
|
|
from paddlespeech.audio.utils.tensor_utils import pad_sequence
|
|
|
|
|
|
|
|
|
|
# from paddlespeech.audio.utils.tensor_utils import reverse_pad_list
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reverse_pad_list(ys_pad: paddle.Tensor,
|
|
|
|
|
ys_lens: paddle.Tensor,
|
|
|
|
|
pad_value: float=-1.0) -> paddle.Tensor:
|
|
|
|
@ -33,12 +44,22 @@ def reverse_pad_list(ys_pad: paddle.Tensor,
|
|
|
|
|
for y, i in zip(ys_pad, ys_lens)], True, pad_value)
|
|
|
|
|
return r_ys_pad
|
|
|
|
|
|
|
|
|
|
def naive_reverse_pad_list_with_sos_eos(r_hyps, r_hyps_lens, sos=5000, eos=5000, ignore_id=-1):
|
|
|
|
|
|
|
|
|
|
def naive_reverse_pad_list_with_sos_eos(r_hyps,
|
|
|
|
|
r_hyps_lens,
|
|
|
|
|
sos=5000,
|
|
|
|
|
eos=5000,
|
|
|
|
|
ignore_id=-1):
|
|
|
|
|
r_hyps = reverse_pad_list(r_hyps, r_hyps_lens, float(ignore_id))
|
|
|
|
|
r_hyps, _ = add_sos_eos(r_hyps, sos, eos, ignore_id)
|
|
|
|
|
return r_hyps
|
|
|
|
|
|
|
|
|
|
def reverse_pad_list_with_sos_eos(r_hyps, r_hyps_lens, sos=5000, eos=5000, ignore_id=-1):
|
|
|
|
|
|
|
|
|
|
def reverse_pad_list_with_sos_eos(r_hyps,
|
|
|
|
|
r_hyps_lens,
|
|
|
|
|
sos=5000,
|
|
|
|
|
eos=5000,
|
|
|
|
|
ignore_id=-1):
|
|
|
|
|
# >>> r_hyps = reverse_pad_list(r_hyps, r_hyps_lens, float(self.ignore_id))
|
|
|
|
|
# >>> r_hyps, _ = add_sos_eos(r_hyps, self.sos, self.eos, self.ignore_id)
|
|
|
|
|
max_len = paddle.max(r_hyps_lens)
|
|
|
|
@ -73,8 +94,7 @@ def reverse_pad_list_with_sos_eos(r_hyps, r_hyps_lens, sos=5000, eos=5000, ignor
|
|
|
|
|
x_arange = x_arange.reshape(reshape_shape)
|
|
|
|
|
dim_index = paddle.expand(x_arange, index_shape).flatten()
|
|
|
|
|
nd_index.append(dim_index)
|
|
|
|
|
ind2 = paddle.transpose(paddle.stack(nd_index),
|
|
|
|
|
[1, 0]).astype("int64")
|
|
|
|
|
ind2 = paddle.transpose(paddle.stack(nd_index), [1, 0]).astype("int64")
|
|
|
|
|
paddle_out = paddle.gather_nd(x, ind2).reshape(index_shape)
|
|
|
|
|
return paddle_out
|
|
|
|
|
|
|
|
|
@ -106,22 +126,12 @@ class TestU2Model(unittest.TestCase):
|
|
|
|
|
self.sos = 5000
|
|
|
|
|
self.eos = 5000
|
|
|
|
|
self.ignore_id = -1
|
|
|
|
|
self.reverse_hyps = paddle.to_tensor(
|
|
|
|
|
[[ 4, 3, 2, 1, -1],
|
|
|
|
|
[ 5, 4, 3, 2, 1]]
|
|
|
|
|
)
|
|
|
|
|
self.reverse_hyps = paddle.to_tensor([[4, 3, 2, 1, -1],
|
|
|
|
|
[5, 4, 3, 2, 1]])
|
|
|
|
|
self.reverse_hyps_sos_eos = paddle.to_tensor(
|
|
|
|
|
[[self.sos, 4 , 3 , 2 , 1 , self.eos],
|
|
|
|
|
[self.sos, 5 , 4 , 3 , 2 , 1 ]]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.hyps = paddle.to_tensor(
|
|
|
|
|
[
|
|
|
|
|
[1, 2, 3, 4, -1],
|
|
|
|
|
[1, 2, 3, 4, 5]
|
|
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
[[self.sos, 4, 3, 2, 1, self.eos], [self.sos, 5, 4, 3, 2, 1]])
|
|
|
|
|
|
|
|
|
|
self.hyps = paddle.to_tensor([[1, 2, 3, 4, -1], [1, 2, 3, 4, 5]])
|
|
|
|
|
|
|
|
|
|
self.hyps_lens = paddle.to_tensor([4, 5], paddle.int32)
|
|
|
|
|
|
|
|
|
@ -130,16 +140,17 @@ class TestU2Model(unittest.TestCase):
|
|
|
|
|
self.assertSequenceEqual(r_hyps.tolist(), self.reverse_hyps.tolist())
|
|
|
|
|
|
|
|
|
|
def test_naive_reverse_pad_list_with_sos_eos(self):
|
|
|
|
|
r_hyps_sos_eos = naive_reverse_pad_list_with_sos_eos(self.hyps, self.hyps_lens)
|
|
|
|
|
self.assertSequenceEqual(r_hyps_sos_eos.tolist(), self.reverse_hyps_sos_eos.tolist())
|
|
|
|
|
r_hyps_sos_eos = naive_reverse_pad_list_with_sos_eos(self.hyps,
|
|
|
|
|
self.hyps_lens)
|
|
|
|
|
self.assertSequenceEqual(r_hyps_sos_eos.tolist(),
|
|
|
|
|
self.reverse_hyps_sos_eos.tolist())
|
|
|
|
|
|
|
|
|
|
def test_static_reverse_pad_list_with_sos_eos(self):
|
|
|
|
|
r_hyps_sos_eos_static = reverse_pad_list_with_sos_eos(self.hyps, self.hyps_lens)
|
|
|
|
|
self.assertSequenceEqual(r_hyps_sos_eos_static.tolist(), self.reverse_hyps_sos_eos.tolist())
|
|
|
|
|
|
|
|
|
|
r_hyps_sos_eos_static = reverse_pad_list_with_sos_eos(self.hyps,
|
|
|
|
|
self.hyps_lens)
|
|
|
|
|
self.assertSequenceEqual(r_hyps_sos_eos_static.tolist(),
|
|
|
|
|
self.reverse_hyps_sos_eos.tolist())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
unittest.main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|