|
|
|
@ -264,14 +264,17 @@ class U2STBaseModel(nn.Layer):
|
|
|
|
|
speech_lengths: paddle.Tensor,
|
|
|
|
|
beam_size: int=10,
|
|
|
|
|
word_reward: float=0.0,
|
|
|
|
|
maxlen_ratio: float=0.5,
|
|
|
|
|
decoding_chunk_size: int=-1,
|
|
|
|
|
num_decoding_left_chunks: int=-1,
|
|
|
|
|
simulate_streaming: bool=False, ) -> paddle.Tensor:
|
|
|
|
|
""" Apply beam search on attention decoder
|
|
|
|
|
""" Apply beam search on attention decoder with length penalty
|
|
|
|
|
Args:
|
|
|
|
|
speech (paddle.Tensor): (batch, max_len, feat_dim)
|
|
|
|
|
speech_length (paddle.Tensor): (batch, )
|
|
|
|
|
beam_size (int): beam size for beam search
|
|
|
|
|
word_reward (float): word reward used in beam search
|
|
|
|
|
maxlen_ratio (float): max length ratio to bound the length of translated text
|
|
|
|
|
decoding_chunk_size (int): decoding chunk for dynamic chunk
|
|
|
|
|
trained model.
|
|
|
|
|
<0: for decoding, use full chunk.
|
|
|
|
@ -284,90 +287,84 @@ class U2STBaseModel(nn.Layer):
|
|
|
|
|
"""
|
|
|
|
|
assert speech.shape[0] == speech_lengths.shape[0]
|
|
|
|
|
assert decoding_chunk_size != 0
|
|
|
|
|
assert speech.shape[0] == 1
|
|
|
|
|
device = speech.place
|
|
|
|
|
batch_size = speech.shape[0]
|
|
|
|
|
|
|
|
|
|
# Let's assume B = batch_size and N = beam_size
|
|
|
|
|
# 1. Encoder
|
|
|
|
|
# 1. Encoder and init hypothesis
|
|
|
|
|
encoder_out, encoder_mask = self._forward_encoder(
|
|
|
|
|
speech, speech_lengths, decoding_chunk_size,
|
|
|
|
|
num_decoding_left_chunks,
|
|
|
|
|
simulate_streaming) # (B, maxlen, encoder_dim)
|
|
|
|
|
maxlen = encoder_out.shape[1]
|
|
|
|
|
encoder_dim = encoder_out.shape[2]
|
|
|
|
|
running_size = batch_size * beam_size
|
|
|
|
|
encoder_out = encoder_out.unsqueeze(1).repeat(1, beam_size, 1, 1).view(
|
|
|
|
|
running_size, maxlen, encoder_dim) # (B*N, maxlen, encoder_dim)
|
|
|
|
|
encoder_mask = encoder_mask.unsqueeze(1).repeat(
|
|
|
|
|
1, beam_size, 1, 1).view(running_size, 1,
|
|
|
|
|
maxlen) # (B*N, 1, max_len)
|
|
|
|
|
|
|
|
|
|
hyps = paddle.ones(
|
|
|
|
|
[running_size, 1], dtype=paddle.long).fill_(self.sos) # (B*N, 1)
|
|
|
|
|
# log scale score
|
|
|
|
|
scores = paddle.to_tensor(
|
|
|
|
|
[0.0] + [-float('inf')] * (beam_size - 1), dtype=paddle.float)
|
|
|
|
|
scores = scores.to(device).repeat(batch_size).unsqueeze(1).to(
|
|
|
|
|
device) # (B*N, 1)
|
|
|
|
|
end_flag = paddle.zeros_like(scores, dtype=paddle.bool) # (B*N, 1)
|
|
|
|
|
cache: Optional[List[paddle.Tensor]] = None
|
|
|
|
|
|
|
|
|
|
maxlen = max(int(encoder_out.shape[1] * maxlen_ratio), 5)
|
|
|
|
|
|
|
|
|
|
hyp = {"score": 0.0, "yseq": [self.sos], "cache": None}
|
|
|
|
|
hyps = [hyp]
|
|
|
|
|
ended_hyps = []
|
|
|
|
|
cur_best_score = -float("inf")
|
|
|
|
|
cache = None
|
|
|
|
|
|
|
|
|
|
# 2. Decoder forward step by step
|
|
|
|
|
for i in range(1, maxlen + 1):
|
|
|
|
|
# Stop if all batch and all beam produce eos
|
|
|
|
|
# TODO(Hui Zhang): if end_flag.sum() == running_size:
|
|
|
|
|
if end_flag.cast(paddle.int64).sum() == running_size:
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
# 2.1 Forward decoder step
|
|
|
|
|
hyps_mask = subsequent_mask(i).unsqueeze(0).repeat(
|
|
|
|
|
running_size, 1, 1).to(device) # (B*N, i, i)
|
|
|
|
|
# logp: (B*N, vocab)
|
|
|
|
|
ys = paddle.ones((len(hyps), i), dtype=paddle.long)
|
|
|
|
|
|
|
|
|
|
if hyps[0]["cache"] is not None:
|
|
|
|
|
cache = [paddle.ones((len(hyps), i-1, hyps[0]["cache"][0].shape[-1]), dtype=paddle.float32) for _ in range(len(hyps[0]["cache"]))]
|
|
|
|
|
for j, hyp in enumerate(hyps):
|
|
|
|
|
ys[j, :] = paddle.to_tensor(hyp["yseq"])
|
|
|
|
|
if hyps[0]["cache"] is not None:
|
|
|
|
|
for k in range(len(cache)):
|
|
|
|
|
cache[k][j] = hyps[j]["cache"][k]
|
|
|
|
|
ys_mask = subsequent_mask(i).unsqueeze(0).to(device)
|
|
|
|
|
|
|
|
|
|
logp, cache = self.st_decoder.forward_one_step(
|
|
|
|
|
encoder_out, encoder_mask, hyps, hyps_mask, cache)
|
|
|
|
|
|
|
|
|
|
# 2.2 First beam prune: select topk best prob at current time
|
|
|
|
|
top_k_logp, top_k_index = logp.topk(beam_size) # (B*N, N)
|
|
|
|
|
top_k_logp += word_reward
|
|
|
|
|
top_k_logp = mask_finished_scores(top_k_logp, end_flag)
|
|
|
|
|
top_k_index = mask_finished_preds(top_k_index, end_flag, self.eos)
|
|
|
|
|
|
|
|
|
|
# 2.3 Seconde beam prune: select topk score with history
|
|
|
|
|
scores = scores + top_k_logp # (B*N, N), broadcast add
|
|
|
|
|
scores = scores.view(batch_size, beam_size * beam_size) # (B, N*N)
|
|
|
|
|
scores, offset_k_index = scores.topk(k=beam_size) # (B, N)
|
|
|
|
|
scores = scores.view(-1, 1) # (B*N, 1)
|
|
|
|
|
|
|
|
|
|
# 2.4. Compute base index in top_k_index,
|
|
|
|
|
# regard top_k_index as (B*N*N),regard offset_k_index as (B*N),
|
|
|
|
|
# then find offset_k_index in top_k_index
|
|
|
|
|
base_k_index = paddle.arange(batch_size).view(-1, 1).repeat(
|
|
|
|
|
1, beam_size) # (B, N)
|
|
|
|
|
base_k_index = base_k_index * beam_size * beam_size
|
|
|
|
|
best_k_index = base_k_index.view(-1) + offset_k_index.view(
|
|
|
|
|
-1) # (B*N)
|
|
|
|
|
|
|
|
|
|
# 2.5 Update best hyps
|
|
|
|
|
best_k_pred = paddle.index_select(
|
|
|
|
|
top_k_index.view(-1), index=best_k_index, axis=0) # (B*N)
|
|
|
|
|
best_hyps_index = best_k_index // beam_size
|
|
|
|
|
last_best_k_hyps = paddle.index_select(
|
|
|
|
|
hyps, index=best_hyps_index, axis=0) # (B*N, i)
|
|
|
|
|
hyps = paddle.cat(
|
|
|
|
|
(last_best_k_hyps, best_k_pred.view(-1, 1)),
|
|
|
|
|
dim=1) # (B*N, i+1)
|
|
|
|
|
|
|
|
|
|
# 2.6 Update end flag
|
|
|
|
|
end_flag = paddle.eq(hyps[:, -1], self.eos).view(-1, 1)
|
|
|
|
|
encoder_out.repeat(len(hyps), 1, 1), encoder_mask.repeat(len(hyps), 1, 1), ys, ys_mask, cache)
|
|
|
|
|
|
|
|
|
|
hyps_best_kept = []
|
|
|
|
|
for j, hyp in enumerate(hyps):
|
|
|
|
|
top_k_logp, top_k_index = logp[j : j + 1].topk(beam_size)
|
|
|
|
|
|
|
|
|
|
for b in range(beam_size):
|
|
|
|
|
new_hyp = {}
|
|
|
|
|
new_hyp["score"] = hyp["score"] + float(top_k_logp[0, b])
|
|
|
|
|
new_hyp["yseq"] = [0] * (1 + len(hyp["yseq"]))
|
|
|
|
|
new_hyp["yseq"][: len(hyp["yseq"])] = hyp["yseq"]
|
|
|
|
|
new_hyp["yseq"][len(hyp["yseq"])] = int(top_k_index[0, b])
|
|
|
|
|
new_hyp["cache"] = [cache_[j] for cache_ in cache]
|
|
|
|
|
# will be (2 x beam) hyps at most
|
|
|
|
|
hyps_best_kept.append(new_hyp)
|
|
|
|
|
|
|
|
|
|
hyps_best_kept = sorted(
|
|
|
|
|
hyps_best_kept, key=lambda x: -x["score"])[:beam_size]
|
|
|
|
|
|
|
|
|
|
# sort and get nbest
|
|
|
|
|
hyps = hyps_best_kept
|
|
|
|
|
if i == maxlen:
|
|
|
|
|
for hyp in hyps:
|
|
|
|
|
hyp["yseq"].append(self.eos)
|
|
|
|
|
|
|
|
|
|
# finalize the ended hypotheses with word reward (by length)
|
|
|
|
|
remained_hyps = []
|
|
|
|
|
for hyp in hyps:
|
|
|
|
|
if hyp["yseq"][-1] == self.eos:
|
|
|
|
|
hyp["score"] += (i - 1) * word_reward
|
|
|
|
|
cur_best_score = max(cur_best_score, hyp["score"])
|
|
|
|
|
ended_hyps.append(hyp)
|
|
|
|
|
else:
|
|
|
|
|
# stop while guarantee the optimality
|
|
|
|
|
if hyp["score"] + maxlen * word_reward > cur_best_score:
|
|
|
|
|
remained_hyps.append(hyp)
|
|
|
|
|
|
|
|
|
|
# stop predition when there is no unended hypothesis
|
|
|
|
|
if not remained_hyps:
|
|
|
|
|
break
|
|
|
|
|
hyps = remained_hyps
|
|
|
|
|
|
|
|
|
|
# 3. Select best of best
|
|
|
|
|
scores = scores.view(batch_size, beam_size)
|
|
|
|
|
# TODO: length normalization
|
|
|
|
|
best_index = paddle.argmax(scores, axis=-1).long() # (B)
|
|
|
|
|
best_hyps_index = best_index + paddle.arange(
|
|
|
|
|
batch_size, dtype=paddle.long) * beam_size
|
|
|
|
|
best_hyps = paddle.index_select(hyps, index=best_hyps_index, axis=0)
|
|
|
|
|
best_hyps = best_hyps[:, 1:]
|
|
|
|
|
return best_hyps
|
|
|
|
|
best_hyp = max(ended_hyps, key=lambda x: x["score"])
|
|
|
|
|
|
|
|
|
|
return paddle.to_tensor([best_hyp["yseq"][1:]])
|
|
|
|
|
|
|
|
|
|
# @jit.to_static
|
|
|
|
|
def subsampling_rate(self) -> int:
|
|
|
|
@ -472,6 +469,7 @@ class U2STBaseModel(nn.Layer):
|
|
|
|
|
decoding_method: str,
|
|
|
|
|
beam_size: int,
|
|
|
|
|
word_reward: float=0.0,
|
|
|
|
|
maxlen_ratio: float=0.5,
|
|
|
|
|
decoding_chunk_size: int=-1,
|
|
|
|
|
num_decoding_left_chunks: int=-1,
|
|
|
|
|
simulate_streaming: bool=False):
|
|
|
|
@ -507,6 +505,7 @@ class U2STBaseModel(nn.Layer):
|
|
|
|
|
feats_lengths,
|
|
|
|
|
beam_size=beam_size,
|
|
|
|
|
word_reward=word_reward,
|
|
|
|
|
maxlen_ratio=maxlen_ratio,
|
|
|
|
|
decoding_chunk_size=decoding_chunk_size,
|
|
|
|
|
num_decoding_left_chunks=num_decoding_left_chunks,
|
|
|
|
|
simulate_streaming=simulate_streaming)
|
|
|
|
|