fix the ws send bug, cache buffer, text=doc

pull/1710/head
xiongxinlei 3 years ago
parent 23a6534119
commit 89b102a7dd

@ -1,10 +1,11 @@
#!/bin/bash
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav
# wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav
# asr
paddlespeech asr --input ./zh.wav
export CUDA_VISIBLE_DEVICES=0
paddlespeech asr --input audio/119994.wav -v
# asr + punc
paddlespeech asr --input ./zh.wav | paddlespeech text --task punc
# paddlespeech asr --input ./zh.wav | paddlespeech text --task punc

@ -5,7 +5,7 @@ process:
n_mels: 80
n_shift: 160
win_length: 400
dither: 0.1
dither: 0.0
- type: cmvn_json
cmvn_path: data/mean_std.json
# these three processes are a.k.a. SpecAugument

@ -3,9 +3,9 @@ decode_batch_size: 128
error_rate_type: cer
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
ctc_weight: 0.5 # ctc weight for attention rescoring decode mode.
decoding_chunk_size: -1 # decoding chunk size. Defaults to -1.
decoding_chunk_size: 1 # decoding chunk size. Defaults to -1.
# <0: for decoding, use full chunk.
# >0: for decoding, use fixed chunk size as set.
# 0: used for training, it's prohibited here.
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
simulate_streaming: False # simulate streaming inference. Defaults to False.
simulate_streaming: True # simulate streaming inference. Defaults to False.

@ -3,12 +3,12 @@ source path.sh
set -e
gpus=0,1,2,3
stage=0
stop_stage=50
conf_path=conf/conformer.yaml
stage=5
stop_stage=5
conf_path=conf/chunk_conformer.yaml
decode_conf_path=conf/tuning/decode.yaml
avg_num=20
audio_file=data/demo_01_03.wav
audio_file=audio/zh.wav
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
@ -44,7 +44,7 @@ fi
# Optionally, you can add LM and test it with runtime.
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# test a single .wav file
CUDA_VISIBLE_DEVICES=0 ./local/test_wav.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${audio_file} || exit -1
CUDA_VISIBLE_DEVICES=0 ./local/test_wav.sh ${conf_path} ${decode_conf_path} exp/chunk_conformer/checkpoints/multi_cn ${audio_file} || exit -1
fi
# Not supported at now!!!

@ -14,3 +14,6 @@
import _locale
_locale._getdefaultlocale = (lambda *args: ['en_US', 'utf8'])

