[speechx] add cls (#2839)

* fix nnet thread crash && rescore cost time

* add nnet thread main
pull/2854/head
YangZhou 3 years ago committed by MarsMeng
parent ee7c266f13
commit 16b9894edf

@ -2,6 +2,8 @@ cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
project(speechx LANGUAGES CXX)
add_compile_options(-fPIC)
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/kaldi)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/common)
@ -10,3 +12,4 @@ add_subdirectory(asr)
add_subdirectory(common)
add_subdirectory(kaldi)
add_subdirectory(codelab)
add_subdirectory(cls)

@ -63,8 +63,9 @@ void CTCPrefixBeamSearch::Reset() {
times_.emplace_back(empty);
}
void CTCPrefixBeamSearch::InitDecoder() { Reset(); }
void CTCPrefixBeamSearch::InitDecoder() {
Reset();
}
void CTCPrefixBeamSearch::AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable) {
@ -77,7 +78,7 @@ void CTCPrefixBeamSearch::AdvanceDecode(
bool flag = decodable->FrameLikelihood(num_frame_decoded_, &frame_prob);
feat_nnet_cost += timer.Elapsed();
if (flag == false) {
VLOG(3) << "decoder advance decode exit." << frame_prob.size();
VLOG(2) << "decoder advance decode exit." << frame_prob.size();
break;
}
@ -87,7 +88,7 @@ void CTCPrefixBeamSearch::AdvanceDecode(
AdvanceDecoding(likelihood);
search_cost += timer.Elapsed();
VLOG(2) << "num_frame_decoded_: " << num_frame_decoded_;
VLOG(1) << "num_frame_decoded_: " << num_frame_decoded_;
}
VLOG(1) << "AdvanceDecode feat + forward cost: " << feat_nnet_cost
<< " sec.";

@ -8,14 +8,21 @@ target_link_libraries(nnet utils)
target_compile_options(nnet PUBLIC ${PADDLE_COMPILE_FLAGS})
target_include_directories(nnet PUBLIC ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
# test bin
#if(USING_U2)
# set(bin_name u2_nnet_main)
# add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
# target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
# target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog nnet)
# test bin
#set(bin_name u2_nnet_main)
#add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
#target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
#target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog nnet)
# target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS})
# target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
# target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS})
#endif()
#target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS})
#target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
#target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS})
set(bin_name u2_nnet_thread_main)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog nnet frontend)
target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS})
target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS})

@ -33,7 +33,6 @@ void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) {
nnet_producer_->Acceptlikelihood(likelihood);
}
// return the size of frame have computed.
int32 Decodable::NumFramesReady() const { return frames_ready_; }

@ -22,14 +22,43 @@ using kaldi::BaseFloat;
NnetProducer::NnetProducer(std::shared_ptr<NnetBase> nnet,
std::shared_ptr<FrontendInterface> frontend)
: nnet_(nnet), frontend_(frontend) {}
: nnet_(nnet), frontend_(frontend) {
abort_ = false;
Reset();
thread_ = std::thread(RunNnetEvaluation, this);
}
void NnetProducer::Accept(const std::vector<kaldi::BaseFloat>& inputs) {
frontend_->Accept(inputs);
condition_variable_.notify_one();
}
void NnetProducer::UnLock() {
std::unique_lock<std::mutex> lock(read_mutex_);
while (frontend_->IsFinished() == false && cache_.empty()) {
condition_read_ready_.wait(lock);
}
return;
}
void NnetProducer::RunNnetEvaluation(NnetProducer *me) {
me->RunNnetEvaluationInteral();
}
void NnetProducer::RunNnetEvaluationInteral() {
bool result = false;
do {
result = Compute();
} while (result);
LOG(INFO) << "NnetEvaluationInteral begin";
while (!abort_) {
std::unique_lock<std::mutex> lock(mutex_);
condition_variable_.wait(lock);
do {
result = Compute();
} while (result);
if (frontend_->IsFinished() == true) {
if (cache_.empty()) finished_ = true;
}
}
LOG(INFO) << "NnetEvaluationInteral exit";
}
void NnetProducer::Acceptlikelihood(
@ -39,12 +68,20 @@ void NnetProducer::Acceptlikelihood(
for (size_t idx = 0; idx < likelihood.NumRows(); ++idx) {
for (size_t col = 0; col < likelihood.NumCols(); ++col) {
prob[col] = likelihood(idx, col);
cache_.push_back(prob);
}
cache_.push_back(prob);
}
}
bool NnetProducer::Read(std::vector<kaldi::BaseFloat>* nnet_prob) {
bool flag = cache_.pop(nnet_prob);
condition_variable_.notify_one();
return flag;
}
bool NnetProducer::ReadandCompute(std::vector<kaldi::BaseFloat>* nnet_prob) {
Compute();
if (frontend_->IsFinished() && cache_.empty()) finished_ = true;
bool flag = cache_.pop(nnet_prob);
return flag;
}
@ -53,22 +90,23 @@ bool NnetProducer::Compute() {
vector<BaseFloat> features;
if (frontend_ == NULL || frontend_->Read(&features) == false) {
// no feat or frontend_ not init.
VLOG(3) << "no feat avalible";
VLOG(2) << "no feat avalible";
return false;
}
CHECK_GE(frontend_->Dim(), 0);
VLOG(2) << "Forward in " << features.size() / frontend_->Dim() << " feats.";
VLOG(1) << "Forward in " << features.size() / frontend_->Dim() << " feats.";
NnetOut out;
nnet_->FeedForward(features, frontend_->Dim(), &out);
int32& vocab_dim = out.vocab_dim;
size_t nframes = out.logprobs.size() / vocab_dim;
VLOG(2) << "Forward out " << nframes << " decoder frames.";
VLOG(1) << "Forward out " << nframes << " decoder frames.";
for (size_t idx = 0; idx < nframes; ++idx) {
std::vector<BaseFloat> logprob(
out.logprobs.data() + idx * vocab_dim,
out.logprobs.data() + (idx + 1) * vocab_dim);
cache_.push_back(logprob);
condition_read_ready_.notify_one();
}
return true;
}

@ -33,27 +33,38 @@ class NnetProducer {
// nnet
bool Read(std::vector<kaldi::BaseFloat>* nnet_prob);
bool ReadandCompute(std::vector<kaldi::BaseFloat>* nnet_prob);
static void RunNnetEvaluation(NnetProducer *me);
void RunNnetEvaluationInteral();
void UnLock();
void Wait() {
abort_ = true;
condition_variable_.notify_one();
if (thread_.joinable()) thread_.join();
}
bool Empty() const { return cache_.empty(); }
void SetFinished() {
void SetInputFinished() {
LOG(INFO) << "set finished";
// std::unique_lock<std::mutex> lock(mutex_);
frontend_->SetFinished();
// read the last chunk data
Compute();
// ready_feed_condition_.notify_one();
LOG(INFO) << "compute last feats done.";
condition_variable_.notify_one();
}
bool IsFinished() const { return frontend_->IsFinished(); }
// the compute thread exit
bool IsFinished() const { return finished_; }
~NnetProducer() {
if (thread_.joinable()) thread_.join();
}
void Reset() {
frontend_->Reset();
nnet_->Reset();
VLOG(3) << "feature cache reset: cache size: " << cache_.size();
cache_.clear();
finished_ = false;
}
void AttentionRescoring(const std::vector<std::vector<int>>& hyps,
@ -66,6 +77,13 @@ class NnetProducer {
std::shared_ptr<FrontendInterface> frontend_;
std::shared_ptr<NnetBase> nnet_;
SafeQueue<std::vector<kaldi::BaseFloat>> cache_;
std::mutex mutex_;
std::mutex read_mutex_;
std::condition_variable condition_variable_;
std::condition_variable condition_read_ready_;
std::thread thread_;
bool finished_;
bool abort_;
DISALLOW_COPY_AND_ASSIGN(NnetProducer);
};

@ -0,0 +1,137 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "base/common.h"
#include "decoder/param.h"
#include "frontend/wave-reader.h"
#include "frontend/feature_pipeline.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "nnet/u2_nnet.h"
#include "nnet/nnet_producer.h"
DEFINE_string(wav_rspecifier, "", "test wav rspecifier");
DEFINE_string(nnet_prob_wspecifier, "", "nnet porb wspecifier");
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
DEFINE_int32(sample_rate, 16000, "sample rate");
using kaldi::BaseFloat;
using kaldi::Matrix;
using std::vector;
int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
int32 num_done = 0, num_err = 0;
int sample_rate = FLAGS_sample_rate;
float streaming_chunk = FLAGS_streaming_chunk;
int chunk_sample_size = streaming_chunk * sample_rate;
CHECK_GT(FLAGS_wav_rspecifier.size(), 0);
CHECK_GT(FLAGS_nnet_prob_wspecifier.size(), 0);
CHECK_GT(FLAGS_model_path.size(), 0);
LOG(INFO) << "input rspecifier: " << FLAGS_wav_rspecifier;
LOG(INFO) << "output wspecifier: " << FLAGS_nnet_prob_wspecifier;
LOG(INFO) << "model path: " << FLAGS_model_path;
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
FLAGS_wav_rspecifier);
kaldi::BaseFloatMatrixWriter nnet_out_writer(FLAGS_nnet_prob_wspecifier);
ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags();
ppspeech::FeaturePipelineOptions feature_opts =
ppspeech::FeaturePipelineOptions::InitFromFlags();
feature_opts.assembler_opts.fill_zero = false;
std::shared_ptr<ppspeech::U2Nnet> nnet(new ppspeech::U2Nnet(model_opts));
std::shared_ptr<ppspeech::FeaturePipeline> feature_pipeline(
new ppspeech::FeaturePipeline(feature_opts));
std::shared_ptr<ppspeech::NnetProducer> nnet_producer(
new ppspeech::NnetProducer(nnet, feature_pipeline));
kaldi::Timer timer;
float tot_wav_duration = 0;
for (; !wav_reader.Done(); wav_reader.Next()) {
std::string utt = wav_reader.Key();
const kaldi::WaveData& wave_data = wav_reader.Value();
LOG(INFO) << "utt: " << utt;
LOG(INFO) << "wav dur: " << wave_data.Duration() << " sec.";
double dur = wave_data.Duration();
tot_wav_duration += dur;
int32 this_channel = 0;
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
this_channel);
int tot_samples = waveform.Dim();
LOG(INFO) << "wav len (sample): " << tot_samples;
int sample_offset = 0;
kaldi::Timer timer;
while (sample_offset < tot_samples) {
int cur_chunk_size =
std::min(chunk_sample_size, tot_samples - sample_offset);
std::vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk[i] = waveform(sample_offset + i);
}
nnet_producer->Accept(wav_chunk);
if (cur_chunk_size < chunk_sample_size) {
nnet_producer->SetInputFinished();
}
// no overlap
sample_offset += cur_chunk_size;
}
CHECK(sample_offset == tot_samples);
std::vector<std::vector<kaldi::BaseFloat>> prob_vec;
while(1) {
std::vector<kaldi::BaseFloat> logprobs;
bool isok = nnet_producer->Read(&logprobs);
if (nnet_producer->IsFinished()) break;
if (isok == false) continue;
prob_vec.push_back(logprobs);
}
{
// writer nnet output
kaldi::MatrixIndexT nrow = prob_vec.size();
kaldi::MatrixIndexT ncol = prob_vec[0].size();
LOG(INFO) << "nnet out shape: " << nrow << ", " << ncol;
kaldi::Matrix<kaldi::BaseFloat> nnet_out(nrow, ncol);
for (int32 row_idx = 0; row_idx < nrow; ++row_idx) {
for (int32 col_idx = 0; col_idx < ncol; ++col_idx) {
nnet_out(row_idx, col_idx) = prob_vec[row_idx][col_idx];
}
}
nnet_out_writer.Write(utt, nnet_out);
}
nnet_producer->Reset();
}
nnet_producer->Wait();
double elapsed = timer.Elapsed();
LOG(INFO) << "Program cost:" << elapsed << " sec";
LOG(INFO) << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
return (num_done != 0 ? 0 : 1);
}

