__all__ = ["end_detect"] def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))): """End detection. described in Eq. (50) of S. Watanabe et al "Hybrid CTC/Attention Architecture for End-to-End Speech Recognition" :param ended_hyps: dict :param i: int :param M: int :param D_end: float :return: bool """ if len(ended_hyps) == 0: return False count = 0 best_hyp = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[0] for m in range(M): # get ended_hyps with their length is i - m hyp_length = i - m hyps_same_length = [x for x in ended_hyps if len(x["yseq"]) == hyp_length] if len(hyps_same_length) > 0: best_hyp_same_length = sorted( hyps_same_length, key=lambda x: x["score"], reverse=True )[0] if best_hyp_same_length["score"] - best_hyp["score"] < D_end: count += 1 if count == M: return True else: return False