diff --git a/speechx/examples/ds2_ol/websocket/websocket_server.sh b/speechx/examples/ds2_ol/websocket/websocket_server.sh index fc57e326..f798dfd4 100755 --- a/speechx/examples/ds2_ol/websocket/websocket_server.sh +++ b/speechx/examples/ds2_ol/websocket/websocket_server.sh @@ -45,7 +45,7 @@ export GLOG_logtostderr=1 # 3. gen cmvn cmvn=$data/cmvn.ark -cmvn-json2kaldi --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn +cmvn_json2kaldi_main --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn wfst=$data/wfst/ diff --git a/speechx/speechx/decoder/ctc_tlg_decoder.cc b/speechx/speechx/decoder/ctc_tlg_decoder.cc index 02e64316..3f8bdd5a 100644 --- a/speechx/speechx/decoder/ctc_tlg_decoder.cc +++ b/speechx/speechx/decoder/ctc_tlg_decoder.cc @@ -47,6 +47,26 @@ void TLGDecoder::Reset() { return; } +std::string TLGDecoder::GetPartialResult() { + if (frame_decoded_size_ == 0) { + // Assertion failed: (this->NumFramesDecoded() > 0 && "You cannot call + // BestPathEnd if no frames were decoded.") + return std::string(""); + } + kaldi::Lattice lat; + kaldi::LatticeWeight weight; + std::vector alignment; + std::vector words_id; + decoder_->GetBestPath(&lat, false); + fst::GetLinearSymbolSequence(lat, &alignment, &words_id, &weight); + std::string words; + for (int32 idx = 0; idx < words_id.size(); ++idx) { + std::string word = word_symbol_table_->Find(words_id[idx]); + words += word; + } + return words; +} + std::string TLGDecoder::GetFinalBestPath() { if (frame_decoded_size_ == 0) { // Assertion failed: (this->NumFramesDecoded() > 0 && "You cannot call diff --git a/speechx/speechx/decoder/ctc_tlg_decoder.h b/speechx/speechx/decoder/ctc_tlg_decoder.h index 361c44af..1ac46ac6 100644 --- a/speechx/speechx/decoder/ctc_tlg_decoder.h +++ b/speechx/speechx/decoder/ctc_tlg_decoder.h @@ -38,6 +38,7 @@ class TLGDecoder { std::string GetBestPath(); std::vector> GetNBestPath(); std::string GetFinalBestPath(); + std::string GetPartialResult(); int NumFrameDecoded(); int DecodeLikelihoods(const std::vector>& probs, std::vector& nbest_words); diff --git a/speechx/speechx/decoder/recognizer.cc b/speechx/speechx/decoder/recognizer.cc index 2c90ada9..44c3911c 100644 --- a/speechx/speechx/decoder/recognizer.cc +++ b/speechx/speechx/decoder/recognizer.cc @@ -44,6 +44,10 @@ std::string Recognizer::GetFinalResult() { return decoder_->GetFinalBestPath(); } +std::string Recognizer::GetPartialResult() { + return decoder_->GetPartialResult(); +} + void Recognizer::SetFinished() { feature_pipeline_->SetFinished(); input_finished_ = true; diff --git a/speechx/speechx/decoder/recognizer.h b/speechx/speechx/decoder/recognizer.h index 9a7e7d11..35e1e167 100644 --- a/speechx/speechx/decoder/recognizer.h +++ b/speechx/speechx/decoder/recognizer.h @@ -43,6 +43,7 @@ class Recognizer { void Accept(const kaldi::Vector& waves); void Decode(); std::string GetFinalResult(); + std::string GetPartialResult(); void SetFinished(); bool IsFinished(); void Reset(); diff --git a/speechx/speechx/websocket/websocket_client.cc b/speechx/speechx/websocket/websocket_client.cc index 6bd930b8..3a852305 100644 --- a/speechx/speechx/websocket/websocket_client.cc +++ b/speechx/speechx/websocket/websocket_client.cc @@ -67,6 +67,9 @@ void WebSocketClient::ReadLoopFunc() { if (obj["type"] == "final_result") { result_ = obj["result"].as_string().c_str(); } + if (obj["type"] == "partial_result") { + partial_result_ = obj["partial_result"].as_string().c_str(); + } if (obj["type"] == "speech_end") { done_ = true; break; diff --git a/speechx/speechx/websocket/websocket_client.h b/speechx/speechx/websocket/websocket_client.h index ac0aed31..7d05448e 100644 --- a/speechx/speechx/websocket/websocket_client.h +++ b/speechx/speechx/websocket/websocket_client.h @@ -41,11 +41,13 @@ class WebSocketClient { void SendDataEnd(); bool Done() const { return done_; } std::string GetResult() { return result_; } + std::string GetPartialResult() { return partial_result_; } private: void Connect(); std::string host_; std::string result_; + std::string partial_result_; int port_; bool done_ = false; asio::io_context ioc_; diff --git a/speechx/speechx/websocket/websocket_client_main.cc b/speechx/speechx/websocket/websocket_client_main.cc index df658b0a..7ad36e3a 100644 --- a/speechx/speechx/websocket/websocket_client_main.cc +++ b/speechx/speechx/websocket/websocket_client_main.cc @@ -59,7 +59,6 @@ int main(int argc, char* argv[]) { client.SendBinaryData(wav_chunk.data(), wav_chunk.size() * sizeof(int16)); - sample_offset += cur_chunk_size; LOG(INFO) << "Send " << cur_chunk_size << " samples"; std::this_thread::sleep_for( diff --git a/speechx/speechx/websocket/websocket_server.cc b/speechx/speechx/websocket/websocket_server.cc index 28c9eca4..569f5378 100644 --- a/speechx/speechx/websocket/websocket_server.cc +++ b/speechx/speechx/websocket/websocket_server.cc @@ -75,9 +75,10 @@ void ConnectionHandler::OnSpeechData(const beast::flat_buffer& buffer) { CHECK(recognizer_ != nullptr); recognizer_->Accept(pcm_data); - // TODO: return lpartial result + std::string partial_result = recognizer_->GetPartialResult(); + json::value rv = { - {"status", "ok"}, {"type", "partial_result"}, {"result", "TODO"}}; + {"status", "ok"}, {"type", "partial_result"}, {"partial_result", partial_result}}; ws_.text(true); ws_.write(asio::buffer(json::serialize(rv))); } diff --git a/speechx/speechx/websocket/websocket_server.h b/speechx/speechx/websocket/websocket_server.h index 9ea88282..009fc42e 100644 --- a/speechx/speechx/websocket/websocket_server.h +++ b/speechx/speechx/websocket/websocket_server.h @@ -44,7 +44,6 @@ class ConnectionHandler { void OnFinish(); void OnSpeechData(const beast::flat_buffer& buffer); void OnError(const std::string& message); - void OnPartialResult(const std::string& result); void OnFinalResult(const std::string& result); void DecodeThreadFunc(); std::string SerializeResult(bool finish);