add ds2 model multi session, test=doc

pull/1704/head
xiongxinlei 3 years ago
parent 5acb0b5252
commit 380afbbc5d

@ -18,44 +18,10 @@ engine_list: ['asr_online']
# ENGINE CONFIG # # ENGINE CONFIG #
################################################################################# #################################################################################
# ################################### ASR #########################################
# ################### speech task: asr; engine_type: online #######################
# asr_online:
# model_type: 'deepspeech2online_aishell'
# am_model: # the pdmodel file of am static model [optional]
# am_params: # the pdiparams file of am static model [optional]
# lang: 'zh'
# sample_rate: 16000
# cfg_path:
# decode_method:
# force_yes: True
# am_predictor_conf:
# device: # set 'gpu:id' or 'cpu'
# switch_ir_optim: True
# glog_info: False # True -> print glog
# summary: True # False -> do not show predictor config
# chunk_buffer_conf:
# frame_duration_ms: 80
# shift_ms: 40
# sample_rate: 16000
# sample_width: 2
# vad_conf:
# aggressiveness: 2
# sample_rate: 16000
# frame_duration_ms: 20
# sample_width: 2
# padding_ms: 200
# padding_ratio: 0.9
################################### ASR ######################################### ################################### ASR #########################################
################### speech task: asr; engine_type: online ####################### ################### speech task: asr; engine_type: online #######################
asr_online: asr_online:
model_type: 'conformer2online_aishell' model_type: 'deepspeech2online_aishell'
am_model: # the pdmodel file of am static model [optional] am_model: # the pdmodel file of am static model [optional]
am_params: # the pdiparams file of am static model [optional] am_params: # the pdiparams file of am static model [optional]
lang: 'zh' lang: 'zh'
@ -71,9 +37,19 @@ asr_online:
summary: True # False -> do not show predictor config summary: True # False -> do not show predictor config
chunk_buffer_conf: chunk_buffer_conf:
frame_duration_ms: 80
shift_ms: 40
sample_rate: 16000
sample_width: 2
window_n: 7 # frame window_n: 7 # frame
shift_n: 4 # frame shift_n: 4 # frame
window_ms: 25 # ms window_ms: 20 # ms
shift_ms: 10 # ms shift_ms: 10 # ms
vad_conf:
aggressiveness: 2
sample_rate: 16000 sample_rate: 16000
sample_width: 2 frame_duration_ms: 20
sample_width: 2
padding_ms: 200
padding_ratio: 0.9

