Merge pull request #1627 from WilliamZhang06/ws-develop

[websocket] added online asr engine
r0.2
Hui Zhang 2 years ago committed by GitHub
commit 61941d14b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -208,6 +208,18 @@ class AudioSegment():
io.BytesIO(bytes), dtype='float32')
return cls(samples, sample_rate)
@classmethod
def from_pcm(cls, samples, sample_rate):
"""Create audio segment from a byte string containing audio samples.
:param samples: Audio samples [num_samples x num_channels].
:type samples: numpy.ndarray
:param sample_rate: Audio sample rate.
:type sample_rate: int
:return: Audio segment instance.
:rtype: AudioSegment
"""
return cls(samples, sample_rate)
@classmethod
def concatenate(cls, *segments):
"""Concatenate an arbitrary number of audio segments together.

@ -107,6 +107,22 @@ class SpeechSegment(AudioSegment):
return cls(audio.samples, audio.sample_rate, transcript, tokens,
token_ids)
@classmethod
def from_pcm(cls, samples, sample_rate, transcript, tokens=None, token_ids=None):
"""Create speech segment from pcm on online mode
Args:
samples (numpy.ndarray): Audio samples [num_samples x num_channels].
sample_rate (int): Audio sample rate.
transcript (str): Transcript text for the speech.
tokens (List[str], optional): text tokens. Defaults to None.
token_ids (List[int], optional): text token ids. Defaults to None.
Returns:
SpeechSegment: Speech segment instance.
"""
audio = AudioSegment.from_pcm(samples, sample_rate)
return cls(audio.samples, audio.sample_rate, transcript, tokens,
token_ids)
@classmethod
def concatenate(cls, *segments):
"""Concatenate an arbitrary number of speech segments together, both

