fix the ws send bug, cache buffer, text=doc

pull/1710/head
xiongxinlei 3 years ago
parent 23a6534119
commit 89b102a7dd

@ -1,10 +1,11 @@
#!/bin/bash #!/bin/bash
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav # wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav
# asr # asr
paddlespeech asr --input ./zh.wav export CUDA_VISIBLE_DEVICES=0
paddlespeech asr --input audio/119994.wav -v
# asr + punc # asr + punc
paddlespeech asr --input ./zh.wav | paddlespeech text --task punc # paddlespeech asr --input ./zh.wav | paddlespeech text --task punc

@ -5,7 +5,7 @@ process:
n_mels: 80 n_mels: 80
n_shift: 160 n_shift: 160
win_length: 400 win_length: 400
dither: 0.1 dither: 0.0
- type: cmvn_json - type: cmvn_json
cmvn_path: data/mean_std.json cmvn_path: data/mean_std.json
# these three processes are a.k.a. SpecAugument # these three processes are a.k.a. SpecAugument

@ -3,9 +3,9 @@ decode_batch_size: 128
error_rate_type: cer error_rate_type: cer
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring' decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
ctc_weight: 0.5 # ctc weight for attention rescoring decode mode. ctc_weight: 0.5 # ctc weight for attention rescoring decode mode.
decoding_chunk_size: -1 # decoding chunk size. Defaults to -1. decoding_chunk_size: 1 # decoding chunk size. Defaults to -1.
# <0: for decoding, use full chunk. # <0: for decoding, use full chunk.
# >0: for decoding, use fixed chunk size as set. # >0: for decoding, use fixed chunk size as set.
# 0: used for training, it's prohibited here. # 0: used for training, it's prohibited here.
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1. num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
simulate_streaming: False # simulate streaming inference. Defaults to False. simulate_streaming: True # simulate streaming inference. Defaults to False.

@ -3,12 +3,12 @@ source path.sh
set -e set -e
gpus=0,1,2,3 gpus=0,1,2,3
stage=0 stage=5
stop_stage=50 stop_stage=5
conf_path=conf/conformer.yaml conf_path=conf/chunk_conformer.yaml
decode_conf_path=conf/tuning/decode.yaml decode_conf_path=conf/tuning/decode.yaml
avg_num=20 avg_num=20
audio_file=data/demo_01_03.wav audio_file=audio/zh.wav
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
@ -44,7 +44,7 @@ fi
# Optionally, you can add LM and test it with runtime. # Optionally, you can add LM and test it with runtime.
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# test a single .wav file # test a single .wav file
CUDA_VISIBLE_DEVICES=0 ./local/test_wav.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${audio_file} || exit -1 CUDA_VISIBLE_DEVICES=0 ./local/test_wav.sh ${conf_path} ${decode_conf_path} exp/chunk_conformer/checkpoints/multi_cn ${audio_file} || exit -1
fi fi
# Not supported at now!!! # Not supported at now!!!

@ -14,3 +14,6 @@
import _locale import _locale
_locale._getdefaultlocale = (lambda *args: ['en_US', 'utf8']) _locale._getdefaultlocale = (lambda *args: ['en_US', 'utf8'])

