add connection stability, test=doc

pull/1704/head
xiongxinlei 3 years ago
parent 68731c61f4
commit 05a8a4b5fc

@ -83,8 +83,10 @@ pretrained_models = {
class PaddleASRConnectionHanddler: class PaddleASRConnectionHanddler:
def __init__(self, asr_engine): def __init__(self, asr_engine):
super().__init__() super().__init__()
logger.info("create an paddle asr connection handler to process the websocket connection")
self.config = asr_engine.config self.config = asr_engine.config
self.model_config = asr_engine.executor.config self.model_config = asr_engine.executor.config
self.model = asr_engine.executor.model
self.asr_engine = asr_engine self.asr_engine = asr_engine
self.init() self.init()
@ -149,6 +151,10 @@ class PaddleASRConnectionHanddler:
assert(len(self.cached_feat.shape) == 3) assert(len(self.cached_feat.shape) == 3)
self.cached_feat = paddle.concat([self.cached_feat, x_chunk], axis=1) self.cached_feat = paddle.concat([self.cached_feat, x_chunk], axis=1)
# set the feat device
if self.device is None:
self.device = self.cached_feat.place
num_frames = x_chunk.shape[1] num_frames = x_chunk.shape[1]
self.num_frames += num_frames self.num_frames += num_frames
self.remained_wav = self.remained_wav[self.n_shift * num_frames:] self.remained_wav = self.remained_wav[self.n_shift * num_frames:]
@ -165,16 +171,17 @@ class PaddleASRConnectionHanddler:
self.subsampling_cache = None self.subsampling_cache = None
self.elayers_output_cache = None self.elayers_output_cache = None
self.conformer_cnn_cache = None self.conformer_cnn_cache = None
self.encoder_outs_ = None self.encoder_out = None
self.cached_feat = None self.cached_feat = None
self.remained_wav = None self.remained_wav = None
self.offset = 0 self.offset = 0
self.num_samples = 0 self.num_samples = 0
self.device = None
self.hyps = []
self.num_frames = 0 self.num_frames = 0
self.chunk_num = 0 self.chunk_num = 0
self.global_frame_offset = 0 self.global_frame_offset = 0
self.result = [] self.result_transcripts = ['']
def decode(self, is_finished=False): def decode(self, is_finished=False):
if "deepspeech2online" in self.model_type: if "deepspeech2online" in self.model_type:
@ -187,7 +194,6 @@ class PaddleASRConnectionHanddler:
self.advance_decoding(is_finished) self.advance_decoding(is_finished)
self.update_result() self.update_result()
return self.result_transcripts[0]
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)
else: else:
@ -203,14 +209,24 @@ class PaddleASRConnectionHanddler:
subsampling = self.model.encoder.embed.subsampling_rate subsampling = self.model.encoder.embed.subsampling_rate
context = self.model.encoder.embed.right_context + 1 context = self.model.encoder.embed.right_context + 1
stride = subsampling * decoding_chunk_size stride = subsampling * decoding_chunk_size
cached_feature_num = context - subsampling # processed chunk feature cached for next chunk
# decoding window for model # decoding window for model
decoding_window = (decoding_chunk_size - 1) * subsampling + context decoding_window = (decoding_chunk_size - 1) * subsampling + context
if self.cached_feat is None:
logger.info("no audio feat, please input more pcm data")
return
num_frames = self.cached_feat.shape[1] num_frames = self.cached_feat.shape[1]
logger.info(f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames") logger.info(f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames")
# the cached feat must be larger decoding_window # the cached feat must be larger decoding_window
if num_frames < decoding_window and not is_finished: if num_frames < decoding_window and not is_finished:
logger.info(f"frame feat num is less than {decoding_window}, please input more pcm data")
return None, None
if num_frames < context:
logger.info("flast {num_frames} is less than context {context} frames, and we cannot do model forward")
return None, None return None, None
logger.info("start to do model forward") logger.info("start to do model forward")
@ -242,14 +258,18 @@ class PaddleASRConnectionHanddler:
# update the offset # update the offset
self.offset += y.shape[1] self.offset += y.shape[1]
logger.info(f"output size: {len(outputs)}")
ys = paddle.cat(outputs, 1) ys = paddle.cat(outputs, 1)
masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool) if self.encoder_out is None:
masks = masks.unsqueeze(1) self.encoder_out = ys
else:
self.encoder_out = paddle.concat([self.encoder_out, ys], axis=1)
# masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool)
# masks = masks.unsqueeze(1)
# get the ctc probs # get the ctc probs
ctc_probs = self.model.ctc.log_softmax(ys) # (1, maxlen, vocab_size) ctc_probs = self.model.ctc.log_softmax(ys) # (1, maxlen, vocab_size)
ctc_probs = ctc_probs.squeeze(0) ctc_probs = ctc_probs.squeeze(0)
# self.searcher.search(xs, ctc_probs, xs.place)
self.searcher.search(None, ctc_probs, self.cached_feat.place) self.searcher.search(None, ctc_probs, self.cached_feat.place)
@ -260,7 +280,8 @@ class PaddleASRConnectionHanddler:
self.cached_feat = None self.cached_feat = None
else: else:
assert self.cached_feat.shape[0] == 1 assert self.cached_feat.shape[0] == 1
self.cached_feat = self.cached_feat[0,end:,:].unsqueeze(0) assert end >= cached_feature_num
self.cached_feat = self.cached_feat[0,end - cached_feature_num:,:].unsqueeze(0)
assert len(self.cached_feat.shape) == 3, f"current cache feat shape is: {self.cached_feat.shape}" assert len(self.cached_feat.shape) == 3, f"current cache feat shape is: {self.cached_feat.shape}"
# ys for rescoring # ys for rescoring
@ -274,9 +295,75 @@ class PaddleASRConnectionHanddler:
] ]
self.result_tokenids = [hyp for hyp in hyps] self.result_tokenids = [hyp for hyp in hyps]
def get_result(self):
if len(self.result_transcripts) > 0:
return self.result_transcripts[0]
else:
return ''
def rescoring(self): def rescoring(self):
pass logger.info("rescoring the final result")
if "attention_rescoring" != self.ctc_decode_config.decoding_method:
return
self.searcher.finalize_search()
self.update_result()
beam_size = self.ctc_decode_config.beam_size
hyps = self.searcher.get_hyps()
if hyps is None or len(hyps) == 0:
return
# assert len(hyps) == beam_size
paddle.save(self.encoder_out, "encoder.out")
hyp_list = []
for hyp in hyps:
hyp_content = hyp[0]
# Prevent the hyp is empty
if len(hyp_content) == 0:
hyp_content = (self.model.ctc.blank_id, )
hyp_content = paddle.to_tensor(
hyp_content, place=self.device, dtype=paddle.long)
hyp_list.append(hyp_content)
hyps_pad = pad_sequence(hyp_list, True, self.model.ignore_id)
hyps_lens = paddle.to_tensor(
[len(hyp[0]) for hyp in hyps], place=self.device,
dtype=paddle.long) # (beam_size,)
hyps_pad, _ = add_sos_eos(hyps_pad, self.model.sos, self.model.eos,
self.model.ignore_id)
hyps_lens = hyps_lens + 1 # Add <sos> at begining
encoder_out = self.encoder_out.repeat(beam_size, 1, 1)
encoder_mask = paddle.ones(
(beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool)
decoder_out, _ = self.model.decoder(
encoder_out, encoder_mask, hyps_pad,
hyps_lens) # (beam_size, max_hyps_len, vocab_size)
# ctc score in ln domain
decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1)
decoder_out = decoder_out.numpy()
# Only use decoder score for rescoring
best_score = -float('inf')
best_index = 0
# hyps is List[(Text=List[int], Score=float)], len(hyps)=beam_size
for i, hyp in enumerate(hyps):
score = 0.0
for j, w in enumerate(hyp[0]):
score += decoder_out[i][j][w]
# last decoder output token is `eos`, for laste decoder input token.
score += decoder_out[i][len(hyp[0])][self.model.eos]
# add ctc score (which in ln domain)
score += hyp[1] * self.ctc_decode_config.ctc_weight
if score > best_score:
best_score = score
best_index = i
# update the one best result
logger.info(f"best index: {best_index}")
self.hyps = [hyps[best_index][0]]
self.update_result()
# return hyps[best_index][0]

