refactor asr online server

pull/2015/head
Hui Zhang 2 years ago
parent f3132ce2d2
commit 8f9b7bba48

@ -51,12 +51,12 @@ repos:
language: system language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$ files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$
exclude: (?=speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin).*(\.cpp|\.cc|\.h|\.py)$ exclude: (?=speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin).*(\.cpp|\.cc|\.h|\.py)$
- id: copyright_checker #- id: copyright_checker
name: copyright_checker # name: copyright_checker
entry: python .pre-commit-hooks/copyright-check.hook # entry: python .pre-commit-hooks/copyright-check.hook
language: system # language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$ # files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$
exclude: (?=third_party|pypinyin|speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin).*(\.cpp|\.cc|\.h|\.py)$ # exclude: (?=third_party|pypinyin|speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin).*(\.cpp|\.cc|\.h|\.py)$
- repo: https://github.com/asottile/reorder_python_imports - repo: https://github.com/asottile/reorder_python_imports
rev: v2.4.0 rev: v2.4.0
hooks: hooks:

@ -6,3 +6,4 @@ paddlespeech_server start --config_file conf/punc_application.yaml &> punc.log &
# nohup python3 streaming_asr_server.py --config_file conf/ws_conformer_application.yaml > streaming_asr.log 2>&1 & # nohup python3 streaming_asr_server.py --config_file conf/ws_conformer_application.yaml > streaming_asr.log 2>&1 &
paddlespeech_server start --config_file conf/ws_conformer_application.yaml &> streaming_asr.log & paddlespeech_server start --config_file conf/ws_conformer_application.yaml &> streaming_asr.log &

@ -10,3 +10,4 @@ paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8290 --input ./zh.wa
# If `127.0.0.1` is not accessible, you need to use the actual service IP address. # If `127.0.0.1` is not accessible, you need to use the actual service IP address.
# python3 websocket_client.py --server_ip 127.0.0.1 --port 8290 --punc.server_ip 127.0.0.1 --punc.port 8190 --wavfile ./zh.wav # python3 websocket_client.py --server_ip 127.0.0.1 --port 8290 --punc.server_ip 127.0.0.1 --punc.port 8190 --wavfile ./zh.wav
paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8290 --punc.server_ip 127.0.0.1 --punc.port 8190 --input ./zh.wav paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8290 --punc.server_ip 127.0.0.1 --punc.port 8190 --input ./zh.wav

