add asr websocket server note, test=doc

pull/1710/head
xiongxinlei 3 years ago
parent efc269b75f
commit 3ce4301665

@ -0,0 +1,10 @@
#!/bin/bash
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav
# asr
paddlespeech asr --input ./zh.wav
# asr + punc
paddlespeech asr --input ./zh.wav | paddlespeech text --task punc

@ -1,12 +1,11 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2021 Mobvoi Inc. All Rights Reserved. # Copyright 2021 Mobvoi Inc. All Rights Reserved.
# Author: zhendong.peng@mobvoi.com (Zhendong Peng) # Author: zhendong.peng@mobvoi.com (Zhendong Peng)
import argparse import argparse
from flask import Flask, render_template from flask import Flask
from flask import render_template
parser = argparse.ArgumentParser(description='training your network') parser = argparse.ArgumentParser(description='training your network')
parser.add_argument('--port', default=19999, type=int, help='port id') parser.add_argument('--port', default=19999, type=int, help='port id')
@ -14,9 +13,11 @@ args = parser.parse_args()
app = Flask(__name__) app = Flask(__name__)
@app.route('/') @app.route('/')
def index(): def index():
return render_template('index.html') return render_template('index.html')
if __name__ == '__main__': if __name__ == '__main__':
app.run(host='0.0.0.0', port=args.port, debug=True) app.run(host='0.0.0.0', port=args.port, debug=True)

@ -15,10 +15,11 @@
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
import argparse import argparse
import asyncio import asyncio
import codecs
import json import json
import logging import logging
import os import os
import codecs
import numpy as np import numpy as np
import soundfile import soundfile
import websockets import websockets
@ -37,15 +38,15 @@ class ASRAudioHandler:
chunk_size = 80 * 16 #80ms, sample_rate = 16kHz chunk_size = 80 * 16 #80ms, sample_rate = 16kHz
if x_len % chunk_size != 0: if x_len % chunk_size != 0:
padding_len_x = chunk_size - x_len % chunk_size padding_len_x = chunk_size - x_len % chunk_size
else: else:
padding_len_x = 0 padding_len_x = 0
padding = np.zeros((padding_len_x), dtype=samples.dtype) padding = np.zeros((padding_len_x), dtype=samples.dtype)
padded_x = np.concatenate([samples, padding], axis=0) padded_x = np.concatenate([samples, padding], axis=0)
assert ( x_len + padding_len_x ) % chunk_size == 0 assert (x_len + padding_len_x) % chunk_size == 0
num_chunk = (x_len + padding_len_x ) / chunk_size num_chunk = (x_len + padding_len_x) / chunk_size
num_chunk = int(num_chunk) num_chunk = int(num_chunk)
for i in range(0, num_chunk): for i in range(0, num_chunk):
@ -56,12 +57,7 @@ class ASRAudioHandler:
async def run(self, wavfile_path: str): async def run(self, wavfile_path: str):
logging.info("send a message to the server") logging.info("send a message to the server")
# 读取音频
# self.read_wave()
# 发送 websocket 的 handshake 协议头
async with websockets.connect(self.url) as ws: async with websockets.connect(self.url) as ws:
# server 端已经接收到 handshake 协议头
# 发送开始指令
audio_info = json.dumps( audio_info = json.dumps(
{ {
"name": "test.wav", "name": "test.wav",
@ -98,7 +94,6 @@ class ASRAudioHandler:
msg = json.loads(msg) msg = json.loads(msg)
logging.info("receive msg={}".format(msg)) logging.info("receive msg={}".format(msg))
return result return result

@ -24,12 +24,22 @@ class Frame(object):
class ChunkBuffer(object): class ChunkBuffer(object):
def __init__(self, def __init__(self,
window_n=7, # frame window_n=7,
shift_n=4, # frame shift_n=4,
window_ms=20, # ms window_ms=20,
shift_ms=10, # ms shift_ms=10,
sample_rate=16000, sample_rate=16000,
sample_width=2): sample_width=2):
"""audio sample data point buffer
Args:
window_n (int, optional): decode window frame length. Defaults to 7 frame.
shift_n (int, optional): decode shift frame length. Defaults to 4 frame.
window_ms (int, optional): frame length, ms. Defaults to 20 ms.
shift_ms (int, optional): shift length, ms. Defaults to 10 ms.
sample_rate (int, optional): audio sample rate. Defaults to 16000.
sample_width (int, optional): sample point bytes. Defaults to 2 bytes.
"""
self.window_n = window_n self.window_n = window_n
self.shift_n = shift_n self.shift_n = shift_n
self.window_ms = window_ms self.window_ms = window_ms
@ -38,11 +48,14 @@ class ChunkBuffer(object):
self.sample_width = sample_width # int16 = 2; float32 = 4 self.sample_width = sample_width # int16 = 2; float32 = 4
self.remained_audio = b'' self.remained_audio = b''
self.window_sec = float((self.window_n - 1) * self.shift_ms + self.window_ms) / 1000.0 self.window_sec = float((self.window_n - 1) * self.shift_ms +
self.window_ms) / 1000.0
self.shift_sec = float(self.shift_n * self.shift_ms / 1000.0) self.shift_sec = float(self.shift_n * self.shift_ms / 1000.0)
self.window_bytes = int(self.window_sec * self.sample_rate * self.sample_width) self.window_bytes = int(self.window_sec * self.sample_rate *
self.shift_bytes = int(self.shift_sec * self.sample_rate * self.sample_width) self.sample_width)
self.shift_bytes = int(self.shift_sec * self.sample_rate *
self.sample_width)
def frame_generator(self, audio): def frame_generator(self, audio):
"""Generates audio frames from PCM audio data. """Generates audio frames from PCM audio data.
@ -57,7 +70,8 @@ class ChunkBuffer(object):
timestamp = 0.0 timestamp = 0.0
while offset + self.window_bytes <= len(audio): while offset + self.window_bytes <= len(audio):
yield Frame(audio[offset:offset + self.window_bytes], timestamp, self.window_sec) yield Frame(audio[offset:offset + self.window_bytes], timestamp,
self.window_sec)
timestamp += self.shift_sec timestamp += self.shift_sec
offset += self.shift_bytes offset += self.shift_bytes

@ -79,11 +79,6 @@ async def websocket_endpoint(websocket: WebSocket):
elif "bytes" in message: elif "bytes" in message:
message = message["bytes"] message = message["bytes"]
# vad for input bytes audio
# vad.add_audio(message)
# message = b''.join(f for f in vad.vad_collector()
# if f is not None)
engine_pool = get_engine_pool() engine_pool = get_engine_pool()
asr_engine = engine_pool['asr'] asr_engine = engine_pool['asr']
asr_results = "" asr_results = ""
@ -95,6 +90,7 @@ async def websocket_endpoint(websocket: WebSocket):
sample_rate) sample_rate)
asr_engine.run(x_chunk, x_chunk_lens) asr_engine.run(x_chunk, x_chunk_lens)
asr_results = asr_engine.postprocess() asr_results = asr_engine.postprocess()
asr_results = asr_engine.postprocess() asr_results = asr_engine.postprocess()
resp = {'asr_results': asr_results} resp = {'asr_results': asr_results}

Loading…
Cancel
Save