@ -0,0 +1,45 @@
# This is the parameter configuration file for PaddleSpeech Serving.
#################################################################################
# SERVER SETTING #
#################################################################################
host: 0.0.0.0
port: 8090
# The task format in the engin_list is: <speech task>_<engine type>
# task choices = ['asr_online', 'tts_online']
# protocol = ['websocket', 'http'] (only one can be selected).
# websocket only support online engine type.
protocol: 'websocket'
engine_list: ['asr_online']
#################################################################################
# ENGINE CONFIG #
#################################################################################
################################### ASR #########################################
################### speech task: asr; engine_type: online #######################
asr_online:
model_type: 'conformer2online_aishell'
am_model: # the pdmodel file of am static model [optional]
am_params: # the pdiparams file of am static model [optional]
lang: 'zh'
sample_rate: 16000
cfg_path:
decode_method:
force_yes: True
am_predictor_conf:
device: # set 'gpu:id' or 'cpu'
switch_ir_optim: True
glog_info: False # True -> print glog
summary: True # False -> do not show predictor config
chunk_buffer_conf:
window_n: 7 # frame
shift_n: 4 # frame
window_ms: 25 # ms
shift_ms: 10 # ms
sample_rate: 16000
sample_width: 2

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import os import os
from typing import Optional from typing import Optional
import copy
import numpy as np import numpy as np
import paddle import paddle
from numpy import float32 from numpy import float32
@ -93,7 +93,7 @@ class PaddleASRConnectionHanddler:
) )
self.config = asr_engine.config self.config = asr_engine.config
self.model_config = asr_engine.executor.config self.model_config = asr_engine.executor.config
self.model = asr_engine.executor.model # self.model = asr_engine.executor.model
self.asr_engine = asr_engine self.asr_engine = asr_engine
self.init() self.init()
@ -102,7 +102,32 @@ class PaddleASRConnectionHanddler:
def init(self): def init(self):
self.model_type = self.asr_engine.executor.model_type self.model_type = self.asr_engine.executor.model_type
if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type:
pass from paddlespeech.s2t.io.collator import SpeechCollator
self.sample_rate = self.asr_engine.executor.sample_rate
self.am_predictor = self.asr_engine.executor.am_predictor
self.text_feature = self.asr_engine.executor.text_feature
self.collate_fn_test = SpeechCollator.from_config(self.model_config)
self.decoder = CTCDecoder(
odim=self.model_config.output_dim, # <blank> is in vocab
enc_n_units=self.model_config.rnn_layer_size * 2,
blank_id=self.model_config.blank_id,
dropout_rate=0.0,
reduction=True, # sum
batch_average=True, # sum / batch_size
grad_norm_type=self.model_config.get('ctc_grad_norm_type', None))
cfg = self.model_config.decode
decode_batch_size = 1 # for online
self.decoder.init_decoder(
decode_batch_size, self.text_feature.vocab_list,
cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta,
cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n,
cfg.num_proc_bsearch)
# frame window samples length and frame shift samples length
self.win_length = int(self.model_config.window_ms * self.sample_rate)
self.n_shift = int(self.model_config.stride_ms * self.sample_rate)
elif "conformer" in self.model_type or "transformer" in self.model_type or "wenetspeech" in self.model_type: elif "conformer" in self.model_type or "transformer" in self.model_type or "wenetspeech" in self.model_type:
self.sample_rate = self.asr_engine.executor.sample_rate self.sample_rate = self.asr_engine.executor.sample_rate
@ -127,7 +152,65 @@ class PaddleASRConnectionHanddler:
def extract_feat(self, samples): def extract_feat(self, samples):
if "deepspeech2online" in self.model_type: if "deepspeech2online" in self.model_type:
pass # self.reamined_wav stores all the samples,
# include the original remained_wav and this package samples
samples = np.frombuffer(samples, dtype=np.int16)
assert samples.ndim == 1
if self.remained_wav is None:
self.remained_wav = samples
else:
assert self.remained_wav.ndim == 1
self.remained_wav = np.concatenate([self.remained_wav, samples])
logger.info(
f"The connection remain the audio samples: {self.remained_wav.shape}"
)
# pcm16 -> pcm 32
samples = pcm2float(self.remained_wav)
# read audio
speech_segment = SpeechSegment.from_pcm(
samples, self.sample_rate, transcript=" ")
# audio augment
self.collate_fn_test.augmentation.transform_audio(speech_segment)
# extract speech feature
spectrum, transcript_part = self.collate_fn_test._speech_featurizer.featurize(
speech_segment, self.collate_fn_test.keep_transcription_text)
# CMVN spectrum
if self.collate_fn_test._normalizer:
spectrum = self.collate_fn_test._normalizer.apply(spectrum)
# spectrum augment
audio = self.collate_fn_test.augmentation.transform_feature(
spectrum)
audio_len = audio.shape[0]
audio = paddle.to_tensor(audio, dtype='float32')
# audio_len = paddle.to_tensor(audio_len)
audio = paddle.unsqueeze(audio, axis=0)
if self.cached_feat is None:
self.cached_feat = audio
else:
assert (len(audio.shape) == 3)
assert (len(self.cached_feat.shape) == 3)
self.cached_feat = paddle.concat(
[self.cached_feat, audio], axis=1)
# set the feat device
if self.device is None:
self.device = self.cached_feat.place
self.num_frames += audio_len
self.remained_wav = self.remained_wav[self.n_shift * audio_len:]
logger.info(
f"process the audio feature success, the connection feat shape: {self.cached_feat.shape}"
)
logger.info(
f"After extract feat, the connection remain the audio samples: {self.remained_wav.shape}"
)
elif "conformer2online" in self.model_type: elif "conformer2online" in self.model_type:
logger.info("Online ASR extract the feat") logger.info("Online ASR extract the feat")
samples = np.frombuffer(samples, dtype=np.int16) samples = np.frombuffer(samples, dtype=np.int16)
@ -179,24 +262,81 @@ class PaddleASRConnectionHanddler:
# logger.info(f"accumulate samples: {self.num_samples}") # logger.info(f"accumulate samples: {self.num_samples}")
def reset(self): def reset(self):
self.subsampling_cache = None if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type:
self.elayers_output_cache = None # for deepspeech2
self.conformer_cnn_cache = None self.chunk_state_h_box = copy.deepcopy(self.asr_engine.executor.chunk_state_h_box)
self.encoder_out = None self.chunk_state_c_box = copy.deepcopy(self.asr_engine.executor.chunk_state_c_box)
self.cached_feat = None self.decoder.reset_decoder(batch_size=1)
self.remained_wav = None elif "conformer" in self.model_type or "transformer" in self.model_type or "wenetspeech" in self.model_type:
self.offset = 0 # for conformer online
self.num_samples = 0 self.subsampling_cache = None
self.device = None self.elayers_output_cache = None
self.hyps = [] self.conformer_cnn_cache = None
self.num_frames = 0 self.encoder_out = None
self.chunk_num = 0 self.cached_feat = None
self.global_frame_offset = 0 self.remained_wav = None
self.result_transcripts = [''] self.offset = 0
self.num_samples = 0
self.device = None
self.hyps = []
self.num_frames = 0
self.chunk_num = 0
self.global_frame_offset = 0
self.result_transcripts = ['']
def decode(self, is_finished=False): def decode(self, is_finished=False):
if "deepspeech2online" in self.model_type: if "deepspeech2online" in self.model_type:
pass # x_chunk 是特征数据
decoding_chunk_size = 1 # decoding_chunk_size=1 in deepspeech2 model
context = 7 # context=7 in deepspeech2 model
subsampling = 4 # subsampling=4 in deepspeech2 model
stride = subsampling * decoding_chunk_size
cached_feature_num = context - subsampling
# decoding window for model
decoding_window = (decoding_chunk_size - 1) * subsampling + context
if self.cached_feat is None:
logger.info("no audio feat, please input more pcm data")
return
num_frames = self.cached_feat.shape[1]
logger.info(
f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames"
)
# the cached feat must be larger decoding_window
if num_frames < decoding_window and not is_finished:
logger.info(
f"frame feat num is less than {decoding_window}, please input more pcm data"
)
return None, None
# if is_finished=True, we need at least context frames
if num_frames < context:
logger.info(
"flast {num_frames} is less than context {context} frames, and we cannot do model forward"
)
return None, None
logger.info("start to do model forward")
# num_frames - context + 1 ensure that current frame can get context window
if is_finished:
# if get the finished chunk, we need process the last context
left_frames = context
else:
# we only process decoding_window frames for one chunk
left_frames = decoding_window
for cur in range(0, num_frames - left_frames + 1, stride):
end = min(cur + decoding_window, num_frames)
# extract the audio
x_chunk = self.cached_feat[:, cur:end, :].numpy()
x_chunk_lens = np.array([x_chunk.shape[1]])
trans_best = self.decode_one_chunk(x_chunk, x_chunk_lens)
self.result_transcripts = [trans_best]
self.cached_feat = self.cached_feat[:, end -
cached_feature_num:, :]
# return trans_best[0]
elif "conformer" in self.model_type or "transformer" in self.model_type: elif "conformer" in self.model_type or "transformer" in self.model_type:
try: try:
logger.info( logger.info(
@ -210,6 +350,48 @@ class PaddleASRConnectionHanddler:
else: else:
raise Exception("invalid model name") raise Exception("invalid model name")
def decode_one_chunk(self, x_chunk, x_chunk_lens):
logger.info("start to decoce one chunk with deepspeech2 model")
input_names = self.am_predictor.get_input_names()
audio_handle = self.am_predictor.get_input_handle(input_names[0])
audio_len_handle = self.am_predictor.get_input_handle(
input_names[1])
h_box_handle = self.am_predictor.get_input_handle(input_names[2])
c_box_handle = self.am_predictor.get_input_handle(input_names[3])
audio_handle.reshape(x_chunk.shape)
audio_handle.copy_from_cpu(x_chunk)
audio_len_handle.reshape(x_chunk_lens.shape)
audio_len_handle.copy_from_cpu(x_chunk_lens)
h_box_handle.reshape(self.chunk_state_h_box.shape)
h_box_handle.copy_from_cpu(self.chunk_state_h_box)
c_box_handle.reshape(self.chunk_state_c_box.shape)
c_box_handle.copy_from_cpu(self.chunk_state_c_box)
output_names = self.am_predictor.get_output_names()
output_handle = self.am_predictor.get_output_handle(output_names[0])
output_lens_handle = self.am_predictor.get_output_handle(
output_names[1])
output_state_h_handle = self.am_predictor.get_output_handle(
output_names[2])
output_state_c_handle = self.am_predictor.get_output_handle(
output_names[3])
self.am_predictor.run()
output_chunk_probs = output_handle.copy_to_cpu()
output_chunk_lens = output_lens_handle.copy_to_cpu()
self.chunk_state_h_box = output_state_h_handle.copy_to_cpu()
self.chunk_state_c_box = output_state_c_handle.copy_to_cpu()
self.decoder.next(output_chunk_probs, output_chunk_lens)
trans_best, trans_beam = self.decoder.decode()
logger.info(f"decode one one best result: {trans_best[0]}")
return trans_best[0]
def advance_decoding(self, is_finished=False): def advance_decoding(self, is_finished=False):
logger.info("start to decode with advanced_decoding method") logger.info("start to decode with advanced_decoding method")
cfg = self.ctc_decode_config cfg = self.ctc_decode_config
@ -240,6 +422,7 @@ class PaddleASRConnectionHanddler:
) )
return None, None return None, None
# if is_finished=True, we need at least context frames
if num_frames < context: if num_frames < context:
logger.info( logger.info(
"flast {num_frames} is less than context {context} frames, and we cannot do model forward" "flast {num_frames} is less than context {context} frames, and we cannot do model forward"
@ -315,6 +498,9 @@ class PaddleASRConnectionHanddler:
return '' return ''
def rescoring(self): def rescoring(self):
if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type:
return
logger.info("rescoring the final result") logger.info("rescoring the final result")
if "attention_rescoring" != self.ctc_decode_config.decoding_method: if "attention_rescoring" != self.ctc_decode_config.decoding_method:
return return

Loading…
Cancel
Save