fix speechx ws server to return dummpy partial result, fix hang for ws client

pull/1787/head
Hui Zhang 3 years ago
parent d7c8c1779f
commit 5e23025c31

@ -22,6 +22,8 @@ from typing import Union
import paddle
import soundfile
from paddleaudio.backends import load as load_audio
from paddleaudio.compliance.librosa import melspectrogram
from yacs.config import CfgNode
from ..executor import BaseExecutor
@ -30,8 +32,6 @@ from ..utils import cli_register
from ..utils import stats_wrapper
from .pretrained_models import model_alias
from .pretrained_models import pretrained_models
from paddleaudio.backends import load as load_audio
from paddleaudio.compliance.librosa import melspectrogram
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.vector.io.batch import feature_normalize
from paddlespeech.vector.modules.sid_model import SpeakerIdetification

@ -14,10 +14,10 @@
import os
import paddle
from yacs.config import CfgNode
from paddleaudio.utils import logger
from paddleaudio.utils import Timer
from yacs.config import CfgNode
from paddlespeech.kws.exps.mdtc.collate import collate_features
from paddlespeech.kws.models.loss import max_pooling_loss
from paddlespeech.kws.models.mdtc import KWSModel

@ -24,11 +24,11 @@ from typing import Any
from typing import Dict
import paddle
import paddleaudio
import requests
import yaml
from paddle.framework import load
import paddleaudio
from . import download
from .entry import client_commands
from .entry import server_commands

@ -27,7 +27,10 @@ from paddlespeech.server.utils.audio_process import save_audio
class ASRAudioHandler:
def __init__(self, url="127.0.0.1", port=8090):
def __init__(self,
url="127.0.0.1",
port=8090,
endopoint='/paddlespeech/asr/streaming'):
"""PaddleSpeech Online ASR Server Client audio handler
Online asr server use the websocket protocal
Args:
@ -36,7 +39,8 @@ class ASRAudioHandler:
"""
self.url = url
self.port = port
self.url = "ws://" + self.url + ":" + str(self.port) + "/ws/asr"
self.url = "ws://" + self.url + ":" + str(self.port) + endopoint
logger.info(f"endpoint: {self.url}")
def read_wave(self, wavfile_path: str):
"""read the audio file from specific wavfile path
@ -95,14 +99,14 @@ class ASRAudioHandler:
separators=(',', ': '))
await ws.send(audio_info)
msg = await ws.recv()
logger.info("receive msg={}".format(msg))
logger.info("client receive msg={}".format(msg))
# 3. send chunk audio data to engine
for chunk_data in self.read_wave(wavfile_path):
await ws.send(chunk_data.tobytes())
msg = await ws.recv()
msg = json.loads(msg)
logger.info("receive msg={}".format(msg))
logger.info("client receive msg={}".format(msg))
# 4. we must send finished signal to the server
audio_info = json.dumps(
@ -119,7 +123,7 @@ class ASRAudioHandler:
# 5. decode the bytes to str
msg = json.loads(msg)
logger.info("final receive msg={}".format(msg))
logger.info("client final receive msg={}".format(msg))
result = msg
return result

@ -27,7 +27,7 @@ ConnectionHandler::ConnectionHandler(
: ws_(std::move(socket)), recognizer_resource_(recognizer_resource) {}
void ConnectionHandler::OnSpeechStart() {
LOG(INFO) << "Recieved speech start signal, start reading speech";
LOG(INFO) << "Server: Recieved speech start signal, start reading speech";
got_start_tag_ = true;
json::value rv = {{"status", "ok"}, {"type", "server_ready"}};
ws_.text(true);
@ -39,14 +39,14 @@ void ConnectionHandler::OnSpeechStart() {
}
void ConnectionHandler::OnSpeechEnd() {
LOG(INFO) << "Recieved speech end signal";
LOG(INFO) << "Server: Recieved speech end signal";
CHECK(recognizer_ != nullptr);
recognizer_->SetFinished();
got_end_tag_ = true;
}
void ConnectionHandler::OnFinalResult(const std::string& result) {
LOG(INFO) << "Final result: " << result;
LOG(INFO) << "Server: Final result: " << result;
json::value rv = {
{"status", "ok"}, {"type", "final_result"}, {"result", result}};
ws_.text(true);
@ -69,10 +69,16 @@ void ConnectionHandler::OnSpeechData(const beast::flat_buffer& buffer) {
pcm_data(i) = static_cast<float>(*pdata);
pdata++;
}
VLOG(2) << "Recieved " << num_samples << " samples";
LOG(INFO) << "Recieved " << num_samples << " samples";
VLOG(2) << "Server: Recieved " << num_samples << " samples";
LOG(INFO) << "Server: Recieved " << num_samples << " samples";
CHECK(recognizer_ != nullptr);
recognizer_->Accept(pcm_data);
// TODO: return lpartial result
json::value rv = {
{"status", "ok"}, {"type", "partial_result"}, {"result", "TODO"}};
ws_.text(true);
ws_.write(asio::buffer(json::serialize(rv)));
}
void ConnectionHandler::DecodeThreadFunc() {
@ -80,9 +86,9 @@ void ConnectionHandler::DecodeThreadFunc() {
while (true) {
recognizer_->Decode();
if (recognizer_->IsFinished()) {
LOG(INFO) << "enter finish";
LOG(INFO) << "Server: enter finish";
recognizer_->Decode();
LOG(INFO) << "finish";
LOG(INFO) << "Server: finish";
std::string result = recognizer_->GetFinalResult();
OnFinalResult(result);
OnFinish();
@ -135,7 +141,7 @@ void ConnectionHandler::operator()() {
ws_.read(buffer);
if (ws_.got_text()) {
std::string message = beast::buffers_to_string(buffer.data());
LOG(INFO) << message;
LOG(INFO) << "Server: Text: " << message;
OnText(message);
if (got_end_tag_) {
break;
@ -152,7 +158,7 @@ void ConnectionHandler::operator()() {
}
}
LOG(INFO) << "Read all pcm data, wait for decoding thread";
LOG(INFO) << "Server: Read all pcm data, wait for decoding thread";
if (decode_thread_ != nullptr) {
decode_thread_->join();
}

Loading…
Cancel
Save