Merge pull request #1710 from Honei/deepspeech_server

[asr][websocket]fix the ws send bug, cache buffer,  text=doc
pull/1716/head
Hui Zhang 3 years ago committed by GitHub
commit 0cde9f87ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -7,4 +7,4 @@ paddlespeech asr --input ./zh.wav
# asr + punc # asr + punc
paddlespeech asr --input ./zh.wav | paddlespeech text --task punc paddlespeech asr --input ./zh.wav | paddlespeech text --task punc

@ -79,7 +79,6 @@ class U2Infer():
ilen = paddle.to_tensor(feat.shape[0]) ilen = paddle.to_tensor(feat.shape[0])
xs = paddle.to_tensor(feat, dtype='float32').unsqueeze(axis=0) xs = paddle.to_tensor(feat, dtype='float32').unsqueeze(axis=0)
decode_config = self.config.decode decode_config = self.config.decode
result_transcripts = self.model.decode( result_transcripts = self.model.decode(
xs, xs,
@ -129,6 +128,7 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
config = CfgNode(new_allowed=True) config = CfgNode(new_allowed=True)
if args.config: if args.config:
config.merge_from_file(args.config) config.merge_from_file(args.config)
if args.decode_cfg: if args.decode_cfg:

@ -4,7 +4,7 @@
# SERVER SETTING # # SERVER SETTING #
################################################################################# #################################################################################
host: 0.0.0.0 host: 0.0.0.0
port: 8091 port: 8090
# The task format in the engin_list is: <speech task>_<engine type> # The task format in the engin_list is: <speech task>_<engine type>
# task choices = ['asr_online', 'tts_online'] # task choices = ['asr_online', 'tts_online']

@ -15,8 +15,10 @@
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
import argparse import argparse
import asyncio import asyncio
import codecs
import json import json
import logging import logging
import os
import numpy as np import numpy as np
import soundfile import soundfile
@ -32,34 +34,30 @@ class ASRAudioHandler:
def read_wave(self, wavfile_path: str): def read_wave(self, wavfile_path: str):
samples, sample_rate = soundfile.read(wavfile_path, dtype='int16') samples, sample_rate = soundfile.read(wavfile_path, dtype='int16')
x_len = len(samples) x_len = len(samples)
chunk_stride = 40 * 16 #40ms, sample_rate = 16kHz # chunk_stride = 40 * 16 #40ms, sample_rate = 16kHz
chunk_size = 80 * 16 #80ms, sample_rate = 16kHz chunk_size = 80 * 16 #80ms, sample_rate = 16kHz
if (x_len - chunk_size) % chunk_stride != 0: if x_len % chunk_size != 0:
padding_len_x = chunk_stride - (x_len - chunk_size) % chunk_stride padding_len_x = chunk_size - x_len % chunk_size
else: else:
padding_len_x = 0 padding_len_x = 0
padding = np.zeros((padding_len_x), dtype=samples.dtype) padding = np.zeros((padding_len_x), dtype=samples.dtype)
padded_x = np.concatenate([samples, padding], axis=0) padded_x = np.concatenate([samples, padding], axis=0)
num_chunk = (x_len + padding_len_x - chunk_size) / chunk_stride + 1 assert (x_len + padding_len_x) % chunk_size == 0
num_chunk = (x_len + padding_len_x) / chunk_size
num_chunk = int(num_chunk) num_chunk = int(num_chunk)
for i in range(0, num_chunk): for i in range(0, num_chunk):
start = i * chunk_stride start = i * chunk_size
end = start + chunk_size end = start + chunk_size
x_chunk = padded_x[start:end] x_chunk = padded_x[start:end]
yield x_chunk yield x_chunk
async def run(self, wavfile_path: str): async def run(self, wavfile_path: str):
logging.info("send a message to the server") logging.info("send a message to the server")
# 读取音频
# self.read_wave()
# 发送 websocket 的 handshake 协议头
async with websockets.connect(self.url) as ws: async with websockets.connect(self.url) as ws:
# server 端已经接收到 handshake 协议头
# 发送开始指令
audio_info = json.dumps( audio_info = json.dumps(
{ {
"name": "test.wav", "name": "test.wav",
@ -77,8 +75,10 @@ class ASRAudioHandler:
for chunk_data in self.read_wave(wavfile_path): for chunk_data in self.read_wave(wavfile_path):
await ws.send(chunk_data.tobytes()) await ws.send(chunk_data.tobytes())
msg = await ws.recv() msg = await ws.recv()
msg = json.loads(msg)
logging.info("receive msg={}".format(msg)) logging.info("receive msg={}".format(msg))
result = msg
# finished # finished
audio_info = json.dumps( audio_info = json.dumps(
{ {
@ -91,16 +91,35 @@ class ASRAudioHandler:
separators=(',', ': ')) separators=(',', ': '))
await ws.send(audio_info) await ws.send(audio_info)
msg = await ws.recv() msg = await ws.recv()
msg = json.loads(msg)
logging.info("receive msg={}".format(msg)) logging.info("receive msg={}".format(msg))
return result
def main(args): def main(args):
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logging.info("asr websocket client start") logging.info("asr websocket client start")
handler = ASRAudioHandler("127.0.0.1", 8091) handler = ASRAudioHandler("127.0.0.1", 8090)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
loop.run_until_complete(handler.run(args.wavfile))
logging.info("asr websocket client finished") # support to process single audio file
if args.wavfile and os.path.exists(args.wavfile):
logging.info(f"start to process the wavscp: {args.wavfile}")
result = loop.run_until_complete(handler.run(args.wavfile))
result = result["asr_results"]
logging.info(f"asr websocket client finished : {result}")
# support to process batch audios from wav.scp
if args.wavscp and os.path.exists(args.wavscp):
logging.info(f"start to process the wavscp: {args.wavscp}")
with codecs.open(args.wavscp, 'r', encoding='utf-8') as f,\
codecs.open("result.txt", 'w', encoding='utf-8') as w:
for line in f:
utt_name, utt_path = line.strip().split()
result = loop.run_until_complete(handler.run(utt_path))
result = result["asr_results"]
w.write(f"{utt_name} {result}\n")
if __name__ == "__main__": if __name__ == "__main__":
@ -110,6 +129,8 @@ if __name__ == "__main__":
action="store", action="store",
help="wav file path ", help="wav file path ",
default="./16_audio.wav") default="./16_audio.wav")
parser.add_argument(
"--wavscp", type=str, default=None, help="The batch audios dict text")
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

@ -24,15 +24,38 @@ class Frame(object):
class ChunkBuffer(object): class ChunkBuffer(object):
def __init__(self, def __init__(self,
frame_duration_ms=80, window_n=7,
shift_ms=40, shift_n=4,
window_ms=20,
shift_ms=10,
sample_rate=16000, sample_rate=16000,
sample_width=2): sample_width=2):
self.sample_rate = sample_rate """audio sample data point buffer
self.frame_duration_ms = frame_duration_ms
Args:
window_n (int, optional): decode window frame length. Defaults to 7 frame.
shift_n (int, optional): decode shift frame length. Defaults to 4 frame.
window_ms (int, optional): frame length, ms. Defaults to 20 ms.
shift_ms (int, optional): shift length, ms. Defaults to 10 ms.
sample_rate (int, optional): audio sample rate. Defaults to 16000.
sample_width (int, optional): sample point bytes. Defaults to 2 bytes.
"""
self.window_n = window_n
self.shift_n = shift_n
self.window_ms = window_ms
self.shift_ms = shift_ms self.shift_ms = shift_ms
self.remained_audio = b'' self.sample_rate = sample_rate
self.sample_width = sample_width # int16 = 2; float32 = 4 self.sample_width = sample_width # int16 = 2; float32 = 4
self.remained_audio = b''
self.window_sec = float((self.window_n - 1) * self.shift_ms +
self.window_ms) / 1000.0
self.shift_sec = float(self.shift_n * self.shift_ms / 1000.0)
self.window_bytes = int(self.window_sec * self.sample_rate *
self.sample_width)
self.shift_bytes = int(self.shift_sec * self.sample_rate *
self.sample_width)
def frame_generator(self, audio): def frame_generator(self, audio):
"""Generates audio frames from PCM audio data. """Generates audio frames from PCM audio data.
@ -43,17 +66,13 @@ class ChunkBuffer(object):
audio = self.remained_audio + audio audio = self.remained_audio + audio
self.remained_audio = b'' self.remained_audio = b''
n = int(self.sample_rate * (self.frame_duration_ms / 1000.0) *
self.sample_width)
shift_n = int(self.sample_rate * (self.shift_ms / 1000.0) *
self.sample_width)
offset = 0 offset = 0
timestamp = 0.0 timestamp = 0.0
duration = (float(n) / self.sample_rate) / self.sample_width
shift_duration = (float(shift_n) / self.sample_rate) / self.sample_width while offset + self.window_bytes <= len(audio):
while offset + n <= len(audio): yield Frame(audio[offset:offset + self.window_bytes], timestamp,
yield Frame(audio[offset:offset + n], timestamp, duration) self.window_sec)
timestamp += shift_duration timestamp += self.shift_sec
offset += shift_n offset += self.shift_bytes
self.remained_audio += audio[offset:] self.remained_audio += audio[offset:]

@ -36,6 +36,10 @@ async def websocket_endpoint(websocket: WebSocket):
# init buffer # init buffer
chunk_buffer_conf = asr_engine.config.chunk_buffer_conf chunk_buffer_conf = asr_engine.config.chunk_buffer_conf
chunk_buffer = ChunkBuffer( chunk_buffer = ChunkBuffer(
window_n=7,
shift_n=4,
window_ms=20,
shift_ms=10,
sample_rate=chunk_buffer_conf['sample_rate'], sample_rate=chunk_buffer_conf['sample_rate'],
sample_width=chunk_buffer_conf['sample_width']) sample_width=chunk_buffer_conf['sample_width'])
# init vad # init vad
@ -75,11 +79,6 @@ async def websocket_endpoint(websocket: WebSocket):
elif "bytes" in message: elif "bytes" in message:
message = message["bytes"] message = message["bytes"]
# vad for input bytes audio
vad.add_audio(message)
message = b''.join(f for f in vad.vad_collector()
if f is not None)
engine_pool = get_engine_pool() engine_pool = get_engine_pool()
asr_engine = engine_pool['asr'] asr_engine = engine_pool['asr']
asr_results = "" asr_results = ""

@ -19,11 +19,11 @@ A few sklearn functions are modified in this script as per requirement.
""" """
import argparse import argparse
import warnings import warnings
from distutils.util import strtobool
import numpy as np import numpy as np
import scipy import scipy
import sklearn import sklearn
from distutils.util import strtobool
from scipy import sparse from scipy import sparse
from scipy.sparse.csgraph import connected_components from scipy.sparse.csgraph import connected_components
from scipy.sparse.csgraph import laplacian as csgraph_laplacian from scipy.sparse.csgraph import laplacian as csgraph_laplacian

@ -26,9 +26,9 @@ import argparse
import os import os
import re import re
import subprocess import subprocess
from distutils.util import strtobool
import numpy as np import numpy as np
from distutils.util import strtobool
FILE_IDS = re.compile(r"(?<=Speaker Diarization for).+(?=\*\*\*)") FILE_IDS = re.compile(r"(?<=Speaker Diarization for).+(?=\*\*\*)")
SCORED_SPEAKER_TIME = re.compile(r"(?<=SCORED SPEAKER TIME =)[\d.]+") SCORED_SPEAKER_TIME = re.compile(r"(?<=SCORED SPEAKER TIME =)[\d.]+")

Loading…
Cancel
Save