add websocket

pull/1720/head
Yang Zhou 3 years ago
parent b78bc6375b
commit 1133540682

@ -63,7 +63,8 @@ include(libsndfile)
# include(boost) # not work # include(boost) # not work
set(boost_SOURCE_DIR ${fc_patch}/boost-src) set(boost_SOURCE_DIR ${fc_patch}/boost-src)
set(BOOST_ROOT ${boost_SOURCE_DIR}) set(BOOST_ROOT ${boost_SOURCE_DIR})
# #find_package(boost REQUIRED PATHS ${BOOST_ROOT}) include_directories(${boost_SOURCE_DIR})
link_directories(${boost_SOURCE_DIR}/stage/lib)
# Eigen # Eigen
include(eigen) include(eigen)
@ -141,4 +142,4 @@ set(DEPS ${DEPS}
set(SPEECHX_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/speechx) set(SPEECHX_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/speechx)
add_subdirectory(speechx) add_subdirectory(speechx)
add_subdirectory(examples) add_subdirectory(examples)

@ -2,4 +2,5 @@ cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_subdirectory(feat) add_subdirectory(feat)
add_subdirectory(nnet) add_subdirectory(nnet)
add_subdirectory(decoder) add_subdirectory(decoder)
add_subdirectory(websocket)

@ -1,6 +1,6 @@
# This contains the locations of binarys build required for running the examples. # This contains the locations of binarys build required for running the examples.
SPEECHX_ROOT=$PWD/../../../ SPEECHX_ROOT=$PWD/../../..
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
SPEECHX_TOOLS=$SPEECHX_ROOT/tools SPEECHX_TOOLS=$SPEECHX_ROOT/tools
@ -10,5 +10,5 @@ TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
export LC_AL=C export LC_AL=C
SPEECHX_BIN=$SPEECHX_EXAMPLES/ds2_ol/decoder:$SPEECHX_EXAMPLES/ds2_ol/feat SPEECHX_BIN=$SPEECHX_EXAMPLES/ds2_ol/decoder:$SPEECHX_EXAMPLES/ds2_ol/feat:$SPEECHX_EXAMPLES/ds2_ol/websocket
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN

@ -42,7 +42,7 @@ fi
if [ ! -d $ckpt_dir ]; then if [ ! -d $ckpt_dir ]; then
mkdir -p $ckpt_dir mkdir -p $ckpt_dir
wget -P $ckpt_dir -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz wget -P $ckpt_dir -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
tar xzfv $model_dir/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz -C $ckpt_dir tar xzfv $ckpt_dir/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz -C $ckpt_dir
fi fi
lm=$data/zh_giga.no_cna_cmn.prune01244.klm lm=$data/zh_giga.no_cna_cmn.prune01244.klm
@ -79,7 +79,7 @@ utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.wolm.log \
ctc-prefix-beam-search-decoder-ol \ ctc-prefix-beam-search-decoder-ol \
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \ --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
--model_path=$model_dir/avg_1.jit.pdmodel \ --model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdiparams \ --params_path=$model_dir/avg_1.jit.pdiparams \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--dict_file=$vocb_dir/vocab.txt \ --dict_file=$vocb_dir/vocab.txt \
--result_wspecifier=ark,t:$data/split${nj}/JOB/result --result_wspecifier=ark,t:$data/split${nj}/JOB/result
@ -92,7 +92,7 @@ utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.lm.log \
ctc-prefix-beam-search-decoder-ol \ ctc-prefix-beam-search-decoder-ol \
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \ --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
--model_path=$model_dir/avg_1.jit.pdmodel \ --model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdiparams \ --params_path=$model_dir/avg_1.jit.pdiparams \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--dict_file=$vocb_dir/vocab.txt \ --dict_file=$vocb_dir/vocab.txt \
--lm_path=$lm \ --lm_path=$lm \
@ -104,9 +104,9 @@ utils/compute-wer.py --char=1 --v=1 ${label_file}_lm $text > ${wer}_lm
graph_dir=./aishell_graph graph_dir=./aishell_graph
if [ ! -d $ ]; then if [ ! -d $graph_dir ]; then
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_graph.zip wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_graph.zip
unzip -d aishell_graph.zip unzip aishell_graph.zip
fi fi
@ -115,7 +115,7 @@ utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.wfst.log \
wfst-decoder-ol \ wfst-decoder-ol \
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \ --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
--model_path=$model_dir/avg_1.jit.pdmodel \ --model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdiparams \ --params_path=$model_dir/avg_1.jit.pdiparams \
--word_symbol_table=$graph_dir/words.txt \ --word_symbol_table=$graph_dir/words.txt \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--graph_path=$graph_dir/TLG.fst --max_active=7500 \ --graph_path=$graph_dir/TLG.fst --max_active=7500 \

@ -17,3 +17,6 @@ add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS}) target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
add_executable(recognizer_test_main ${CMAKE_CURRENT_SOURCE_DIR}/recognizer_test_main.cc)
target_include_directories(recognizer_test_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(recognizer_test_main PUBLIC frontend kaldi-feat-common nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder ${DEPS})

@ -34,12 +34,10 @@ DEFINE_int32(receptive_field_length,
DEFINE_int32(downsampling_rate, DEFINE_int32(downsampling_rate,
4, 4,
"two CNN(kernel=5) module downsampling rate."); "two CNN(kernel=5) module downsampling rate.");
DEFINE_string(
model_input_names,
"audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box",
"model input names");
DEFINE_string(model_output_names, DEFINE_string(model_output_names,
"softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0", "save_infer_model/scale_0.tmp_1,save_infer_model/"
"scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/"
"scale_3.tmp_1",
"model output names"); "model output names");
DEFINE_string(model_cache_names, "5-1-1024,5-1-1024", "model cache names"); DEFINE_string(model_cache_names, "5-1-1024,5-1-1024", "model cache names");
@ -52,18 +50,14 @@ int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false); gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
CHECK(FLAGS_result_wspecifier != "");
CHECK(FLAGS_feature_rspecifier != "");
kaldi::SequentialBaseFloatMatrixReader feature_reader( kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_rspecifier); FLAGS_feature_rspecifier);
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier); kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
std::string model_path = FLAGS_model_path;
std::string model_graph = FLAGS_model_path;
std::string model_params = FLAGS_param_path; std::string model_params = FLAGS_param_path;
std::string dict_file = FLAGS_dict_file; std::string dict_file = FLAGS_dict_file;
std::string lm_path = FLAGS_lm_path; std::string lm_path = FLAGS_lm_path;
LOG(INFO) << "model path: " << model_graph; LOG(INFO) << "model path: " << model_path;
LOG(INFO) << "model param: " << model_params; LOG(INFO) << "model param: " << model_params;
LOG(INFO) << "dict path: " << dict_file; LOG(INFO) << "dict path: " << dict_file;
LOG(INFO) << "lm path: " << lm_path; LOG(INFO) << "lm path: " << lm_path;
@ -76,10 +70,9 @@ int main(int argc, char* argv[]) {
ppspeech::CTCBeamSearch decoder(opts); ppspeech::CTCBeamSearch decoder(opts);
ppspeech::ModelOptions model_opts; ppspeech::ModelOptions model_opts;
model_opts.model_path = model_graph; model_opts.model_path = model_path;
model_opts.params_path = model_params; model_opts.params_path = model_params;
model_opts.cache_shape = FLAGS_model_cache_names; model_opts.cache_shape = FLAGS_model_cache_names;
model_opts.input_names = FLAGS_model_input_names;
model_opts.output_names = FLAGS_model_output_names; model_opts.output_names = FLAGS_model_output_names;
std::shared_ptr<ppspeech::PaddleNnet> nnet( std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts)); new ppspeech::PaddleNnet(model_opts));
@ -125,7 +118,6 @@ int main(int argc, char* argv[]) {
if (feature_chunk_size < receptive_field_length) break; if (feature_chunk_size < receptive_field_length) break;
int32 start = chunk_idx * chunk_stride; int32 start = chunk_idx * chunk_stride;
int32 end = start + chunk_size;
for (int row_id = 0; row_id < chunk_size; ++row_id) { for (int row_id = 0; row_id < chunk_size; ++row_id) {
kaldi::SubVector<kaldi::BaseFloat> tmp(feature, start); kaldi::SubVector<kaldi::BaseFloat> tmp(feature, start);

@ -73,9 +73,9 @@ int main(int argc, char* argv[]) {
LOG(INFO) << "cmvn stats have write into: " << FLAGS_cmvn_write_path; LOG(INFO) << "cmvn stats have write into: " << FLAGS_cmvn_write_path;
LOG(INFO) << "Binary: " << FLAGS_binary; LOG(INFO) << "Binary: " << FLAGS_binary;
} catch (simdjson::simdjson_error& err) { } catch (simdjson::simdjson_error& err) {
LOG(ERR) << err.what(); LOG(ERROR) << err.what();
} }
return 0; return 0;
} }