@ -14,3 +14,7 @@
import _locale import _locale
_locale._getdefaultlocale = (lambda *args: ['en_US', 'utf8']) _locale._getdefaultlocale = (lambda *args: ['en_US', 'utf8'])

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import os import os
import sys import sys
from typing import ByteString
from typing import Optional from typing import Optional
import numpy as np import numpy as np
@ -30,9 +31,10 @@ from paddlespeech.s2t.transform.transformation import Transformation
from paddlespeech.s2t.utils.tensor_utils import add_sos_eos from paddlespeech.s2t.utils.tensor_utils import add_sos_eos
from paddlespeech.s2t.utils.tensor_utils import pad_sequence from paddlespeech.s2t.utils.tensor_utils import pad_sequence
from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.s2t.utils.utility import UpdateConfig
from paddlespeech.server.engine.asr.online.ctc_endpoint import OnlineCTCEndpoingOpt
from paddlespeech.server.engine.asr.online.ctc_endpoint import OnlineCTCEndpoint
from paddlespeech.server.engine.asr.online.ctc_search import CTCPrefixBeamSearch from paddlespeech.server.engine.asr.online.ctc_search import CTCPrefixBeamSearch
from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils.audio_process import pcm2float
from paddlespeech.server.utils.paddle_predictor import init_predictor from paddlespeech.server.utils.paddle_predictor import init_predictor
__all__ = ['PaddleASRConnectionHanddler', 'ASRServerExecutor', 'ASREngine'] __all__ = ['PaddleASRConnectionHanddler', 'ASRServerExecutor', 'ASREngine']
@ -54,24 +56,33 @@ class PaddleASRConnectionHanddler:
self.model_config = asr_engine.executor.config self.model_config = asr_engine.executor.config
self.asr_engine = asr_engine self.asr_engine = asr_engine
self.init()
self.reset()
def init(self):
# model_type, sample_rate and text_feature is shared for deepspeech2 and conformer # model_type, sample_rate and text_feature is shared for deepspeech2 and conformer
self.model_type = self.asr_engine.executor.model_type self.model_type = self.asr_engine.executor.model_type
self.sample_rate = self.asr_engine.executor.sample_rate self.sample_rate = self.asr_engine.executor.sample_rate
# tokens to text # tokens to text
self.text_feature = self.asr_engine.executor.text_feature self.text_feature = self.asr_engine.executor.text_feature
if "deepspeech2" in self.model_type:
self.am_predictor = self.asr_engine.executor.am_predictor
# extract feat, new only fbank in conformer model # extract feat, new only fbank in conformer model
self.preprocess_conf = self.model_config.preprocess_config self.preprocess_conf = self.model_config.preprocess_config
self.preprocess_args = {"train": False} self.preprocess_args = {"train": False}
self.preprocessing = Transformation(self.preprocess_conf) self.preprocessing = Transformation(self.preprocess_conf)
# frame window and frame shift, in samples unit
self.win_length = self.preprocess_conf.process[0]['win_length']
self.n_shift = self.preprocess_conf.process[0]['n_shift']
assert self.preprocess_conf.process[0]['fs'] == self.sample_rate, (
self.sample_rate, self.preprocess_conf.process[0]['fs'])
self.frame_shift_in_ms = int(
self.n_shift / self.preprocess_conf.process[0]['fs'] * 1000)
self.init_decoder()
self.reset()
def init_decoder(self):
if "deepspeech2" in self.model_type:
self.am_predictor = self.asr_engine.executor.am_predictor
self.decoder = CTCDecoder( self.decoder = CTCDecoder(
odim=self.model_config.output_dim, # <blank> is in vocab odim=self.model_config.output_dim, # <blank> is in vocab
enc_n_units=self.model_config.rnn_layer_size * 2, enc_n_units=self.model_config.rnn_layer_size * 2,
@ -90,10 +101,6 @@ class PaddleASRConnectionHanddler:
cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n, cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n,
cfg.num_proc_bsearch) cfg.num_proc_bsearch)
# frame window and frame shift, in samples unit
self.win_length = self.preprocess_conf.process[0]['win_length']
self.n_shift = self.preprocess_conf.process[0]['n_shift']
elif "conformer" in self.model_type or "transformer" in self.model_type: elif "conformer" in self.model_type or "transformer" in self.model_type:
# acoustic model # acoustic model
self.model = self.asr_engine.executor.model self.model = self.asr_engine.executor.model
@ -102,68 +109,88 @@ class PaddleASRConnectionHanddler:
self.ctc_decode_config = self.asr_engine.executor.config.decode self.ctc_decode_config = self.asr_engine.executor.config.decode
self.searcher = CTCPrefixBeamSearch(self.ctc_decode_config) self.searcher = CTCPrefixBeamSearch(self.ctc_decode_config)
# extract feat, new only fbank in conformer model # ctc endpoint
self.preprocess_conf = self.model_config.preprocess_config self.endpoint_opt = OnlineCTCEndpoingOpt(
self.preprocess_args = {"train": False} frame_shift_in_ms=self.frame_shift_in_ms, blank=0)
self.preprocessing = Transformation(self.preprocess_conf) self.endpointer = OnlineCTCEndpoint(self.endpoint_opt)
# frame window and frame shift, in samples unit
self.win_length = self.preprocess_conf.process[0]['win_length']
self.n_shift = self.preprocess_conf.process[0]['n_shift']
else: else:
raise ValueError(f"Not supported: {self.model_type}") raise ValueError(f"Not supported: {self.model_type}")
def extract_feat(self, samples): def model_reset(self):
# we compute the elapsed time of first char occuring if "deepspeech2" in self.model_type:
# and we record the start time at the first pcm sample arraving return
if "deepspeech2online" in self.model_type: # feature cache
# self.reamined_wav stores all the samples, self.cached_feat = None
# include the original remained_wav and this package samples
samples = np.frombuffer(samples, dtype=np.int16)
assert samples.ndim == 1
if self.remained_wav is None: ## conformer
self.remained_wav = samples # cache for conformer online
else: self.subsampling_cache = None
assert self.remained_wav.ndim == 1 self.elayers_output_cache = None
self.remained_wav = np.concatenate([self.remained_wav, samples]) self.conformer_cnn_cache = None
logger.info( self.encoder_out = None
f"The connection remain the audio samples: {self.remained_wav.shape}" # conformer decoding state
) self.offset = 0 # global offset in decoding frame unit
# fbank ## just for record info
feat = self.preprocessing(self.remained_wav, self.chunk_num = 0 # global decoding chunk num, not used
**self.preprocess_args)
feat = paddle.to_tensor(
feat, dtype="float32").unsqueeze(axis=0)
if self.cached_feat is None: def reset_continuous_decoding(self):
self.cached_feat = feat """
else: when in continous decoding, reset for next utterance.
assert (len(feat.shape) == 3) """
assert (len(self.cached_feat.shape) == 3) self.global_frame_offset = self.num_frames
self.cached_feat = paddle.concat( self.model_reset()
[self.cached_feat, feat], axis=1) self.searcher.reset()
self.endpointer.reset()
# set the feat device def reset(self):
if self.device is None: if "deepspeech2" in self.model_type:
self.device = self.cached_feat.place # for deepspeech2
# init state
self.chunk_state_h_box = np.zeros(
(self.model_config.num_rnn_layers, 1,
self.model_config.rnn_layer_size),
dtype=float32)
self.chunk_state_c_box = np.zeros(
(self.model_config.num_rnn_layers, 1,
self.model_config.rnn_layer_size),
dtype=float32)
self.decoder.reset_decoder(batch_size=1)
# cur frame step if "conformer" in self.model_type or "transformer" in self.model_type:
num_frames = feat.shape[1] self.searcher.reset()
self.endpointer.reset()
self.num_frames += num_frames self.device = None
self.remained_wav = self.remained_wav[self.n_shift * num_frames:]
logger.info( ## common
f"process the audio feature success, the connection feat shape: {self.cached_feat.shape}" # global sample and frame step
) self.num_samples = 0
logger.info( self.global_frame_offset = 0
f"After extract feat, the connection remain the audio samples: {self.remained_wav.shape}" # frame step of cur utterance
) self.num_frames = 0
# cache for audio and feat
self.remained_wav = None
self.cached_feat = None
## conformer
self.model_reset()
## outputs
# partial/ending decoding results
self.result_transcripts = ['']
# token timestamp result
self.word_time_stamp = []
## just for record
self.hyps = []
# one best timestamp viterbi prob is large.
self.time_stamp = []
elif "conformer_online" in self.model_type: def extract_feat(self, samples: ByteString):
logger.info("Online ASR extract the feat") logger.info("Online ASR extract the feat")
samples = np.frombuffer(samples, dtype=np.int16) samples = np.frombuffer(samples, dtype=np.int16)
assert samples.ndim == 1 assert samples.ndim == 1
@ -189,10 +216,8 @@ class PaddleASRConnectionHanddler:
return 0 return 0
# fbank # fbank
x_chunk = self.preprocessing(self.remained_wav, x_chunk = self.preprocessing(self.remained_wav, **self.preprocess_args)
**self.preprocess_args) x_chunk = paddle.to_tensor(x_chunk, dtype="float32").unsqueeze(axis=0)
x_chunk = paddle.to_tensor(
x_chunk, dtype="float32").unsqueeze(axis=0)
# feature cache # feature cache
if self.cached_feat is None: if self.cached_feat is None:
@ -224,55 +249,6 @@ class PaddleASRConnectionHanddler:
) )
logger.info(f"global samples: {self.num_samples}") logger.info(f"global samples: {self.num_samples}")
logger.info(f"global frames: {self.num_frames}") logger.info(f"global frames: {self.num_frames}")
else:
raise ValueError(f"not supported: {self.model_type}")
def reset(self):
if "deepspeech2" in self.model_type:
# for deepspeech2
# init state
self.chunk_state_h_box = np.zeros(
(self.model_config.num_rnn_layers, 1,
self.model_config.rnn_layer_size),
dtype=float32)
self.chunk_state_c_box = np.zeros(
(self.model_config.num_rnn_layers, 1,
self.model_config.rnn_layer_size),
dtype=float32)
self.decoder.reset_decoder(batch_size=1)
self.device = None
## common
# global sample and frame step
self.num_samples = 0
self.num_frames = 0
# cache for audio and feat
self.remained_wav = None
self.cached_feat = None
# partial/ending decoding results
self.result_transcripts = ['']
## conformer
# cache for conformer online
self.subsampling_cache = None
self.elayers_output_cache = None
self.conformer_cnn_cache = None
self.encoder_out = None
# conformer decoding state
self.chunk_num = 0 # globa decoding chunk num
self.offset = 0 # global offset in decoding frame unit
self.hyps = []
# token timestamp result
self.word_time_stamp = []
# one best timestamp viterbi prob is large.
self.time_stamp = []
def decode(self, is_finished=False): def decode(self, is_finished=False):
"""advance decoding """advance decoding
@ -280,14 +256,12 @@ class PaddleASRConnectionHanddler:
Args: Args:
is_finished (bool, optional): Is last frame or not. Defaults to False. is_finished (bool, optional): Is last frame or not. Defaults to False.
Raises:
Exception: when not support model.
Returns: Returns:
None: nothing None:
""" """
if "deepspeech2online" in self.model_type: if "deepspeech2" in self.model_type:
decoding_chunk_size = 1 # decoding chunk size = 1. int decoding frame unit decoding_chunk_size = 1 # decoding chunk size = 1. int decoding frame unit
context = 7 # context=7, in audio frame unit context = 7 # context=7, in audio frame unit
subsampling = 4 # subsampling=4, in audio frame unit subsampling = 4 # subsampling=4, in audio frame unit
@ -332,9 +306,11 @@ class PaddleASRConnectionHanddler:
end = None end = None
for cur in range(0, num_frames - left_frames + 1, stride): for cur in range(0, num_frames - left_frames + 1, stride):
end = min(cur + decoding_window, num_frames) end = min(cur + decoding_window, num_frames)
# extract the audio # extract the audio
x_chunk = self.cached_feat[:, cur:end, :].numpy() x_chunk = self.cached_feat[:, cur:end, :].numpy()
x_chunk_lens = np.array([x_chunk.shape[1]]) x_chunk_lens = np.array([x_chunk.shape[1]])
trans_best = self.decode_one_chunk(x_chunk, x_chunk_lens) trans_best = self.decode_one_chunk(x_chunk, x_chunk_lens)
self.result_transcripts = [trans_best] self.result_transcripts = [trans_best]
@ -409,31 +385,38 @@ class PaddleASRConnectionHanddler:
@paddle.no_grad() @paddle.no_grad()
def advance_decoding(self, is_finished=False): def advance_decoding(self, is_finished=False):
if "deepspeech" in self.model_type:
return
logger.info( logger.info(
"Conformer/Transformer: start to decode with advanced_decoding method" "Conformer/Transformer: start to decode with advanced_decoding method"
) )
cfg = self.ctc_decode_config cfg = self.ctc_decode_config
# cur chunk size, in decoding frame unit # cur chunk size, in decoding frame unit, e.g. 16
decoding_chunk_size = cfg.decoding_chunk_size decoding_chunk_size = cfg.decoding_chunk_size
# using num of history chunks # using num of history chunks, e.g -1
num_decoding_left_chunks = cfg.num_decoding_left_chunks num_decoding_left_chunks = cfg.num_decoding_left_chunks
assert decoding_chunk_size > 0 assert decoding_chunk_size > 0
# e.g. 4
subsampling = self.model.encoder.embed.subsampling_rate subsampling = self.model.encoder.embed.subsampling_rate
# e.g. 7
context = self.model.encoder.embed.right_context + 1 context = self.model.encoder.embed.right_context + 1
# processed chunk feature cached for next chunk # processed chunk feature cached for next chunk, e.g. 3
cached_feature_num = context - subsampling cached_feature_num = context - subsampling
# decoding stride, in audio frame unit
stride = subsampling * decoding_chunk_size
# decoding window, in audio frame unit # decoding window, in audio frame unit
decoding_window = (decoding_chunk_size - 1) * subsampling + context decoding_window = (decoding_chunk_size - 1) * subsampling + context
# decoding stride, in audio frame unit
stride = subsampling * decoding_chunk_size
if self.cached_feat is None: if self.cached_feat is None:
logger.info("no audio feat, please input more pcm data") logger.info("no audio feat, please input more pcm data")
return return
# (B=1,T,D)
num_frames = self.cached_feat.shape[1] num_frames = self.cached_feat.shape[1]
logger.info( logger.info(
f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames" f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames"
@ -454,9 +437,6 @@ class PaddleASRConnectionHanddler:
return None, None return None, None
logger.info("start to do model forward") logger.info("start to do model forward")
# hist of chunks, in deocding frame unit
required_cache_size = decoding_chunk_size * num_decoding_left_chunks
outputs = []
# num_frames - context + 1 ensure that current frame can get context window # num_frames - context + 1 ensure that current frame can get context window
if is_finished: if is_finished:
@ -466,7 +446,11 @@ class PaddleASRConnectionHanddler:
# we only process decoding_window frames for one chunk # we only process decoding_window frames for one chunk
left_frames = decoding_window left_frames = decoding_window
# hist of chunks, in deocding frame unit
required_cache_size = decoding_chunk_size * num_decoding_left_chunks
# record the end for removing the processed feat # record the end for removing the processed feat
outputs = []
end = None end = None
for cur in range(0, num_frames - left_frames + 1, stride): for cur in range(0, num_frames - left_frames + 1, stride):
end = min(cur + decoding_window, num_frames) end = min(cur + decoding_window, num_frames)
@ -491,30 +475,28 @@ class PaddleASRConnectionHanddler:
self.encoder_out = ys self.encoder_out = ys
else: else:
self.encoder_out = paddle.concat([self.encoder_out, ys], axis=1) self.encoder_out = paddle.concat([self.encoder_out, ys], axis=1)
logger.info(
f"This connection handler encoder out shape: {self.encoder_out.shape}"
)
# get the ctc probs # get the ctc probs
ctc_probs = self.model.ctc.log_softmax(ys) # (1, maxlen, vocab_size) ctc_probs = self.model.ctc.log_softmax(ys) # (1, maxlen, vocab_size)
ctc_probs = ctc_probs.squeeze(0) ctc_probs = ctc_probs.squeeze(0)
## decoding
# advance decoding # advance decoding
self.searcher.search(ctc_probs, self.cached_feat.place) self.searcher.search(ctc_probs, self.cached_feat.place)
# get one best hyps # get one best hyps
self.hyps = self.searcher.get_one_best_hyps() self.hyps = self.searcher.get_one_best_hyps()
assert self.cached_feat.shape[0] == 1
assert end >= cached_feature_num
# advance cache of feat # advance cache of feat
self.cached_feat = self.cached_feat[0, end - assert self.cached_feat.shape[0] == 1 #(B=1,T,D)
cached_feature_num:, :].unsqueeze(0) assert end >= cached_feature_num
self.cached_feat = self.cached_feat[:, end - cached_feature_num:, :]
assert len( assert len(
self.cached_feat.shape self.cached_feat.shape
) == 3, f"current cache feat shape is: {self.cached_feat.shape}" ) == 3, f"current cache feat shape is: {self.cached_feat.shape}"
logger.info(
f"This connection handler encoder out shape: {self.encoder_out.shape}"
)
def update_result(self): def update_result(self):
"""Conformer/Transformer hyps to result. """Conformer/Transformer hyps to result.
""" """
@ -654,24 +636,28 @@ class PaddleASRConnectionHanddler:
# update each word start and end time stamp # update each word start and end time stamp
# decoding frame to audio frame # decoding frame to audio frame
frame_shift = self.model.encoder.embed.subsampling_rate decode_frame_shift = self.model.encoder.embed.subsampling_rate
frame_shift_in_sec = frame_shift * (self.n_shift / self.sample_rate) decode_frame_shift_in_sec = decode_frame_shift * (self.n_shift /
logger.info(f"frame shift sec: {frame_shift_in_sec}") self.sample_rate)
logger.info(f"decode frame shift in sec: {decode_frame_shift_in_sec}")
global_offset_in_sec = self.global_frame_offset * self.frame_shift_in_ms / 1000.0
logger.info(f"global offset: {global_offset_in_sec} sec.")
word_time_stamp = [] word_time_stamp = []
for idx, _ in enumerate(self.time_stamp): for idx, _ in enumerate(self.time_stamp):
start = (self.time_stamp[idx - 1] + self.time_stamp[idx] start = (self.time_stamp[idx - 1] + self.time_stamp[idx]
) / 2.0 if idx > 0 else 0 ) / 2.0 if idx > 0 else 0
start = start * frame_shift_in_sec start = start * decode_frame_shift_in_sec
end = (self.time_stamp[idx] + self.time_stamp[idx + 1] end = (self.time_stamp[idx] + self.time_stamp[idx + 1]
) / 2.0 if idx < len(self.time_stamp) - 1 else self.offset ) / 2.0 if idx < len(self.time_stamp) - 1 else self.offset
end = end * frame_shift_in_sec end = end * decode_frame_shift_in_sec
word_time_stamp.append({ word_time_stamp.append({
"w": self.result_transcripts[0][idx], "w": self.result_transcripts[0][idx],
"bg": start, "bg": global_offset_in_sec + start,
"ed": end "ed": global_offset_in_sec + end
}) })
# logger.info(f"{word_time_stamp[-1]}") # logger.info(f"{word_time_stamp[-1]}")
@ -705,13 +691,14 @@ class ASRServerExecutor(ASRExecutor):
self.model_type = model_type self.model_type = model_type
self.sample_rate = sample_rate self.sample_rate = sample_rate
logger.info(f"model_type: {self.model_type}")
sample_rate_str = '16k' if sample_rate == 16000 else '8k' sample_rate_str = '16k' if sample_rate == 16000 else '8k'
tag = model_type + '-' + lang + '-' + sample_rate_str tag = model_type + '-' + lang + '-' + sample_rate_str
self.task_resource.set_task_model(model_tag=tag) self.task_resource.set_task_model(model_tag=tag)
if cfg_path is None or am_model is None or am_params is None: if cfg_path is None or am_model is None or am_params is None:
logger.info(f"Load the pretrained model, tag = {tag}")
self.res_path = self.task_resource.res_dir self.res_path = self.task_resource.res_dir
self.cfg_path = os.path.join( self.cfg_path = os.path.join(
self.res_path, self.task_resource.res_dict['cfg_path']) self.res_path, self.task_resource.res_dict['cfg_path'])
@ -719,7 +706,6 @@ class ASRServerExecutor(ASRExecutor):
self.task_resource.res_dict['model']) self.task_resource.res_dict['model'])
self.am_params = os.path.join(self.res_path, self.am_params = os.path.join(self.res_path,
self.task_resource.res_dict['params']) self.task_resource.res_dict['params'])
logger.info(self.res_path)
else: else:
self.cfg_path = os.path.abspath(cfg_path) self.cfg_path = os.path.abspath(cfg_path)
self.am_model = os.path.abspath(am_model) self.am_model = os.path.abspath(am_model)
@ -727,9 +713,12 @@ class ASRServerExecutor(ASRExecutor):
self.res_path = os.path.dirname( self.res_path = os.path.dirname(
os.path.dirname(os.path.abspath(self.cfg_path))) os.path.dirname(os.path.abspath(self.cfg_path)))
logger.info(self.cfg_path) logger.info("Load the pretrained model:")
logger.info(self.am_model) logger.info(f" tag = {tag}")
logger.info(self.am_params) logger.info(f" res_path: {self.res_path}")
logger.info(f" cfg path: {self.cfg_path}")
logger.info(f" am_model path: {self.am_model}")
logger.info(f" am_params path: {self.am_params}")
#Init body. #Init body.
self.config = CfgNode(new_allowed=True) self.config = CfgNode(new_allowed=True)
@ -738,13 +727,18 @@ class ASRServerExecutor(ASRExecutor):
if self.config.spm_model_prefix: if self.config.spm_model_prefix:
self.config.spm_model_prefix = os.path.join( self.config.spm_model_prefix = os.path.join(
self.res_path, self.config.spm_model_prefix) self.res_path, self.config.spm_model_prefix)
logger.info(f"spm model path: {self.config.spm_model_prefix}")
self.vocab = self.config.vocab_filepath
self.text_feature = TextFeaturizer( self.text_feature = TextFeaturizer(
unit_type=self.config.unit_type, unit_type=self.config.unit_type,
vocab=self.config.vocab_filepath, vocab=self.config.vocab_filepath,
spm_model_prefix=self.config.spm_model_prefix) spm_model_prefix=self.config.spm_model_prefix)
self.vocab = self.config.vocab_filepath
with UpdateConfig(self.config):
if "deepspeech2" in model_type: if "deepspeech2" in model_type:
with UpdateConfig(self.config):
# download lm
self.config.decode.lang_model_path = os.path.join( self.config.decode.lang_model_path = os.path.join(
MODEL_HOME, 'language_model', MODEL_HOME, 'language_model',
self.config.decode.lang_model_path) self.config.decode.lang_model_path)
@ -756,7 +750,16 @@ class ASRServerExecutor(ASRExecutor):
lm_url, lm_url,
os.path.dirname(self.config.decode.lang_model_path), lm_md5) os.path.dirname(self.config.decode.lang_model_path), lm_md5)
# AM predictor
logger.info("ASR engine start to init the am predictor")
self.am_predictor_conf = am_predictor_conf
self.am_predictor = init_predictor(
model_file=self.am_model,
params_file=self.am_params,
predictor_conf=self.am_predictor_conf)
elif "conformer" in model_type or "transformer" in model_type: elif "conformer" in model_type or "transformer" in model_type:
with UpdateConfig(self.config):
logger.info("start to create the stream conformer asr engine") logger.info("start to create the stream conformer asr engine")
# update the decoding method # update the decoding method
if decode_method: if decode_method:
@ -770,37 +773,24 @@ class ASRServerExecutor(ASRExecutor):
logger.info( logger.info(
"we set the decoding_method to attention_rescoring") "we set the decoding_method to attention_rescoring")
self.config.decode.decoding_method = "attention_rescoring" self.config.decode.decoding_method = "attention_rescoring"
assert self.config.decode.decoding_method in [ assert self.config.decode.decoding_method in [
"ctc_prefix_beam_search", "attention_rescoring" "ctc_prefix_beam_search", "attention_rescoring"
], f"we only support ctc_prefix_beam_search and attention_rescoring dedoding method, current decoding method is {self.config.decode.decoding_method}" ], f"we only support ctc_prefix_beam_search and attention_rescoring dedoding method, current decoding method is {self.config.decode.decoding_method}"
else:
raise Exception("wrong type")
if "deepspeech2" in model_type: # load model
# AM predictor
logger.info("ASR engine start to init the am predictor")
self.am_predictor_conf = am_predictor_conf
self.am_predictor = init_predictor(
model_file=self.am_model,
params_file=self.am_params,
predictor_conf=self.am_predictor_conf)
elif "conformer" in model_type or "transformer" in model_type:
model_name = model_type[:model_type.rindex( model_name = model_type[:model_type.rindex(
'_')] # model_type: {model_name}_{dataset} '_')] # model_type: {model_name}_{dataset}
logger.info(f"model name: {model_name}") logger.info(f"model name: {model_name}")
model_class = self.task_resource.get_model_class(model_name) model_class = self.task_resource.get_model_class(model_name)
model_conf = self.config model = model_class.from_config(self.config)
model = model_class.from_config(model_conf)
self.model = model self.model = model
self.model.set_state_dict(paddle.load(self.am_model))
self.model.eval() self.model.eval()
# load model
model_dict = paddle.load(self.am_model)
self.model.set_state_dict(model_dict)
logger.info("create the transformer like model success")
else: else:
raise ValueError(f"Not support: {model_type}") raise Exception(f"not support: {model_type}")
logger.info(f"create the {model_type} model success")
return True return True

@ -0,0 +1,108 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import List
from paddlespeech.cli.log import logger
@dataclass
class OnlineCTCEndpointRule:
must_contain_nonsilence: bool = True
min_trailing_silence: int = 1000
min_utterance_length: int = 0
@dataclass
class OnlineCTCEndpoingOpt:
frame_shift_in_ms: int = 10
blank: int = 0 # blank id, that we consider as silence for purposes of endpointing.
blank_threshold: float = 0.8 # above blank threshold is silence
# We support three rules. We terminate decoding if ANY of these rules
# evaluates to "true". If you want to add more rules, do it by changing this
# code. If you want to disable a rule, you can set the silence-timeout for
# that rule to a very large number.
# rule1 times out after 5 seconds of silence, even if we decoded nothing.
rule1: OnlineCTCEndpointRule = OnlineCTCEndpointRule(False, 5000, 0)
# rule4 times out after 1.0 seconds of silence after decoding something,
# even if we did not reach a final-state at all.
rule2: OnlineCTCEndpointRule = OnlineCTCEndpointRule(True, 1000, 0)
# rule5 times out after the utterance is 20 seconds long, regardless of
# anything else.
rule3: OnlineCTCEndpointRule = OnlineCTCEndpointRule(False, 0, 20000)
class OnlineCTCEndpoint:
"""
[END-TO-END AUTOMATIC SPEECH RECOGNITION INTEGRATED WITH CTC-BASED VOICE ACTIVITY DETECTION](https://arxiv.org/pdf/2002.00551.pdf)
"""
def __init__(self, opts: OnlineCTCEndpoingOpt):
self.opts = opts
logger.info(f"Endpont Opts: {opts}")
self.frame_shift_in_ms = opts.frame_shift_in_ms
self.num_frames_decoded = 0
self.trailing_silence_frames = 0
self.reset()
def reset(self):
self.num_frames_decoded = 0
self.trailing_silence_frames = 0
def rule_activated(self,
rule: OnlineCTCEndpointRule,
rule_name: str,
decoding_something: bool,
trailine_silence: int,
utterance_length: int) -> bool:
ans = (
decoding_something or (not rule.must_contain_nonsilence)
) and trailine_silence >= rule.min_trailing_silence and utterance_length >= rule.min_utterance_length
if (ans):
logger.info(
f"Endpoint Rule: {rule_name} activated: {decoding_something}, {trailine_silence}, {utterance_length}"
)
return ans
def endpoint_detected(ctc_log_probs: List[List[float]],
decoding_something: bool) -> bool:
for logprob in ctc_log_probs:
blank_prob = exp(logprob[self.opts.blank_id])
self.num_frames_decoded += 1
if blank_prob > self.opts.blank_threshold:
self.trailing_silence_frames += 1
else:
self.trailing_silence_frames = 0
assert self.num_frames_decoded >= self.trailing_silence_frames
assert self.frame_shift_in_ms > 0
utterance_length = self.num_frames_decoded * self.frame_shift_in_ms
trailing_silence = self.trailing_silence_frames * self.frame_shift_in_ms
if self.rule_activated(self.opts.rule1, 'rule1', decoding_something,
trailing_silence, utterance_length):
return True
if self.rule_activated(self.opts.rule2, 'rule2', decoding_something,
trailing_silence, utterance_length):
return True
if self.rule_activated(self.opts.rule3, 'rule3', decoding_something,
trailing_silence, utterance_length):
return True
return False

@ -30,8 +30,29 @@ class CTCPrefixBeamSearch:
config (yacs.config.CfgNode): the ctc prefix beam search configuration config (yacs.config.CfgNode): the ctc prefix beam search configuration
""" """
self.config = config self.config = config
# beam size
self.first_beam_size = self.config.beam_size
# TODO(support second beam size)
self.second_beam_size = int(self.first_beam_size * 1.0)
logger.info(
f"first and second beam size: {self.first_beam_size}, {self.second_beam_size}"
)
# state
self.cur_hyps = None
self.hyps = None
self.abs_time_step = 0
self.reset() self.reset()
def reset(self):
"""Rest the search cache value
"""
self.cur_hyps = None
self.hyps = None
self.abs_time_step = 0
@paddle.no_grad() @paddle.no_grad()
def search(self, ctc_probs, device, blank_id=0): def search(self, ctc_probs, device, blank_id=0):
"""ctc prefix beam search method decode a chunk feature """ctc prefix beam search method decode a chunk feature
@ -47,12 +68,17 @@ class CTCPrefixBeamSearch:
""" """
# decode # decode
logger.info("start to ctc prefix search") logger.info("start to ctc prefix search")
assert len(ctc_probs.shape) == 2
batch_size = 1 batch_size = 1
beam_size = self.config.beam_size
maxlen = ctc_probs.shape[0]
assert len(ctc_probs.shape) == 2 vocab_size = ctc_probs.shape[1]
first_beam_size = min(self.first_beam_size, vocab_size)
second_beam_size = min(self.second_beam_size, vocab_size)
logger.info(
f"effect first and second beam size: {self.first_beam_size}, {self.second_beam_size}"
)
maxlen = ctc_probs.shape[0]
# cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score)) # cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score))
# 0. blank_ending_score, # 0. blank_ending_score,
@ -75,7 +101,8 @@ class CTCPrefixBeamSearch:
# 2.1 First beam prune: select topk best # 2.1 First beam prune: select topk best
# do token passing process # do token passing process
top_k_logp, top_k_index = logp.topk(beam_size) # (beam_size,) top_k_logp, top_k_index = logp.topk(
first_beam_size) # (first_beam_size,)
for s in top_k_index: for s in top_k_index:
s = s.item() s = s.item()
ps = logp[s].item() ps = logp[s].item()
@ -148,7 +175,7 @@ class CTCPrefixBeamSearch:
next_hyps.items(), next_hyps.items(),
key=lambda x: log_add([x[1][0], x[1][1]]), key=lambda x: log_add([x[1][0], x[1][1]]),
reverse=True) reverse=True)
self.cur_hyps = next_hyps[:beam_size] self.cur_hyps = next_hyps[:second_beam_size]
# 2.3 update the absolute time step # 2.3 update the absolute time step
self.abs_time_step += 1 self.abs_time_step += 1
@ -163,7 +190,7 @@ class CTCPrefixBeamSearch:
"""Return the one best result """Return the one best result
Returns: Returns:
list: the one best result list: the one best result, List[str]
""" """
return [self.hyps[0][0]] return [self.hyps[0][0]]
@ -171,17 +198,10 @@ class CTCPrefixBeamSearch:
"""Return the search hyps """Return the search hyps
Returns: Returns:
list: return the search hyps list: return the search hyps, List[Tuple[str, float, ...]]
""" """
return self.hyps return self.hyps
def reset(self):
"""Rest the search cache value
"""
self.cur_hyps = None
self.hyps = None
self.abs_time_step = 0
def finalize_search(self): def finalize_search(self):
"""do nothing in ctc_prefix_beam_search """do nothing in ctc_prefix_beam_search
""" """

Loading…
Cancel
Save