@ -39,12 +39,28 @@ U2Recognizer::U2Recognizer(const U2RecognizerResource& resource)
unit_table_ = decoder_->VocabTable();
symbol_table_ = unit_table_;
global_frame_offset_ = 0;
input_finished_ = false;
num_frames_ = 0;
result_.clear();
}
U2Recognizer::~U2Recognizer() {
SetInputFinished();
WaitDecodeFinished();
}
Reset();
void U2Recognizer::WaitDecodeFinished() {
if (thread_.joinable()) thread_.join();
}
void U2Recognizer::Reset() {
void U2Recognizer::WaitFinished() {
if (thread_.joinable()) thread_.join();
nnet_producer_->Wait();
}
void U2Recognizer::InitDecoder() {
global_frame_offset_ = 0;
input_finished_ = false;
num_frames_ = 0;
@ -52,6 +68,7 @@ void U2Recognizer::Reset() {
decodable_->Reset();
decoder_->Reset();
thread_ = std::thread(RunDecoderSearch, this);
}
void U2Recognizer::ResetContinuousDecoding() {
@ -63,6 +80,19 @@ void U2Recognizer::ResetContinuousDecoding() {
decoder_->Reset();
}
void U2Recognizer::RunDecoderSearch(U2Recognizer* me) {
me->RunDecoderSearchInternal();
}
void U2Recognizer::RunDecoderSearchInternal() {
LOG(INFO) << "DecoderSearchInteral begin";
while (!nnet_producer_->IsFinished()) {
nnet_producer_->UnLock();
decoder_->AdvanceDecode(decodable_);
}
Decode();
LOG(INFO) << "DecoderSearchInteral exit";
}
void U2Recognizer::Accept(const vector<BaseFloat>& waves) {
kaldi::Timer timer;
@ -71,7 +101,6 @@ void U2Recognizer::Accept(const vector<BaseFloat>& waves) {
<< " samples.";
}
void U2Recognizer::Decode() {
decoder_->AdvanceDecode(decodable_);
UpdateResult(false);
@ -207,8 +236,8 @@ std::string U2Recognizer::GetFinalResult() { return result_[0].sentence; }
std::string U2Recognizer::GetPartialResult() { return result_[0].sentence; }
void U2Recognizer::SetFinished() {
nnet_producer_->SetFinished();
void U2Recognizer::SetInputFinished() {
nnet_producer_->SetInputFinished();
input_finished_ = true;
}

@ -112,19 +112,21 @@ struct U2RecognizerResource {
class U2Recognizer {
public:
explicit U2Recognizer(const U2RecognizerResource& resouce);
void Reset();
~U2Recognizer();
void InitDecoder();
void ResetContinuousDecoding();
void Accept(const std::vector<kaldi::BaseFloat>& waves);
void Decode();
void Rescoring();
std::string GetFinalResult();
std::string GetPartialResult();
void SetFinished();
void SetInputFinished();
bool IsFinished() { return input_finished_; }
void WaitDecodeFinished();
void WaitFinished();
bool DecodedSomething() const {
return !result_.empty() && !result_[0].sentence.empty();
@ -137,18 +139,17 @@ class U2Recognizer {
// feature_pipeline_->FrameShift();
}
const std::vector<DecodeResult>& Result() const { return result_; }
void AttentionRescoring();
private:
void AttentionRescoring();
static void RunDecoderSearch(U2Recognizer *me);
void RunDecoderSearchInternal();
void UpdateResult(bool finish = false);
private:
U2RecognizerResource opts_;
// std::shared_ptr<U2RecognizerResource> resource_;
// U2RecognizerResource resource_;
std::shared_ptr<NnetProducer> nnet_producer_;
std::shared_ptr<Decodable> decodable_;
std::unique_ptr<CTCPrefixBeamSearch> decoder_;
@ -167,6 +168,7 @@ class U2Recognizer {
const int time_stamp_gap_ = 100;
bool input_finished_;
std::thread thread_;
};
} // namespace ppspeech

@ -49,6 +49,7 @@ int main(int argc, char* argv[]) {
ppspeech::U2Recognizer recognizer(resource);
for (; !wav_reader.Done(); wav_reader.Next()) {
recognizer.InitDecoder();
std::string utt = wav_reader.Key();
const kaldi::WaveData& wave_data = wav_reader.Value();
LOG(INFO) << "utt: " << utt;
@ -79,7 +80,7 @@ int main(int argc, char* argv[]) {
recognizer.Accept(wav_chunk);
if (cur_chunk_size < chunk_sample_size) {
recognizer.SetFinished();
recognizer.SetInputFinished();
}
recognizer.Decode();
if (recognizer.DecodedSomething()) {
@ -100,7 +101,6 @@ int main(int argc, char* argv[]) {
std::string result = recognizer.GetFinalResult();
recognizer.Reset();
if (result.empty()) {
// the TokenWriter can not write empty string.

@ -22,15 +22,6 @@ DEFINE_string(result_wspecifier, "", "test result wspecifier");
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
DEFINE_int32(sample_rate, 16000, "sample rate");
void decode_func(std::shared_ptr<ppspeech::U2Recognizer> recognizer) {
while (!recognizer->IsFinished()) {
recognizer->Decode();
usleep(100);
}
recognizer->Decode();
recognizer->Rescoring();
}
int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false);
@ -40,6 +31,7 @@ int main(int argc, char* argv[]) {
int32 num_done = 0, num_err = 0;
double tot_wav_duration = 0.0;
double tot_attention_rescore_time = 0.0;
double tot_decode_time = 0.0;
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
@ -59,7 +51,7 @@ int main(int argc, char* argv[]) {
new ppspeech::U2Recognizer(resource));
for (; !wav_reader.Done(); wav_reader.Next()) {
std::thread recognizer_thread(decode_func, recognizer_ptr);
recognizer_ptr->InitDecoder();
std::string utt = wav_reader.Key();
const kaldi::WaveData& wave_data = wav_reader.Value();
LOG(INFO) << "utt: " << utt;
@ -74,7 +66,6 @@ int main(int argc, char* argv[]) {
LOG(INFO) << "wav len (sample): " << tot_samples;
int sample_offset = 0;
kaldi::Timer timer;
kaldi::Timer local_timer;
while (sample_offset < tot_samples) {
@ -85,21 +76,23 @@ int main(int argc, char* argv[]) {
for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk[i] = waveform(sample_offset + i);
}
// wav_chunk = waveform.Range(sample_offset + i, cur_chunk_size);
recognizer_ptr->Accept(wav_chunk);
if (cur_chunk_size < chunk_sample_size) {
recognizer_ptr->SetFinished();
recognizer_ptr->SetInputFinished();
}
// no overlap
sample_offset += cur_chunk_size;
}
CHECK(sample_offset == tot_samples);
recognizer_ptr->WaitDecodeFinished();
kaldi::Timer timer;
recognizer_ptr->AttentionRescoring();
tot_attention_rescore_time += timer.Elapsed();
recognizer_thread.join();
std::string result = recognizer_ptr->GetFinalResult();
recognizer_ptr->Reset();
if (result.empty()) {
// the TokenWriter can not write empty string.
++num_err;
@ -107,6 +100,7 @@ int main(int argc, char* argv[]) {
continue;
}
tot_decode_time += local_timer.Elapsed();
LOG(INFO) << utt << " " << result;
LOG(INFO) << " RTF: " << local_timer.Elapsed() / dur << " dur: " << dur
<< " cost: " << local_timer.Elapsed();
@ -115,9 +109,11 @@ int main(int argc, char* argv[]) {
++num_done;
}
recognizer_ptr->WaitFinished();
LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done);
LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec";
LOG(INFO) << "total decode cost:" << tot_decode_time << " sec";
LOG(INFO) << "total rescore cost:" << tot_attention_rescore_time << " sec";
LOG(INFO) << "RTF is: " << tot_decode_time / tot_wav_duration;
}

@ -0,0 +1,46 @@
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
project(cls)
set(ARCH "mserver_x86_64" CACHE STRING "Target Architecture:
android_arm, android_armv7, android_armv8, android_x86, android_x86_64,
mserver_x86_64, ubuntu_x86_64,
ios_armv7, ios_armv7s, ios_armv8, ios_x86_64, ios_x86,
windows_x86")
set(CMAKE_VERBOSE_MAKEFILE ON)
set(FASTDEPLOY_DIR ${CMAKE_SOURCE_DIR}/fc_patch/fastdeploy)
if(NOT EXISTS ${FASTDEPLOY_DIR}/fastdeploy-linux-x64-1.0.2.tgz)
exec_program("mkdir -p ${FASTDEPLOY_DIR} &&
wget https://bj.bcebos.com/fastdeploy/release/cpp/fastdeploy-linux-x64-1.0.2.tgz -P ${FASTDEPLOY_DIR} &&
tar xzvf ${FASTDEPLOY_DIR}/fastdeploy-linux-x64-1.0.2.tgz -C ${FASTDEPLOY_DIR} &&
mv ${FASTDEPLOY_DIR}/fastdeploy-linux-x64-1.0.2 ${FASTDEPLOY_DIR}/linux-x64")
endif()
if(NOT EXISTS ${FASTDEPLOY_DIR}/fastdeploy-android-1.0.0-shared.tgz)
exec_program("mkdir -p ${FASTDEPLOY_DIR} &&
wget https://bj.bcebos.com/fastdeploy/release/android/fastdeploy-android-1.0.0-shared.tgz -P ${FASTDEPLOY_DIR} &&
tar xzvf ${FASTDEPLOY_DIR}/fastdeploy-android-1.0.0-shared.tgz -C ${FASTDEPLOY_DIR} &&
mv ${FASTDEPLOY_DIR}/fastdeploy-android-1.0.0-shared ${FASTDEPLOY_DIR}/android-armv7v8")
endif()
if (ARCH STREQUAL "mserver_x86_64")
set(FASTDEPLOY_INSTALL_DIR ${FASTDEPLOY_DIR}/linux-x64)
add_definitions("-DUSE_PADDLE_INFERENCE_BACKEND")
# add_definitions("-DUSE_ORT_BACKEND")
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -msse -msse2")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -msse -msse2 -mavx -O3")
elseif (ARCH STREQUAL "android_armv7")
set(FASTDEPLOY_INSTALL_DIR ${FASTDEPLOY_DIR}/android-armv7v8)
add_definitions("-DUSE_PADDLE_LITE_BAKEND")
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -g -mfloat-abi=softfp -mfpu=vfpv3 -mfpu=neon -fPIC -pie -fPIE")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -g0 -O3 -mfloat-abi=softfp -mfpu=vfpv3 -mfpu=neon -fPIC -pie -fPIE")
endif()
# add_definitions("-DTEST_DEBUG")
# add_definitions("-DPRINT_TIME")
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
add_subdirectory(nnet)

@ -0,0 +1,12 @@
include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake)
# FastDeploy
include_directories(${FASTDEPLOY_INCS})
set(srcs cls_nnet.cc cls_interface.cc)
add_library(cls SHARED ${srcs})
target_link_libraries(cls -static-libstdc++;-Wl,-Bsymbolic ${FASTDEPLOY_LIBS} kaldi-native-fbank-core)
set(bin_name cls_nnet_main)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_link_libraries(${bin_name} -static-libstdc++;-Wl,-Bsymbolic cls)

@ -0,0 +1,59 @@
#include "cls/nnet/cls_nnet.h"
#include "cls/nnet/config.h"
namespace ppspeech{
void* cls_create_instance(const char* conf_path){
Config conf(conf_path);
//cls init
ppspeech::ClsNnetConf cls_nnet_conf;
cls_nnet_conf.wav_normal_ = conf.Read("wav_normal", true);
cls_nnet_conf.wav_normal_type_ = conf.Read("wav_normal_type", std::string("linear"));
cls_nnet_conf.wav_norm_mul_factor_ = conf.Read("wav_norm_mul_factor", 1.0);
cls_nnet_conf.model_file_path_ = conf.Read("model_path", std::string(""));
cls_nnet_conf.param_file_path_ = conf.Read("param_path", std::string(""));
cls_nnet_conf.dict_file_path_ = conf.Read("dict_path", std::string(""));
cls_nnet_conf.num_cpu_thread_ = conf.Read("num_cpu_thread", 12);
cls_nnet_conf.samp_freq = conf.Read("samp_freq", 32000);
cls_nnet_conf.frame_length_ms = conf.Read("frame_length_ms", 32);
cls_nnet_conf.frame_shift_ms = conf.Read("frame_shift_ms", 10);
cls_nnet_conf.num_bins = conf.Read("num_bins", 64);
cls_nnet_conf.low_freq = conf.Read("low_freq", 50);
cls_nnet_conf.high_freq = conf.Read("high_freq", 14000);
cls_nnet_conf.dither = conf.Read("dither", 0.0);
ppspeech::ClsNnet* cls_model = new ppspeech::ClsNnet();
int ret = cls_model->init(cls_nnet_conf);
return (void*)cls_model;
};
int cls_destroy_instance(void* instance){
ppspeech::ClsNnet* cls_model = (ppspeech::ClsNnet*)instance;
if(cls_model != NULL){
delete cls_model;
cls_model = NULL;
}
return 0;
};
int cls_feedforward(void* instance, const char* wav_path, int topk, char* result, int result_max_len){
ppspeech::ClsNnet* cls_model = (ppspeech::ClsNnet*)instance;
if(cls_model == NULL){
printf("instance is null\n");
return -1;
}
int ret = cls_model->forward(wav_path, topk, result, result_max_len);
return 0;
};
int cls_reset(void* instance){
ppspeech::ClsNnet* cls_model = (ppspeech::ClsNnet*)instance;
if(cls_model == NULL){
printf("instance is null\n");
return -1;
}
cls_model->reset();
return 0;
};
}

@ -0,0 +1,10 @@
#pragma once
namespace ppspeech{
void* cls_create_instance(const char* conf_path);
int cls_destroy_instance(void* instance);
int cls_feedforward(void* instance, const char* wav_path, int topk, char* result, int result_max_len);
int cls_reset(void* instance);
}

@ -0,0 +1,405 @@
#include "cls/nnet/cls_nnet.h"
#ifdef PRINT_TIME
#include <time.h>
#endif
namespace ppspeech {
ClsNnet::ClsNnet(){
// wav_reader_ = NULL;
runtime_ = NULL;
};
void ClsNnet::reset(){
// wav_reader_->Clear();
ss_.str("");
};
int ClsNnet::init(ClsNnetConf& conf){
conf_ = conf;
//init fbank opts
fbank_opts_.frame_opts.samp_freq = conf.samp_freq;
fbank_opts_.frame_opts.frame_length_ms = conf.frame_length_ms;
fbank_opts_.frame_opts.frame_shift_ms = conf.frame_shift_ms;
fbank_opts_.mel_opts.num_bins = conf.num_bins;
fbank_opts_.mel_opts.low_freq = conf.low_freq;
fbank_opts_.mel_opts.high_freq = conf.high_freq;
fbank_opts_.frame_opts.dither = conf.dither;
fbank_opts_.use_log_fbank = false;
//init dict
if (conf.dict_file_path_ != ""){
init_dict(conf.dict_file_path_);
}
// init model
fastdeploy::RuntimeOption runtime_option;
#ifdef USE_ORT_BACKEND
runtime_option.SetModelPath(conf.model_file_path_, "", fastdeploy::ModelFormat::ONNX); // onnx
runtime_option.UseOrtBackend(); // onnx
#endif
#ifdef USE_PADDLE_LITE_BACKEND
runtime_option.SetModelPath(conf.model_file_path_, conf.param_file_path_, fastdeploy::ModelFormat::PADDLE);
runtime_option.UseLiteBackend();
#endif
#ifdef USE_PADDLE_INFERENCE_BACKEND
runtime_option.SetModelPath(conf.model_file_path_, conf.param_file_path_, fastdeploy::ModelFormat::PADDLE);
runtime_option.UsePaddleInferBackend();
#endif
runtime_option.SetCpuThreadNum(conf.num_cpu_thread_);
runtime_option.DeletePaddleBackendPass("simplify_with_basic_ops_pass");
runtime_ = std::unique_ptr<fastdeploy::Runtime>(new fastdeploy::Runtime());
if (!runtime_->Init(runtime_option)) {
std::cerr << "--- Init FastDeploy Runitme Failed! "
<< "\n--- Model: " << conf.model_file_path_ << std::endl;
return -1;
} else {
std::cout << "--- Init FastDeploy Runitme Done! "
<< "\n--- Model: " << conf.model_file_path_ << std::endl;
}
reset();
return 0;
};
int ClsNnet::init_dict(std::string& dict_path){
std::ifstream fp(dict_path);
std::string line = "";
while(getline(fp, line)){
dict_.push_back(line);
}
return 0;
};
int ClsNnet::forward(const char* wav_path, int topk, char* result, int result_max_len){
#ifdef PRINT_TIME
double duration = 0;
std::chrono::high_resolution_clock::time_point start_time = std::chrono::high_resolution_clock::now();
#endif
//read wav
WavReader wav_reader(wav_path);
int wavform_len = wav_reader.num_samples();
std::vector<float> wavform(wavform_len);
memcpy(wavform.data(), wav_reader.data(), wavform_len * sizeof(float));
waveform_float_normal(wavform);
waveform_normal(wavform, conf_.wav_normal_, conf_.wav_normal_type_, conf_.wav_norm_mul_factor_);
#ifdef TEST_DEBUG
{
std::ofstream fp("cls.wavform", std::ios::out);
for (int i = 0; i < wavform.size(); ++i) {
fp << std::setprecision(18) << wavform[i] << " ";
}
fp << "\n";
}
#endif
#ifdef PRINT_TIME
std::chrono::high_resolution_clock::time_point end_time = std::chrono::high_resolution_clock::now();
duration = std::chrono::duration<double, std::micro>(end_time - start_time).count();
printf("wav read consume: %fs\n", duration / 1000000);
#endif
#ifdef PRINT_TIME
start_time = std::chrono::high_resolution_clock::now();
#endif
std::vector<float> feats;
std::unique_ptr<ppspeech::FrontendInterface> data_source(new ppspeech::DataCache());
ppspeech::Fbank fbank(fbank_opts_, std::move(data_source));
fbank.Accept(wavform);
fbank.SetFinished();
fbank.Read(&feats);
int feat_dim = fbank_opts_.mel_opts.num_bins;
int num_frames = feats.size() / feat_dim;
for (int i = 0; i < num_frames; ++i){
for(int j = 0; j < feat_dim; ++j){
feats[i * feat_dim + j] = power_to_db(feats[i * feat_dim + j]);
}
}
#ifdef TEST_DEBUG
{
std::ofstream fp("cls.feat", std::ios::out);
for (int i = 0; i < num_frames; ++i) {
for (int j = 0; j < feat_dim; ++j){
fp << std::setprecision(18) << feats[i * feat_dim + j] << " ";
}
fp << "\n";
}
}
#endif
#ifdef PRINT_TIME
end_time = std::chrono::high_resolution_clock::now();
duration = std::chrono::duration<double, std::micro>(end_time - start_time).count();
printf("extract fbank consume: %fs\n", duration / 1000000);
#endif
// model_forward_stream(feats);
//infer
std::vector<float> model_out;
#ifdef PRINT_TIME
start_time = std::chrono::high_resolution_clock::now();
#endif
model_forward(feats.data(), num_frames, feat_dim, model_out);
#ifdef PRINT_TIME
end_time = std::chrono::high_resolution_clock::now();
duration = std::chrono::duration<double, std::micro>(end_time - start_time).count();
printf("fast deploy infer consume: %fs\n", duration / 1000000);
#endif
#ifdef TEST_DEBUG
{
std::ofstream fp("cls.logits", std::ios::out);
for (int i = 0; i < model_out.size(); ++i) {
fp << std::setprecision(18) << model_out[i] << "\n";
}
}
#endif
// construct result str
ss_ << "{";
get_topk(topk, model_out);
ss_ << "}";
if (result_max_len <= ss_.str().size()){
printf("result_max_len is short than result len\n");
}
snprintf(result, result_max_len, "%s", ss_.str().c_str());
return 0;
};
int ClsNnet::model_forward(const float* features, const int num_frames, const int feat_dim, std::vector<float>& model_out){
// init input tensor shape
fastdeploy::TensorInfo info = runtime_->GetInputInfo(0);
info.shape = {1, num_frames, feat_dim};
std::vector<fastdeploy::FDTensor> input_tensors(1);
std::vector<fastdeploy::FDTensor> output_tensors(1);
input_tensors[0].SetExternalData({1, num_frames, feat_dim}, fastdeploy::FDDataType::FP32, (void*)features);
//get input name
input_tensors[0].name = info.name;
runtime_->Infer(input_tensors, &output_tensors);
// output_tensors[0].PrintInfo();
std::vector<int64_t> output_shape = output_tensors[0].Shape();
model_out.resize(output_shape[0] * output_shape[1]);
memcpy((void*)model_out.data(), output_tensors[0].Data(), output_shape[0] * output_shape[1] * sizeof(float));
return 0;
};
int ClsNnet::model_forward_stream(std::vector<float>& feats){
// init input tensor shape
std::vector<fastdeploy::TensorInfo> input_infos = runtime_->GetInputInfos();
std::vector<fastdeploy::TensorInfo> output_infos = runtime_->GetOutputInfos();
std::vector<fastdeploy::FDTensor> input_tensors(14);
std::vector<fastdeploy::FDTensor> output_tensors(13);
{
std::vector<float> feats_tmp(feats.begin(), feats.begin() + 400 * 64);
std::vector<int> flag({0});
std::vector<float> block1_conv1_cache(1 * 64 * 400 * 64, 0);
std::vector<float> block1_conv2_cache(1 * 64 * 400 * 64, 0);
std::vector<float> block2_conv1_cache(1 * 128 * 200 * 32, 0);
std::vector<float> block2_conv2_cache(1 * 128 * 200 * 32, 0);
std::vector<float> block3_conv1_cache(1 * 256 * 100 * 16, 0);
std::vector<float> block3_conv2_cache(1 * 256 * 100 * 16, 0);
std::vector<float> block4_conv1_cache(1 * 512 * 50 * 8, 0);
std::vector<float> block4_conv2_cache(1 * 512 * 50 * 8, 0);
std::vector<float> block5_conv1_cache(1 * 1024 * 25 * 4, 0);
std::vector<float> block5_conv2_cache(1 * 1024 * 25 * 4, 0);
std::vector<float> block6_conv1_cache(1 * 2048 * 12 * 2, 0);
std::vector<float> block6_conv2_cache(1 * 2048 * 12 * 2, 0);
input_tensors[0].name = input_infos[0].name;
input_tensors[0].SetExternalData({1, 400, 64}, fastdeploy::FDDataType::FP32, (void*)feats_tmp.data());
input_tensors[1].name = input_infos[1].name;
input_tensors[1].SetExternalData({1}, fastdeploy::FDDataType::INT32, (void*)flag.data());
input_tensors[2].name = input_infos[2].name;
input_tensors[2].SetExternalData({1, 64, 400, 64}, fastdeploy::FDDataType::FP32, (void*)block1_conv1_cache.data());
input_tensors[3].name = input_infos[3].name;
input_tensors[3].SetExternalData({1, 64, 400, 64}, fastdeploy::FDDataType::FP32, (void*)block1_conv2_cache.data());
input_tensors[4].name = input_infos[4].name;
input_tensors[4].SetExternalData({1, 128, 200, 32}, fastdeploy::FDDataType::FP32, (void*)block2_conv1_cache.data());
input_tensors[5].name = input_infos[5].name;
input_tensors[5].SetExternalData({1, 128, 200, 32}, fastdeploy::FDDataType::FP32, (void*)block2_conv2_cache.data());
input_tensors[6].name = input_infos[6].name;
input_tensors[6].SetExternalData({1, 256, 100, 16}, fastdeploy::FDDataType::FP32, (void*)block3_conv1_cache.data());
input_tensors[7].name = input_infos[7].name;
input_tensors[7].SetExternalData({1, 256, 100, 16}, fastdeploy::FDDataType::FP32, (void*)block3_conv2_cache.data());
input_tensors[8].name = input_infos[8].name;
input_tensors[8].SetExternalData({1, 512, 50, 8}, fastdeploy::FDDataType::FP32, (void*)block4_conv1_cache.data());
input_tensors[9].name = input_infos[9].name;
input_tensors[9].SetExternalData({1, 512, 50, 8}, fastdeploy::FDDataType::FP32, (void*)block4_conv2_cache.data());
input_tensors[10].name = input_infos[10].name;
input_tensors[10].SetExternalData({1, 1024, 25, 4}, fastdeploy::FDDataType::FP32, (void*)block5_conv1_cache.data());
input_tensors[11].name = input_infos[11].name;
input_tensors[11].SetExternalData({1, 1024, 25, 4}, fastdeploy::FDDataType::FP32, (void*)block5_conv2_cache.data());
input_tensors[12].name = input_infos[12].name;
input_tensors[12].SetExternalData({1, 2048, 12, 2}, fastdeploy::FDDataType::FP32, (void*)block6_conv1_cache.data());
input_tensors[13].name = input_infos[13].name;
input_tensors[13].SetExternalData({1, 2048, 12, 2}, fastdeploy::FDDataType::FP32, (void*)block6_conv2_cache.data());
std::vector<float> model_out_tmp;
#ifdef PRINT_TIME
double duration = 0;
std::chrono::high_resolution_clock::time_point start_time = std::chrono::high_resolution_clock::now();
#endif
runtime_->Infer(input_tensors, &output_tensors);
#ifdef PRINT_TIME
std::chrono::high_resolution_clock::time_point end_time = std::chrono::high_resolution_clock::now();
duration = std::chrono::duration<double, std::micro>(end_time - start_time).count();
printf("infer %d:%d consume: %fs\n", 0, 400, duration / 1000000);
#endif
// output_tensors[0].PrintInfo();
std::vector<int64_t> output_shape = output_tensors[0].Shape();
model_out_tmp.resize(output_shape[0] * output_shape[1]);
memcpy((void*)model_out_tmp.data(), output_tensors[0].Data(), output_shape[0] * output_shape[1] * sizeof(float));
#ifdef TEST_DEBUG
{
std::stringstream ss;
ss << "cls.logits." << 0 << "-" << 400;
std::ofstream fp(ss.str(), std::ios::out);
for (int i = 0; i < model_out_tmp.size(); ++i) {
fp << std::setprecision(18) << model_out_tmp[i] << "\n";
}
}
#endif
}
{
std::vector<float> feats_tmp(feats.begin() + 32 * 64, feats.begin() + 432 * 64);
std::vector<int> flag({1});
input_tensors[0].SetExternalData({1, 400, 64}, fastdeploy::FDDataType::FP32, (void*)feats_tmp.data());
input_tensors[1].SetExternalData({1}, fastdeploy::FDDataType::INT32, (void*)flag.data());
input_tensors[2] = output_tensors[1];
input_tensors[2].name = input_infos[2].name;
input_tensors[3] = output_tensors[2];
input_tensors[3].name = input_infos[3].name;
input_tensors[4] = output_tensors[3];
input_tensors[4].name = input_infos[4].name;
input_tensors[5] = output_tensors[4];
input_tensors[5].name = input_infos[5].name;
input_tensors[6] = output_tensors[5];
input_tensors[6].name = input_infos[6].name;
input_tensors[7] = output_tensors[6];
input_tensors[7].name = input_infos[7].name;
input_tensors[8] = output_tensors[7];
input_tensors[8].name = input_infos[8].name;
input_tensors[9] = output_tensors[8];
input_tensors[9].name = input_infos[9].name;
input_tensors[10] = output_tensors[9];
input_tensors[10].name = input_infos[10].name;
input_tensors[11] = output_tensors[10];
input_tensors[11].name = input_infos[11].name;
input_tensors[12] = output_tensors[11];
input_tensors[12].name = input_infos[12].name;
input_tensors[13] = output_tensors[12];
input_tensors[13].name = input_infos[13].name;
std::vector<float> model_out_tmp;
#ifdef PRINT_TIME
double duration = 0;
std::chrono::high_resolution_clock::time_point start_time = std::chrono::high_resolution_clock::now();
#endif
runtime_->Infer(input_tensors, &output_tensors);
#ifdef PRINT_TIME
std::chrono::high_resolution_clock::time_point end_time = std::chrono::high_resolution_clock::now();
duration = std::chrono::duration<double, std::micro>(end_time - start_time).count();
printf("infer %d:%d consume: %fs\n", 32, 432, duration / 1000000);
#endif
// output_tensors[0].PrintInfo();
std::vector<int64_t> output_shape = output_tensors[0].Shape();
model_out_tmp.resize(output_shape[0] * output_shape[1]);
memcpy((void*)model_out_tmp.data(), output_tensors[0].Data(), output_shape[0] * output_shape[1] * sizeof(float));
#ifdef TEST_DEBUG
{
std::stringstream ss;
ss << "cls.logits." << 32 << "-" << 432;
std::ofstream fp(ss.str(), std::ios::out);
for (int i = 0; i < model_out_tmp.size(); ++i) {
fp << std::setprecision(18) << model_out_tmp[i] << "\n";
}
}
#endif
}
exit(1);
return 0;
};
int ClsNnet::get_topk(int k, std::vector<float>& model_out){
std::vector<std::pair<float, int>> sort_vec;
for (int i = 0; i < model_out.size(); ++i){
sort_vec.push_back({-1 * model_out[i], i});
}
std::sort(sort_vec.begin(), sort_vec.end());
for (int i = 0; i < k; ++i){
if (i != 0){
ss_ << ",";
}
ss_ << "\"" << dict_[sort_vec[i].second] << "\":\"" << -1 * sort_vec[i].first << "\"";
}
return 0;
};
int ClsNnet::waveform_float_normal(std::vector<float>& waveform){
int tot_samples = waveform.size();
for (int i = 0; i < tot_samples; i++){
waveform[i] = waveform[i] / 32768.0;
}
return 0;
}
int ClsNnet::waveform_normal(std::vector<float>& waveform, bool wav_normal, std::string& wav_normal_type, float wav_norm_mul_factor){
if (wav_normal == false){
return 0;
}
if (wav_normal_type == "linear"){
float amax = INT32_MIN;
for (int i = 0; i < waveform.size(); ++i){
float tmp = std::abs(waveform[i]);
amax = std::max(amax, tmp);
}
float factor = 1.0 / (amax + 1e-8);
for (int i = 0; i < waveform.size(); ++i){
waveform[i] = waveform[i] * factor * wav_norm_mul_factor;
}
} else if (wav_normal_type == "gaussian") {
double sum = std::accumulate(waveform.begin(), waveform.end(), 0.0);
double mean = sum / waveform.size(); //均值
double accum = 0.0;
std::for_each (waveform.begin(), waveform.end(), [&](const double d) {
accum += (d-mean)*(d-mean);
});
double stdev = sqrt(accum/(waveform.size()-1)); //方差
stdev = std::max(stdev, 1e-8);
for (int i = 0; i < waveform.size(); ++i){
waveform[i] = wav_norm_mul_factor * (waveform[i] - mean) / stdev;
}
} else {
printf("don't support\n");
return -1;
}
return 0;
}
float ClsNnet::power_to_db(float in, float ref_value, float amin, float top_db){
if(amin <= 0){
printf("amin must be strictly positive\n");
return -1;
};
if(ref_value <= 0){
printf("ref_value must be strictly positive\n");
return -1;
}
float out = 10.0 * log10(std::max(amin, in));
out -= 10.0 * log10(std::max(ref_value, amin));
return out;
}
}

@ -0,0 +1,73 @@
// Copyright 2022 Horizon Robotics. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// modified from
// https://github.com/wenet-e2e/wenet/blob/main/runtime/core/decoder/asr_model.h
#pragma once
#include <iomanip>
#include <algorithm>
#include <numeric>
#include "fastdeploy/runtime.h"
#include "cls/nnet/wav.h"
#include "common/frontend/frontend_itf.h"
#include "common/frontend/data_cache.h"
#include "common/frontend/feature-fbank.h"
#include "frontend/fbank.h"
namespace ppspeech {
struct ClsNnetConf {
//wav
bool wav_normal_;
std::string wav_normal_type_;
float wav_norm_mul_factor_;
//model
std::string model_file_path_;
std::string param_file_path_;
std::string dict_file_path_;
int num_cpu_thread_;
//fbank
float samp_freq;
float frame_length_ms;
float frame_shift_ms;
int num_bins;
float low_freq;
float high_freq;
float dither;
};
class ClsNnet {
public:
ClsNnet();
int init(ClsNnetConf& conf);
int forward(const char* wav_path, int topk, char* result, int result_max_len);
void reset();
private:
int init_dict(std::string& dict_path);
int model_forward(const float* features, const int num_frames, const int feat_dim, std::vector<float>& model_out);
int model_forward_stream(std::vector<float>& feats);
int get_topk(int k, std::vector<float>& model_out);
int waveform_float_normal(std::vector<float>& waveform);
int waveform_normal(std::vector<float>& waveform, bool wav_normal, std::string& wav_normal_type, float wav_norm_mul_factor);
float power_to_db(float in, float ref_value = 1.0, float amin = 1e-10, float top_db = 80.0);
ClsNnetConf conf_;
knf::FbankOptions fbank_opts_;
std::unique_ptr<fastdeploy::Runtime> runtime_;
std::vector<std::string> dict_;
std::stringstream ss_;
};
} // namespace ppspeech

@ -0,0 +1,45 @@
/**
* Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string>
#include <fstream>
#include "nnet/cls_interface.h"
int main(int argc, char* argv[]) {
if (argc != 4){
printf("usage : cls_nnet_main conf_path wav_path topk\n");
return 0;
}
const char* conf_path = argv[1];
const char* wav_path = argv[2];
int topk = std::atoi(argv[3]);
void* instance = ppspeech::cls_create_instance(conf_path);
int ret = 0;
//read wav
std::ifstream ifs(wav_path);
std::string line = "";
while(getline(ifs, line)){
//read wav
char result[1024] = {0};
ret = ppspeech::cls_feedforward(instance, line.c_str(), topk, result, 1024);
printf("%s %s\n", line.c_str(), result);
ret = ppspeech::cls_reset(instance);
}
ret = ppspeech::cls_destroy_instance(instance);
return 0;
}

@ -0,0 +1,352 @@
#include <iostream>
#include <fstream>
#include <map>
#include <string>
#include <sstream>
using namespace std;
#pragma once
#pragma region ParseIniFile
/*
* \brief Generic configuration Class
*
*/
class Config {
// Data
protected:
std::string m_Delimiter; //!< separator between key and value
std::string m_Comment; //!< separator between value and comments
std::map<std::string, std::string> m_Contents; //!< extracted keys and values
typedef std::map<std::string, std::string>::iterator mapi;
typedef std::map<std::string, std::string>::const_iterator mapci;
// Methods
public:
Config(std::string filename, std::string delimiter = "=", std::string comment = "#");
Config();
template<class T> T Read(const std::string& in_key) const; //!<Search for key and read value or optional default value, call as read<T>
template<class T> T Read(const std::string& in_key, const T& in_value) const;
template<class T> bool ReadInto(T& out_var, const std::string& in_key) const;
template<class T>
bool ReadInto(T& out_var, const std::string& in_key, const T& in_value) const;
bool FileExist(std::string filename);
void ReadFile(std::string filename, std::string delimiter = "=", std::string comment = "#");
// Check whether key exists in configuration
bool KeyExists(const std::string& in_key) const;
// Modify keys and values
template<class T> void Add(const std::string& in_key, const T& in_value);
void Remove(const std::string& in_key);
// Check or change configuration syntax
std::string GetDelimiter() const { return m_Delimiter; }
std::string GetComment() const { return m_Comment; }
std::string SetDelimiter(const std::string& in_s)
{
std::string old = m_Delimiter; m_Delimiter = in_s; return old;
}
std::string SetComment(const std::string& in_s)
{
std::string old = m_Comment; m_Comment = in_s; return old;
}
// Write or read configuration
friend std::ostream& operator<<(std::ostream& os, const Config& cf);
friend std::istream& operator >> (std::istream& is, Config& cf);
protected:
template<class T> static std::string T_as_string(const T& t);
template<class T> static T string_as_T(const std::string& s);
static void Trim(std::string& inout_s);
// Exception types
public:
struct File_not_found {
std::string filename;
File_not_found(const std::string& filename_ = std::string())
: filename(filename_) {}
};
struct Key_not_found { // thrown only by T read(key) variant of read()
std::string key;
Key_not_found(const std::string& key_ = std::string())
: key(key_) {}
};
};
/* static */
template<class T>
std::string Config::T_as_string(const T& t)
{
// Convert from a T to a string
// Type T must support << operator
std::ostringstream ost;
ost << t;
return ost.str();
}
/* static */
template<class T>
T Config::string_as_T(const std::string& s)
{
// Convert from a string to a T
// Type T must support >> operator
T t;
std::istringstream ist(s);
ist >> t;
return t;
}
/* static */
template<>
inline std::string Config::string_as_T<std::string>(const std::string& s)
{
// Convert from a string to a string
// In other words, do nothing
return s;
}
/* static */
template<>
inline bool Config::string_as_T<bool>(const std::string& s)
{
// Convert from a string to a bool
// Interpret "false", "F", "no", "n", "0" as false
// Interpret "true", "T", "yes", "y", "1", "-1", or anything else as true
bool b = true;
std::string sup = s;
for (std::string::iterator p = sup.begin(); p != sup.end(); ++p)
*p = toupper(*p); // make string all caps
if (sup == std::string("FALSE") || sup == std::string("F") ||
sup == std::string("NO") || sup == std::string("N") ||
sup == std::string("0") || sup == std::string("NONE"))
b = false;
return b;
}
template<class T>
T Config::Read(const std::string& key) const
{
// Read the value corresponding to key
mapci p = m_Contents.find(key);
if (p == m_Contents.end()) throw Key_not_found(key);
return string_as_T<T>(p->second);
}
template<class T>
T Config::Read(const std::string& key, const T& value) const
{
// Return the value corresponding to key or given default value
// if key is not found
mapci p = m_Contents.find(key);
if (p == m_Contents.end()) {
printf("%s = %s(default)\n", key.c_str(), T_as_string(value).c_str());
return value;
} else {
printf("%s = %s\n", key.c_str(), T_as_string(p->second).c_str());
return string_as_T<T>(p->second);
}
}
template<class T>
bool Config::ReadInto(T& var, const std::string& key) const
{
// Get the value corresponding to key and store in var
// Return true if key is found
// Otherwise leave var untouched
mapci p = m_Contents.find(key);
bool found = (p != m_Contents.end());
if (found) var = string_as_T<T>(p->second);
return found;
}
template<class T>
bool Config::ReadInto(T& var, const std::string& key, const T& value) const
{
// Get the value corresponding to key and store in var
// Return true if key is found
// Otherwise set var to given default
mapci p = m_Contents.find(key);
bool found = (p != m_Contents.end());
if (found)
var = string_as_T<T>(p->second);
else
var = value;
return found;
}
template<class T>
void Config::Add(const std::string& in_key, const T& value)
{
// Add a key with given value
std::string v = T_as_string(value);
std::string key = in_key;
Trim(key);
Trim(v);
m_Contents[key] = v;
return;
}
Config::Config(string filename, string delimiter,
string comment)
: m_Delimiter(delimiter), m_Comment(comment)
{
// Construct a Config, getting keys and values from given file
std::ifstream in(filename.c_str());
if (!in) throw File_not_found(filename);
in >> (*this);
}
Config::Config()
: m_Delimiter(string(1, '=')), m_Comment(string(1, '#'))
{
// Construct a Config without a file; empty
}
bool Config::KeyExists(const string& key) const
{
// Indicate whether key is found
mapci p = m_Contents.find(key);
return (p != m_Contents.end());
}
/* static */
void Config::Trim(string& inout_s)
{
// Remove leading and trailing whitespace
static const char whitespace[] = " \n\t\v\r\f";
inout_s.erase(0, inout_s.find_first_not_of(whitespace));
inout_s.erase(inout_s.find_last_not_of(whitespace) + 1U);
}
std::ostream& operator<<(std::ostream& os, const Config& cf)
{
// Save a Config to os
for (Config::mapci p = cf.m_Contents.begin();
p != cf.m_Contents.end();
++p)
{
os << p->first << " " << cf.m_Delimiter << " ";
os << p->second << std::endl;
}
return os;
}
void Config::Remove(const string& key)
{
// Remove key and its value
m_Contents.erase(m_Contents.find(key));
return;
}
std::istream& operator >> (std::istream& is, Config& cf)
{
// Load a Config from is
// Read in keys and values, keeping internal whitespace
typedef string::size_type pos;
const string& delim = cf.m_Delimiter; // separator
const string& comm = cf.m_Comment; // comment
const pos skip = delim.length(); // length of separator
string nextline = ""; // might need to read ahead to see where value ends
while (is || nextline.length() > 0)
{
// Read an entire line at a time
string line;
if (nextline.length() > 0)
{
line = nextline; // we read ahead; use it now
nextline = "";
}
else
{
std::getline(is, line);
}
// Ignore comments
line = line.substr(0, line.find(comm));
// Parse the line if it contains a delimiter
pos delimPos = line.find(delim);
if (delimPos < string::npos)
{
// Extract the key
string key = line.substr(0, delimPos);
line.replace(0, delimPos + skip, "");
// See if value continues on the next line
// Stop at blank line, next line with a key, end of stream,
// or end of file sentry
bool terminate = false;
while (!terminate && is)
{
std::getline(is, nextline);
terminate = true;
string nlcopy = nextline;
Config::Trim(nlcopy);
if (nlcopy == "") continue;
nextline = nextline.substr(0, nextline.find(comm));
if (nextline.find(delim) != string::npos)
continue;
nlcopy = nextline;
Config::Trim(nlcopy);
if (nlcopy != "") line += "\n";
line += nextline;
terminate = false;
}
// Store key and value
Config::Trim(key);
Config::Trim(line);
cf.m_Contents[key] = line; // overwrites if key is repeated
}
}
return is;
}
bool Config::FileExist(std::string filename)
{
bool exist = false;
std::ifstream in(filename.c_str());
if (in)
exist = true;
return exist;
}
void Config::ReadFile(string filename, string delimiter,
string comment)
{
m_Delimiter = delimiter;
m_Comment = comment;
std::ifstream in(filename.c_str());
if (!in) throw File_not_found(filename);
in >> (*this);
}
#pragma endregion ParseIniFIle

@ -0,0 +1,241 @@
// Copyright (c) 2016 Personal (Binbin Zhang)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef FRONTEND_WAV_H_
#define FRONTEND_WAV_H_
#include <assert.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <string>
#pragma once
namespace ppspeech {
struct WavHeader {
char riff[4] = {'R', 'I', 'F', 'F'};
unsigned int size = 0;
char wav[4] = {'W', 'A', 'V', 'E'};
char fmt[4] = {'f', 'm', 't', ' '};
unsigned int fmt_size = 16;
uint16_t format = 1;
uint16_t channels = 0;
unsigned int sample_rate = 0;
unsigned int bytes_per_second = 0;
uint16_t block_size = 0;
uint16_t bit = 0;
char data[4] = {'d', 'a', 't', 'a'};
unsigned int data_size = 0;
WavHeader() {}
WavHeader(int num_samples, int num_channel, int sample_rate,
int bits_per_sample) {
data_size = num_samples * num_channel * (bits_per_sample / 8);
size = sizeof(WavHeader) - 8 + data_size;
channels = num_channel;
this->sample_rate = sample_rate;
bytes_per_second = sample_rate * num_channel * (bits_per_sample / 8);
block_size = num_channel * (bits_per_sample / 8);
bit = bits_per_sample;
}
};
class WavReader {
public:
WavReader() : data_(nullptr) {}
explicit WavReader(const std::string& filename) { Open(filename); }
bool Open(const std::string& filename) {
FILE* fp = fopen(filename.c_str(), "rb");
if (NULL == fp) {
//LOG(WARNING) << "Error in read " << filename;
return false;
}
WavHeader header;
fread(&header, 1, sizeof(header), fp);
if (header.fmt_size < 16) {
fprintf(stderr,
"WaveData: expect PCM format data "
"to have fmt chunk of at least size 16.\n");
return false;
} else if (header.fmt_size > 16) {
int offset = 44 - 8 + header.fmt_size - 16;
fseek(fp, offset, SEEK_SET);
fread(header.data, 8, sizeof(char), fp);
}
// check "RIFF" "WAVE" "fmt " "data"
// Skip any sub-chunks between "fmt" and "data". Usually there will
// be a single "fact" sub chunk, but on Windows there can also be a
// "list" sub chunk.
while (0 != strncmp(header.data, "data", 4)) {
// We will just ignore the data in these chunks.
fseek(fp, header.data_size, SEEK_CUR);
// read next sub chunk
fread(header.data, 8, sizeof(char), fp);
}
num_channel_ = header.channels;
sample_rate_ = header.sample_rate;
bits_per_sample_ = header.bit;
int num_data = header.data_size / (bits_per_sample_ / 8);
data_ = new float[num_data];
num_samples_ = num_data / num_channel_;
for (int i = 0; i < num_data; ++i) {
switch (bits_per_sample_) {
case 8: {
char sample;
fread(&sample, 1, sizeof(char), fp);
data_[i] = static_cast<float>(sample);
break;
}
case 16: {
int16_t sample;
fread(&sample, 1, sizeof(int16_t), fp);
data_[i] = static_cast<float>(sample);
break;
}
case 32: {
int sample;
fread(&sample, 1, sizeof(int), fp);
data_[i] = static_cast<float>(sample);
break;
}
default:
fprintf(stderr, "unsupported quantization bits");
exit(1);
}
}
fclose(fp);
return true;
}
int num_channel() const { return num_channel_; }
int sample_rate() const { return sample_rate_; }
int bits_per_sample() const { return bits_per_sample_; }
int num_samples() const { return num_samples_; }
~WavReader() {
delete[] data_;
}
const float* data() const { return data_; }
private:
int num_channel_;
int sample_rate_;
int bits_per_sample_;
int num_samples_; // sample points per channel
float* data_;
};
class WavWriter {
public:
WavWriter(const float* data, int num_samples, int num_channel,
int sample_rate, int bits_per_sample)
: data_(data),
num_samples_(num_samples),
num_channel_(num_channel),
sample_rate_(sample_rate),
bits_per_sample_(bits_per_sample) {}
void Write(const std::string& filename) {
FILE* fp = fopen(filename.c_str(), "w");
WavHeader header(num_samples_, num_channel_, sample_rate_,
bits_per_sample_);
fwrite(&header, 1, sizeof(header), fp);
for (int i = 0; i < num_samples_; ++i) {
for (int j = 0; j < num_channel_; ++j) {
switch (bits_per_sample_) {
case 8: {
char sample = static_cast<char>(data_[i * num_channel_ + j]);
fwrite(&sample, 1, sizeof(sample), fp);
break;
}
case 16: {
int16_t sample = static_cast<int16_t>(data_[i * num_channel_ + j]);
fwrite(&sample, 1, sizeof(sample), fp);
break;
}
case 32: {
int sample = static_cast<int>(data_[i * num_channel_ + j]);
fwrite(&sample, 1, sizeof(sample), fp);
break;
}
}
}
}
fclose(fp);
}
private:
const float* data_;
int num_samples_; // total float points in data_
int num_channel_;
int sample_rate_;
int bits_per_sample_;
};
class StreamWavWriter {
public:
StreamWavWriter(int num_channel, int sample_rate, int bits_per_sample)
: num_channel_(num_channel),
sample_rate_(sample_rate),
bits_per_sample_(bits_per_sample),
total_num_samples_(0) {}
StreamWavWriter(const std::string& filename, int num_channel,
int sample_rate, int bits_per_sample)
: StreamWavWriter(num_channel, sample_rate, bits_per_sample) {
Open(filename);
}
void Open(const std::string& filename) {
fp_ = fopen(filename.c_str(), "wb");
fseek(fp_, sizeof(WavHeader), SEEK_SET);
}
void Write(const int16_t* sample_data, size_t num_samples) {
fwrite(sample_data, sizeof(int16_t), num_samples, fp_);
total_num_samples_ += num_samples;
}
void Close() {
WavHeader header(total_num_samples_, num_channel_, sample_rate_,
bits_per_sample_);
fseek(fp_, 0L, SEEK_SET);
fwrite(&header, 1, sizeof(header), fp_);
fclose(fp_);
}
private:
FILE* fp_;
int num_channel_;
int sample_rate_;
int bits_per_sample_;
size_t total_num_samples_;
};
} // namespace wenet
#endif // FRONTEND_WAV_H_

@ -73,8 +73,7 @@ int main(int argc, char* argv[]) {
new ppspeech::CMVN(FLAGS_cmvn_file, std::move(fbank)));
// the feature cache output feature chunk by chunk.
ppspeech::FeatureCacheOptions feat_cache_opts;
ppspeech::FeatureCache feature_cache(feat_cache_opts, std::move(cmvn));
ppspeech::FeatureCache feature_cache(kint16max, std::move(cmvn));
LOG(INFO) << "fbank: " << true;
LOG(INFO) << "feat dim: " << feature_cache.Dim();

@ -20,10 +20,9 @@ using kaldi::BaseFloat;
using std::unique_ptr;
using std::vector;
FeatureCache::FeatureCache(FeatureCacheOptions opts,
FeatureCache::FeatureCache(size_t max_size,
unique_ptr<FrontendInterface> base_extractor) {
max_size_ = opts.max_size;
timeout_ = opts.timeout; // ms
max_size_ = max_size;
base_extractor_ = std::move(base_extractor);
dim_ = base_extractor_->Dim();
}
@ -31,34 +30,25 @@ FeatureCache::FeatureCache(FeatureCacheOptions opts,
void FeatureCache::Accept(const std::vector<kaldi::BaseFloat>& inputs) {
// read inputs
base_extractor_->Accept(inputs);
// feed current data
bool result = false;
do {
result = Compute();
} while (result);
}
// pop feature chunk
bool FeatureCache::Read(std::vector<kaldi::BaseFloat>* feats) {
kaldi::Timer timer;
std::unique_lock<std::mutex> lock(mutex_);
while (cache_.empty() && base_extractor_->IsFinished() == false) {
// todo refactor: wait
// ready_read_condition_.wait(lock);
int32 elapsed = static_cast<int32>(timer.Elapsed() * 1000); // ms
if (elapsed > timeout_) {
return false;
}
usleep(100); // sleep 0.1 ms
// feed current data
if (cache_.empty()) {
bool result = false;
do {
result = Compute();
} while (result);
}
if (cache_.empty()) return false;
// read from cache
*feats = cache_.front();
cache_.pop();
ready_feed_condition_.notify_one();
VLOG(1) << "FeatureCache::Read cost: " << timer.Elapsed() << " sec.";
return true;
}
@ -73,23 +63,15 @@ bool FeatureCache::Compute() {
kaldi::Timer timer;
int32 num_chunk = feature.size() / dim_;
nframe_ += num_chunk;
VLOG(3) << "nframe computed: " << nframe_;
for (int chunk_idx = 0; chunk_idx < num_chunk; ++chunk_idx) {
int32 start = chunk_idx * dim_;
vector<BaseFloat> feature_chunk(feature.data() + start,
feature.data() + start + dim_);
std::unique_lock<std::mutex> lock(mutex_);
while (cache_.size() >= max_size_) {
// cache full, wait
ready_feed_condition_.wait(lock);
}
// feed cache
cache_.push(feature_chunk);
ready_read_condition_.notify_one();
++nframe_;
}
VLOG(1) << "FeatureCache::Compute cost: " << timer.Elapsed() << " sec. "
@ -97,4 +79,4 @@ bool FeatureCache::Compute() {
return true;
}
} // namespace ppspeech
} // namespace ppspeech

@ -19,16 +19,10 @@
namespace ppspeech {
struct FeatureCacheOptions {
int32 max_size;
int32 timeout; // ms
FeatureCacheOptions() : max_size(kint16max), timeout(1) {}
};
class FeatureCache : public FrontendInterface {
public:
explicit FeatureCache(
FeatureCacheOptions opts,
size_t max_size = kint16max,
std::unique_ptr<FrontendInterface> base_extractor = NULL);
// Feed feats or waves
@ -41,13 +35,11 @@ class FeatureCache : public FrontendInterface {
virtual size_t Dim() const { return dim_; }
virtual void SetFinished() {
std::unique_lock<std::mutex> lock(mutex_);
LOG(INFO) << "set finished";
// std::unique_lock<std::mutex> lock(mutex_);
base_extractor_->SetFinished();
// read the last chunk data
Compute();
// ready_feed_condition_.notify_one();
base_extractor_->SetFinished();
LOG(INFO) << "compute last feats done.";
}
@ -66,16 +58,10 @@ class FeatureCache : public FrontendInterface {
int32 dim_;
size_t max_size_; // cache capacity
int32 frame_chunk_size_; // window
int32 frame_chunk_stride_; // stride
std::unique_ptr<FrontendInterface> base_extractor_;
kaldi::int32 timeout_; // ms
std::vector<kaldi::BaseFloat> remained_feature_;
std::queue<std::vector<BaseFloat>> cache_; // feature cache
std::mutex mutex_;
std::condition_variable ready_feed_condition_;
std::condition_variable ready_read_condition_;
int32 nframe_; // num of feature computed
DISALLOW_COPY_AND_ASSIGN(FeatureCache);

@ -33,7 +33,7 @@ FeaturePipeline::FeaturePipeline(const FeaturePipelineOptions& opts)
new ppspeech::CMVN(opts.cmvn_file, std::move(base_feature)));
unique_ptr<FrontendInterface> cache(
new ppspeech::FeatureCache(opts.feature_cache_opts, std::move(cmvn)));
new ppspeech::FeatureCache(kint16max, std::move(cmvn)));
base_extractor_.reset(
new ppspeech::Assembler(opts.assembler_opts, std::move(cache)));

@ -39,7 +39,6 @@ namespace ppspeech {
struct FeaturePipelineOptions {
std::string cmvn_file{};
knf::FbankOptions fbank_opts{};
FeatureCacheOptions feature_cache_opts{};
AssemblerOptions assembler_opts{};
static FeaturePipelineOptions InitFromFlags() {

Loading…
Cancel
Save