@ -32,7 +32,6 @@ DEFINE_string(feature_wspecifier, "", "output feats wspecifier");
DEFINE_string(cmvn_file, "./cmvn.ark", "read cmvn"); DEFINE_string(cmvn_file, "./cmvn.ark", "read cmvn");
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size"); DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false); gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
@ -66,7 +65,8 @@ int main(int argc, char* argv[]) {
std::unique_ptr<ppspeech::FrontendInterface> cmvn( std::unique_ptr<ppspeech::FrontendInterface> cmvn(
new ppspeech::CMVN(FLAGS_cmvn_file, std::move(linear_spectrogram))); new ppspeech::CMVN(FLAGS_cmvn_file, std::move(linear_spectrogram)));
ppspeech::FeatureCache feature_cache(kint16max, std::move(cmvn)); ppspeech::FeatureCacheOptions feat_cache_opts;
ppspeech::FeatureCache feature_cache(feat_cache_opts, std::move(cmvn));
LOG(INFO) << "feat dim: " << feature_cache.Dim(); LOG(INFO) << "feat dim: " << feature_cache.Dim();
int sample_rate = 16000; int sample_rate = 16000;

@ -0,0 +1,10 @@
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_executable(websocket_server_main ${CMAKE_CURRENT_SOURCE_DIR}/websocket_server_main.cc)
target_include_directories(websocket_server_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(websocket_server_main PUBLIC frontend kaldi-feat-common nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder websocket ${DEPS})
add_executable(websocket_client_main ${CMAKE_CURRENT_SOURCE_DIR}/websocket_client_main.cc)
target_include_directories(websocket_client_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(websocket_client_main PUBLIC frontend kaldi-feat-common nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder websocket ${DEPS})

@ -0,0 +1,82 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "websocket/websocket_client.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/kaldi-io.h"
#include "kaldi/util/table-types.h"
DEFINE_string(host, "127.0.0.1", "host of websocket server");
DEFINE_int32(port, 201314, "port of websocket server");
DEFINE_string(wav_rspecifier, "", "test wav scp path");
DEFINE_double(streaming_chunk, 0.1, "streaming feature chunk size");
using kaldi::int16;
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
ppspeech::WebSocketClient client(FLAGS_host, FLAGS_port);
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
FLAGS_wav_rspecifier);
const int sample_rate = 16000;
const float streaming_chunk = FLAGS_streaming_chunk;
const int chunk_sample_size = streaming_chunk * sample_rate;
for (; !wav_reader.Done(); wav_reader.Next()) {
client.SendStartSignal();
std::string utt = wav_reader.Key();
const kaldi::WaveData& wave_data = wav_reader.Value();
CHECK_EQ(wave_data.SampFreq(), sample_rate);
int32 this_channel = 0;
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
this_channel);
const int tot_samples = waveform.Dim();
int sample_offset = 0;
while (sample_offset < tot_samples) {
int cur_chunk_size =
std::min(chunk_sample_size, tot_samples - sample_offset);
std::vector<int16> wav_chunk(cur_chunk_size);
for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk[i] = static_cast<int16>(waveform(sample_offset + i));
}
client.SendBinaryData(wav_chunk.data(),
wav_chunk.size() * sizeof(int16));
sample_offset += cur_chunk_size;
LOG(INFO) << "Send " << cur_chunk_size << " samples";
std::this_thread::sleep_for(
std::chrono::milliseconds(static_cast<int>(1 * 1000)));
if (cur_chunk_size < chunk_sample_size) {
client.SendEndSignal();
}
}
while (!client.Done()) {
}
std::string result = client.GetResult();
LOG(INFO) << "utt: " << utt << " " << result;
client.Join();
return 0;
}
return 0;
}

