refactor asr online server

pull/2015/head
Hui Zhang 2 years ago
parent f3132ce2d2
commit 8f9b7bba48

@ -51,12 +51,12 @@ repos:
language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$
exclude: (?=speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin).*(\.cpp|\.cc|\.h|\.py)$
- id: copyright_checker
name: copyright_checker
entry: python .pre-commit-hooks/copyright-check.hook
language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$
exclude: (?=third_party|pypinyin|speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin).*(\.cpp|\.cc|\.h|\.py)$
#- id: copyright_checker
# name: copyright_checker
# entry: python .pre-commit-hooks/copyright-check.hook
# language: system
# files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$
# exclude: (?=third_party|pypinyin|speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin).*(\.cpp|\.cc|\.h|\.py)$
- repo: https://github.com/asottile/reorder_python_imports
rev: v2.4.0
hooks:

@ -5,4 +5,5 @@ export CUDA_VISIBLE_DEVICE=0,1,2,3
paddlespeech_server start --config_file conf/punc_application.yaml &> punc.log &
# nohup python3 streaming_asr_server.py --config_file conf/ws_conformer_application.yaml > streaming_asr.log 2>&1 &
paddlespeech_server start --config_file conf/ws_conformer_application.yaml &> streaming_asr.log &
paddlespeech_server start --config_file conf/ws_conformer_application.yaml &> streaming_asr.log &

@ -9,4 +9,5 @@ paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8290 --input ./zh.wa
# read the wav and call streaming and punc service
# If `127.0.0.1` is not accessible, you need to use the actual service IP address.
# python3 websocket_client.py --server_ip 127.0.0.1 --port 8290 --punc.server_ip 127.0.0.1 --punc.port 8190 --wavfile ./zh.wav
paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8290 --punc.server_ip 127.0.0.1 --punc.port 8190 --input ./zh.wav
paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8290 --punc.server_ip 127.0.0.1 --punc.port 8190 --input ./zh.wav

@ -14,3 +14,7 @@
import _locale
_locale._getdefaultlocale = (lambda *args: ['en_US', 'utf8'])

