diff --git a/paddlespeech/cli/vector/infer.py b/paddlespeech/cli/vector/infer.py index 1dff6edb..37e19391 100644 --- a/paddlespeech/cli/vector/infer.py +++ b/paddlespeech/cli/vector/infer.py @@ -22,6 +22,8 @@ from typing import Union import paddle import soundfile +from paddleaudio.backends import load as load_audio +from paddleaudio.compliance.librosa import melspectrogram from yacs.config import CfgNode from ..executor import BaseExecutor @@ -30,8 +32,6 @@ from ..utils import cli_register from ..utils import stats_wrapper from .pretrained_models import model_alias from .pretrained_models import pretrained_models -from paddleaudio.backends import load as load_audio -from paddleaudio.compliance.librosa import melspectrogram from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.vector.io.batch import feature_normalize from paddlespeech.vector.modules.sid_model import SpeakerIdetification diff --git a/paddlespeech/kws/exps/mdtc/train.py b/paddlespeech/kws/exps/mdtc/train.py index 56082bd7..5a9ca92d 100644 --- a/paddlespeech/kws/exps/mdtc/train.py +++ b/paddlespeech/kws/exps/mdtc/train.py @@ -14,10 +14,10 @@ import os import paddle -from yacs.config import CfgNode - from paddleaudio.utils import logger from paddleaudio.utils import Timer +from yacs.config import CfgNode + from paddlespeech.kws.exps.mdtc.collate import collate_features from paddlespeech.kws.models.loss import max_pooling_loss from paddlespeech.kws.models.mdtc import KWSModel diff --git a/paddlespeech/server/util.py b/paddlespeech/server/util.py index 1f1b0be1..ae3e9c6a 100644 --- a/paddlespeech/server/util.py +++ b/paddlespeech/server/util.py @@ -24,11 +24,11 @@ from typing import Any from typing import Dict import paddle +import paddleaudio import requests import yaml from paddle.framework import load -import paddleaudio from . import download from .entry import client_commands from .entry import server_commands diff --git a/paddlespeech/server/utils/audio_handler.py b/paddlespeech/server/utils/audio_handler.py index c2863115..727b8f90 100644 --- a/paddlespeech/server/utils/audio_handler.py +++ b/paddlespeech/server/utils/audio_handler.py @@ -27,7 +27,10 @@ from paddlespeech.server.utils.audio_process import save_audio class ASRAudioHandler: - def __init__(self, url="127.0.0.1", port=8090): + def __init__(self, + url="127.0.0.1", + port=8090, + endopoint='/paddlespeech/asr/streaming'): """PaddleSpeech Online ASR Server Client audio handler Online asr server use the websocket protocal Args: @@ -36,7 +39,8 @@ class ASRAudioHandler: """ self.url = url self.port = port - self.url = "ws://" + self.url + ":" + str(self.port) + "/ws/asr" + self.url = "ws://" + self.url + ":" + str(self.port) + endopoint + logger.info(f"endpoint: {self.url}") def read_wave(self, wavfile_path: str): """read the audio file from specific wavfile path @@ -95,14 +99,14 @@ class ASRAudioHandler: separators=(',', ': ')) await ws.send(audio_info) msg = await ws.recv() - logger.info("receive msg={}".format(msg)) + logger.info("client receive msg={}".format(msg)) # 3. send chunk audio data to engine for chunk_data in self.read_wave(wavfile_path): await ws.send(chunk_data.tobytes()) msg = await ws.recv() msg = json.loads(msg) - logger.info("receive msg={}".format(msg)) + logger.info("client receive msg={}".format(msg)) # 4. we must send finished signal to the server audio_info = json.dumps( @@ -119,7 +123,7 @@ class ASRAudioHandler: # 5. decode the bytes to str msg = json.loads(msg) - logger.info("final receive msg={}".format(msg)) + logger.info("client final receive msg={}".format(msg)) result = msg return result diff --git a/speechx/speechx/websocket/websocket_server.cc b/speechx/speechx/websocket/websocket_server.cc index 3f6da894..62d3d9e0 100644 --- a/speechx/speechx/websocket/websocket_server.cc +++ b/speechx/speechx/websocket/websocket_server.cc @@ -27,7 +27,7 @@ ConnectionHandler::ConnectionHandler( : ws_(std::move(socket)), recognizer_resource_(recognizer_resource) {} void ConnectionHandler::OnSpeechStart() { - LOG(INFO) << "Recieved speech start signal, start reading speech"; + LOG(INFO) << "Server: Recieved speech start signal, start reading speech"; got_start_tag_ = true; json::value rv = {{"status", "ok"}, {"type", "server_ready"}}; ws_.text(true); @@ -39,14 +39,14 @@ void ConnectionHandler::OnSpeechStart() { } void ConnectionHandler::OnSpeechEnd() { - LOG(INFO) << "Recieved speech end signal"; + LOG(INFO) << "Server: Recieved speech end signal"; CHECK(recognizer_ != nullptr); recognizer_->SetFinished(); got_end_tag_ = true; } void ConnectionHandler::OnFinalResult(const std::string& result) { - LOG(INFO) << "Final result: " << result; + LOG(INFO) << "Server: Final result: " << result; json::value rv = { {"status", "ok"}, {"type", "final_result"}, {"result", result}}; ws_.text(true); @@ -69,10 +69,16 @@ void ConnectionHandler::OnSpeechData(const beast::flat_buffer& buffer) { pcm_data(i) = static_cast(*pdata); pdata++; } - VLOG(2) << "Recieved " << num_samples << " samples"; - LOG(INFO) << "Recieved " << num_samples << " samples"; + VLOG(2) << "Server: Recieved " << num_samples << " samples"; + LOG(INFO) << "Server: Recieved " << num_samples << " samples"; CHECK(recognizer_ != nullptr); recognizer_->Accept(pcm_data); + + // TODO: return lpartial result + json::value rv = { + {"status", "ok"}, {"type", "partial_result"}, {"result", "TODO"}}; + ws_.text(true); + ws_.write(asio::buffer(json::serialize(rv))); } void ConnectionHandler::DecodeThreadFunc() { @@ -80,9 +86,9 @@ void ConnectionHandler::DecodeThreadFunc() { while (true) { recognizer_->Decode(); if (recognizer_->IsFinished()) { - LOG(INFO) << "enter finish"; + LOG(INFO) << "Server: enter finish"; recognizer_->Decode(); - LOG(INFO) << "finish"; + LOG(INFO) << "Server: finish"; std::string result = recognizer_->GetFinalResult(); OnFinalResult(result); OnFinish(); @@ -135,7 +141,7 @@ void ConnectionHandler::operator()() { ws_.read(buffer); if (ws_.got_text()) { std::string message = beast::buffers_to_string(buffer.data()); - LOG(INFO) << message; + LOG(INFO) << "Server: Text: " << message; OnText(message); if (got_end_tag_) { break; @@ -152,7 +158,7 @@ void ConnectionHandler::operator()() { } } - LOG(INFO) << "Read all pcm data, wait for decoding thread"; + LOG(INFO) << "Server: Read all pcm data, wait for decoding thread"; if (decode_thread_ != nullptr) { decode_thread_->join(); }