You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/paddlespeech/server/engine/asr/online/ctc_endpoint.py

120 lines
4.5 KiB

# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
import numpy as np
from paddlespeech.cli.log import logger
@dataclass
class OnlineCTCEndpointRule:
must_contain_nonsilence: bool = True
min_trailing_silence: int = 1000
min_utterance_length: int = 0
@dataclass
class OnlineCTCEndpoingOpt:
frame_shift_in_ms: int = 10
blank: int = 0 # blank id, that we consider as silence for purposes of endpointing.
blank_threshold: float = 0.8 # above blank threshold is silence
# We support three rules. We terminate decoding if ANY of these rules
# evaluates to "true". If you want to add more rules, do it by changing this
# code. If you want to disable a rule, you can set the silence-timeout for
# that rule to a very large number.
# rule1 times out after 5 seconds of silence, even if we decoded nothing.
rule1: OnlineCTCEndpointRule = OnlineCTCEndpointRule(False, 5000, 0)
# rule2 times out after 1.0 seconds of silence after decoding something,
# even if we did not reach a final-state at all.
rule2: OnlineCTCEndpointRule = OnlineCTCEndpointRule(True, 1000, 0)
# rule3 times out after the utterance is 20 seconds long, regardless of
# anything else.
rule3: OnlineCTCEndpointRule = OnlineCTCEndpointRule(False, 0, 20000)
class OnlineCTCEndpoint:
"""
[END-TO-END AUTOMATIC SPEECH RECOGNITION INTEGRATED WITH CTC-BASED VOICE ACTIVITY DETECTION](https://arxiv.org/pdf/2002.00551.pdf)
"""
def __init__(self, opts: OnlineCTCEndpoingOpt):
self.opts = opts
logger.info(f"Endpont Opts: {opts}")
self.frame_shift_in_ms = opts.frame_shift_in_ms
self.num_frames_decoded = 0
self.trailing_silence_frames = 0
self.reset()
def reset(self):
self.num_frames_decoded = 0
self.trailing_silence_frames = 0
def rule_activated(self,
rule: OnlineCTCEndpointRule,
rule_name: str,
decoding_something: bool,
trailine_silence: int,
utterance_length: int) -> bool:
ans = (
decoding_something or (not rule.must_contain_nonsilence)
) and trailine_silence >= rule.min_trailing_silence and utterance_length >= rule.min_utterance_length
if (ans):
logger.info(f"Endpoint Rule: {rule_name} activated: {rule}")
return ans
def endpoint_detected(self,
ctc_log_probs: np.ndarray,
decoding_something: bool) -> bool:
"""detect endpoint.
Args:
ctc_log_probs (np.ndarray): (T, D)
decoding_something (bool): contain nonsilince.
Returns:
bool: whether endpoint detected.
"""
for logprob in ctc_log_probs:
blank_prob = np.exp(logprob[self.opts.blank])
self.num_frames_decoded += 1
if blank_prob > self.opts.blank_threshold:
self.trailing_silence_frames += 1
else:
self.trailing_silence_frames = 0
assert self.num_frames_decoded >= self.trailing_silence_frames
assert self.frame_shift_in_ms > 0
decoding_something = (self.num_frames_decoded > self.trailing_silence_frames) and decoding_something
utterance_length = self.num_frames_decoded * self.frame_shift_in_ms
trailing_silence = self.trailing_silence_frames * self.frame_shift_in_ms
if self.rule_activated(self.opts.rule1, 'rule1', decoding_something,
trailing_silence, utterance_length):
return True
if self.rule_activated(self.opts.rule2, 'rule2', decoding_something,
trailing_silence, utterance_length):
return True
if self.rule_activated(self.opts.rule3, 'rule3', decoding_something,
trailing_silence, utterance_length):
return True
return False