@ -0,0 +1,30 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "websocket/websocket_server.h"
#include "decoder/param.h"
DEFINE_int32(port, 201314, "websocket listening port");
int main(int argc, char *argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
ppspeech::RecognizerResource resource = ppspeech::InitRecognizerResoure();
ppspeech::WebSocketServer server(FLAGS_port, resource);
LOG(INFO) << "Listening at port " << FLAGS_port;
server.Start();
return 0;
}

@ -30,4 +30,10 @@ include_directories(
${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/decoder ${CMAKE_CURRENT_SOURCE_DIR}/decoder
) )
add_subdirectory(decoder) add_subdirectory(decoder)
include_directories(
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/websocket
)
add_subdirectory(websocket)

@ -28,8 +28,10 @@
#include <sstream> #include <sstream>
#include <stack> #include <stack>
#include <string> #include <string>
#include <thread>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <utility>
#include <vector> #include <vector>
#include "base/basic_types.h" #include "base/basic_types.h"

@ -7,5 +7,6 @@ add_library(decoder STATIC
ctc_decoders/path_trie.cpp ctc_decoders/path_trie.cpp
ctc_decoders/scorer.cpp ctc_decoders/scorer.cpp
ctc_tlg_decoder.cc ctc_tlg_decoder.cc
recognizer.cc
) )
target_link_libraries(decoder PUBLIC kenlm utils fst) target_link_libraries(decoder PUBLIC kenlm utils fst frontend nnet kaldi-decoder)