@ -17,7 +17,8 @@ import uvicorn
from fastapi import FastAPI
from paddlespeech.server.engine.engine_pool import init_engine_pool
from paddlespeech.server.restful.api import setup_router
from paddlespeech.server.restful.api import setup_router as setup_http_router
from paddlespeech.server.ws.api import setup_router as setup_ws_router
from paddlespeech.server.utils.config import get_config
app = FastAPI(
@ -35,7 +36,12 @@ def init(config):
"""
# init api
api_list = list(engine.split("_")[0] for engine in config.engine_list)
api_router = setup_router(api_list)
if config.protocol == "websocket":
api_router = setup_ws_router(api_list)
elif config.protocol == "http":
api_router = setup_http_router(api_list)
else:
raise Exception("unsupported protocol")
app.include_router(api_router)
if not init_engine_pool(config):

@ -8,7 +8,9 @@ port: 8090
# The task format in the engin_list is: <speech task>_<engine type>
# task choices = ['asr_python', 'asr_inference', 'tts_python', 'tts_inference']
# protocol = ['websocket', 'http'] (only one can be selected).
# http only support offline engine type.
protocol: 'http'
engine_list: ['asr_python', 'tts_python', 'cls_python']
@ -48,6 +50,24 @@ asr_inference:
summary: True # False -> do not show predictor config
################### speech task: asr; engine_type: online #######################
asr_online:
model_type: 'deepspeech2online_aishell'
am_model: # the pdmodel file of am static model [optional]
am_params: # the pdiparams file of am static model [optional]
lang: 'zh'
sample_rate: 16000
cfg_path:
decode_method:
force_yes: True
am_predictor_conf:
device: # set 'gpu:id' or 'cpu'
switch_ir_optim: True
glog_info: False # True -> print glog
summary: True # False -> do not show predictor config
################################### TTS #########################################
################### speech task: tts; engine_type: python #######################
tts_python:

@ -0,0 +1,51 @@
# This is the parameter configuration file for PaddleSpeech Serving.
#################################################################################
# SERVER SETTING #
#################################################################################
host: 0.0.0.0
port: 8091
# The task format in the engin_list is: <speech task>_<engine type>
# task choices = ['asr_online', 'tts_online']
# protocol = ['websocket', 'http'] (only one can be selected).
# websocket only support online engine type.
protocol: 'websocket'
engine_list: ['asr_online']
#################################################################################
# ENGINE CONFIG #
#################################################################################
################################### ASR #########################################
################### speech task: asr; engine_type: online #######################
asr_online:
model_type: 'deepspeech2online_aishell'
am_model: # the pdmodel file of am static model [optional]
am_params: # the pdiparams file of am static model [optional]
lang: 'zh'
sample_rate: 16000
cfg_path:
decode_method:
force_yes: True
am_predictor_conf:
device: # set 'gpu:id' or 'cpu'
switch_ir_optim: True
glog_info: False # True -> print glog
summary: True # False -> do not show predictor config
chunk_buffer_conf:
frame_duration_ms: 80
shift_ms: 40
sample_rate: 16000
sample_width: 2
vad_conf:
aggressiveness: 2
sample_rate: 16000
frame_duration_ms: 20
sample_width: 2
padding_ms: 200
padding_ratio: 0.9

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,355 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import os
import time
from typing import Optional
import pickle
import numpy as np
from numpy import float32
import soundfile
import paddle
from yacs.config import CfgNode
from paddlespeech.s2t.frontend.speech import SpeechSegment
from paddlespeech.cli.asr.infer import ASRExecutor
from paddlespeech.cli.log import logger
from paddlespeech.cli.utils import MODEL_HOME
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.modules.ctc import CTCDecoder
from paddlespeech.s2t.utils.utility import UpdateConfig
from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils.config import get_config
from paddlespeech.server.utils.paddle_predictor import init_predictor
from paddlespeech.server.utils.paddle_predictor import run_model
__all__ = ['ASREngine']
pretrained_models = {
"deepspeech2online_aishell-zh-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.1.1.model.tar.gz',
'md5':
'd5e076217cf60486519f72c217d21b9b',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2_online/checkpoints/avg_1',
'model':
'exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel',
'params':
'exp/deepspeech2_online/checkpoints/avg_1.jit.pdiparams',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
}
class ASRServerExecutor(ASRExecutor):
def __init__(self):
super().__init__()
pass
def _init_from_path(self,
model_type: str='wenetspeech',
am_model: Optional[os.PathLike]=None,
am_params: Optional[os.PathLike]=None,
lang: str='zh',
sample_rate: int=16000,
cfg_path: Optional[os.PathLike]=None,
decode_method: str='attention_rescoring',
am_predictor_conf: dict=None):
"""
Init model and other resources from a specific path.
"""
if cfg_path is None or am_model is None or am_params is None:
sample_rate_str = '16k' if sample_rate == 16000 else '8k'
tag = model_type + '-' + lang + '-' + sample_rate_str
res_path = self._get_pretrained_path(tag) # wenetspeech_zh
self.res_path = res_path
self.cfg_path = os.path.join(res_path,
pretrained_models[tag]['cfg_path'])
self.am_model = os.path.join(res_path,
pretrained_models[tag]['model'])
self.am_params = os.path.join(res_path,
pretrained_models[tag]['params'])
logger.info(res_path)
logger.info(self.cfg_path)
logger.info(self.am_model)
logger.info(self.am_params)
else:
self.cfg_path = os.path.abspath(cfg_path)
self.am_model = os.path.abspath(am_model)
self.am_params = os.path.abspath(am_params)
self.res_path = os.path.dirname(
os.path.dirname(os.path.abspath(self.cfg_path)))
#Init body.
self.config = CfgNode(new_allowed=True)
self.config.merge_from_file(self.cfg_path)
with UpdateConfig(self.config):
if "deepspeech2online" in model_type or "deepspeech2offline" in model_type:
from paddlespeech.s2t.io.collator import SpeechCollator
self.vocab = self.config.vocab_filepath
self.config.decode.lang_model_path = os.path.join(
MODEL_HOME, 'language_model',
self.config.decode.lang_model_path)
self.collate_fn_test = SpeechCollator.from_config(self.config)
self.text_feature = TextFeaturizer(
unit_type=self.config.unit_type, vocab=self.vocab)
lm_url = pretrained_models[tag]['lm_url']
lm_md5 = pretrained_models[tag]['lm_md5']
self.download_lm(
lm_url,
os.path.dirname(self.config.decode.lang_model_path), lm_md5)
elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type:
raise Exception("wrong type")
else:
raise Exception("wrong type")
# AM predictor
self.am_predictor_conf = am_predictor_conf
self.am_predictor = init_predictor(
model_file=self.am_model,
params_file=self.am_params,
predictor_conf=self.am_predictor_conf)
# decoder
self.decoder = CTCDecoder(
odim=self.config.output_dim, # <blank> is in vocab
enc_n_units=self.config.rnn_layer_size * 2,
blank_id=self.config.blank_id,
dropout_rate=0.0,
reduction=True, # sum
batch_average=True, # sum / batch_size
grad_norm_type=self.config.get('ctc_grad_norm_type', None))
# init decoder
cfg = self.config.decode
decode_batch_size = 1 # for online
self.decoder.init_decoder(
decode_batch_size, self.text_feature.vocab_list,
cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta,
cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n,
cfg.num_proc_bsearch)
# init state box
self.chunk_state_h_box = np.zeros(
(self.config.num_rnn_layers, 1, self.config.rnn_layer_size),
dtype=float32)
self.chunk_state_c_box = np.zeros(
(self.config.num_rnn_layers, 1, self.config.rnn_layer_size),
dtype=float32)
def reset_decoder_and_chunk(self):
"""reset decoder and chunk state for an new audio
"""
self.decoder.reset_decoder(batch_size=1)
# init state box, for new audio request
self.chunk_state_h_box = np.zeros(
(self.config.num_rnn_layers, 1, self.config.rnn_layer_size),
dtype=float32)
self.chunk_state_c_box = np.zeros(
(self.config.num_rnn_layers, 1, self.config.rnn_layer_size),
dtype=float32)
def decode_one_chunk(self, x_chunk, x_chunk_lens, model_type: str):
"""decode one chunk
Args:
x_chunk (numpy.array): shape[B, T, D]
x_chunk_lens (numpy.array): shape[B]
model_type (str): online model type
Returns:
[type]: [description]
"""
if "deepspeech2online" in model_type :
input_names = self.am_predictor.get_input_names()
audio_handle = self.am_predictor.get_input_handle(input_names[0])
audio_len_handle = self.am_predictor.get_input_handle(input_names[1])
h_box_handle = self.am_predictor.get_input_handle(input_names[2])
c_box_handle = self.am_predictor.get_input_handle(input_names[3])
audio_handle.reshape(x_chunk.shape)
audio_handle.copy_from_cpu(x_chunk)
audio_len_handle.reshape(x_chunk_lens.shape)
audio_len_handle.copy_from_cpu(x_chunk_lens)
h_box_handle.reshape(self.chunk_state_h_box.shape)
h_box_handle.copy_from_cpu(self.chunk_state_h_box)
c_box_handle.reshape(self.chunk_state_c_box.shape)
c_box_handle.copy_from_cpu(self.chunk_state_c_box)
output_names = self.am_predictor.get_output_names()
output_handle = self.am_predictor.get_output_handle(output_names[0])
output_lens_handle = self.am_predictor.get_output_handle(output_names[1])
output_state_h_handle = self.am_predictor.get_output_handle(
output_names[2])
output_state_c_handle = self.am_predictor.get_output_handle(
output_names[3])
self.am_predictor.run()
output_chunk_probs = output_handle.copy_to_cpu()
output_chunk_lens = output_lens_handle.copy_to_cpu()
self.chunk_state_h_box = output_state_h_handle.copy_to_cpu()
self.chunk_state_c_box = output_state_c_handle.copy_to_cpu()
self.decoder.next(output_chunk_probs, output_chunk_lens)
trans_best, trans_beam = self.decoder.decode()
return trans_best[0]
elif "conformer" in model_type or "transformer" in model_type:
raise Exception("invalid model name")
else:
raise Exception("invalid model name")
def _pcm16to32(self, audio):
"""pcm int16 to float32
Args:
audio(numpy.array): numpy.int16
Returns:
audio(numpy.array): numpy.float32
"""
if audio.dtype == np.int16:
audio = audio.astype("float32")
bits = np.iinfo(np.int16).bits
audio = audio / (2**(bits - 1))
return audio
def extract_feat(self, samples, sample_rate):
"""extract feat
Args:
samples (numpy.array): numpy.float32
sample_rate (int): sample rate
Returns:
x_chunk (numpy.array): shape[B, T, D]
x_chunk_lens (numpy.array): shape[B]
"""
# pcm16 -> pcm 32
samples = self._pcm16to32(samples)
# read audio
speech_segment = SpeechSegment.from_pcm(
samples, sample_rate, transcript=" ")
# audio augment
self.collate_fn_test.augmentation.transform_audio(speech_segment)
# extract speech feature
spectrum, transcript_part = self.collate_fn_test._speech_featurizer.featurize(
speech_segment, self.collate_fn_test.keep_transcription_text)
# CMVN spectrum
if self.collate_fn_test._normalizer:
spectrum = self.collate_fn_test._normalizer.apply(spectrum)
# spectrum augment
audio = self.collate_fn_test.augmentation.transform_feature(spectrum)
audio_len = audio.shape[0]
audio = paddle.to_tensor(audio, dtype='float32')
# audio_len = paddle.to_tensor(audio_len)
audio = paddle.unsqueeze(audio, axis=0)
x_chunk = audio.numpy()
x_chunk_lens = np.array([audio_len])
return x_chunk, x_chunk_lens
class ASREngine(BaseEngine):
"""ASR server engine
Args:
metaclass: Defaults to Singleton.
"""
def __init__(self):
super(ASREngine, self).__init__()
def init(self, config: dict) -> bool:
"""init engine resource
Args:
config_file (str): config file
Returns:
bool: init failed or success
"""
self.input = None
self.output = ""
self.executor = ASRServerExecutor()
self.config = config
self.executor._init_from_path(
model_type=self.config.model_type,
am_model=self.config.am_model,
am_params=self.config.am_params,
lang=self.config.lang,
sample_rate=self.config.sample_rate,
cfg_path=self.config.cfg_path,
decode_method=self.config.decode_method,
am_predictor_conf=self.config.am_predictor_conf)
logger.info("Initialize ASR server engine successfully.")
return True
def preprocess(self, samples, sample_rate):
"""preprocess
Args:
samples (numpy.array): numpy.float32
sample_rate (int): sample rate
Returns:
x_chunk (numpy.array): shape[B, T, D]
x_chunk_lens (numpy.array): shape[B]
"""
x_chunk, x_chunk_lens = self.executor.extract_feat(samples, sample_rate)
return x_chunk, x_chunk_lens
def run(self, x_chunk, x_chunk_lens, decoder_chunk_size=1):
"""run online engine
Args:
x_chunk (numpy.array): shape[B, T, D]
x_chunk_lens (numpy.array): shape[B]
decoder_chunk_size(int)
"""
self.output = self.executor.decode_one_chunk(x_chunk, x_chunk_lens, self.config.model_type)
def postprocess(self):
"""postprocess
"""
return self.output
def reset(self):
"""reset engine decoder and inference state
"""
self.executor.reset_decoder_and_chunk()
self.output = ""

@ -25,6 +25,9 @@ class EngineFactory(object):
elif engine_name == 'asr' and engine_type == 'python':
from paddlespeech.server.engine.asr.python.asr_engine import ASREngine
return ASREngine()
elif engine_name == 'asr' and engine_type == 'online':
from paddlespeech.server.engine.asr.online.asr_engine import ASREngine
return ASREngine()
elif engine_name == 'tts' and engine_type == 'inference':
from paddlespeech.server.engine.tts.paddleinference.tts_engine import TTSEngine
return TTSEngine()

@ -0,0 +1,161 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
record wave from the mic
"""
import asyncio
import json
import logging
import threading
import wave
from signal import SIGINT
from signal import SIGTERM
import pyaudio
import websockets
class ASRAudioHandler(threading.Thread):
def __init__(self, url="127.0.0.1", port=8091):
threading.Thread.__init__(self)
self.url = url
self.port = port
self.url = "ws://" + self.url + ":" + str(self.port) + "/ws/asr"
self.fileName = "./output.wav"
self.chunk = 5120
self.format = pyaudio.paInt16
self.channels = 1
self.rate = 16000
self._running = True
self._frames = []
self.data_backup = []
def startrecord(self):
"""
start a new thread to record wave
"""
threading._start_new_thread(self.recording, ())
def recording(self):
"""
recording wave
"""
self._running = True
self._frames = []
p = pyaudio.PyAudio()
stream = p.open(
format=self.format,
channels=self.channels,
rate=self.rate,
input=True,
frames_per_buffer=self.chunk)
while (self._running):
data = stream.read(self.chunk)
self._frames.append(data)
self.data_backup.append(data)
stream.stop_stream()
stream.close()
p.terminate()
def save(self):
"""
save wave data
"""
p = pyaudio.PyAudio()
wf = wave.open(self.fileName, 'wb')
wf.setnchannels(self.channels)
wf.setsampwidth(p.get_sample_size(self.format))
wf.setframerate(self.rate)
wf.writeframes(b''.join(self.data_backup))
wf.close()
p.terminate()
def stoprecord(self):
"""
stop recording
"""
self._running = False
async def run(self):
aa = input("是否开始录音? (y/n)")
if aa.strip() == "y":
self.startrecord()
logging.info("*" * 10 + "开始录音,请输入语音")
async with websockets.connect(self.url) as ws:
# 发送开始指令
audio_info = json.dumps(
{
"name": "test.wav",
"signal": "start",
"nbest": 5
},
sort_keys=True,
indent=4,
separators=(',', ': '))
await ws.send(audio_info)
msg = await ws.recv()
logging.info("receive msg={}".format(msg))
# send bytes data
logging.info("结束录音请: Ctrl + c。继续请按回车。")
try:
while True:
while len(self._frames) > 0:
await ws.send(self._frames.pop(0))
msg = await ws.recv()
logging.info("receive msg={}".format(msg))
except asyncio.CancelledError:
# quit
# send finished
audio_info = json.dumps(
{
"name": "test.wav",
"signal": "end",
"nbest": 5
},
sort_keys=True,
indent=4,
separators=(',', ': '))
await ws.send(audio_info)
msg = await ws.recv()
logging.info("receive msg={}".format(msg))
self.stoprecord()
logging.info("*" * 10 + "录音结束")
self.save()
elif aa.strip() == "n":
exit()
else:
print("无效输入!")
exit()
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
logging.info("asr websocket client start")
handler = ASRAudioHandler("127.0.0.1", 8091)
loop = asyncio.get_event_loop()
main_task = asyncio.ensure_future(handler.run())
for signal in [SIGINT, SIGTERM]:
loop.add_signal_handler(signal, main_task.cancel)
try:
loop.run_until_complete(main_task)
finally:
loop.close()
logging.info("asr websocket client finished")

@ -0,0 +1,115 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#!/usr/bin/python
# -*- coding: UTF-8 -*-
import argparse
import asyncio
import json
import logging
import numpy as np
import soundfile
import websockets
class ASRAudioHandler:
def __init__(self, url="127.0.0.1", port=8090):
self.url = url
self.port = port
self.url = "ws://" + self.url + ":" + str(self.port) + "/ws/asr"
def read_wave(self, wavfile_path: str):
samples, sample_rate = soundfile.read(wavfile_path, dtype='int16')
x_len = len(samples)
chunk_stride = 40 * 16 #40ms, sample_rate = 16kHz
chunk_size = 80 * 16 #80ms, sample_rate = 16kHz
if (x_len - chunk_size) % chunk_stride != 0:
padding_len_x = chunk_stride - (x_len - chunk_size) % chunk_stride
else:
padding_len_x = 0
padding = np.zeros((padding_len_x), dtype=samples.dtype)
padded_x = np.concatenate([samples, padding], axis=0)
num_chunk = (x_len + padding_len_x - chunk_size) / chunk_stride + 1
num_chunk = int(num_chunk)
for i in range(0, num_chunk):
start = i * chunk_stride
end = start + chunk_size
x_chunk = padded_x[start:end]
yield x_chunk
async def run(self, wavfile_path: str):
logging.info("send a message to the server")
# 读取音频
# self.read_wave()
# 发送 websocket 的 handshake 协议头
async with websockets.connect(self.url) as ws:
# server 端已经接收到 handshake 协议头
# 发送开始指令
audio_info = json.dumps(
{
"name": "test.wav",
"signal": "start",
"nbest": 5
},
sort_keys=True,
indent=4,
separators=(',', ': '))
await ws.send(audio_info)
msg = await ws.recv()
logging.info("receive msg={}".format(msg))
# send chunk audio data to engine
for chunk_data in self.read_wave(wavfile_path):
await ws.send(chunk_data.tobytes())
msg = await ws.recv()
logging.info("receive msg={}".format(msg))
# finished
audio_info = json.dumps(
{
"name": "test.wav",
"signal": "end",
"nbest": 5
},
sort_keys=True,
indent=4,
separators=(',', ': '))
await ws.send(audio_info)
msg = await ws.recv()
logging.info("receive msg={}".format(msg))
def main(args):
logging.basicConfig(level=logging.INFO)
logging.info("asr websocket client start")
handler = ASRAudioHandler("127.0.0.1", 8091)
loop = asyncio.get_event_loop()
loop.run_until_complete(handler.run(args.wavfile))
logging.info("asr websocket client finished")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--wavfile",
action="store",
help="wav file path ",
default="./16_audio.wav")
args = parser.parse_args()
main(args)