@ -79,7 +79,6 @@ class U2Infer():
ilen = paddle.to_tensor(feat.shape[0]) ilen = paddle.to_tensor(feat.shape[0])
xs = paddle.to_tensor(feat, dtype='float32').unsqueeze(axis=0) xs = paddle.to_tensor(feat, dtype='float32').unsqueeze(axis=0)
decode_config = self.config.decode decode_config = self.config.decode
result_transcripts = self.model.decode( result_transcripts = self.model.decode(
xs, xs,
@ -129,9 +128,12 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
config = CfgNode(new_allowed=True) config = CfgNode(new_allowed=True)
if args.config: if args.config:
print(f"load config: {args.config}")
config.merge_from_file(args.config) config.merge_from_file(args.config)
if args.decode_cfg: if args.decode_cfg:
print(f"load decode cfg: {args.decode_cfg}")
decode_confs = CfgNode(new_allowed=True) decode_confs = CfgNode(new_allowed=True)
decode_confs.merge_from_file(args.decode_cfg) decode_confs.merge_from_file(args.decode_cfg)
config.decode = decode_confs config.decode = decode_confs

@ -4,7 +4,7 @@
# SERVER SETTING # # SERVER SETTING #
################################################################################# #################################################################################
host: 0.0.0.0 host: 0.0.0.0
port: 8091 port: 8096
# The task format in the engin_list is: <speech task>_<engine type> # The task format in the engin_list is: <speech task>_<engine type>
# task choices = ['asr_online', 'tts_online'] # task choices = ['asr_online', 'tts_online']

@ -17,7 +17,8 @@ import argparse
import asyncio import asyncio
import json import json
import logging import logging
import os
import codecs
import numpy as np import numpy as np
import soundfile import soundfile
import websockets import websockets
@ -32,22 +33,23 @@ class ASRAudioHandler:
def read_wave(self, wavfile_path: str): def read_wave(self, wavfile_path: str):
samples, sample_rate = soundfile.read(wavfile_path, dtype='int16') samples, sample_rate = soundfile.read(wavfile_path, dtype='int16')
x_len = len(samples) x_len = len(samples)
chunk_stride = 40 * 16 #40ms, sample_rate = 16kHz # chunk_stride = 40 * 16 #40ms, sample_rate = 16kHz
chunk_size = 80 * 16 #80ms, sample_rate = 16kHz chunk_size = 80 * 16 #80ms, sample_rate = 16kHz
if (x_len - chunk_size) % chunk_stride != 0: if x_len % chunk_size != 0:
padding_len_x = chunk_stride - (x_len - chunk_size) % chunk_stride padding_len_x = chunk_size - x_len % chunk_size
else: else:
padding_len_x = 0 padding_len_x = 0
padding = np.zeros((padding_len_x), dtype=samples.dtype) padding = np.zeros((padding_len_x), dtype=samples.dtype)
padded_x = np.concatenate([samples, padding], axis=0) padded_x = np.concatenate([samples, padding], axis=0)
num_chunk = (x_len + padding_len_x - chunk_size) / chunk_stride + 1 assert ( x_len + padding_len_x ) % chunk_size == 0
num_chunk = (x_len + padding_len_x ) / chunk_size
num_chunk = int(num_chunk) num_chunk = int(num_chunk)
for i in range(0, num_chunk): for i in range(0, num_chunk):
start = i * chunk_stride start = i * chunk_size
end = start + chunk_size end = start + chunk_size
x_chunk = padded_x[start:end] x_chunk = padded_x[start:end]
yield x_chunk yield x_chunk
@ -77,8 +79,10 @@ class ASRAudioHandler:
for chunk_data in self.read_wave(wavfile_path): for chunk_data in self.read_wave(wavfile_path):
await ws.send(chunk_data.tobytes()) await ws.send(chunk_data.tobytes())
msg = await ws.recv() msg = await ws.recv()
msg = json.loads(msg)
logging.info("receive msg={}".format(msg)) logging.info("receive msg={}".format(msg))
result = msg
# finished # finished
audio_info = json.dumps( audio_info = json.dumps(
{ {
@ -91,16 +95,36 @@ class ASRAudioHandler:
separators=(',', ': ')) separators=(',', ': '))
await ws.send(audio_info) await ws.send(audio_info)
msg = await ws.recv() msg = await ws.recv()
msg = json.loads(msg)
logging.info("receive msg={}".format(msg)) logging.info("receive msg={}".format(msg))
return result
def main(args): def main(args):
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logging.info("asr websocket client start") logging.info("asr websocket client start")
handler = ASRAudioHandler("127.0.0.1", 8091) handler = ASRAudioHandler("127.0.0.1", 8096)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
loop.run_until_complete(handler.run(args.wavfile))
logging.info("asr websocket client finished") # support to process single audio file
if args.wavfile and os.path.exists(args.wavfile):
logging.info(f"start to process the wavscp: {args.wavfile}")
result = loop.run_until_complete(handler.run(args.wavfile))
result = result["asr_results"]
logging.info(f"asr websocket client finished : {result}")
# support to process batch audios from wav.scp
if args.wavscp and os.path.exists(args.wavscp):
logging.info(f"start to process the wavscp: {args.wavscp}")
with codecs.open(args.wavscp, 'r', encoding='utf-8') as f,\
codecs.open("result.txt", 'w', encoding='utf-8') as w:
for line in f:
utt_name, utt_path = line.strip().split()
result = loop.run_until_complete(handler.run(utt_path))
result = result["asr_results"]
w.write(f"{utt_name} {result}\n")
if __name__ == "__main__": if __name__ == "__main__":
@ -110,6 +134,8 @@ if __name__ == "__main__":
action="store", action="store",
help="wav file path ", help="wav file path ",
default="./16_audio.wav") default="./16_audio.wav")
parser.add_argument(
"--wavscp", type=str, default=None, help="The batch audios dict text")
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

@ -24,15 +24,25 @@ class Frame(object):
class ChunkBuffer(object): class ChunkBuffer(object):
def __init__(self, def __init__(self,
frame_duration_ms=80, window_n=7, # frame
shift_ms=40, shift_n=4, # frame
window_ms=20, # ms
shift_ms=10, # ms
sample_rate=16000, sample_rate=16000,
sample_width=2): sample_width=2):
self.sample_rate = sample_rate self.window_n = window_n
self.frame_duration_ms = frame_duration_ms self.shift_n = shift_n
self.window_ms = window_ms
self.shift_ms = shift_ms self.shift_ms = shift_ms
self.remained_audio = b'' self.sample_rate = sample_rate
self.sample_width = sample_width # int16 = 2; float32 = 4 self.sample_width = sample_width # int16 = 2; float32 = 4
self.remained_audio = b''
self.window_sec = float((self.window_n - 1) * self.shift_ms + self.window_ms) / 1000.0
self.shift_sec = float(self.shift_n * self.shift_ms / 1000.0)
self.window_bytes = int(self.window_sec * self.sample_rate * self.sample_width)
self.shift_bytes = int(self.shift_sec * self.sample_rate * self.sample_width)
def frame_generator(self, audio): def frame_generator(self, audio):
"""Generates audio frames from PCM audio data. """Generates audio frames from PCM audio data.
@ -43,17 +53,12 @@ class ChunkBuffer(object):
audio = self.remained_audio + audio audio = self.remained_audio + audio
self.remained_audio = b'' self.remained_audio = b''
n = int(self.sample_rate * (self.frame_duration_ms / 1000.0) *
self.sample_width)
shift_n = int(self.sample_rate * (self.shift_ms / 1000.0) *
self.sample_width)
offset = 0 offset = 0
timestamp = 0.0 timestamp = 0.0
duration = (float(n) / self.sample_rate) / self.sample_width
shift_duration = (float(shift_n) / self.sample_rate) / self.sample_width while offset + self.window_bytes <= len(audio):
while offset + n <= len(audio): yield Frame(audio[offset:offset + self.window_bytes], timestamp, self.window_sec)
yield Frame(audio[offset:offset + n], timestamp, duration) timestamp += self.shift_sec
timestamp += shift_duration offset += self.shift_bytes
offset += shift_n
self.remained_audio += audio[offset:] self.remained_audio += audio[offset:]

@ -36,6 +36,10 @@ async def websocket_endpoint(websocket: WebSocket):
# init buffer # init buffer
chunk_buffer_conf = asr_engine.config.chunk_buffer_conf chunk_buffer_conf = asr_engine.config.chunk_buffer_conf
chunk_buffer = ChunkBuffer( chunk_buffer = ChunkBuffer(
window_n=7,
shift_n=4,
window_ms=20,
shift_ms=10,
sample_rate=chunk_buffer_conf['sample_rate'], sample_rate=chunk_buffer_conf['sample_rate'],
sample_width=chunk_buffer_conf['sample_width']) sample_width=chunk_buffer_conf['sample_width'])
# init vad # init vad
@ -76,9 +80,9 @@ async def websocket_endpoint(websocket: WebSocket):
message = message["bytes"] message = message["bytes"]
# vad for input bytes audio # vad for input bytes audio
vad.add_audio(message) # vad.add_audio(message)
message = b''.join(f for f in vad.vad_collector() # message = b''.join(f for f in vad.vad_collector()
if f is not None) # if f is not None)
engine_pool = get_engine_pool() engine_pool = get_engine_pool()
asr_engine = engine_pool['asr'] asr_engine = engine_pool['asr']
@ -89,9 +93,9 @@ async def websocket_endpoint(websocket: WebSocket):
sample_rate = asr_engine.config.sample_rate sample_rate = asr_engine.config.sample_rate
x_chunk, x_chunk_lens = asr_engine.preprocess(samples, x_chunk, x_chunk_lens = asr_engine.preprocess(samples,
sample_rate) sample_rate)
print(x_chunk_lens)
asr_engine.run(x_chunk, x_chunk_lens) asr_engine.run(x_chunk, x_chunk_lens)
asr_results = asr_engine.postprocess() asr_results = asr_engine.postprocess()
asr_results = asr_engine.postprocess() asr_results = asr_engine.postprocess()
resp = {'asr_results': asr_results} resp = {'asr_results': asr_results}

@ -168,7 +168,7 @@ class DevelopCommand(develop):
def run(self): def run(self):
develop.run(self) develop.run(self)
# must after develop.run, or pkg install by shell will not see # must after develop.run, or pkg install by shell will not see
self.execute(_post_install, (self.install_lib, ), msg="Post Install...") # self.execute(_post_install, (self.install_lib, ), msg="Post Install...")
class InstallCommand(install): class InstallCommand(install):

Loading…
Cancel
Save