@ -33,7 +33,6 @@ void TLGDecoder::InitDecoder() {
void TLGDecoder::AdvanceDecode( void TLGDecoder::AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable) { const std::shared_ptr<kaldi::DecodableInterface>& decodable) {
while (!decodable->IsLastFrame(frame_decoded_size_)) { while (!decodable->IsLastFrame(frame_decoded_size_)) {
LOG(INFO) << "num frame decode: " << frame_decoded_size_;
AdvanceDecoding(decodable.get()); AdvanceDecoding(decodable.get());
} }
} }
@ -63,4 +62,4 @@ std::string TLGDecoder::GetFinalBestPath() {
} }
return words; return words;
} }
} }

@ -0,0 +1,94 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base/common.h"
#include "decoder/ctc_beam_search_decoder.h"
#include "decoder/ctc_tlg_decoder.h"
#include "frontend/audio/feature_pipeline.h"
DEFINE_string(cmvn_file, "", "read cmvn");
DEFINE_double(streaming_chunk, 0.1, "streaming feature chunk size");
DEFINE_bool(convert2PCM32, true, "audio convert to pcm32");
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
DEFINE_string(params_path, "avg_1.jit.pdiparams", "paddle nnet model param");
DEFINE_string(word_symbol_table, "words.txt", "word symbol table");
DEFINE_string(graph_path, "TLG", "decoder graph");
DEFINE_double(acoustic_scale, 1.0, "acoustic scale");
DEFINE_int32(max_active, 7500, "max active");
DEFINE_double(beam, 15.0, "decoder beam");
DEFINE_double(lattice_beam, 7.5, "decoder beam");
DEFINE_int32(receptive_field_length,
7,
"receptive field of two CNN(kernel=5) downsampling module.");
DEFINE_int32(downsampling_rate,
4,
"two CNN(kernel=5) module downsampling rate.");
DEFINE_string(model_output_names,
"save_infer_model/scale_0.tmp_1,save_infer_model/"
"scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/"
"scale_3.tmp_1",
"model output names");
DEFINE_string(model_cache_names, "5-1-1024,5-1-1024", "model cache names");
namespace ppspeech {
// todo refactor later
FeaturePipelineOptions InitFeaturePipelineOptions() {
FeaturePipelineOptions opts;
opts.cmvn_file = FLAGS_cmvn_file;
opts.linear_spectrogram_opts.streaming_chunk = FLAGS_streaming_chunk;
opts.convert2PCM32 = FLAGS_convert2PCM32;
kaldi::FrameExtractionOptions frame_opts;
frame_opts.frame_length_ms = 20;
frame_opts.frame_shift_ms = 10;
frame_opts.remove_dc_offset = false;
frame_opts.window_type = "hanning";
frame_opts.preemph_coeff = 0.0;
frame_opts.dither = 0.0;
opts.linear_spectrogram_opts.frame_opts = frame_opts;
opts.feature_cache_opts.frame_chunk_size = FLAGS_receptive_field_length;
opts.feature_cache_opts.frame_chunk_stride = FLAGS_downsampling_rate;
return opts;
}
ModelOptions InitModelOptions() {
ModelOptions model_opts;
model_opts.model_path = FLAGS_model_path;
model_opts.params_path = FLAGS_params_path;
model_opts.cache_shape = FLAGS_model_cache_names;
model_opts.output_names = FLAGS_model_output_names;
return model_opts;
}
TLGDecoderOptions InitDecoderOptions() {
TLGDecoderOptions decoder_opts;
decoder_opts.word_symbol_table = FLAGS_word_symbol_table;
decoder_opts.fst_path = FLAGS_graph_path;
decoder_opts.opts.max_active = FLAGS_max_active;
decoder_opts.opts.beam = FLAGS_beam;
decoder_opts.opts.lattice_beam = FLAGS_lattice_beam;
return decoder_opts;
}
RecognizerResource InitRecognizerResoure() {
RecognizerResource resource;
resource.acoustic_scale = FLAGS_acoustic_scale;
resource.feature_pipeline_opts = InitFeaturePipelineOptions();
resource.model_opts = InitModelOptions();
resource.tlg_opts = InitDecoderOptions();
return resource;
}
}

