parent
f852514a3e
commit
418cc37ffb
@ -0,0 +1,8 @@
|
||||
# Codelab
|
||||
|
||||
## introduction
|
||||
|
||||
> The below is for developing and offline testing. Do not run it only if you know what it is.
|
||||
* nnet
|
||||
* feat
|
||||
* decoder
|
@ -0,0 +1,2 @@
|
||||
data
|
||||
exp
|
@ -0,0 +1,12 @@
|
||||
# ASR Decoder
|
||||
|
||||
ASR Decoder test bins. We using theses bins to test CTC BeamSearch decoder and WFST decoder.
|
||||
|
||||
* decoder_test_main.cc
|
||||
feed nnet output logprob, and only test decoder
|
||||
|
||||
* offline_decoder_sliding_chunk_main.cc
|
||||
feed streaming audio feature, decode as streaming manner.
|
||||
|
||||
* offline_wfst_decoder_main.cc
|
||||
feed streaming audio feature, decode using WFST as streaming manner.
|
@ -0,0 +1,14 @@
|
||||
# This contains the locations of binarys build required for running the examples.
|
||||
|
||||
SPEECHX_ROOT=$PWD/../../../
|
||||
SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx
|
||||
|
||||
SPEECHX_TOOLS=$SPEECHX_ROOT/tools
|
||||
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
|
||||
|
||||
[ -d $SPEECHX_BUILD ] || { echo "Error: 'build/speechx' directory not found. please ensure that the project build successfully"; }
|
||||
|
||||
export LC_AL=C
|
||||
|
||||
SPEECHX_BIN=$SPEECHX_ROOT/build/speechx/decoder:$SPEECHX_ROOT/build/speechx/frontend/audio
|
||||
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN
|
@ -0,0 +1,78 @@
|
||||
#!/bin/bash
|
||||
set +x
|
||||
set -e
|
||||
|
||||
. path.sh
|
||||
|
||||
# 1. compile
|
||||
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
|
||||
pushd ${SPEECHX_ROOT}
|
||||
bash build.sh
|
||||
popd
|
||||
fi
|
||||
|
||||
# input
|
||||
mkdir -p data
|
||||
data=$PWD/data
|
||||
ckpt_dir=$data/model
|
||||
model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/
|
||||
vocb_dir=$ckpt_dir/data/lang_char/
|
||||
|
||||
lm=$data/zh_giga.no_cna_cmn.prune01244.klm
|
||||
|
||||
# output
|
||||
exp_dir=./exp
|
||||
mkdir -p $exp_dir
|
||||
|
||||
# 2. download model
|
||||
if [[ ! -f data/model/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz ]]; then
|
||||
mkdir -p data/model
|
||||
pushd data/model
|
||||
wget -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
|
||||
tar xzfv asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
|
||||
popd
|
||||
fi
|
||||
|
||||
# produce wav scp
|
||||
if [ ! -f data/wav.scp ]; then
|
||||
pushd data
|
||||
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
|
||||
echo "utt1 " $PWD/zh.wav > wav.scp
|
||||
popd
|
||||
fi
|
||||
|
||||
# download lm
|
||||
if [ ! -f $lm ]; then
|
||||
pushd data
|
||||
wget -c https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm
|
||||
popd
|
||||
fi
|
||||
|
||||
feat_wspecifier=$exp_dir/feats.ark
|
||||
cmvn=$exp_dir/cmvn.ark
|
||||
|
||||
export GLOG_logtostderr=1
|
||||
|
||||
# dump json cmvn to kaldi
|
||||
cmvn_json2kaldi_main \
|
||||
--json_file $ckpt_dir/data/mean_std.json \
|
||||
--cmvn_write_path $cmvn \
|
||||
--binary=false
|
||||
echo "convert json cmvn to kaldi ark."
|
||||
|
||||
|
||||
# generate linear feature as streaming
|
||||
compute_linear_spectrogram_main \
|
||||
--wav_rspecifier=scp:$data/wav.scp \
|
||||
--feature_wspecifier=ark,t:$feat_wspecifier \
|
||||
--cmvn_file=$cmvn
|
||||
echo "compute linear spectrogram feature."
|
||||
|
||||
# run ctc beam search decoder as streaming
|
||||
ctc_prefix_beam_search_decoder_main \
|
||||
--result_wspecifier=ark,t:$exp_dir/result.txt \
|
||||
--feature_rspecifier=ark:$feat_wspecifier \
|
||||
--model_path=$model_dir/avg_1.jit.pdmodel \
|
||||
--param_path=$model_dir/avg_1.jit.pdiparams \
|
||||
--dict_file=$vocb_dir/vocab.txt \
|
||||
--lm_path=$lm
|
@ -0,0 +1,26 @@
|
||||
#!/bin/bash
|
||||
|
||||
# this script is for memory check, so please run ./run.sh first.
|
||||
|
||||
set +x
|
||||
set -e
|
||||
|
||||
. ./path.sh
|
||||
|
||||
if [ ! -d ${SPEECHX_TOOLS}/valgrind/install ]; then
|
||||
echo "please install valgrind in the speechx tools dir.\n"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
model_dir=../paddle_asr_model
|
||||
feat_wspecifier=./feats.ark
|
||||
cmvn=./cmvn.ark
|
||||
|
||||
valgrind --tool=memcheck --track-origins=yes --leak-check=full --show-leak-kinds=all \
|
||||
offline_decoder_main \
|
||||
--feature_respecifier=ark:$feat_wspecifier \
|
||||
--model_path=$model_dir/avg_1.jit.pdmodel \
|
||||
--param_path=$model_dir/avg_1.jit.pdparams \
|
||||
--dict_file=$model_dir/vocab.txt \
|
||||
--lm_path=$model_dir/avg_1.jit.klm
|
||||
|
@ -0,0 +1,7 @@
|
||||
# Deepspeech2 Straming Audio Feature
|
||||
|
||||
ASR audio feature test bins. We using theses bins to test linaer/fbank/mfcc asr feature as streaming manner.
|
||||
|
||||
* compute_linear_spectrogram_main.cc
|
||||
|
||||
compute linear spectrogram without db norm in streaming manner.
|
@ -0,0 +1,14 @@
|
||||
# This contains the locations of binarys build required for running the examples.
|
||||
|
||||
SPEECHX_ROOT=$PWD/../../../
|
||||
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
|
||||
|
||||
SPEECHX_TOOLS=$SPEECHX_ROOT/tools
|
||||
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
|
||||
|
||||
[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; }
|
||||
|
||||
export LC_AL=C
|
||||
|
||||
SPEECHX_BIN=$SPEECHX_ROOT/build/speechx/decoder:$SPEECHX_ROOT/build/speechx/frontend/audio
|
||||
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN
|
@ -0,0 +1,57 @@
|
||||
#!/bin/bash
|
||||
set +x
|
||||
set -e
|
||||
|
||||
. ./path.sh
|
||||
|
||||
# 1. compile
|
||||
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
|
||||
pushd ${SPEECHX_ROOT}
|
||||
bash build.sh
|
||||
popd
|
||||
fi
|
||||
|
||||
# 2. download model
|
||||
if [ ! -e data/model/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz ]; then
|
||||
mkdir -p data/model
|
||||
pushd data/model
|
||||
wget -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
|
||||
tar xzfv asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
|
||||
popd
|
||||
fi
|
||||
|
||||
# produce wav scp
|
||||
if [ ! -f data/wav.scp ]; then
|
||||
mkdir -p data
|
||||
pushd data
|
||||
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
|
||||
echo "utt1 " $PWD/zh.wav > wav.scp
|
||||
popd
|
||||
fi
|
||||
|
||||
|
||||
# input
|
||||
data_dir=./data
|
||||
exp_dir=./exp
|
||||
model_dir=$data_dir/model/
|
||||
|
||||
mkdir -p $exp_dir
|
||||
|
||||
|
||||
# 3. run feat
|
||||
export GLOG_logtostderr=1
|
||||
|
||||
cmvn_json2kaldi_main \
|
||||
--json_file $model_dir/data/mean_std.json \
|
||||
--cmvn_write_path $exp_dir/cmvn.ark \
|
||||
--binary=false
|
||||
echo "convert json cmvn to kaldi ark."
|
||||
|
||||
|
||||
compute_linear_spectrogram_main \
|
||||
--wav_rspecifier=scp:$data_dir/wav.scp \
|
||||
--feature_wspecifier=ark,t:$exp_dir/feats.ark \
|
||||
--cmvn_file=$exp_dir/cmvn.ark
|
||||
echo "compute linear spectrogram feature."
|
||||
|
||||
|
@ -0,0 +1,24 @@
|
||||
#!/bin/bash
|
||||
|
||||
# this script is for memory check, so please run ./run.sh first.
|
||||
|
||||
set +x
|
||||
set -e
|
||||
|
||||
. ./path.sh
|
||||
|
||||
if [ ! -d ${SPEECHX_TOOLS}/valgrind/install ]; then
|
||||
echo "please install valgrind in the speechx tools dir.\n"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
model_dir=../paddle_asr_model
|
||||
feat_wspecifier=./feats.ark
|
||||
cmvn=./cmvn.ark
|
||||
|
||||
valgrind --tool=memcheck --track-origins=yes --leak-check=full --show-leak-kinds=all \
|
||||
compute_linear_spectrogram_main \
|
||||
--wav_rspecifier=scp:$model_dir/wav.scp \
|
||||
--feature_wspecifier=ark,t:$feat_wspecifier \
|
||||
--cmvn_write_path=$cmvn
|
||||
|
@ -0,0 +1,2 @@
|
||||
data
|
||||
exp
|
@ -0,0 +1,3 @@
|
||||
# Deepspeech2 Streaming NNet Test
|
||||
|
||||
Using for ds2 streaming nnet inference test.
|
@ -0,0 +1,14 @@
|
||||
# This contains the locations of binarys build required for running the examples.
|
||||
|
||||
SPEECHX_ROOT=$PWD/../../../
|
||||
SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx
|
||||
|
||||
SPEECHX_TOOLS=$SPEECHX_ROOT/tools
|
||||
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
|
||||
|
||||
[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; }
|
||||
|
||||
export LC_AL=C
|
||||
|
||||
SPEECHX_BIN=$SPEECHX_BUILD/codelab/nnet
|
||||
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN
|
@ -0,0 +1,29 @@
|
||||
#!/bin/bash
|
||||
set +x
|
||||
set -e
|
||||
|
||||
. path.sh
|
||||
|
||||
# 1. compile
|
||||
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
|
||||
pushd ${SPEECHX_ROOT}
|
||||
bash build.sh
|
||||
popd
|
||||
fi
|
||||
|
||||
# 2. download model
|
||||
if [ ! -f data/model/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz ]; then
|
||||
mkdir -p data/model
|
||||
pushd data/model
|
||||
wget -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
|
||||
tar xzfv asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
|
||||
popd
|
||||
fi
|
||||
|
||||
ckpt_dir=./data/model
|
||||
model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/
|
||||
|
||||
ds2_model_test_main \
|
||||
--model_path=$model_dir/avg_1.jit.pdmodel \
|
||||
--param_path=$model_dir/avg_1.jit.pdiparams
|
||||
|
@ -0,0 +1,21 @@
|
||||
#!/bin/bash
|
||||
|
||||
# this script is for memory check, so please run ./run.sh first.
|
||||
|
||||
set +x
|
||||
set -e
|
||||
|
||||
. ./path.sh
|
||||
|
||||
if [ ! -d ${SPEECHX_TOOLS}/valgrind/install ]; then
|
||||
echo "please install valgrind in the speechx tools dir.\n"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ckpt_dir=./data/model
|
||||
model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/
|
||||
|
||||
valgrind --tool=memcheck --track-origins=yes --leak-check=full --show-leak-kinds=all \
|
||||
ds2_model_test_main \
|
||||
--model_path=$model_dir/avg_1.jit.pdmodel \
|
||||
--param_path=$model_dir/avg_1.jit.pdparams
|
@ -1,9 +0,0 @@
|
||||
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
|
||||
|
||||
add_executable(websocket_server_main ${CMAKE_CURRENT_SOURCE_DIR}/websocket_server_main.cc)
|
||||
target_include_directories(websocket_server_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
||||
target_link_libraries(websocket_server_main PUBLIC frontend kaldi-feat-common nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder websocket ${DEPS})
|
||||
|
||||
add_executable(websocket_client_main ${CMAKE_CURRENT_SOURCE_DIR}/websocket_client_main.cc)
|
||||
target_include_directories(websocket_client_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
||||
target_link_libraries(websocket_client_main PUBLIC frontend kaldi-feat-common nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder websocket ${DEPS})
|
@ -1,14 +1,14 @@
|
||||
# This contains the locations of binarys build required for running the examples.
|
||||
|
||||
SPEECHX_ROOT=$PWD/../../..
|
||||
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
|
||||
SPEECHX_ROOT=$PWD/../../../
|
||||
SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx
|
||||
|
||||
SPEECHX_TOOLS=$SPEECHX_ROOT/tools
|
||||
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
|
||||
|
||||
[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; }
|
||||
[ -d $SPEECHX_BUILD ] || { echo "Error: 'build/speechx' directory not found. please ensure that the project build successfully"; }
|
||||
|
||||
export LC_AL=C
|
||||
|
||||
SPEECHX_BIN=$SPEECHX_EXAMPLES/ds2_ol/websocket:$SPEECHX_EXAMPLES/ds2_ol/feat
|
||||
SPEECHX_BIN=$SPEECHX_BUILD/websocket
|
||||
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN
|
||||
|
@ -1,4 +1,4 @@
|
||||
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
|
||||
|
||||
add_subdirectory(ds2_ol)
|
||||
add_subdirectory(dev)
|
||||
add_subdirectory(glog)
|
||||
add_subdirectory(nnet)
|
@ -0,0 +1,7 @@
|
||||
|
||||
## For Developer
|
||||
|
||||
> Reminder: Only for developer.
|
||||
|
||||
* codelab - for speechx developer, using for test.
|
||||
|
@ -0,0 +1,8 @@
|
||||
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
|
||||
|
||||
add_executable(glog_main ${CMAKE_CURRENT_SOURCE_DIR}/glog_main.cc)
|
||||
target_link_libraries(glog_main glog)
|
||||
|
||||
|
||||
add_executable(glog_logtostderr_main ${CMAKE_CURRENT_SOURCE_DIR}/glog_logtostderr_main.cc)
|
||||
target_link_libraries(glog_logtostderr_main glog)
|
@ -0,0 +1,38 @@
|
||||
# [GLOG](https://rpg.ifi.uzh.ch/docs/glog.html)
|
||||
|
||||
Unless otherwise specified, glog writes to the filename `/tmp/<program name>.<hostname>.<user name>.log.<severity level>.<date>.<time>.<pid>` (e.g., "/tmp/hello_world.example.com.hamaji.log.INFO.20080709-222411.10474"). By default, glog copies the log messages of severity level ERROR or FATAL to standard error (stderr) in addition to log files.
|
||||
|
||||
Several flags influence glog's output behavior. If the Google gflags library is installed on your machine, the configure script (see the INSTALL file in the package for detail of this script) will automatically detect and use it, allowing you to pass flags on the command line. For example, if you want to turn the flag --logtostderr on, you can start your application with the following command line:
|
||||
|
||||
`./your_application --logtostderr=1`
|
||||
|
||||
If the Google gflags library isn't installed, you set flags via environment variables, prefixing the flag name with "GLOG_", e.g.
|
||||
|
||||
`GLOG_logtostderr=1 ./your_application`
|
||||
|
||||
You can also modify flag values in your program by modifying global variables `FLAGS_*` . Most settings start working immediately after you update `FLAGS_*` . The exceptions are the flags related to destination files. For example, you might want to set `FLAGS_log_dir` before calling `google::InitGoogleLogging` . Here is an example:
|
||||
∂∂
|
||||
```c++
|
||||
LOG(INFO) << "file";
|
||||
// Most flags work immediately after updating values.
|
||||
FLAGS_logtostderr = 1;
|
||||
LOG(INFO) << "stderr";
|
||||
FLAGS_logtostderr = 0;
|
||||
// This won't change the log destination. If you want to set this
|
||||
// value, you should do this before google::InitGoogleLogging .
|
||||
FLAGS_log_dir = "/some/log/directory";
|
||||
LOG(INFO) << "the same file";
|
||||
```
|
||||
|
||||
* this is the test script:
|
||||
```
|
||||
# run
|
||||
glog_test
|
||||
|
||||
echo "------"
|
||||
export FLAGS_logtostderr=1
|
||||
glog_test
|
||||
|
||||
echo "------"
|
||||
glog_logtostderr_test
|
||||
```
|
@ -0,0 +1,6 @@
|
||||
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
|
||||
|
||||
set(bin_name ds2_model_test_main)
|
||||
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
|
||||
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
||||
target_link_libraries(${bin_name} PUBLIC nnet gflags glog ${DEPS})
|
@ -0,0 +1,203 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// deepspeech2 online model info
|
||||
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <iterator>
|
||||
#include <numeric>
|
||||
#include <thread>
|
||||
#include "base/flags.h"
|
||||
#include "base/log.h"
|
||||
#include "paddle_inference_api.h"
|
||||
|
||||
using std::cout;
|
||||
using std::endl;
|
||||
|
||||
|
||||
DEFINE_string(model_path, "", "xxx.pdmodel");
|
||||
DEFINE_string(param_path, "", "xxx.pdiparams");
|
||||
DEFINE_int32(chunk_size, 35, "feature chunk size, unit:frame");
|
||||
DEFINE_int32(feat_dim, 161, "feature dim");
|
||||
|
||||
|
||||
void produce_data(std::vector<std::vector<float>>* data);
|
||||
void model_forward_test();
|
||||
|
||||
void produce_data(std::vector<std::vector<float>>* data) {
|
||||
int chunk_size = FLAGS_chunk_size; // chunk_size in frame
|
||||
int col_size = FLAGS_feat_dim; // feat dim
|
||||
cout << "chunk size: " << chunk_size << endl;
|
||||
cout << "feat dim: " << col_size << endl;
|
||||
|
||||
data->reserve(chunk_size);
|
||||
data->back().reserve(col_size);
|
||||
for (int row = 0; row < chunk_size; ++row) {
|
||||
data->push_back(std::vector<float>());
|
||||
for (int col_idx = 0; col_idx < col_size; ++col_idx) {
|
||||
data->back().push_back(0.201);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void model_forward_test() {
|
||||
std::cout << "1. read the data" << std::endl;
|
||||
std::vector<std::vector<float>> feats;
|
||||
produce_data(&feats);
|
||||
|
||||
std::cout << "2. load the model" << std::endl;
|
||||
;
|
||||
std::string model_graph = FLAGS_model_path;
|
||||
std::string model_params = FLAGS_param_path;
|
||||
CHECK(model_graph != "");
|
||||
CHECK(model_params != "");
|
||||
cout << "model path: " << model_graph << endl;
|
||||
cout << "model param path : " << model_params << endl;
|
||||
|
||||
paddle_infer::Config config;
|
||||
config.SetModel(model_graph, model_params);
|
||||
config.SwitchIrOptim(false);
|
||||
cout << "SwitchIrOptim: " << false << endl;
|
||||
config.DisableFCPadding();
|
||||
cout << "DisableFCPadding: " << endl;
|
||||
auto predictor = paddle_infer::CreatePredictor(config);
|
||||
|
||||
std::cout << "3. feat shape, row=" << feats.size()
|
||||
<< ",col=" << feats[0].size() << std::endl;
|
||||
std::vector<float> pp_input_mat;
|
||||
for (const auto& item : feats) {
|
||||
pp_input_mat.insert(pp_input_mat.end(), item.begin(), item.end());
|
||||
}
|
||||
|
||||
std::cout << "4. fead the data to model" << std::endl;
|
||||
int row = feats.size();
|
||||
int col = feats[0].size();
|
||||
std::vector<std::string> input_names = predictor->GetInputNames();
|
||||
std::vector<std::string> output_names = predictor->GetOutputNames();
|
||||
for (auto name : input_names) {
|
||||
cout << "model input names: " << name << endl;
|
||||
}
|
||||
for (auto name : output_names) {
|
||||
cout << "model output names: " << name << endl;
|
||||
}
|
||||
|
||||
// input
|
||||
std::unique_ptr<paddle_infer::Tensor> input_tensor =
|
||||
predictor->GetInputHandle(input_names[0]);
|
||||
std::vector<int> INPUT_SHAPE = {1, row, col};
|
||||
input_tensor->Reshape(INPUT_SHAPE);
|
||||
input_tensor->CopyFromCpu(pp_input_mat.data());
|
||||
|
||||
// input length
|
||||
std::unique_ptr<paddle_infer::Tensor> input_len =
|
||||
predictor->GetInputHandle(input_names[1]);
|
||||
std::vector<int> input_len_size = {1};
|
||||
input_len->Reshape(input_len_size);
|
||||
std::vector<int64_t> audio_len;
|
||||
audio_len.push_back(row);
|
||||
input_len->CopyFromCpu(audio_len.data());
|
||||
|
||||
// state_h
|
||||
std::unique_ptr<paddle_infer::Tensor> chunk_state_h_box =
|
||||
predictor->GetInputHandle(input_names[2]);
|
||||
std::vector<int> chunk_state_h_box_shape = {5, 1, 1024};
|
||||
chunk_state_h_box->Reshape(chunk_state_h_box_shape);
|
||||
int chunk_state_h_box_size =
|
||||
std::accumulate(chunk_state_h_box_shape.begin(),
|
||||
chunk_state_h_box_shape.end(),
|
||||
1,
|
||||
std::multiplies<int>());
|
||||
std::vector<float> chunk_state_h_box_data(chunk_state_h_box_size, 0.0f);
|
||||
chunk_state_h_box->CopyFromCpu(chunk_state_h_box_data.data());
|
||||
|
||||
// state_c
|
||||
std::unique_ptr<paddle_infer::Tensor> chunk_state_c_box =
|
||||
predictor->GetInputHandle(input_names[3]);
|
||||
std::vector<int> chunk_state_c_box_shape = {5, 1, 1024};
|
||||
chunk_state_c_box->Reshape(chunk_state_c_box_shape);
|
||||
int chunk_state_c_box_size =
|
||||
std::accumulate(chunk_state_c_box_shape.begin(),
|
||||
chunk_state_c_box_shape.end(),
|
||||
1,
|
||||
std::multiplies<int>());
|
||||
std::vector<float> chunk_state_c_box_data(chunk_state_c_box_size, 0.0f);
|
||||
chunk_state_c_box->CopyFromCpu(chunk_state_c_box_data.data());
|
||||
|
||||
// run
|
||||
bool success = predictor->Run();
|
||||
|
||||
// state_h out
|
||||
std::unique_ptr<paddle_infer::Tensor> h_out =
|
||||
predictor->GetOutputHandle(output_names[2]);
|
||||
std::vector<int> h_out_shape = h_out->shape();
|
||||
int h_out_size = std::accumulate(
|
||||
h_out_shape.begin(), h_out_shape.end(), 1, std::multiplies<int>());
|
||||
std::vector<float> h_out_data(h_out_size);
|
||||
h_out->CopyToCpu(h_out_data.data());
|
||||
|
||||
// stage_c out
|
||||
std::unique_ptr<paddle_infer::Tensor> c_out =
|
||||
predictor->GetOutputHandle(output_names[3]);
|
||||
std::vector<int> c_out_shape = c_out->shape();
|
||||
int c_out_size = std::accumulate(
|
||||
c_out_shape.begin(), c_out_shape.end(), 1, std::multiplies<int>());
|
||||
std::vector<float> c_out_data(c_out_size);
|
||||
c_out->CopyToCpu(c_out_data.data());
|
||||
|
||||
// output tensor
|
||||
std::unique_ptr<paddle_infer::Tensor> output_tensor =
|
||||
predictor->GetOutputHandle(output_names[0]);
|
||||
std::vector<int> output_shape = output_tensor->shape();
|
||||
std::vector<float> output_probs;
|
||||
int output_size = std::accumulate(
|
||||
output_shape.begin(), output_shape.end(), 1, std::multiplies<int>());
|
||||
output_probs.resize(output_size);
|
||||
output_tensor->CopyToCpu(output_probs.data());
|
||||
row = output_shape[1];
|
||||
col = output_shape[2];
|
||||
|
||||
// probs
|
||||
std::vector<std::vector<float>> probs;
|
||||
probs.reserve(row);
|
||||
for (int i = 0; i < row; i++) {
|
||||
probs.push_back(std::vector<float>());
|
||||
probs.back().reserve(col);
|
||||
|
||||
for (int j = 0; j < col; j++) {
|
||||
probs.back().push_back(output_probs[i * col + j]);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<float>> log_feat = probs;
|
||||
std::cout << "probs, row: " << log_feat.size()
|
||||
<< " col: " << log_feat[0].size() << std::endl;
|
||||
for (size_t row_idx = 0; row_idx < log_feat.size(); ++row_idx) {
|
||||
for (size_t col_idx = 0; col_idx < log_feat[row_idx].size();
|
||||
++col_idx) {
|
||||
std::cout << log_feat[row_idx][col_idx] << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
|
||||
model_forward_test();
|
||||
return 0;
|
||||
}
|
@ -0,0 +1,167 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// todo refactor, repalce with gtest
|
||||
|
||||
#include "base/flags.h"
|
||||
#include "base/log.h"
|
||||
#include "decoder/ctc_beam_search_decoder.h"
|
||||
#include "frontend/audio/data_cache.h"
|
||||
#include "kaldi/util/table-types.h"
|
||||
#include "nnet/decodable.h"
|
||||
#include "nnet/paddle_nnet.h"
|
||||
|
||||
DEFINE_string(feature_rspecifier, "", "test feature rspecifier");
|
||||
DEFINE_string(result_wspecifier, "", "test result wspecifier");
|
||||
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
|
||||
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
|
||||
DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm");
|
||||
DEFINE_string(lm_path, "", "language model");
|
||||
DEFINE_int32(receptive_field_length,
|
||||
7,
|
||||
"receptive field of two CNN(kernel=5) downsampling module.");
|
||||
DEFINE_int32(downsampling_rate,
|
||||
4,
|
||||
"two CNN(kernel=5) module downsampling rate.");
|
||||
DEFINE_string(
|
||||
model_input_names,
|
||||
"audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box",
|
||||
"model input names");
|
||||
DEFINE_string(model_output_names,
|
||||
"softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0",
|
||||
"model output names");
|
||||
DEFINE_string(model_cache_names,
|
||||
"chunk_state_h_box,chunk_state_c_box",
|
||||
"model cache names");
|
||||
DEFINE_string(model_cache_shapes, "5-1-1024,5-1-1024", "model cache shapes");
|
||||
|
||||
using kaldi::BaseFloat;
|
||||
using kaldi::Matrix;
|
||||
using std::vector;
|
||||
|
||||
// test ds2 online decoder by feeding speech feature
|
||||
int main(int argc, char* argv[]) {
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
|
||||
CHECK(FLAGS_result_wspecifier != "");
|
||||
CHECK(FLAGS_feature_rspecifier != "");
|
||||
|
||||
kaldi::SequentialBaseFloatMatrixReader feature_reader(
|
||||
FLAGS_feature_rspecifier);
|
||||
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
|
||||
std::string model_path = FLAGS_model_path;
|
||||
std::string model_params = FLAGS_param_path;
|
||||
std::string dict_file = FLAGS_dict_file;
|
||||
std::string lm_path = FLAGS_lm_path;
|
||||
LOG(INFO) << "model path: " << model_path;
|
||||
LOG(INFO) << "model param: " << model_params;
|
||||
LOG(INFO) << "dict path: " << dict_file;
|
||||
LOG(INFO) << "lm path: " << lm_path;
|
||||
|
||||
int32 num_done = 0, num_err = 0;
|
||||
|
||||
ppspeech::CTCBeamSearchOptions opts;
|
||||
opts.dict_file = dict_file;
|
||||
opts.lm_path = lm_path;
|
||||
ppspeech::CTCBeamSearch decoder(opts);
|
||||
|
||||
ppspeech::ModelOptions model_opts;
|
||||
model_opts.model_path = model_path;
|
||||
model_opts.param_path = model_params;
|
||||
model_opts.cache_names = FLAGS_model_cache_names;
|
||||
model_opts.cache_shape = FLAGS_model_cache_shapes;
|
||||
model_opts.input_names = FLAGS_model_input_names;
|
||||
model_opts.output_names = FLAGS_model_output_names;
|
||||
std::shared_ptr<ppspeech::PaddleNnet> nnet(
|
||||
new ppspeech::PaddleNnet(model_opts));
|
||||
std::shared_ptr<ppspeech::DataCache> raw_data(new ppspeech::DataCache());
|
||||
std::shared_ptr<ppspeech::Decodable> decodable(
|
||||
new ppspeech::Decodable(nnet, raw_data));
|
||||
|
||||
int32 chunk_size = FLAGS_receptive_field_length;
|
||||
int32 chunk_stride = FLAGS_downsampling_rate;
|
||||
int32 receptive_field_length = FLAGS_receptive_field_length;
|
||||
LOG(INFO) << "chunk size (frame): " << chunk_size;
|
||||
LOG(INFO) << "chunk stride (frame): " << chunk_stride;
|
||||
LOG(INFO) << "receptive field (frame): " << receptive_field_length;
|
||||
decoder.InitDecoder();
|
||||
|
||||
kaldi::Timer timer;
|
||||
for (; !feature_reader.Done(); feature_reader.Next()) {
|
||||
string utt = feature_reader.Key();
|
||||
kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
|
||||
raw_data->SetDim(feature.NumCols());
|
||||
LOG(INFO) << "process utt: " << utt;
|
||||
LOG(INFO) << "rows: " << feature.NumRows();
|
||||
LOG(INFO) << "cols: " << feature.NumCols();
|
||||
|
||||
int32 row_idx = 0;
|
||||
int32 padding_len = 0;
|
||||
int32 ori_feature_len = feature.NumRows();
|
||||
if ((feature.NumRows() - chunk_size) % chunk_stride != 0) {
|
||||
padding_len =
|
||||
chunk_stride - (feature.NumRows() - chunk_size) % chunk_stride;
|
||||
feature.Resize(feature.NumRows() + padding_len,
|
||||
feature.NumCols(),
|
||||
kaldi::kCopyData);
|
||||
}
|
||||
int32 num_chunks = (feature.NumRows() - chunk_size) / chunk_stride + 1;
|
||||
for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) {
|
||||
kaldi::Vector<kaldi::BaseFloat> feature_chunk(chunk_size *
|
||||
feature.NumCols());
|
||||
int32 feature_chunk_size = 0;
|
||||
if (ori_feature_len > chunk_idx * chunk_stride) {
|
||||
feature_chunk_size = std::min(
|
||||
ori_feature_len - chunk_idx * chunk_stride, chunk_size);
|
||||
}
|
||||
if (feature_chunk_size < receptive_field_length) break;
|
||||
|
||||
int32 start = chunk_idx * chunk_stride;
|
||||
|
||||
for (int row_id = 0; row_id < chunk_size; ++row_id) {
|
||||
kaldi::SubVector<kaldi::BaseFloat> tmp(feature, start);
|
||||
kaldi::SubVector<kaldi::BaseFloat> f_chunk_tmp(
|
||||
feature_chunk.Data() + row_id * feature.NumCols(),
|
||||
feature.NumCols());
|
||||
f_chunk_tmp.CopyFromVec(tmp);
|
||||
++start;
|
||||
}
|
||||
raw_data->Accept(feature_chunk);
|
||||
if (chunk_idx == num_chunks - 1) {
|
||||
raw_data->SetFinished();
|
||||
}
|
||||
decoder.AdvanceDecode(decodable);
|
||||
}
|
||||
std::string result;
|
||||
result = decoder.GetFinalBestPath();
|
||||
decodable->Reset();
|
||||
decoder.Reset();
|
||||
if (result.empty()) {
|
||||
// the TokenWriter can not write empty string.
|
||||
++num_err;
|
||||
KALDI_LOG << " the result of " << utt << " is empty";
|
||||
continue;
|
||||
}
|
||||
KALDI_LOG << " the result of " << utt << " is " << result;
|
||||
result_writer.Write(utt, result);
|
||||
++num_done;
|
||||
}
|
||||
|
||||
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
|
||||
<< " with errors.";
|
||||
double elapsed = timer.Elapsed();
|
||||
KALDI_LOG << " cost:" << elapsed << " s";
|
||||
return (num_done != 0 ? 0 : 1);
|
||||
}
|
@ -0,0 +1,74 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// todo refactor, repalce with gtest
|
||||
|
||||
#include "base/flags.h"
|
||||
#include "base/log.h"
|
||||
#include "decoder/ctc_beam_search_decoder.h"
|
||||
#include "kaldi/util/table-types.h"
|
||||
#include "nnet/decodable.h"
|
||||
|
||||
DEFINE_string(nnet_prob_respecifier, "", "test nnet prob rspecifier");
|
||||
DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm");
|
||||
DEFINE_string(lm_path, "lm.klm", "language model");
|
||||
|
||||
using kaldi::BaseFloat;
|
||||
using kaldi::Matrix;
|
||||
using std::vector;
|
||||
|
||||
// test decoder by feeding nnet posterior probability
|
||||
int main(int argc, char* argv[]) {
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
|
||||
kaldi::SequentialBaseFloatMatrixReader likelihood_reader(
|
||||
FLAGS_nnet_prob_respecifier);
|
||||
std::string dict_file = FLAGS_dict_file;
|
||||
std::string lm_path = FLAGS_lm_path;
|
||||
LOG(INFO) << "dict path: " << dict_file;
|
||||
LOG(INFO) << "lm path: " << lm_path;
|
||||
|
||||
int32 num_done = 0, num_err = 0;
|
||||
|
||||
ppspeech::CTCBeamSearchOptions opts;
|
||||
opts.dict_file = dict_file;
|
||||
opts.lm_path = lm_path;
|
||||
ppspeech::CTCBeamSearch decoder(opts);
|
||||
|
||||
std::shared_ptr<ppspeech::Decodable> decodable(
|
||||
new ppspeech::Decodable(nullptr, nullptr));
|
||||
|
||||
decoder.InitDecoder();
|
||||
|
||||
for (; !likelihood_reader.Done(); likelihood_reader.Next()) {
|
||||
string utt = likelihood_reader.Key();
|
||||
const kaldi::Matrix<BaseFloat> likelihood = likelihood_reader.Value();
|
||||
LOG(INFO) << "process utt: " << utt;
|
||||
LOG(INFO) << "rows: " << likelihood.NumRows();
|
||||
LOG(INFO) << "cols: " << likelihood.NumCols();
|
||||
decodable->Acceptlikelihood(likelihood);
|
||||
decoder.AdvanceDecode(decodable);
|
||||
std::string result;
|
||||
result = decoder.GetFinalBestPath();
|
||||
KALDI_LOG << " the result of " << utt << " is " << result;
|
||||
decodable->Reset();
|
||||
decoder.Reset();
|
||||
++num_done;
|
||||
}
|
||||
|
||||
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
|
||||
<< " with errors.";
|
||||
return (num_done != 0 ? 0 : 1);
|
||||
}
|
@ -0,0 +1,99 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "decoder/recognizer.h"
|
||||
#include "decoder/param.h"
|
||||
#include "kaldi/feat/wave-reader.h"
|
||||
#include "kaldi/util/table-types.h"
|
||||
|
||||
DEFINE_string(wav_rspecifier, "", "test feature rspecifier");
|
||||
DEFINE_string(result_wspecifier, "", "test result wspecifier");
|
||||
DEFINE_int32(sample_rate, 16000, "sample rate");
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
|
||||
ppspeech::RecognizerResource resource = ppspeech::InitRecognizerResoure();
|
||||
ppspeech::Recognizer recognizer(resource);
|
||||
|
||||
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
|
||||
FLAGS_wav_rspecifier);
|
||||
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
|
||||
|
||||
int sample_rate = FLAGS_sample_rate;
|
||||
float streaming_chunk = FLAGS_streaming_chunk;
|
||||
int chunk_sample_size = streaming_chunk * sample_rate;
|
||||
LOG(INFO) << "sr: " << sample_rate;
|
||||
LOG(INFO) << "chunk size (s): " << streaming_chunk;
|
||||
LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
|
||||
|
||||
int32 num_done = 0, num_err = 0;
|
||||
double tot_wav_duration = 0.0;
|
||||
|
||||
kaldi::Timer timer;
|
||||
|
||||
for (; !wav_reader.Done(); wav_reader.Next()) {
|
||||
std::string utt = wav_reader.Key();
|
||||
const kaldi::WaveData& wave_data = wav_reader.Value();
|
||||
|
||||
int32 this_channel = 0;
|
||||
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
|
||||
this_channel);
|
||||
int tot_samples = waveform.Dim();
|
||||
tot_wav_duration += tot_samples * 1.0 / sample_rate;
|
||||
LOG(INFO) << "wav len (sample): " << tot_samples;
|
||||
|
||||
int sample_offset = 0;
|
||||
std::vector<kaldi::Vector<BaseFloat>> feats;
|
||||
int feature_rows = 0;
|
||||
while (sample_offset < tot_samples) {
|
||||
int cur_chunk_size =
|
||||
std::min(chunk_sample_size, tot_samples - sample_offset);
|
||||
|
||||
kaldi::Vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
|
||||
for (int i = 0; i < cur_chunk_size; ++i) {
|
||||
wav_chunk(i) = waveform(sample_offset + i);
|
||||
}
|
||||
// wav_chunk = waveform.Range(sample_offset + i, cur_chunk_size);
|
||||
|
||||
recognizer.Accept(wav_chunk);
|
||||
if (cur_chunk_size < chunk_sample_size) {
|
||||
recognizer.SetFinished();
|
||||
}
|
||||
recognizer.Decode();
|
||||
|
||||
// no overlap
|
||||
sample_offset += cur_chunk_size;
|
||||
}
|
||||
|
||||
std::string result;
|
||||
result = recognizer.GetFinalResult();
|
||||
recognizer.Reset();
|
||||
if (result.empty()) {
|
||||
// the TokenWriter can not write empty string.
|
||||
++num_err;
|
||||
KALDI_LOG << " the result of " << utt << " is empty";
|
||||
continue;
|
||||
}
|
||||
KALDI_LOG << " the result of " << utt << " is " << result;
|
||||
result_writer.Write(utt, result);
|
||||
++num_done;
|
||||
}
|
||||
double elapsed = timer.Elapsed();
|
||||
KALDI_LOG << "Done " << num_done << " out of " << (num_err + num_done);
|
||||
KALDI_LOG << " cost:" << elapsed << " s";
|
||||
KALDI_LOG << "total wav duration is: " << tot_wav_duration << " s";
|
||||
KALDI_LOG << "the RTF is: " << elapsed / tot_wav_duration;
|
||||
}
|
@ -0,0 +1,169 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// todo refactor, repalce with gtest
|
||||
|
||||
#include "base/flags.h"
|
||||
#include "base/log.h"
|
||||
#include "decoder/ctc_tlg_decoder.h"
|
||||
#include "frontend/audio/data_cache.h"
|
||||
#include "kaldi/util/table-types.h"
|
||||
#include "nnet/decodable.h"
|
||||
#include "nnet/paddle_nnet.h"
|
||||
|
||||
DEFINE_string(feature_rspecifier, "", "test feature rspecifier");
|
||||
DEFINE_string(result_wspecifier, "", "test result wspecifier");
|
||||
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
|
||||
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
|
||||
DEFINE_string(word_symbol_table, "words.txt", "word symbol table");
|
||||
DEFINE_string(graph_path, "TLG", "decoder graph");
|
||||
|
||||
DEFINE_double(acoustic_scale, 1.0, "acoustic scale");
|
||||
DEFINE_int32(max_active, 7500, "decoder graph");
|
||||
DEFINE_int32(receptive_field_length,
|
||||
7,
|
||||
"receptive field of two CNN(kernel=5) downsampling module.");
|
||||
DEFINE_int32(downsampling_rate,
|
||||
4,
|
||||
"two CNN(kernel=5) module downsampling rate.");
|
||||
DEFINE_string(
|
||||
model_input_names,
|
||||
"audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box",
|
||||
"model input names");
|
||||
DEFINE_string(model_output_names,
|
||||
"softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0",
|
||||
"model output names");
|
||||
DEFINE_string(model_cache_names,
|
||||
"chunk_state_h_box,chunk_state_c_box",
|
||||
"model cache names");
|
||||
DEFINE_string(model_cache_shapes, "5-1-1024,5-1-1024", "model cache shapes");
|
||||
|
||||
using kaldi::BaseFloat;
|
||||
using kaldi::Matrix;
|
||||
using std::vector;
|
||||
|
||||
// test TLG decoder by feeding speech feature.
|
||||
int main(int argc, char* argv[]) {
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
|
||||
kaldi::SequentialBaseFloatMatrixReader feature_reader(
|
||||
FLAGS_feature_rspecifier);
|
||||
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
|
||||
std::string model_graph = FLAGS_model_path;
|
||||
std::string model_params = FLAGS_param_path;
|
||||
std::string word_symbol_table = FLAGS_word_symbol_table;
|
||||
std::string graph_path = FLAGS_graph_path;
|
||||
LOG(INFO) << "model path: " << model_graph;
|
||||
LOG(INFO) << "model param: " << model_params;
|
||||
LOG(INFO) << "word symbol path: " << word_symbol_table;
|
||||
LOG(INFO) << "graph path: " << graph_path;
|
||||
|
||||
int32 num_done = 0, num_err = 0;
|
||||
|
||||
ppspeech::TLGDecoderOptions opts;
|
||||
opts.word_symbol_table = word_symbol_table;
|
||||
opts.fst_path = graph_path;
|
||||
opts.opts.max_active = FLAGS_max_active;
|
||||
opts.opts.beam = 15.0;
|
||||
opts.opts.lattice_beam = 7.5;
|
||||
ppspeech::TLGDecoder decoder(opts);
|
||||
|
||||
ppspeech::ModelOptions model_opts;
|
||||
model_opts.model_path = model_graph;
|
||||
model_opts.param_path = model_params;
|
||||
model_opts.cache_names = FLAGS_model_cache_names;
|
||||
model_opts.cache_shape = FLAGS_model_cache_shapes;
|
||||
model_opts.input_names = FLAGS_model_input_names;
|
||||
model_opts.output_names = FLAGS_model_output_names;
|
||||
std::shared_ptr<ppspeech::PaddleNnet> nnet(
|
||||
new ppspeech::PaddleNnet(model_opts));
|
||||
std::shared_ptr<ppspeech::DataCache> raw_data(new ppspeech::DataCache());
|
||||
std::shared_ptr<ppspeech::Decodable> decodable(
|
||||
new ppspeech::Decodable(nnet, raw_data, FLAGS_acoustic_scale));
|
||||
|
||||
int32 chunk_size = FLAGS_receptive_field_length;
|
||||
int32 chunk_stride = FLAGS_downsampling_rate;
|
||||
int32 receptive_field_length = FLAGS_receptive_field_length;
|
||||
LOG(INFO) << "chunk size (frame): " << chunk_size;
|
||||
LOG(INFO) << "chunk stride (frame): " << chunk_stride;
|
||||
LOG(INFO) << "receptive field (frame): " << receptive_field_length;
|
||||
decoder.InitDecoder();
|
||||
kaldi::Timer timer;
|
||||
for (; !feature_reader.Done(); feature_reader.Next()) {
|
||||
string utt = feature_reader.Key();
|
||||
kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
|
||||
raw_data->SetDim(feature.NumCols());
|
||||
LOG(INFO) << "process utt: " << utt;
|
||||
LOG(INFO) << "rows: " << feature.NumRows();
|
||||
LOG(INFO) << "cols: " << feature.NumCols();
|
||||
|
||||
int32 row_idx = 0;
|
||||
int32 padding_len = 0;
|
||||
int32 ori_feature_len = feature.NumRows();
|
||||
if ((feature.NumRows() - chunk_size) % chunk_stride != 0) {
|
||||
padding_len =
|
||||
chunk_stride - (feature.NumRows() - chunk_size) % chunk_stride;
|
||||
feature.Resize(feature.NumRows() + padding_len,
|
||||
feature.NumCols(),
|
||||
kaldi::kCopyData);
|
||||
}
|
||||
int32 num_chunks = (feature.NumRows() - chunk_size) / chunk_stride + 1;
|
||||
for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) {
|
||||
kaldi::Vector<kaldi::BaseFloat> feature_chunk(chunk_size *
|
||||
feature.NumCols());
|
||||
int32 feature_chunk_size = 0;
|
||||
if (ori_feature_len > chunk_idx * chunk_stride) {
|
||||
feature_chunk_size = std::min(
|
||||
ori_feature_len - chunk_idx * chunk_stride, chunk_size);
|
||||
}
|
||||
if (feature_chunk_size < receptive_field_length) break;
|
||||
|
||||
int32 start = chunk_idx * chunk_stride;
|
||||
for (int row_id = 0; row_id < chunk_size; ++row_id) {
|
||||
kaldi::SubVector<kaldi::BaseFloat> tmp(feature, start);
|
||||
kaldi::SubVector<kaldi::BaseFloat> f_chunk_tmp(
|
||||
feature_chunk.Data() + row_id * feature.NumCols(),
|
||||
feature.NumCols());
|
||||
f_chunk_tmp.CopyFromVec(tmp);
|
||||
++start;
|
||||
}
|
||||
raw_data->Accept(feature_chunk);
|
||||
if (chunk_idx == num_chunks - 1) {
|
||||
raw_data->SetFinished();
|
||||
}
|
||||
decoder.AdvanceDecode(decodable);
|
||||
}
|
||||
std::string result;
|
||||
result = decoder.GetFinalBestPath();
|
||||
decodable->Reset();
|
||||
decoder.Reset();
|
||||
if (result.empty()) {
|
||||
// the TokenWriter can not write empty string.
|
||||
++num_err;
|
||||
KALDI_LOG << " the result of " << utt << " is empty";
|
||||
continue;
|
||||
}
|
||||
KALDI_LOG << " the result of " << utt << " is " << result;
|
||||
result_writer.Write(utt, result);
|
||||
++num_done;
|
||||
}
|
||||
|
||||
double elapsed = timer.Elapsed();
|
||||
KALDI_LOG << " cost:" << elapsed << " s";
|
||||
|
||||
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
|
||||
<< " with errors.";
|
||||
return (num_done != 0 ? 0 : 1);
|
||||
}
|
@ -0,0 +1,85 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Note: Do not print/log ondemand object.
|
||||
|
||||
#include "base/common.h"
|
||||
#include "base/flags.h"
|
||||
#include "base/log.h"
|
||||
#include "kaldi/matrix/kaldi-matrix.h"
|
||||
#include "kaldi/util/kaldi-io.h"
|
||||
#include "utils/file_utils.h"
|
||||
// #include "boost/json.hpp"
|
||||
#include <boost/json/src.hpp>
|
||||
|
||||
DEFINE_string(json_file, "", "cmvn json file");
|
||||
DEFINE_string(cmvn_write_path, "./cmvn.ark", "write cmvn");
|
||||
DEFINE_bool(binary, true, "write cmvn in binary (true) or text(false)");
|
||||
|
||||
using namespace boost::json; // from <boost/json.hpp>
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
|
||||
LOG(INFO) << "cmvn josn path: " << FLAGS_json_file;
|
||||
|
||||
auto ifs = std::ifstream(FLAGS_json_file);
|
||||
std::string json_str = ppspeech::ReadFile2String(FLAGS_json_file);
|
||||
auto value = boost::json::parse(json_str);
|
||||
if (!value.is_object()) {
|
||||
LOG(ERROR) << "Input json file format error.";
|
||||
}
|
||||
|
||||
for (auto obj : value.as_object()) {
|
||||
if (obj.key() == "mean_stat") {
|
||||
LOG(INFO) << "mean_stat:" << obj.value();
|
||||
}
|
||||
if (obj.key() == "var_stat") {
|
||||
LOG(INFO) << "var_stat: " << obj.value();
|
||||
}
|
||||
if (obj.key() == "frame_num") {
|
||||
LOG(INFO) << "frame_num: " << obj.value();
|
||||
}
|
||||
}
|
||||
|
||||
boost::json::array mean_stat = value.at("mean_stat").as_array();
|
||||
std::vector<kaldi::BaseFloat> mean_stat_vec;
|
||||
for (auto it = mean_stat.begin(); it != mean_stat.end(); it++) {
|
||||
mean_stat_vec.push_back(it->as_double());
|
||||
}
|
||||
|
||||
boost::json::array var_stat = value.at("var_stat").as_array();
|
||||
std::vector<kaldi::BaseFloat> var_stat_vec;
|
||||
for (auto it = var_stat.begin(); it != var_stat.end(); it++) {
|
||||
var_stat_vec.push_back(it->as_double());
|
||||
}
|
||||
|
||||
kaldi::int32 frame_num = uint64_t(value.at("frame_num").as_int64());
|
||||
LOG(INFO) << "nframe: " << frame_num;
|
||||
|
||||
size_t mean_size = mean_stat_vec.size();
|
||||
kaldi::Matrix<double> cmvn_stats(2, mean_size + 1);
|
||||
for (size_t idx = 0; idx < mean_size; ++idx) {
|
||||
cmvn_stats(0, idx) = mean_stat_vec[idx];
|
||||
cmvn_stats(1, idx) = var_stat_vec[idx];
|
||||
}
|
||||
cmvn_stats(0, mean_size) = frame_num;
|
||||
LOG(INFO) << cmvn_stats;
|
||||
|
||||
kaldi::WriteKaldiObject(cmvn_stats, FLAGS_cmvn_write_path, FLAGS_binary);
|
||||
LOG(INFO) << "cmvn stats have write into: " << FLAGS_cmvn_write_path;
|
||||
LOG(INFO) << "Binary: " << FLAGS_binary;
|
||||
return 0;
|
||||
}
|
@ -0,0 +1,143 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// todo refactor, repalce with gtest
|
||||
|
||||
#include "base/flags.h"
|
||||
#include "base/log.h"
|
||||
#include "kaldi/feat/wave-reader.h"
|
||||
#include "kaldi/util/kaldi-io.h"
|
||||
#include "kaldi/util/table-types.h"
|
||||
|
||||
#include "frontend/audio/audio_cache.h"
|
||||
#include "frontend/audio/data_cache.h"
|
||||
#include "frontend/audio/fbank.h"
|
||||
#include "frontend/audio/feature_cache.h"
|
||||
#include "frontend/audio/frontend_itf.h"
|
||||
#include "frontend/audio/normalizer.h"
|
||||
|
||||
DEFINE_string(wav_rspecifier, "", "test wav scp path");
|
||||
DEFINE_string(feature_wspecifier, "", "output feats wspecifier");
|
||||
DEFINE_string(cmvn_file, "", "read cmvn");
|
||||
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
|
||||
DEFINE_int32(num_bins, 161, "fbank num bins");
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
|
||||
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
|
||||
FLAGS_wav_rspecifier);
|
||||
kaldi::BaseFloatMatrixWriter feat_writer(FLAGS_feature_wspecifier);
|
||||
|
||||
int32 num_done = 0, num_err = 0;
|
||||
|
||||
// feature pipeline: wave cache --> povey window
|
||||
// -->fbank --> global cmvn -> feat cache
|
||||
|
||||
std::unique_ptr<ppspeech::FrontendInterface> data_source(
|
||||
new ppspeech::AudioCache(3600 * 1600, false));
|
||||
|
||||
ppspeech::FbankOptions opt;
|
||||
opt.fbank_opts.frame_opts.frame_length_ms = 25;
|
||||
opt.fbank_opts.frame_opts.frame_shift_ms = 10;
|
||||
opt.streaming_chunk = FLAGS_streaming_chunk;
|
||||
opt.fbank_opts.mel_opts.num_bins = FLAGS_num_bins;
|
||||
opt.fbank_opts.frame_opts.dither = 0.0;
|
||||
|
||||
std::unique_ptr<ppspeech::FrontendInterface> fbank(
|
||||
new ppspeech::Fbank(opt, std::move(data_source)));
|
||||
|
||||
std::unique_ptr<ppspeech::FrontendInterface> cmvn(
|
||||
new ppspeech::CMVN(FLAGS_cmvn_file, std::move(fbank)));
|
||||
|
||||
ppspeech::FeatureCacheOptions feat_cache_opts;
|
||||
// the feature cache output feature chunk by chunk.
|
||||
// frame_chunk_size : num frame of a chunk.
|
||||
// frame_chunk_stride: chunk sliding window stride.
|
||||
feat_cache_opts.frame_chunk_stride = 1;
|
||||
feat_cache_opts.frame_chunk_size = 1;
|
||||
ppspeech::FeatureCache feature_cache(feat_cache_opts, std::move(cmvn));
|
||||
LOG(INFO) << "fbank: " << true;
|
||||
LOG(INFO) << "feat dim: " << feature_cache.Dim();
|
||||
|
||||
int sample_rate = 16000;
|
||||
float streaming_chunk = FLAGS_streaming_chunk;
|
||||
int chunk_sample_size = streaming_chunk * sample_rate;
|
||||
LOG(INFO) << "sr: " << sample_rate;
|
||||
LOG(INFO) << "chunk size (s): " << streaming_chunk;
|
||||
LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
|
||||
|
||||
for (; !wav_reader.Done(); wav_reader.Next()) {
|
||||
std::string utt = wav_reader.Key();
|
||||
const kaldi::WaveData& wave_data = wav_reader.Value();
|
||||
LOG(INFO) << "process utt: " << utt;
|
||||
|
||||
int32 this_channel = 0;
|
||||
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
|
||||
this_channel);
|
||||
int tot_samples = waveform.Dim();
|
||||
LOG(INFO) << "wav len (sample): " << tot_samples;
|
||||
|
||||
int sample_offset = 0;
|
||||
std::vector<kaldi::Vector<BaseFloat>> feats;
|
||||
int feature_rows = 0;
|
||||
while (sample_offset < tot_samples) {
|
||||
int cur_chunk_size =
|
||||
std::min(chunk_sample_size, tot_samples - sample_offset);
|
||||
|
||||
kaldi::Vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
|
||||
for (int i = 0; i < cur_chunk_size; ++i) {
|
||||
wav_chunk(i) = waveform(sample_offset + i);
|
||||
}
|
||||
|
||||
kaldi::Vector<BaseFloat> features;
|
||||
feature_cache.Accept(wav_chunk);
|
||||
if (cur_chunk_size < chunk_sample_size) {
|
||||
feature_cache.SetFinished();
|
||||
}
|
||||
bool flag = true;
|
||||
do {
|
||||
flag = feature_cache.Read(&features);
|
||||
feats.push_back(features);
|
||||
feature_rows += features.Dim() / feature_cache.Dim();
|
||||
} while (flag == true && features.Dim() != 0);
|
||||
sample_offset += cur_chunk_size;
|
||||
}
|
||||
|
||||
int cur_idx = 0;
|
||||
kaldi::Matrix<kaldi::BaseFloat> features(feature_rows,
|
||||
feature_cache.Dim());
|
||||
for (auto feat : feats) {
|
||||
int num_rows = feat.Dim() / feature_cache.Dim();
|
||||
for (int row_idx = 0; row_idx < num_rows; ++row_idx) {
|
||||
for (size_t col_idx = 0; col_idx < feature_cache.Dim();
|
||||
++col_idx) {
|
||||
features(cur_idx, col_idx) =
|
||||
feat(row_idx * feature_cache.Dim() + col_idx);
|
||||
}
|
||||
++cur_idx;
|
||||
}
|
||||
}
|
||||
feat_writer.Write(utt, features);
|
||||
feature_cache.Reset();
|
||||
|
||||
if (num_done % 50 == 0 && num_done != 0)
|
||||
KALDI_VLOG(2) << "Processed " << num_done << " utterances";
|
||||
num_done++;
|
||||
}
|
||||
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
|
||||
<< " with errors.";
|
||||
return (num_done != 0 ? 0 : 1);
|
||||
}
|
@ -0,0 +1,145 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "base/flags.h"
|
||||
#include "base/log.h"
|
||||
#include "kaldi/feat/wave-reader.h"
|
||||
#include "kaldi/util/kaldi-io.h"
|
||||
#include "kaldi/util/table-types.h"
|
||||
|
||||
#include "frontend/audio/audio_cache.h"
|
||||
#include "frontend/audio/data_cache.h"
|
||||
#include "frontend/audio/feature_cache.h"
|
||||
#include "frontend/audio/frontend_itf.h"
|
||||
#include "frontend/audio/linear_spectrogram.h"
|
||||
#include "frontend/audio/normalizer.h"
|
||||
|
||||
DEFINE_string(wav_rspecifier, "", "test wav scp path");
|
||||
DEFINE_string(feature_wspecifier, "", "output feats wspecifier");
|
||||
DEFINE_string(cmvn_file, "./cmvn.ark", "read cmvn");
|
||||
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
|
||||
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
|
||||
FLAGS_wav_rspecifier);
|
||||
kaldi::BaseFloatMatrixWriter feat_writer(FLAGS_feature_wspecifier);
|
||||
|
||||
int32 num_done = 0, num_err = 0;
|
||||
|
||||
// feature pipeline: wave cache --> hanning window
|
||||
// -->linear_spectrogram --> global cmvn -> feat cache
|
||||
|
||||
std::unique_ptr<ppspeech::FrontendInterface> data_source(
|
||||
new ppspeech::AudioCache(3600 * 1600, true));
|
||||
|
||||
ppspeech::LinearSpectrogramOptions opt;
|
||||
opt.frame_opts.frame_length_ms = 20;
|
||||
opt.frame_opts.frame_shift_ms = 10;
|
||||
opt.streaming_chunk = FLAGS_streaming_chunk;
|
||||
opt.frame_opts.dither = 0.0;
|
||||
opt.frame_opts.remove_dc_offset = false;
|
||||
opt.frame_opts.window_type = "hanning";
|
||||
opt.frame_opts.preemph_coeff = 0.0;
|
||||
LOG(INFO) << "linear feature: " << true;
|
||||
LOG(INFO) << "frame length (ms): " << opt.frame_opts.frame_length_ms;
|
||||
LOG(INFO) << "frame shift (ms): " << opt.frame_opts.frame_shift_ms;
|
||||
|
||||
std::unique_ptr<ppspeech::FrontendInterface> linear_spectrogram(
|
||||
new ppspeech::LinearSpectrogram(opt, std::move(data_source)));
|
||||
|
||||
std::unique_ptr<ppspeech::FrontendInterface> cmvn(
|
||||
new ppspeech::CMVN(FLAGS_cmvn_file, std::move(linear_spectrogram)));
|
||||
|
||||
ppspeech::FeatureCacheOptions feat_cache_opts;
|
||||
// the feature cache output feature chunk by chunk.
|
||||
// frame_chunk_size : num frame of a chunk.
|
||||
// frame_chunk_stride: chunk sliding window stride.
|
||||
feat_cache_opts.frame_chunk_stride = 1;
|
||||
feat_cache_opts.frame_chunk_size = 1;
|
||||
ppspeech::FeatureCache feature_cache(feat_cache_opts, std::move(cmvn));
|
||||
LOG(INFO) << "feat dim: " << feature_cache.Dim();
|
||||
|
||||
int sample_rate = 16000;
|
||||
float streaming_chunk = FLAGS_streaming_chunk;
|
||||
int chunk_sample_size = streaming_chunk * sample_rate;
|
||||
LOG(INFO) << "sample rate: " << sample_rate;
|
||||
LOG(INFO) << "chunk size (s): " << streaming_chunk;
|
||||
LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
|
||||
|
||||
|
||||
for (; !wav_reader.Done(); wav_reader.Next()) {
|
||||
std::string utt = wav_reader.Key();
|
||||
const kaldi::WaveData& wave_data = wav_reader.Value();
|
||||
LOG(INFO) << "process utt: " << utt;
|
||||
|
||||
int32 this_channel = 0;
|
||||
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
|
||||
this_channel);
|
||||
int tot_samples = waveform.Dim();
|
||||
LOG(INFO) << "wav len (sample): " << tot_samples;
|
||||
|
||||
int sample_offset = 0;
|
||||
std::vector<kaldi::Vector<BaseFloat>> feats;
|
||||
int feature_rows = 0;
|
||||
while (sample_offset < tot_samples) {
|
||||
int cur_chunk_size =
|
||||
std::min(chunk_sample_size, tot_samples - sample_offset);
|
||||
|
||||
kaldi::Vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
|
||||
for (int i = 0; i < cur_chunk_size; ++i) {
|
||||
wav_chunk(i) = waveform(sample_offset + i);
|
||||
}
|
||||
|
||||
kaldi::Vector<BaseFloat> features;
|
||||
feature_cache.Accept(wav_chunk);
|
||||
if (cur_chunk_size < chunk_sample_size) {
|
||||
feature_cache.SetFinished();
|
||||
}
|
||||
bool flag = true;
|
||||
do {
|
||||
flag = feature_cache.Read(&features);
|
||||
feats.push_back(features);
|
||||
feature_rows += features.Dim() / feature_cache.Dim();
|
||||
} while (flag == true && features.Dim() != 0);
|
||||
sample_offset += cur_chunk_size;
|
||||
}
|
||||
|
||||
int cur_idx = 0;
|
||||
kaldi::Matrix<kaldi::BaseFloat> features(feature_rows,
|
||||
feature_cache.Dim());
|
||||
for (auto feat : feats) {
|
||||
int num_rows = feat.Dim() / feature_cache.Dim();
|
||||
for (int row_idx = 0; row_idx < num_rows; ++row_idx) {
|
||||
for (size_t col_idx = 0; col_idx < feature_cache.Dim();
|
||||
++col_idx) {
|
||||
features(cur_idx, col_idx) =
|
||||
feat(row_idx * feature_cache.Dim() + col_idx);
|
||||
}
|
||||
++cur_idx;
|
||||
}
|
||||
}
|
||||
feat_writer.Write(utt, features);
|
||||
feature_cache.Reset();
|
||||
|
||||
if (num_done % 50 == 0 && num_done != 0)
|
||||
KALDI_VLOG(2) << "Processed " << num_done << " utterances";
|
||||
num_done++;
|
||||
}
|
||||
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
|
||||
<< " with errors.";
|
||||
return (num_done != 0 ? 0 : 1);
|
||||
}
|
Loading…
Reference in new issue