add multi session result, test=doc

pull/1704/head
xiongxinlei 3 years ago
parent 10e825d9b2
commit 68731c61f4

@ -185,9 +185,9 @@ class PaddleASRConnectionHanddler:
f"we will use the transformer like model : {self.model_type}"
)
self.advance_decoding(is_finished)
# self.update_result()
self.update_result()
# return self.result_transcripts[0]
return self.result_transcripts[0]
except Exception as e:
logger.exception(e)
else:
@ -225,21 +225,35 @@ class PaddleASRConnectionHanddler:
# we only process decoding_window frames for one chunk
left_frames = decoding_window
# logger.info(f"")
# record the end for removing the processed feat
end = None
for cur in range(0, num_frames - left_frames + 1, stride):
end = min(cur + decoding_window, num_frames)
print(f"cur chunk: {self.chunk_num}, cur: {cur}, end: {end}")
self.chunk_num += 1
# chunk_xs = self.cached_feat[:, cur:end, :]
# (y, self.subsampling_cache, self.elayers_output_cache,
# self.conformer_cnn_cache) = self.model.encoder.forward_chunk(
# chunk_xs, self.offset, required_cache_size,
# self.subsampling_cache, self.elayers_output_cache,
# self.conformer_cnn_cache)
# outputs.append(y)
chunk_xs = self.cached_feat[:, cur:end, :]
(y, self.subsampling_cache, self.elayers_output_cache,
self.conformer_cnn_cache) = self.model.encoder.forward_chunk(
chunk_xs, self.offset, required_cache_size,
self.subsampling_cache, self.elayers_output_cache,
self.conformer_cnn_cache)
outputs.append(y)
# update the offset
# self.offset += y.shape[1]
self.offset += y.shape[1]
ys = paddle.cat(outputs, 1)
masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool)
masks = masks.unsqueeze(1)
# get the ctc probs
ctc_probs = self.model.ctc.log_softmax(ys) # (1, maxlen, vocab_size)
ctc_probs = ctc_probs.squeeze(0)
# self.searcher.search(xs, ctc_probs, xs.place)
self.searcher.search(None, ctc_probs, self.cached_feat.place)
self.hyps = self.searcher.get_one_best_hyps()
# remove the processed feat
if end == num_frames:
@ -248,18 +262,6 @@ class PaddleASRConnectionHanddler:
assert self.cached_feat.shape[0] == 1
self.cached_feat = self.cached_feat[0,end:,:].unsqueeze(0)
assert len(self.cached_feat.shape) == 3, f"current cache feat shape is: {self.cached_feat.shape}"
# ys = paddle.cat(outputs, 1)
# masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool)
# masks = masks.unsqueeze(1)
# # get the ctc probs
# ctc_probs = self.model.ctc.log_softmax(ys) # (1, maxlen, vocab_size)
# ctc_probs = ctc_probs.squeeze(0)
# # self.searcher.search(xs, ctc_probs, xs.place)
# self.searcher.search(None, ctc_probs, self.cached_feat.place)
# self.hyps = self.searcher.get_one_best_hyps()
# ys for rescoring
# return ys, masks

@ -46,7 +46,7 @@ class CTCPrefixBeamSearch:
logger.info("start to ctc prefix search")
# device = xs.place
batch_size = xs.shape[0]
batch_size = 1
beam_size = self.config.beam_size
maxlen = ctc_probs.shape[0]

@ -63,12 +63,12 @@ class ChunkBuffer(object):
the sample rate.
Yields Frames of the requested duration.
"""
audio = self.remained_audio + audio
self.remained_audio = b''
offset = 0
timestamp = 0.0
while offset + self.window_bytes <= len(audio):
yield Frame(audio[offset:offset + self.window_bytes], timestamp,
self.window_sec)

@ -22,6 +22,7 @@ from starlette.websockets import WebSocketState as WebSocketState
from paddlespeech.server.engine.engine_pool import get_engine_pool
from paddlespeech.server.utils.buffer import ChunkBuffer
from paddlespeech.server.utils.vad import VADAudio
from paddlespeech.server.engine.asr.online.asr_engine import PaddleASRConnectionHanddler
router = APIRouter()
@ -33,6 +34,7 @@ async def websocket_endpoint(websocket: WebSocket):
engine_pool = get_engine_pool()
asr_engine = engine_pool['asr']
connection_handler = None
# init buffer
# each websocekt connection has its own chunk buffer
chunk_buffer_conf = asr_engine.config.chunk_buffer_conf
@ -67,13 +69,17 @@ async def websocket_endpoint(websocket: WebSocket):
if message['signal'] == 'start':
resp = {"status": "ok", "signal": "server_ready"}
# do something at begining here
# create the instance to process the audio
connection_handler = PaddleASRConnectionHanddler(asr_engine)
await websocket.send_json(resp)
elif message['signal'] == 'end':
engine_pool = get_engine_pool()
asr_engine = engine_pool['asr']
# reset single engine for an new connection
asr_results = connection_handler.decode(is_finished=True)
connection_handler.reset()
asr_engine.reset()
resp = {"status": "ok", "signal": "finished"}
resp = {"status": "ok", "signal": "finished", 'asr_results': asr_results}
await websocket.send_json(resp)
break
else:
@ -81,23 +87,27 @@ async def websocket_endpoint(websocket: WebSocket):
await websocket.send_json(resp)
elif "bytes" in message:
message = message["bytes"]
engine_pool = get_engine_pool()
asr_engine = engine_pool['asr']
asr_results = ""
frames = chunk_buffer.frame_generator(message)
for frame in frames:
# get the pcm data from the bytes
samples = np.frombuffer(frame.bytes, dtype=np.int16)
sample_rate = asr_engine.config.sample_rate
x_chunk, x_chunk_lens = asr_engine.preprocess(samples,
sample_rate)
asr_engine.run(x_chunk, x_chunk_lens)
asr_results = asr_engine.postprocess()
connection_handler.extract_feat(message)
asr_results = connection_handler.decode(is_finished=False)
# connection_handler.
# frames = chunk_buffer.frame_generator(message)
# for frame in frames:
# # get the pcm data from the bytes
# samples = np.frombuffer(frame.bytes, dtype=np.int16)
# sample_rate = asr_engine.config.sample_rate
# x_chunk, x_chunk_lens = asr_engine.preprocess(samples,
# sample_rate)
# asr_engine.run(x_chunk, x_chunk_lens)
# asr_results = asr_engine.postprocess()
asr_results = asr_engine.postprocess()
resp = {'asr_results': asr_results}
# # connection accept the sample data frame by frame
# asr_results = asr_engine.postprocess()
resp = {'asr_results': asr_results}
print("\n")
await websocket.send_json(resp)
except WebSocketDisconnect:
pass

Loading…
Cancel
Save