@ -13,6 +13,7 @@
# limitations under the License.
import os
import sys
from typing import ByteString
from typing import Optional
import numpy as np
@ -30,9 +31,10 @@ from paddlespeech.s2t.transform.transformation import Transformation
from paddlespeech.s2t.utils.tensor_utils import add_sos_eos
from paddlespeech.s2t.utils.tensor_utils import pad_sequence
from paddlespeech.s2t.utils.utility import UpdateConfig
from paddlespeech.server.engine.asr.online.ctc_endpoint import OnlineCTCEndpoingOpt
from paddlespeech.server.engine.asr.online.ctc_endpoint import OnlineCTCEndpoint
from paddlespeech.server.engine.asr.online.ctc_search import CTCPrefixBeamSearch
from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils.audio_process import pcm2float
from paddlespeech.server.utils.paddle_predictor import init_predictor
__all__ = ['PaddleASRConnectionHanddler', 'ASRServerExecutor', 'ASREngine']
@ -54,24 +56,33 @@ class PaddleASRConnectionHanddler:
self.model_config = asr_engine.executor.config
self.asr_engine = asr_engine
self.init()
self.reset()
def init(self):
# model_type, sample_rate and text_feature is shared for deepspeech2 and conformer
self.model_type = self.asr_engine.executor.model_type
self.sample_rate = self.asr_engine.executor.sample_rate
# tokens to text
self.text_feature = self.asr_engine.executor.text_feature
# extract feat, new only fbank in conformer model
self.preprocess_conf = self.model_config.preprocess_config
self.preprocess_args = {"train": False}
self.preprocessing = Transformation(self.preprocess_conf)
# frame window and frame shift, in samples unit
self.win_length = self.preprocess_conf.process[0]['win_length']
self.n_shift = self.preprocess_conf.process[0]['n_shift']
assert self.preprocess_conf.process[0]['fs'] == self.sample_rate, (
self.sample_rate, self.preprocess_conf.process[0]['fs'])
self.frame_shift_in_ms = int(
self.n_shift / self.preprocess_conf.process[0]['fs'] * 1000)
self.init_decoder()
self.reset()
def init_decoder(self):
if "deepspeech2" in self.model_type:
self.am_predictor = self.asr_engine.executor.am_predictor
# extract feat, new only fbank in conformer model
self.preprocess_conf = self.model_config.preprocess_config
self.preprocess_args = {"train": False}
self.preprocessing = Transformation(self.preprocess_conf)
self.decoder = CTCDecoder(
odim=self.model_config.output_dim, # <blank> is in vocab
enc_n_units=self.model_config.rnn_layer_size * 2,
@ -90,10 +101,6 @@ class PaddleASRConnectionHanddler:
cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n,
cfg.num_proc_bsearch)
# frame window and frame shift, in samples unit
self.win_length = self.preprocess_conf.process[0]['win_length']
self.n_shift = self.preprocess_conf.process[0]['n_shift']
elif "conformer" in self.model_type or "transformer" in self.model_type:
# acoustic model
self.model = self.asr_engine.executor.model
@ -102,130 +109,40 @@ class PaddleASRConnectionHanddler:
self.ctc_decode_config = self.asr_engine.executor.config.decode
self.searcher = CTCPrefixBeamSearch(self.ctc_decode_config)
# extract feat, new only fbank in conformer model
self.preprocess_conf = self.model_config.preprocess_config
self.preprocess_args = {"train": False}
self.preprocessing = Transformation(self.preprocess_conf)
# frame window and frame shift, in samples unit
self.win_length = self.preprocess_conf.process[0]['win_length']
self.n_shift = self.preprocess_conf.process[0]['n_shift']
# ctc endpoint
self.endpoint_opt = OnlineCTCEndpoingOpt(
frame_shift_in_ms=self.frame_shift_in_ms, blank=0)
self.endpointer = OnlineCTCEndpoint(self.endpoint_opt)
else:
raise ValueError(f"Not supported: {self.model_type}")
def extract_feat(self, samples):
# we compute the elapsed time of first char occuring
# and we record the start time at the first pcm sample arraving
if "deepspeech2online" in self.model_type:
# self.reamined_wav stores all the samples,
# include the original remained_wav and this package samples
samples = np.frombuffer(samples, dtype=np.int16)
assert samples.ndim == 1
if self.remained_wav is None:
self.remained_wav = samples
else:
assert self.remained_wav.ndim == 1
self.remained_wav = np.concatenate([self.remained_wav, samples])
logger.info(
f"The connection remain the audio samples: {self.remained_wav.shape}"
)
# fbank
feat = self.preprocessing(self.remained_wav,
**self.preprocess_args)
feat = paddle.to_tensor(
feat, dtype="float32").unsqueeze(axis=0)
if self.cached_feat is None:
self.cached_feat = feat
else:
assert (len(feat.shape) == 3)
assert (len(self.cached_feat.shape) == 3)
self.cached_feat = paddle.concat(
[self.cached_feat, feat], axis=1)
# set the feat device
if self.device is None:
self.device = self.cached_feat.place
# cur frame step
num_frames = feat.shape[1]
self.num_frames += num_frames
self.remained_wav = self.remained_wav[self.n_shift * num_frames:]
logger.info(
f"process the audio feature success, the connection feat shape: {self.cached_feat.shape}"
)
logger.info(
f"After extract feat, the connection remain the audio samples: {self.remained_wav.shape}"
)
elif "conformer_online" in self.model_type:
logger.info("Online ASR extract the feat")
samples = np.frombuffer(samples, dtype=np.int16)
assert samples.ndim == 1
self.num_samples += samples.shape[0]
logger.info(
f"This package receive {samples.shape[0]} pcm data. Global samples:{self.num_samples}"
)
# self.reamined_wav stores all the samples,
# include the original remained_wav and this package samples
if self.remained_wav is None:
self.remained_wav = samples
else:
assert self.remained_wav.ndim == 1 # (T,)
self.remained_wav = np.concatenate([self.remained_wav, samples])
logger.info(
f"The concatenation of remain and now audio samples length is: {self.remained_wav.shape}"
)
if len(self.remained_wav) < self.win_length:
# samples not enough for feature window
return 0
# fbank
x_chunk = self.preprocessing(self.remained_wav,
**self.preprocess_args)
x_chunk = paddle.to_tensor(
x_chunk, dtype="float32").unsqueeze(axis=0)
# feature cache
if self.cached_feat is None:
self.cached_feat = x_chunk
else:
assert (len(x_chunk.shape) == 3) # (B,T,D)
assert (len(self.cached_feat.shape) == 3) # (B,T,D)
self.cached_feat = paddle.concat(
[self.cached_feat, x_chunk], axis=1)
# set the feat device
if self.device is None:
self.device = self.cached_feat.place
def model_reset(self):
if "deepspeech2" in self.model_type:
return
# cur frame step
num_frames = x_chunk.shape[1]
# feature cache
self.cached_feat = None
# global frame step
self.num_frames += num_frames
## conformer
# cache for conformer online
self.subsampling_cache = None
self.elayers_output_cache = None
self.conformer_cnn_cache = None
self.encoder_out = None
# conformer decoding state
self.offset = 0 # global offset in decoding frame unit
# update remained wav
self.remained_wav = self.remained_wav[self.n_shift * num_frames:]
## just for record info
self.chunk_num = 0 # global decoding chunk num, not used
logger.info(
f"process the audio feature success, the cached feat shape: {self.cached_feat.shape}"
)
logger.info(
f"After extract feat, the cached remain the audio samples: {self.remained_wav.shape}"
)
logger.info(f"global samples: {self.num_samples}")
logger.info(f"global frames: {self.num_frames}")
else:
raise ValueError(f"not supported: {self.model_type}")
def reset_continuous_decoding(self):
"""
when in continous decoding, reset for next utterance.
"""
self.global_frame_offset = self.num_frames
self.model_reset()
self.searcher.reset()
self.endpointer.reset()
def reset(self):
if "deepspeech2" in self.model_type:
@ -241,53 +158,110 @@ class PaddleASRConnectionHanddler:
dtype=float32)
self.decoder.reset_decoder(batch_size=1)
if "conformer" in self.model_type or "transformer" in self.model_type:
self.searcher.reset()
self.endpointer.reset()
self.device = None
## common
# global sample and frame step
self.num_samples = 0
self.global_frame_offset = 0
# frame step of cur utterance
self.num_frames = 0
# cache for audio and feat
self.remained_wav = None
self.cached_feat = None
# partial/ending decoding results
self.result_transcripts = ['']
## conformer
self.model_reset()
# cache for conformer online
self.subsampling_cache = None
self.elayers_output_cache = None
self.conformer_cnn_cache = None
self.encoder_out = None
# conformer decoding state
self.chunk_num = 0 # globa decoding chunk num
self.offset = 0 # global offset in decoding frame unit
self.hyps = []
## outputs
# partial/ending decoding results
self.result_transcripts = ['']
# token timestamp result
self.word_time_stamp = []
## just for record
self.hyps = []
# one best timestamp viterbi prob is large.
self.time_stamp = []
def extract_feat(self, samples: ByteString):
logger.info("Online ASR extract the feat")
samples = np.frombuffer(samples, dtype=np.int16)
assert samples.ndim == 1
self.num_samples += samples.shape[0]
logger.info(
f"This package receive {samples.shape[0]} pcm data. Global samples:{self.num_samples}"
)
# self.reamined_wav stores all the samples,
# include the original remained_wav and this package samples
if self.remained_wav is None:
self.remained_wav = samples
else:
assert self.remained_wav.ndim == 1 # (T,)
self.remained_wav = np.concatenate([self.remained_wav, samples])
logger.info(
f"The concatenation of remain and now audio samples length is: {self.remained_wav.shape}"
)
if len(self.remained_wav) < self.win_length:
# samples not enough for feature window
return 0
# fbank
x_chunk = self.preprocessing(self.remained_wav, **self.preprocess_args)
x_chunk = paddle.to_tensor(x_chunk, dtype="float32").unsqueeze(axis=0)
# feature cache
if self.cached_feat is None:
self.cached_feat = x_chunk
else:
assert (len(x_chunk.shape) == 3) # (B,T,D)
assert (len(self.cached_feat.shape) == 3) # (B,T,D)
self.cached_feat = paddle.concat(
[self.cached_feat, x_chunk], axis=1)
# set the feat device
if self.device is None:
self.device = self.cached_feat.place
# cur frame step
num_frames = x_chunk.shape[1]
# global frame step
self.num_frames += num_frames
# update remained wav
self.remained_wav = self.remained_wav[self.n_shift * num_frames:]
logger.info(
f"process the audio feature success, the cached feat shape: {self.cached_feat.shape}"
)
logger.info(
f"After extract feat, the cached remain the audio samples: {self.remained_wav.shape}"
)
logger.info(f"global samples: {self.num_samples}")
logger.info(f"global frames: {self.num_frames}")
def decode(self, is_finished=False):
"""advance decoding
Args:
is_finished (bool, optional): Is last frame or not. Defaults to False.
Raises:
Exception: when not support model.
Returns:
None: nothing
None:
"""
if "deepspeech2online" in self.model_type:
if "deepspeech2" in self.model_type:
decoding_chunk_size = 1 # decoding chunk size = 1. int decoding frame unit
context = 7 # context=7, in audio frame unit
subsampling = 4 # subsampling=4, in audio frame unit
@ -332,9 +306,11 @@ class PaddleASRConnectionHanddler:
end = None
for cur in range(0, num_frames - left_frames + 1, stride):
end = min(cur + decoding_window, num_frames)
# extract the audio
x_chunk = self.cached_feat[:, cur:end, :].numpy()
x_chunk_lens = np.array([x_chunk.shape[1]])
trans_best = self.decode_one_chunk(x_chunk, x_chunk_lens)
self.result_transcripts = [trans_best]
@ -409,31 +385,38 @@ class PaddleASRConnectionHanddler:
@paddle.no_grad()
def advance_decoding(self, is_finished=False):
if "deepspeech" in self.model_type:
return
logger.info(
"Conformer/Transformer: start to decode with advanced_decoding method"
)
cfg = self.ctc_decode_config
# cur chunk size, in decoding frame unit
# cur chunk size, in decoding frame unit, e.g. 16
decoding_chunk_size = cfg.decoding_chunk_size
# using num of history chunks
# using num of history chunks, e.g -1
num_decoding_left_chunks = cfg.num_decoding_left_chunks
assert decoding_chunk_size > 0
# e.g. 4
subsampling = self.model.encoder.embed.subsampling_rate
# e.g. 7
context = self.model.encoder.embed.right_context + 1
# processed chunk feature cached for next chunk
# processed chunk feature cached for next chunk, e.g. 3
cached_feature_num = context - subsampling
# decoding stride, in audio frame unit
stride = subsampling * decoding_chunk_size
# decoding window, in audio frame unit
decoding_window = (decoding_chunk_size - 1) * subsampling + context
# decoding stride, in audio frame unit
stride = subsampling * decoding_chunk_size
if self.cached_feat is None:
logger.info("no audio feat, please input more pcm data")
return
# (B=1,T,D)
num_frames = self.cached_feat.shape[1]
logger.info(
f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames"
@ -454,9 +437,6 @@ class PaddleASRConnectionHanddler:
return None, None
logger.info("start to do model forward")
# hist of chunks, in deocding frame unit
required_cache_size = decoding_chunk_size * num_decoding_left_chunks
outputs = []
# num_frames - context + 1 ensure that current frame can get context window
if is_finished:
@ -466,7 +446,11 @@ class PaddleASRConnectionHanddler:
# we only process decoding_window frames for one chunk
left_frames = decoding_window
# hist of chunks, in deocding frame unit
required_cache_size = decoding_chunk_size * num_decoding_left_chunks
# record the end for removing the processed feat
outputs = []
end = None
for cur in range(0, num_frames - left_frames + 1, stride):
end = min(cur + decoding_window, num_frames)
@ -491,30 +475,28 @@ class PaddleASRConnectionHanddler:
self.encoder_out = ys
else:
self.encoder_out = paddle.concat([self.encoder_out, ys], axis=1)
logger.info(
f"This connection handler encoder out shape: {self.encoder_out.shape}"
)
# get the ctc probs
ctc_probs = self.model.ctc.log_softmax(ys) # (1, maxlen, vocab_size)
ctc_probs = ctc_probs.squeeze(0)
## decoding
# advance decoding
self.searcher.search(ctc_probs, self.cached_feat.place)
# get one best hyps
self.hyps = self.searcher.get_one_best_hyps()
assert self.cached_feat.shape[0] == 1
assert end >= cached_feature_num
# advance cache of feat
self.cached_feat = self.cached_feat[0, end -
cached_feature_num:, :].unsqueeze(0)
assert self.cached_feat.shape[0] == 1 #(B=1,T,D)
assert end >= cached_feature_num
self.cached_feat = self.cached_feat[:, end - cached_feature_num:, :]
assert len(
self.cached_feat.shape
) == 3, f"current cache feat shape is: {self.cached_feat.shape}"
logger.info(
f"This connection handler encoder out shape: {self.encoder_out.shape}"
)
def update_result(self):
"""Conformer/Transformer hyps to result.
"""
@ -654,24 +636,28 @@ class PaddleASRConnectionHanddler:
# update each word start and end time stamp
# decoding frame to audio frame
frame_shift = self.model.encoder.embed.subsampling_rate
frame_shift_in_sec = frame_shift * (self.n_shift / self.sample_rate)
logger.info(f"frame shift sec: {frame_shift_in_sec}")
decode_frame_shift = self.model.encoder.embed.subsampling_rate
decode_frame_shift_in_sec = decode_frame_shift * (self.n_shift /
self.sample_rate)
logger.info(f"decode frame shift in sec: {decode_frame_shift_in_sec}")
global_offset_in_sec = self.global_frame_offset * self.frame_shift_in_ms / 1000.0
logger.info(f"global offset: {global_offset_in_sec} sec.")
word_time_stamp = []
for idx, _ in enumerate(self.time_stamp):
start = (self.time_stamp[idx - 1] + self.time_stamp[idx]
) / 2.0 if idx > 0 else 0
start = start * frame_shift_in_sec
start = start * decode_frame_shift_in_sec
end = (self.time_stamp[idx] + self.time_stamp[idx + 1]
) / 2.0 if idx < len(self.time_stamp) - 1 else self.offset
end = end * frame_shift_in_sec
end = end * decode_frame_shift_in_sec
word_time_stamp.append({
"w": self.result_transcripts[0][idx],
"bg": start,
"ed": end
"bg": global_offset_in_sec + start,
"ed": global_offset_in_sec + end
})
# logger.info(f"{word_time_stamp[-1]}")
@ -705,13 +691,14 @@ class ASRServerExecutor(ASRExecutor):
self.model_type = model_type
self.sample_rate = sample_rate
logger.info(f"model_type: {self.model_type}")
sample_rate_str = '16k' if sample_rate == 16000 else '8k'
tag = model_type + '-' + lang + '-' + sample_rate_str
self.task_resource.set_task_model(model_tag=tag)
if cfg_path is None or am_model is None or am_params is None:
logger.info(f"Load the pretrained model, tag = {tag}")
self.res_path = self.task_resource.res_dir
self.cfg_path = os.path.join(
self.res_path, self.task_resource.res_dict['cfg_path'])
@ -719,7 +706,6 @@ class ASRServerExecutor(ASRExecutor):
self.task_resource.res_dict['model'])
self.am_params = os.path.join(self.res_path,
self.task_resource.res_dict['params'])
logger.info(self.res_path)
else:
self.cfg_path = os.path.abspath(cfg_path)
self.am_model = os.path.abspath(am_model)
@ -727,9 +713,12 @@ class ASRServerExecutor(ASRExecutor):
self.res_path = os.path.dirname(
os.path.dirname(os.path.abspath(self.cfg_path)))
logger.info(self.cfg_path)
logger.info(self.am_model)
logger.info(self.am_params)
logger.info("Load the pretrained model:")
logger.info(f" tag = {tag}")
logger.info(f" res_path: {self.res_path}")
logger.info(f" cfg path: {self.cfg_path}")
logger.info(f" am_model path: {self.am_model}")
logger.info(f" am_params path: {self.am_params}")
#Init body.
self.config = CfgNode(new_allowed=True)
@ -738,25 +727,39 @@ class ASRServerExecutor(ASRExecutor):
if self.config.spm_model_prefix:
self.config.spm_model_prefix = os.path.join(
self.res_path, self.config.spm_model_prefix)
logger.info(f"spm model path: {self.config.spm_model_prefix}")
self.vocab = self.config.vocab_filepath
self.text_feature = TextFeaturizer(
unit_type=self.config.unit_type,
vocab=self.config.vocab_filepath,
spm_model_prefix=self.config.spm_model_prefix)
self.vocab = self.config.vocab_filepath
with UpdateConfig(self.config):
if "deepspeech2" in model_type:
if "deepspeech2" in model_type:
with UpdateConfig(self.config):
# download lm
self.config.decode.lang_model_path = os.path.join(
MODEL_HOME, 'language_model',
self.config.decode.lang_model_path)
lm_url = self.task_resource.res_dict['lm_url']
lm_md5 = self.task_resource.res_dict['lm_md5']
logger.info(f"Start to load language model {lm_url}")
self.download_lm(
lm_url,
os.path.dirname(self.config.decode.lang_model_path), lm_md5)
lm_url = self.task_resource.res_dict['lm_url']
lm_md5 = self.task_resource.res_dict['lm_md5']
logger.info(f"Start to load language model {lm_url}")
self.download_lm(
lm_url,
os.path.dirname(self.config.decode.lang_model_path), lm_md5)
# AM predictor
logger.info("ASR engine start to init the am predictor")
self.am_predictor_conf = am_predictor_conf
self.am_predictor = init_predictor(
model_file=self.am_model,
params_file=self.am_params,
predictor_conf=self.am_predictor_conf)
elif "conformer" in model_type or "transformer" in model_type:
elif "conformer" in model_type or "transformer" in model_type:
with UpdateConfig(self.config):
logger.info("start to create the stream conformer asr engine")
# update the decoding method
if decode_method:
@ -770,37 +773,24 @@ class ASRServerExecutor(ASRExecutor):
logger.info(
"we set the decoding_method to attention_rescoring")
self.config.decode.decoding_method = "attention_rescoring"
assert self.config.decode.decoding_method in [
"ctc_prefix_beam_search", "attention_rescoring"
], f"we only support ctc_prefix_beam_search and attention_rescoring dedoding method, current decoding method is {self.config.decode.decoding_method}"
else:
raise Exception("wrong type")
if "deepspeech2" in model_type:
# AM predictor
logger.info("ASR engine start to init the am predictor")
self.am_predictor_conf = am_predictor_conf
self.am_predictor = init_predictor(
model_file=self.am_model,
params_file=self.am_params,
predictor_conf=self.am_predictor_conf)
elif "conformer" in model_type or "transformer" in model_type:
# load model
model_name = model_type[:model_type.rindex(
'_')] # model_type: {model_name}_{dataset}
logger.info(f"model name: {model_name}")
model_class = self.task_resource.get_model_class(model_name)
model_conf = self.config
model = model_class.from_config(model_conf)
model = model_class.from_config(self.config)
self.model = model
self.model.set_state_dict(paddle.load(self.am_model))
self.model.eval()
# load model
model_dict = paddle.load(self.am_model)
self.model.set_state_dict(model_dict)
logger.info("create the transformer like model success")
else:
raise ValueError(f"Not support: {model_type}")
raise Exception(f"not support: {model_type}")
logger.info(f"create the {model_type} model success")
return True

