[Engine] recognizer controller refactor (#3139)

* refactor recognizer_controller
* clean frontend file
pull/3156/head
YangZhou 2 years ago committed by GitHub
parent 591b957b96
commit f35a87ab89
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -60,6 +60,7 @@ set(FETCHCONTENT_BASE_DIR ${fc_patch})
###############################################################################
# https://github.com/google/brotli/pull/655
option(BUILD_SHARED_LIBS "Build shared libraries" ON)
option(WITH_PPS_DEBUG "debug option" OFF)
if (WITH_PPS_DEBUG)
add_definitions("-DPPS_DEBUG")

@ -16,7 +16,7 @@ set(TEST_BINS
foreach(bin_name IN LISTS TEST_BINS)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} nnet decoder fst utils libgflags_nothreads.so glog kaldi-base kaldi-matrix kaldi-util)
target_link_libraries(${bin_name} nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util)
target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS})
target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS} -ldl)

@ -16,18 +16,6 @@ target_include_directories(nnet PUBLIC ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURC
# test bin
#set(bin_name u2_nnet_main)
#add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
#target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
#target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog nnet)
#target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS})
#target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
#target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS})
set(bin_name u2_nnet_thread_main)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog nnet frontend)
target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS})
target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS})

@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base/common.h"
#include "kaldi/decoder/decodable-itf.h"
#include "matrix/kaldi-matrix.h"

@ -24,42 +24,11 @@ using std::vector;
NnetProducer::NnetProducer(std::shared_ptr<NnetBase> nnet,
std::shared_ptr<FrontendInterface> frontend)
: nnet_(nnet), frontend_(frontend) {
abort_ = false;
Reset();
if (nnet_ != nullptr) thread_ = std::thread(RunNnetEvaluation, this);
}
void NnetProducer::Accept(const std::vector<kaldi::BaseFloat>& inputs) {
frontend_->Accept(inputs);
condition_variable_.notify_one();
}
void NnetProducer::WaitProduce() {
std::unique_lock<std::mutex> lock(read_mutex_);
while (frontend_->IsFinished() == false && cache_.empty()) {
condition_read_ready_.wait(lock);
}
return;
}
void NnetProducer::RunNnetEvaluation(NnetProducer* me) {
me->RunNnetEvaluationInteral();
}
void NnetProducer::RunNnetEvaluationInteral() {
bool result = false;
LOG(INFO) << "NnetEvaluationInteral begin";
while (!abort_) {
std::unique_lock<std::mutex> lock(mutex_);
condition_variable_.wait(lock);
do {
result = Compute();
} while (result);
if (frontend_->IsFinished() == true) {
if (cache_.empty()) finished_ = true;
}
}
LOG(INFO) << "NnetEvaluationInteral exit";
}
void NnetProducer::Acceptlikelihood(
@ -76,14 +45,7 @@ void NnetProducer::Acceptlikelihood(
bool NnetProducer::Read(std::vector<kaldi::BaseFloat>* nnet_prob) {
bool flag = cache_.pop(nnet_prob);
condition_variable_.notify_one();
return flag;
}
bool NnetProducer::ReadandCompute(std::vector<kaldi::BaseFloat>* nnet_prob) {
Compute();
if (frontend_->IsFinished() && cache_.empty()) finished_ = true;
bool flag = cache_.pop(nnet_prob);
LOG(INFO) << "nnet cache_ size: " << cache_.size();
return flag;
}
@ -91,7 +53,10 @@ bool NnetProducer::Compute() {
vector<BaseFloat> features;
if (frontend_ == NULL || frontend_->Read(&features) == false) {
// no feat or frontend_ not init.
VLOG(2) << "no feat avalible";
LOG(INFO) << "no feat avalible";
if (frontend_->IsFinished() == true) {
finished_ = true;
}
return false;
}
CHECK_GE(frontend_->Dim(), 0);
@ -107,7 +72,6 @@ bool NnetProducer::Compute() {
out.logprobs.data() + idx * vocab_dim,
out.logprobs.data() + (idx + 1) * vocab_dim);
cache_.push_back(logprob);
condition_read_ready_.notify_one();
}
return true;
}

@ -25,7 +25,6 @@ class NnetProducer {
public:
explicit NnetProducer(std::shared_ptr<NnetBase> nnet,
std::shared_ptr<FrontendInterface> frontend = NULL);
// Feed feats or waves
void Accept(const std::vector<kaldi::BaseFloat>& inputs);
@ -33,36 +32,24 @@ class NnetProducer {
// nnet
bool Read(std::vector<kaldi::BaseFloat>* nnet_prob);
bool ReadandCompute(std::vector<kaldi::BaseFloat>* nnet_prob);
static void RunNnetEvaluation(NnetProducer* me);
void RunNnetEvaluationInteral();
void WaitProduce();
void Wait() {
abort_ = true;
condition_variable_.notify_one();
if (thread_.joinable()) thread_.join();
}
bool Empty() const { return cache_.empty(); }
void SetInputFinished() {
LOG(INFO) << "set finished";
frontend_->SetFinished();
condition_variable_.notify_one();
}
// the compute thread exit
bool IsFinished() const { return finished_; }
~NnetProducer() {
if (thread_.joinable()) thread_.join();
bool IsFinished() const {
return (frontend_->IsFinished() && finished_);
}
~NnetProducer() {}
void Reset() {
if (frontend_ != NULL) frontend_->Reset();
if (nnet_ != NULL) nnet_->Reset();
VLOG(3) << "feature cache reset: cache size: " << cache_.size();
cache_.clear();
finished_ = false;
}
@ -71,19 +58,13 @@ class NnetProducer {
float reverse_weight,
std::vector<float>* rescoring_score);
private:
bool Compute();
private:
std::shared_ptr<FrontendInterface> frontend_;
std::shared_ptr<NnetBase> nnet_;
SafeQueue<std::vector<kaldi::BaseFloat>> cache_;
std::mutex mutex_;
std::mutex read_mutex_;
std::condition_variable condition_variable_;
std::condition_variable condition_read_ready_;
std::thread thread_;
bool finished_;
bool abort_;
DISALLOW_COPY_AND_ASSIGN(NnetProducer);
};

@ -1,23 +1,22 @@
set(srcs)
list(APPEND srcs
u2_recognizer.cc
recognizer_controller.cc
recognizer_controller_impl.cc
)
add_library(recognizer STATIC ${srcs})
target_link_libraries(recognizer PUBLIC decoder)
set(TEST_BINS
u2_recognizer_main
u2_recognizer_batch_main
recognizer_batch_main
recognizer_main
)
foreach(bin_name IN LISTS TEST_BINS)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} recognizer nnet decoder fst utils libgflags_nothreads.so glog kaldi-base kaldi-matrix kaldi-util)
target_link_libraries(${bin_name} recognizer nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util)
target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS})
target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS} -ldl)

@ -0,0 +1,13 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

@ -0,0 +1,13 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