@ -0,0 +1,59 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class Frame(object):
"""Represents a "frame" of audio data."""
def __init__(self, bytes, timestamp, duration):
self.bytes = bytes
self.timestamp = timestamp
self.duration = duration
class ChunkBuffer(object):
def __init__(self,
frame_duration_ms=80,
shift_ms=40,
sample_rate=16000,
sample_width=2):
self.sample_rate = sample_rate
self.frame_duration_ms = frame_duration_ms
self.shift_ms = shift_ms
self.remained_audio = b''
self.sample_width = sample_width # int16 = 2; float32 = 4
def frame_generator(self, audio):
"""Generates audio frames from PCM audio data.
Takes the desired frame duration in milliseconds, the PCM data, and
the sample rate.
Yields Frames of the requested duration.
"""
audio = self.remained_audio + audio
self.remained_audio = b''
n = int(self.sample_rate *
(self.frame_duration_ms / 1000.0) * self.sample_width)
shift_n = int(self.sample_rate *
(self.shift_ms / 1000.0) * self.sample_width)
offset = 0
timestamp = 0.0
duration = (float(n) / self.sample_rate) / self.sample_width
shift_duration = (float(shift_n) / self.sample_rate) / self.sample_width
while offset + n <= len(audio):
yield Frame(audio[offset:offset + n], timestamp, duration)
timestamp += shift_duration
offset += shift_n
self.remained_audio += audio[offset:]

