add conformer online server, test=doc

pull/1704/head
xiongxinlei 3 years ago
parent af484fc980
commit d21ccd0287

@ -91,6 +91,20 @@ pretrained_models = {
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5': 'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3' '29e02312deb2e59b3c8686c7966d4fe3'
},
"conformer2online_aishell-zh-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_chunk_conformer_aishell_ckpt_0.1.2.model.tar.gz',
'md5':
'4814e52e0fc2fd48899373f95c84b0c9',
'cfg_path':
'config.yaml',
'ckpt_path':
'exp/deepspeech2_online/checkpoints/avg_30',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
}, },
"deepspeech2offline_librispeech-en-16k": { "deepspeech2offline_librispeech-en-16k": {
'url': 'url':
@ -115,6 +129,8 @@ model_alias = {
"paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline", "paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline",
"conformer": "conformer":
"paddlespeech.s2t.models.u2:U2Model", "paddlespeech.s2t.models.u2:U2Model",
"conformer2online":
"paddlespeech.s2t.models.u2:U2Model",
"transformer": "transformer":
"paddlespeech.s2t.models.u2:U2Model", "paddlespeech.s2t.models.u2:U2Model",
"wenetspeech": "wenetspeech":
@ -219,6 +235,7 @@ class ASRExecutor(BaseExecutor):
""" """
Init model and other resources from a specific path. Init model and other resources from a specific path.
""" """
logger.info("start to init the model")
if hasattr(self, 'model'): if hasattr(self, 'model'):
logger.info('Model had been initialized.') logger.info('Model had been initialized.')
return return
@ -233,14 +250,15 @@ class ASRExecutor(BaseExecutor):
self.ckpt_path = os.path.join( self.ckpt_path = os.path.join(
res_path, pretrained_models[tag]['ckpt_path'] + ".pdparams") res_path, pretrained_models[tag]['ckpt_path'] + ".pdparams")
logger.info(res_path) logger.info(res_path)
logger.info(self.cfg_path)
logger.info(self.ckpt_path)
else: else:
self.cfg_path = os.path.abspath(cfg_path) self.cfg_path = os.path.abspath(cfg_path)
self.ckpt_path = os.path.abspath(ckpt_path + ".pdparams") self.ckpt_path = os.path.abspath(ckpt_path + ".pdparams")
self.res_path = os.path.dirname( self.res_path = os.path.dirname(
os.path.dirname(os.path.abspath(self.cfg_path))) os.path.dirname(os.path.abspath(self.cfg_path)))
logger.info(self.cfg_path)
logger.info(self.ckpt_path)
#Init body. #Init body.
self.config = CfgNode(new_allowed=True) self.config = CfgNode(new_allowed=True)
self.config.merge_from_file(self.cfg_path) self.config.merge_from_file(self.cfg_path)
@ -269,7 +287,6 @@ class ASRExecutor(BaseExecutor):
vocab=self.config.vocab_filepath, vocab=self.config.vocab_filepath,
spm_model_prefix=self.config.spm_model_prefix) spm_model_prefix=self.config.spm_model_prefix)
self.config.decode.decoding_method = decode_method self.config.decode.decoding_method = decode_method
else: else:
raise Exception("wrong type") raise Exception("wrong type")
model_name = model_type[:model_type.rindex( model_name = model_type[:model_type.rindex(
@ -347,12 +364,14 @@ class ASRExecutor(BaseExecutor):
else: else:
raise Exception("wrong type") raise Exception("wrong type")
logger.info("audio feat process success")
@paddle.no_grad() @paddle.no_grad()
def infer(self, model_type: str): def infer(self, model_type: str):
""" """
Model inference and result stored in self.output. Model inference and result stored in self.output.
""" """
logger.info("start to infer the model to get the output")
cfg = self.config.decode cfg = self.config.decode
audio = self._inputs["audio"] audio = self._inputs["audio"]
audio_len = self._inputs["audio_len"] audio_len = self._inputs["audio_len"]
@ -369,17 +388,22 @@ class ASRExecutor(BaseExecutor):
self._outputs["result"] = result_transcripts[0] self._outputs["result"] = result_transcripts[0]
elif "conformer" in model_type or "transformer" in model_type: elif "conformer" in model_type or "transformer" in model_type:
result_transcripts = self.model.decode( logger.info(f"we will use the transformer like model : {model_type}")
audio, try:
audio_len, result_transcripts = self.model.decode(
text_feature=self.text_feature, audio,
decoding_method=cfg.decoding_method, audio_len,
beam_size=cfg.beam_size, text_feature=self.text_feature,
ctc_weight=cfg.ctc_weight, decoding_method=cfg.decoding_method,
decoding_chunk_size=cfg.decoding_chunk_size, beam_size=cfg.beam_size,
num_decoding_left_chunks=cfg.num_decoding_left_chunks, ctc_weight=cfg.ctc_weight,
simulate_streaming=cfg.simulate_streaming) decoding_chunk_size=cfg.decoding_chunk_size,
self._outputs["result"] = result_transcripts[0][0] num_decoding_left_chunks=cfg.num_decoding_left_chunks,
simulate_streaming=cfg.simulate_streaming)
self._outputs["result"] = result_transcripts[0][0]
except Exception as e:
logger.exception(e)
else: else:
raise Exception("invalid model name") raise Exception("invalid model name")

@ -213,12 +213,14 @@ class U2BaseModel(ASRInterface, nn.Layer):
num_decoding_left_chunks=num_decoding_left_chunks num_decoding_left_chunks=num_decoding_left_chunks
) # (B, maxlen, encoder_dim) ) # (B, maxlen, encoder_dim)
else: else:
print("offline decode from the asr")
encoder_out, encoder_mask = self.encoder( encoder_out, encoder_mask = self.encoder(
speech, speech,
speech_lengths, speech_lengths,
decoding_chunk_size=decoding_chunk_size, decoding_chunk_size=decoding_chunk_size,
num_decoding_left_chunks=num_decoding_left_chunks num_decoding_left_chunks=num_decoding_left_chunks
) # (B, maxlen, encoder_dim) ) # (B, maxlen, encoder_dim)
print("offline decode success")
return encoder_out, encoder_mask return encoder_out, encoder_mask
def recognize( def recognize(
@ -706,13 +708,15 @@ class U2BaseModel(ASRInterface, nn.Layer):
List[List[int]]: transcripts. List[List[int]]: transcripts.
""" """
batch_size = feats.shape[0] batch_size = feats.shape[0]
print("start to decode the audio feat")
if decoding_method in ['ctc_prefix_beam_search', if decoding_method in ['ctc_prefix_beam_search',
'attention_rescoring'] and batch_size > 1: 'attention_rescoring'] and batch_size > 1:
logger.fatal( logger.error(
f'decoding mode {decoding_method} must be running with batch_size == 1' f'decoding mode {decoding_method} must be running with batch_size == 1'
) )
logger.error(f"current batch_size is {batch_size}")
sys.exit(1) sys.exit(1)
print(f"use the {decoding_method} to decode the audio feat")
if decoding_method == 'attention': if decoding_method == 'attention':
hyps = self.recognize( hyps = self.recognize(
feats, feats,

@ -180,7 +180,8 @@ class CTCDecoder(CTCDecoderBase):
# init once # init once
if self._ext_scorer is not None: if self._ext_scorer is not None:
return return
from paddlespeech.s2t.decoders.ctcdecoder import Scorer # noqa: F401
if language_model_path != '': if language_model_path != '':
logger.info("begin to initialize the external scorer " logger.info("begin to initialize the external scorer "
"for decoding") "for decoding")

@ -317,6 +317,8 @@ class BaseEncoder(nn.Layer):
outputs = [] outputs = []
offset = 0 offset = 0
# Feed forward overlap input step by step # Feed forward overlap input step by step
print(f"context: {context}")
print(f"stride: {stride}")
for cur in range(0, num_frames - context + 1, stride): for cur in range(0, num_frames - context + 1, stride):
end = min(cur + decoding_window, num_frames) end = min(cur + decoding_window, num_frames)
chunk_xs = xs[:, cur:end, :] chunk_xs = xs[:, cur:end, :]

@ -4,7 +4,7 @@
# SERVER SETTING # # SERVER SETTING #
################################################################################# #################################################################################
host: 0.0.0.0 host: 0.0.0.0
port: 8091 port: 8096
# The task format in the engin_list is: <speech task>_<engine type> # The task format in the engin_list is: <speech task>_<engine type>
# task choices = ['asr_online', 'tts_online'] # task choices = ['asr_online', 'tts_online']
@ -18,10 +18,44 @@ engine_list: ['asr_online']
# ENGINE CONFIG # # ENGINE CONFIG #
################################################################################# #################################################################################
# ################################### ASR #########################################
# ################### speech task: asr; engine_type: online #######################
# asr_online:
# model_type: 'deepspeech2online_aishell'
# am_model: # the pdmodel file of am static model [optional]
# am_params: # the pdiparams file of am static model [optional]
# lang: 'zh'
# sample_rate: 16000
# cfg_path:
# decode_method:
# force_yes: True
# am_predictor_conf:
# device: # set 'gpu:id' or 'cpu'
# switch_ir_optim: True
# glog_info: False # True -> print glog
# summary: True # False -> do not show predictor config
# chunk_buffer_conf:
# frame_duration_ms: 80
# shift_ms: 40
# sample_rate: 16000
# sample_width: 2
# vad_conf:
# aggressiveness: 2
# sample_rate: 16000
# frame_duration_ms: 20
# sample_width: 2
# padding_ms: 200
# padding_ratio: 0.9
################################### ASR ######################################### ################################### ASR #########################################
################### speech task: asr; engine_type: online ####################### ################### speech task: asr; engine_type: online #######################
asr_online: asr_online:
model_type: 'deepspeech2online_aishell' model_type: 'conformer2online_aishell'
am_model: # the pdmodel file of am static model [optional] am_model: # the pdmodel file of am static model [optional]
am_params: # the pdiparams file of am static model [optional] am_params: # the pdiparams file of am static model [optional]
lang: 'zh' lang: 'zh'
@ -37,15 +71,15 @@ asr_online:
summary: True # False -> do not show predictor config summary: True # False -> do not show predictor config
chunk_buffer_conf: chunk_buffer_conf:
frame_duration_ms: 80 frame_duration_ms: 85
shift_ms: 40 shift_ms: 40
sample_rate: 16000 sample_rate: 16000
sample_width: 2 sample_width: 2
vad_conf: # vad_conf:
aggressiveness: 2 # aggressiveness: 2
sample_rate: 16000 # sample_rate: 16000
frame_duration_ms: 20 # frame_duration_ms: 20
sample_width: 2 # sample_width: 2
padding_ms: 200 # padding_ms: 200
padding_ratio: 0.9 # padding_ratio: 0.9

@ -20,11 +20,15 @@ from numpy import float32
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.cli.asr.infer import ASRExecutor from paddlespeech.cli.asr.infer import ASRExecutor
from paddlespeech.cli.asr.infer import model_alias
from paddlespeech.cli.asr.infer import pretrained_models
from paddlespeech.cli.log import logger from paddlespeech.cli.log import logger
from paddlespeech.cli.utils import MODEL_HOME from paddlespeech.cli.utils import MODEL_HOME
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.frontend.speech import SpeechSegment from paddlespeech.s2t.frontend.speech import SpeechSegment
from paddlespeech.s2t.modules.ctc import CTCDecoder from paddlespeech.s2t.modules.ctc import CTCDecoder
from paddlespeech.s2t.transform.transformation import Transformation
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.s2t.utils.utility import UpdateConfig
from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils.audio_process import pcm2float from paddlespeech.server.utils.audio_process import pcm2float
@ -51,6 +55,24 @@ pretrained_models = {
'lm_md5': 'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3' '29e02312deb2e59b3c8686c7966d4fe3'
}, },
"conformer2online_aishell-zh-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr1_chunk_conformer_aishell_ckpt_0.1.2.model.tar.gz',
'md5':
'4814e52e0fc2fd48899373f95c84b0c9',
'cfg_path':
'exp/chunk_conformer//conf/config.yaml',
'ckpt_path':
'exp/chunk_conformer/checkpoints/avg_30/',
'model':
'exp/chunk_conformer/checkpoints/avg_30.pdparams',
'params':
'exp/chunk_conformer/checkpoints/avg_30.pdparams',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
} }
@ -71,15 +93,17 @@ class ASRServerExecutor(ASRExecutor):
""" """
Init model and other resources from a specific path. Init model and other resources from a specific path.
""" """
self.model_type = model_type
self.sample_rate = sample_rate
if cfg_path is None or am_model is None or am_params is None: if cfg_path is None or am_model is None or am_params is None:
sample_rate_str = '16k' if sample_rate == 16000 else '8k' sample_rate_str = '16k' if sample_rate == 16000 else '8k'
tag = model_type + '-' + lang + '-' + sample_rate_str tag = model_type + '-' + lang + '-' + sample_rate_str
logger.info(f"Load the pretrained model, tag = {tag}") logger.info(f"Load the pretrained model, tag = {tag}")
res_path = self._get_pretrained_path(tag) # wenetspeech_zh res_path = self._get_pretrained_path(tag) # wenetspeech_zh
self.res_path = res_path self.res_path = res_path
self.cfg_path = os.path.join(res_path, self.cfg_path = "/home/users/xiongxinlei/task/paddlespeech-develop/PaddleSpeech/paddlespeech/server/tests/asr/online/conf/config.yaml"
pretrained_models[tag]['cfg_path']) # self.cfg_path = os.path.join(res_path,
# pretrained_models[tag]['cfg_path'])
self.am_model = os.path.join(res_path, self.am_model = os.path.join(res_path,
pretrained_models[tag]['model']) pretrained_models[tag]['model'])
@ -119,49 +143,67 @@ class ASRServerExecutor(ASRExecutor):
lm_url, lm_url,
os.path.dirname(self.config.decode.lang_model_path), lm_md5) os.path.dirname(self.config.decode.lang_model_path), lm_md5)
elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type: elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type:
# 开发 conformer 的流式模型
logger.info("start to create the stream conformer asr engine") logger.info("start to create the stream conformer asr engine")
# 复用cli里面的代码 if self.config.spm_model_prefix:
self.config.spm_model_prefix = os.path.join(
self.res_path, self.config.spm_model_prefix)
self.config.vocab_filepath = os.path.join(
self.res_path, self.config.vocab_filepath)
self.text_feature = TextFeaturizer(
unit_type=self.config.unit_type,
vocab=self.config.vocab_filepath,
spm_model_prefix=self.config.spm_model_prefix)
# update the decoding method
if decode_method:
self.config.decode.decoding_method = decode_method
else: else:
raise Exception("wrong type") raise Exception("wrong type")
if "deepspeech2online" in model_type or "deepspeech2offline" in model_type:
# AM predictor # AM predictor
logger.info("ASR engine start to init the am predictor") logger.info("ASR engine start to init the am predictor")
self.am_predictor_conf = am_predictor_conf self.am_predictor_conf = am_predictor_conf
self.am_predictor = init_predictor( self.am_predictor = init_predictor(
model_file=self.am_model, model_file=self.am_model,
params_file=self.am_params, params_file=self.am_params,
predictor_conf=self.am_predictor_conf) predictor_conf=self.am_predictor_conf)
# decoder # decoder
logger.info("ASR engine start to create the ctc decoder instance") logger.info("ASR engine start to create the ctc decoder instance")
self.decoder = CTCDecoder( self.decoder = CTCDecoder(
odim=self.config.output_dim, # <blank> is in vocab odim=self.config.output_dim, # <blank> is in vocab
enc_n_units=self.config.rnn_layer_size * 2, enc_n_units=self.config.rnn_layer_size * 2,
blank_id=self.config.blank_id, blank_id=self.config.blank_id,
dropout_rate=0.0, dropout_rate=0.0,
reduction=True, # sum reduction=True, # sum
batch_average=True, # sum / batch_size batch_average=True, # sum / batch_size
grad_norm_type=self.config.get('ctc_grad_norm_type', None)) grad_norm_type=self.config.get('ctc_grad_norm_type', None))
# init decoder # init decoder
logger.info("ASR engine start to init the ctc decoder") logger.info("ASR engine start to init the ctc decoder")
cfg = self.config.decode cfg = self.config.decode
decode_batch_size = 1 # for online decode_batch_size = 1 # for online
self.decoder.init_decoder( self.decoder.init_decoder(
decode_batch_size, self.text_feature.vocab_list, decode_batch_size, self.text_feature.vocab_list,
cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta, cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta,
cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n, cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n,
cfg.num_proc_bsearch) cfg.num_proc_bsearch)
# init state box # init state box
self.chunk_state_h_box = np.zeros( self.chunk_state_h_box = np.zeros(
(self.config.num_rnn_layers, 1, self.config.rnn_layer_size), (self.config.num_rnn_layers, 1, self.config.rnn_layer_size),
dtype=float32) dtype=float32)
self.chunk_state_c_box = np.zeros( self.chunk_state_c_box = np.zeros(
(self.config.num_rnn_layers, 1, self.config.rnn_layer_size), (self.config.num_rnn_layers, 1, self.config.rnn_layer_size),
dtype=float32) dtype=float32)
elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type:
model_name = model_type[:model_type.rindex(
'_')] # model_type: {model_name}_{dataset}
logger.info(f"model name: {model_name}")
model_class = dynamic_import(model_name, model_alias)
model_conf = self.config
model = model_class.from_config(model_conf)
self.model = model
logger.info("create the transformer like model success")
def reset_decoder_and_chunk(self): def reset_decoder_and_chunk(self):
"""reset decoder and chunk state for an new audio """reset decoder and chunk state for an new audio
@ -186,6 +228,7 @@ class ASRServerExecutor(ASRExecutor):
Returns: Returns:
[type]: [description] [type]: [description]
""" """
logger.info("start to decoce chunk by chunk")
if "deepspeech2online" in model_type: if "deepspeech2online" in model_type:
input_names = self.am_predictor.get_input_names() input_names = self.am_predictor.get_input_names()
audio_handle = self.am_predictor.get_input_handle(input_names[0]) audio_handle = self.am_predictor.get_input_handle(input_names[0])
@ -224,10 +267,29 @@ class ASRServerExecutor(ASRExecutor):
self.decoder.next(output_chunk_probs, output_chunk_lens) self.decoder.next(output_chunk_probs, output_chunk_lens)
trans_best, trans_beam = self.decoder.decode() trans_best, trans_beam = self.decoder.decode()
logger.info(f"decode one one best result: {trans_best[0]}")
return trans_best[0] return trans_best[0]
elif "conformer" in model_type or "transformer" in model_type: elif "conformer" in model_type or "transformer" in model_type:
raise Exception("invalid model name") try:
logger.info(
f"we will use the transformer like model : {self.model_type}"
)
cfg = self.config.decode
result_transcripts = self.model.decode(
x_chunk,
x_chunk_lens,
text_feature=self.text_feature,
decoding_method=cfg.decoding_method,
beam_size=cfg.beam_size,
ctc_weight=cfg.ctc_weight,
decoding_chunk_size=cfg.decoding_chunk_size,
num_decoding_left_chunks=cfg.num_decoding_left_chunks,
simulate_streaming=cfg.simulate_streaming)
return result_transcripts[0][0]
except Exception as e:
logger.exception(e)
else: else:
raise Exception("invalid model name") raise Exception("invalid model name")
@ -244,32 +306,55 @@ class ASRServerExecutor(ASRExecutor):
""" """
# pcm16 -> pcm 32 # pcm16 -> pcm 32
samples = pcm2float(samples) samples = pcm2float(samples)
if "deepspeech2online" in self.model_type:
# read audio # read audio
speech_segment = SpeechSegment.from_pcm( speech_segment = SpeechSegment.from_pcm(
samples, sample_rate, transcript=" ") samples, sample_rate, transcript=" ")
# audio augment # audio augment
self.collate_fn_test.augmentation.transform_audio(speech_segment) self.collate_fn_test.augmentation.transform_audio(speech_segment)
# extract speech feature # extract speech feature
spectrum, transcript_part = self.collate_fn_test._speech_featurizer.featurize( spectrum, transcript_part = self.collate_fn_test._speech_featurizer.featurize(
speech_segment, self.collate_fn_test.keep_transcription_text) speech_segment, self.collate_fn_test.keep_transcription_text)
# CMVN spectrum # CMVN spectrum
if self.collate_fn_test._normalizer: if self.collate_fn_test._normalizer:
spectrum = self.collate_fn_test._normalizer.apply(spectrum) spectrum = self.collate_fn_test._normalizer.apply(spectrum)
# spectrum augment # spectrum augment
audio = self.collate_fn_test.augmentation.transform_feature(spectrum) audio = self.collate_fn_test.augmentation.transform_feature(
spectrum)
audio_len = audio.shape[0]
audio = paddle.to_tensor(audio, dtype='float32') audio_len = audio.shape[0]
# audio_len = paddle.to_tensor(audio_len) audio = paddle.to_tensor(audio, dtype='float32')
audio = paddle.unsqueeze(audio, axis=0) # audio_len = paddle.to_tensor(audio_len)
audio = paddle.unsqueeze(audio, axis=0)
x_chunk = audio.numpy()
x_chunk_lens = np.array([audio_len]) x_chunk = audio.numpy()
x_chunk_lens = np.array([audio_len])
return x_chunk, x_chunk_lens
return x_chunk, x_chunk_lens
elif "conformer2online" in self.model_type:
if sample_rate != self.sample_rate:
logger.info(f"audio sample rate {sample_rate} is not match," \
"the model sample_rate is {self.sample_rate}")
logger.info(f"ASR Engine use the {self.model_type} to process")
logger.info("Create the preprocess instance")
preprocess_conf = self.config.preprocess_config
preprocess_args = {"train": False}
preprocessing = Transformation(preprocess_conf)
logger.info("Read the audio file")
logger.info(f"audio shape: {samples.shape}")
# fbank
x_chunk = preprocessing(samples, **preprocess_args)
x_chunk_lens = paddle.to_tensor(x_chunk.shape[0])
x_chunk = paddle.to_tensor(
x_chunk, dtype="float32").unsqueeze(axis=0)
logger.info(
f"process the audio feature success, feat shape: {x_chunk.shape}"
)
return x_chunk, x_chunk_lens
class ASREngine(BaseEngine): class ASREngine(BaseEngine):
@ -310,7 +395,10 @@ class ASREngine(BaseEngine):
logger.info("Initialize ASR server engine successfully.") logger.info("Initialize ASR server engine successfully.")
return True return True
def preprocess(self, samples, sample_rate): def preprocess(self,
samples,
sample_rate,
model_type="deepspeech2online_aishell-zh-16k"):
"""preprocess """preprocess
Args: Args:
@ -321,6 +409,7 @@ class ASREngine(BaseEngine):
x_chunk (numpy.array): shape[B, T, D] x_chunk (numpy.array): shape[B, T, D]
x_chunk_lens (numpy.array): shape[B] x_chunk_lens (numpy.array): shape[B]
""" """
# if "deepspeech" in model_type:
x_chunk, x_chunk_lens = self.executor.extract_feat(samples, sample_rate) x_chunk, x_chunk_lens = self.executor.extract_feat(samples, sample_rate)
return x_chunk, x_chunk_lens return x_chunk, x_chunk_lens

@ -103,7 +103,7 @@ class ASRAudioHandler:
def main(args): def main(args):
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logging.info("asr websocket client start") logging.info("asr websocket client start")
handler = ASRAudioHandler("127.0.0.1", 8090) handler = ASRAudioHandler("127.0.0.1", 8096)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
# support to process single audio file # support to process single audio file

@ -14,6 +14,7 @@
import json import json
import numpy as np import numpy as np
import json
from fastapi import APIRouter from fastapi import APIRouter
from fastapi import WebSocket from fastapi import WebSocket
from fastapi import WebSocketDisconnect from fastapi import WebSocketDisconnect
@ -28,7 +29,7 @@ router = APIRouter()
@router.websocket('/ws/asr') @router.websocket('/ws/asr')
async def websocket_endpoint(websocket: WebSocket): async def websocket_endpoint(websocket: WebSocket):
print("websocket protocal receive the dataset")
await websocket.accept() await websocket.accept()
engine_pool = get_engine_pool() engine_pool = get_engine_pool()
@ -36,14 +37,18 @@ async def websocket_endpoint(websocket: WebSocket):
# init buffer # init buffer
chunk_buffer_conf = asr_engine.config.chunk_buffer_conf chunk_buffer_conf = asr_engine.config.chunk_buffer_conf
chunk_buffer = ChunkBuffer( chunk_buffer = ChunkBuffer(
frame_duration_ms=chunk_buffer_conf['frame_duration_ms'],
sample_rate=chunk_buffer_conf['sample_rate'], sample_rate=chunk_buffer_conf['sample_rate'],
sample_width=chunk_buffer_conf['sample_width']) sample_width=chunk_buffer_conf['sample_width'])
# init vad # init vad
vad_conf = asr_engine.config.vad_conf # print(asr_engine.config)
vad = VADAudio( # print(type(asr_engine.config))
aggressiveness=vad_conf['aggressiveness'], vad_conf = asr_engine.config.get('vad_conf', None)
rate=vad_conf['sample_rate'], if vad_conf:
frame_duration_ms=vad_conf['frame_duration_ms']) vad = VADAudio(
aggressiveness=vad_conf['aggressiveness'],
rate=vad_conf['sample_rate'],
frame_duration_ms=vad_conf['frame_duration_ms'])
try: try:
while True: while True:
@ -65,7 +70,7 @@ async def websocket_endpoint(websocket: WebSocket):
engine_pool = get_engine_pool() engine_pool = get_engine_pool()
asr_engine = engine_pool['asr'] asr_engine = engine_pool['asr']
# reset single engine for an new connection # reset single engine for an new connection
asr_engine.reset() # asr_engine.reset()
resp = {"status": "ok", "signal": "finished"} resp = {"status": "ok", "signal": "finished"}
await websocket.send_json(resp) await websocket.send_json(resp)
break break
@ -75,16 +80,16 @@ async def websocket_endpoint(websocket: WebSocket):
elif "bytes" in message: elif "bytes" in message:
message = message["bytes"] message = message["bytes"]
# vad for input bytes audio # # vad for input bytes audio
vad.add_audio(message) # vad.add_audio(message)
message = b''.join(f for f in vad.vad_collector() # message = b''.join(f for f in vad.vad_collector()
if f is not None) # if f is not None)
engine_pool = get_engine_pool() engine_pool = get_engine_pool()
asr_engine = engine_pool['asr'] asr_engine = engine_pool['asr']
asr_results = "" asr_results = ""
frames = chunk_buffer.frame_generator(message) frames = chunk_buffer.frame_generator(message)
for frame in frames: for frame in frames:
# get the pcm data from the bytes
samples = np.frombuffer(frame.bytes, dtype=np.int16) samples = np.frombuffer(frame.bytes, dtype=np.int16)
sample_rate = asr_engine.config.sample_rate sample_rate = asr_engine.config.sample_rate
x_chunk, x_chunk_lens = asr_engine.preprocess(samples, x_chunk, x_chunk_lens = asr_engine.preprocess(samples,

Loading…
Cancel
Save