|
|
|
@ -310,7 +310,12 @@ class U2STBaseModel(nn.Layer):
|
|
|
|
|
ys = paddle.ones((len(hyps), i), dtype=paddle.long)
|
|
|
|
|
|
|
|
|
|
if hyps[0]["cache"] is not None:
|
|
|
|
|
cache = [paddle.ones((len(hyps), i-1, hyps[0]["cache"][0].shape[-1]), dtype=paddle.float32) for _ in range(len(hyps[0]["cache"]))]
|
|
|
|
|
cache = [
|
|
|
|
|
paddle.ones(
|
|
|
|
|
(len(hyps), i - 1, hyp_cache.shape[-1]),
|
|
|
|
|
dtype=paddle.float32)
|
|
|
|
|
for hyp_cache in hyps[0]["cache"]
|
|
|
|
|
]
|
|
|
|
|
for j, hyp in enumerate(hyps):
|
|
|
|
|
ys[j, :] = paddle.to_tensor(hyp["yseq"])
|
|
|
|
|
if hyps[0]["cache"] is not None:
|
|
|
|
@ -319,17 +324,18 @@ class U2STBaseModel(nn.Layer):
|
|
|
|
|
ys_mask = subsequent_mask(i).unsqueeze(0).to(device)
|
|
|
|
|
|
|
|
|
|
logp, cache = self.st_decoder.forward_one_step(
|
|
|
|
|
encoder_out.repeat(len(hyps), 1, 1), encoder_mask.repeat(len(hyps), 1, 1), ys, ys_mask, cache)
|
|
|
|
|
encoder_out.repeat(len(hyps), 1, 1),
|
|
|
|
|
encoder_mask.repeat(len(hyps), 1, 1), ys, ys_mask, cache)
|
|
|
|
|
|
|
|
|
|
hyps_best_kept = []
|
|
|
|
|
for j, hyp in enumerate(hyps):
|
|
|
|
|
top_k_logp, top_k_index = logp[j : j + 1].topk(beam_size)
|
|
|
|
|
top_k_logp, top_k_index = logp[j:j + 1].topk(beam_size)
|
|
|
|
|
|
|
|
|
|
for b in range(beam_size):
|
|
|
|
|
new_hyp = {}
|
|
|
|
|
new_hyp["score"] = hyp["score"] + float(top_k_logp[0, b])
|
|
|
|
|
new_hyp["yseq"] = [0] * (1 + len(hyp["yseq"]))
|
|
|
|
|
new_hyp["yseq"][: len(hyp["yseq"])] = hyp["yseq"]
|
|
|
|
|
new_hyp["yseq"][:len(hyp["yseq"])] = hyp["yseq"]
|
|
|
|
|
new_hyp["yseq"][len(hyp["yseq"])] = int(top_k_index[0, b])
|
|
|
|
|
new_hyp["cache"] = [cache_[j] for cache_ in cache]
|
|
|
|
|
# will be (2 x beam) hyps at most
|
|
|
|
|