@ -0,0 +1,108 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import List
from paddlespeech.cli.log import logger
@dataclass
class OnlineCTCEndpointRule:
must_contain_nonsilence: bool = True
min_trailing_silence: int = 1000
min_utterance_length: int = 0
@dataclass
class OnlineCTCEndpoingOpt:
frame_shift_in_ms: int = 10
blank: int = 0 # blank id, that we consider as silence for purposes of endpointing.
blank_threshold: float = 0.8 # above blank threshold is silence
# We support three rules. We terminate decoding if ANY of these rules
# evaluates to "true". If you want to add more rules, do it by changing this
# code. If you want to disable a rule, you can set the silence-timeout for
# that rule to a very large number.
# rule1 times out after 5 seconds of silence, even if we decoded nothing.
rule1: OnlineCTCEndpointRule = OnlineCTCEndpointRule(False, 5000, 0)
# rule4 times out after 1.0 seconds of silence after decoding something,
# even if we did not reach a final-state at all.
rule2: OnlineCTCEndpointRule = OnlineCTCEndpointRule(True, 1000, 0)
# rule5 times out after the utterance is 20 seconds long, regardless of
# anything else.
rule3: OnlineCTCEndpointRule = OnlineCTCEndpointRule(False, 0, 20000)
class OnlineCTCEndpoint:
"""
[END-TO-END AUTOMATIC SPEECH RECOGNITION INTEGRATED WITH CTC-BASED VOICE ACTIVITY DETECTION](https://arxiv.org/pdf/2002.00551.pdf)
"""
def __init__(self, opts: OnlineCTCEndpoingOpt):
self.opts = opts
logger.info(f"Endpont Opts: {opts}")
self.frame_shift_in_ms = opts.frame_shift_in_ms
self.num_frames_decoded = 0
self.trailing_silence_frames = 0
self.reset()
def reset(self):
self.num_frames_decoded = 0
self.trailing_silence_frames = 0
def rule_activated(self,
rule: OnlineCTCEndpointRule,
rule_name: str,
decoding_something: bool,
trailine_silence: int,
utterance_length: int) -> bool:
ans = (
decoding_something or (not rule.must_contain_nonsilence)
) and trailine_silence >= rule.min_trailing_silence and utterance_length >= rule.min_utterance_length
if (ans):
logger.info(
f"Endpoint Rule: {rule_name} activated: {decoding_something}, {trailine_silence}, {utterance_length}"
)
return ans
def endpoint_detected(ctc_log_probs: List[List[float]],
decoding_something: bool) -> bool:
for logprob in ctc_log_probs:
blank_prob = exp(logprob[self.opts.blank_id])
self.num_frames_decoded += 1
if blank_prob > self.opts.blank_threshold:
self.trailing_silence_frames += 1
else:
self.trailing_silence_frames = 0
assert self.num_frames_decoded >= self.trailing_silence_frames
assert self.frame_shift_in_ms > 0
utterance_length = self.num_frames_decoded * self.frame_shift_in_ms
trailing_silence = self.trailing_silence_frames * self.frame_shift_in_ms
if self.rule_activated(self.opts.rule1, 'rule1', decoding_something,
trailing_silence, utterance_length):
return True
if self.rule_activated(self.opts.rule2, 'rule2', decoding_something,
trailing_silence, utterance_length):
return True
if self.rule_activated(self.opts.rule3, 'rule3', decoding_something,
trailing_silence, utterance_length):
return True
return False

