fix speechx core dump when stop immediately after start

pull/1815/head
Hui Zhang 2 years ago
parent 2a661fcdb4
commit fc96130fdc

@ -21,6 +21,8 @@ from typing import Union
import numpy as np import numpy as np
import paddle import paddle
import yaml import yaml
from paddleaudio import load
from paddleaudio.features import LogMelSpectrogram
from ..executor import BaseExecutor from ..executor import BaseExecutor
from ..log import logger from ..log import logger
@ -28,8 +30,6 @@ from ..utils import cli_register
from ..utils import stats_wrapper from ..utils import stats_wrapper
from .pretrained_models import model_alias from .pretrained_models import model_alias
from .pretrained_models import pretrained_models from .pretrained_models import pretrained_models
from paddleaudio import load
from paddleaudio.features import LogMelSpectrogram
from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.s2t.utils.dynamic_import import dynamic_import
__all__ = ['CLSExecutor'] __all__ = ['CLSExecutor']

@ -325,7 +325,6 @@ if not hasattr(paddle.Tensor, 'type_as'):
setattr(paddle.static.Variable, 'type_as', type_as) setattr(paddle.static.Variable, 'type_as', type_as)
def to(x: paddle.Tensor, *args, **kwargs) -> paddle.Tensor: def to(x: paddle.Tensor, *args, **kwargs) -> paddle.Tensor:
assert len(args) == 1 assert len(args) == 1
if isinstance(args[0], str): # dtype if isinstance(args[0], str): # dtype
@ -372,7 +371,6 @@ if not hasattr(paddle.Tensor, 'tolist'):
"register user tolist to paddle.Tensor, remove this when fixed!") "register user tolist to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'tolist', tolist) setattr(paddle.Tensor, 'tolist', tolist)
setattr(paddle.static.Variable, 'tolist', tolist) setattr(paddle.static.Variable, 'tolist', tolist)
########### hack paddle.nn ############# ########### hack paddle.nn #############
from paddle.nn import Layer from paddle.nn import Layer

@ -521,4 +521,4 @@ class TextClientExecutor(BaseExecutor):
res = requests.post(url=url, data=json.dumps(request)) res = requests.post(url=url, data=json.dumps(request))
response_dict = res.json() response_dict = res.json()
punc_text = response_dict["result"]["punc_text"] punc_text = response_dict["result"]["punc_text"]
return punc_text return punc_text

@ -91,8 +91,7 @@ class ASRWsAudioHandler:
if url is None or port is None or endpoint is None: if url is None or port is None or endpoint is None:
self.url = None self.url = None
else: else:
self.url = "ws://" + self.url + ":" + str( self.url = "ws://" + self.url + ":" + str(self.port) + endpoint
self.port) + endpoint
self.punc_server = TextHttpHandler(punc_server_ip, punc_server_port) self.punc_server = TextHttpHandler(punc_server_ip, punc_server_port)
logger.info(f"endpoint: {self.url}") logger.info(f"endpoint: {self.url}")
@ -139,8 +138,7 @@ class ASRWsAudioHandler:
logging.info("send a message to the server") logging.info("send a message to the server")
if self.url is None: if self.url is None:
logger.error( logger.error("No asr server, please input valid ip and port")
"No asr server, please input valid ip and port")
return "" return ""
# 1. send websocket handshake protocal # 1. send websocket handshake protocal
@ -167,8 +165,7 @@ class ASRWsAudioHandler:
msg = json.loads(msg) msg = json.loads(msg)
if self.punc_server and len(msg["result"]) > 0: if self.punc_server and len(msg["result"]) > 0:
msg["result"] = self.punc_server.run( msg["result"] = self.punc_server.run(msg["result"])
msg["result"])
logger.info("client receive msg={}".format(msg)) logger.info("client receive msg={}".format(msg))
# 4. we must send finished signal to the server # 4. we must send finished signal to the server
@ -189,7 +186,7 @@ class ASRWsAudioHandler:
if self.punc_server: if self.punc_server:
msg["result"] = self.punc_server.run(msg["result"]) msg["result"] = self.punc_server.run(msg["result"])
logger.info("client final receive msg={}".format(msg)) logger.info("client final receive msg={}".format(msg))
result = msg result = msg

@ -48,6 +48,12 @@ void TLGDecoder::Reset() {
} }
std::string TLGDecoder::GetFinalBestPath() { std::string TLGDecoder::GetFinalBestPath() {
if (frame_decoded_size_ == 0) {
// Assertion failed: (this->NumFramesDecoded() > 0 && "You cannot call
// BestPathEnd if no frames were decoded.")
return std::string("");
}
decoder_->FinalizeDecoding(); decoder_->FinalizeDecoding();
kaldi::Lattice lat; kaldi::Lattice lat;
kaldi::LatticeWeight weight; kaldi::LatticeWeight weight;

@ -27,21 +27,22 @@ ConnectionHandler::ConnectionHandler(
: ws_(std::move(socket)), recognizer_resource_(recognizer_resource) {} : ws_(std::move(socket)), recognizer_resource_(recognizer_resource) {}
void ConnectionHandler::OnSpeechStart() { void ConnectionHandler::OnSpeechStart() {
LOG(INFO) << "Server: Recieved speech start signal, start reading speech";
got_start_tag_ = true;
json::value rv = {{"status", "ok"}, {"type", "server_ready"}};
ws_.text(true);
ws_.write(asio::buffer(json::serialize(rv)));
recognizer_ = std::make_shared<Recognizer>(recognizer_resource_); recognizer_ = std::make_shared<Recognizer>(recognizer_resource_);
// Start decoder thread // Start decoder thread
decode_thread_ = std::make_shared<std::thread>( decode_thread_ = std::make_shared<std::thread>(
&ConnectionHandler::DecodeThreadFunc, this); &ConnectionHandler::DecodeThreadFunc, this);
got_start_tag_ = true;
LOG(INFO) << "Server: Recieved speech start signal, start reading speech";
json::value rv = {{"status", "ok"}, {"type", "server_ready"}};
ws_.text(true);
ws_.write(asio::buffer(json::serialize(rv)));
} }
void ConnectionHandler::OnSpeechEnd() { void ConnectionHandler::OnSpeechEnd() {
LOG(INFO) << "Server: Recieved speech end signal"; LOG(INFO) << "Server: Recieved speech end signal";
CHECK(recognizer_ != nullptr); if (recognizer_ != nullptr) {
recognizer_->SetFinished(); recognizer_->SetFinished();
}
got_end_tag_ = true; got_end_tag_ = true;
} }
@ -158,7 +159,7 @@ void ConnectionHandler::operator()() {
} }
} }
LOG(INFO) << "Server: Read all pcm data, wait for decoding thread"; LOG(INFO) << "Server: finished to wait for decoding thread join.";
if (decode_thread_ != nullptr) { if (decode_thread_ != nullptr) {
decode_thread_->join(); decode_thread_->join();
} }

Loading…
Cancel
Save