add multi session result, test=doc

pull/1704/head
xiongxinlei 3 years ago
parent 10e825d9b2
commit 68731c61f4

@ -185,9 +185,9 @@ class PaddleASRConnectionHanddler:
f"we will use the transformer like model : {self.model_type}" f"we will use the transformer like model : {self.model_type}"
) )
self.advance_decoding(is_finished) self.advance_decoding(is_finished)
# self.update_result() self.update_result()
# return self.result_transcripts[0] return self.result_transcripts[0]
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)
else: else:
@ -225,22 +225,36 @@ class PaddleASRConnectionHanddler:
# we only process decoding_window frames for one chunk # we only process decoding_window frames for one chunk
left_frames = decoding_window left_frames = decoding_window
# logger.info(f"") # record the end for removing the processed feat
end = None end = None
for cur in range(0, num_frames - left_frames + 1, stride): for cur in range(0, num_frames - left_frames + 1, stride):
end = min(cur + decoding_window, num_frames) end = min(cur + decoding_window, num_frames)
print(f"cur chunk: {self.chunk_num}, cur: {cur}, end: {end}")
self.chunk_num += 1 self.chunk_num += 1
# chunk_xs = self.cached_feat[:, cur:end, :] chunk_xs = self.cached_feat[:, cur:end, :]
# (y, self.subsampling_cache, self.elayers_output_cache, (y, self.subsampling_cache, self.elayers_output_cache,
# self.conformer_cnn_cache) = self.model.encoder.forward_chunk( self.conformer_cnn_cache) = self.model.encoder.forward_chunk(
# chunk_xs, self.offset, required_cache_size, chunk_xs, self.offset, required_cache_size,
# self.subsampling_cache, self.elayers_output_cache, self.subsampling_cache, self.elayers_output_cache,
# self.conformer_cnn_cache) self.conformer_cnn_cache)
# outputs.append(y) outputs.append(y)
# update the offset # update the offset
# self.offset += y.shape[1] self.offset += y.shape[1]
ys = paddle.cat(outputs, 1)
masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool)
masks = masks.unsqueeze(1)
# get the ctc probs
ctc_probs = self.model.ctc.log_softmax(ys) # (1, maxlen, vocab_size)
ctc_probs = ctc_probs.squeeze(0)
# self.searcher.search(xs, ctc_probs, xs.place)
self.searcher.search(None, ctc_probs, self.cached_feat.place)
self.hyps = self.searcher.get_one_best_hyps()
# remove the processed feat # remove the processed feat
if end == num_frames: if end == num_frames:
self.cached_feat = None self.cached_feat = None
@ -248,19 +262,7 @@ class PaddleASRConnectionHanddler:
assert self.cached_feat.shape[0] == 1 assert self.cached_feat.shape[0] == 1
self.cached_feat = self.cached_feat[0,end:,:].unsqueeze(0) self.cached_feat = self.cached_feat[0,end:,:].unsqueeze(0)
assert len(self.cached_feat.shape) == 3, f"current cache feat shape is: {self.cached_feat.shape}" assert len(self.cached_feat.shape) == 3, f"current cache feat shape is: {self.cached_feat.shape}"
# ys = paddle.cat(outputs, 1)
# masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool)
# masks = masks.unsqueeze(1)
# # get the ctc probs
# ctc_probs = self.model.ctc.log_softmax(ys) # (1, maxlen, vocab_size)
# ctc_probs = ctc_probs.squeeze(0)
# # self.searcher.search(xs, ctc_probs, xs.place)
# self.searcher.search(None, ctc_probs, self.cached_feat.place)
# self.hyps = self.searcher.get_one_best_hyps()
# ys for rescoring # ys for rescoring
# return ys, masks # return ys, masks

@ -46,7 +46,7 @@ class CTCPrefixBeamSearch:
logger.info("start to ctc prefix search") logger.info("start to ctc prefix search")
# device = xs.place # device = xs.place
batch_size = xs.shape[0] batch_size = 1
beam_size = self.config.beam_size beam_size = self.config.beam_size
maxlen = ctc_probs.shape[0] maxlen = ctc_probs.shape[0]

