Merge pull request #1815 from zh794390558/bugfix

[speechx] fix speechx core dump when stop immediately after start
pull/1823/head
Hui Zhang 2 years ago committed by GitHub
commit bb8785c6ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -21,6 +21,8 @@ from typing import Union
import numpy as np
import paddle
import yaml
from paddleaudio import load
from paddleaudio.features import LogMelSpectrogram
from ..executor import BaseExecutor
from ..log import logger
@ -28,8 +30,6 @@ from ..utils import cli_register
from ..utils import stats_wrapper
from .pretrained_models import model_alias
from .pretrained_models import pretrained_models
from paddleaudio import load
from paddleaudio.features import LogMelSpectrogram
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
__all__ = ['CLSExecutor']

@ -325,7 +325,6 @@ if not hasattr(paddle.Tensor, 'type_as'):
setattr(paddle.static.Variable, 'type_as', type_as)
def to(x: paddle.Tensor, *args, **kwargs) -> paddle.Tensor:
assert len(args) == 1
if isinstance(args[0], str): # dtype
@ -372,7 +371,6 @@ if not hasattr(paddle.Tensor, 'tolist'):
"register user tolist to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'tolist', tolist)
setattr(paddle.static.Variable, 'tolist', tolist)
########### hack paddle.nn #############
from paddle.nn import Layer

@ -521,4 +521,4 @@ class TextClientExecutor(BaseExecutor):
res = requests.post(url=url, data=json.dumps(request))
response_dict = res.json()
punc_text = response_dict["result"]["punc_text"]
return punc_text
return punc_text

@ -91,8 +91,7 @@ class ASRWsAudioHandler:
if url is None or port is None or endpoint is None:
self.url = None
else:
self.url = "ws://" + self.url + ":" + str(
self.port) + endpoint
self.url = "ws://" + self.url + ":" + str(self.port) + endpoint
self.punc_server = TextHttpHandler(punc_server_ip, punc_server_port)
logger.info(f"endpoint: {self.url}")
@ -139,8 +138,7 @@ class ASRWsAudioHandler:
logging.info("send a message to the server")
if self.url is None:
logger.error(
"No asr server, please input valid ip and port")
logger.error("No asr server, please input valid ip and port")
return ""
# 1. send websocket handshake protocal
@ -167,8 +165,7 @@ class ASRWsAudioHandler:
msg = json.loads(msg)
if self.punc_server and len(msg["result"]) > 0:
msg["result"] = self.punc_server.run(
msg["result"])
msg["result"] = self.punc_server.run(msg["result"])
logger.info("client receive msg={}".format(msg))
# 4. we must send finished signal to the server
@ -189,7 +186,7 @@ class ASRWsAudioHandler:
if self.punc_server:
msg["result"] = self.punc_server.run(msg["result"])
logger.info("client final receive msg={}".format(msg))
result = msg

@ -48,6 +48,12 @@ void TLGDecoder::Reset() {
}
std::string TLGDecoder::GetFinalBestPath() {
if (frame_decoded_size_ == 0) {
// Assertion failed: (this->NumFramesDecoded() > 0 && "You cannot call
// BestPathEnd if no frames were decoded.")
return std::string("");
}
decoder_->FinalizeDecoding();
kaldi::Lattice lat;
kaldi::LatticeWeight weight;

@ -27,21 +27,22 @@ ConnectionHandler::ConnectionHandler(
: ws_(std::move(socket)), recognizer_resource_(recognizer_resource) {}
void ConnectionHandler::OnSpeechStart() {
LOG(INFO) << "Server: Recieved speech start signal, start reading speech";
got_start_tag_ = true;
json::value rv = {{"status", "ok"}, {"type", "server_ready"}};
ws_.text(true);
ws_.write(asio::buffer(json::serialize(rv)));
recognizer_ = std::make_shared<Recognizer>(recognizer_resource_);
// Start decoder thread
decode_thread_ = std::make_shared<std::thread>(
&ConnectionHandler::DecodeThreadFunc, this);
got_start_tag_ = true;
LOG(INFO) << "Server: Recieved speech start signal, start reading speech";
json::value rv = {{"status", "ok"}, {"type", "server_ready"}};
ws_.text(true);
ws_.write(asio::buffer(json::serialize(rv)));
}
void ConnectionHandler::OnSpeechEnd() {
LOG(INFO) << "Server: Recieved speech end signal";
CHECK(recognizer_ != nullptr);
recognizer_->SetFinished();
if (recognizer_ != nullptr) {
recognizer_->SetFinished();
}
got_end_tag_ = true;
}
@ -158,7 +159,7 @@ void ConnectionHandler::operator()() {
}
}
LOG(INFO) << "Server: Read all pcm data, wait for decoding thread";
LOG(INFO) << "Server: finished to wait for decoding thread join.";
if (decode_thread_ != nullptr) {
decode_thread_->join();
}

Loading…
Cancel
Save