@ -0,0 +1,78 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import webrtcvad
class VADAudio():
def __init__(self,
aggressiveness=2,
rate=16000,
frame_duration_ms=20,
sample_width=2,
padding_ms=200,
padding_ratio=0.9):
"""Initializes VAD with given aggressivenes and sets up internal queues"""
self.vad = webrtcvad.Vad(aggressiveness)
self.rate = rate
self.sample_width = sample_width
self.frame_duration_ms = frame_duration_ms
self._frame_length = int(rate * (frame_duration_ms / 1000.0) *
self.sample_width)
self._buffer_queue = collections.deque()
self.ring_buffer = collections.deque(maxlen=padding_ms //
frame_duration_ms)
self._ratio = padding_ratio
self.triggered = False
def add_audio(self, audio):
"""Adds new audio to internal queue"""
for x in audio:
self._buffer_queue.append(x)
def frame_generator(self):
"""Generator that yields audio frames of frame_duration_ms"""
while len(self._buffer_queue) > self._frame_length:
frame = bytearray()
for _ in range(self._frame_length):
frame.append(self._buffer_queue.popleft())
yield bytes(frame)
def vad_collector(self):
"""Generator that yields series of consecutive audio frames comprising each utterence, separated by yielding a single None.
Determines voice activity by ratio of frames in padding_ms. Uses a buffer to include padding_ms prior to being triggered.
Example: (frame, ..., frame, None, frame, ..., frame, None, ...)
|---utterence---| |---utterence---|
"""
for frame in self.frame_generator():
is_speech = self.vad.is_speech(frame, self.rate)
if not self.triggered:
self.ring_buffer.append((frame, is_speech))
num_voiced = len(
[f for f, speech in self.ring_buffer if speech])
if num_voiced > self._ratio * self.ring_buffer.maxlen:
self.triggered = True
for f, s in self.ring_buffer:
yield f
self.ring_buffer.clear()
else:
yield frame
self.ring_buffer.append((frame, is_speech))
num_unvoiced = len(
[f for f, speech in self.ring_buffer if not speech])
if num_unvoiced > self._ratio * self.ring_buffer.maxlen:
self.triggered = False
yield None
self.ring_buffer.clear()

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,38 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
from fastapi import APIRouter
from paddlespeech.server.ws.asr_socket import router as asr_router
_router = APIRouter()
def setup_router(api_list: List):
"""setup router for fastapi
Args:
api_list (List): [asr, tts]
Returns:
APIRouter
"""
for api_name in api_list:
if api_name == 'asr':
_router.include_router(asr_router)
elif api_name == 'tts':
pass
else:
pass
return _router

