Merge pull request #1787 from zh794390558/spx_ws

[speechx] fix speechx ws server to return dummpy partial  result
pull/1788/head
Hui Zhang 3 years ago committed by GitHub
commit cea7b5eb65
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -22,6 +22,8 @@ from typing import Union
import paddle import paddle
import soundfile import soundfile
from paddleaudio.backends import load as load_audio
from paddleaudio.compliance.librosa import melspectrogram
from yacs.config import CfgNode from yacs.config import CfgNode
from ..executor import BaseExecutor from ..executor import BaseExecutor
@ -30,8 +32,6 @@ from ..utils import cli_register
from ..utils import stats_wrapper from ..utils import stats_wrapper
from .pretrained_models import model_alias from .pretrained_models import model_alias
from .pretrained_models import pretrained_models from .pretrained_models import pretrained_models
from paddleaudio.backends import load as load_audio
from paddleaudio.compliance.librosa import melspectrogram
from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.vector.io.batch import feature_normalize from paddlespeech.vector.io.batch import feature_normalize
from paddlespeech.vector.modules.sid_model import SpeakerIdetification from paddlespeech.vector.modules.sid_model import SpeakerIdetification

@ -14,10 +14,10 @@
import os import os
import paddle import paddle
from yacs.config import CfgNode
from paddleaudio.utils import logger from paddleaudio.utils import logger
from paddleaudio.utils import Timer from paddleaudio.utils import Timer
from yacs.config import CfgNode
from paddlespeech.kws.exps.mdtc.collate import collate_features from paddlespeech.kws.exps.mdtc.collate import collate_features
from paddlespeech.kws.models.loss import max_pooling_loss from paddlespeech.kws.models.loss import max_pooling_loss
from paddlespeech.kws.models.mdtc import KWSModel from paddlespeech.kws.models.mdtc import KWSModel

@ -24,11 +24,11 @@ from typing import Any
from typing import Dict from typing import Dict
import paddle import paddle
import paddleaudio
import requests import requests
import yaml import yaml
from paddle.framework import load from paddle.framework import load
import paddleaudio
from . import download from . import download
from .entry import client_commands from .entry import client_commands
from .entry import server_commands from .entry import server_commands