@ -110,6 +110,11 @@ class CTCPrefixBeamSearch:
return [self.hyps[0][0]] return [self.hyps[0][0]]
def get_hyps(self): def get_hyps(self):
"""Return the search hyps
Returns:
list: return the search hyps
"""
return self.hyps return self.hyps
def reset(self): def reset(self):
@ -117,3 +122,8 @@ class CTCPrefixBeamSearch:
""" """
self.cur_hyps = None self.cur_hyps = None
self.hyps = None self.hyps = None
def finalize_search(self):
"""do nothing in ctc_prefix_beam_search
"""
pass

@ -13,16 +13,15 @@
# limitations under the License. # limitations under the License.
import json import json
import numpy as np
from fastapi import APIRouter from fastapi import APIRouter
from fastapi import WebSocket from fastapi import WebSocket
from fastapi import WebSocketDisconnect from fastapi import WebSocketDisconnect
from starlette.websockets import WebSocketState as WebSocketState from starlette.websockets import WebSocketState as WebSocketState
from paddlespeech.server.engine.asr.online.asr_engine import PaddleASRConnectionHanddler
from paddlespeech.server.engine.engine_pool import get_engine_pool from paddlespeech.server.engine.engine_pool import get_engine_pool
from paddlespeech.server.utils.buffer import ChunkBuffer from paddlespeech.server.utils.buffer import ChunkBuffer
from paddlespeech.server.utils.vad import VADAudio from paddlespeech.server.utils.vad import VADAudio
from paddlespeech.server.engine.asr.online.asr_engine import PaddleASRConnectionHanddler
router = APIRouter() router = APIRouter()
@ -73,13 +72,17 @@ async def websocket_endpoint(websocket: WebSocket):
connection_handler = PaddleASRConnectionHanddler(asr_engine) connection_handler = PaddleASRConnectionHanddler(asr_engine)
await websocket.send_json(resp) await websocket.send_json(resp)
elif message['signal'] == 'end': elif message['signal'] == 'end':
engine_pool = get_engine_pool()
asr_engine = engine_pool['asr']
# reset single engine for an new connection # reset single engine for an new connection
asr_results = connection_handler.decode(is_finished=True) connection_handler.decode(is_finished=True)
connection_handler.rescoring()
asr_results = connection_handler.get_result()
connection_handler.reset() connection_handler.reset()
asr_engine.reset()
resp = {"status": "ok", "signal": "finished", 'asr_results': asr_results} resp = {
"status": "ok",
"signal": "finished",
'asr_results': asr_results
}
await websocket.send_json(resp) await websocket.send_json(resp)
break break
else: else:
@ -87,25 +90,11 @@ async def websocket_endpoint(websocket: WebSocket):
await websocket.send_json(resp) await websocket.send_json(resp)
elif "bytes" in message: elif "bytes" in message:
message = message["bytes"] message = message["bytes"]
engine_pool = get_engine_pool()
asr_engine = engine_pool['asr']
asr_results = ""
connection_handler.extract_feat(message)
asr_results = connection_handler.decode(is_finished=False)
# connection_handler.
# frames = chunk_buffer.frame_generator(message)
# for frame in frames:
# # get the pcm data from the bytes
# samples = np.frombuffer(frame.bytes, dtype=np.int16)
# sample_rate = asr_engine.config.sample_rate
# x_chunk, x_chunk_lens = asr_engine.preprocess(samples,
# sample_rate)
# asr_engine.run(x_chunk, x_chunk_lens)
# asr_results = asr_engine.postprocess()
# # connection accept the sample data frame by frame connection_handler.extract_feat(message)
connection_handler.decode(is_finished=False)
asr_results = connection_handler.get_result()
# asr_results = asr_engine.postprocess()
resp = {'asr_results': asr_results} resp = {'asr_results': asr_results}
print("\n") print("\n")
await websocket.send_json(resp) await websocket.send_json(resp)

Loading…
Cancel
Save