support bitransformer decoder

pull/2415/head
tianhao zhang 2 years ago
parent 027535dec1
commit d3e5937591

@ -237,7 +237,7 @@ def st_reverse_pad_list(ys_pad: paddle.Tensor,
# >>> r_hyps = reverse_pad_list(r_hyps, r_hyps_lens, float(self.ignore_id))
# >>> r_hyps, _ = add_sos_eos(r_hyps, self.sos, self.eos, self.ignore_id)
B = ys_pad.shape[0]
_sos = paddle.ones([B, 1], dtype=ys_pad.dtype) * sos
_sos = paddle.full([B, 1], sos, dtype=ys_pad.dtype)
max_len = paddle.max(ys_lens)
index_range = paddle.arange(0, max_len, 1)
seq_len_expand = ys_lens.unsqueeze(1)
@ -279,6 +279,7 @@ def st_reverse_pad_list(ys_pad: paddle.Tensor,
# >>> tensor([[3, 2, 1],
# >>> [4, 8, 9],
# >>> [2, 2, 2]])
eos = paddle.full([1], eos, dtype=r_hyps.dtype)
r_hyps = paddle.where(seq_mask, r_hyps, eos)
# >>> r_hyps
# >>> tensor([[3, 2, 1],

@ -361,7 +361,7 @@ class DataLoaderFactory():
elif mode == 'valid':
config['manifest'] = config.dev_manifest
config['train_mode'] = False
elif model == 'test' or mode == 'align':
elif mode == 'test' or mode == 'align':
config['manifest'] = config.test_manifest
config['train_mode'] = False
config['dither'] = 0.0

Loading…
Cancel
Save