parent
e089268642
commit
516eaa05f0
@ -0,0 +1,7 @@
|
||||
project(websocket)
|
||||
|
||||
add_library(websocket STATIC
|
||||
websocket_server.cc
|
||||
websocket_client.cc
|
||||
)
|
||||
target_link_libraries(websocket PUBLIC frontend decoder nnet)
|
@ -0,0 +1,105 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "websocket/websocket_client.h"
|
||||
|
||||
#include "boost/json/src.hpp"
|
||||
|
||||
namespace json = boost::json;
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
WebSocketClient::WebSocketClient(const std::string& host, int port)
|
||||
: host_(host), port_(port) {
|
||||
Connect();
|
||||
t_.reset(new std::thread(&WebSocketClient::ReadLoopFunc, this));
|
||||
}
|
||||
|
||||
void WebSocketClient::Connect() {
|
||||
tcp::resolver resolver{ioc_};
|
||||
// Look up the domain name
|
||||
auto const results = resolver.resolve(host_, std::to_string(port_));
|
||||
// Make the connection on the IP address we get from a lookup
|
||||
auto ep = asio::connect(ws_.next_layer(), results);
|
||||
// Update the host_ string. This will provide the value of the
|
||||
// Host HTTP header during the WebSocket handshake.
|
||||
// See https://tools.ietf.org/html/rfc7230#section-5.4
|
||||
std::string host = host_ + ":" + std::to_string(ep.port());
|
||||
// Perform the websocket handshake
|
||||
ws_.handshake(host, "/");
|
||||
}
|
||||
|
||||
void WebSocketClient::SendTextData(const std::string& data) {
|
||||
ws_.text(true);
|
||||
ws_.write(asio::buffer(data));
|
||||
}
|
||||
|
||||
void WebSocketClient::SendBinaryData(const void* data, size_t size) {
|
||||
ws_.binary(true);
|
||||
ws_.write(asio::buffer(data, size));
|
||||
}
|
||||
|
||||
void WebSocketClient::Close() { ws_.close(websocket::close_code::normal); }
|
||||
|
||||
void WebSocketClient::ReadLoopFunc() {
|
||||
try {
|
||||
while (true) {
|
||||
beast::flat_buffer buffer;
|
||||
ws_.read(buffer);
|
||||
std::string message = beast::buffers_to_string(buffer.data());
|
||||
LOG(INFO) << message;
|
||||
CHECK(ws_.got_text());
|
||||
json::object obj = json::parse(message).as_object();
|
||||
if (obj["status"] != "ok") {
|
||||
break;
|
||||
}
|
||||
if (obj["type"] == "final_result") {
|
||||
result_ = obj["result"].as_string().c_str();
|
||||
}
|
||||
if (obj["type"] == "speech_end") {
|
||||
done_ = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
} catch (beast::system_error const& se) {
|
||||
// This indicates that the session was closed
|
||||
if (se.code() != websocket::error::closed) {
|
||||
LOG(ERROR) << se.code().message();
|
||||
}
|
||||
} catch (std::exception const& e) {
|
||||
LOG(ERROR) << e.what();
|
||||
}
|
||||
}
|
||||
|
||||
void WebSocketClient::Join() { t_->join(); }
|
||||
|
||||
void WebSocketClient::SendStartSignal() {
|
||||
json::value start_tag = {{"signal", "start"}};
|
||||
std::string start_message = json::serialize(start_tag);
|
||||
this->SendTextData(start_message);
|
||||
}
|
||||
|
||||
void WebSocketClient::SendDataEnd() {
|
||||
json::value end_tag = {{"data", "end"}};
|
||||
std::string end_message = json::serialize(end_tag);
|
||||
this->SendTextData(end_message);
|
||||
}
|
||||
|
||||
void WebSocketClient::SendEndSignal() {
|
||||
json::value end_tag = {{"signal", "end"}};
|
||||
std::string end_message = json::serialize(end_tag);
|
||||
this->SendTextData(end_message);
|
||||
}
|
||||
|
||||
} // namespace ppspeech
|
@ -0,0 +1,55 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "base/common.h"
|
||||
|
||||
#include "boost/asio/connect.hpp"
|
||||
#include "boost/asio/ip/tcp.hpp"
|
||||
#include "boost/beast/core.hpp"
|
||||
#include "boost/beast/websocket.hpp"
|
||||
|
||||
namespace beast = boost::beast; // from <boost/beast.hpp>
|
||||
namespace http = beast::http; // from <boost/beast/http.hpp>
|
||||
namespace websocket = beast::websocket; // from <boost/beast/websocket.hpp>
|
||||
namespace asio = boost::asio; // from <boost/asio.hpp>
|
||||
using tcp = boost::asio::ip::tcp; // from <boost/asio/ip/tcp.hpp>
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
class WebSocketClient {
|
||||
public:
|
||||
WebSocketClient(const std::string& host, int port);
|
||||
|
||||
void SendTextData(const std::string& data);
|
||||
void SendBinaryData(const void* data, size_t size);
|
||||
void ReadLoopFunc();
|
||||
void Close();
|
||||
void Join();
|
||||
void SendStartSignal();
|
||||
void SendEndSignal();
|
||||
void SendDataEnd();
|
||||
bool Done() const { return done_; }
|
||||
std::string GetResult() { return result_; }
|
||||
|
||||
private:
|
||||
void Connect();
|
||||
std::string host_;
|
||||
std::string result_;
|
||||
int port_;
|
||||
bool done_ = false;
|
||||
asio::io_context ioc_;
|
||||
websocket::stream<tcp::socket> ws_{ioc_};
|
||||
std::unique_ptr<std::thread> t_{nullptr};
|
||||
};
|
||||
}
|
@ -0,0 +1,192 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
#include "websocket/websocket_server.h"
|
||||
|
||||
#include "base/common.h"
|
||||
#include "boost/json/src.hpp"
|
||||
|
||||
namespace json = boost::json;
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
ConnectionHandler::ConnectionHandler(
|
||||
tcp::socket&& socket, const RecognizerResource& recognizer_resource)
|
||||
: ws_(std::move(socket)), recognizer_resource_(recognizer_resource) {}
|
||||
|
||||
void ConnectionHandler::OnSpeechStart() {
|
||||
LOG(INFO) << "Recieved speech start signal, start reading speech";
|
||||
got_start_tag_ = true;
|
||||
json::value rv = {{"status", "ok"}, {"type", "server_ready"}};
|
||||
ws_.text(true);
|
||||
ws_.write(asio::buffer(json::serialize(rv)));
|
||||
recognizer_ = std::make_shared<Recognizer>(recognizer_resource_);
|
||||
// Start decoder thread
|
||||
decode_thread_ = std::make_shared<std::thread>(
|
||||
&ConnectionHandler::DecodeThreadFunc, this);
|
||||
}
|
||||
|
||||
void ConnectionHandler::OnSpeechEnd() {
|
||||
LOG(INFO) << "Recieved speech end signal";
|
||||
CHECK(recognizer_ != nullptr);
|
||||
recognizer_->SetFinished();
|
||||
got_end_tag_ = true;
|
||||
}
|
||||
|
||||
void ConnectionHandler::OnFinalResult(const std::string& result) {
|
||||
LOG(INFO) << "Final result: " << result;
|
||||
json::value rv = {
|
||||
{"status", "ok"}, {"type", "final_result"}, {"result", result}};
|
||||
ws_.text(true);
|
||||
ws_.write(asio::buffer(json::serialize(rv)));
|
||||
}
|
||||
|
||||
void ConnectionHandler::OnFinish() {
|
||||
// Send finish tag
|
||||
json::value rv = {{"status", "ok"}, {"type", "speech_end"}};
|
||||
ws_.text(true);
|
||||
ws_.write(asio::buffer(json::serialize(rv)));
|
||||
}
|
||||
|
||||
void ConnectionHandler::OnSpeechData(const beast::flat_buffer& buffer) {
|
||||
// Read binary PCM data
|
||||
int num_samples = buffer.size() / sizeof(int16_t);
|
||||
kaldi::Vector<kaldi::BaseFloat> pcm_data(num_samples);
|
||||
const int16_t* pdata = static_cast<const int16_t*>(buffer.data().data());
|
||||
for (int i = 0; i < num_samples; i++) {
|
||||
pcm_data(i) = static_cast<float>(*pdata);
|
||||
pdata++;
|
||||
}
|
||||
VLOG(2) << "Recieved " << num_samples << " samples";
|
||||
LOG(INFO) << "Recieved " << num_samples << " samples";
|
||||
CHECK(recognizer_ != nullptr);
|
||||
recognizer_->Accept(pcm_data);
|
||||
}
|
||||
|
||||
void ConnectionHandler::DecodeThreadFunc() {
|
||||
try {
|
||||
while (true) {
|
||||
recognizer_->Decode();
|
||||
if (recognizer_->IsFinished()) {
|
||||
LOG(INFO) << "enter finish";
|
||||
recognizer_->Decode();
|
||||
LOG(INFO) << "finish";
|
||||
std::string result = recognizer_->GetFinalResult();
|
||||
OnFinalResult(result);
|
||||
OnFinish();
|
||||
stop_recognition_ = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
} catch (std::exception const& e) {
|
||||
LOG(ERROR) << e.what();
|
||||
}
|
||||
}
|
||||
|
||||
void ConnectionHandler::OnError(const std::string& message) {
|
||||
json::value rv = {{"status", "failed"}, {"message", message}};
|
||||
ws_.text(true);
|
||||
ws_.write(asio::buffer(json::serialize(rv)));
|
||||
// Close websocket
|
||||
ws_.close(websocket::close_code::normal);
|
||||
}
|
||||
|
||||
void ConnectionHandler::OnText(const std::string& message) {
|
||||
json::value v = json::parse(message);
|
||||
if (v.is_object()) {
|
||||
json::object obj = v.get_object();
|
||||
if (obj.find("signal") != obj.end()) {
|
||||
json::string signal = obj["signal"].as_string();
|
||||
if (signal == "start") {
|
||||
OnSpeechStart();
|
||||
} else if (signal == "end") {
|
||||
OnSpeechEnd();
|
||||
} else {
|
||||
OnError("Unexpected signal type");
|
||||
}
|
||||
} else {
|
||||
OnError("Wrong message header");
|
||||
}
|
||||
} else {
|
||||
OnError("Wrong protocol");
|
||||
}
|
||||
}
|
||||
|
||||
void ConnectionHandler::operator()() {
|
||||
try {
|
||||
// Accept the websocket handshake
|
||||
ws_.accept();
|
||||
for (;;) {
|
||||
// This buffer will hold the incoming message
|
||||
beast::flat_buffer buffer;
|
||||
// Read a message
|
||||
ws_.read(buffer);
|
||||
if (ws_.got_text()) {
|
||||
std::string message = beast::buffers_to_string(buffer.data());
|
||||
LOG(INFO) << message;
|
||||
OnText(message);
|
||||
if (got_end_tag_) {
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
if (!got_start_tag_) {
|
||||
OnError("Start signal is expected before binary data");
|
||||
} else {
|
||||
if (stop_recognition_) {
|
||||
break;
|
||||
}
|
||||
OnSpeechData(buffer);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
LOG(INFO) << "Read all pcm data, wait for decoding thread";
|
||||
if (decode_thread_ != nullptr) {
|
||||
decode_thread_->join();
|
||||
}
|
||||
} catch (beast::system_error const& se) {
|
||||
// This indicates that the session was closed
|
||||
if (se.code() != websocket::error::closed) {
|
||||
if (decode_thread_ != nullptr) {
|
||||
decode_thread_->join();
|
||||
}
|
||||
OnSpeechEnd();
|
||||
LOG(ERROR) << se.code().message();
|
||||
}
|
||||
} catch (std::exception const& e) {
|
||||
LOG(ERROR) << e.what();
|
||||
}
|
||||
}
|
||||
|
||||
void WebSocketServer::Start() {
|
||||
try {
|
||||
auto const address = asio::ip::make_address("0.0.0.0");
|
||||
tcp::acceptor acceptor{ioc_, {address, static_cast<uint16_t>(port_)}};
|
||||
for (;;) {
|
||||
// This will receive the new connection
|
||||
tcp::socket socket{ioc_};
|
||||
// Block until we get a connection
|
||||
acceptor.accept(socket);
|
||||
// Launch the session, transferring ownership of the socket
|
||||
ConnectionHandler handler(std::move(socket), recognizer_resource_);
|
||||
std::thread t(std::move(handler));
|
||||
t.detach();
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
LOG(FATAL) << e.what();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ppspeech
|
@ -0,0 +1,80 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "base/common.h"
|
||||
|
||||
#include "boost/asio/connect.hpp"
|
||||
#include "boost/asio/ip/tcp.hpp"
|
||||
#include "boost/beast/core.hpp"
|
||||
#include "boost/beast/websocket.hpp"
|
||||
|
||||
#include "decoder/recognizer.h"
|
||||
#include "frontend/audio/feature_pipeline.h"
|
||||
|
||||
namespace beast = boost::beast; // from <boost/beast.hpp>
|
||||
namespace http = beast::http; // from <boost/beast/http.hpp>
|
||||
namespace websocket = beast::websocket; // from <boost/beast/websocket.hpp>
|
||||
namespace asio = boost::asio; // from <boost/asio.hpp>
|
||||
using tcp = boost::asio::ip::tcp; // from <boost/asio/ip/tcp.hpp>
|
||||
|
||||
namespace ppspeech {
|
||||
class ConnectionHandler {
|
||||
public:
|
||||
ConnectionHandler(tcp::socket&& socket,
|
||||
const RecognizerResource& recognizer_resource_);
|
||||
void operator()();
|
||||
|
||||
private:
|
||||
void OnSpeechStart();
|
||||
void OnSpeechEnd();
|
||||
void OnText(const std::string& message);
|
||||
void OnFinish();
|
||||
void OnSpeechData(const beast::flat_buffer& buffer);
|
||||
void OnError(const std::string& message);
|
||||
void OnPartialResult(const std::string& result);
|
||||
void OnFinalResult(const std::string& result);
|
||||
void DecodeThreadFunc();
|
||||
std::string SerializeResult(bool finish);
|
||||
|
||||
bool continuous_decoding_ = false;
|
||||
int nbest_ = 1;
|
||||
websocket::stream<tcp::socket> ws_;
|
||||
RecognizerResource recognizer_resource_;
|
||||
|
||||
bool got_start_tag_ = false;
|
||||
bool got_end_tag_ = false;
|
||||
// When endpoint is detected, stop recognition, and stop receiving data.
|
||||
bool stop_recognition_ = false;
|
||||
std::shared_ptr<ppspeech::Recognizer> recognizer_ = nullptr;
|
||||
std::shared_ptr<std::thread> decode_thread_ = nullptr;
|
||||
};
|
||||
|
||||
class WebSocketServer {
|
||||
public:
|
||||
WebSocketServer(int port, const RecognizerResource& recognizer_resource)
|
||||
: port_(port), recognizer_resource_(recognizer_resource) {}
|
||||
|
||||
void Start();
|
||||
|
||||
private:
|
||||
int port_;
|
||||
RecognizerResource recognizer_resource_;
|
||||
// The io_context is required for all I/O
|
||||
asio::io_context ioc_{1};
|
||||
};
|
||||
|
||||
} // namespace ppspeech
|
Loading…
Reference in new issue