decoder end detection, when no more hyps score larger than before by a margin meet M count

pull/882/head
Hui Zhang 3 years ago
parent 2430545d45
commit 54b31d35f1

@ -0,0 +1,34 @@
__all__ = ["end_detect"]
def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))):
"""End detection.
described in Eq. (50) of S. Watanabe et al
"Hybrid CTC/Attention Architecture for End-to-End Speech Recognition"
:param ended_hyps: dict
:param i: int
:param M: int
:param D_end: float
:return: bool
"""
if len(ended_hyps) == 0:
return False
count = 0
best_hyp = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[0]
for m in range(M):
# get ended_hyps with their length is i - m
hyp_length = i - m
hyps_same_length = [x for x in ended_hyps if len(x["yseq"]) == hyp_length]
if len(hyps_same_length) > 0:
best_hyp_same_length = sorted(
hyps_same_length, key=lambda x: x["score"], reverse=True
)[0]
if best_hyp_same_length["score"] - best_hyp["score"] < D_end:
count += 1
if count == M:
return True
else:
return False
Loading…
Cancel
Save