@ -63,12 +63,12 @@ class ChunkBuffer(object):
the sample rate. the sample rate.
Yields Frames of the requested duration. Yields Frames of the requested duration.
""" """
audio = self.remained_audio + audio audio = self.remained_audio + audio
self.remained_audio = b'' self.remained_audio = b''
offset = 0 offset = 0
timestamp = 0.0 timestamp = 0.0
while offset + self.window_bytes <= len(audio): while offset + self.window_bytes <= len(audio):
yield Frame(audio[offset:offset + self.window_bytes], timestamp, yield Frame(audio[offset:offset + self.window_bytes], timestamp,
self.window_sec) self.window_sec)

@ -22,6 +22,7 @@ from starlette.websockets import WebSocketState as WebSocketState
from paddlespeech.server.engine.engine_pool import get_engine_pool from paddlespeech.server.engine.engine_pool import get_engine_pool
from paddlespeech.server.utils.buffer import ChunkBuffer from paddlespeech.server.utils.buffer import ChunkBuffer
from paddlespeech.server.utils.vad import VADAudio from paddlespeech.server.utils.vad import VADAudio
from paddlespeech.server.engine.asr.online.asr_engine import PaddleASRConnectionHanddler
router = APIRouter() router = APIRouter()
@ -33,6 +34,7 @@ async def websocket_endpoint(websocket: WebSocket):
engine_pool = get_engine_pool() engine_pool = get_engine_pool()
asr_engine = engine_pool['asr'] asr_engine = engine_pool['asr']
connection_handler = None
# init buffer # init buffer
# each websocekt connection has its own chunk buffer # each websocekt connection has its own chunk buffer
chunk_buffer_conf = asr_engine.config.chunk_buffer_conf chunk_buffer_conf = asr_engine.config.chunk_buffer_conf
@ -67,13 +69,17 @@ async def websocket_endpoint(websocket: WebSocket):
if message['signal'] == 'start': if message['signal'] == 'start':
resp = {"status": "ok", "signal": "server_ready"} resp = {"status": "ok", "signal": "server_ready"}
# do something at begining here # do something at begining here
# create the instance to process the audio
connection_handler = PaddleASRConnectionHanddler(asr_engine)
await websocket.send_json(resp) await websocket.send_json(resp)
elif message['signal'] == 'end': elif message['signal'] == 'end':
engine_pool = get_engine_pool() engine_pool = get_engine_pool()
asr_engine = engine_pool['asr'] asr_engine = engine_pool['asr']
# reset single engine for an new connection # reset single engine for an new connection
asr_results = connection_handler.decode(is_finished=True)
connection_handler.reset()
asr_engine.reset() asr_engine.reset()
resp = {"status": "ok", "signal": "finished"} resp = {"status": "ok", "signal": "finished", 'asr_results': asr_results}
await websocket.send_json(resp) await websocket.send_json(resp)
break break
else: else:
@ -81,23 +87,27 @@ async def websocket_endpoint(websocket: WebSocket):
await websocket.send_json(resp) await websocket.send_json(resp)
elif "bytes" in message: elif "bytes" in message:
message = message["bytes"] message = message["bytes"]
engine_pool = get_engine_pool() engine_pool = get_engine_pool()
asr_engine = engine_pool['asr'] asr_engine = engine_pool['asr']
asr_results = "" asr_results = ""
frames = chunk_buffer.frame_generator(message) connection_handler.extract_feat(message)
for frame in frames: asr_results = connection_handler.decode(is_finished=False)
# get the pcm data from the bytes # connection_handler.
samples = np.frombuffer(frame.bytes, dtype=np.int16) # frames = chunk_buffer.frame_generator(message)
sample_rate = asr_engine.config.sample_rate # for frame in frames:
x_chunk, x_chunk_lens = asr_engine.preprocess(samples, # # get the pcm data from the bytes
sample_rate) # samples = np.frombuffer(frame.bytes, dtype=np.int16)
asr_engine.run(x_chunk, x_chunk_lens) # sample_rate = asr_engine.config.sample_rate
asr_results = asr_engine.postprocess() # x_chunk, x_chunk_lens = asr_engine.preprocess(samples,
# sample_rate)
# asr_engine.run(x_chunk, x_chunk_lens)
# asr_results = asr_engine.postprocess()
asr_results = asr_engine.postprocess() # # connection accept the sample data frame by frame
# asr_results = asr_engine.postprocess()
resp = {'asr_results': asr_results} resp = {'asr_results': asr_results}
print("\n")
await websocket.send_json(resp) await websocket.send_json(resp)
except WebSocketDisconnect: except WebSocketDisconnect:
pass pass

Loading…
Cancel
Save