@ -30,8 +30,29 @@ class CTCPrefixBeamSearch:
config (yacs.config.CfgNode): the ctc prefix beam search configuration
"""
self.config = config
# beam size
self.first_beam_size = self.config.beam_size
# TODO(support second beam size)
self.second_beam_size = int(self.first_beam_size * 1.0)
logger.info(
f"first and second beam size: {self.first_beam_size}, {self.second_beam_size}"
)
# state
self.cur_hyps = None
self.hyps = None
self.abs_time_step = 0
self.reset()
def reset(self):
"""Rest the search cache value
"""
self.cur_hyps = None
self.hyps = None
self.abs_time_step = 0
@paddle.no_grad()
def search(self, ctc_probs, device, blank_id=0):
"""ctc prefix beam search method decode a chunk feature
@ -47,12 +68,17 @@ class CTCPrefixBeamSearch:
"""
# decode
logger.info("start to ctc prefix search")
assert len(ctc_probs.shape) == 2
batch_size = 1
beam_size = self.config.beam_size
maxlen = ctc_probs.shape[0]
assert len(ctc_probs.shape) == 2
vocab_size = ctc_probs.shape[1]
first_beam_size = min(self.first_beam_size, vocab_size)
second_beam_size = min(self.second_beam_size, vocab_size)
logger.info(
f"effect first and second beam size: {self.first_beam_size}, {self.second_beam_size}"
)
maxlen = ctc_probs.shape[0]
# cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score))
# 0. blank_ending_score,
@ -75,7 +101,8 @@ class CTCPrefixBeamSearch:
# 2.1 First beam prune: select topk best
# do token passing process
top_k_logp, top_k_index = logp.topk(beam_size) # (beam_size,)
top_k_logp, top_k_index = logp.topk(
first_beam_size) # (first_beam_size,)
for s in top_k_index:
s = s.item()
ps = logp[s].item()
@ -148,7 +175,7 @@ class CTCPrefixBeamSearch:
next_hyps.items(),
key=lambda x: log_add([x[1][0], x[1][1]]),
reverse=True)
self.cur_hyps = next_hyps[:beam_size]
self.cur_hyps = next_hyps[:second_beam_size]
# 2.3 update the absolute time step
self.abs_time_step += 1
@ -163,7 +190,7 @@ class CTCPrefixBeamSearch:
"""Return the one best result
Returns:
list: the one best result
list: the one best result, List[str]
"""
return [self.hyps[0][0]]
@ -171,17 +198,10 @@ class CTCPrefixBeamSearch:
"""Return the search hyps
Returns:
list: return the search hyps
list: return the search hyps, List[Tuple[str, float, ...]]
"""
return self.hyps
def reset(self):
"""Rest the search cache value
"""
self.cur_hyps = None
self.hyps = None
self.abs_time_step = 0
def finalize_search(self):
"""do nothing in ctc_prefix_beam_search
"""

Loading…
Cancel
Save