@ -27,7 +27,10 @@ from paddlespeech.server.utils.audio_process import save_audio
class ASRAudioHandler: class ASRAudioHandler:
def __init__(self, url="127.0.0.1", port=8090): def __init__(self,
url="127.0.0.1",
port=8090,
endopoint='/paddlespeech/asr/streaming'):
"""PaddleSpeech Online ASR Server Client audio handler """PaddleSpeech Online ASR Server Client audio handler
Online asr server use the websocket protocal Online asr server use the websocket protocal
Args: Args:
@ -36,7 +39,8 @@ class ASRAudioHandler:
""" """
self.url = url self.url = url
self.port = port self.port = port
self.url = "ws://" + self.url + ":" + str(self.port) + "/ws/asr" self.url = "ws://" + self.url + ":" + str(self.port) + endopoint
logger.info(f"endpoint: {self.url}")
def read_wave(self, wavfile_path: str): def read_wave(self, wavfile_path: str):
"""read the audio file from specific wavfile path """read the audio file from specific wavfile path
@ -95,14 +99,14 @@ class ASRAudioHandler:
separators=(',', ': ')) separators=(',', ': '))
await ws.send(audio_info) await ws.send(audio_info)
msg = await ws.recv() msg = await ws.recv()
logger.info("receive msg={}".format(msg)) logger.info("client receive msg={}".format(msg))
# 3. send chunk audio data to engine # 3. send chunk audio data to engine
for chunk_data in self.read_wave(wavfile_path): for chunk_data in self.read_wave(wavfile_path):
await ws.send(chunk_data.tobytes()) await ws.send(chunk_data.tobytes())
msg = await ws.recv() msg = await ws.recv()
msg = json.loads(msg) msg = json.loads(msg)
logger.info("receive msg={}".format(msg)) logger.info("client receive msg={}".format(msg))
# 4. we must send finished signal to the server # 4. we must send finished signal to the server
audio_info = json.dumps( audio_info = json.dumps(
@ -119,7 +123,7 @@ class ASRAudioHandler:
# 5. decode the bytes to str # 5. decode the bytes to str
msg = json.loads(msg) msg = json.loads(msg)
logger.info("final receive msg={}".format(msg)) logger.info("client final receive msg={}".format(msg))
result = msg result = msg
return result return result

@ -27,7 +27,7 @@ ConnectionHandler::ConnectionHandler(
: ws_(std::move(socket)), recognizer_resource_(recognizer_resource) {} : ws_(std::move(socket)), recognizer_resource_(recognizer_resource) {}
void ConnectionHandler::OnSpeechStart() { void ConnectionHandler::OnSpeechStart() {
LOG(INFO) << "Recieved speech start signal, start reading speech"; LOG(INFO) << "Server: Recieved speech start signal, start reading speech";
got_start_tag_ = true; got_start_tag_ = true;
json::value rv = {{"status", "ok"}, {"type", "server_ready"}}; json::value rv = {{"status", "ok"}, {"type", "server_ready"}};
ws_.text(true); ws_.text(true);
@ -39,14 +39,14 @@ void ConnectionHandler::OnSpeechStart() {
} }
void ConnectionHandler::OnSpeechEnd() { void ConnectionHandler::OnSpeechEnd() {
LOG(INFO) << "Recieved speech end signal"; LOG(INFO) << "Server: Recieved speech end signal";
CHECK(recognizer_ != nullptr); CHECK(recognizer_ != nullptr);
recognizer_->SetFinished(); recognizer_->SetFinished();
got_end_tag_ = true; got_end_tag_ = true;
} }
void ConnectionHandler::OnFinalResult(const std::string& result) { void ConnectionHandler::OnFinalResult(const std::string& result) {
LOG(INFO) << "Final result: " << result; LOG(INFO) << "Server: Final result: " << result;
json::value rv = { json::value rv = {
{"status", "ok"}, {"type", "final_result"}, {"result", result}}; {"status", "ok"}, {"type", "final_result"}, {"result", result}};
ws_.text(true); ws_.text(true);
@ -69,10 +69,16 @@ void ConnectionHandler::OnSpeechData(const beast::flat_buffer& buffer) {
pcm_data(i) = static_cast<float>(*pdata); pcm_data(i) = static_cast<float>(*pdata);
pdata++; pdata++;
} }
VLOG(2) << "Recieved " << num_samples << " samples"; VLOG(2) << "Server: Recieved " << num_samples << " samples";
LOG(INFO) << "Recieved " << num_samples << " samples"; LOG(INFO) << "Server: Recieved " << num_samples << " samples";
CHECK(recognizer_ != nullptr); CHECK(recognizer_ != nullptr);
recognizer_->Accept(pcm_data); recognizer_->Accept(pcm_data);
// TODO: return lpartial result
json::value rv = {
{"status", "ok"}, {"type", "partial_result"}, {"result", "TODO"}};
ws_.text(true);
ws_.write(asio::buffer(json::serialize(rv)));
} }
void ConnectionHandler::DecodeThreadFunc() { void ConnectionHandler::DecodeThreadFunc() {
@ -80,9 +86,9 @@ void ConnectionHandler::DecodeThreadFunc() {
while (true) { while (true) {
recognizer_->Decode(); recognizer_->Decode();
if (recognizer_->IsFinished()) { if (recognizer_->IsFinished()) {
LOG(INFO) << "enter finish"; LOG(INFO) << "Server: enter finish";
recognizer_->Decode(); recognizer_->Decode();
LOG(INFO) << "finish"; LOG(INFO) << "Server: finish";
std::string result = recognizer_->GetFinalResult(); std::string result = recognizer_->GetFinalResult();
OnFinalResult(result); OnFinalResult(result);
OnFinish(); OnFinish();
@ -135,7 +141,7 @@ void ConnectionHandler::operator()() {
ws_.read(buffer); ws_.read(buffer);
if (ws_.got_text()) { if (ws_.got_text()) {
std::string message = beast::buffers_to_string(buffer.data()); std::string message = beast::buffers_to_string(buffer.data());
LOG(INFO) << message; LOG(INFO) << "Server: Text: " << message;
OnText(message); OnText(message);
if (got_end_tag_) { if (got_end_tag_) {
break; break;
@ -152,7 +158,7 @@ void ConnectionHandler::operator()() {
} }
} }
LOG(INFO) << "Read all pcm data, wait for decoding thread"; LOG(INFO) << "Server: Read all pcm data, wait for decoding thread";
if (decode_thread_ != nullptr) { if (decode_thread_ != nullptr) {
decode_thread_->join(); decode_thread_->join();
} }

Loading…
Cancel
Save