diff --git a/paddlespeech/audio/utils/tensor_utils.py b/paddlespeech/audio/utils/tensor_utils.py index 44dcb52e..e9008f17 100644 --- a/paddlespeech/audio/utils/tensor_utils.py +++ b/paddlespeech/audio/utils/tensor_utils.py @@ -237,7 +237,7 @@ def st_reverse_pad_list(ys_pad: paddle.Tensor, # >>> r_hyps = reverse_pad_list(r_hyps, r_hyps_lens, float(self.ignore_id)) # >>> r_hyps, _ = add_sos_eos(r_hyps, self.sos, self.eos, self.ignore_id) B = ys_pad.shape[0] - _sos = paddle.ones([B, 1], dtype=ys_pad.dtype) * sos + _sos = paddle.full([B, 1], sos, dtype=ys_pad.dtype) max_len = paddle.max(ys_lens) index_range = paddle.arange(0, max_len, 1) seq_len_expand = ys_lens.unsqueeze(1) @@ -279,6 +279,7 @@ def st_reverse_pad_list(ys_pad: paddle.Tensor, # >>> tensor([[3, 2, 1], # >>> [4, 8, 9], # >>> [2, 2, 2]]) + eos = paddle.full([1], eos, dtype=r_hyps.dtype) r_hyps = paddle.where(seq_mask, r_hyps, eos) # >>> r_hyps # >>> tensor([[3, 2, 1], diff --git a/paddlespeech/s2t/io/dataloader.py b/paddlespeech/s2t/io/dataloader.py index 4cc8274f..5ba891c3 100644 --- a/paddlespeech/s2t/io/dataloader.py +++ b/paddlespeech/s2t/io/dataloader.py @@ -361,7 +361,7 @@ class DataLoaderFactory(): elif mode == 'valid': config['manifest'] = config.dev_manifest config['train_mode'] = False - elif model == 'test' or mode == 'align': + elif mode == 'test' or mode == 'align': config['manifest'] = config.test_manifest config['train_mode'] = False config['dither'] = 0.0