@ -79,7 +79,6 @@ class U2Infer():
ilen = paddle.to_tensor(feat.shape[0])
xs = paddle.to_tensor(feat, dtype='float32').unsqueeze(axis=0)
decode_config = self.config.decode
result_transcripts = self.model.decode(
xs,
@ -129,9 +128,12 @@ if __name__ == "__main__":
args = parser.parse_args()
config = CfgNode(new_allowed=True)
if args.config:
print(f"load config: {args.config}")
config.merge_from_file(args.config)
if args.decode_cfg:
print(f"load decode cfg: {args.decode_cfg}")
decode_confs = CfgNode(new_allowed=True)
decode_confs.merge_from_file(args.decode_cfg)
config.decode = decode_confs

@ -4,7 +4,7 @@
# SERVER SETTING #
#################################################################################
host: 0.0.0.0
port: 8091
port: 8096
# The task format in the engin_list is: <speech task>_<engine type>
# task choices = ['asr_online', 'tts_online']

@ -17,7 +17,8 @@ import argparse
import asyncio
import json
import logging
import os
import codecs
import numpy as np
import soundfile
import websockets
@ -32,22 +33,23 @@ class ASRAudioHandler:
def read_wave(self, wavfile_path: str):
samples, sample_rate = soundfile.read(wavfile_path, dtype='int16')
x_len = len(samples)
chunk_stride = 40 * 16 #40ms, sample_rate = 16kHz
# chunk_stride = 40 * 16 #40ms, sample_rate = 16kHz
chunk_size = 80 * 16 #80ms, sample_rate = 16kHz
if (x_len - chunk_size) % chunk_stride != 0:
padding_len_x = chunk_stride - (x_len - chunk_size) % chunk_stride
if x_len % chunk_size != 0:
padding_len_x = chunk_size - x_len % chunk_size
else:
padding_len_x = 0
padding = np.zeros((padding_len_x), dtype=samples.dtype)
padded_x = np.concatenate([samples, padding], axis=0)
num_chunk = (x_len + padding_len_x - chunk_size) / chunk_stride + 1
assert ( x_len + padding_len_x ) % chunk_size == 0
num_chunk = (x_len + padding_len_x ) / chunk_size
num_chunk = int(num_chunk)
for i in range(0, num_chunk):
start = i * chunk_stride
start = i * chunk_size
end = start + chunk_size
x_chunk = padded_x[start:end]
yield x_chunk
@ -77,8 +79,10 @@ class ASRAudioHandler:
for chunk_data in self.read_wave(wavfile_path):
await ws.send(chunk_data.tobytes())
msg = await ws.recv()
msg = json.loads(msg)
logging.info("receive msg={}".format(msg))
result = msg
# finished
audio_info = json.dumps(
{
@ -91,16 +95,36 @@ class ASRAudioHandler:
separators=(',', ': '))
await ws.send(audio_info)
msg = await ws.recv()
msg = json.loads(msg)
logging.info("receive msg={}".format(msg))
return result
def main(args):
logging.basicConfig(level=logging.INFO)
logging.info("asr websocket client start")
handler = ASRAudioHandler("127.0.0.1", 8091)
handler = ASRAudioHandler("127.0.0.1", 8096)
loop = asyncio.get_event_loop()
loop.run_until_complete(handler.run(args.wavfile))
logging.info("asr websocket client finished")
# support to process single audio file
if args.wavfile and os.path.exists(args.wavfile):
logging.info(f"start to process the wavscp: {args.wavfile}")
result = loop.run_until_complete(handler.run(args.wavfile))
result = result["asr_results"]
logging.info(f"asr websocket client finished : {result}")
# support to process batch audios from wav.scp
if args.wavscp and os.path.exists(args.wavscp):
logging.info(f"start to process the wavscp: {args.wavscp}")
with codecs.open(args.wavscp, 'r', encoding='utf-8') as f,\
codecs.open("result.txt", 'w', encoding='utf-8') as w:
for line in f:
utt_name, utt_path = line.strip().split()
result = loop.run_until_complete(handler.run(utt_path))
result = result["asr_results"]
w.write(f"{utt_name} {result}\n")
if __name__ == "__main__":
@ -110,6 +134,8 @@ if __name__ == "__main__":
action="store",
help="wav file path ",
default="./16_audio.wav")
parser.add_argument(
"--wavscp", type=str, default=None, help="The batch audios dict text")
args = parser.parse_args()
main(args)

@ -24,15 +24,25 @@ class Frame(object):
class ChunkBuffer(object):
def __init__(self,
frame_duration_ms=80,
shift_ms=40,
window_n=7, # frame
shift_n=4, # frame
window_ms=20, # ms
shift_ms=10, # ms
sample_rate=16000,
sample_width=2):
self.sample_rate = sample_rate
self.frame_duration_ms = frame_duration_ms
self.window_n = window_n
self.shift_n = shift_n
self.window_ms = window_ms
self.shift_ms = shift_ms
self.remained_audio = b''
self.sample_rate = sample_rate
self.sample_width = sample_width # int16 = 2; float32 = 4
self.remained_audio = b''
self.window_sec = float((self.window_n - 1) * self.shift_ms + self.window_ms) / 1000.0
self.shift_sec = float(self.shift_n * self.shift_ms / 1000.0)
self.window_bytes = int(self.window_sec * self.sample_rate * self.sample_width)
self.shift_bytes = int(self.shift_sec * self.sample_rate * self.sample_width)
def frame_generator(self, audio):
"""Generates audio frames from PCM audio data.
@ -43,17 +53,12 @@ class ChunkBuffer(object):
audio = self.remained_audio + audio
self.remained_audio = b''
n = int(self.sample_rate * (self.frame_duration_ms / 1000.0) *
self.sample_width)
shift_n = int(self.sample_rate * (self.shift_ms / 1000.0) *
self.sample_width)
offset = 0
timestamp = 0.0
duration = (float(n) / self.sample_rate) / self.sample_width
shift_duration = (float(shift_n) / self.sample_rate) / self.sample_width
while offset + n <= len(audio):
yield Frame(audio[offset:offset + n], timestamp, duration)
timestamp += shift_duration
offset += shift_n
while offset + self.window_bytes <= len(audio):
yield Frame(audio[offset:offset + self.window_bytes], timestamp, self.window_sec)
timestamp += self.shift_sec
offset += self.shift_bytes
self.remained_audio += audio[offset:]

@ -36,6 +36,10 @@ async def websocket_endpoint(websocket: WebSocket):
# init buffer
chunk_buffer_conf = asr_engine.config.chunk_buffer_conf
chunk_buffer = ChunkBuffer(
window_n=7,
shift_n=4,
window_ms=20,
shift_ms=10,
sample_rate=chunk_buffer_conf['sample_rate'],
sample_width=chunk_buffer_conf['sample_width'])
# init vad
@ -76,9 +80,9 @@ async def websocket_endpoint(websocket: WebSocket):
message = message["bytes"]
# vad for input bytes audio
vad.add_audio(message)
message = b''.join(f for f in vad.vad_collector()
if f is not None)
# vad.add_audio(message)
# message = b''.join(f for f in vad.vad_collector()
# if f is not None)
engine_pool = get_engine_pool()
asr_engine = engine_pool['asr']
@ -89,9 +93,9 @@ async def websocket_endpoint(websocket: WebSocket):
sample_rate = asr_engine.config.sample_rate
x_chunk, x_chunk_lens = asr_engine.preprocess(samples,
sample_rate)
print(x_chunk_lens)
asr_engine.run(x_chunk, x_chunk_lens)
asr_results = asr_engine.postprocess()
asr_results = asr_engine.postprocess()
resp = {'asr_results': asr_results}

@ -168,7 +168,7 @@ class DevelopCommand(develop):
def run(self):
develop.run(self)
# must after develop.run, or pkg install by shell will not see
self.execute(_post_install, (self.install_lib, ), msg="Post Install...")
# self.execute(_post_install, (self.install_lib, ), msg="Post Install...")
class InstallCommand(install):

Loading…
Cancel
Save