@ -0,0 +1,60 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder/recognizer.h"
namespace ppspeech {
using kaldi::Vector;
using kaldi::VectorBase;
using kaldi::BaseFloat;
using std::vector;
using kaldi::SubVector;
using std::unique_ptr;
Recognizer::Recognizer(const RecognizerResource& resource) {
// resource_ = resource;
const FeaturePipelineOptions& feature_opts = resource.feature_pipeline_opts;
feature_pipeline_.reset(new FeaturePipeline(feature_opts));
std::shared_ptr<PaddleNnet> nnet(new PaddleNnet(resource.model_opts));
BaseFloat ac_scale = resource.acoustic_scale;
decodable_.reset(new Decodable(nnet, feature_pipeline_, ac_scale));
decoder_.reset(new TLGDecoder(resource.tlg_opts));
input_finished_ = false;
}
void Recognizer::Accept(const Vector<BaseFloat>& waves) {
feature_pipeline_->Accept(waves);
}
void Recognizer::Decode() { decoder_->AdvanceDecode(decodable_); }
std::string Recognizer::GetFinalResult() {
return decoder_->GetFinalBestPath();
}
void Recognizer::SetFinished() {
feature_pipeline_->SetFinished();
input_finished_ = true;
}
bool Recognizer::IsFinished() { return input_finished_; }
void Recognizer::Reset() {
feature_pipeline_->Reset();
decodable_->Reset();
decoder_->Reset();
}
} // namespace ppspeech

@ -0,0 +1,59 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// todo refactor later (SGoat)
#pragma once
#include "decoder/ctc_beam_search_decoder.h"
#include "decoder/ctc_tlg_decoder.h"
#include "frontend/audio/feature_pipeline.h"
#include "nnet/decodable.h"
#include "nnet/paddle_nnet.h"
namespace ppspeech {
struct RecognizerResource {
FeaturePipelineOptions feature_pipeline_opts;
ModelOptions model_opts;
TLGDecoderOptions tlg_opts;
// CTCBeamSearchOptions beam_search_opts;
kaldi::BaseFloat acoustic_scale;
RecognizerResource()
: acoustic_scale(1.0),
feature_pipeline_opts(),
model_opts(),
tlg_opts() {}
};
class Recognizer {
public:
explicit Recognizer(const RecognizerResource& resouce);
void Accept(const kaldi::Vector<kaldi::BaseFloat>& waves);
void Decode();
std::string GetFinalResult();
void SetFinished();
bool IsFinished();
void Reset();
private:
// std::shared_ptr<RecognizerResource> resource_;
// RecognizerResource resource_;
std::shared_ptr<FeaturePipeline> feature_pipeline_;
std::shared_ptr<Decodable> decodable_;
std::unique_ptr<TLGDecoder> decoder_;
bool input_finished_;
};
} // namespace ppspeech

@ -6,6 +6,7 @@ add_library(frontend STATIC
linear_spectrogram.cc linear_spectrogram.cc
audio_cache.cc audio_cache.cc
feature_cache.cc feature_cache.cc
feature_pipeline.cc
) )
target_link_libraries(frontend PUBLIC kaldi-matrix) target_link_libraries(frontend PUBLIC kaldi-matrix kaldi-feat-common)

