You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
50 lines
1.6 KiB
50 lines
1.6 KiB
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
__all__ = ["end_detect"]
|
|
|
|
|
|
def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))):
|
|
"""End detection.
|
|
|
|
described in Eq. (50) of S. Watanabe et al
|
|
"Hybrid CTC/Attention Architecture for End-to-End Speech Recognition"
|
|
|
|
:param ended_hyps: dict
|
|
:param i: int
|
|
:param M: int
|
|
:param D_end: float
|
|
:return: bool
|
|
"""
|
|
if len(ended_hyps) == 0:
|
|
return False
|
|
count = 0
|
|
best_hyp = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[0]
|
|
for m in range(M):
|
|
# get ended_hyps with their length is i - m
|
|
hyp_length = i - m
|
|
hyps_same_length = [
|
|
x for x in ended_hyps if len(x["yseq"]) == hyp_length
|
|
]
|
|
if len(hyps_same_length) > 0:
|
|
best_hyp_same_length = sorted(
|
|
hyps_same_length, key=lambda x: x["score"], reverse=True)[0]
|
|
if best_hyp_same_length["score"] - best_hyp["score"] < D_end:
|
|
count += 1
|
|
|
|
if count == M:
|
|
return True
|
|
else:
|
|
return False
|