diff --git a/speechx/speechx/websocket/CMakeLists.txt b/speechx/speechx/websocket/CMakeLists.txt new file mode 100644 index 000000000..582a38031 --- /dev/null +++ b/speechx/speechx/websocket/CMakeLists.txt @@ -0,0 +1,7 @@ +project(websocket) + +add_library(websocket STATIC + websocket_server.cc + websocket_client.cc +) +target_link_libraries(websocket PUBLIC frontend decoder nnet) diff --git a/speechx/speechx/websocket/websocket_client.cc b/speechx/speechx/websocket/websocket_client.cc new file mode 100644 index 000000000..bf3bbef26 --- /dev/null +++ b/speechx/speechx/websocket/websocket_client.cc @@ -0,0 +1,105 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "websocket/websocket_client.h" + +#include "boost/json/src.hpp" + +namespace json = boost::json; + +namespace ppspeech { + +WebSocketClient::WebSocketClient(const std::string& host, int port) + : host_(host), port_(port) { + Connect(); + t_.reset(new std::thread(&WebSocketClient::ReadLoopFunc, this)); +} + +void WebSocketClient::Connect() { + tcp::resolver resolver{ioc_}; + // Look up the domain name + auto const results = resolver.resolve(host_, std::to_string(port_)); + // Make the connection on the IP address we get from a lookup + auto ep = asio::connect(ws_.next_layer(), results); + // Update the host_ string. This will provide the value of the + // Host HTTP header during the WebSocket handshake. + // See https://tools.ietf.org/html/rfc7230#section-5.4 + std::string host = host_ + ":" + std::to_string(ep.port()); + // Perform the websocket handshake + ws_.handshake(host, "/"); +} + +void WebSocketClient::SendTextData(const std::string& data) { + ws_.text(true); + ws_.write(asio::buffer(data)); +} + +void WebSocketClient::SendBinaryData(const void* data, size_t size) { + ws_.binary(true); + ws_.write(asio::buffer(data, size)); +} + +void WebSocketClient::Close() { ws_.close(websocket::close_code::normal); } + +void WebSocketClient::ReadLoopFunc() { + try { + while (true) { + beast::flat_buffer buffer; + ws_.read(buffer); + std::string message = beast::buffers_to_string(buffer.data()); + LOG(INFO) << message; + CHECK(ws_.got_text()); + json::object obj = json::parse(message).as_object(); + if (obj["status"] != "ok") { + break; + } + if (obj["type"] == "final_result") { + result_ = obj["result"].as_string().c_str(); + } + if (obj["type"] == "speech_end") { + done_ = true; + break; + } + } + } catch (beast::system_error const& se) { + // This indicates that the session was closed + if (se.code() != websocket::error::closed) { + LOG(ERROR) << se.code().message(); + } + } catch (std::exception const& e) { + LOG(ERROR) << e.what(); + } +} + +void WebSocketClient::Join() { t_->join(); } + +void WebSocketClient::SendStartSignal() { + json::value start_tag = {{"signal", "start"}}; + std::string start_message = json::serialize(start_tag); + this->SendTextData(start_message); +} + +void WebSocketClient::SendDataEnd() { + json::value end_tag = {{"data", "end"}}; + std::string end_message = json::serialize(end_tag); + this->SendTextData(end_message); +} + +void WebSocketClient::SendEndSignal() { + json::value end_tag = {{"signal", "end"}}; + std::string end_message = json::serialize(end_tag); + this->SendTextData(end_message); +} + +} // namespace ppspeech diff --git a/speechx/speechx/websocket/websocket_client.h b/speechx/speechx/websocket/websocket_client.h new file mode 100644 index 000000000..35def076d --- /dev/null +++ b/speechx/speechx/websocket/websocket_client.h @@ -0,0 +1,55 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/common.h" + +#include "boost/asio/connect.hpp" +#include "boost/asio/ip/tcp.hpp" +#include "boost/beast/core.hpp" +#include "boost/beast/websocket.hpp" + +namespace beast = boost::beast; // from +namespace http = beast::http; // from +namespace websocket = beast::websocket; // from +namespace asio = boost::asio; // from +using tcp = boost::asio::ip::tcp; // from + +namespace ppspeech { + +class WebSocketClient { + public: + WebSocketClient(const std::string& host, int port); + + void SendTextData(const std::string& data); + void SendBinaryData(const void* data, size_t size); + void ReadLoopFunc(); + void Close(); + void Join(); + void SendStartSignal(); + void SendEndSignal(); + void SendDataEnd(); + bool Done() const { return done_; } + std::string GetResult() { return result_; } + + private: + void Connect(); + std::string host_; + std::string result_; + int port_; + bool done_ = false; + asio::io_context ioc_; + websocket::stream ws_{ioc_}; + std::unique_ptr t_{nullptr}; +}; +} \ No newline at end of file diff --git a/speechx/speechx/websocket/websocket_server.cc b/speechx/speechx/websocket/websocket_server.cc new file mode 100644 index 000000000..3f6da894b --- /dev/null +++ b/speechx/speechx/websocket/websocket_server.cc @@ -0,0 +1,192 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include "websocket/websocket_server.h" + +#include "base/common.h" +#include "boost/json/src.hpp" + +namespace json = boost::json; + +namespace ppspeech { + +ConnectionHandler::ConnectionHandler( + tcp::socket&& socket, const RecognizerResource& recognizer_resource) + : ws_(std::move(socket)), recognizer_resource_(recognizer_resource) {} + +void ConnectionHandler::OnSpeechStart() { + LOG(INFO) << "Recieved speech start signal, start reading speech"; + got_start_tag_ = true; + json::value rv = {{"status", "ok"}, {"type", "server_ready"}}; + ws_.text(true); + ws_.write(asio::buffer(json::serialize(rv))); + recognizer_ = std::make_shared(recognizer_resource_); + // Start decoder thread + decode_thread_ = std::make_shared( + &ConnectionHandler::DecodeThreadFunc, this); +} + +void ConnectionHandler::OnSpeechEnd() { + LOG(INFO) << "Recieved speech end signal"; + CHECK(recognizer_ != nullptr); + recognizer_->SetFinished(); + got_end_tag_ = true; +} + +void ConnectionHandler::OnFinalResult(const std::string& result) { + LOG(INFO) << "Final result: " << result; + json::value rv = { + {"status", "ok"}, {"type", "final_result"}, {"result", result}}; + ws_.text(true); + ws_.write(asio::buffer(json::serialize(rv))); +} + +void ConnectionHandler::OnFinish() { + // Send finish tag + json::value rv = {{"status", "ok"}, {"type", "speech_end"}}; + ws_.text(true); + ws_.write(asio::buffer(json::serialize(rv))); +} + +void ConnectionHandler::OnSpeechData(const beast::flat_buffer& buffer) { + // Read binary PCM data + int num_samples = buffer.size() / sizeof(int16_t); + kaldi::Vector pcm_data(num_samples); + const int16_t* pdata = static_cast(buffer.data().data()); + for (int i = 0; i < num_samples; i++) { + pcm_data(i) = static_cast(*pdata); + pdata++; + } + VLOG(2) << "Recieved " << num_samples << " samples"; + LOG(INFO) << "Recieved " << num_samples << " samples"; + CHECK(recognizer_ != nullptr); + recognizer_->Accept(pcm_data); +} + +void ConnectionHandler::DecodeThreadFunc() { + try { + while (true) { + recognizer_->Decode(); + if (recognizer_->IsFinished()) { + LOG(INFO) << "enter finish"; + recognizer_->Decode(); + LOG(INFO) << "finish"; + std::string result = recognizer_->GetFinalResult(); + OnFinalResult(result); + OnFinish(); + stop_recognition_ = true; + break; + } + } + } catch (std::exception const& e) { + LOG(ERROR) << e.what(); + } +} + +void ConnectionHandler::OnError(const std::string& message) { + json::value rv = {{"status", "failed"}, {"message", message}}; + ws_.text(true); + ws_.write(asio::buffer(json::serialize(rv))); + // Close websocket + ws_.close(websocket::close_code::normal); +} + +void ConnectionHandler::OnText(const std::string& message) { + json::value v = json::parse(message); + if (v.is_object()) { + json::object obj = v.get_object(); + if (obj.find("signal") != obj.end()) { + json::string signal = obj["signal"].as_string(); + if (signal == "start") { + OnSpeechStart(); + } else if (signal == "end") { + OnSpeechEnd(); + } else { + OnError("Unexpected signal type"); + } + } else { + OnError("Wrong message header"); + } + } else { + OnError("Wrong protocol"); + } +} + +void ConnectionHandler::operator()() { + try { + // Accept the websocket handshake + ws_.accept(); + for (;;) { + // This buffer will hold the incoming message + beast::flat_buffer buffer; + // Read a message + ws_.read(buffer); + if (ws_.got_text()) { + std::string message = beast::buffers_to_string(buffer.data()); + LOG(INFO) << message; + OnText(message); + if (got_end_tag_) { + break; + } + } else { + if (!got_start_tag_) { + OnError("Start signal is expected before binary data"); + } else { + if (stop_recognition_) { + break; + } + OnSpeechData(buffer); + } + } + } + + LOG(INFO) << "Read all pcm data, wait for decoding thread"; + if (decode_thread_ != nullptr) { + decode_thread_->join(); + } + } catch (beast::system_error const& se) { + // This indicates that the session was closed + if (se.code() != websocket::error::closed) { + if (decode_thread_ != nullptr) { + decode_thread_->join(); + } + OnSpeechEnd(); + LOG(ERROR) << se.code().message(); + } + } catch (std::exception const& e) { + LOG(ERROR) << e.what(); + } +} + +void WebSocketServer::Start() { + try { + auto const address = asio::ip::make_address("0.0.0.0"); + tcp::acceptor acceptor{ioc_, {address, static_cast(port_)}}; + for (;;) { + // This will receive the new connection + tcp::socket socket{ioc_}; + // Block until we get a connection + acceptor.accept(socket); + // Launch the session, transferring ownership of the socket + ConnectionHandler handler(std::move(socket), recognizer_resource_); + std::thread t(std::move(handler)); + t.detach(); + } + } catch (const std::exception& e) { + LOG(FATAL) << e.what(); + } +} + +} // namespace ppspeech diff --git a/speechx/speechx/websocket/websocket_server.h b/speechx/speechx/websocket/websocket_server.h new file mode 100644 index 000000000..469f123f1 --- /dev/null +++ b/speechx/speechx/websocket/websocket_server.h @@ -0,0 +1,80 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#pragma once + +#include "base/common.h" + +#include "boost/asio/connect.hpp" +#include "boost/asio/ip/tcp.hpp" +#include "boost/beast/core.hpp" +#include "boost/beast/websocket.hpp" + +#include "decoder/recognizer.h" +#include "frontend/audio/feature_pipeline.h" + +namespace beast = boost::beast; // from +namespace http = beast::http; // from +namespace websocket = beast::websocket; // from +namespace asio = boost::asio; // from +using tcp = boost::asio::ip::tcp; // from + +namespace ppspeech { +class ConnectionHandler { + public: + ConnectionHandler(tcp::socket&& socket, + const RecognizerResource& recognizer_resource_); + void operator()(); + + private: + void OnSpeechStart(); + void OnSpeechEnd(); + void OnText(const std::string& message); + void OnFinish(); + void OnSpeechData(const beast::flat_buffer& buffer); + void OnError(const std::string& message); + void OnPartialResult(const std::string& result); + void OnFinalResult(const std::string& result); + void DecodeThreadFunc(); + std::string SerializeResult(bool finish); + + bool continuous_decoding_ = false; + int nbest_ = 1; + websocket::stream ws_; + RecognizerResource recognizer_resource_; + + bool got_start_tag_ = false; + bool got_end_tag_ = false; + // When endpoint is detected, stop recognition, and stop receiving data. + bool stop_recognition_ = false; + std::shared_ptr recognizer_ = nullptr; + std::shared_ptr decode_thread_ = nullptr; +}; + +class WebSocketServer { + public: + WebSocketServer(int port, const RecognizerResource& recognizer_resource) + : port_(port), recognizer_resource_(recognizer_resource) {} + + void Start(); + + private: + int port_; + RecognizerResource recognizer_resource_; + // The io_context is required for all I/O + asio::io_context ioc_{1}; +}; + +} // namespace ppspeech