pull/882/head
parent
2430545d45
commit
54b31d35f1
@ -0,0 +1,34 @@
|
|||||||
|
|
||||||
|
__all__ = ["end_detect"]
|
||||||
|
|
||||||
|
def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))):
|
||||||
|
"""End detection.
|
||||||
|
|
||||||
|
described in Eq. (50) of S. Watanabe et al
|
||||||
|
"Hybrid CTC/Attention Architecture for End-to-End Speech Recognition"
|
||||||
|
|
||||||
|
:param ended_hyps: dict
|
||||||
|
:param i: int
|
||||||
|
:param M: int
|
||||||
|
:param D_end: float
|
||||||
|
:return: bool
|
||||||
|
"""
|
||||||
|
if len(ended_hyps) == 0:
|
||||||
|
return False
|
||||||
|
count = 0
|
||||||
|
best_hyp = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[0]
|
||||||
|
for m in range(M):
|
||||||
|
# get ended_hyps with their length is i - m
|
||||||
|
hyp_length = i - m
|
||||||
|
hyps_same_length = [x for x in ended_hyps if len(x["yseq"]) == hyp_length]
|
||||||
|
if len(hyps_same_length) > 0:
|
||||||
|
best_hyp_same_length = sorted(
|
||||||
|
hyps_same_length, key=lambda x: x["score"], reverse=True
|
||||||
|
)[0]
|
||||||
|
if best_hyp_same_length["score"] - best_hyp["score"] < D_end:
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
if count == M:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
Loading…
Reference in new issue