Merge pull request #1982 from SmileGoat/refactor_file_struct

[speechx] add partail result, test=doc
pull/1986/head
Hui Zhang 2 years ago committed by GitHub
commit 6a8d6c2c7c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -10,5 +10,5 @@ TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
export LC_AL=C
SPEECHX_BIN=$SPEECHX_BUILD/websocket
SPEECHX_BIN=$SPEECHX_BUILD/protocol/websocket
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN

@ -45,7 +45,7 @@ export GLOG_logtostderr=1
# 3. gen cmvn
cmvn=$data/cmvn.ark
cmvn-json2kaldi --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn
cmvn_json2kaldi_main --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn
wfst=$data/wfst/

@ -34,9 +34,9 @@ add_subdirectory(decoder)
include_directories(
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/websocket
${CMAKE_CURRENT_SOURCE_DIR}/protocol
)
add_subdirectory(websocket)
add_subdirectory(protocol)
include_directories(
${CMAKE_CURRENT_SOURCE_DIR}

@ -47,6 +47,26 @@ void TLGDecoder::Reset() {
return;
}
std::string TLGDecoder::GetPartialResult() {
if (frame_decoded_size_ == 0) {
// Assertion failed: (this->NumFramesDecoded() > 0 && "You cannot call
// BestPathEnd if no frames were decoded.")
return std::string("");
}
kaldi::Lattice lat;
kaldi::LatticeWeight weight;
std::vector<int> alignment;
std::vector<int> words_id;
decoder_->GetBestPath(&lat, false);
fst::GetLinearSymbolSequence(lat, &alignment, &words_id, &weight);
std::string words;
for (int32 idx = 0; idx < words_id.size(); ++idx) {
std::string word = word_symbol_table_->Find(words_id[idx]);
words += word;
}
return words;
}
std::string TLGDecoder::GetFinalBestPath() {
if (frame_decoded_size_ == 0) {
// Assertion failed: (this->NumFramesDecoded() > 0 && "You cannot call

@ -38,6 +38,7 @@ class TLGDecoder {
std::string GetBestPath();
std::vector<std::pair<double, std::string>> GetNBestPath();
std::string GetFinalBestPath();
std::string GetPartialResult();
int NumFrameDecoded();
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs,
std::vector<std::string>& nbest_words);

@ -44,6 +44,10 @@ std::string Recognizer::GetFinalResult() {
return decoder_->GetFinalBestPath();
}
std::string Recognizer::GetPartialResult() {
return decoder_->GetPartialResult();
}
void Recognizer::SetFinished() {
feature_pipeline_->SetFinished();
input_finished_ = true;

@ -43,6 +43,7 @@ class Recognizer {
void Accept(const kaldi::Vector<kaldi::BaseFloat>& waves);
void Decode();
std::string GetFinalResult();
std::string GetPartialResult();
void SetFinished();
bool IsFinished();
void Reset();

@ -0,0 +1,3 @@
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_subdirectory(websocket)

@ -67,6 +67,9 @@ void WebSocketClient::ReadLoopFunc() {
if (obj["type"] == "final_result") {
result_ = obj["result"].as_string().c_str();
}
if (obj["type"] == "partial_result") {
partial_result_ = obj["result"].as_string().c_str();
}
if (obj["type"] == "speech_end") {
done_ = true;
break;

@ -40,12 +40,14 @@ class WebSocketClient {
void SendEndSignal();
void SendDataEnd();
bool Done() const { return done_; }
std::string GetResult() { return result_; }
std::string GetResult() const { return result_; }
std::string GetPartialResult() const { return partial_result_;}
private:
void Connect();
std::string host_;
std::string result_;
std::string partial_result_;
int port_;
bool done_ = false;
asio::io_context ioc_;

@ -59,7 +59,6 @@ int main(int argc, char* argv[]) {
client.SendBinaryData(wav_chunk.data(),
wav_chunk.size() * sizeof(int16));
sample_offset += cur_chunk_size;
LOG(INFO) << "Send " << cur_chunk_size << " samples";
std::this_thread::sleep_for(

@ -75,9 +75,10 @@ void ConnectionHandler::OnSpeechData(const beast::flat_buffer& buffer) {
CHECK(recognizer_ != nullptr);
recognizer_->Accept(pcm_data);
// TODO: return lpartial result
std::string partial_result = recognizer_->GetPartialResult();
json::value rv = {
{"status", "ok"}, {"type", "partial_result"}, {"result", "TODO"}};
{"status", "ok"}, {"type", "partial_result"}, {"result", partial_result}};
ws_.text(true);
ws_.write(asio::buffer(json::serialize(rv)));
}

@ -44,7 +44,6 @@ class ConnectionHandler {
void OnFinish();
void OnSpeechData(const beast::flat_buffer& buffer);
void OnError(const std::string& message);
void OnPartialResult(const std::string& result);
void OnFinalResult(const std::string& result);
void DecodeThreadFunc();
std::string SerializeResult(bool finish);
Loading…
Cancel
Save