config and formalize

pull/1350/head
Junkun 3 years ago
parent 43aad7a018
commit f866059b74

@ -1,8 +1,9 @@
batch_size: 5
batch_size: 1
error_rate_type: char-bleu
decoding_method: fullsentence # 'fullsentence', 'simultaneous'
beam_size: 10
word_reward: 0.7
maxlen_ratio: 0.3
decoding_chunk_size: -1 # decoding chunk size. Defaults to -1.
# <0: for decoding, use full chunk.
# >0: for decoding, use fixed chunk size as set.

@ -1,9 +1,10 @@
batch_size: 5
batch_size: 1
error_rate_type: char-bleu
decoding_method: fullsentence # 'fullsentence', 'simultaneous'
beam_size: 10
word_reward: 0.7
maxlen_ratio: 0.3
decoding_chunk_size: -1 # decoding chunk size. Defaults to -1.
# <0: for decoding, use full chunk.
# >0: for decoding, use fixed chunk size as set.

@ -310,7 +310,12 @@ class U2STBaseModel(nn.Layer):
ys = paddle.ones((len(hyps), i), dtype=paddle.long)
if hyps[0]["cache"] is not None:
cache = [paddle.ones((len(hyps), i-1, hyps[0]["cache"][0].shape[-1]), dtype=paddle.float32) for _ in range(len(hyps[0]["cache"]))]
cache = [
paddle.ones(
(len(hyps), i - 1, hyp_cache.shape[-1]),
dtype=paddle.float32)
for hyp_cache in hyps[0]["cache"]
]
for j, hyp in enumerate(hyps):
ys[j, :] = paddle.to_tensor(hyp["yseq"])
if hyps[0]["cache"] is not None:
@ -319,17 +324,18 @@ class U2STBaseModel(nn.Layer):
ys_mask = subsequent_mask(i).unsqueeze(0).to(device)
logp, cache = self.st_decoder.forward_one_step(
encoder_out.repeat(len(hyps), 1, 1), encoder_mask.repeat(len(hyps), 1, 1), ys, ys_mask, cache)
encoder_out.repeat(len(hyps), 1, 1),
encoder_mask.repeat(len(hyps), 1, 1), ys, ys_mask, cache)
hyps_best_kept = []
for j, hyp in enumerate(hyps):
top_k_logp, top_k_index = logp[j : j + 1].topk(beam_size)
top_k_logp, top_k_index = logp[j:j + 1].topk(beam_size)
for b in range(beam_size):
new_hyp = {}
new_hyp["score"] = hyp["score"] + float(top_k_logp[0, b])
new_hyp["yseq"] = [0] * (1 + len(hyp["yseq"]))
new_hyp["yseq"][: len(hyp["yseq"])] = hyp["yseq"]
new_hyp["yseq"][:len(hyp["yseq"])] = hyp["yseq"]
new_hyp["yseq"][len(hyp["yseq"])] = int(top_k_index[0, b])
new_hyp["cache"] = [cache_[j] for cache_ in cache]
# will be (2 x beam) hyps at most

Loading…
Cancel
Save