@ -0,0 +1,100 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import numpy as np
from fastapi import APIRouter
from fastapi import WebSocket
from fastapi import WebSocketDisconnect
from starlette.websockets import WebSocketState as WebSocketState
from paddlespeech.server.engine.engine_pool import get_engine_pool
from paddlespeech.server.utils.buffer import ChunkBuffer
from paddlespeech.server.utils.vad import VADAudio
router = APIRouter()
@router.websocket('/ws/asr')
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
engine_pool = get_engine_pool()
asr_engine = engine_pool['asr']
# init buffer
chunk_buffer_conf = asr_engine.config.chunk_buffer_conf
chunk_buffer = ChunkBuffer(
sample_rate=chunk_buffer_conf['sample_rate'],
sample_width=chunk_buffer_conf['sample_width'])
# init vad
vad_conf = asr_engine.config.vad_conf
vad = VADAudio(
aggressiveness=vad_conf['aggressiveness'],
rate=vad_conf['sample_rate'],
frame_duration_ms=vad_conf['frame_duration_ms'])
try:
while True:
# careful here, changed the source code from starlette.websockets
assert websocket.application_state == WebSocketState.CONNECTED
message = await websocket.receive()
websocket._raise_on_disconnect(message)
if "text" in message:
message = json.loads(message["text"])
if 'signal' not in message:
resp = {"status": "ok", "message": "no valid json data"}
await websocket.send_json(resp)
if message['signal'] == 'start':
resp = {"status": "ok", "signal": "server_ready"}
# do something at begining here
await websocket.send_json(resp)
elif message['signal'] == 'end':
engine_pool = get_engine_pool()
asr_engine = engine_pool['asr']
# reset single engine for an new connection
asr_engine.reset()
resp = {"status": "ok", "signal": "finished"}
await websocket.send_json(resp)
break
else:
resp = {"status": "ok", "message": "no valid json data"}
await websocket.send_json(resp)
elif "bytes" in message:
message = message["bytes"]
# vad for input bytes audio
vad.add_audio(message)
message = b''.join(f for f in vad.vad_collector()
if f is not None)
engine_pool = get_engine_pool()
asr_engine = engine_pool['asr']
asr_results = ""
frames = chunk_buffer.frame_generator(message)
for frame in frames:
samples = np.frombuffer(frame.bytes, dtype=np.int16)
sample_rate = asr_engine.config.sample_rate
x_chunk, x_chunk_lens = asr_engine.preprocess(samples,
sample_rate)
asr_engine.run(x_chunk, x_chunk_lens)
asr_results = asr_engine.postprocess()
asr_results = asr_engine.postprocess()
resp = {'asr_results': asr_results}
await websocket.send_json(resp)
except WebSocketDisconnect:
pass
Loading…
Cancel
Save