@ -19,7 +19,6 @@
#include "frontend/wave-reader.h"
#include "kaldi/util/table-types.h"
#include "nnet/u2_nnet.h"
#include "recognizer/u2_recognizer.h"
#include "recognizer/recognizer_controller.h"
DEFINE_string(wav_rspecifier, "", "test feature rspecifier");
@ -69,9 +68,10 @@ void recognizer_func(ppspeech::RecognizerController* recognizer_controller,
kaldi::WaveData wave_data;
wave_data.Read(infile);
int32 recog_id = -1;
while (recog_id != -1) {
while (recog_id == -1) {
recog_id = recognizer_controller->GetRecognizerInstanceId();
}
recognizer_controller->InitDecoder(recog_id);
LOG(INFO) << "utt: " << utt;
LOG(INFO) << "wav dur: " << wave_data.Duration() << " sec.";
double dur = wave_data.Duration();
@ -96,13 +96,10 @@ void recognizer_func(ppspeech::RecognizerController* recognizer_controller,
}
recognizer_controller->Accept(wav_chunk, recog_id);
if (cur_chunk_size < chunk_sample_size) {
recognizer_controller->SetInputFinished(recog_id);
}
// no overlap
sample_offset += cur_chunk_size;
}
recognizer_controller->SetInputFinished(recog_id);
CHECK(sample_offset == tot_samples);
std::string result = recognizer_controller->GetFinalResult(recog_id);
if (result.empty()) {
@ -142,8 +139,8 @@ int main(int argc, char* argv[]) {
LOG(INFO) << "chunk size (s): " << streaming_chunk;
LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
ppspeech::U2RecognizerResource resource =
ppspeech::U2RecognizerResource::InitFromFlags();
ppspeech::RecognizerResource resource =
ppspeech::RecognizerResource::InitFromFlags();
ppspeech::RecognizerController recognizer_controller(njob, resource);
ThreadPool threadpool(njob);
vector<vector<string>> wavlist;

@ -13,17 +13,15 @@
// limitations under the License.
#include "recognizer/recognizer_controller.h"
#include "recognizer/u2_recognizer.h"
#include "nnet/u2_nnet.h"
namespace ppspeech {
RecognizerController::RecognizerController(int num_worker, U2RecognizerResource resource) {
RecognizerController::RecognizerController(int num_worker, RecognizerResource resource) {
nnet_ = std::make_shared<ppspeech::U2Nnet>(resource.model_opts);
recognizer_workers.resize(num_worker);
for (size_t i = 0; i < num_worker; ++i) {
recognizer_workers[i].reset(new ppspeech::U2Recognizer(resource, nnet_->Clone()));
recognizer_workers[i]->InitDecoder();
recognizer_workers[i].reset(new ppspeech::RecognizerControllerImpl(resource, nnet_->Clone()));
waiting_workers.push(i);
}
}
@ -43,16 +41,18 @@ int RecognizerController::GetRecognizerInstanceId() {
RecognizerController::~RecognizerController() {
for (size_t i = 0; i < recognizer_workers.size(); ++i) {
recognizer_workers[i]->SetInputFinished();
recognizer_workers[i]->WaitDecodeFinished();
recognizer_workers[i]->WaitFinished();
}
}
void RecognizerController::InitDecoder(int idx) {
recognizer_workers[idx]->InitDecoder();
}
std::string RecognizerController::GetFinalResult(int idx) {
recognizer_workers[idx]->WaitDecodeFinished();
recognizer_workers[idx]->WaitDecoderFinished();
recognizer_workers[idx]->AttentionRescoring();
std::string result = recognizer_workers[idx]->GetFinalResult();
recognizer_workers[idx]->InitDecoder();
{
std::unique_lock<std::mutex> lock(mutex_);
waiting_workers.push(idx);
@ -68,4 +68,4 @@ void RecognizerController::SetInputFinished(int idx) {
recognizer_workers[idx]->SetInputFinished();
}
}
}

@ -12,19 +12,22 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <queue>
#include <memory>
#include "recognizer/u2_recognizer.h"
#include "recognizer/recognizer_controller_impl.h"
#include "nnet/u2_nnet.h"
namespace ppspeech {
class RecognizerController {
public:
explicit RecognizerController(int num_worker, U2RecognizerResource resource);
explicit RecognizerController(int num_worker, RecognizerResource resource);
~RecognizerController();
int GetRecognizerInstanceId();
void InitDecoder(int idx);
void Accept(std::vector<float> data, int idx);
void SetInputFinished(int idx);
std::string GetFinalResult(int idx);
@ -33,7 +36,9 @@ class RecognizerController {
std::queue<int> waiting_workers;
std::shared_ptr<ppspeech::U2Nnet> nnet_;
std::mutex mutex_;
std::vector<std::unique_ptr<ppspeech::U2Recognizer>> recognizer_workers;
std::vector<std::unique_ptr<ppspeech::RecognizerControllerImpl>> recognizer_workers;
DISALLOW_COPY_AND_ASSIGN(RecognizerController);
};
}

@ -1,4 +1,4 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -12,21 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "recognizer/u2_recognizer.h"
#include "nnet/u2_nnet.h"
#ifdef USE_ONNX
#include "nnet/u2_onnx_nnet.h"
#endif
#include "recognizer/recognizer_controller_impl.h"
#include "decoder/ctc_prefix_beam_search_decoder.h"
#include "common/utils/strings.h"
namespace ppspeech {
using kaldi::BaseFloat;
using std::unique_ptr;
using std::vector;
U2Recognizer::U2Recognizer(const U2RecognizerResource& resource)
: opts_(resource) {
RecognizerControllerImpl::RecognizerControllerImpl(const RecognizerResource& resource)
: opts_(resource) {
BaseFloat am_scale = resource.acoustic_scale;
const FeaturePipelineOptions& feature_opts = resource.feature_pipeline_opts;
std::shared_ptr<FeaturePipeline> feature_pipeline(
@ -42,8 +35,9 @@ U2Recognizer::U2Recognizer(const U2RecognizerResource& resource)
}
#endif
nnet_producer_.reset(new NnetProducer(nnet, feature_pipeline));
decodable_.reset(new Decodable(nnet_producer_, am_scale));
nnet_thread_ = std::thread(RunNnetEvaluation, this);
decodable_.reset(new Decodable(nnet_producer_, am_scale));
CHECK_NE(resource.vocab_path, "");
if (resource.decoder_opts.tlg_decoder_opts.fst_path.empty()) {
LOG(INFO) << resource.decoder_opts.tlg_decoder_opts.fst_path;
@ -55,21 +49,21 @@ U2Recognizer::U2Recognizer(const U2RecognizerResource& resource)
}
symbol_table_ = decoder_->WordSymbolTable();
global_frame_offset_ = 0;
input_finished_ = false;
num_frames_ = 0;
result_.clear();
result_.clear();
}
U2Recognizer::U2Recognizer(const U2RecognizerResource& resource,
std::shared_ptr<NnetBase> nnet)
: opts_(resource) {
RecognizerControllerImpl::RecognizerControllerImpl(const RecognizerResource& resource,
std::shared_ptr<NnetBase> nnet)
:opts_(resource) {
BaseFloat am_scale = resource.acoustic_scale;
const FeaturePipelineOptions& feature_opts = resource.feature_pipeline_opts;
std::shared_ptr<FeaturePipeline> feature_pipeline =
std::make_shared<FeaturePipeline>(feature_opts);
nnet_producer_.reset(new NnetProducer(nnet, feature_pipeline));
nnet_producer_ = std::make_shared<NnetProducer>(nnet, feature_pipeline);
nnet_thread_ = std::thread(RunNnetEvaluation, this);
decodable_.reset(new Decodable(nnet_producer_, am_scale));
CHECK_NE(resource.vocab_path, "");
@ -88,21 +82,72 @@ U2Recognizer::U2Recognizer(const U2RecognizerResource& resource,
result_.clear();
}
U2Recognizer::~U2Recognizer() {
SetInputFinished();
WaitDecodeFinished();
RecognizerControllerImpl::~RecognizerControllerImpl() {
WaitFinished();
}
void RecognizerControllerImpl::Reset() {
nnet_producer_->Reset();
}
void RecognizerControllerImpl::RunDecoder(RecognizerControllerImpl* me) {
me->RunDecoderInternal();
}
void RecognizerControllerImpl::RunDecoderInternal() {
LOG(INFO) << "DecoderInternal begin";
while (!nnet_producer_->IsFinished()) {
nnet_condition_.notify_one();
decoder_->AdvanceDecode(decodable_);
}
decoder_->AdvanceDecode(decodable_);
UpdateResult(false);
LOG(INFO) << "DecoderInternal exit";
}
void RecognizerControllerImpl::WaitDecoderFinished() {
if (decoder_thread_.joinable()) decoder_thread_.join();
}
void RecognizerControllerImpl::RunNnetEvaluation(RecognizerControllerImpl* me) {
me->RunNnetEvaluationInternal();
}
void RecognizerControllerImpl::SetInputFinished() {
nnet_producer_->SetInputFinished();
nnet_condition_.notify_one();
LOG(INFO) << "Set Input Finished";
}
void RecognizerControllerImpl::WaitFinished() {
abort_ = true;
LOG(INFO) << "nnet wait finished";
nnet_condition_.notify_one();
if (nnet_thread_.joinable()) {
nnet_thread_.join();
}
}
void U2Recognizer::WaitDecodeFinished() {
if (thread_.joinable()) thread_.join();
void RecognizerControllerImpl::RunNnetEvaluationInternal() {
bool result = false;
LOG(INFO) << "NnetEvaluationInteral begin";
while (!abort_) {
std::unique_lock<std::mutex> lock(nnet_mutex_);
nnet_condition_.wait(lock);
do {
result = nnet_producer_->Compute();
decoder_condition_.notify_one();
} while (result);
}
LOG(INFO) << "NnetEvaluationInteral exit";
}
void U2Recognizer::WaitFinished() {
if (thread_.joinable()) thread_.join();
nnet_producer_->Wait();
void RecognizerControllerImpl::Accept(std::vector<float> data) {
nnet_producer_->Accept(data);
nnet_condition_.notify_one();
}
void U2Recognizer::InitDecoder() {
void RecognizerControllerImpl::InitDecoder() {
global_frame_offset_ = 0;
input_finished_ = false;
num_frames_ = 0;
@ -110,51 +155,56 @@ void U2Recognizer::InitDecoder() {
decodable_->Reset();
decoder_->Reset();
thread_ = std::thread(RunDecoderSearch, this);
decoder_thread_ = std::thread(RunDecoder, this);
}
void U2Recognizer::ResetContinuousDecoding() {
global_frame_offset_ = num_frames_;
num_frames_ = 0;
result_.clear();
void RecognizerControllerImpl::AttentionRescoring() {
decoder_->FinalizeSearch();
UpdateResult(false);
decodable_->Reset();
decoder_->Reset();
}
// No need to do rescoring
if (0.0 == opts_.decoder_opts.rescoring_weight) {
LOG_EVERY_N(WARNING, 3) << "Not do AttentionRescoring!";
return;
}
LOG_EVERY_N(WARNING, 3) << "Do AttentionRescoring!";
void U2Recognizer::RunDecoderSearch(U2Recognizer* me) {
me->RunDecoderSearchInternal();
}
// Inputs() returns N-best input ids, which is the basic unit for rescoring
// In CtcPrefixBeamSearch, inputs are the same to outputs
const auto& hypotheses = decoder_->Inputs();
int num_hyps = hypotheses.size();
if (num_hyps <= 0) {
return;
}
void U2Recognizer::RunDecoderSearchInternal() {
LOG(INFO) << "DecoderSearchInteral begin";
while (!nnet_producer_->IsFinished()) {
nnet_producer_->WaitProduce();
decoder_->AdvanceDecode(decodable_);
std::vector<float> rescoring_score;
decodable_->AttentionRescoring(
hypotheses, opts_.decoder_opts.reverse_weight, &rescoring_score);
// combine ctc score and rescoring score
for (size_t i = 0; i < num_hyps; i++) {
VLOG(3) << "hyp " << i << " rescoring_score: " << rescoring_score[i]
<< " ctc_score: " << result_[i].score
<< " rescoring_weight: " << opts_.decoder_opts.rescoring_weight
<< " ctc_weight: " << opts_.decoder_opts.ctc_weight;
result_[i].score =
opts_.decoder_opts.rescoring_weight * rescoring_score[i] +
opts_.decoder_opts.ctc_weight * result_[i].score;
VLOG(3) << "hyp: " << result_[0].sentence
<< " score: " << result_[0].score;
}
decoder_->AdvanceDecode(decodable_);
UpdateResult(false);
LOG(INFO) << "DecoderSearchInteral exit";
}
void U2Recognizer::Accept(const vector<BaseFloat>& waves) {
kaldi::Timer timer;
nnet_producer_->Accept(waves);
VLOG(1) << "feed waves cost: " << timer.Elapsed() << " sec. "
<< waves.size() << " samples.";
std::sort(result_.begin(), result_.end(), DecodeResult::CompareFunc);
VLOG(3) << "result: " << result_[0].sentence
<< " score: " << result_[0].score;
}
void U2Recognizer::Decode() {
decoder_->AdvanceDecode(decodable_);
UpdateResult(false);
}
std::string RecognizerControllerImpl::GetFinalResult() { return result_[0].sentence; }
void U2Recognizer::Rescoring() {
// Do attention Rescoring
AttentionRescoring();
}
std::string RecognizerControllerImpl::GetPartialResult() { return result_[0].sentence; }
void U2Recognizer::UpdateResult(bool finish) {
void RecognizerControllerImpl::UpdateResult(bool finish) {
const auto& hypotheses = decoder_->Outputs();
const auto& inputs = decoder_->Inputs();
const auto& likelihood = decoder_->Likelihood();
@ -169,10 +219,9 @@ void U2Recognizer::UpdateResult(bool finish) {
path.score = likelihood[i];
for (size_t j = 0; j < hypothesis.size(); j++) {
std::string word = symbol_table_->Find(hypothesis[j]);
// path.sentence += (" " + word); // todo SmileGoat: add blank
// processor
path.sentence += word; // todo SmileGoat: add blank processor
path.sentence += (" " + word);
}
path.sentence = DelBlank(path.sentence);
// TimeStamp is only supported in final result
// TimeStamp of the output of CtcWfstBeamSearch may be inaccurate due to
@ -229,56 +278,4 @@ void U2Recognizer::UpdateResult(bool finish) {
}
}
void U2Recognizer::AttentionRescoring() {
decoder_->FinalizeSearch();
UpdateResult(false);
// No need to do rescoring
if (0.0 == opts_.decoder_opts.rescoring_weight) {
LOG_EVERY_N(WARNING, 3) << "Not do AttentionRescoring!";
return;
}
LOG_EVERY_N(WARNING, 3) << "Do AttentionRescoring!";
// Inputs() returns N-best input ids, which is the basic unit for rescoring
// In CtcPrefixBeamSearch, inputs are the same to outputs
const auto& hypotheses = decoder_->Inputs();
int num_hyps = hypotheses.size();
if (num_hyps <= 0) {
return;
}
std::vector<float> rescoring_score;
decodable_->AttentionRescoring(
hypotheses, opts_.decoder_opts.reverse_weight, &rescoring_score);
// combine ctc score and rescoring score
for (size_t i = 0; i < num_hyps; i++) {
VLOG(3) << "hyp " << i << " rescoring_score: " << rescoring_score[i]
<< " ctc_score: " << result_[i].score
<< " rescoring_weight: " << opts_.decoder_opts.rescoring_weight
<< " ctc_weight: " << opts_.decoder_opts.ctc_weight;
result_[i].score =
opts_.decoder_opts.rescoring_weight * rescoring_score[i] +
opts_.decoder_opts.ctc_weight * result_[i].score;
VLOG(3) << "hyp: " << result_[0].sentence
<< " score: " << result_[0].score;
}
std::sort(result_.begin(), result_.end(), DecodeResult::CompareFunc);
VLOG(3) << "result: " << result_[0].sentence
<< " score: " << result_[0].score;
}
std::string U2Recognizer::GetFinalResult() { return result_[0].sentence; }
std::string U2Recognizer::GetPartialResult() { return result_[0].sentence; }
void U2Recognizer::SetInputFinished() {
nnet_producer_->SetInputFinished();
input_finished_ = true;
}
} // namespace ppspeech
} // namespace ppspeech

@ -0,0 +1,91 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "decoder/common.h"
#include "fst/fstlib.h"
#include "fst/symbol-table.h"
#include "nnet/u2_nnet.h"
#include "nnet/nnet_producer.h"
#ifdef USE_ONNX
#include "nnet/u2_onnx_nnet.h"
#endif
#include "nnet/decodable.h"
#include "recognizer/recognizer_resource.h"
#include <memory>
namespace ppspeech {
class RecognizerControllerImpl {
public:
explicit RecognizerControllerImpl(const RecognizerResource& resource);
explicit RecognizerControllerImpl(const RecognizerResource& resource,
std::shared_ptr<NnetBase> nnet);
~RecognizerControllerImpl();
void Accept(std::vector<float> data);
void InitDecoder();
void SetInputFinished();
std::string GetFinalResult();
std::string GetPartialResult();
void Rescoring();
void Reset();
void WaitDecoderFinished();
void WaitFinished();
void AttentionRescoring();
bool DecodedSomething() const {
return !result_.empty() && !result_[0].sentence.empty();
}
int FrameShiftInMs() const {
return 1; //todo
}
private:
static void RunNnetEvaluation(RecognizerControllerImpl* me);
void RunNnetEvaluationInternal();
static void RunDecoder(RecognizerControllerImpl* me);
void RunDecoderInternal();
void UpdateResult(bool finish = false);
std::shared_ptr<Decodable> decodable_;
std::unique_ptr<DecoderBase> decoder_;
std::shared_ptr<NnetProducer> nnet_producer_;
// e2e unit symbol table
std::shared_ptr<fst::SymbolTable> symbol_table_ = nullptr;
std::vector<DecodeResult> result_;
RecognizerResource opts_;
bool abort_ = false;
// global decoded frame offset
int global_frame_offset_;
// cur decoded frame num
int num_frames_;
// timestamp gap between words in a sentence
const int time_stamp_gap_ = 100;
bool input_finished_;
std::mutex nnet_mutex_;
std::mutex decoder_mutex_;
std::condition_variable nnet_condition_;
std::condition_variable decoder_condition_;
std::thread nnet_thread_;
std::thread decoder_thread_;
DISALLOW_COPY_AND_ASSIGN(RecognizerControllerImpl);
};
}

@ -0,0 +1,13 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

@ -0,0 +1,13 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

@ -15,7 +15,7 @@
#include "decoder/param.h"
#include "frontend/wave-reader.h"
#include "kaldi/util/table-types.h"
#include "recognizer/u2_recognizer.h"
#include "recognizer/recognizer_controller.h"
DEFINE_string(wav_rspecifier, "", "test feature rspecifier");
DEFINE_string(result_wspecifier, "", "test result wspecifier");
@ -45,10 +45,10 @@ int main(int argc, char* argv[]) {
LOG(INFO) << "chunk size (s): " << streaming_chunk;
LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
ppspeech::U2RecognizerResource resource =
ppspeech::U2RecognizerResource::InitFromFlags();
std::shared_ptr<ppspeech::U2Recognizer> recognizer_ptr(
new ppspeech::U2Recognizer(resource));
ppspeech::RecognizerResource resource =
ppspeech::RecognizerResource::InitFromFlags();
std::shared_ptr<ppspeech::RecognizerControllerImpl> recognizer_ptr(
new ppspeech::RecognizerControllerImpl(resource));
for (; !wav_reader.Done(); wav_reader.Next()) {
recognizer_ptr->InitDecoder();
@ -84,7 +84,7 @@ int main(int argc, char* argv[]) {
}
CHECK(sample_offset == tot_samples);
recognizer_ptr->SetInputFinished();
recognizer_ptr->WaitDecodeFinished();
recognizer_ptr->WaitDecoderFinished();
kaldi::Timer timer;
recognizer_ptr->AttentionRescoring();

@ -1,28 +1,8 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "decoder/common.h"
#include "decoder/ctc_beam_search_opt.h"
#include "decoder/ctc_prefix_beam_search_decoder.h"
#include "decoder/ctc_tlg_decoder.h"
#include "decoder/decoder_itf.h"
#include "frontend/feature_pipeline.h"
#include "fst/fstlib.h"
#include "fst/symbol-table.h"
#include "nnet/decodable.h"
DECLARE_int32(nnet_decoder_chunk);
DECLARE_int32(num_left_chunks);
@ -87,7 +67,7 @@ struct DecodeOptions {
}
};
struct U2RecognizerResource {
struct RecognizerResource {
kaldi::BaseFloat acoustic_scale{1.0};
std::string vocab_path{};
@ -95,8 +75,8 @@ struct U2RecognizerResource {
ModelOptions model_opts{};
DecodeOptions decoder_opts{};
static U2RecognizerResource InitFromFlags() {
U2RecognizerResource resource;
static RecognizerResource InitFromFlags() {
RecognizerResource resource;
resource.vocab_path = FLAGS_vocab_path;
resource.acoustic_scale = FLAGS_acoustic_scale;
LOG(INFO) << "vocab path: " << resource.vocab_path;
@ -113,68 +93,4 @@ struct U2RecognizerResource {
}
};
class U2Recognizer {
public:
explicit U2Recognizer(const U2RecognizerResource& resouce);
explicit U2Recognizer(const U2RecognizerResource& resource,
std::shared_ptr<NnetBase> nnet);
~U2Recognizer();
void InitDecoder();
void ResetContinuousDecoding();
void Accept(const std::vector<kaldi::BaseFloat>& waves);
void Decode();
void Rescoring();
std::string GetFinalResult();
std::string GetPartialResult();
void SetInputFinished();
bool IsFinished() { return input_finished_; }
void WaitDecodeFinished();
void WaitFinished();
bool DecodedSomething() const {
return !result_.empty() && !result_[0].sentence.empty();
}
int FrameShiftInMs() const {
// one decoder frame length in ms, todo
return 1;
// return decodable_->Nnet()->SubsamplingRate() *
// feature_pipeline_->FrameShift();
}
const std::vector<DecodeResult>& Result() const { return result_; }
void AttentionRescoring();
private:
static void RunDecoderSearch(U2Recognizer* me);
void RunDecoderSearchInternal();
void UpdateResult(bool finish = false);
private:
U2RecognizerResource opts_;
std::shared_ptr<NnetProducer> nnet_producer_;
std::shared_ptr<Decodable> decodable_;
std::unique_ptr<DecoderBase> decoder_;
// e2e unit symbol table
std::shared_ptr<fst::SymbolTable> symbol_table_ = nullptr;
std::vector<DecodeResult> result_;
// global decoded frame offset
int global_frame_offset_;
// cur decoded frame num
int num_frames_;
// timestamp gap between words in a sentence
const int time_stamp_gap_ = 100;
bool input_finished_;
std::thread thread_;
};
} // namespace ppspeech
} //namespace ppspeech

@ -1,185 +0,0 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "common/base/thread_pool.h"
#include "common/utils/file_utils.h"
#include "common/utils/strings.h"
#include "decoder/param.h"
#include "frontend/wave-reader.h"
#include "kaldi/util/table-types.h"
#include "nnet/u2_nnet.h"
#include "recognizer/u2_recognizer.h"
DEFINE_string(wav_rspecifier, "", "test feature rspecifier");
DEFINE_string(result_wspecifier, "", "test result wspecifier");
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
DEFINE_int32(sample_rate, 16000, "sample rate");
DEFINE_int32(njob, 3, "njob");
using std::string;
using std::vector;
void SplitUtt(string wavlist_file,
vector<vector<string>>* uttlists,
vector<vector<string>>* wavlists,
int njob) {
vector<string> wavlist;
wavlists->resize(njob);
uttlists->resize(njob);
ppspeech::ReadFileToVector(wavlist_file, &wavlist);
for (size_t idx = 0; idx < wavlist.size(); ++idx) {
string utt_str = wavlist[idx];
vector<string> utt_wav = ppspeech::StrSplit(utt_str, " \t");
LOG(INFO) << utt_wav[0];
CHECK_EQ(utt_wav.size(), size_t(2));
uttlists->at(idx % njob).push_back(utt_wav[0]);
wavlists->at(idx % njob).push_back(utt_wav[1]);
}
}
void recognizer_func(const ppspeech::U2RecognizerResource& resource,
std::shared_ptr<ppspeech::NnetBase> nnet,
std::vector<string> wavlist,
std::vector<string> uttlist,
std::vector<string>* results) {
int32 num_done = 0, num_err = 0;
double tot_wav_duration = 0.0;
double tot_attention_rescore_time = 0.0;
double tot_decode_time = 0.0;
int chunk_sample_size = FLAGS_streaming_chunk * FLAGS_sample_rate;
if (wavlist.empty()) return;
std::shared_ptr<ppspeech::U2Recognizer> recognizer_ptr =
std::make_shared<ppspeech::U2Recognizer>(resource, nnet);
results->reserve(wavlist.size());
for (size_t idx = 0; idx < wavlist.size(); ++idx) {
std::string utt = uttlist[idx];
std::string wav_file = wavlist[idx];
std::ifstream infile;
infile.open(wav_file, std::ifstream::in);
kaldi::WaveData wave_data;
wave_data.Read(infile);
recognizer_ptr->InitDecoder();
LOG(INFO) << "utt: " << utt;
LOG(INFO) << "wav dur: " << wave_data.Duration() << " sec.";
double dur = wave_data.Duration();
tot_wav_duration += dur;
int32 this_channel = 0;
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
this_channel);
int tot_samples = waveform.Dim();
LOG(INFO) << "wav len (sample): " << tot_samples;
int sample_offset = 0;
kaldi::Timer local_timer;
while (sample_offset < tot_samples) {
int cur_chunk_size =
std::min(chunk_sample_size, tot_samples - sample_offset);
std::vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk[i] = waveform(sample_offset + i);
}
recognizer_ptr->Accept(wav_chunk);
if (cur_chunk_size < chunk_sample_size) {
recognizer_ptr->SetInputFinished();
}
// no overlap
sample_offset += cur_chunk_size;
}
CHECK(sample_offset == tot_samples);
recognizer_ptr->WaitDecodeFinished();
kaldi::Timer timer;
recognizer_ptr->AttentionRescoring();
tot_attention_rescore_time += timer.Elapsed();
std::string result = recognizer_ptr->GetFinalResult();
if (result.empty()) {
// the TokenWriter can not write empty string.
++num_err;
LOG(INFO) << " the result of " << utt << " is empty";
result = " ";
}
tot_decode_time += local_timer.Elapsed();
LOG(INFO) << utt << " " << result;
LOG(INFO) << " RTF: " << local_timer.Elapsed() / dur << " dur: " << dur
<< " cost: " << local_timer.Elapsed();
results->push_back(result);
++num_done;
}
recognizer_ptr->WaitFinished();
LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done);
LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec";
LOG(INFO) << "total decode cost:" << tot_decode_time << " sec";
LOG(INFO) << "total rescore cost:" << tot_attention_rescore_time << " sec";
LOG(INFO) << "RTF is: " << tot_decode_time / tot_wav_duration;
}
int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
int sample_rate = FLAGS_sample_rate;
float streaming_chunk = FLAGS_streaming_chunk;
int chunk_sample_size = streaming_chunk * sample_rate;
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
int njob = FLAGS_njob;
LOG(INFO) << "sr: " << sample_rate;
LOG(INFO) << "chunk size (s): " << streaming_chunk;
LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
ppspeech::U2RecognizerResource resource =
ppspeech::U2RecognizerResource::InitFromFlags();
ThreadPool threadpool(njob);
vector<vector<string>> wavlist;
vector<vector<string>> uttlist;
vector<vector<string>> resultlist(njob);
vector<std::future<void>> futurelist;
std::shared_ptr<ppspeech::U2Nnet> nnet(
new ppspeech::U2Nnet(resource.model_opts));
SplitUtt(FLAGS_wav_rspecifier, &uttlist, &wavlist, njob);
for (size_t i = 0; i < njob; ++i) {
std::future<void> f = threadpool.enqueue(recognizer_func,
resource,
nnet->Clone(),
wavlist[i],
uttlist[i],
&resultlist[i]);
futurelist.push_back(std::move(f));
}
for (size_t i = 0; i < njob; ++i) {
futurelist[i].get();
}
for (size_t idx = 0; idx < njob; ++idx) {
for (size_t utt_idx = 0; utt_idx < uttlist[idx].size(); ++utt_idx) {
string utt = uttlist[idx][utt_idx];
string result = resultlist[idx][utt_idx];
result_writer.Write(utt, result);
}
}
return 0;
}

@ -52,23 +52,20 @@ bool Assembler::Compute(vector<BaseFloat>* feats) {
vector<BaseFloat> feature;
bool result = base_extractor_->Read(&feature);
if (result == false || feature.size() == 0) {
VLOG(3) << "result: " << result
VLOG(1) << "result: " << result
<< " feature dim: " << feature.size();
if (IsFinished() == false) {
VLOG(3) << "finished reading feature. cache size: "
VLOG(1) << "finished reading feature. cache size: "
<< feature_cache_.size();
return false;
} else {
VLOG(3) << "break";
VLOG(1) << "break";
break;
}
}
CHECK(feature.size() == dim_);
feature_cache_.push(feature);
nframes_ += 1;
VLOG(3) << "nframes: " << nframes_;
VLOG(1) << "nframes: " << nframes_;
}
if (feature_cache_.size() < receptive_filed_length_) {

@ -56,28 +56,14 @@ bool AudioCache::Read(vector<BaseFloat>* waves) {
kaldi::Timer timer;
size_t chunk_size = waves->size();
std::unique_lock<std::mutex> lock(mutex_);
while (chunk_size > size_) {
// when audio is empty and no more data feed
// ready_read_condition will block in dead lock,
// so replace with timeout_
// ready_read_condition_.wait(lock);
int32 elapsed = static_cast<int32>(timer.Elapsed() * 1000);
if (elapsed > timeout_) {
if (finished_ == true) {
// read last chunk data
break;
}
if (chunk_size > size_) {
return false;
}
}
usleep(100); // sleep 0.1 ms
}
// read last chunk data
if (chunk_size > size_) {
chunk_size = size_;
waves->resize(chunk_size);
if (finished_ == false) {
return false;
} else {
// read last chunk data
chunk_size = size_;
waves->resize(chunk_size);
}
}
for (size_t idx = 0; idx < chunk_size; ++idx) {

@ -39,7 +39,7 @@ class AudioCache : public FrontendInterface {
finished_ = true;
}
virtual bool IsFinished() const { return finished_; }
virtual bool IsFinished() const { return finished_ && (size_ == 0); }
void Reset() override {
offset_ = 0;

@ -1,62 +0,0 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "frontend/audio/fbank.h"
#include "kaldi/base/kaldi-math.h"
#include "kaldi/feat/feature-common.h"
#include "kaldi/feat/feature-functions.h"
#include "kaldi/matrix/matrix-functions.h"
namespace ppspeech {
using kaldi::BaseFloat;
using kaldi::int32;
using kaldi::Matrix;
using kaldi::SubVector;
using kaldi::Vector;
using kaldi::VectorBase;
using std::vector;
FbankComputer::FbankComputer(const Options& opts)
: opts_(opts), computer_(opts) {}
int32 FbankComputer::Dim() const {
return opts_.mel_opts.num_bins + (opts_.use_energy ? 1 : 0);
}
bool FbankComputer::NeedRawLogEnergy() {
return opts_.use_energy && opts_.raw_energy;
}
// Compute feat
bool FbankComputer::Compute(Vector<BaseFloat>* window,
Vector<BaseFloat>* feat) {
RealFft(window, true);
kaldi::ComputePowerSpectrum(window);
const kaldi::MelBanks& mel_bank = *(computer_.GetMelBanks(1.0));
SubVector<BaseFloat> power_spectrum(*window, 0, window->Dim() / 2 + 1);
if (!opts_.use_power) {
power_spectrum.ApplyPow(0.5);
}
int32 mel_offset = ((opts_.use_energy && !opts_.htk_compat) ? 1 : 0);
SubVector<BaseFloat> mel_energies(
*feat, mel_offset, opts_.mel_opts.num_bins);
mel_bank.Compute(power_spectrum, &mel_energies);
mel_energies.ApplyFloor(1e-07);
mel_energies.ApplyLog();
return true;
}
} // namespace ppspeech

@ -49,7 +49,8 @@ bool FeatureCache::Read(std::vector<kaldi::BaseFloat>* feats) {
// read from cache
*feats = cache_.front();
cache_.pop();
VLOG(1) << "FeatureCache::Read cost: " << timer.Elapsed() << " sec.";
VLOG(2) << "FeatureCache::Read cost: " << timer.Elapsed() << " sec.";
VLOG(1) << "FeatureCache::size : " << cache_.size();
return true;
}
@ -74,7 +75,7 @@ bool FeatureCache::Compute() {
++nframe_;
}
VLOG(1) << "FeatureCache::Compute cost: " << timer.Elapsed() << " sec. "
VLOG(2) << "FeatureCache::Compute cost: " << timer.Elapsed() << " sec. "
<< num_chunk << " feats.";
return true;
}

@ -36,21 +36,19 @@ class FeatureCache : public FrontendInterface {
virtual void SetFinished() {
std::unique_lock<std::mutex> lock(mutex_);
LOG(INFO) << "set finished";
// read the last chunk data
Compute();
base_extractor_->SetFinished();
LOG(INFO) << "compute last feats done.";
}
virtual bool IsFinished() const { return base_extractor_->IsFinished(); }
virtual bool IsFinished() const {
return base_extractor_->IsFinished() && cache_.empty();
}
void Reset() override {
std::queue<std::vector<BaseFloat>> empty;
VLOG(1) << "feature cache size: " << cache_.size();
std::swap(cache_, empty);
nframe_ = 0;
base_extractor_->Reset();
VLOG(3) << "feature cache reset: cache size: " << cache_.size();
}
private:

@ -1,109 +0,0 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "frontend/audio/mfcc.h"
#include "kaldi/base/kaldi-math.h"
#include "kaldi/feat/feature-common.h"
#include "kaldi/feat/feature-functions.h"
#include "kaldi/matrix/matrix-functions.h"
namespace ppspeech {
using kaldi::BaseFloat;
using kaldi::int32;
using kaldi::Matrix;
using kaldi::SubVector;
using kaldi::Vector;
using kaldi::VectorBase;
using std::vector;
Mfcc::Mfcc(const MfccOptions& opts,
std::unique_ptr<FrontendInterface> base_extractor)
: opts_(opts),
computer_(opts.mfcc_opts),
window_function_(computer_.GetFrameOptions()) {
base_extractor_ = std::move(base_extractor);
chunk_sample_size_ =
static_cast<int32>(opts.streaming_chunk * opts.frame_opts.samp_freq);
}
void Mfcc::Accept(const VectorBase<BaseFloat>& inputs) {
base_extractor_->Accept(inputs);
}
bool Mfcc::Read(Vector<BaseFloat>* feats) {
Vector<BaseFloat> wav(chunk_sample_size_);
bool flag = base_extractor_->Read(&wav);
if (flag == false || wav.Dim() == 0) return false;
// append remaned waves
int32 wav_len = wav.Dim();
int32 left_len = remained_wav_.Dim();
Vector<BaseFloat> waves(left_len + wav_len);
waves.Range(0, left_len).CopyFromVec(remained_wav_);
waves.Range(left_len, wav_len).CopyFromVec(wav);
// compute speech feature
Compute(waves, feats);
// cache remaned waves
kaldi::FrameExtractionOptions frame_opts = computer_.GetFrameOptions();
int32 num_frames = kaldi::NumFrames(waves.Dim(), frame_opts);
int32 frame_shift = frame_opts.WindowShift();
int32 left_samples = waves.Dim() - frame_shift * num_frames;
remained_wav_.Resize(left_samples);
remained_wav_.CopyFromVec(
waves.Range(frame_shift * num_frames, left_samples));
return true;
}
// Compute spectrogram feat
bool Mfcc::Compute(const Vector<BaseFloat>& waves, Vector<BaseFloat>* feats) {
const FrameExtractionOptions& frame_opts = computer_.GetFrameOptions();
int32 num_samples = waves.Dim();
int32 frame_length = frame_opts.WindowSize();
int32 sample_rate = frame_opts.samp_freq;
if (num_samples < frame_length) {
return true;
}
int32 num_frames = kaldi::NumFrames(num_samples, frame_opts);
feats->Rsize(num_frames * Dim());
Vector<BaseFloat> window;
bool need_raw_log_energy = computer_.NeedRawLogEnergy();
for (int32 frame = 0; frame < num_frames; frame++) {
BaseFloat raw_log_energy = 0.0;
kaldi::ExtractWindow(0,
waves,
frame,
frame_opts,
window_function_,
&window,
need_raw_log_energy ? &raw_log_energy : NULL);
Vector<BaseFloat> this_feature(computer_.Dim(), kUndefined);
// note: this online feature-extraction code does not support VTLN.
BaseFloat vtln_warp = 1.0;
computer_.Compute(raw_log_energy, vtln_warp, &window, &this_feature);
SubVector<BaseFloat> output_row(feats->Data() + frame * Dim(), Dim());
output_row.CopyFromVec(this_feature);
}
return true;
}
} // namespace ppspeech

@ -1,75 +0,0 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "kaldi/feat/feature-mfcc.h"
#include "kaldi/matrix/kaldi-vector.h"
namespace ppspeech {
struct MfccOptions {
kaldi::MfccOptions mfcc_opts;
kaldi::BaseFloat streaming_chunk; // second
MfccOptions() : streaming_chunk(0.1), mfcc_opts() {}
void Register(kaldi::OptionsItf* opts) {
opts->Register("streaming-chunk",
&streaming_chunk,
"streaming chunk size, default: 0.1 sec");
mfcc_opts.Register(opts);
}
};
class Mfcc : public FrontendInterface {
public:
explicit Mfcc(const MfccOptions& opts,
unique_ptr<FrontendInterface> base_extractor);
virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs);
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* feats);
// the dim_ is the dim of single frame feature
virtual size_t Dim() const { return computer_.Dim(); }
virtual void SetFinished() { base_extractor_->SetFinished(); }
virtual bool IsFinished() const { return base_extractor_->IsFinished(); }
virtual void Reset() {
base_extractor_->Reset();
remained_wav_.Resize(0);
}
private:
bool Compute(const kaldi::Vector<kaldi::BaseFloat>& waves,
kaldi::Vector<kaldi::BaseFloat>* feats);
MfccOptions opts_;
std::unique_ptr<FrontendInterface> base_extractor_;
FeatureWindowFunction window_function_;
kaldi::MfccComputer computer_;
// features_ is the Mfcc or Plp or Fbank features that we have already
// computed.
kaldi::Vector<kaldi::BaseFloat> features_;
kaldi::Vector<kaldi::BaseFloat> remained_wav_;
DISALLOW_COPY_AND_ASSIGN(Fbank);
};
} // namespace ppspeech

@ -1,26 +0,0 @@
#include "utils/blank_process.h"
namespace ppspeech {
std::string BlankProcess(const std::string& str) {
std::string out = "";
int p = 0;
int end = str.size();
int q = -1; // last char of the output string
while (p != end) {
while (p != end && str[p] == ' ') {
p += 1;
}
if (p == end)
return out;
if (q != -1 && isalpha(str[p]) && isalpha(str[q]) && str[p-1] == ' ')
// add a space when the last and current chars are in English and there have space(s) between them
out += ' ';
out += str[p];
q = p;
p += 1;
}
return out;
}
} // namespace ppspeech

@ -1,9 +0,0 @@
#include <string>
#include <vector>
#include <cctype>
namespace ppspeech {
std::string BlankProcess(const std::string& str);
} // namespace ppspeech

@ -49,6 +49,75 @@ std::string StrJoin(const std::vector<std::string>& strs, const char* delim) {
return ss.str();
}
std::string DelBlank(const std::string& str) {
std::string out = "";
int ptr_in = 0; // the pointer of input string (for traversal)
int end = str.size();
int ptr_out = -1; // the pointer of output string (last char)
while (ptr_in != end) {
while (ptr_in != end && str[ptr_in] == ' ') {
ptr_in += 1;
}
if (ptr_in == end)
return out;
if (ptr_out != -1 && isalpha(str[ptr_in]) && isalpha(str[ptr_out]) && str[ptr_in-1] == ' ')
// add a space when the last and current chars are in English and there have space(s) between them
out += ' ';
out += str[ptr_in];
ptr_out = ptr_in;
ptr_in += 1;
}
return out;
}
std::string AddBlank(const std::string& str) {
std::string out = "";
int ptr = 0; // the pointer of the input string
int end = str.size();
while (ptr != end) {
if (isalpha(str[ptr])) {
if (ptr == 0 or str[ptr-1] != ' ')
out += " "; // add pre-space for an English word
while (isalpha(str[ptr])) {
out += str[ptr];
ptr += 1;
}
out += " "; // add post-space for an English word
} else {
out += str[ptr];
ptr += 1;
}
}
return out;
}
std::string ReverseFraction(const std::string& str) {
std::string out = "";
int ptr = 0; // the pointer of the input string
int end = str.size();
int left, right, frac; // the start index of the left tag, right tag and '/'.
left = right = frac = 0;
int len_tag = 5; // length of "<tag>"
while (ptr != end) {
// find the position of left tag, right tag and '/'. (xxx<tag>num1/num2</tag>)
left = str.find("<tag>", ptr);
if (left == -1)
break;
out += str.substr(ptr, left - ptr); // content before left tag (xxx)
frac = str.find("/", left);
right = str.find("<tag>", frac);
out += str.substr(frac + 1, right - frac - 1) + '/' +
str.substr(left + len_tag, frac - left - len_tag); // num2/num1
ptr = right + len_tag;
}
if (ptr != end) {
out += str.substr(ptr, end - ptr);
}
return out;
}
#ifdef _MSC_VER
std::wstring ToWString(const std::string& str) {
unsigned len = str.size() * 2;
@ -61,4 +130,4 @@ std::wstring ToWString(const std::string& str) {
}
#endif
} // namespace ppspeech
} // namespace ppspeech

@ -25,8 +25,14 @@ std::vector<std::string> StrSplit(const std::string& str,
std::string StrJoin(const std::vector<std::string>& strs, const char* delim);
std::string DelBlank(const std::string& str);
std::string AddBlank(const std::string& str);
std::string ReverseFraction(const std::string& str);
#ifdef _MSC_VER
std::wstring ToWString(const std::string& str);
#endif
} // namespace ppspeech
} // namespace ppspeech

@ -32,4 +32,47 @@ TEST(StringTest, StrJoinTest) {
std::vector<std::string> ins{"hello", "world"};
std::string out = ppspeech::StrJoin(ins, " ");
EXPECT_THAT(out, "hello world");
}
}
TEST(StringText, DelBlankTest) {
std::string test_str = "我 今天 去 了 超市 花了 120 元。";
std::string out_str = ppspeech::DelBlank(test_str);
int ret = out_str.compare("我今天去了超市花了120元。");
EXPECT_EQ(ret, 0);
test_str = "how are you today";
out_str = ppspeech::DelBlank(test_str);
ret = out_str.compare("how are you today");
EXPECT_EQ(ret, 0);
test_str = "我 的 paper 在 哪里?";
out_str = ppspeech::DelBlank(test_str);
ret = out_str.compare("我的paper在哪里");
EXPECT_EQ(ret, 0);
}
TEST(StringTest, AddBlankTest) {
std::string test_str = "how are you";
std::string out_str = ppspeech::AddBlank(test_str);
int ret = out_str.compare(" how are you ");
EXPECT_EQ(ret, 0);
test_str = "欢迎来到China。";
out_str = ppspeech::AddBlank(test_str);
ret = out_str.compare("欢迎来到 China 。");
EXPECT_EQ(ret, 0);
}
TEST(StringTest, ReverseFractionTest) {
std::string test_str = "<tag>3/1<tag>";
std::string out_str = ppspeech::ReverseFraction(test_str);
int ret = out_str.compare("1/3");
std::cout<<out_str<<std::endl;
EXPECT_EQ(ret, 0);
test_str = "<tag>3/1<tag> <tag>100/10000<tag>";
out_str = ppspeech::ReverseFraction(test_str);
ret = out_str.compare("1/3 10000/100");
std::cout<<out_str<<std::endl;
EXPECT_EQ(ret, 0);
}

@ -1,74 +0,0 @@
#include "utils/text_process.h"
namespace ppspeech {
std::string DelBlank(const std::string& str) {
std::string out = "";
int ptr_in = 0; // the pointer of input string (for traversal)
int end = str.size();
int ptr_out = -1; // the pointer of output string (last char)
while (ptr_in != end) {
while (ptr_in != end && str[ptr_in] == ' ') {
ptr_in += 1;
}
if (ptr_in == end)
return out;
if (ptr_out != -1 && isalpha(str[ptr_in]) && isalpha(str[ptr_out]) && str[ptr_in-1] == ' ')
// add a space when the last and current chars are in English and there have space(s) between them
out += ' ';
out += str[ptr_in];
ptr_out = ptr_in;
ptr_in += 1;
}
return out;
}
std::string AddBlank(const std::string& str) {
std::string out = "";
int ptr = 0; // the pointer of the input string
int end = str.size();
while (ptr != end) {
if (isalpha(str[ptr])) {
if (ptr == 0 or str[ptr-1] != ' ')
out += " "; // add pre-space for an English word
while (isalpha(str[ptr])) {
out += str[ptr];
ptr += 1;
}
out += " "; // add post-space for an English word
} else {
out += str[ptr];
ptr += 1;
}
}
return out;
}
std::string ReverseFraction(const std::string& str) {
std::string out = "";
int ptr = 0; // the pointer of the input string
int end = str.size();
int left, right, frac; // the start index of the left tag, right tag and '/'.
left = right = frac = 0;
int len_tag = 5; // length of "<tag>"
while (ptr != end) {
// find the position of left tag, right tag and '/'. (xxx<tag>num1/num2</tag>)
left = str.find("<tag>", ptr);
if (left == -1)
break;
out += str.substr(ptr, left - ptr); // content before left tag (xxx)
frac = str.find("/", left);
right = str.find("<tag>", frac);
out += str.substr(frac + 1, right - frac - 1) + '/' +
str.substr(left + len_tag, frac - left - len_tag); // num2/num1
ptr = right + len_tag;
}
if (ptr != end) {
out += str.substr(ptr, end - ptr);
}
return out;
}
} // namespace ppspeech

@ -1,13 +0,0 @@
#include <string>
#include <vector>
#include <cctype>
namespace ppspeech {
std::string DelBlank(const std::string& str);
std::string AddBlank(const std::string& str);
std::string ReverseFraction(const std::string& str);
} // namespace ppspeech

@ -1,47 +0,0 @@
#include "utils/text_process.h"
#include <gtest/gtest.h>
#include <gmock/gmock.h>
TEST(TextProcess, DelBlankTest) {
std::string test_str = "我 今天 去 了 超市 花了 120 元。";
std::string out_str = ppspeech::DelBlank(test_str);
int ret = out_str.compare("我今天去了超市花了120元。");
EXPECT_EQ(ret, 0);
test_str = "how are you today";
out_str = ppspeech::DelBlank(test_str);
ret = out_str.compare("how are you today");
EXPECT_EQ(ret, 0);
test_str = "我 的 paper 在 哪里?";
out_str = ppspeech::DelBlank(test_str);
ret = out_str.compare("我的paper在哪里");
EXPECT_EQ(ret, 0);
}
TEST(TextProcess, AddBlankTest) {
std::string test_str = "how are you";
std::string out_str = ppspeech::AddBlank(test_str);
int ret = out_str.compare(" how are you ");
EXPECT_EQ(ret, 0);
test_str = "欢迎来到China。";
out_str = ppspeech::AddBlank(test_str);
ret = out_str.compare("欢迎来到 China 。");
EXPECT_EQ(ret, 0);
}
TEST(TextProcess, ReverseFractionTest) {
std::string test_str = "<tag>3/1<tag>";
std::string out_str = ppspeech::ReverseFraction(test_str);
int ret = out_str.compare("1/3");
std::cout<<out_str<<std::endl;
EXPECT_EQ(ret, 0);
test_str = "<tag>3/1<tag> <tag>100/10000<tag>";
out_str = ppspeech::ReverseFraction(test_str);
ret = out_str.compare("1/3 10000/100");
std::cout<<out_str<<std::endl;
EXPECT_EQ(ret, 0);
}

@ -16,7 +16,7 @@ text=$data/test/text
./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recognizer.log \
u2_recognizer_main \
recognizer_main \
--use_fbank=true \
--num_bins=80 \
--cmvn_file=$model_dir/mean_std.json \

Loading…
Cancel
Save