Merge pull request #1710 from Honei/deepspeech_server

[asr][websocket]fix the ws send bug, cache buffer,  text=doc
pull/1716/head
Hui Zhang 2 years ago committed by GitHub
commit 0cde9f87ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -7,4 +7,4 @@ paddlespeech asr --input ./zh.wav
# asr + punc
paddlespeech asr --input ./zh.wav | paddlespeech text --task punc
paddlespeech asr --input ./zh.wav | paddlespeech text --task punc

@ -79,7 +79,6 @@ class U2Infer():
ilen = paddle.to_tensor(feat.shape[0])
xs = paddle.to_tensor(feat, dtype='float32').unsqueeze(axis=0)
decode_config = self.config.decode
result_transcripts = self.model.decode(
xs,
@ -129,6 +128,7 @@ if __name__ == "__main__":
args = parser.parse_args()
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
if args.decode_cfg:

@ -4,7 +4,7 @@
# SERVER SETTING #
#################################################################################
host: 0.0.0.0
port: 8091
port: 8090
# The task format in the engin_list is: <speech task>_<engine type>
# task choices = ['asr_online', 'tts_online']

@ -15,8 +15,10 @@
# -*- coding: UTF-8 -*-
import argparse
import asyncio
import codecs
import json
import logging
import os
import numpy as np
import soundfile
@ -32,34 +34,30 @@ class ASRAudioHandler:
def read_wave(self, wavfile_path: str):
samples, sample_rate = soundfile.read(wavfile_path, dtype='int16')
x_len = len(samples)
chunk_stride = 40 * 16 #40ms, sample_rate = 16kHz
# chunk_stride = 40 * 16 #40ms, sample_rate = 16kHz
chunk_size = 80 * 16 #80ms, sample_rate = 16kHz
if (x_len - chunk_size) % chunk_stride != 0:
padding_len_x = chunk_stride - (x_len - chunk_size) % chunk_stride
if x_len % chunk_size != 0:
padding_len_x = chunk_size - x_len % chunk_size
else:
padding_len_x = 0
padding = np.zeros((padding_len_x), dtype=samples.dtype)
padded_x = np.concatenate([samples, padding], axis=0)
num_chunk = (x_len + padding_len_x - chunk_size) / chunk_stride + 1
assert (x_len + padding_len_x) % chunk_size == 0
num_chunk = (x_len + padding_len_x) / chunk_size
num_chunk = int(num_chunk)
for i in range(0, num_chunk):
start = i * chunk_stride
start = i * chunk_size
end = start + chunk_size
x_chunk = padded_x[start:end]
yield x_chunk
async def run(self, wavfile_path: str):
logging.info("send a message to the server")
# 读取音频
# self.read_wave()
# 发送 websocket 的 handshake 协议头
async with websockets.connect(self.url) as ws:
# server 端已经接收到 handshake 协议头
# 发送开始指令
audio_info = json.dumps(
{
"name": "test.wav",
@ -77,8 +75,10 @@ class ASRAudioHandler:
for chunk_data in self.read_wave(wavfile_path):
await ws.send(chunk_data.tobytes())
msg = await ws.recv()
msg = json.loads(msg)
logging.info("receive msg={}".format(msg))
result = msg
# finished
audio_info = json.dumps(
{
@ -91,16 +91,35 @@ class ASRAudioHandler:
separators=(',', ': '))
await ws.send(audio_info)
msg = await ws.recv()
msg = json.loads(msg)
logging.info("receive msg={}".format(msg))
return result
def main(args):
logging.basicConfig(level=logging.INFO)
logging.info("asr websocket client start")
handler = ASRAudioHandler("127.0.0.1", 8091)
handler = ASRAudioHandler("127.0.0.1", 8090)
loop = asyncio.get_event_loop()
loop.run_until_complete(handler.run(args.wavfile))
logging.info("asr websocket client finished")
# support to process single audio file
if args.wavfile and os.path.exists(args.wavfile):
logging.info(f"start to process the wavscp: {args.wavfile}")
result = loop.run_until_complete(handler.run(args.wavfile))
result = result["asr_results"]
logging.info(f"asr websocket client finished : {result}")
# support to process batch audios from wav.scp
if args.wavscp and os.path.exists(args.wavscp):
logging.info(f"start to process the wavscp: {args.wavscp}")
with codecs.open(args.wavscp, 'r', encoding='utf-8') as f,\
codecs.open("result.txt", 'w', encoding='utf-8') as w:
for line in f:
utt_name, utt_path = line.strip().split()
result = loop.run_until_complete(handler.run(utt_path))
result = result["asr_results"]
w.write(f"{utt_name} {result}\n")
if __name__ == "__main__":
@ -110,6 +129,8 @@ if __name__ == "__main__":
action="store",
help="wav file path ",
default="./16_audio.wav")
parser.add_argument(
"--wavscp", type=str, default=None, help="The batch audios dict text")
args = parser.parse_args()
main(args)

@ -24,15 +24,38 @@ class Frame(object):
class ChunkBuffer(object):
def __init__(self,
frame_duration_ms=80,
shift_ms=40,
window_n=7,
shift_n=4,
window_ms=20,
shift_ms=10,
sample_rate=16000,
sample_width=2):
self.sample_rate = sample_rate
self.frame_duration_ms = frame_duration_ms
"""audio sample data point buffer
Args:
window_n (int, optional): decode window frame length. Defaults to 7 frame.
shift_n (int, optional): decode shift frame length. Defaults to 4 frame.
window_ms (int, optional): frame length, ms. Defaults to 20 ms.
shift_ms (int, optional): shift length, ms. Defaults to 10 ms.
sample_rate (int, optional): audio sample rate. Defaults to 16000.
sample_width (int, optional): sample point bytes. Defaults to 2 bytes.
"""
self.window_n = window_n
self.shift_n = shift_n
self.window_ms = window_ms
self.shift_ms = shift_ms
self.remained_audio = b''
self.sample_rate = sample_rate
self.sample_width = sample_width # int16 = 2; float32 = 4
self.remained_audio = b''
self.window_sec = float((self.window_n - 1) * self.shift_ms +
self.window_ms) / 1000.0
self.shift_sec = float(self.shift_n * self.shift_ms / 1000.0)
self.window_bytes = int(self.window_sec * self.sample_rate *
self.sample_width)
self.shift_bytes = int(self.shift_sec * self.sample_rate *
self.sample_width)
def frame_generator(self, audio):
"""Generates audio frames from PCM audio data.
@ -43,17 +66,13 @@ class ChunkBuffer(object):
audio = self.remained_audio + audio
self.remained_audio = b''
n = int(self.sample_rate * (self.frame_duration_ms / 1000.0) *
self.sample_width)
shift_n = int(self.sample_rate * (self.shift_ms / 1000.0) *
self.sample_width)
offset = 0
timestamp = 0.0
duration = (float(n) / self.sample_rate) / self.sample_width
shift_duration = (float(shift_n) / self.sample_rate) / self.sample_width
while offset + n <= len(audio):
yield Frame(audio[offset:offset + n], timestamp, duration)
timestamp += shift_duration
offset += shift_n
while offset + self.window_bytes <= len(audio):
yield Frame(audio[offset:offset + self.window_bytes], timestamp,
self.window_sec)
timestamp += self.shift_sec
offset += self.shift_bytes
self.remained_audio += audio[offset:]

@ -36,6 +36,10 @@ async def websocket_endpoint(websocket: WebSocket):
# init buffer
chunk_buffer_conf = asr_engine.config.chunk_buffer_conf
chunk_buffer = ChunkBuffer(
window_n=7,
shift_n=4,
window_ms=20,
shift_ms=10,
sample_rate=chunk_buffer_conf['sample_rate'],
sample_width=chunk_buffer_conf['sample_width'])
# init vad
@ -75,11 +79,6 @@ async def websocket_endpoint(websocket: WebSocket):
elif "bytes" in message:
message = message["bytes"]
# vad for input bytes audio
vad.add_audio(message)
message = b''.join(f for f in vad.vad_collector()
if f is not None)
engine_pool = get_engine_pool()
asr_engine = engine_pool['asr']
asr_results = ""

@ -19,11 +19,11 @@ A few sklearn functions are modified in this script as per requirement.
"""
import argparse
import warnings
from distutils.util import strtobool
import numpy as np
import scipy
import sklearn
from distutils.util import strtobool
from scipy import sparse
from scipy.sparse.csgraph import connected_components
from scipy.sparse.csgraph import laplacian as csgraph_laplacian

@ -26,9 +26,9 @@ import argparse
import os
import re
import subprocess
from distutils.util import strtobool
import numpy as np
from distutils.util import strtobool
FILE_IDS = re.compile(r"(?<=Speaker Diarization for).+(?=\*\*\*)")
SCORED_SPEAKER_TIME = re.compile(r"(?<=SCORED SPEAKER TIME =)[\d.]+")

Loading…
Cancel
Save