u2 recog test main ok

pull/2524/head
Hui Zhang 2 years ago
parent 86eb718908
commit 17ea30e7ca

@ -0,0 +1,22 @@
#!/bin/bash
set -e
. path.sh
data=data
exp=exp
mkdir -p $exp
ckpt_dir=./data/model
model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
u2_recognizer_main \
--use_fbank=true \
--num_bins=80 \
--cmvn_file=$exp/cmvn.ark \
--model_path=$model_dir/export.jit \
--nnet_decoder_chunk=16 \
--receptive_field_length=7 \
--downsampling_rate=4 \
--vocab_path=$model_dir/unit.txt \
--wav_rspecifier=scp:$data/wav.scp \
--result_wspecifier=ark,t:$exp/result.ark

@ -52,11 +52,12 @@ DEFINE_string(model_cache_names,
"chunk_state_h_box,chunk_state_c_box", "chunk_state_h_box,chunk_state_c_box",
"model cache names"); "model cache names");
DEFINE_string(model_cache_shapes, "5-1-1024,5-1-1024", "model cache shapes"); DEFINE_string(model_cache_shapes, "5-1-1024,5-1-1024", "model cache shapes");
DEFINE_string(vocab_path, "", "nnet vocab path.");
// decoder // decoder
DEFINE_string(word_symbol_table, "words.txt", "word symbol table");
DEFINE_string(graph_path, "TLG", "decoder graph");
DEFINE_double(acoustic_scale, 1.0, "acoustic scale"); DEFINE_double(acoustic_scale, 1.0, "acoustic scale");
DEFINE_string(graph_path, "TLG", "decoder graph");
DEFINE_string(word_symbol_table, "words.txt", "word symbol table");
DEFINE_int32(max_active, 7500, "max active"); DEFINE_int32(max_active, 7500, "max active");
DEFINE_double(beam, 15.0, "decoder beam"); DEFINE_double(beam, 15.0, "decoder beam");
DEFINE_double(lattice_beam, 7.5, "decoder beam"); DEFINE_double(lattice_beam, 7.5, "decoder beam");
@ -72,13 +73,14 @@ FeaturePipelineOptions InitFeaturePipelineOptions() {
frame_opts.dither = 0.0; frame_opts.dither = 0.0;
frame_opts.frame_shift_ms = 10; frame_opts.frame_shift_ms = 10;
opts.use_fbank = FLAGS_use_fbank; opts.use_fbank = FLAGS_use_fbank;
LOG(INFO) << "feature type: " << opts.use_fbank ? "fbank" : "linear"; LOG(INFO) << "feature type: " << (opts.use_fbank ? "fbank" : "linear");
if (opts.use_fbank) { if (opts.use_fbank) {
opts.to_float32 = false; opts.to_float32 = false;
frame_opts.window_type = "povey"; frame_opts.window_type = "povey";
frame_opts.frame_length_ms = 25; frame_opts.frame_length_ms = 25;
opts.fbank_opts.mel_opts.num_bins = FLAGS_num_bins; opts.fbank_opts.mel_opts.num_bins = FLAGS_num_bins;
opts.fbank_opts.frame_opts = frame_opts; opts.fbank_opts.frame_opts = frame_opts;
LOG(INFO) << "num bins: " << opts.fbank_opts.mel_opts.num_bins;
} else { } else {
opts.to_float32 = true; opts.to_float32 = true;
frame_opts.remove_dc_offset = false; frame_opts.remove_dc_offset = false;

@ -33,12 +33,15 @@ U2Recognizer::U2Recognizer(const U2RecognizerResource& resource): opts_(resource
BaseFloat am_scale = resource.acoustic_scale; BaseFloat am_scale = resource.acoustic_scale;
decodable_.reset(new Decodable(nnet, feature_pipeline_, am_scale)); decodable_.reset(new Decodable(nnet, feature_pipeline_, am_scale));
CHECK(resource.vocab_path != "");
decoder_.reset(new CTCPrefixBeamSearch(resource.vocab_path, resource.decoder_opts.ctc_prefix_search_opts)); decoder_.reset(new CTCPrefixBeamSearch(resource.vocab_path, resource.decoder_opts.ctc_prefix_search_opts));
unit_table_ = decoder_->VocabTable(); unit_table_ = decoder_->VocabTable();
symbol_table_ = unit_table_; symbol_table_ = unit_table_;
input_finished_ = false; input_finished_ = false;
Reset();
} }
void U2Recognizer::Reset() { void U2Recognizer::Reset() {
@ -69,6 +72,7 @@ void U2Recognizer::Accept(const VectorBase<BaseFloat>& waves) {
void U2Recognizer::Decode() { void U2Recognizer::Decode() {
decoder_->AdvanceDecode(decodable_); decoder_->AdvanceDecode(decodable_);
UpdateResult(false);
} }
void U2Recognizer::Rescoring() { void U2Recognizer::Rescoring() {

@ -92,12 +92,13 @@ struct DecodeOptions {
struct U2RecognizerResource { struct U2RecognizerResource {
kaldi::BaseFloat acoustic_scale{1.0};
std::string vocab_path{};
FeaturePipelineOptions feature_pipeline_opts{}; FeaturePipelineOptions feature_pipeline_opts{};
ModelOptions model_opts{}; ModelOptions model_opts{};
DecodeOptions decoder_opts{}; DecodeOptions decoder_opts{};
// CTCBeamSearchOptions beam_search_opts; // CTCBeamSearchOptions beam_search_opts;
kaldi::BaseFloat acoustic_scale{1.0};
std::string vocab_path{};
}; };

@ -25,13 +25,16 @@ DEFINE_int32(sample_rate, 16000, "sample rate");
ppspeech::U2RecognizerResource InitOpts() { ppspeech::U2RecognizerResource InitOpts() {
ppspeech::U2RecognizerResource resource; ppspeech::U2RecognizerResource resource;
resource.vocab_path = FLAGS_vocab_path;
resource.acoustic_scale = FLAGS_acoustic_scale; resource.acoustic_scale = FLAGS_acoustic_scale;
resource.feature_pipeline_opts = ppspeech::InitFeaturePipelineOptions();
resource.feature_pipeline_opts = ppspeech::InitFeaturePipelineOptions();
LOG(INFO) << "feature!";
ppspeech::ModelOptions model_opts; ppspeech::ModelOptions model_opts;
model_opts.model_path = FLAGS_model_path; model_opts.model_path = FLAGS_model_path;
resource.model_opts = model_opts; resource.model_opts = model_opts;
LOG(INFO) << "model!";
ppspeech::DecodeOptions decoder_opts; ppspeech::DecodeOptions decoder_opts;
decoder_opts.chunk_size=16; decoder_opts.chunk_size=16;
@ -44,6 +47,7 @@ ppspeech::U2RecognizerResource InitOpts() {
decoder_opts.ctc_prefix_search_opts.second_beam_size = 10; decoder_opts.ctc_prefix_search_opts.second_beam_size = 10;
resource.decoder_opts = decoder_opts; resource.decoder_opts = decoder_opts;
LOG(INFO) << "decoder!";
return resource; return resource;
} }
@ -57,9 +61,6 @@ int main(int argc, char* argv[]) {
int32 num_done = 0, num_err = 0; int32 num_done = 0, num_err = 0;
double tot_wav_duration = 0.0; double tot_wav_duration = 0.0;
ppspeech::U2RecognizerResource resource = InitOpts();
ppspeech::U2Recognizer recognizer(resource);
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader( kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
FLAGS_wav_rspecifier); FLAGS_wav_rspecifier);
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier); kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
@ -71,8 +72,10 @@ int main(int argc, char* argv[]) {
LOG(INFO) << "chunk size (s): " << streaming_chunk; LOG(INFO) << "chunk size (s): " << streaming_chunk;
LOG(INFO) << "chunk size (sample): " << chunk_sample_size; LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
kaldi::Timer timer; ppspeech::U2RecognizerResource resource = InitOpts();
ppspeech::U2Recognizer recognizer(resource);
kaldi::Timer timer;
for (; !wav_reader.Done(); wav_reader.Next()) { for (; !wav_reader.Done(); wav_reader.Next()) {
std::string utt = wav_reader.Key(); std::string utt = wav_reader.Key();
const kaldi::WaveData& wave_data = wav_reader.Value(); const kaldi::WaveData& wave_data = wav_reader.Value();

@ -29,7 +29,9 @@ using std::unique_ptr;
CMVN::CMVN(std::string cmvn_file, unique_ptr<FrontendInterface> base_extractor) CMVN::CMVN(std::string cmvn_file, unique_ptr<FrontendInterface> base_extractor)
: var_norm_(true) { : var_norm_(true) {
CHECK(cmvn_file != "");
base_extractor_ = std::move(base_extractor); base_extractor_ = std::move(base_extractor);
bool binary; bool binary;
kaldi::Input ki(cmvn_file, &binary); kaldi::Input ki(cmvn_file, &binary);
stats_.Read(ki.Stream(), binary); stats_.Read(ki.Stream(), binary);
@ -55,11 +57,11 @@ bool CMVN::Read(kaldi::Vector<BaseFloat>* feats) {
// feats contain num_frames feature. // feats contain num_frames feature.
void CMVN::Compute(VectorBase<BaseFloat>* feats) const { void CMVN::Compute(VectorBase<BaseFloat>* feats) const {
KALDI_ASSERT(feats != NULL); KALDI_ASSERT(feats != NULL);
int32 dim = stats_.NumCols() - 1;
if (stats_.NumRows() > 2 || stats_.NumRows() < 1 || if (stats_.NumRows() > 2 || stats_.NumRows() < 1 ||
feats->Dim() % dim != 0) { feats->Dim() % dim_ != 0) {
KALDI_ERR << "Dim mismatch: cmvn " << stats_.NumRows() << 'x' KALDI_ERR << "Dim mismatch: cmvn " << stats_.NumRows() << ','
<< stats_.NumCols() << ", feats " << feats->Dim() << 'x'; << stats_.NumCols() - 1 << ", feats " << feats->Dim() << 'x';
} }
if (stats_.NumRows() == 1 && var_norm_) { if (stats_.NumRows() == 1 && var_norm_) {
KALDI_ERR KALDI_ERR
@ -67,7 +69,7 @@ void CMVN::Compute(VectorBase<BaseFloat>* feats) const {
<< "are supplied."; << "are supplied.";
} }
double count = stats_(0, dim); double count = stats_(0, dim_);
// Do not change the threshold of 1.0 here: in the balanced-cmvn code, when // Do not change the threshold of 1.0 here: in the balanced-cmvn code, when
// computing an offset and representing it as stats_, we use a count of one. // computing an offset and representing it as stats_, we use a count of one.
if (count < 1.0) if (count < 1.0)
@ -77,14 +79,14 @@ void CMVN::Compute(VectorBase<BaseFloat>* feats) const {
if (!var_norm_) { if (!var_norm_) {
Vector<BaseFloat> offset(feats->Dim()); Vector<BaseFloat> offset(feats->Dim());
SubVector<double> mean_stats(stats_.RowData(0), dim); SubVector<double> mean_stats(stats_.RowData(0), dim_);
Vector<double> mean_stats_apply(feats->Dim()); Vector<double> mean_stats_apply(feats->Dim());
// fill the datat of mean_stats in mean_stats_appy whose dim is equal // fill the datat of mean_stats in mean_stats_appy whose dim_ is equal
// with the dim of feature. // with the dim_ of feature.
// the dim of feats = dim * num_frames; // the dim_ of feats = dim_ * num_frames;
for (int32 idx = 0; idx < feats->Dim() / dim; ++idx) { for (int32 idx = 0; idx < feats->Dim() / dim_; ++idx) {
SubVector<double> stats_tmp(mean_stats_apply.Data() + dim * idx, SubVector<double> stats_tmp(mean_stats_apply.Data() + dim_ * idx,
dim); dim_);
stats_tmp.CopyFromVec(mean_stats); stats_tmp.CopyFromVec(mean_stats);
} }
offset.AddVec(-1.0 / count, mean_stats_apply); offset.AddVec(-1.0 / count, mean_stats_apply);
@ -94,7 +96,7 @@ void CMVN::Compute(VectorBase<BaseFloat>* feats) const {
// norm(0, d) = mean offset; // norm(0, d) = mean offset;
// norm(1, d) = scale, e.g. x(d) <-- x(d)*norm(1, d) + norm(0, d). // norm(1, d) = scale, e.g. x(d) <-- x(d)*norm(1, d) + norm(0, d).
kaldi::Matrix<BaseFloat> norm(2, feats->Dim()); kaldi::Matrix<BaseFloat> norm(2, feats->Dim());
for (int32 d = 0; d < dim; d++) { for (int32 d = 0; d < dim_; d++) {
double mean, offset, scale; double mean, offset, scale;
mean = stats_(0, d) / count; mean = stats_(0, d) / count;
double var = (stats_(1, d) / count) - mean * mean, floor = 1.0e-20; double var = (stats_(1, d) / count) - mean * mean, floor = 1.0e-20;
@ -111,7 +113,7 @@ void CMVN::Compute(VectorBase<BaseFloat>* feats) const {
for (int32 d_skip = d; d_skip < feats->Dim();) { for (int32 d_skip = d; d_skip < feats->Dim();) {
norm(0, d_skip) = offset; norm(0, d_skip) = offset;
norm(1, d_skip) = scale; norm(1, d_skip) = scale;
d_skip = d_skip + dim; d_skip = d_skip + dim_;
} }
} }
// Apply the normalization. // Apply the normalization.

@ -32,6 +32,7 @@ FeaturePipeline::FeaturePipeline(const FeaturePipelineOptions& opts) : opts_(opt
opts.linear_spectrogram_opts, std::move(data_source))); opts.linear_spectrogram_opts, std::move(data_source)));
} }
CHECK(opts.cmvn_file != "");
unique_ptr<FrontendInterface> cmvn( unique_ptr<FrontendInterface> cmvn(
new ppspeech::CMVN(opts.cmvn_file, std::move(base_feature))); new ppspeech::CMVN(opts.cmvn_file, std::move(base_feature)));

Loading…
Cancel
Save