|
|
|
@ -310,7 +310,12 @@ class U2STBaseModel(nn.Layer):
|
|
|
|
|
ys = paddle.ones((len(hyps), i), dtype=paddle.long)
|
|
|
|
|
|
|
|
|
|
if hyps[0]["cache"] is not None:
|
|
|
|
|
cache = [paddle.ones((len(hyps), i-1, hyps[0]["cache"][0].shape[-1]), dtype=paddle.float32) for _ in range(len(hyps[0]["cache"]))]
|
|
|
|
|
cache = [
|
|
|
|
|
paddle.ones(
|
|
|
|
|
(len(hyps), i - 1, hyp_cache.shape[-1]),
|
|
|
|
|
dtype=paddle.float32)
|
|
|
|
|
for hyp_cache in hyps[0]["cache"]
|
|
|
|
|
]
|
|
|
|
|
for j, hyp in enumerate(hyps):
|
|
|
|
|
ys[j, :] = paddle.to_tensor(hyp["yseq"])
|
|
|
|
|
if hyps[0]["cache"] is not None:
|
|
|
|
@ -319,7 +324,8 @@ class U2STBaseModel(nn.Layer):
|
|
|
|
|
ys_mask = subsequent_mask(i).unsqueeze(0).to(device)
|
|
|
|
|
|
|
|
|
|
logp, cache = self.st_decoder.forward_one_step(
|
|
|
|
|
encoder_out.repeat(len(hyps), 1, 1), encoder_mask.repeat(len(hyps), 1, 1), ys, ys_mask, cache)
|
|
|
|
|
encoder_out.repeat(len(hyps), 1, 1),
|
|
|
|
|
encoder_mask.repeat(len(hyps), 1, 1), ys, ys_mask, cache)
|
|
|
|
|
|
|
|
|
|
hyps_best_kept = []
|
|
|
|
|
for j, hyp in enumerate(hyps):
|
|
|
|
|