commit
61941d14b0
@ -0,0 +1,51 @@
|
||||
# This is the parameter configuration file for PaddleSpeech Serving.
|
||||
|
||||
#################################################################################
|
||||
# SERVER SETTING #
|
||||
#################################################################################
|
||||
host: 0.0.0.0
|
||||
port: 8091
|
||||
|
||||
# The task format in the engin_list is: <speech task>_<engine type>
|
||||
# task choices = ['asr_online', 'tts_online']
|
||||
# protocol = ['websocket', 'http'] (only one can be selected).
|
||||
# websocket only support online engine type.
|
||||
protocol: 'websocket'
|
||||
engine_list: ['asr_online']
|
||||
|
||||
|
||||
#################################################################################
|
||||
# ENGINE CONFIG #
|
||||
#################################################################################
|
||||
|
||||
################################### ASR #########################################
|
||||
################### speech task: asr; engine_type: online #######################
|
||||
asr_online:
|
||||
model_type: 'deepspeech2online_aishell'
|
||||
am_model: # the pdmodel file of am static model [optional]
|
||||
am_params: # the pdiparams file of am static model [optional]
|
||||
lang: 'zh'
|
||||
sample_rate: 16000
|
||||
cfg_path:
|
||||
decode_method:
|
||||
force_yes: True
|
||||
|
||||
am_predictor_conf:
|
||||
device: # set 'gpu:id' or 'cpu'
|
||||
switch_ir_optim: True
|
||||
glog_info: False # True -> print glog
|
||||
summary: True # False -> do not show predictor config
|
||||
|
||||
chunk_buffer_conf:
|
||||
frame_duration_ms: 80
|
||||
shift_ms: 40
|
||||
sample_rate: 16000
|
||||
sample_width: 2
|
||||
|
||||
vad_conf:
|
||||
aggressiveness: 2
|
||||
sample_rate: 16000
|
||||
frame_duration_ms: 20
|
||||
sample_width: 2
|
||||
padding_ms: 200
|
||||
padding_ratio: 0.9
|
@ -0,0 +1,13 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -0,0 +1,355 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import io
|
||||
import os
|
||||
import time
|
||||
from typing import Optional
|
||||
import pickle
|
||||
import numpy as np
|
||||
from numpy import float32
|
||||
import soundfile
|
||||
|
||||
import paddle
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from paddlespeech.s2t.frontend.speech import SpeechSegment
|
||||
from paddlespeech.cli.asr.infer import ASRExecutor
|
||||
from paddlespeech.cli.log import logger
|
||||
from paddlespeech.cli.utils import MODEL_HOME
|
||||
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
|
||||
from paddlespeech.s2t.modules.ctc import CTCDecoder
|
||||
from paddlespeech.s2t.utils.utility import UpdateConfig
|
||||
from paddlespeech.server.engine.base_engine import BaseEngine
|
||||
from paddlespeech.server.utils.config import get_config
|
||||
from paddlespeech.server.utils.paddle_predictor import init_predictor
|
||||
from paddlespeech.server.utils.paddle_predictor import run_model
|
||||
|
||||
__all__ = ['ASREngine']
|
||||
|
||||
pretrained_models = {
|
||||
"deepspeech2online_aishell-zh-16k": {
|
||||
'url':
|
||||
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.1.1.model.tar.gz',
|
||||
'md5':
|
||||
'd5e076217cf60486519f72c217d21b9b',
|
||||
'cfg_path':
|
||||
'model.yaml',
|
||||
'ckpt_path':
|
||||
'exp/deepspeech2_online/checkpoints/avg_1',
|
||||
'model':
|
||||
'exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel',
|
||||
'params':
|
||||
'exp/deepspeech2_online/checkpoints/avg_1.jit.pdiparams',
|
||||
'lm_url':
|
||||
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
|
||||
'lm_md5':
|
||||
'29e02312deb2e59b3c8686c7966d4fe3'
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class ASRServerExecutor(ASRExecutor):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
pass
|
||||
|
||||
def _init_from_path(self,
|
||||
model_type: str='wenetspeech',
|
||||
am_model: Optional[os.PathLike]=None,
|
||||
am_params: Optional[os.PathLike]=None,
|
||||
lang: str='zh',
|
||||
sample_rate: int=16000,
|
||||
cfg_path: Optional[os.PathLike]=None,
|
||||
decode_method: str='attention_rescoring',
|
||||
am_predictor_conf: dict=None):
|
||||
"""
|
||||
Init model and other resources from a specific path.
|
||||
"""
|
||||
|
||||
if cfg_path is None or am_model is None or am_params is None:
|
||||
sample_rate_str = '16k' if sample_rate == 16000 else '8k'
|
||||
tag = model_type + '-' + lang + '-' + sample_rate_str
|
||||
res_path = self._get_pretrained_path(tag) # wenetspeech_zh
|
||||
self.res_path = res_path
|
||||
self.cfg_path = os.path.join(res_path,
|
||||
pretrained_models[tag]['cfg_path'])
|
||||
|
||||
self.am_model = os.path.join(res_path,
|
||||
pretrained_models[tag]['model'])
|
||||
self.am_params = os.path.join(res_path,
|
||||
pretrained_models[tag]['params'])
|
||||
logger.info(res_path)
|
||||
logger.info(self.cfg_path)
|
||||
logger.info(self.am_model)
|
||||
logger.info(self.am_params)
|
||||
else:
|
||||
self.cfg_path = os.path.abspath(cfg_path)
|
||||
self.am_model = os.path.abspath(am_model)
|
||||
self.am_params = os.path.abspath(am_params)
|
||||
self.res_path = os.path.dirname(
|
||||
os.path.dirname(os.path.abspath(self.cfg_path)))
|
||||
|
||||
#Init body.
|
||||
self.config = CfgNode(new_allowed=True)
|
||||
self.config.merge_from_file(self.cfg_path)
|
||||
|
||||
with UpdateConfig(self.config):
|
||||
if "deepspeech2online" in model_type or "deepspeech2offline" in model_type:
|
||||
from paddlespeech.s2t.io.collator import SpeechCollator
|
||||
self.vocab = self.config.vocab_filepath
|
||||
self.config.decode.lang_model_path = os.path.join(
|
||||
MODEL_HOME, 'language_model',
|
||||
self.config.decode.lang_model_path)
|
||||
self.collate_fn_test = SpeechCollator.from_config(self.config)
|
||||
self.text_feature = TextFeaturizer(
|
||||
unit_type=self.config.unit_type, vocab=self.vocab)
|
||||
|
||||
lm_url = pretrained_models[tag]['lm_url']
|
||||
lm_md5 = pretrained_models[tag]['lm_md5']
|
||||
self.download_lm(
|
||||
lm_url,
|
||||
os.path.dirname(self.config.decode.lang_model_path), lm_md5)
|
||||
elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type:
|
||||
raise Exception("wrong type")
|
||||
else:
|
||||
raise Exception("wrong type")
|
||||
|
||||
# AM predictor
|
||||
self.am_predictor_conf = am_predictor_conf
|
||||
self.am_predictor = init_predictor(
|
||||
model_file=self.am_model,
|
||||
params_file=self.am_params,
|
||||
predictor_conf=self.am_predictor_conf)
|
||||
|
||||
# decoder
|
||||
self.decoder = CTCDecoder(
|
||||
odim=self.config.output_dim, # <blank> is in vocab
|
||||
enc_n_units=self.config.rnn_layer_size * 2,
|
||||
blank_id=self.config.blank_id,
|
||||
dropout_rate=0.0,
|
||||
reduction=True, # sum
|
||||
batch_average=True, # sum / batch_size
|
||||
grad_norm_type=self.config.get('ctc_grad_norm_type', None))
|
||||
|
||||
# init decoder
|
||||
cfg = self.config.decode
|
||||
decode_batch_size = 1 # for online
|
||||
self.decoder.init_decoder(
|
||||
decode_batch_size, self.text_feature.vocab_list,
|
||||
cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta,
|
||||
cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n,
|
||||
cfg.num_proc_bsearch)
|
||||
|
||||
# init state box
|
||||
self.chunk_state_h_box = np.zeros(
|
||||
(self.config.num_rnn_layers, 1, self.config.rnn_layer_size),
|
||||
dtype=float32)
|
||||
self.chunk_state_c_box = np.zeros(
|
||||
(self.config.num_rnn_layers, 1, self.config.rnn_layer_size),
|
||||
dtype=float32)
|
||||
|
||||
def reset_decoder_and_chunk(self):
|
||||
"""reset decoder and chunk state for an new audio
|
||||
"""
|
||||
self.decoder.reset_decoder(batch_size=1)
|
||||
# init state box, for new audio request
|
||||
self.chunk_state_h_box = np.zeros(
|
||||
(self.config.num_rnn_layers, 1, self.config.rnn_layer_size),
|
||||
dtype=float32)
|
||||
self.chunk_state_c_box = np.zeros(
|
||||
(self.config.num_rnn_layers, 1, self.config.rnn_layer_size),
|
||||
dtype=float32)
|
||||
|
||||
def decode_one_chunk(self, x_chunk, x_chunk_lens, model_type: str):
|
||||
"""decode one chunk
|
||||
|
||||
Args:
|
||||
x_chunk (numpy.array): shape[B, T, D]
|
||||
x_chunk_lens (numpy.array): shape[B]
|
||||
model_type (str): online model type
|
||||
|
||||
Returns:
|
||||
[type]: [description]
|
||||
"""
|
||||
if "deepspeech2online" in model_type :
|
||||
input_names = self.am_predictor.get_input_names()
|
||||
audio_handle = self.am_predictor.get_input_handle(input_names[0])
|
||||
audio_len_handle = self.am_predictor.get_input_handle(input_names[1])
|
||||
h_box_handle = self.am_predictor.get_input_handle(input_names[2])
|
||||
c_box_handle = self.am_predictor.get_input_handle(input_names[3])
|
||||
|
||||
audio_handle.reshape(x_chunk.shape)
|
||||
audio_handle.copy_from_cpu(x_chunk)
|
||||
|
||||
audio_len_handle.reshape(x_chunk_lens.shape)
|
||||
audio_len_handle.copy_from_cpu(x_chunk_lens)
|
||||
|
||||
h_box_handle.reshape(self.chunk_state_h_box.shape)
|
||||
h_box_handle.copy_from_cpu(self.chunk_state_h_box)
|
||||
|
||||
c_box_handle.reshape(self.chunk_state_c_box.shape)
|
||||
c_box_handle.copy_from_cpu(self.chunk_state_c_box)
|
||||
|
||||
output_names = self.am_predictor.get_output_names()
|
||||
output_handle = self.am_predictor.get_output_handle(output_names[0])
|
||||
output_lens_handle = self.am_predictor.get_output_handle(output_names[1])
|
||||
output_state_h_handle = self.am_predictor.get_output_handle(
|
||||
output_names[2])
|
||||
output_state_c_handle = self.am_predictor.get_output_handle(
|
||||
output_names[3])
|
||||
|
||||
self.am_predictor.run()
|
||||
|
||||
output_chunk_probs = output_handle.copy_to_cpu()
|
||||
output_chunk_lens = output_lens_handle.copy_to_cpu()
|
||||
self.chunk_state_h_box = output_state_h_handle.copy_to_cpu()
|
||||
self.chunk_state_c_box = output_state_c_handle.copy_to_cpu()
|
||||
|
||||
self.decoder.next(output_chunk_probs, output_chunk_lens)
|
||||
trans_best, trans_beam = self.decoder.decode()
|
||||
|
||||
return trans_best[0]
|
||||
|
||||
elif "conformer" in model_type or "transformer" in model_type:
|
||||
raise Exception("invalid model name")
|
||||
else:
|
||||
raise Exception("invalid model name")
|
||||
|
||||
def _pcm16to32(self, audio):
|
||||
"""pcm int16 to float32
|
||||
|
||||
Args:
|
||||
audio(numpy.array): numpy.int16
|
||||
|
||||
Returns:
|
||||
audio(numpy.array): numpy.float32
|
||||
"""
|
||||
if audio.dtype == np.int16:
|
||||
audio = audio.astype("float32")
|
||||
bits = np.iinfo(np.int16).bits
|
||||
audio = audio / (2**(bits - 1))
|
||||
return audio
|
||||
|
||||
def extract_feat(self, samples, sample_rate):
|
||||
"""extract feat
|
||||
|
||||
Args:
|
||||
samples (numpy.array): numpy.float32
|
||||
sample_rate (int): sample rate
|
||||
|
||||
Returns:
|
||||
x_chunk (numpy.array): shape[B, T, D]
|
||||
x_chunk_lens (numpy.array): shape[B]
|
||||
"""
|
||||
# pcm16 -> pcm 32
|
||||
samples = self._pcm16to32(samples)
|
||||
|
||||
# read audio
|
||||
speech_segment = SpeechSegment.from_pcm(
|
||||
samples, sample_rate, transcript=" ")
|
||||
# audio augment
|
||||
self.collate_fn_test.augmentation.transform_audio(speech_segment)
|
||||
|
||||
# extract speech feature
|
||||
spectrum, transcript_part = self.collate_fn_test._speech_featurizer.featurize(
|
||||
speech_segment, self.collate_fn_test.keep_transcription_text)
|
||||
# CMVN spectrum
|
||||
if self.collate_fn_test._normalizer:
|
||||
spectrum = self.collate_fn_test._normalizer.apply(spectrum)
|
||||
|
||||
# spectrum augment
|
||||
audio = self.collate_fn_test.augmentation.transform_feature(spectrum)
|
||||
|
||||
audio_len = audio.shape[0]
|
||||
audio = paddle.to_tensor(audio, dtype='float32')
|
||||
# audio_len = paddle.to_tensor(audio_len)
|
||||
audio = paddle.unsqueeze(audio, axis=0)
|
||||
|
||||
x_chunk = audio.numpy()
|
||||
x_chunk_lens = np.array([audio_len])
|
||||
|
||||
return x_chunk, x_chunk_lens
|
||||
|
||||
|
||||
class ASREngine(BaseEngine):
|
||||
"""ASR server engine
|
||||
|
||||
Args:
|
||||
metaclass: Defaults to Singleton.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(ASREngine, self).__init__()
|
||||
|
||||
def init(self, config: dict) -> bool:
|
||||
"""init engine resource
|
||||
|
||||
Args:
|
||||
config_file (str): config file
|
||||
|
||||
Returns:
|
||||
bool: init failed or success
|
||||
"""
|
||||
self.input = None
|
||||
self.output = ""
|
||||
self.executor = ASRServerExecutor()
|
||||
self.config = config
|
||||
|
||||
self.executor._init_from_path(
|
||||
model_type=self.config.model_type,
|
||||
am_model=self.config.am_model,
|
||||
am_params=self.config.am_params,
|
||||
lang=self.config.lang,
|
||||
sample_rate=self.config.sample_rate,
|
||||
cfg_path=self.config.cfg_path,
|
||||
decode_method=self.config.decode_method,
|
||||
am_predictor_conf=self.config.am_predictor_conf)
|
||||
|
||||
logger.info("Initialize ASR server engine successfully.")
|
||||
return True
|
||||
|
||||
def preprocess(self, samples, sample_rate):
|
||||
"""preprocess
|
||||
|
||||
Args:
|
||||
samples (numpy.array): numpy.float32
|
||||
sample_rate (int): sample rate
|
||||
|
||||
Returns:
|
||||
x_chunk (numpy.array): shape[B, T, D]
|
||||
x_chunk_lens (numpy.array): shape[B]
|
||||
"""
|
||||
x_chunk, x_chunk_lens = self.executor.extract_feat(samples, sample_rate)
|
||||
return x_chunk, x_chunk_lens
|
||||
|
||||
def run(self, x_chunk, x_chunk_lens, decoder_chunk_size=1):
|
||||
"""run online engine
|
||||
|
||||
Args:
|
||||
x_chunk (numpy.array): shape[B, T, D]
|
||||
x_chunk_lens (numpy.array): shape[B]
|
||||
decoder_chunk_size(int)
|
||||
"""
|
||||
self.output = self.executor.decode_one_chunk(x_chunk, x_chunk_lens, self.config.model_type)
|
||||
|
||||
def postprocess(self):
|
||||
"""postprocess
|
||||
"""
|
||||
return self.output
|
||||
|
||||
def reset(self):
|
||||
"""reset engine decoder and inference state
|
||||
"""
|
||||
self.executor.reset_decoder_and_chunk()
|
||||
self.output = ""
|
@ -0,0 +1,161 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
record wave from the mic
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import wave
|
||||
from signal import SIGINT
|
||||
from signal import SIGTERM
|
||||
|
||||
import pyaudio
|
||||
import websockets
|
||||
|
||||
|
||||
class ASRAudioHandler(threading.Thread):
|
||||
def __init__(self, url="127.0.0.1", port=8091):
|
||||
threading.Thread.__init__(self)
|
||||
self.url = url
|
||||
self.port = port
|
||||
self.url = "ws://" + self.url + ":" + str(self.port) + "/ws/asr"
|
||||
self.fileName = "./output.wav"
|
||||
self.chunk = 5120
|
||||
self.format = pyaudio.paInt16
|
||||
self.channels = 1
|
||||
self.rate = 16000
|
||||
self._running = True
|
||||
self._frames = []
|
||||
self.data_backup = []
|
||||
|
||||
def startrecord(self):
|
||||
"""
|
||||
start a new thread to record wave
|
||||
"""
|
||||
threading._start_new_thread(self.recording, ())
|
||||
|
||||
def recording(self):
|
||||
"""
|
||||
recording wave
|
||||
"""
|
||||
self._running = True
|
||||
self._frames = []
|
||||
p = pyaudio.PyAudio()
|
||||
stream = p.open(
|
||||
format=self.format,
|
||||
channels=self.channels,
|
||||
rate=self.rate,
|
||||
input=True,
|
||||
frames_per_buffer=self.chunk)
|
||||
while (self._running):
|
||||
data = stream.read(self.chunk)
|
||||
self._frames.append(data)
|
||||
self.data_backup.append(data)
|
||||
|
||||
stream.stop_stream()
|
||||
stream.close()
|
||||
p.terminate()
|
||||
|
||||
def save(self):
|
||||
"""
|
||||
save wave data
|
||||
"""
|
||||
p = pyaudio.PyAudio()
|
||||
wf = wave.open(self.fileName, 'wb')
|
||||
wf.setnchannels(self.channels)
|
||||
wf.setsampwidth(p.get_sample_size(self.format))
|
||||
wf.setframerate(self.rate)
|
||||
wf.writeframes(b''.join(self.data_backup))
|
||||
wf.close()
|
||||
p.terminate()
|
||||
|
||||
def stoprecord(self):
|
||||
"""
|
||||
stop recording
|
||||
"""
|
||||
self._running = False
|
||||
|
||||
async def run(self):
|
||||
aa = input("是否开始录音? (y/n)")
|
||||
if aa.strip() == "y":
|
||||
self.startrecord()
|
||||
logging.info("*" * 10 + "开始录音,请输入语音")
|
||||
|
||||
async with websockets.connect(self.url) as ws:
|
||||
# 发送开始指令
|
||||
audio_info = json.dumps(
|
||||
{
|
||||
"name": "test.wav",
|
||||
"signal": "start",
|
||||
"nbest": 5
|
||||
},
|
||||
sort_keys=True,
|
||||
indent=4,
|
||||
separators=(',', ': '))
|
||||
await ws.send(audio_info)
|
||||
msg = await ws.recv()
|
||||
logging.info("receive msg={}".format(msg))
|
||||
|
||||
# send bytes data
|
||||
logging.info("结束录音请: Ctrl + c。继续请按回车。")
|
||||
try:
|
||||
while True:
|
||||
while len(self._frames) > 0:
|
||||
await ws.send(self._frames.pop(0))
|
||||
msg = await ws.recv()
|
||||
logging.info("receive msg={}".format(msg))
|
||||
except asyncio.CancelledError:
|
||||
# quit
|
||||
# send finished
|
||||
audio_info = json.dumps(
|
||||
{
|
||||
"name": "test.wav",
|
||||
"signal": "end",
|
||||
"nbest": 5
|
||||
},
|
||||
sort_keys=True,
|
||||
indent=4,
|
||||
separators=(',', ': '))
|
||||
await ws.send(audio_info)
|
||||
msg = await ws.recv()
|
||||
logging.info("receive msg={}".format(msg))
|
||||
|
||||
self.stoprecord()
|
||||
logging.info("*" * 10 + "录音结束")
|
||||
self.save()
|
||||
elif aa.strip() == "n":
|
||||
exit()
|
||||
else:
|
||||
print("无效输入!")
|
||||
exit()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logging.info("asr websocket client start")
|
||||
|
||||
handler = ASRAudioHandler("127.0.0.1", 8091)
|
||||
loop = asyncio.get_event_loop()
|
||||
main_task = asyncio.ensure_future(handler.run())
|
||||
for signal in [SIGINT, SIGTERM]:
|
||||
loop.add_signal_handler(signal, main_task.cancel)
|
||||
try:
|
||||
loop.run_until_complete(main_task)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
logging.info("asr websocket client finished")
|
@ -0,0 +1,115 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#!/usr/bin/python
|
||||
# -*- coding: UTF-8 -*-
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import soundfile
|
||||
import websockets
|
||||
|
||||
|
||||
class ASRAudioHandler:
|
||||
def __init__(self, url="127.0.0.1", port=8090):
|
||||
self.url = url
|
||||
self.port = port
|
||||
self.url = "ws://" + self.url + ":" + str(self.port) + "/ws/asr"
|
||||
|
||||
def read_wave(self, wavfile_path: str):
|
||||
samples, sample_rate = soundfile.read(wavfile_path, dtype='int16')
|
||||
x_len = len(samples)
|
||||
chunk_stride = 40 * 16 #40ms, sample_rate = 16kHz
|
||||
chunk_size = 80 * 16 #80ms, sample_rate = 16kHz
|
||||
|
||||
if (x_len - chunk_size) % chunk_stride != 0:
|
||||
padding_len_x = chunk_stride - (x_len - chunk_size) % chunk_stride
|
||||
else:
|
||||
padding_len_x = 0
|
||||
|
||||
padding = np.zeros((padding_len_x), dtype=samples.dtype)
|
||||
padded_x = np.concatenate([samples, padding], axis=0)
|
||||
|
||||
num_chunk = (x_len + padding_len_x - chunk_size) / chunk_stride + 1
|
||||
num_chunk = int(num_chunk)
|
||||
|
||||
for i in range(0, num_chunk):
|
||||
start = i * chunk_stride
|
||||
end = start + chunk_size
|
||||
x_chunk = padded_x[start:end]
|
||||
yield x_chunk
|
||||
|
||||
async def run(self, wavfile_path: str):
|
||||
logging.info("send a message to the server")
|
||||
# 读取音频
|
||||
# self.read_wave()
|
||||
# 发送 websocket 的 handshake 协议头
|
||||
async with websockets.connect(self.url) as ws:
|
||||
# server 端已经接收到 handshake 协议头
|
||||
# 发送开始指令
|
||||
audio_info = json.dumps(
|
||||
{
|
||||
"name": "test.wav",
|
||||
"signal": "start",
|
||||
"nbest": 5
|
||||
},
|
||||
sort_keys=True,
|
||||
indent=4,
|
||||
separators=(',', ': '))
|
||||
await ws.send(audio_info)
|
||||
msg = await ws.recv()
|
||||
logging.info("receive msg={}".format(msg))
|
||||
|
||||
# send chunk audio data to engine
|
||||
for chunk_data in self.read_wave(wavfile_path):
|
||||
await ws.send(chunk_data.tobytes())
|
||||
msg = await ws.recv()
|
||||
logging.info("receive msg={}".format(msg))
|
||||
|
||||
# finished
|
||||
audio_info = json.dumps(
|
||||
{
|
||||
"name": "test.wav",
|
||||
"signal": "end",
|
||||
"nbest": 5
|
||||
},
|
||||
sort_keys=True,
|
||||
indent=4,
|
||||
separators=(',', ': '))
|
||||
await ws.send(audio_info)
|
||||
msg = await ws.recv()
|
||||
logging.info("receive msg={}".format(msg))
|
||||
|
||||
|
||||
def main(args):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logging.info("asr websocket client start")
|
||||
handler = ASRAudioHandler("127.0.0.1", 8091)
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(handler.run(args.wavfile))
|
||||
logging.info("asr websocket client finished")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--wavfile",
|
||||
action="store",
|
||||
help="wav file path ",
|
||||
default="./16_audio.wav")
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
@ -0,0 +1,59 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
class Frame(object):
|
||||
"""Represents a "frame" of audio data."""
|
||||
|
||||
def __init__(self, bytes, timestamp, duration):
|
||||
self.bytes = bytes
|
||||
self.timestamp = timestamp
|
||||
self.duration = duration
|
||||
|
||||
|
||||
class ChunkBuffer(object):
|
||||
def __init__(self,
|
||||
frame_duration_ms=80,
|
||||
shift_ms=40,
|
||||
sample_rate=16000,
|
||||
sample_width=2):
|
||||
self.sample_rate = sample_rate
|
||||
self.frame_duration_ms = frame_duration_ms
|
||||
self.shift_ms = shift_ms
|
||||
self.remained_audio = b''
|
||||
self.sample_width = sample_width # int16 = 2; float32 = 4
|
||||
|
||||
def frame_generator(self, audio):
|
||||
"""Generates audio frames from PCM audio data.
|
||||
Takes the desired frame duration in milliseconds, the PCM data, and
|
||||
the sample rate.
|
||||
Yields Frames of the requested duration.
|
||||
"""
|
||||
audio = self.remained_audio + audio
|
||||
self.remained_audio = b''
|
||||
|
||||
n = int(self.sample_rate *
|
||||
(self.frame_duration_ms / 1000.0) * self.sample_width)
|
||||
shift_n = int(self.sample_rate *
|
||||
(self.shift_ms / 1000.0) * self.sample_width)
|
||||
offset = 0
|
||||
timestamp = 0.0
|
||||
duration = (float(n) / self.sample_rate) / self.sample_width
|
||||
shift_duration = (float(shift_n) / self.sample_rate) / self.sample_width
|
||||
while offset + n <= len(audio):
|
||||
yield Frame(audio[offset:offset + n], timestamp, duration)
|
||||
timestamp += shift_duration
|
||||
offset += shift_n
|
||||
|
||||
self.remained_audio += audio[offset:]
|
@ -0,0 +1,78 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import collections
|
||||
|
||||
import webrtcvad
|
||||
|
||||
|
||||
class VADAudio():
|
||||
def __init__(self,
|
||||
aggressiveness=2,
|
||||
rate=16000,
|
||||
frame_duration_ms=20,
|
||||
sample_width=2,
|
||||
padding_ms=200,
|
||||
padding_ratio=0.9):
|
||||
"""Initializes VAD with given aggressivenes and sets up internal queues"""
|
||||
self.vad = webrtcvad.Vad(aggressiveness)
|
||||
self.rate = rate
|
||||
self.sample_width = sample_width
|
||||
self.frame_duration_ms = frame_duration_ms
|
||||
self._frame_length = int(rate * (frame_duration_ms / 1000.0) *
|
||||
self.sample_width)
|
||||
self._buffer_queue = collections.deque()
|
||||
self.ring_buffer = collections.deque(maxlen=padding_ms //
|
||||
frame_duration_ms)
|
||||
self._ratio = padding_ratio
|
||||
self.triggered = False
|
||||
|
||||
def add_audio(self, audio):
|
||||
"""Adds new audio to internal queue"""
|
||||
for x in audio:
|
||||
self._buffer_queue.append(x)
|
||||
|
||||
def frame_generator(self):
|
||||
"""Generator that yields audio frames of frame_duration_ms"""
|
||||
while len(self._buffer_queue) > self._frame_length:
|
||||
frame = bytearray()
|
||||
for _ in range(self._frame_length):
|
||||
frame.append(self._buffer_queue.popleft())
|
||||
yield bytes(frame)
|
||||
|
||||
def vad_collector(self):
|
||||
"""Generator that yields series of consecutive audio frames comprising each utterence, separated by yielding a single None.
|
||||
Determines voice activity by ratio of frames in padding_ms. Uses a buffer to include padding_ms prior to being triggered.
|
||||
Example: (frame, ..., frame, None, frame, ..., frame, None, ...)
|
||||
|---utterence---| |---utterence---|
|
||||
"""
|
||||
for frame in self.frame_generator():
|
||||
is_speech = self.vad.is_speech(frame, self.rate)
|
||||
if not self.triggered:
|
||||
self.ring_buffer.append((frame, is_speech))
|
||||
num_voiced = len(
|
||||
[f for f, speech in self.ring_buffer if speech])
|
||||
if num_voiced > self._ratio * self.ring_buffer.maxlen:
|
||||
self.triggered = True
|
||||
for f, s in self.ring_buffer:
|
||||
yield f
|
||||
self.ring_buffer.clear()
|
||||
else:
|
||||
yield frame
|
||||
self.ring_buffer.append((frame, is_speech))
|
||||
num_unvoiced = len(
|
||||
[f for f, speech in self.ring_buffer if not speech])
|
||||
if num_unvoiced > self._ratio * self.ring_buffer.maxlen:
|
||||
self.triggered = False
|
||||
yield None
|
||||
self.ring_buffer.clear()
|
@ -0,0 +1,13 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -0,0 +1,38 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from paddlespeech.server.ws.asr_socket import router as asr_router
|
||||
|
||||
_router = APIRouter()
|
||||
|
||||
|
||||
def setup_router(api_list: List):
|
||||
"""setup router for fastapi
|
||||
Args:
|
||||
api_list (List): [asr, tts]
|
||||
Returns:
|
||||
APIRouter
|
||||
"""
|
||||
for api_name in api_list:
|
||||
if api_name == 'asr':
|
||||
_router.include_router(asr_router)
|
||||
elif api_name == 'tts':
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
|
||||
return _router
|
@ -0,0 +1,100 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
from fastapi import APIRouter
|
||||
from fastapi import WebSocket
|
||||
from fastapi import WebSocketDisconnect
|
||||
from starlette.websockets import WebSocketState as WebSocketState
|
||||
|
||||
from paddlespeech.server.engine.engine_pool import get_engine_pool
|
||||
from paddlespeech.server.utils.buffer import ChunkBuffer
|
||||
from paddlespeech.server.utils.vad import VADAudio
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.websocket('/ws/asr')
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
|
||||
await websocket.accept()
|
||||
|
||||
engine_pool = get_engine_pool()
|
||||
asr_engine = engine_pool['asr']
|
||||
# init buffer
|
||||
chunk_buffer_conf = asr_engine.config.chunk_buffer_conf
|
||||
chunk_buffer = ChunkBuffer(
|
||||
sample_rate=chunk_buffer_conf['sample_rate'],
|
||||
sample_width=chunk_buffer_conf['sample_width'])
|
||||
# init vad
|
||||
vad_conf = asr_engine.config.vad_conf
|
||||
vad = VADAudio(
|
||||
aggressiveness=vad_conf['aggressiveness'],
|
||||
rate=vad_conf['sample_rate'],
|
||||
frame_duration_ms=vad_conf['frame_duration_ms'])
|
||||
|
||||
try:
|
||||
while True:
|
||||
# careful here, changed the source code from starlette.websockets
|
||||
assert websocket.application_state == WebSocketState.CONNECTED
|
||||
message = await websocket.receive()
|
||||
websocket._raise_on_disconnect(message)
|
||||
if "text" in message:
|
||||
message = json.loads(message["text"])
|
||||
if 'signal' not in message:
|
||||
resp = {"status": "ok", "message": "no valid json data"}
|
||||
await websocket.send_json(resp)
|
||||
|
||||
if message['signal'] == 'start':
|
||||
resp = {"status": "ok", "signal": "server_ready"}
|
||||
# do something at begining here
|
||||
await websocket.send_json(resp)
|
||||
elif message['signal'] == 'end':
|
||||
engine_pool = get_engine_pool()
|
||||
asr_engine = engine_pool['asr']
|
||||
# reset single engine for an new connection
|
||||
asr_engine.reset()
|
||||
resp = {"status": "ok", "signal": "finished"}
|
||||
await websocket.send_json(resp)
|
||||
break
|
||||
else:
|
||||
resp = {"status": "ok", "message": "no valid json data"}
|
||||
await websocket.send_json(resp)
|
||||
elif "bytes" in message:
|
||||
message = message["bytes"]
|
||||
|
||||
# vad for input bytes audio
|
||||
vad.add_audio(message)
|
||||
message = b''.join(f for f in vad.vad_collector()
|
||||
if f is not None)
|
||||
|
||||
engine_pool = get_engine_pool()
|
||||
asr_engine = engine_pool['asr']
|
||||
asr_results = ""
|
||||
frames = chunk_buffer.frame_generator(message)
|
||||
for frame in frames:
|
||||
samples = np.frombuffer(frame.bytes, dtype=np.int16)
|
||||
sample_rate = asr_engine.config.sample_rate
|
||||
x_chunk, x_chunk_lens = asr_engine.preprocess(samples,
|
||||
sample_rate)
|
||||
asr_engine.run(x_chunk, x_chunk_lens)
|
||||
asr_results = asr_engine.postprocess()
|
||||
|
||||
asr_results = asr_engine.postprocess()
|
||||
resp = {'asr_results': asr_results}
|
||||
|
||||
await websocket.send_json(resp)
|
||||
except WebSocketDisconnect:
|
||||
pass
|
Loading…
Reference in new issue