parent
418cc37ffb
commit
ae629e2fe6
@ -1,6 +0,0 @@
|
|||||||
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
|
|
||||||
|
|
||||||
add_subdirectory(feat)
|
|
||||||
add_subdirectory(nnet)
|
|
||||||
add_subdirectory(decoder)
|
|
||||||
add_subdirectory(websocket)
|
|
@ -1,2 +0,0 @@
|
|||||||
data
|
|
||||||
exp
|
|
@ -1,22 +0,0 @@
|
|||||||
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
|
|
||||||
|
|
||||||
set(bin_name ctc-prefix-beam-search-decoder-ol)
|
|
||||||
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
|
|
||||||
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
|
||||||
target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
|
|
||||||
|
|
||||||
|
|
||||||
set(bin_name wfst-decoder-ol)
|
|
||||||
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
|
|
||||||
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
|
||||||
target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder ${DEPS})
|
|
||||||
|
|
||||||
|
|
||||||
set(bin_name nnet-logprob-decoder-test)
|
|
||||||
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
|
|
||||||
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
|
||||||
target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
|
|
||||||
|
|
||||||
add_executable(recognizer_test_main ${CMAKE_CURRENT_SOURCE_DIR}/recognizer_test_main.cc)
|
|
||||||
target_include_directories(recognizer_test_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
|
||||||
target_link_libraries(recognizer_test_main PUBLIC frontend kaldi-feat-common nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder ${DEPS})
|
|
@ -1,12 +0,0 @@
|
|||||||
# ASR Decoder
|
|
||||||
|
|
||||||
ASR Decoder test bins. We using theses bins to test CTC BeamSearch decoder and WFST decoder.
|
|
||||||
|
|
||||||
* decoder_test_main.cc
|
|
||||||
feed nnet output logprob, and only test decoder
|
|
||||||
|
|
||||||
* offline_decoder_sliding_chunk_main.cc
|
|
||||||
feed streaming audio feature, decode as streaming manner.
|
|
||||||
|
|
||||||
* offline_wfst_decoder_main.cc
|
|
||||||
feed streaming audio feature, decode using WFST as streaming manner.
|
|
@ -1,167 +0,0 @@
|
|||||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
// todo refactor, repalce with gtest
|
|
||||||
|
|
||||||
#include "base/flags.h"
|
|
||||||
#include "base/log.h"
|
|
||||||
#include "decoder/ctc_beam_search_decoder.h"
|
|
||||||
#include "frontend/audio/data_cache.h"
|
|
||||||
#include "kaldi/util/table-types.h"
|
|
||||||
#include "nnet/decodable.h"
|
|
||||||
#include "nnet/paddle_nnet.h"
|
|
||||||
|
|
||||||
DEFINE_string(feature_rspecifier, "", "test feature rspecifier");
|
|
||||||
DEFINE_string(result_wspecifier, "", "test result wspecifier");
|
|
||||||
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
|
|
||||||
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
|
|
||||||
DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm");
|
|
||||||
DEFINE_string(lm_path, "", "language model");
|
|
||||||
DEFINE_int32(receptive_field_length,
|
|
||||||
7,
|
|
||||||
"receptive field of two CNN(kernel=5) downsampling module.");
|
|
||||||
DEFINE_int32(downsampling_rate,
|
|
||||||
4,
|
|
||||||
"two CNN(kernel=5) module downsampling rate.");
|
|
||||||
DEFINE_string(
|
|
||||||
model_input_names,
|
|
||||||
"audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box",
|
|
||||||
"model input names");
|
|
||||||
DEFINE_string(model_output_names,
|
|
||||||
"softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0",
|
|
||||||
"model output names");
|
|
||||||
DEFINE_string(model_cache_names,
|
|
||||||
"chunk_state_h_box,chunk_state_c_box",
|
|
||||||
"model cache names");
|
|
||||||
DEFINE_string(model_cache_shapes, "5-1-1024,5-1-1024", "model cache shapes");
|
|
||||||
|
|
||||||
using kaldi::BaseFloat;
|
|
||||||
using kaldi::Matrix;
|
|
||||||
using std::vector;
|
|
||||||
|
|
||||||
// test ds2 online decoder by feeding speech feature
|
|
||||||
int main(int argc, char* argv[]) {
|
|
||||||
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
|
||||||
google::InitGoogleLogging(argv[0]);
|
|
||||||
|
|
||||||
CHECK(FLAGS_result_wspecifier != "");
|
|
||||||
CHECK(FLAGS_feature_rspecifier != "");
|
|
||||||
|
|
||||||
kaldi::SequentialBaseFloatMatrixReader feature_reader(
|
|
||||||
FLAGS_feature_rspecifier);
|
|
||||||
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
|
|
||||||
std::string model_path = FLAGS_model_path;
|
|
||||||
std::string model_params = FLAGS_param_path;
|
|
||||||
std::string dict_file = FLAGS_dict_file;
|
|
||||||
std::string lm_path = FLAGS_lm_path;
|
|
||||||
LOG(INFO) << "model path: " << model_path;
|
|
||||||
LOG(INFO) << "model param: " << model_params;
|
|
||||||
LOG(INFO) << "dict path: " << dict_file;
|
|
||||||
LOG(INFO) << "lm path: " << lm_path;
|
|
||||||
|
|
||||||
int32 num_done = 0, num_err = 0;
|
|
||||||
|
|
||||||
ppspeech::CTCBeamSearchOptions opts;
|
|
||||||
opts.dict_file = dict_file;
|
|
||||||
opts.lm_path = lm_path;
|
|
||||||
ppspeech::CTCBeamSearch decoder(opts);
|
|
||||||
|
|
||||||
ppspeech::ModelOptions model_opts;
|
|
||||||
model_opts.model_path = model_path;
|
|
||||||
model_opts.param_path = model_params;
|
|
||||||
model_opts.cache_names = FLAGS_model_cache_names;
|
|
||||||
model_opts.cache_shape = FLAGS_model_cache_shapes;
|
|
||||||
model_opts.input_names = FLAGS_model_input_names;
|
|
||||||
model_opts.output_names = FLAGS_model_output_names;
|
|
||||||
std::shared_ptr<ppspeech::PaddleNnet> nnet(
|
|
||||||
new ppspeech::PaddleNnet(model_opts));
|
|
||||||
std::shared_ptr<ppspeech::DataCache> raw_data(new ppspeech::DataCache());
|
|
||||||
std::shared_ptr<ppspeech::Decodable> decodable(
|
|
||||||
new ppspeech::Decodable(nnet, raw_data));
|
|
||||||
|
|
||||||
int32 chunk_size = FLAGS_receptive_field_length;
|
|
||||||
int32 chunk_stride = FLAGS_downsampling_rate;
|
|
||||||
int32 receptive_field_length = FLAGS_receptive_field_length;
|
|
||||||
LOG(INFO) << "chunk size (frame): " << chunk_size;
|
|
||||||
LOG(INFO) << "chunk stride (frame): " << chunk_stride;
|
|
||||||
LOG(INFO) << "receptive field (frame): " << receptive_field_length;
|
|
||||||
decoder.InitDecoder();
|
|
||||||
|
|
||||||
kaldi::Timer timer;
|
|
||||||
for (; !feature_reader.Done(); feature_reader.Next()) {
|
|
||||||
string utt = feature_reader.Key();
|
|
||||||
kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
|
|
||||||
raw_data->SetDim(feature.NumCols());
|
|
||||||
LOG(INFO) << "process utt: " << utt;
|
|
||||||
LOG(INFO) << "rows: " << feature.NumRows();
|
|
||||||
LOG(INFO) << "cols: " << feature.NumCols();
|
|
||||||
|
|
||||||
int32 row_idx = 0;
|
|
||||||
int32 padding_len = 0;
|
|
||||||
int32 ori_feature_len = feature.NumRows();
|
|
||||||
if ((feature.NumRows() - chunk_size) % chunk_stride != 0) {
|
|
||||||
padding_len =
|
|
||||||
chunk_stride - (feature.NumRows() - chunk_size) % chunk_stride;
|
|
||||||
feature.Resize(feature.NumRows() + padding_len,
|
|
||||||
feature.NumCols(),
|
|
||||||
kaldi::kCopyData);
|
|
||||||
}
|
|
||||||
int32 num_chunks = (feature.NumRows() - chunk_size) / chunk_stride + 1;
|
|
||||||
for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) {
|
|
||||||
kaldi::Vector<kaldi::BaseFloat> feature_chunk(chunk_size *
|
|
||||||
feature.NumCols());
|
|
||||||
int32 feature_chunk_size = 0;
|
|
||||||
if (ori_feature_len > chunk_idx * chunk_stride) {
|
|
||||||
feature_chunk_size = std::min(
|
|
||||||
ori_feature_len - chunk_idx * chunk_stride, chunk_size);
|
|
||||||
}
|
|
||||||
if (feature_chunk_size < receptive_field_length) break;
|
|
||||||
|
|
||||||
int32 start = chunk_idx * chunk_stride;
|
|
||||||
|
|
||||||
for (int row_id = 0; row_id < chunk_size; ++row_id) {
|
|
||||||
kaldi::SubVector<kaldi::BaseFloat> tmp(feature, start);
|
|
||||||
kaldi::SubVector<kaldi::BaseFloat> f_chunk_tmp(
|
|
||||||
feature_chunk.Data() + row_id * feature.NumCols(),
|
|
||||||
feature.NumCols());
|
|
||||||
f_chunk_tmp.CopyFromVec(tmp);
|
|
||||||
++start;
|
|
||||||
}
|
|
||||||
raw_data->Accept(feature_chunk);
|
|
||||||
if (chunk_idx == num_chunks - 1) {
|
|
||||||
raw_data->SetFinished();
|
|
||||||
}
|
|
||||||
decoder.AdvanceDecode(decodable);
|
|
||||||
}
|
|
||||||
std::string result;
|
|
||||||
result = decoder.GetFinalBestPath();
|
|
||||||
decodable->Reset();
|
|
||||||
decoder.Reset();
|
|
||||||
if (result.empty()) {
|
|
||||||
// the TokenWriter can not write empty string.
|
|
||||||
++num_err;
|
|
||||||
KALDI_LOG << " the result of " << utt << " is empty";
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
KALDI_LOG << " the result of " << utt << " is " << result;
|
|
||||||
result_writer.Write(utt, result);
|
|
||||||
++num_done;
|
|
||||||
}
|
|
||||||
|
|
||||||
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
|
|
||||||
<< " with errors.";
|
|
||||||
double elapsed = timer.Elapsed();
|
|
||||||
KALDI_LOG << " cost:" << elapsed << " s";
|
|
||||||
return (num_done != 0 ? 0 : 1);
|
|
||||||
}
|
|
@ -1,3 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
|
|
@ -1,74 +0,0 @@
|
|||||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
// todo refactor, repalce with gtest
|
|
||||||
|
|
||||||
#include "base/flags.h"
|
|
||||||
#include "base/log.h"
|
|
||||||
#include "decoder/ctc_beam_search_decoder.h"
|
|
||||||
#include "kaldi/util/table-types.h"
|
|
||||||
#include "nnet/decodable.h"
|
|
||||||
|
|
||||||
DEFINE_string(nnet_prob_respecifier, "", "test nnet prob rspecifier");
|
|
||||||
DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm");
|
|
||||||
DEFINE_string(lm_path, "lm.klm", "language model");
|
|
||||||
|
|
||||||
using kaldi::BaseFloat;
|
|
||||||
using kaldi::Matrix;
|
|
||||||
using std::vector;
|
|
||||||
|
|
||||||
// test decoder by feeding nnet posterior probability
|
|
||||||
int main(int argc, char* argv[]) {
|
|
||||||
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
|
||||||
google::InitGoogleLogging(argv[0]);
|
|
||||||
|
|
||||||
kaldi::SequentialBaseFloatMatrixReader likelihood_reader(
|
|
||||||
FLAGS_nnet_prob_respecifier);
|
|
||||||
std::string dict_file = FLAGS_dict_file;
|
|
||||||
std::string lm_path = FLAGS_lm_path;
|
|
||||||
LOG(INFO) << "dict path: " << dict_file;
|
|
||||||
LOG(INFO) << "lm path: " << lm_path;
|
|
||||||
|
|
||||||
int32 num_done = 0, num_err = 0;
|
|
||||||
|
|
||||||
ppspeech::CTCBeamSearchOptions opts;
|
|
||||||
opts.dict_file = dict_file;
|
|
||||||
opts.lm_path = lm_path;
|
|
||||||
ppspeech::CTCBeamSearch decoder(opts);
|
|
||||||
|
|
||||||
std::shared_ptr<ppspeech::Decodable> decodable(
|
|
||||||
new ppspeech::Decodable(nullptr, nullptr));
|
|
||||||
|
|
||||||
decoder.InitDecoder();
|
|
||||||
|
|
||||||
for (; !likelihood_reader.Done(); likelihood_reader.Next()) {
|
|
||||||
string utt = likelihood_reader.Key();
|
|
||||||
const kaldi::Matrix<BaseFloat> likelihood = likelihood_reader.Value();
|
|
||||||
LOG(INFO) << "process utt: " << utt;
|
|
||||||
LOG(INFO) << "rows: " << likelihood.NumRows();
|
|
||||||
LOG(INFO) << "cols: " << likelihood.NumCols();
|
|
||||||
decodable->Acceptlikelihood(likelihood);
|
|
||||||
decoder.AdvanceDecode(decodable);
|
|
||||||
std::string result;
|
|
||||||
result = decoder.GetFinalBestPath();
|
|
||||||
KALDI_LOG << " the result of " << utt << " is " << result;
|
|
||||||
decodable->Reset();
|
|
||||||
decoder.Reset();
|
|
||||||
++num_done;
|
|
||||||
}
|
|
||||||
|
|
||||||
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
|
|
||||||
<< " with errors.";
|
|
||||||
return (num_done != 0 ? 0 : 1);
|
|
||||||
}
|
|
@ -1,14 +0,0 @@
|
|||||||
# This contains the locations of binarys build required for running the examples.
|
|
||||||
|
|
||||||
SPEECHX_ROOT=$PWD/../../../
|
|
||||||
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
|
|
||||||
|
|
||||||
SPEECHX_TOOLS=$SPEECHX_ROOT/tools
|
|
||||||
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
|
|
||||||
|
|
||||||
[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; }
|
|
||||||
|
|
||||||
export LC_AL=C
|
|
||||||
|
|
||||||
SPEECHX_BIN=$SPEECHX_EXAMPLES/ds2_ol/decoder:$SPEECHX_EXAMPLES/ds2_ol/feat
|
|
||||||
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN
|
|
@ -1,99 +0,0 @@
|
|||||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
#include "decoder/recognizer.h"
|
|
||||||
#include "decoder/param.h"
|
|
||||||
#include "kaldi/feat/wave-reader.h"
|
|
||||||
#include "kaldi/util/table-types.h"
|
|
||||||
|
|
||||||
DEFINE_string(wav_rspecifier, "", "test feature rspecifier");
|
|
||||||
DEFINE_string(result_wspecifier, "", "test result wspecifier");
|
|
||||||
DEFINE_int32(sample_rate, 16000, "sample rate");
|
|
||||||
|
|
||||||
int main(int argc, char* argv[]) {
|
|
||||||
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
|
||||||
google::InitGoogleLogging(argv[0]);
|
|
||||||
|
|
||||||
ppspeech::RecognizerResource resource = ppspeech::InitRecognizerResoure();
|
|
||||||
ppspeech::Recognizer recognizer(resource);
|
|
||||||
|
|
||||||
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
|
|
||||||
FLAGS_wav_rspecifier);
|
|
||||||
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
|
|
||||||
|
|
||||||
int sample_rate = FLAGS_sample_rate;
|
|
||||||
float streaming_chunk = FLAGS_streaming_chunk;
|
|
||||||
int chunk_sample_size = streaming_chunk * sample_rate;
|
|
||||||
LOG(INFO) << "sr: " << sample_rate;
|
|
||||||
LOG(INFO) << "chunk size (s): " << streaming_chunk;
|
|
||||||
LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
|
|
||||||
|
|
||||||
int32 num_done = 0, num_err = 0;
|
|
||||||
double tot_wav_duration = 0.0;
|
|
||||||
|
|
||||||
kaldi::Timer timer;
|
|
||||||
|
|
||||||
for (; !wav_reader.Done(); wav_reader.Next()) {
|
|
||||||
std::string utt = wav_reader.Key();
|
|
||||||
const kaldi::WaveData& wave_data = wav_reader.Value();
|
|
||||||
|
|
||||||
int32 this_channel = 0;
|
|
||||||
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
|
|
||||||
this_channel);
|
|
||||||
int tot_samples = waveform.Dim();
|
|
||||||
tot_wav_duration += tot_samples * 1.0 / sample_rate;
|
|
||||||
LOG(INFO) << "wav len (sample): " << tot_samples;
|
|
||||||
|
|
||||||
int sample_offset = 0;
|
|
||||||
std::vector<kaldi::Vector<BaseFloat>> feats;
|
|
||||||
int feature_rows = 0;
|
|
||||||
while (sample_offset < tot_samples) {
|
|
||||||
int cur_chunk_size =
|
|
||||||
std::min(chunk_sample_size, tot_samples - sample_offset);
|
|
||||||
|
|
||||||
kaldi::Vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
|
|
||||||
for (int i = 0; i < cur_chunk_size; ++i) {
|
|
||||||
wav_chunk(i) = waveform(sample_offset + i);
|
|
||||||
}
|
|
||||||
// wav_chunk = waveform.Range(sample_offset + i, cur_chunk_size);
|
|
||||||
|
|
||||||
recognizer.Accept(wav_chunk);
|
|
||||||
if (cur_chunk_size < chunk_sample_size) {
|
|
||||||
recognizer.SetFinished();
|
|
||||||
}
|
|
||||||
recognizer.Decode();
|
|
||||||
|
|
||||||
// no overlap
|
|
||||||
sample_offset += cur_chunk_size;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string result;
|
|
||||||
result = recognizer.GetFinalResult();
|
|
||||||
recognizer.Reset();
|
|
||||||
if (result.empty()) {
|
|
||||||
// the TokenWriter can not write empty string.
|
|
||||||
++num_err;
|
|
||||||
KALDI_LOG << " the result of " << utt << " is empty";
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
KALDI_LOG << " the result of " << utt << " is " << result;
|
|
||||||
result_writer.Write(utt, result);
|
|
||||||
++num_done;
|
|
||||||
}
|
|
||||||
double elapsed = timer.Elapsed();
|
|
||||||
KALDI_LOG << "Done " << num_done << " out of " << (num_err + num_done);
|
|
||||||
KALDI_LOG << " cost:" << elapsed << " s";
|
|
||||||
KALDI_LOG << "total wav duration is: " << tot_wav_duration << " s";
|
|
||||||
KALDI_LOG << "the RTF is: " << elapsed / tot_wav_duration;
|
|
||||||
}
|
|
@ -1,78 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
set +x
|
|
||||||
set -e
|
|
||||||
|
|
||||||
. path.sh
|
|
||||||
|
|
||||||
# 1. compile
|
|
||||||
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
|
|
||||||
pushd ${SPEECHX_ROOT}
|
|
||||||
bash build.sh
|
|
||||||
popd
|
|
||||||
fi
|
|
||||||
|
|
||||||
# input
|
|
||||||
mkdir -p data
|
|
||||||
data=$PWD/data
|
|
||||||
ckpt_dir=$data/model
|
|
||||||
model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/
|
|
||||||
vocb_dir=$ckpt_dir/data/lang_char/
|
|
||||||
|
|
||||||
lm=$data/zh_giga.no_cna_cmn.prune01244.klm
|
|
||||||
|
|
||||||
# output
|
|
||||||
exp_dir=./exp
|
|
||||||
mkdir -p $exp_dir
|
|
||||||
|
|
||||||
# 2. download model
|
|
||||||
if [[ ! -f data/model/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz ]]; then
|
|
||||||
mkdir -p data/model
|
|
||||||
pushd data/model
|
|
||||||
wget -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
|
|
||||||
tar xzfv asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
|
|
||||||
popd
|
|
||||||
fi
|
|
||||||
|
|
||||||
# produce wav scp
|
|
||||||
if [ ! -f data/wav.scp ]; then
|
|
||||||
pushd data
|
|
||||||
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
|
|
||||||
echo "utt1 " $PWD/zh.wav > wav.scp
|
|
||||||
popd
|
|
||||||
fi
|
|
||||||
|
|
||||||
# download lm
|
|
||||||
if [ ! -f $lm ]; then
|
|
||||||
pushd data
|
|
||||||
wget -c https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm
|
|
||||||
popd
|
|
||||||
fi
|
|
||||||
|
|
||||||
feat_wspecifier=$exp_dir/feats.ark
|
|
||||||
cmvn=$exp_dir/cmvn.ark
|
|
||||||
|
|
||||||
export GLOG_logtostderr=1
|
|
||||||
|
|
||||||
# dump json cmvn to kaldi
|
|
||||||
cmvn-json2kaldi \
|
|
||||||
--json_file $ckpt_dir/data/mean_std.json \
|
|
||||||
--cmvn_write_path $cmvn \
|
|
||||||
--binary=false
|
|
||||||
echo "convert json cmvn to kaldi ark."
|
|
||||||
|
|
||||||
|
|
||||||
# generate linear feature as streaming
|
|
||||||
linear-spectrogram-wo-db-norm-ol \
|
|
||||||
--wav_rspecifier=scp:$data/wav.scp \
|
|
||||||
--feature_wspecifier=ark,t:$feat_wspecifier \
|
|
||||||
--cmvn_file=$cmvn
|
|
||||||
echo "compute linear spectrogram feature."
|
|
||||||
|
|
||||||
# run ctc beam search decoder as streaming
|
|
||||||
ctc-prefix-beam-search-decoder-ol \
|
|
||||||
--result_wspecifier=ark,t:$exp_dir/result.txt \
|
|
||||||
--feature_rspecifier=ark:$feat_wspecifier \
|
|
||||||
--model_path=$model_dir/avg_1.jit.pdmodel \
|
|
||||||
--param_path=$model_dir/avg_1.jit.pdiparams \
|
|
||||||
--dict_file=$vocb_dir/vocab.txt \
|
|
||||||
--lm_path=$lm
|
|
@ -1,26 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
# this script is for memory check, so please run ./run.sh first.
|
|
||||||
|
|
||||||
set +x
|
|
||||||
set -e
|
|
||||||
|
|
||||||
. ./path.sh
|
|
||||||
|
|
||||||
if [ ! -d ${SPEECHX_TOOLS}/valgrind/install ]; then
|
|
||||||
echo "please install valgrind in the speechx tools dir.\n"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
model_dir=../paddle_asr_model
|
|
||||||
feat_wspecifier=./feats.ark
|
|
||||||
cmvn=./cmvn.ark
|
|
||||||
|
|
||||||
valgrind --tool=memcheck --track-origins=yes --leak-check=full --show-leak-kinds=all \
|
|
||||||
offline_decoder_main \
|
|
||||||
--feature_respecifier=ark:$feat_wspecifier \
|
|
||||||
--model_path=$model_dir/avg_1.jit.pdmodel \
|
|
||||||
--param_path=$model_dir/avg_1.jit.pdparams \
|
|
||||||
--dict_file=$model_dir/vocab.txt \
|
|
||||||
--lm_path=$model_dir/avg_1.jit.klm
|
|
||||||
|
|
@ -1,169 +0,0 @@
|
|||||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
// todo refactor, repalce with gtest
|
|
||||||
|
|
||||||
#include "base/flags.h"
|
|
||||||
#include "base/log.h"
|
|
||||||
#include "decoder/ctc_tlg_decoder.h"
|
|
||||||
#include "frontend/audio/data_cache.h"
|
|
||||||
#include "kaldi/util/table-types.h"
|
|
||||||
#include "nnet/decodable.h"
|
|
||||||
#include "nnet/paddle_nnet.h"
|
|
||||||
|
|
||||||
DEFINE_string(feature_rspecifier, "", "test feature rspecifier");
|
|
||||||
DEFINE_string(result_wspecifier, "", "test result wspecifier");
|
|
||||||
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
|
|
||||||
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
|
|
||||||
DEFINE_string(word_symbol_table, "words.txt", "word symbol table");
|
|
||||||
DEFINE_string(graph_path, "TLG", "decoder graph");
|
|
||||||
|
|
||||||
DEFINE_double(acoustic_scale, 1.0, "acoustic scale");
|
|
||||||
DEFINE_int32(max_active, 7500, "decoder graph");
|
|
||||||
DEFINE_int32(receptive_field_length,
|
|
||||||
7,
|
|
||||||
"receptive field of two CNN(kernel=5) downsampling module.");
|
|
||||||
DEFINE_int32(downsampling_rate,
|
|
||||||
4,
|
|
||||||
"two CNN(kernel=5) module downsampling rate.");
|
|
||||||
DEFINE_string(
|
|
||||||
model_input_names,
|
|
||||||
"audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box",
|
|
||||||
"model input names");
|
|
||||||
DEFINE_string(model_output_names,
|
|
||||||
"softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0",
|
|
||||||
"model output names");
|
|
||||||
DEFINE_string(model_cache_names,
|
|
||||||
"chunk_state_h_box,chunk_state_c_box",
|
|
||||||
"model cache names");
|
|
||||||
DEFINE_string(model_cache_shapes, "5-1-1024,5-1-1024", "model cache shapes");
|
|
||||||
|
|
||||||
using kaldi::BaseFloat;
|
|
||||||
using kaldi::Matrix;
|
|
||||||
using std::vector;
|
|
||||||
|
|
||||||
// test TLG decoder by feeding speech feature.
|
|
||||||
int main(int argc, char* argv[]) {
|
|
||||||
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
|
||||||
google::InitGoogleLogging(argv[0]);
|
|
||||||
|
|
||||||
kaldi::SequentialBaseFloatMatrixReader feature_reader(
|
|
||||||
FLAGS_feature_rspecifier);
|
|
||||||
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
|
|
||||||
std::string model_graph = FLAGS_model_path;
|
|
||||||
std::string model_params = FLAGS_param_path;
|
|
||||||
std::string word_symbol_table = FLAGS_word_symbol_table;
|
|
||||||
std::string graph_path = FLAGS_graph_path;
|
|
||||||
LOG(INFO) << "model path: " << model_graph;
|
|
||||||
LOG(INFO) << "model param: " << model_params;
|
|
||||||
LOG(INFO) << "word symbol path: " << word_symbol_table;
|
|
||||||
LOG(INFO) << "graph path: " << graph_path;
|
|
||||||
|
|
||||||
int32 num_done = 0, num_err = 0;
|
|
||||||
|
|
||||||
ppspeech::TLGDecoderOptions opts;
|
|
||||||
opts.word_symbol_table = word_symbol_table;
|
|
||||||
opts.fst_path = graph_path;
|
|
||||||
opts.opts.max_active = FLAGS_max_active;
|
|
||||||
opts.opts.beam = 15.0;
|
|
||||||
opts.opts.lattice_beam = 7.5;
|
|
||||||
ppspeech::TLGDecoder decoder(opts);
|
|
||||||
|
|
||||||
ppspeech::ModelOptions model_opts;
|
|
||||||
model_opts.model_path = model_graph;
|
|
||||||
model_opts.param_path = model_params;
|
|
||||||
model_opts.cache_names = FLAGS_model_cache_names;
|
|
||||||
model_opts.cache_shape = FLAGS_model_cache_shapes;
|
|
||||||
model_opts.input_names = FLAGS_model_input_names;
|
|
||||||
model_opts.output_names = FLAGS_model_output_names;
|
|
||||||
std::shared_ptr<ppspeech::PaddleNnet> nnet(
|
|
||||||
new ppspeech::PaddleNnet(model_opts));
|
|
||||||
std::shared_ptr<ppspeech::DataCache> raw_data(new ppspeech::DataCache());
|
|
||||||
std::shared_ptr<ppspeech::Decodable> decodable(
|
|
||||||
new ppspeech::Decodable(nnet, raw_data, FLAGS_acoustic_scale));
|
|
||||||
|
|
||||||
int32 chunk_size = FLAGS_receptive_field_length;
|
|
||||||
int32 chunk_stride = FLAGS_downsampling_rate;
|
|
||||||
int32 receptive_field_length = FLAGS_receptive_field_length;
|
|
||||||
LOG(INFO) << "chunk size (frame): " << chunk_size;
|
|
||||||
LOG(INFO) << "chunk stride (frame): " << chunk_stride;
|
|
||||||
LOG(INFO) << "receptive field (frame): " << receptive_field_length;
|
|
||||||
decoder.InitDecoder();
|
|
||||||
kaldi::Timer timer;
|
|
||||||
for (; !feature_reader.Done(); feature_reader.Next()) {
|
|
||||||
string utt = feature_reader.Key();
|
|
||||||
kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
|
|
||||||
raw_data->SetDim(feature.NumCols());
|
|
||||||
LOG(INFO) << "process utt: " << utt;
|
|
||||||
LOG(INFO) << "rows: " << feature.NumRows();
|
|
||||||
LOG(INFO) << "cols: " << feature.NumCols();
|
|
||||||
|
|
||||||
int32 row_idx = 0;
|
|
||||||
int32 padding_len = 0;
|
|
||||||
int32 ori_feature_len = feature.NumRows();
|
|
||||||
if ((feature.NumRows() - chunk_size) % chunk_stride != 0) {
|
|
||||||
padding_len =
|
|
||||||
chunk_stride - (feature.NumRows() - chunk_size) % chunk_stride;
|
|
||||||
feature.Resize(feature.NumRows() + padding_len,
|
|
||||||
feature.NumCols(),
|
|
||||||
kaldi::kCopyData);
|
|
||||||
}
|
|
||||||
int32 num_chunks = (feature.NumRows() - chunk_size) / chunk_stride + 1;
|
|
||||||
for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) {
|
|
||||||
kaldi::Vector<kaldi::BaseFloat> feature_chunk(chunk_size *
|
|
||||||
feature.NumCols());
|
|
||||||
int32 feature_chunk_size = 0;
|
|
||||||
if (ori_feature_len > chunk_idx * chunk_stride) {
|
|
||||||
feature_chunk_size = std::min(
|
|
||||||
ori_feature_len - chunk_idx * chunk_stride, chunk_size);
|
|
||||||
}
|
|
||||||
if (feature_chunk_size < receptive_field_length) break;
|
|
||||||
|
|
||||||
int32 start = chunk_idx * chunk_stride;
|
|
||||||
for (int row_id = 0; row_id < chunk_size; ++row_id) {
|
|
||||||
kaldi::SubVector<kaldi::BaseFloat> tmp(feature, start);
|
|
||||||
kaldi::SubVector<kaldi::BaseFloat> f_chunk_tmp(
|
|
||||||
feature_chunk.Data() + row_id * feature.NumCols(),
|
|
||||||
feature.NumCols());
|
|
||||||
f_chunk_tmp.CopyFromVec(tmp);
|
|
||||||
++start;
|
|
||||||
}
|
|
||||||
raw_data->Accept(feature_chunk);
|
|
||||||
if (chunk_idx == num_chunks - 1) {
|
|
||||||
raw_data->SetFinished();
|
|
||||||
}
|
|
||||||
decoder.AdvanceDecode(decodable);
|
|
||||||
}
|
|
||||||
std::string result;
|
|
||||||
result = decoder.GetFinalBestPath();
|
|
||||||
decodable->Reset();
|
|
||||||
decoder.Reset();
|
|
||||||
if (result.empty()) {
|
|
||||||
// the TokenWriter can not write empty string.
|
|
||||||
++num_err;
|
|
||||||
KALDI_LOG << " the result of " << utt << " is empty";
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
KALDI_LOG << " the result of " << utt << " is " << result;
|
|
||||||
result_writer.Write(utt, result);
|
|
||||||
++num_done;
|
|
||||||
}
|
|
||||||
|
|
||||||
double elapsed = timer.Elapsed();
|
|
||||||
KALDI_LOG << " cost:" << elapsed << " s";
|
|
||||||
|
|
||||||
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
|
|
||||||
<< " with errors.";
|
|
||||||
return (num_done != 0 ? 0 : 1);
|
|
||||||
}
|
|
@ -1,2 +0,0 @@
|
|||||||
exp
|
|
||||||
data
|
|
@ -1,16 +0,0 @@
|
|||||||
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
|
|
||||||
|
|
||||||
set(bin_name linear-spectrogram-wo-db-norm-ol)
|
|
||||||
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
|
|
||||||
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
|
||||||
target_link_libraries(${bin_name} frontend kaldi-util kaldi-feat-common gflags glog)
|
|
||||||
|
|
||||||
set(bin_name compute_fbank_main)
|
|
||||||
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
|
|
||||||
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
|
||||||
target_link_libraries(${bin_name} frontend kaldi-util kaldi-feat-common gflags glog)
|
|
||||||
|
|
||||||
set(bin_name cmvn-json2kaldi)
|
|
||||||
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
|
|
||||||
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
|
||||||
target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog)
|
|
@ -1,7 +0,0 @@
|
|||||||
# Deepspeech2 Straming Audio Feature
|
|
||||||
|
|
||||||
ASR audio feature test bins. We using theses bins to test linaer/fbank/mfcc asr feature as streaming manner.
|
|
||||||
|
|
||||||
* linear_spectrogram_without_db_norm_main.cc
|
|
||||||
|
|
||||||
compute linear spectrogram w/o db norm in streaming manner.
|
|
@ -1,85 +0,0 @@
|
|||||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
// Note: Do not print/log ondemand object.
|
|
||||||
|
|
||||||
#include "base/common.h"
|
|
||||||
#include "base/flags.h"
|
|
||||||
#include "base/log.h"
|
|
||||||
#include "kaldi/matrix/kaldi-matrix.h"
|
|
||||||
#include "kaldi/util/kaldi-io.h"
|
|
||||||
#include "utils/file_utils.h"
|
|
||||||
// #include "boost/json.hpp"
|
|
||||||
#include <boost/json/src.hpp>
|
|
||||||
|
|
||||||
DEFINE_string(json_file, "", "cmvn json file");
|
|
||||||
DEFINE_string(cmvn_write_path, "./cmvn.ark", "write cmvn");
|
|
||||||
DEFINE_bool(binary, true, "write cmvn in binary (true) or text(false)");
|
|
||||||
|
|
||||||
using namespace boost::json; // from <boost/json.hpp>
|
|
||||||
|
|
||||||
int main(int argc, char* argv[]) {
|
|
||||||
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
|
||||||
google::InitGoogleLogging(argv[0]);
|
|
||||||
|
|
||||||
LOG(INFO) << "cmvn josn path: " << FLAGS_json_file;
|
|
||||||
|
|
||||||
auto ifs = std::ifstream(FLAGS_json_file);
|
|
||||||
std::string json_str = ppspeech::ReadFile2String(FLAGS_json_file);
|
|
||||||
auto value = boost::json::parse(json_str);
|
|
||||||
if (!value.is_object()) {
|
|
||||||
LOG(ERROR) << "Input json file format error.";
|
|
||||||
}
|
|
||||||
|
|
||||||
for (auto obj : value.as_object()) {
|
|
||||||
if (obj.key() == "mean_stat") {
|
|
||||||
LOG(INFO) << "mean_stat:" << obj.value();
|
|
||||||
}
|
|
||||||
if (obj.key() == "var_stat") {
|
|
||||||
LOG(INFO) << "var_stat: " << obj.value();
|
|
||||||
}
|
|
||||||
if (obj.key() == "frame_num") {
|
|
||||||
LOG(INFO) << "frame_num: " << obj.value();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
boost::json::array mean_stat = value.at("mean_stat").as_array();
|
|
||||||
std::vector<kaldi::BaseFloat> mean_stat_vec;
|
|
||||||
for (auto it = mean_stat.begin(); it != mean_stat.end(); it++) {
|
|
||||||
mean_stat_vec.push_back(it->as_double());
|
|
||||||
}
|
|
||||||
|
|
||||||
boost::json::array var_stat = value.at("var_stat").as_array();
|
|
||||||
std::vector<kaldi::BaseFloat> var_stat_vec;
|
|
||||||
for (auto it = var_stat.begin(); it != var_stat.end(); it++) {
|
|
||||||
var_stat_vec.push_back(it->as_double());
|
|
||||||
}
|
|
||||||
|
|
||||||
kaldi::int32 frame_num = uint64_t(value.at("frame_num").as_int64());
|
|
||||||
LOG(INFO) << "nframe: " << frame_num;
|
|
||||||
|
|
||||||
size_t mean_size = mean_stat_vec.size();
|
|
||||||
kaldi::Matrix<double> cmvn_stats(2, mean_size + 1);
|
|
||||||
for (size_t idx = 0; idx < mean_size; ++idx) {
|
|
||||||
cmvn_stats(0, idx) = mean_stat_vec[idx];
|
|
||||||
cmvn_stats(1, idx) = var_stat_vec[idx];
|
|
||||||
}
|
|
||||||
cmvn_stats(0, mean_size) = frame_num;
|
|
||||||
LOG(INFO) << cmvn_stats;
|
|
||||||
|
|
||||||
kaldi::WriteKaldiObject(cmvn_stats, FLAGS_cmvn_write_path, FLAGS_binary);
|
|
||||||
LOG(INFO) << "cmvn stats have write into: " << FLAGS_cmvn_write_path;
|
|
||||||
LOG(INFO) << "Binary: " << FLAGS_binary;
|
|
||||||
return 0;
|
|
||||||
}
|
|
@ -1,143 +0,0 @@
|
|||||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
// todo refactor, repalce with gtest
|
|
||||||
|
|
||||||
#include "base/flags.h"
|
|
||||||
#include "base/log.h"
|
|
||||||
#include "kaldi/feat/wave-reader.h"
|
|
||||||
#include "kaldi/util/kaldi-io.h"
|
|
||||||
#include "kaldi/util/table-types.h"
|
|
||||||
|
|
||||||
#include "frontend/audio/audio_cache.h"
|
|
||||||
#include "frontend/audio/data_cache.h"
|
|
||||||
#include "frontend/audio/fbank.h"
|
|
||||||
#include "frontend/audio/feature_cache.h"
|
|
||||||
#include "frontend/audio/frontend_itf.h"
|
|
||||||
#include "frontend/audio/normalizer.h"
|
|
||||||
|
|
||||||
DEFINE_string(wav_rspecifier, "", "test wav scp path");
|
|
||||||
DEFINE_string(feature_wspecifier, "", "output feats wspecifier");
|
|
||||||
DEFINE_string(cmvn_file, "", "read cmvn");
|
|
||||||
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
|
|
||||||
DEFINE_int32(num_bins, 161, "fbank num bins");
|
|
||||||
|
|
||||||
int main(int argc, char* argv[]) {
|
|
||||||
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
|
||||||
google::InitGoogleLogging(argv[0]);
|
|
||||||
|
|
||||||
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
|
|
||||||
FLAGS_wav_rspecifier);
|
|
||||||
kaldi::BaseFloatMatrixWriter feat_writer(FLAGS_feature_wspecifier);
|
|
||||||
|
|
||||||
int32 num_done = 0, num_err = 0;
|
|
||||||
|
|
||||||
// feature pipeline: wave cache --> povey window
|
|
||||||
// -->fbank --> global cmvn -> feat cache
|
|
||||||
|
|
||||||
std::unique_ptr<ppspeech::FrontendInterface> data_source(
|
|
||||||
new ppspeech::AudioCache(3600 * 1600, false));
|
|
||||||
|
|
||||||
ppspeech::FbankOptions opt;
|
|
||||||
opt.fbank_opts.frame_opts.frame_length_ms = 25;
|
|
||||||
opt.fbank_opts.frame_opts.frame_shift_ms = 10;
|
|
||||||
opt.streaming_chunk = FLAGS_streaming_chunk;
|
|
||||||
opt.fbank_opts.mel_opts.num_bins = FLAGS_num_bins;
|
|
||||||
opt.fbank_opts.frame_opts.dither = 0.0;
|
|
||||||
|
|
||||||
std::unique_ptr<ppspeech::FrontendInterface> fbank(
|
|
||||||
new ppspeech::Fbank(opt, std::move(data_source)));
|
|
||||||
|
|
||||||
std::unique_ptr<ppspeech::FrontendInterface> cmvn(
|
|
||||||
new ppspeech::CMVN(FLAGS_cmvn_file, std::move(fbank)));
|
|
||||||
|
|
||||||
ppspeech::FeatureCacheOptions feat_cache_opts;
|
|
||||||
// the feature cache output feature chunk by chunk.
|
|
||||||
// frame_chunk_size : num frame of a chunk.
|
|
||||||
// frame_chunk_stride: chunk sliding window stride.
|
|
||||||
feat_cache_opts.frame_chunk_stride = 1;
|
|
||||||
feat_cache_opts.frame_chunk_size = 1;
|
|
||||||
ppspeech::FeatureCache feature_cache(feat_cache_opts, std::move(cmvn));
|
|
||||||
LOG(INFO) << "fbank: " << true;
|
|
||||||
LOG(INFO) << "feat dim: " << feature_cache.Dim();
|
|
||||||
|
|
||||||
int sample_rate = 16000;
|
|
||||||
float streaming_chunk = FLAGS_streaming_chunk;
|
|
||||||
int chunk_sample_size = streaming_chunk * sample_rate;
|
|
||||||
LOG(INFO) << "sr: " << sample_rate;
|
|
||||||
LOG(INFO) << "chunk size (s): " << streaming_chunk;
|
|
||||||
LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
|
|
||||||
|
|
||||||
for (; !wav_reader.Done(); wav_reader.Next()) {
|
|
||||||
std::string utt = wav_reader.Key();
|
|
||||||
const kaldi::WaveData& wave_data = wav_reader.Value();
|
|
||||||
LOG(INFO) << "process utt: " << utt;
|
|
||||||
|
|
||||||
int32 this_channel = 0;
|
|
||||||
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
|
|
||||||
this_channel);
|
|
||||||
int tot_samples = waveform.Dim();
|
|
||||||
LOG(INFO) << "wav len (sample): " << tot_samples;
|
|
||||||
|
|
||||||
int sample_offset = 0;
|
|
||||||
std::vector<kaldi::Vector<BaseFloat>> feats;
|
|
||||||
int feature_rows = 0;
|
|
||||||
while (sample_offset < tot_samples) {
|
|
||||||
int cur_chunk_size =
|
|
||||||
std::min(chunk_sample_size, tot_samples - sample_offset);
|
|
||||||
|
|
||||||
kaldi::Vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
|
|
||||||
for (int i = 0; i < cur_chunk_size; ++i) {
|
|
||||||
wav_chunk(i) = waveform(sample_offset + i);
|
|
||||||
}
|
|
||||||
|
|
||||||
kaldi::Vector<BaseFloat> features;
|
|
||||||
feature_cache.Accept(wav_chunk);
|
|
||||||
if (cur_chunk_size < chunk_sample_size) {
|
|
||||||
feature_cache.SetFinished();
|
|
||||||
}
|
|
||||||
bool flag = true;
|
|
||||||
do {
|
|
||||||
flag = feature_cache.Read(&features);
|
|
||||||
feats.push_back(features);
|
|
||||||
feature_rows += features.Dim() / feature_cache.Dim();
|
|
||||||
} while (flag == true && features.Dim() != 0);
|
|
||||||
sample_offset += cur_chunk_size;
|
|
||||||
}
|
|
||||||
|
|
||||||
int cur_idx = 0;
|
|
||||||
kaldi::Matrix<kaldi::BaseFloat> features(feature_rows,
|
|
||||||
feature_cache.Dim());
|
|
||||||
for (auto feat : feats) {
|
|
||||||
int num_rows = feat.Dim() / feature_cache.Dim();
|
|
||||||
for (int row_idx = 0; row_idx < num_rows; ++row_idx) {
|
|
||||||
for (size_t col_idx = 0; col_idx < feature_cache.Dim();
|
|
||||||
++col_idx) {
|
|
||||||
features(cur_idx, col_idx) =
|
|
||||||
feat(row_idx * feature_cache.Dim() + col_idx);
|
|
||||||
}
|
|
||||||
++cur_idx;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
feat_writer.Write(utt, features);
|
|
||||||
feature_cache.Reset();
|
|
||||||
|
|
||||||
if (num_done % 50 == 0 && num_done != 0)
|
|
||||||
KALDI_VLOG(2) << "Processed " << num_done << " utterances";
|
|
||||||
num_done++;
|
|
||||||
}
|
|
||||||
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
|
|
||||||
<< " with errors.";
|
|
||||||
return (num_done != 0 ? 0 : 1);
|
|
||||||
}
|
|
@ -1,147 +0,0 @@
|
|||||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
// todo refactor, repalce with gtest
|
|
||||||
|
|
||||||
#include "base/flags.h"
|
|
||||||
#include "base/log.h"
|
|
||||||
#include "kaldi/feat/wave-reader.h"
|
|
||||||
#include "kaldi/util/kaldi-io.h"
|
|
||||||
#include "kaldi/util/table-types.h"
|
|
||||||
|
|
||||||
#include "frontend/audio/audio_cache.h"
|
|
||||||
#include "frontend/audio/data_cache.h"
|
|
||||||
#include "frontend/audio/feature_cache.h"
|
|
||||||
#include "frontend/audio/frontend_itf.h"
|
|
||||||
#include "frontend/audio/linear_spectrogram.h"
|
|
||||||
#include "frontend/audio/normalizer.h"
|
|
||||||
|
|
||||||
DEFINE_string(wav_rspecifier, "", "test wav scp path");
|
|
||||||
DEFINE_string(feature_wspecifier, "", "output feats wspecifier");
|
|
||||||
DEFINE_string(cmvn_file, "./cmvn.ark", "read cmvn");
|
|
||||||
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
|
|
||||||
|
|
||||||
int main(int argc, char* argv[]) {
|
|
||||||
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
|
||||||
google::InitGoogleLogging(argv[0]);
|
|
||||||
|
|
||||||
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
|
|
||||||
FLAGS_wav_rspecifier);
|
|
||||||
kaldi::BaseFloatMatrixWriter feat_writer(FLAGS_feature_wspecifier);
|
|
||||||
|
|
||||||
int32 num_done = 0, num_err = 0;
|
|
||||||
|
|
||||||
// feature pipeline: wave cache --> hanning window
|
|
||||||
// -->linear_spectrogram --> global cmvn -> feat cache
|
|
||||||
|
|
||||||
std::unique_ptr<ppspeech::FrontendInterface> data_source(
|
|
||||||
new ppspeech::AudioCache(3600 * 1600, true));
|
|
||||||
|
|
||||||
ppspeech::LinearSpectrogramOptions opt;
|
|
||||||
opt.frame_opts.frame_length_ms = 20;
|
|
||||||
opt.frame_opts.frame_shift_ms = 10;
|
|
||||||
opt.streaming_chunk = FLAGS_streaming_chunk;
|
|
||||||
opt.frame_opts.dither = 0.0;
|
|
||||||
opt.frame_opts.remove_dc_offset = false;
|
|
||||||
opt.frame_opts.window_type = "hanning";
|
|
||||||
opt.frame_opts.preemph_coeff = 0.0;
|
|
||||||
LOG(INFO) << "linear feature: " << true;
|
|
||||||
LOG(INFO) << "frame length (ms): " << opt.frame_opts.frame_length_ms;
|
|
||||||
LOG(INFO) << "frame shift (ms): " << opt.frame_opts.frame_shift_ms;
|
|
||||||
|
|
||||||
std::unique_ptr<ppspeech::FrontendInterface> linear_spectrogram(
|
|
||||||
new ppspeech::LinearSpectrogram(opt, std::move(data_source)));
|
|
||||||
|
|
||||||
std::unique_ptr<ppspeech::FrontendInterface> cmvn(
|
|
||||||
new ppspeech::CMVN(FLAGS_cmvn_file, std::move(linear_spectrogram)));
|
|
||||||
|
|
||||||
ppspeech::FeatureCacheOptions feat_cache_opts;
|
|
||||||
// the feature cache output feature chunk by chunk.
|
|
||||||
// frame_chunk_size : num frame of a chunk.
|
|
||||||
// frame_chunk_stride: chunk sliding window stride.
|
|
||||||
feat_cache_opts.frame_chunk_stride = 1;
|
|
||||||
feat_cache_opts.frame_chunk_size = 1;
|
|
||||||
ppspeech::FeatureCache feature_cache(feat_cache_opts, std::move(cmvn));
|
|
||||||
LOG(INFO) << "feat dim: " << feature_cache.Dim();
|
|
||||||
|
|
||||||
int sample_rate = 16000;
|
|
||||||
float streaming_chunk = FLAGS_streaming_chunk;
|
|
||||||
int chunk_sample_size = streaming_chunk * sample_rate;
|
|
||||||
LOG(INFO) << "sample rate: " << sample_rate;
|
|
||||||
LOG(INFO) << "chunk size (s): " << streaming_chunk;
|
|
||||||
LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
|
|
||||||
|
|
||||||
|
|
||||||
for (; !wav_reader.Done(); wav_reader.Next()) {
|
|
||||||
std::string utt = wav_reader.Key();
|
|
||||||
const kaldi::WaveData& wave_data = wav_reader.Value();
|
|
||||||
LOG(INFO) << "process utt: " << utt;
|
|
||||||
|
|
||||||
int32 this_channel = 0;
|
|
||||||
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
|
|
||||||
this_channel);
|
|
||||||
int tot_samples = waveform.Dim();
|
|
||||||
LOG(INFO) << "wav len (sample): " << tot_samples;
|
|
||||||
|
|
||||||
int sample_offset = 0;
|
|
||||||
std::vector<kaldi::Vector<BaseFloat>> feats;
|
|
||||||
int feature_rows = 0;
|
|
||||||
while (sample_offset < tot_samples) {
|
|
||||||
int cur_chunk_size =
|
|
||||||
std::min(chunk_sample_size, tot_samples - sample_offset);
|
|
||||||
|
|
||||||
kaldi::Vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
|
|
||||||
for (int i = 0; i < cur_chunk_size; ++i) {
|
|
||||||
wav_chunk(i) = waveform(sample_offset + i);
|
|
||||||
}
|
|
||||||
|
|
||||||
kaldi::Vector<BaseFloat> features;
|
|
||||||
feature_cache.Accept(wav_chunk);
|
|
||||||
if (cur_chunk_size < chunk_sample_size) {
|
|
||||||
feature_cache.SetFinished();
|
|
||||||
}
|
|
||||||
bool flag = true;
|
|
||||||
do {
|
|
||||||
flag = feature_cache.Read(&features);
|
|
||||||
feats.push_back(features);
|
|
||||||
feature_rows += features.Dim() / feature_cache.Dim();
|
|
||||||
} while (flag == true && features.Dim() != 0);
|
|
||||||
sample_offset += cur_chunk_size;
|
|
||||||
}
|
|
||||||
|
|
||||||
int cur_idx = 0;
|
|
||||||
kaldi::Matrix<kaldi::BaseFloat> features(feature_rows,
|
|
||||||
feature_cache.Dim());
|
|
||||||
for (auto feat : feats) {
|
|
||||||
int num_rows = feat.Dim() / feature_cache.Dim();
|
|
||||||
for (int row_idx = 0; row_idx < num_rows; ++row_idx) {
|
|
||||||
for (size_t col_idx = 0; col_idx < feature_cache.Dim();
|
|
||||||
++col_idx) {
|
|
||||||
features(cur_idx, col_idx) =
|
|
||||||
feat(row_idx * feature_cache.Dim() + col_idx);
|
|
||||||
}
|
|
||||||
++cur_idx;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
feat_writer.Write(utt, features);
|
|
||||||
feature_cache.Reset();
|
|
||||||
|
|
||||||
if (num_done % 50 == 0 && num_done != 0)
|
|
||||||
KALDI_VLOG(2) << "Processed " << num_done << " utterances";
|
|
||||||
num_done++;
|
|
||||||
}
|
|
||||||
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
|
|
||||||
<< " with errors.";
|
|
||||||
return (num_done != 0 ? 0 : 1);
|
|
||||||
}
|
|
@ -1,14 +0,0 @@
|
|||||||
# This contains the locations of binarys build required for running the examples.
|
|
||||||
|
|
||||||
SPEECHX_ROOT=$PWD/../../../
|
|
||||||
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
|
|
||||||
|
|
||||||
SPEECHX_TOOLS=$SPEECHX_ROOT/tools
|
|
||||||
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
|
|
||||||
|
|
||||||
[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; }
|
|
||||||
|
|
||||||
export LC_AL=C
|
|
||||||
|
|
||||||
SPEECHX_BIN=$SPEECHX_EXAMPLES/ds2_ol/feat
|
|
||||||
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN
|
|
@ -1,57 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
set +x
|
|
||||||
set -e
|
|
||||||
|
|
||||||
. ./path.sh
|
|
||||||
|
|
||||||
# 1. compile
|
|
||||||
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
|
|
||||||
pushd ${SPEECHX_ROOT}
|
|
||||||
bash build.sh
|
|
||||||
popd
|
|
||||||
fi
|
|
||||||
|
|
||||||
# 2. download model
|
|
||||||
if [ ! -e data/model/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz ]; then
|
|
||||||
mkdir -p data/model
|
|
||||||
pushd data/model
|
|
||||||
wget -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
|
|
||||||
tar xzfv asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
|
|
||||||
popd
|
|
||||||
fi
|
|
||||||
|
|
||||||
# produce wav scp
|
|
||||||
if [ ! -f data/wav.scp ]; then
|
|
||||||
mkdir -p data
|
|
||||||
pushd data
|
|
||||||
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
|
|
||||||
echo "utt1 " $PWD/zh.wav > wav.scp
|
|
||||||
popd
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
# input
|
|
||||||
data_dir=./data
|
|
||||||
exp_dir=./exp
|
|
||||||
model_dir=$data_dir/model/
|
|
||||||
|
|
||||||
mkdir -p $exp_dir
|
|
||||||
|
|
||||||
|
|
||||||
# 3. run feat
|
|
||||||
export GLOG_logtostderr=1
|
|
||||||
|
|
||||||
cmvn-json2kaldi \
|
|
||||||
--json_file $model_dir/data/mean_std.json \
|
|
||||||
--cmvn_write_path $exp_dir/cmvn.ark \
|
|
||||||
--binary=false
|
|
||||||
echo "convert json cmvn to kaldi ark."
|
|
||||||
|
|
||||||
|
|
||||||
linear-spectrogram-wo-db-norm-ol \
|
|
||||||
--wav_rspecifier=scp:$data_dir/wav.scp \
|
|
||||||
--feature_wspecifier=ark,t:$exp_dir/feats.ark \
|
|
||||||
--cmvn_file=$exp_dir/cmvn.ark
|
|
||||||
echo "compute linear spectrogram feature."
|
|
||||||
|
|
||||||
|
|
@ -1,24 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
# this script is for memory check, so please run ./run.sh first.
|
|
||||||
|
|
||||||
set +x
|
|
||||||
set -e
|
|
||||||
|
|
||||||
. ./path.sh
|
|
||||||
|
|
||||||
if [ ! -d ${SPEECHX_TOOLS}/valgrind/install ]; then
|
|
||||||
echo "please install valgrind in the speechx tools dir.\n"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
model_dir=../paddle_asr_model
|
|
||||||
feat_wspecifier=./feats.ark
|
|
||||||
cmvn=./cmvn.ark
|
|
||||||
|
|
||||||
valgrind --tool=memcheck --track-origins=yes --leak-check=full --show-leak-kinds=all \
|
|
||||||
linear_spectrogram_main \
|
|
||||||
--wav_rspecifier=scp:$model_dir/wav.scp \
|
|
||||||
--feature_wspecifier=ark,t:$feat_wspecifier \
|
|
||||||
--cmvn_write_path=$cmvn
|
|
||||||
|
|
@ -1,2 +0,0 @@
|
|||||||
data
|
|
||||||
exp
|
|
@ -1,6 +0,0 @@
|
|||||||
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
|
|
||||||
|
|
||||||
set(bin_name ds2-model-ol-test)
|
|
||||||
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
|
|
||||||
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
|
||||||
target_link_libraries(${bin_name} PUBLIC nnet gflags glog ${DEPS})
|
|
@ -1,3 +0,0 @@
|
|||||||
# Deepspeech2 Streaming NNet Test
|
|
||||||
|
|
||||||
Using for ds2 streaming nnet inference test.
|
|
@ -1,203 +0,0 @@
|
|||||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
// deepspeech2 online model info
|
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
#include <fstream>
|
|
||||||
#include <functional>
|
|
||||||
#include <iostream>
|
|
||||||
#include <iterator>
|
|
||||||
#include <numeric>
|
|
||||||
#include <thread>
|
|
||||||
#include "base/flags.h"
|
|
||||||
#include "base/log.h"
|
|
||||||
#include "paddle_inference_api.h"
|
|
||||||
|
|
||||||
using std::cout;
|
|
||||||
using std::endl;
|
|
||||||
|
|
||||||
|
|
||||||
DEFINE_string(model_path, "", "xxx.pdmodel");
|
|
||||||
DEFINE_string(param_path, "", "xxx.pdiparams");
|
|
||||||
DEFINE_int32(chunk_size, 35, "feature chunk size, unit:frame");
|
|
||||||
DEFINE_int32(feat_dim, 161, "feature dim");
|
|
||||||
|
|
||||||
|
|
||||||
void produce_data(std::vector<std::vector<float>>* data);
|
|
||||||
void model_forward_test();
|
|
||||||
|
|
||||||
void produce_data(std::vector<std::vector<float>>* data) {
|
|
||||||
int chunk_size = FLAGS_chunk_size; // chunk_size in frame
|
|
||||||
int col_size = FLAGS_feat_dim; // feat dim
|
|
||||||
cout << "chunk size: " << chunk_size << endl;
|
|
||||||
cout << "feat dim: " << col_size << endl;
|
|
||||||
|
|
||||||
data->reserve(chunk_size);
|
|
||||||
data->back().reserve(col_size);
|
|
||||||
for (int row = 0; row < chunk_size; ++row) {
|
|
||||||
data->push_back(std::vector<float>());
|
|
||||||
for (int col_idx = 0; col_idx < col_size; ++col_idx) {
|
|
||||||
data->back().push_back(0.201);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void model_forward_test() {
|
|
||||||
std::cout << "1. read the data" << std::endl;
|
|
||||||
std::vector<std::vector<float>> feats;
|
|
||||||
produce_data(&feats);
|
|
||||||
|
|
||||||
std::cout << "2. load the model" << std::endl;
|
|
||||||
;
|
|
||||||
std::string model_graph = FLAGS_model_path;
|
|
||||||
std::string model_params = FLAGS_param_path;
|
|
||||||
CHECK(model_graph != "");
|
|
||||||
CHECK(model_params != "");
|
|
||||||
cout << "model path: " << model_graph << endl;
|
|
||||||
cout << "model param path : " << model_params << endl;
|
|
||||||
|
|
||||||
paddle_infer::Config config;
|
|
||||||
config.SetModel(model_graph, model_params);
|
|
||||||
config.SwitchIrOptim(false);
|
|
||||||
cout << "SwitchIrOptim: " << false << endl;
|
|
||||||
config.DisableFCPadding();
|
|
||||||
cout << "DisableFCPadding: " << endl;
|
|
||||||
auto predictor = paddle_infer::CreatePredictor(config);
|
|
||||||
|
|
||||||
std::cout << "3. feat shape, row=" << feats.size()
|
|
||||||
<< ",col=" << feats[0].size() << std::endl;
|
|
||||||
std::vector<float> pp_input_mat;
|
|
||||||
for (const auto& item : feats) {
|
|
||||||
pp_input_mat.insert(pp_input_mat.end(), item.begin(), item.end());
|
|
||||||
}
|
|
||||||
|
|
||||||
std::cout << "4. fead the data to model" << std::endl;
|
|
||||||
int row = feats.size();
|
|
||||||
int col = feats[0].size();
|
|
||||||
std::vector<std::string> input_names = predictor->GetInputNames();
|
|
||||||
std::vector<std::string> output_names = predictor->GetOutputNames();
|
|
||||||
for (auto name : input_names) {
|
|
||||||
cout << "model input names: " << name << endl;
|
|
||||||
}
|
|
||||||
for (auto name : output_names) {
|
|
||||||
cout << "model output names: " << name << endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
// input
|
|
||||||
std::unique_ptr<paddle_infer::Tensor> input_tensor =
|
|
||||||
predictor->GetInputHandle(input_names[0]);
|
|
||||||
std::vector<int> INPUT_SHAPE = {1, row, col};
|
|
||||||
input_tensor->Reshape(INPUT_SHAPE);
|
|
||||||
input_tensor->CopyFromCpu(pp_input_mat.data());
|
|
||||||
|
|
||||||
// input length
|
|
||||||
std::unique_ptr<paddle_infer::Tensor> input_len =
|
|
||||||
predictor->GetInputHandle(input_names[1]);
|
|
||||||
std::vector<int> input_len_size = {1};
|
|
||||||
input_len->Reshape(input_len_size);
|
|
||||||
std::vector<int64_t> audio_len;
|
|
||||||
audio_len.push_back(row);
|
|
||||||
input_len->CopyFromCpu(audio_len.data());
|
|
||||||
|
|
||||||
// state_h
|
|
||||||
std::unique_ptr<paddle_infer::Tensor> chunk_state_h_box =
|
|
||||||
predictor->GetInputHandle(input_names[2]);
|
|
||||||
std::vector<int> chunk_state_h_box_shape = {5, 1, 1024};
|
|
||||||
chunk_state_h_box->Reshape(chunk_state_h_box_shape);
|
|
||||||
int chunk_state_h_box_size =
|
|
||||||
std::accumulate(chunk_state_h_box_shape.begin(),
|
|
||||||
chunk_state_h_box_shape.end(),
|
|
||||||
1,
|
|
||||||
std::multiplies<int>());
|
|
||||||
std::vector<float> chunk_state_h_box_data(chunk_state_h_box_size, 0.0f);
|
|
||||||
chunk_state_h_box->CopyFromCpu(chunk_state_h_box_data.data());
|
|
||||||
|
|
||||||
// state_c
|
|
||||||
std::unique_ptr<paddle_infer::Tensor> chunk_state_c_box =
|
|
||||||
predictor->GetInputHandle(input_names[3]);
|
|
||||||
std::vector<int> chunk_state_c_box_shape = {5, 1, 1024};
|
|
||||||
chunk_state_c_box->Reshape(chunk_state_c_box_shape);
|
|
||||||
int chunk_state_c_box_size =
|
|
||||||
std::accumulate(chunk_state_c_box_shape.begin(),
|
|
||||||
chunk_state_c_box_shape.end(),
|
|
||||||
1,
|
|
||||||
std::multiplies<int>());
|
|
||||||
std::vector<float> chunk_state_c_box_data(chunk_state_c_box_size, 0.0f);
|
|
||||||
chunk_state_c_box->CopyFromCpu(chunk_state_c_box_data.data());
|
|
||||||
|
|
||||||
// run
|
|
||||||
bool success = predictor->Run();
|
|
||||||
|
|
||||||
// state_h out
|
|
||||||
std::unique_ptr<paddle_infer::Tensor> h_out =
|
|
||||||
predictor->GetOutputHandle(output_names[2]);
|
|
||||||
std::vector<int> h_out_shape = h_out->shape();
|
|
||||||
int h_out_size = std::accumulate(
|
|
||||||
h_out_shape.begin(), h_out_shape.end(), 1, std::multiplies<int>());
|
|
||||||
std::vector<float> h_out_data(h_out_size);
|
|
||||||
h_out->CopyToCpu(h_out_data.data());
|
|
||||||
|
|
||||||
// stage_c out
|
|
||||||
std::unique_ptr<paddle_infer::Tensor> c_out =
|
|
||||||
predictor->GetOutputHandle(output_names[3]);
|
|
||||||
std::vector<int> c_out_shape = c_out->shape();
|
|
||||||
int c_out_size = std::accumulate(
|
|
||||||
c_out_shape.begin(), c_out_shape.end(), 1, std::multiplies<int>());
|
|
||||||
std::vector<float> c_out_data(c_out_size);
|
|
||||||
c_out->CopyToCpu(c_out_data.data());
|
|
||||||
|
|
||||||
// output tensor
|
|
||||||
std::unique_ptr<paddle_infer::Tensor> output_tensor =
|
|
||||||
predictor->GetOutputHandle(output_names[0]);
|
|
||||||
std::vector<int> output_shape = output_tensor->shape();
|
|
||||||
std::vector<float> output_probs;
|
|
||||||
int output_size = std::accumulate(
|
|
||||||
output_shape.begin(), output_shape.end(), 1, std::multiplies<int>());
|
|
||||||
output_probs.resize(output_size);
|
|
||||||
output_tensor->CopyToCpu(output_probs.data());
|
|
||||||
row = output_shape[1];
|
|
||||||
col = output_shape[2];
|
|
||||||
|
|
||||||
// probs
|
|
||||||
std::vector<std::vector<float>> probs;
|
|
||||||
probs.reserve(row);
|
|
||||||
for (int i = 0; i < row; i++) {
|
|
||||||
probs.push_back(std::vector<float>());
|
|
||||||
probs.back().reserve(col);
|
|
||||||
|
|
||||||
for (int j = 0; j < col; j++) {
|
|
||||||
probs.back().push_back(output_probs[i * col + j]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<std::vector<float>> log_feat = probs;
|
|
||||||
std::cout << "probs, row: " << log_feat.size()
|
|
||||||
<< " col: " << log_feat[0].size() << std::endl;
|
|
||||||
for (size_t row_idx = 0; row_idx < log_feat.size(); ++row_idx) {
|
|
||||||
for (size_t col_idx = 0; col_idx < log_feat[row_idx].size();
|
|
||||||
++col_idx) {
|
|
||||||
std::cout << log_feat[row_idx][col_idx] << " ";
|
|
||||||
}
|
|
||||||
std::cout << std::endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
int main(int argc, char* argv[]) {
|
|
||||||
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
|
||||||
google::InitGoogleLogging(argv[0]);
|
|
||||||
|
|
||||||
model_forward_test();
|
|
||||||
return 0;
|
|
||||||
}
|
|
@ -1,14 +0,0 @@
|
|||||||
# This contains the locations of binarys build required for running the examples.
|
|
||||||
|
|
||||||
SPEECHX_ROOT=$PWD/../../../
|
|
||||||
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
|
|
||||||
|
|
||||||
SPEECHX_TOOLS=$SPEECHX_ROOT/tools
|
|
||||||
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
|
|
||||||
|
|
||||||
[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; }
|
|
||||||
|
|
||||||
export LC_AL=C
|
|
||||||
|
|
||||||
SPEECHX_BIN=$SPEECHX_EXAMPLES/ds2_ol/nnet
|
|
||||||
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN
|
|
@ -1,38 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
set +x
|
|
||||||
set -e
|
|
||||||
|
|
||||||
. path.sh
|
|
||||||
|
|
||||||
# 1. compile
|
|
||||||
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
|
|
||||||
pushd ${SPEECHX_ROOT}
|
|
||||||
bash build.sh
|
|
||||||
popd
|
|
||||||
fi
|
|
||||||
|
|
||||||
# 2. download model
|
|
||||||
if [ ! -f data/model/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz ]; then
|
|
||||||
mkdir -p data/model
|
|
||||||
pushd data/model
|
|
||||||
wget -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
|
|
||||||
tar xzfv asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
|
|
||||||
popd
|
|
||||||
fi
|
|
||||||
|
|
||||||
# produce wav scp
|
|
||||||
if [ ! -f data/wav.scp ]; then
|
|
||||||
mkdir -p data
|
|
||||||
pushd data
|
|
||||||
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
|
|
||||||
echo "utt1 " $PWD/zh.wav > wav.scp
|
|
||||||
popd
|
|
||||||
fi
|
|
||||||
|
|
||||||
ckpt_dir=./data/model
|
|
||||||
model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/
|
|
||||||
|
|
||||||
ds2-model-ol-test \
|
|
||||||
--model_path=$model_dir/avg_1.jit.pdmodel \
|
|
||||||
--param_path=$model_dir/avg_1.jit.pdiparams
|
|
||||||
|
|
@ -1,20 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
# this script is for memory check, so please run ./run.sh first.
|
|
||||||
|
|
||||||
set +x
|
|
||||||
set -e
|
|
||||||
|
|
||||||
. ./path.sh
|
|
||||||
|
|
||||||
if [ ! -d ${SPEECHX_TOOLS}/valgrind/install ]; then
|
|
||||||
echo "please install valgrind in the speechx tools dir.\n"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
model_dir=../paddle_asr_model
|
|
||||||
|
|
||||||
valgrind --tool=memcheck --track-origins=yes --leak-check=full --show-leak-kinds=all \
|
|
||||||
pp-model-test \
|
|
||||||
--model_path=$model_dir/avg_1.jit.pdmodel \
|
|
||||||
--param_path=$model_dir/avg_1.jit.pdparams
|
|
@ -1,4 +1,4 @@
|
|||||||
# Utils
|
# Utils
|
||||||
|
|
||||||
* [kaldi utils](https://github.com/kaldi-asr/kaldi/blob/cbed4ff688/egs/wsj/s5/utils)
|
* [kaldi utils](https://github.com/kaldi-asr/kaldi/blob/cbed4ff688/egs/wsj/s5/utils)
|
||||||
* [espnet utils)(https://github.com/espnet/espnet/tree/master/utils)
|
* [espnet utils](https://github.com/espnet/espnet/tree/master/utils)
|
||||||
|
Loading…
Reference in new issue