[engine]add recognizer api && clean params && make a shared decoder resource (#3165)
parent
11ce08b260
commit
b05ead51d7
@ -0,0 +1,168 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "common/base/thread_pool.h"
|
||||
#include "common/utils/file_utils.h"
|
||||
#include "common/utils/strings.h"
|
||||
#include "decoder/param.h"
|
||||
#include "frontend/wave-reader.h"
|
||||
#include "kaldi/util/table-types.h"
|
||||
#include "nnet/u2_nnet.h"
|
||||
#include "recognizer/recognizer.h"
|
||||
|
||||
DEFINE_string(wav_rspecifier, "", "test feature rspecifier");
|
||||
DEFINE_string(result_wspecifier, "", "test result wspecifier");
|
||||
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
|
||||
DEFINE_int32(sample_rate, 16000, "sample rate");
|
||||
DEFINE_int32(njob, 3, "njob");
|
||||
|
||||
using std::string;
|
||||
using std::vector;
|
||||
|
||||
void SplitUtt(string wavlist_file,
|
||||
vector<vector<string>>* uttlists,
|
||||
vector<vector<string>>* wavlists,
|
||||
int njob) {
|
||||
vector<string> wavlist;
|
||||
wavlists->resize(njob);
|
||||
uttlists->resize(njob);
|
||||
ppspeech::ReadFileToVector(wavlist_file, &wavlist);
|
||||
for (size_t idx = 0; idx < wavlist.size(); ++idx) {
|
||||
string utt_str = wavlist[idx];
|
||||
vector<string> utt_wav = ppspeech::StrSplit(utt_str, " \t");
|
||||
LOG(INFO) << utt_wav[0];
|
||||
CHECK_EQ(utt_wav.size(), size_t(2));
|
||||
uttlists->at(idx % njob).push_back(utt_wav[0]);
|
||||
wavlists->at(idx % njob).push_back(utt_wav[1]);
|
||||
}
|
||||
}
|
||||
|
||||
void recognizer_func(std::vector<string> wavlist,
|
||||
std::vector<string> uttlist,
|
||||
std::vector<string>* results) {
|
||||
int32 num_done = 0, num_err = 0;
|
||||
double tot_wav_duration = 0.0;
|
||||
double tot_attention_rescore_time = 0.0;
|
||||
double tot_decode_time = 0.0;
|
||||
int chunk_sample_size = FLAGS_streaming_chunk * FLAGS_sample_rate;
|
||||
if (wavlist.empty()) return;
|
||||
|
||||
results->reserve(wavlist.size());
|
||||
for (size_t idx = 0; idx < wavlist.size(); ++idx) {
|
||||
std::string utt = uttlist[idx];
|
||||
std::string wav_file = wavlist[idx];
|
||||
std::ifstream infile;
|
||||
infile.open(wav_file, std::ifstream::in);
|
||||
kaldi::WaveData wave_data;
|
||||
wave_data.Read(infile);
|
||||
int32 recog_id = -1;
|
||||
while (recog_id == -1) {
|
||||
recog_id = GetRecognizerInstanceId();
|
||||
}
|
||||
InitDecoder(recog_id);
|
||||
LOG(INFO) << "utt: " << utt;
|
||||
LOG(INFO) << "wav dur: " << wave_data.Duration() << " sec.";
|
||||
double dur = wave_data.Duration();
|
||||
tot_wav_duration += dur;
|
||||
|
||||
int32 this_channel = 0;
|
||||
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
|
||||
this_channel);
|
||||
int tot_samples = waveform.Dim();
|
||||
LOG(INFO) << "wav len (sample): " << tot_samples;
|
||||
|
||||
int sample_offset = 0;
|
||||
kaldi::Timer local_timer;
|
||||
|
||||
while (sample_offset < tot_samples) {
|
||||
int cur_chunk_size =
|
||||
std::min(chunk_sample_size, tot_samples - sample_offset);
|
||||
|
||||
std::vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
|
||||
for (int i = 0; i < cur_chunk_size; ++i) {
|
||||
wav_chunk[i] = waveform(sample_offset + i);
|
||||
}
|
||||
|
||||
AcceptData(wav_chunk, recog_id);
|
||||
// no overlap
|
||||
sample_offset += cur_chunk_size;
|
||||
}
|
||||
SetInputFinished(recog_id);
|
||||
CHECK(sample_offset == tot_samples);
|
||||
std::string result = GetFinalResult(recog_id);
|
||||
if (result.empty()) {
|
||||
// the TokenWriter can not write empty string.
|
||||
++num_err;
|
||||
LOG(INFO) << " the result of " << utt << " is empty";
|
||||
result = " ";
|
||||
}
|
||||
|
||||
tot_decode_time += local_timer.Elapsed();
|
||||
LOG(INFO) << utt << " " << result;
|
||||
LOG(INFO) << " RTF: " << local_timer.Elapsed() / dur << " dur: " << dur
|
||||
<< " cost: " << local_timer.Elapsed();
|
||||
|
||||
results->push_back(result);
|
||||
++num_done;
|
||||
}
|
||||
LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done);
|
||||
LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec";
|
||||
LOG(INFO) << "total decode cost:" << tot_decode_time << " sec";
|
||||
LOG(INFO) << "RTF is: " << tot_decode_time / tot_wav_duration;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
gflags::SetUsageMessage("Usage:");
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
google::InstallFailureSignalHandler();
|
||||
FLAGS_logtostderr = 1;
|
||||
|
||||
int sample_rate = FLAGS_sample_rate;
|
||||
float streaming_chunk = FLAGS_streaming_chunk;
|
||||
int chunk_sample_size = streaming_chunk * sample_rate;
|
||||
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
|
||||
int njob = FLAGS_njob;
|
||||
LOG(INFO) << "sr: " << sample_rate;
|
||||
LOG(INFO) << "chunk size (s): " << streaming_chunk;
|
||||
LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
|
||||
|
||||
InitRecognizer(FLAGS_model_path, FLAGS_word_symbol_table, FLAGS_graph_path, njob);
|
||||
ThreadPool threadpool(njob);
|
||||
vector<vector<string>> wavlist;
|
||||
vector<vector<string>> uttlist;
|
||||
vector<vector<string>> resultlist(njob);
|
||||
vector<std::future<void>> futurelist;
|
||||
SplitUtt(FLAGS_wav_rspecifier, &uttlist, &wavlist, njob);
|
||||
for (size_t i = 0; i < njob; ++i) {
|
||||
std::future<void> f = threadpool.enqueue(recognizer_func,
|
||||
wavlist[i],
|
||||
uttlist[i],
|
||||
&resultlist[i]);
|
||||
futurelist.push_back(std::move(f));
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < njob; ++i) {
|
||||
futurelist[i].get();
|
||||
}
|
||||
|
||||
for (size_t idx = 0; idx < njob; ++idx) {
|
||||
for (size_t utt_idx = 0; utt_idx < uttlist[idx].size(); ++utt_idx) {
|
||||
string utt = uttlist[idx][utt_idx];
|
||||
string result = resultlist[idx][utt_idx];
|
||||
result_writer.Write(utt, result);
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
@ -1,13 +0,0 @@
|
||||
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
@ -1,13 +0,0 @@
|
||||
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
@ -0,0 +1,66 @@
|
||||
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "recognizer/recognizer_instance.h"
|
||||
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
RecognizerInstance& RecognizerInstance::GetInstance() {
|
||||
static RecognizerInstance instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
bool RecognizerInstance::Init(const std::string& model_file,
|
||||
const std::string& word_symbol_table_file,
|
||||
const std::string& fst_file,
|
||||
int num_instance) {
|
||||
RecognizerResource resource = RecognizerResource::InitFromFlags();
|
||||
resource.model_opts.model_path = model_file;
|
||||
//resource.vocab_path = word_symbol_table_file;
|
||||
if (!fst_file.empty()) {
|
||||
resource.decoder_opts.tlg_decoder_opts.fst_path = fst_file;
|
||||
resource.decoder_opts.tlg_decoder_opts.fst_path = word_symbol_table_file;
|
||||
} else {
|
||||
resource.decoder_opts.ctc_prefix_search_opts.word_symbol_table =
|
||||
word_symbol_table_file;
|
||||
}
|
||||
recognizer_controller_ = std::make_unique<RecognizerController>(num_instance, resource);
|
||||
return true;
|
||||
}
|
||||
|
||||
void RecognizerInstance::InitDecoder(int idx) {
|
||||
recognizer_controller_->InitDecoder(idx);
|
||||
return;
|
||||
}
|
||||
|
||||
int RecognizerInstance::GetRecognizerInstanceId() {
|
||||
return recognizer_controller_->GetRecognizerInstanceId();
|
||||
}
|
||||
|
||||
void RecognizerInstance::Accept(const std::vector<float>& waves, int idx) const {
|
||||
recognizer_controller_->Accept(waves, idx);
|
||||
return;
|
||||
}
|
||||
|
||||
void RecognizerInstance::SetInputFinished(int idx) const {
|
||||
recognizer_controller_->SetInputFinished(idx);
|
||||
return;
|
||||
}
|
||||
|
||||
std::string RecognizerInstance::GetResult(int idx) const {
|
||||
return recognizer_controller_->GetFinalResult(idx);
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,42 @@
|
||||
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "base/common.h"
|
||||
#include "recognizer/recognizer_controller.h"
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
class RecognizerInstance {
|
||||
public:
|
||||
static RecognizerInstance& GetInstance();
|
||||
RecognizerInstance() {}
|
||||
~RecognizerInstance() {}
|
||||
bool Init(const std::string& model_file,
|
||||
const std::string& word_symbol_table_file,
|
||||
const std::string& fst_file,
|
||||
int num_instance);
|
||||
int GetRecognizerInstanceId();
|
||||
void InitDecoder(int idx);
|
||||
void Accept(const std::vector<float>& waves, int idx) const;
|
||||
void SetInputFinished(int idx) const;
|
||||
std::string GetResult(int idx) const;
|
||||
|
||||
private:
|
||||
std::unique_ptr<RecognizerController> recognizer_controller_;
|
||||
};
|
||||
|
||||
|
||||
} // namespace ppspeech
|
@ -0,0 +1,41 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
data=data
|
||||
exp=exp
|
||||
nj=40
|
||||
|
||||
. utils/parse_options.sh
|
||||
|
||||
mkdir -p $exp
|
||||
ckpt_dir=./data/model
|
||||
model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.3.0.model/
|
||||
aishell_wav_scp=aishell_test.scp
|
||||
text=$data/test/text
|
||||
|
||||
./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj
|
||||
|
||||
lang_dir=./data/lang_test/
|
||||
graph=$lang_dir/TLG.fst
|
||||
word_table=$lang_dir/words.txt
|
||||
|
||||
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recognizer_wfst.log \
|
||||
recognizer_main \
|
||||
--use_fbank=true \
|
||||
--num_bins=80 \
|
||||
--cmvn_file=$model_dir/mean_std.json \
|
||||
--model_path=$model_dir/export.jit \
|
||||
--graph_path=$lang_dir/TLG.fst \
|
||||
--word_symbol_table=$word_table \
|
||||
--nnet_decoder_chunk=16 \
|
||||
--receptive_field_length=7 \
|
||||
--subsampling_rate=4 \
|
||||
--wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \
|
||||
--result_wspecifier=ark,t:$data/split${nj}/JOB/result_recognizer_wfst.ark
|
||||
|
||||
|
||||
cat $data/split${nj}/*/result_recognizer_wfst.ark > $exp/aishell_recognizer_wfst
|
||||
utils/compute-wer.py --char=1 --v=1 $text $exp/aishell_recognizer_wfst > $exp/aishell.recognizer_wfst.err
|
||||
echo "recognizer test have finished!!!"
|
||||
echo "please checkout in $exp/aishell.recognizer_wfst.err"
|
||||
tail -n 7 $exp/aishell.recognizer_wfst.err
|
Loading…
Reference in new issue