@ -41,7 +41,7 @@ void AudioCache::Accept(const VectorBase<BaseFloat>& waves) {
ready_feed_condition_.wait(lock); ready_feed_condition_.wait(lock);
} }
for (size_t idx = 0; idx < waves.Dim(); ++idx) { for (size_t idx = 0; idx < waves.Dim(); ++idx) {
int32 buffer_idx = (idx + offset_) % ring_buffer_.size(); int32 buffer_idx = (idx + offset_ + size_) % ring_buffer_.size();
ring_buffer_[buffer_idx] = waves(idx); ring_buffer_[buffer_idx] = waves(idx);
if (convert2PCM32_) if (convert2PCM32_)
ring_buffer_[buffer_idx] = Convert2PCM32(waves(idx)); ring_buffer_[buffer_idx] = Convert2PCM32(waves(idx));

@ -24,7 +24,7 @@ namespace ppspeech {
class AudioCache : public FrontendInterface { class AudioCache : public FrontendInterface {
public: public:
explicit AudioCache(int buffer_size = 1000 * kint16max, explicit AudioCache(int buffer_size = 1000 * kint16max,
bool convert2PCM32 = false); bool convert2PCM32 = true);
virtual void Accept(const kaldi::VectorBase<BaseFloat>& waves); virtual void Accept(const kaldi::VectorBase<BaseFloat>& waves);

@ -23,10 +23,13 @@ using std::vector;
using kaldi::SubVector; using kaldi::SubVector;
using std::unique_ptr; using std::unique_ptr;
FeatureCache::FeatureCache(int max_size, FeatureCache::FeatureCache(FeatureCacheOptions opts,
unique_ptr<FrontendInterface> base_extractor) { unique_ptr<FrontendInterface> base_extractor) {
max_size_ = max_size; max_size_ = opts.max_size;
frame_chunk_stride_ = opts.frame_chunk_stride;
frame_chunk_size_ = opts.frame_chunk_size;
base_extractor_ = std::move(base_extractor); base_extractor_ = std::move(base_extractor);
dim_ = base_extractor_->Dim();
} }
void FeatureCache::Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs) { void FeatureCache::Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs) {
@ -44,13 +47,14 @@ bool FeatureCache::Read(kaldi::Vector<kaldi::BaseFloat>* feats) {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
while (cache_.empty() && base_extractor_->IsFinished() == false) { while (cache_.empty() && base_extractor_->IsFinished() == false) {
ready_read_condition_.wait(lock); // todo refactor: wait
BaseFloat elapsed = timer.Elapsed() * 1000; // ready_read_condition_.wait(lock);
// todo replace 1.0 with timeout_ int32 elapsed = static_cast<int32>(timer.Elapsed() * 1000);
if (elapsed > 1.0) { // todo replace 1 with timeout_, 1 ms
if (elapsed > 1) {
return false; return false;
} }
usleep(1000); // sleep 1 ms usleep(100); // sleep 0.1 ms
} }
if (cache_.empty()) return false; if (cache_.empty()) return false;
feats->Resize(cache_.front().Dim()); feats->Resize(cache_.front().Dim());
@ -63,25 +67,41 @@ bool FeatureCache::Read(kaldi::Vector<kaldi::BaseFloat>* feats) {
// read all data from base_feature_extractor_ into cache_ // read all data from base_feature_extractor_ into cache_
bool FeatureCache::Compute() { bool FeatureCache::Compute() {
// compute and feed // compute and feed
Vector<BaseFloat> feature_chunk; Vector<BaseFloat> feature;
bool result = base_extractor_->Read(&feature_chunk); bool result = base_extractor_->Read(&feature);
if (result == false || feature.Dim() == 0) return false;
int32 joint_len = feature.Dim() + remained_feature_.Dim();
int32 num_chunk =
((joint_len / dim_) - frame_chunk_size_) / frame_chunk_stride_ + 1;
std::unique_lock<std::mutex> lock(mutex_); Vector<BaseFloat> joint_feature(joint_len);
while (cache_.size() >= max_size_) { joint_feature.Range(0, remained_feature_.Dim())
ready_feed_condition_.wait(lock); .CopyFromVec(remained_feature_);
} joint_feature.Range(remained_feature_.Dim(), feature.Dim())
.CopyFromVec(feature);
// feed cache for (int chunk_idx = 0; chunk_idx < num_chunk; ++chunk_idx) {
if (feature_chunk.Dim() != 0) { int32 start = chunk_idx * frame_chunk_stride_ * dim_;
Vector<BaseFloat> feature_chunk(frame_chunk_size_ * dim_);
SubVector<BaseFloat> tmp(joint_feature.Data() + start,
frame_chunk_size_ * dim_);
feature_chunk.CopyFromVec(tmp);
std::unique_lock<std::mutex> lock(mutex_);
while (cache_.size() >= max_size_) {
ready_feed_condition_.wait(lock);
}
// feed cache
cache_.push(feature_chunk); cache_.push(feature_chunk);
ready_read_condition_.notify_one();
} }
ready_read_condition_.notify_one(); int32 remained_feature_len =
joint_len - num_chunk * frame_chunk_stride_ * dim_;
remained_feature_.Resize(remained_feature_len);
remained_feature_.CopyFromVec(joint_feature.Range(
frame_chunk_stride_ * num_chunk * dim_, remained_feature_len));
return result; return result;
} }
void Reset() {
// std::lock_guard<std::mutex> lock(mutex_);
return;
}
} // namespace ppspeech } // namespace ppspeech

@ -19,10 +19,18 @@
namespace ppspeech { namespace ppspeech {
struct FeatureCacheOptions {
int32 max_size;
int32 frame_chunk_size;
int32 frame_chunk_stride;
FeatureCacheOptions()
: max_size(kint16max), frame_chunk_size(1), frame_chunk_stride(1) {}
};
class FeatureCache : public FrontendInterface { class FeatureCache : public FrontendInterface {
public: public:
explicit FeatureCache( explicit FeatureCache(
int32 max_size = kint16max, FeatureCacheOptions opts,
std::unique_ptr<FrontendInterface> base_extractor = NULL); std::unique_ptr<FrontendInterface> base_extractor = NULL);
// Feed feats or waves // Feed feats or waves
@ -32,12 +40,15 @@ class FeatureCache : public FrontendInterface {
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* feats); virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* feats);
// feat dim // feat dim
virtual size_t Dim() const { return base_extractor_->Dim(); } virtual size_t Dim() const { return dim_; }
virtual void SetFinished() { virtual void SetFinished() {
// std::unique_lock<std::mutex> lock(mutex_);
base_extractor_->SetFinished(); base_extractor_->SetFinished();
LOG(INFO) << "set finished";
// read the last chunk data // read the last chunk data
Compute(); Compute();
// ready_feed_condition_.notify_one();
} }
virtual bool IsFinished() const { return base_extractor_->IsFinished(); } virtual bool IsFinished() const { return base_extractor_->IsFinished(); }
@ -52,9 +63,13 @@ class FeatureCache : public FrontendInterface {
private: private:
bool Compute(); bool Compute();
int32 dim_;
size_t max_size_; size_t max_size_;
std::unique_ptr<FrontendInterface> base_extractor_; int32 frame_chunk_size_;
int32 frame_chunk_stride_;
kaldi::Vector<kaldi::BaseFloat> remained_feature_;
std::unique_ptr<FrontendInterface> base_extractor_;
std::mutex mutex_; std::mutex mutex_;
std::queue<kaldi::Vector<BaseFloat>> cache_; std::queue<kaldi::Vector<BaseFloat>> cache_;
std::condition_variable ready_feed_condition_; std::condition_variable ready_feed_condition_;

@ -0,0 +1,36 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "frontend/audio/feature_pipeline.h"
namespace ppspeech {
using std::unique_ptr;
FeaturePipeline::FeaturePipeline(const FeaturePipelineOptions& opts) {
unique_ptr<FrontendInterface> data_source(
new ppspeech::AudioCache(1000 * kint16max, opts.convert2PCM32));
unique_ptr<FrontendInterface> linear_spectrogram(
new ppspeech::LinearSpectrogram(opts.linear_spectrogram_opts,
std::move(data_source)));
unique_ptr<FrontendInterface> cmvn(
new ppspeech::CMVN(opts.cmvn_file, std::move(linear_spectrogram)));
base_extractor_.reset(
new ppspeech::FeatureCache(opts.feature_cache_opts, std::move(cmvn)));
}
} // ppspeech

@ -0,0 +1,57 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// todo refactor later (SGoat)
#pragma once
#include "frontend/audio/audio_cache.h"
#include "frontend/audio/data_cache.h"
#include "frontend/audio/feature_cache.h"
#include "frontend/audio/frontend_itf.h"
#include "frontend/audio/linear_spectrogram.h"
#include "frontend/audio/normalizer.h"
namespace ppspeech {
struct FeaturePipelineOptions {
std::string cmvn_file;
bool convert2PCM32;
LinearSpectrogramOptions linear_spectrogram_opts;
FeatureCacheOptions feature_cache_opts;
FeaturePipelineOptions()
: cmvn_file(""),
convert2PCM32(false),
linear_spectrogram_opts(),
feature_cache_opts() {}
};
class FeaturePipeline : public FrontendInterface {
public:
explicit FeaturePipeline(const FeaturePipelineOptions& opts);
virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& waves) {
base_extractor_->Accept(waves);
}
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* feats) {
return base_extractor_->Read(feats);
}
virtual size_t Dim() const { return base_extractor_->Dim(); }
virtual void SetFinished() { base_extractor_->SetFinished(); }
virtual bool IsFinished() const { return base_extractor_->IsFinished(); }
virtual void Reset() { base_extractor_->Reset(); }
private:
std::unique_ptr<FrontendInterface> base_extractor_;
};
}

@ -52,16 +52,16 @@ bool LinearSpectrogram::Read(Vector<BaseFloat>* feats) {
if (flag == false || input_feats.Dim() == 0) return false; if (flag == false || input_feats.Dim() == 0) return false;
int32 feat_len = input_feats.Dim(); int32 feat_len = input_feats.Dim();
int32 left_len = reminded_wav_.Dim(); int32 left_len = remained_wav_.Dim();
Vector<BaseFloat> waves(feat_len + left_len); Vector<BaseFloat> waves(feat_len + left_len);
waves.Range(0, left_len).CopyFromVec(reminded_wav_); waves.Range(0, left_len).CopyFromVec(remained_wav_);
waves.Range(left_len, feat_len).CopyFromVec(input_feats); waves.Range(left_len, feat_len).CopyFromVec(input_feats);
Compute(waves, feats); Compute(waves, feats);
int32 frame_shift = opts_.frame_opts.WindowShift(); int32 frame_shift = opts_.frame_opts.WindowShift();
int32 num_frames = kaldi::NumFrames(waves.Dim(), opts_.frame_opts); int32 num_frames = kaldi::NumFrames(waves.Dim(), opts_.frame_opts);
int32 left_samples = waves.Dim() - frame_shift * num_frames; int32 left_samples = waves.Dim() - frame_shift * num_frames;
reminded_wav_.Resize(left_samples); remained_wav_.Resize(left_samples);
reminded_wav_.CopyFromVec( remained_wav_.CopyFromVec(
waves.Range(frame_shift * num_frames, left_samples)); waves.Range(frame_shift * num_frames, left_samples));
return true; return true;
} }

@ -25,12 +25,12 @@ struct LinearSpectrogramOptions {
kaldi::FrameExtractionOptions frame_opts; kaldi::FrameExtractionOptions frame_opts;
kaldi::BaseFloat streaming_chunk; // second kaldi::BaseFloat streaming_chunk; // second
LinearSpectrogramOptions() : streaming_chunk(0.36), frame_opts() {} LinearSpectrogramOptions() : streaming_chunk(0.1), frame_opts() {}
void Register(kaldi::OptionsItf* opts) { void Register(kaldi::OptionsItf* opts) {
opts->Register("streaming-chunk", opts->Register("streaming-chunk",
&streaming_chunk, &streaming_chunk,
"streaming chunk size, default: 0.36 sec"); "streaming chunk size, default: 0.1 sec");
frame_opts.Register(opts); frame_opts.Register(opts);
} }
}; };
@ -48,7 +48,7 @@ class LinearSpectrogram : public FrontendInterface {
virtual bool IsFinished() const { return base_extractor_->IsFinished(); } virtual bool IsFinished() const { return base_extractor_->IsFinished(); }
virtual void Reset() { virtual void Reset() {
base_extractor_->Reset(); base_extractor_->Reset();
reminded_wav_.Resize(0); remained_wav_.Resize(0);
} }
private: private:
@ -60,7 +60,7 @@ class LinearSpectrogram : public FrontendInterface {
kaldi::BaseFloat hanning_window_energy_; kaldi::BaseFloat hanning_window_energy_;
LinearSpectrogramOptions opts_; LinearSpectrogramOptions opts_;
std::unique_ptr<FrontendInterface> base_extractor_; std::unique_ptr<FrontendInterface> base_extractor_;
kaldi::Vector<kaldi::BaseFloat> reminded_wav_; kaldi::Vector<kaldi::BaseFloat> remained_wav_;
int chunk_sample_size_; int chunk_sample_size_;
DISALLOW_COPY_AND_ASSIGN(LinearSpectrogram); DISALLOW_COPY_AND_ASSIGN(LinearSpectrogram);
}; };

@ -78,7 +78,6 @@ bool Decodable::AdvanceChunk() {
} }
int32 nnet_dim = 0; int32 nnet_dim = 0;
Vector<BaseFloat> inferences; Vector<BaseFloat> inferences;
Matrix<BaseFloat> nnet_cache_tmp;
nnet_->FeedForward(features, frontend_->Dim(), &inferences, &nnet_dim); nnet_->FeedForward(features, frontend_->Dim(), &inferences, &nnet_dim);
nnet_cache_.Resize(inferences.Dim() / nnet_dim, nnet_dim); nnet_cache_.Resize(inferences.Dim() / nnet_dim, nnet_dim);
nnet_cache_.CopyRowsFromVec(inferences); nnet_cache_.CopyRowsFromVec(inferences);

Loading…
Cancel
Save