commit
e6e5d86a5a
@ -1 +0,0 @@
|
||||
tmp
|
@ -1,19 +0,0 @@
|
||||
# 1xt2x
|
||||
|
||||
Convert Deepspeech 1.8 released model to 2.x.
|
||||
|
||||
## Model source directory
|
||||
* Deepspeech2x
|
||||
|
||||
## Expriment directory
|
||||
* aishell
|
||||
* librispeech
|
||||
* baidu_en8k
|
||||
|
||||
# The released model
|
||||
|
||||
Acoustic Model | Training Data | Hours of Speech | Token-based | CER | WER
|
||||
:-------------:| :------------:| :---------------: | :---------: | :---: | :----:
|
||||
Ds2 Offline Aishell 1xt2x model| Aishell Dataset | 151 h | Char-based | 0.080447 |
|
||||
Ds2 Offline Librispeech 1xt2x model | Librispeech Dataset | 960 h | Word-based | | 0.068548
|
||||
Ds2 Offline Baidu en8k 1x2x model | Baidu Internal English Dataset | 8628 h |Word-based | | 0.054112
|
@ -1,5 +0,0 @@
|
||||
exp
|
||||
data
|
||||
*log
|
||||
tmp
|
||||
nohup*
|
@ -1 +0,0 @@
|
||||
[]
|
@ -1,65 +0,0 @@
|
||||
# https://yaml.org/type/float.html
|
||||
###########################################
|
||||
# Data #
|
||||
###########################################
|
||||
train_manifest: data/manifest.train
|
||||
dev_manifest: data/manifest.dev
|
||||
test_manifest: data/manifest.test
|
||||
min_input_len: 0.0
|
||||
max_input_len: 27.0 # second
|
||||
min_output_len: 0.0
|
||||
max_output_len: .inf
|
||||
min_output_input_ratio: 0.00
|
||||
max_output_input_ratio: .inf
|
||||
|
||||
###########################################
|
||||
# Dataloader #
|
||||
###########################################
|
||||
batch_size: 64 # one gpu
|
||||
mean_std_filepath: data/mean_std.npz
|
||||
unit_type: char
|
||||
vocab_filepath: data/vocab.txt
|
||||
augmentation_config: conf/augmentation.json
|
||||
random_seed: 0
|
||||
spm_model_prefix:
|
||||
spectrum_type: linear
|
||||
feat_dim:
|
||||
delta_delta: False
|
||||
stride_ms: 10.0
|
||||
window_ms: 20.0
|
||||
n_fft: None
|
||||
max_freq: None
|
||||
target_sample_rate: 16000
|
||||
use_dB_normalization: True
|
||||
target_dB: -20
|
||||
dither: 1.0
|
||||
keep_transcription_text: False
|
||||
sortagrad: True
|
||||
shuffle_method: batch_shuffle
|
||||
num_workers: 2
|
||||
|
||||
############################################
|
||||
# Network Architecture #
|
||||
############################################
|
||||
num_conv_layers: 2
|
||||
num_rnn_layers: 3
|
||||
rnn_layer_size: 1024
|
||||
use_gru: True
|
||||
share_rnn_weights: False
|
||||
blank_id: 4333
|
||||
|
||||
###########################################
|
||||
# Training #
|
||||
###########################################
|
||||
n_epoch: 80
|
||||
accum_grad: 1
|
||||
lr: 2e-3
|
||||
lr_decay: 0.83
|
||||
weight_decay: 1e-06
|
||||
global_grad_clip: 3.0
|
||||
log_interval: 100
|
||||
checkpoint:
|
||||
kbest_n: 50
|
||||
latest_n: 5
|
||||
|
||||
|
@ -1,10 +0,0 @@
|
||||
decode_batch_size: 32
|
||||
error_rate_type: cer
|
||||
decoding_method: ctc_beam_search
|
||||
lang_model_path: data/lm/zh_giga.no_cna_cmn.prune01244.klm
|
||||
alpha: 2.6
|
||||
beta: 5.0
|
||||
beam_size: 300
|
||||
cutoff_prob: 0.99
|
||||
cutoff_top_n: 40
|
||||
num_proc_bsearch: 8
|
@ -1,69 +0,0 @@
|
||||
#!/bin/bash
|
||||
if [ $# != 1 ];then
|
||||
echo "usage: ${0} ckpt_dir"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
ckpt_dir=$1
|
||||
|
||||
stage=-1
|
||||
stop_stage=100
|
||||
|
||||
source ${MAIN_ROOT}/utils/parse_options.sh
|
||||
|
||||
mkdir -p data
|
||||
TARGET_DIR=${MAIN_ROOT}/dataset
|
||||
mkdir -p ${TARGET_DIR}
|
||||
|
||||
bash local/download_model.sh ${ckpt_dir}
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
cd ${ckpt_dir}
|
||||
tar xzvf aishell_model_v1.8_to_v2.x.tar.gz
|
||||
cd -
|
||||
mv ${ckpt_dir}/mean_std.npz data/
|
||||
mv ${ckpt_dir}/vocab.txt data/
|
||||
|
||||
|
||||
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
|
||||
# download data, generate manifests
|
||||
python3 ${TARGET_DIR}/aishell/aishell.py \
|
||||
--manifest_prefix="data/manifest" \
|
||||
--target_dir="${TARGET_DIR}/aishell"
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Prepare Aishell failed. Terminated."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
for dataset in train dev test; do
|
||||
mv data/manifest.${dataset} data/manifest.${dataset}.raw
|
||||
done
|
||||
fi
|
||||
|
||||
|
||||
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
# format manifest with tokenids, vocab size
|
||||
for dataset in train dev test; do
|
||||
{
|
||||
python3 ${MAIN_ROOT}/utils/format_data.py \
|
||||
--cmvn_path "data/mean_std.npz" \
|
||||
--unit_type "char" \
|
||||
--vocab_path="data/vocab.txt" \
|
||||
--manifest_path="data/manifest.${dataset}.raw" \
|
||||
--output_path="data/manifest.${dataset}"
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Formt mnaifest failed. Terminated."
|
||||
exit 1
|
||||
fi
|
||||
} &
|
||||
done
|
||||
wait
|
||||
fi
|
||||
|
||||
echo "Aishell data preparation done."
|
||||
exit 0
|
@ -1,23 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
. ${MAIN_ROOT}/utils/utility.sh
|
||||
|
||||
DIR=data/lm
|
||||
mkdir -p ${DIR}
|
||||
|
||||
URL='https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm'
|
||||
MD5="29e02312deb2e59b3c8686c7966d4fe3"
|
||||
TARGET=${DIR}/zh_giga.no_cna_cmn.prune01244.klm
|
||||
|
||||
|
||||
echo "Start downloading the language model. The language model is large, please wait for a moment ..."
|
||||
download $URL $MD5 $TARGET > /dev/null 2>&1
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Fail to download the language model!"
|
||||
exit 1
|
||||
else
|
||||
echo "Download the language model sucessfully"
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -1,25 +0,0 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
if [ $# != 1 ];then
|
||||
echo "usage: ${0} ckpt_dir"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
ckpt_dir=$1
|
||||
|
||||
. ${MAIN_ROOT}/utils/utility.sh
|
||||
|
||||
URL='https://deepspeech.bj.bcebos.com/mandarin_models/aishell_model_v1.8_to_v2.x.tar.gz'
|
||||
MD5=87e7577d4bea737dbf3e8daab37aa808
|
||||
TARGET=${ckpt_dir}/aishell_model_v1.8_to_v2.x.tar.gz
|
||||
|
||||
|
||||
echo "Download Aishell model ..."
|
||||
download $URL $MD5 $TARGET
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Fail to download Aishell model!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -1,36 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
if [ $# != 4 ];then
|
||||
echo "usage: ${0} config_path decode_config_path ckpt_path_prefix model_type"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
||||
echo "using $ngpu gpus..."
|
||||
|
||||
config_path=$1
|
||||
decode_config_path=$2
|
||||
ckpt_prefix=$3
|
||||
model_type=$4
|
||||
|
||||
# download language model
|
||||
bash local/download_lm_ch.sh
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
python3 -u ${BIN_DIR}/test.py \
|
||||
--ngpu ${ngpu} \
|
||||
--config ${config_path} \
|
||||
--decode_cfg ${decode_config_path} \
|
||||
--result_file ${ckpt_prefix}.rsl \
|
||||
--checkpoint_path ${ckpt_prefix} \
|
||||
--model_type ${model_type}
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in evaluation!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -1,17 +0,0 @@
|
||||
export MAIN_ROOT=`realpath ${PWD}/../../../../`
|
||||
export LOCAL_DEEPSPEECH2=`realpath ${PWD}/../`
|
||||
|
||||
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
|
||||
export LC_ALL=C
|
||||
|
||||
export PYTHONDONTWRITEBYTECODE=1
|
||||
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
|
||||
export PYTHONIOENCODING=UTF-8
|
||||
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
|
||||
export PYTHONPATH=${LOCAL_DEEPSPEECH2}:${PYTHONPATH}
|
||||
|
||||
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
|
||||
|
||||
MODEL=deepspeech2
|
||||
export BIN_DIR=${LOCAL_DEEPSPEECH2}/src_deepspeech2x/bin
|
||||
echo "BIN_DIR "${BIN_DIR}
|
@ -1,29 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
source path.sh
|
||||
|
||||
stage=0
|
||||
stop_stage=100
|
||||
conf_path=conf/deepspeech2.yaml
|
||||
decode_conf_path=conf/tuning/decode.yaml
|
||||
avg_num=1
|
||||
model_type=offline
|
||||
gpus=2
|
||||
|
||||
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
|
||||
|
||||
v18_ckpt=aishell_v1.8
|
||||
ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}')
|
||||
echo "checkpoint name ${ckpt}"
|
||||
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
# prepare data
|
||||
mkdir -p exp/${ckpt}/checkpoints
|
||||
bash ./local/data.sh exp/${ckpt}/checkpoints || exit -1
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
# test ckpt avg_n
|
||||
CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${v18_ckpt} ${model_type}|| exit -1
|
||||
fi
|
||||
|
@ -1,5 +0,0 @@
|
||||
exp
|
||||
data
|
||||
*log
|
||||
tmp
|
||||
nohup*
|
@ -1 +0,0 @@
|
||||
[]
|
@ -1,64 +0,0 @@
|
||||
# https://yaml.org/type/float.html
|
||||
###########################################
|
||||
# Data #
|
||||
###########################################
|
||||
train_manifest: data/manifest.train
|
||||
dev_manifest: data/manifest.dev
|
||||
test_manifest: data/manifest.test-clean
|
||||
min_input_len: 0.0
|
||||
max_input_len: .inf # second
|
||||
min_output_len: 0.0
|
||||
max_output_len: .inf
|
||||
min_output_input_ratio: 0.00
|
||||
max_output_input_ratio: .inf
|
||||
|
||||
###########################################
|
||||
# Dataloader #
|
||||
###########################################
|
||||
batch_size: 64 # one gpu
|
||||
mean_std_filepath: data/mean_std.npz
|
||||
unit_type: char
|
||||
vocab_filepath: data/vocab.txt
|
||||
augmentation_config: conf/augmentation.json
|
||||
random_seed: 0
|
||||
spm_model_prefix:
|
||||
spectrum_type: linear
|
||||
feat_dim:
|
||||
delta_delta: False
|
||||
stride_ms: 10.0
|
||||
window_ms: 20.0
|
||||
n_fft: None
|
||||
max_freq: None
|
||||
target_sample_rate: 16000
|
||||
use_dB_normalization: True
|
||||
target_dB: -20
|
||||
dither: 1.0
|
||||
keep_transcription_text: False
|
||||
sortagrad: True
|
||||
shuffle_method: batch_shuffle
|
||||
num_workers: 2
|
||||
|
||||
############################################
|
||||
# Network Architecture #
|
||||
############################################
|
||||
num_conv_layers: 2
|
||||
num_rnn_layers: 3
|
||||
rnn_layer_size: 1024
|
||||
use_gru: True
|
||||
share_rnn_weights: False
|
||||
blank_id: 28
|
||||
|
||||
###########################################
|
||||
# Training #
|
||||
###########################################
|
||||
n_epoch: 80
|
||||
accum_grad: 1
|
||||
lr: 2e-3
|
||||
lr_decay: 0.83
|
||||
weight_decay: 1e-06
|
||||
global_grad_clip: 3.0
|
||||
log_interval: 100
|
||||
checkpoint:
|
||||
kbest_n: 50
|
||||
latest_n: 5
|
||||
|
@ -1,10 +0,0 @@
|
||||
decode_batch_size: 32
|
||||
error_rate_type: wer
|
||||
decoding_method: ctc_beam_search
|
||||
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
|
||||
alpha: 1.4
|
||||
beta: 0.35
|
||||
beam_size: 500
|
||||
cutoff_prob: 1.0
|
||||
cutoff_top_n: 40
|
||||
num_proc_bsearch: 8
|
@ -1,85 +0,0 @@
|
||||
#!/bin/bash
|
||||
if [ $# != 1 ];then
|
||||
echo "usage: ${0} ckpt_dir"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
ckpt_dir=$1
|
||||
|
||||
stage=-1
|
||||
stop_stage=100
|
||||
unit_type=char
|
||||
|
||||
source ${MAIN_ROOT}/utils/parse_options.sh
|
||||
|
||||
mkdir -p data
|
||||
TARGET_DIR=${MAIN_ROOT}/dataset
|
||||
mkdir -p ${TARGET_DIR}
|
||||
|
||||
|
||||
bash local/download_model.sh ${ckpt_dir}
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
cd ${ckpt_dir}
|
||||
tar xzvf baidu_en8k_v1.8_to_v2.x.tar.gz
|
||||
cd -
|
||||
mv ${ckpt_dir}/mean_std.npz data/
|
||||
mv ${ckpt_dir}/vocab.txt data/
|
||||
|
||||
|
||||
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
|
||||
# download data, generate manifests
|
||||
python3 ${TARGET_DIR}/librispeech/librispeech.py \
|
||||
--manifest_prefix="data/manifest" \
|
||||
--target_dir="${TARGET_DIR}/librispeech" \
|
||||
--full_download="True"
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Prepare LibriSpeech failed. Terminated."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
for set in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
|
||||
mv data/manifest.${set} data/manifest.${set}.raw
|
||||
done
|
||||
|
||||
rm -rf data/manifest.train.raw data/manifest.dev.raw data/manifest.test.raw
|
||||
for set in train-clean-100 train-clean-360 train-other-500; do
|
||||
cat data/manifest.${set}.raw >> data/manifest.train.raw
|
||||
done
|
||||
|
||||
for set in dev-clean dev-other; do
|
||||
cat data/manifest.${set}.raw >> data/manifest.dev.raw
|
||||
done
|
||||
|
||||
for set in test-clean test-other; do
|
||||
cat data/manifest.${set}.raw >> data/manifest.test.raw
|
||||
done
|
||||
fi
|
||||
|
||||
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
# format manifest with tokenids, vocab size
|
||||
for set in train dev test dev-clean dev-other test-clean test-other; do
|
||||
{
|
||||
python3 ${MAIN_ROOT}/utils/format_data.py \
|
||||
--cmvn_path "data/mean_std.npz" \
|
||||
--unit_type ${unit_type} \
|
||||
--vocab_path="data/vocab.txt" \
|
||||
--manifest_path="data/manifest.${set}.raw" \
|
||||
--output_path="data/manifest.${set}"
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Formt mnaifest.${set} failed. Terminated."
|
||||
exit 1
|
||||
fi
|
||||
}&
|
||||
done
|
||||
wait
|
||||
fi
|
||||
|
||||
echo "LibriSpeech Data preparation done."
|
||||
exit 0
|
||||
|
@ -1,22 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
. ${MAIN_ROOT}/utils/utility.sh
|
||||
|
||||
DIR=data/lm
|
||||
mkdir -p ${DIR}
|
||||
|
||||
URL=https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm
|
||||
MD5="099a601759d467cd0a8523ff939819c5"
|
||||
TARGET=${DIR}/common_crawl_00.prune01111.trie.klm
|
||||
|
||||
echo "Start downloading the language model. The language model is large, please wait for a moment ..."
|
||||
download $URL $MD5 $TARGET > /dev/null 2>&1
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Fail to download the language model!"
|
||||
exit 1
|
||||
else
|
||||
echo "Download the language model sucessfully"
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -1,25 +0,0 @@
|
||||
#! /usr/bin/env bash
|
||||
if [ $# != 1 ];then
|
||||
echo "usage: ${0} ckpt_dir"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
ckpt_dir=$1
|
||||
|
||||
|
||||
. ${MAIN_ROOT}/utils/utility.sh
|
||||
|
||||
URL='https://deepspeech.bj.bcebos.com/eng_models/baidu_en8k_v1.8_to_v2.x.tar.gz'
|
||||
MD5=c1676be8505cee436e6f312823e9008c
|
||||
TARGET=${ckpt_dir}/baidu_en8k_v1.8_to_v2.x.tar.gz
|
||||
|
||||
|
||||
echo "Download BaiduEn8k model ..."
|
||||
download $URL $MD5 $TARGET
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Fail to download BaiduEn8k model!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -1,36 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
if [ $# != 4 ];then
|
||||
echo "usage: ${0} config_path decode_config_path ckpt_path_prefix model_type"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
||||
echo "using $ngpu gpus..."
|
||||
|
||||
config_path=$1
|
||||
decode_config_path=$2
|
||||
ckpt_prefix=$3
|
||||
model_type=$4
|
||||
|
||||
# download language model
|
||||
bash local/download_lm_en.sh
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
python3 -u ${BIN_DIR}/test.py \
|
||||
--ngpu ${ngpu} \
|
||||
--config ${config_path} \
|
||||
--decode_cfg ${decode_config_path} \
|
||||
--result_file ${ckpt_prefix}.rsl \
|
||||
--checkpoint_path ${ckpt_prefix} \
|
||||
--model_type ${model_type}
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in evaluation!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -1,17 +0,0 @@
|
||||
export MAIN_ROOT=`realpath ${PWD}/../../../../`
|
||||
export LOCAL_DEEPSPEECH2=`realpath ${PWD}/../`
|
||||
|
||||
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
|
||||
export LC_ALL=C
|
||||
|
||||
export PYTHONDONTWRITEBYTECODE=1
|
||||
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
|
||||
export PYTHONIOENCODING=UTF-8
|
||||
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
|
||||
export PYTHONPATH=${LOCAL_DEEPSPEECH2}:${PYTHONPATH}
|
||||
|
||||
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
|
||||
|
||||
MODEL=deepspeech2
|
||||
export BIN_DIR=${LOCAL_DEEPSPEECH2}/src_deepspeech2x/bin
|
||||
echo "BIN_DIR "${BIN_DIR}
|
@ -1,29 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
source path.sh
|
||||
|
||||
stage=0
|
||||
stop_stage=100
|
||||
conf_path=conf/deepspeech2.yaml
|
||||
decode_conf_path=conf/tuning/decode.yaml
|
||||
avg_num=1
|
||||
model_type=offline
|
||||
gpus=0
|
||||
|
||||
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
|
||||
|
||||
v18_ckpt=baidu_en8k_v1.8
|
||||
ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}')
|
||||
echo "checkpoint name ${ckpt}"
|
||||
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
# prepare data
|
||||
mkdir -p exp/${ckpt}/checkpoints
|
||||
bash ./local/data.sh exp/${ckpt}/checkpoints || exit -1
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
# test ckpt avg_n
|
||||
CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${v18_ckpt} ${model_type}|| exit -1
|
||||
fi
|
||||
|
@ -1,5 +0,0 @@
|
||||
exp
|
||||
data
|
||||
*log
|
||||
tmp
|
||||
nohup*
|
@ -1 +0,0 @@
|
||||
[]
|
@ -1,64 +0,0 @@
|
||||
# https://yaml.org/type/float.html
|
||||
###########################################
|
||||
# Data #
|
||||
###########################################
|
||||
train_manifest: data/manifest.train
|
||||
dev_manifest: data/manifest.dev
|
||||
test_manifest: data/manifest.test-clean
|
||||
min_input_len: 0.0
|
||||
max_input_len: 1000.0 # second
|
||||
min_output_len: 0.0
|
||||
max_output_len: .inf
|
||||
min_output_input_ratio: 0.00
|
||||
max_output_input_ratio: .inf
|
||||
|
||||
###########################################
|
||||
# Dataloader #
|
||||
###########################################
|
||||
batch_size: 64 # one gpu
|
||||
mean_std_filepath: data/mean_std.npz
|
||||
unit_type: char
|
||||
vocab_filepath: data/vocab.txt
|
||||
augmentation_config: conf/augmentation.json
|
||||
random_seed: 0
|
||||
spm_model_prefix:
|
||||
spectrum_type: linear
|
||||
feat_dim:
|
||||
delta_delta: False
|
||||
stride_ms: 10.0
|
||||
window_ms: 20.0
|
||||
n_fft: None
|
||||
max_freq: None
|
||||
target_sample_rate: 16000
|
||||
use_dB_normalization: True
|
||||
target_dB: -20
|
||||
dither: 1.0
|
||||
keep_transcription_text: False
|
||||
sortagrad: True
|
||||
shuffle_method: batch_shuffle
|
||||
num_workers: 2
|
||||
|
||||
############################################
|
||||
# Network Architecture #
|
||||
############################################
|
||||
num_conv_layers: 2
|
||||
num_rnn_layers: 3
|
||||
rnn_layer_size: 2048
|
||||
use_gru: False
|
||||
share_rnn_weights: True
|
||||
blank_id: 28
|
||||
|
||||
###########################################
|
||||
# Training #
|
||||
###########################################
|
||||
n_epoch: 80
|
||||
accum_grad: 1
|
||||
lr: 2e-3
|
||||
lr_decay: 0.83
|
||||
weight_decay: 1e-06
|
||||
global_grad_clip: 3.0
|
||||
log_interval: 100
|
||||
checkpoint:
|
||||
kbest_n: 50
|
||||
latest_n: 5
|
||||
|
@ -1,10 +0,0 @@
|
||||
decode_batch_size: 32
|
||||
error_rate_type: wer
|
||||
decoding_method: ctc_beam_search
|
||||
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
|
||||
alpha: 2.5
|
||||
beta: 0.3
|
||||
beam_size: 500
|
||||
cutoff_prob: 1.0
|
||||
cutoff_top_n: 40
|
||||
num_proc_bsearch: 8
|
@ -1,83 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
if [ $# != 1 ];then
|
||||
echo "usage: ${0} ckpt_dir"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
ckpt_dir=$1
|
||||
|
||||
stage=-1
|
||||
stop_stage=100
|
||||
unit_type=char
|
||||
|
||||
source ${MAIN_ROOT}/utils/parse_options.sh
|
||||
|
||||
mkdir -p data
|
||||
TARGET_DIR=${MAIN_ROOT}/dataset
|
||||
mkdir -p ${TARGET_DIR}
|
||||
|
||||
bash local/download_model.sh ${ckpt_dir}
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
cd ${ckpt_dir}
|
||||
tar xzvf librispeech_v1.8_to_v2.x.tar.gz
|
||||
cd -
|
||||
mv ${ckpt_dir}/mean_std.npz data/
|
||||
mv ${ckpt_dir}/vocab.txt data/
|
||||
|
||||
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
|
||||
# download data, generate manifests
|
||||
python3 ${TARGET_DIR}/librispeech/librispeech.py \
|
||||
--manifest_prefix="data/manifest" \
|
||||
--target_dir="${TARGET_DIR}/librispeech" \
|
||||
--full_download="True"
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Prepare LibriSpeech failed. Terminated."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
for set in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
|
||||
mv data/manifest.${set} data/manifest.${set}.raw
|
||||
done
|
||||
|
||||
rm -rf data/manifest.train.raw data/manifest.dev.raw data/manifest.test.raw
|
||||
for set in train-clean-100 train-clean-360 train-other-500; do
|
||||
cat data/manifest.${set}.raw >> data/manifest.train.raw
|
||||
done
|
||||
|
||||
for set in dev-clean dev-other; do
|
||||
cat data/manifest.${set}.raw >> data/manifest.dev.raw
|
||||
done
|
||||
|
||||
for set in test-clean test-other; do
|
||||
cat data/manifest.${set}.raw >> data/manifest.test.raw
|
||||
done
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
# format manifest with tokenids, vocab size
|
||||
for set in train dev test dev-clean dev-other test-clean test-other; do
|
||||
{
|
||||
python3 ${MAIN_ROOT}/utils/format_data.py \
|
||||
--cmvn_path "data/mean_std.npz" \
|
||||
--unit_type ${unit_type} \
|
||||
--vocab_path="data/vocab.txt" \
|
||||
--manifest_path="data/manifest.${set}.raw" \
|
||||
--output_path="data/manifest.${set}"
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Formt mnaifest.${set} failed. Terminated."
|
||||
exit 1
|
||||
fi
|
||||
}&
|
||||
done
|
||||
wait
|
||||
fi
|
||||
|
||||
echo "LibriSpeech Data preparation done."
|
||||
exit 0
|
||||
|
@ -1,22 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
. ${MAIN_ROOT}/utils/utility.sh
|
||||
|
||||
DIR=data/lm
|
||||
mkdir -p ${DIR}
|
||||
|
||||
URL=https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm
|
||||
MD5="099a601759d467cd0a8523ff939819c5"
|
||||
TARGET=${DIR}/common_crawl_00.prune01111.trie.klm
|
||||
|
||||
echo "Start downloading the language model. The language model is large, please wait for a moment ..."
|
||||
download $URL $MD5 $TARGET > /dev/null 2>&1
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Fail to download the language model!"
|
||||
exit 1
|
||||
else
|
||||
echo "Download the language model sucessfully"
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -1,25 +0,0 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
if [ $# != 1 ];then
|
||||
echo "usage: ${0} ckpt_dir"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
ckpt_dir=$1
|
||||
|
||||
. ${MAIN_ROOT}/utils/utility.sh
|
||||
|
||||
URL='https://deepspeech.bj.bcebos.com/eng_models/librispeech_v1.8_to_v2.x.tar.gz'
|
||||
MD5=a06d9aadb560ea113984dc98d67232c8
|
||||
TARGET=${ckpt_dir}/librispeech_v1.8_to_v2.x.tar.gz
|
||||
|
||||
|
||||
echo "Download LibriSpeech model ..."
|
||||
download $URL $MD5 $TARGET
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Fail to download LibriSpeech model!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -1,36 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
if [ $# != 4 ];then
|
||||
echo "usage: ${0} config_path decode_config_path ckpt_path_prefix model_type"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
||||
echo "using $ngpu gpus..."
|
||||
|
||||
config_path=$1
|
||||
decode_config_path=$2
|
||||
ckpt_prefix=$3
|
||||
model_type=$4
|
||||
|
||||
# download language model
|
||||
bash local/download_lm_en.sh
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
python3 -u ${BIN_DIR}/test.py \
|
||||
--ngpu ${ngpu} \
|
||||
--config ${config_path} \
|
||||
--decode_cfg ${decode_config_path} \
|
||||
--result_file ${ckpt_prefix}.rsl \
|
||||
--checkpoint_path ${ckpt_prefix} \
|
||||
--model_type ${model_type}
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in evaluation!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -1,16 +0,0 @@
|
||||
export MAIN_ROOT=`realpath ${PWD}/../../../../`
|
||||
export LOCAL_DEEPSPEECH2=`realpath ${PWD}/../`
|
||||
|
||||
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
|
||||
export LC_ALL=C
|
||||
|
||||
export PYTHONDONTWRITEBYTECODE=1
|
||||
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
|
||||
export PYTHONIOENCODING=UTF-8
|
||||
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
|
||||
export PYTHONPATH=${LOCAL_DEEPSPEECH2}:${PYTHONPATH}
|
||||
|
||||
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
|
||||
|
||||
MODEL=deepspeech2
|
||||
export BIN_DIR=${LOCAL_DEEPSPEECH2}/src_deepspeech2x/bin
|
@ -1,28 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
source path.sh
|
||||
|
||||
stage=0
|
||||
stop_stage=100
|
||||
conf_path=conf/deepspeech2.yaml
|
||||
decode_conf_path=conf/tuning/decode.yaml
|
||||
avg_num=1
|
||||
model_type=offline
|
||||
gpus=1
|
||||
|
||||
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
|
||||
|
||||
v18_ckpt=librispeech_v1.8
|
||||
ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}')
|
||||
echo "checkpoint name ${ckpt}"
|
||||
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
# prepare data
|
||||
mkdir -p exp/${ckpt}/checkpoints
|
||||
bash ./local/data.sh exp/${ckpt}/checkpoints || exit -1
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
# test ckpt avg_n
|
||||
CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${v18_ckpt} ${model_type}|| exit -1
|
||||
fi
|
@ -1,370 +0,0 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.fluid import core
|
||||
from paddle.nn import functional as F
|
||||
|
||||
from paddlespeech.s2t.utils.log import Log
|
||||
|
||||
#TODO(Hui Zhang): remove fluid import
|
||||
logger = Log(__name__).getlog()
|
||||
|
||||
########### hack logging #############
|
||||
logger.warn = logger.warning
|
||||
|
||||
########### hack paddle #############
|
||||
paddle.half = 'float16'
|
||||
paddle.float = 'float32'
|
||||
paddle.double = 'float64'
|
||||
paddle.short = 'int16'
|
||||
paddle.int = 'int32'
|
||||
paddle.long = 'int64'
|
||||
paddle.uint16 = 'uint16'
|
||||
paddle.cdouble = 'complex128'
|
||||
|
||||
|
||||
def convert_dtype_to_string(tensor_dtype):
|
||||
"""
|
||||
Convert the data type in numpy to the data type in Paddle
|
||||
Args:
|
||||
tensor_dtype(core.VarDesc.VarType): the data type in numpy.
|
||||
Returns:
|
||||
core.VarDesc.VarType: the data type in Paddle.
|
||||
"""
|
||||
dtype = tensor_dtype
|
||||
if dtype == core.VarDesc.VarType.FP32:
|
||||
return paddle.float32
|
||||
elif dtype == core.VarDesc.VarType.FP64:
|
||||
return paddle.float64
|
||||
elif dtype == core.VarDesc.VarType.FP16:
|
||||
return paddle.float16
|
||||
elif dtype == core.VarDesc.VarType.INT32:
|
||||
return paddle.int32
|
||||
elif dtype == core.VarDesc.VarType.INT16:
|
||||
return paddle.int16
|
||||
elif dtype == core.VarDesc.VarType.INT64:
|
||||
return paddle.int64
|
||||
elif dtype == core.VarDesc.VarType.BOOL:
|
||||
return paddle.bool
|
||||
elif dtype == core.VarDesc.VarType.BF16:
|
||||
# since there is still no support for bfloat16 in NumPy,
|
||||
# uint16 is used for casting bfloat16
|
||||
return paddle.uint16
|
||||
elif dtype == core.VarDesc.VarType.UINT8:
|
||||
return paddle.uint8
|
||||
elif dtype == core.VarDesc.VarType.INT8:
|
||||
return paddle.int8
|
||||
elif dtype == core.VarDesc.VarType.COMPLEX64:
|
||||
return paddle.complex64
|
||||
elif dtype == core.VarDesc.VarType.COMPLEX128:
|
||||
return paddle.complex128
|
||||
else:
|
||||
raise ValueError("Not supported tensor dtype %s" % dtype)
|
||||
|
||||
|
||||
if not hasattr(paddle, 'softmax'):
|
||||
logger.warn("register user softmax to paddle, remove this when fixed!")
|
||||
setattr(paddle, 'softmax', paddle.nn.functional.softmax)
|
||||
|
||||
if not hasattr(paddle, 'log_softmax'):
|
||||
logger.warn("register user log_softmax to paddle, remove this when fixed!")
|
||||
setattr(paddle, 'log_softmax', paddle.nn.functional.log_softmax)
|
||||
|
||||
if not hasattr(paddle, 'sigmoid'):
|
||||
logger.warn("register user sigmoid to paddle, remove this when fixed!")
|
||||
setattr(paddle, 'sigmoid', paddle.nn.functional.sigmoid)
|
||||
|
||||
if not hasattr(paddle, 'log_sigmoid'):
|
||||
logger.warn("register user log_sigmoid to paddle, remove this when fixed!")
|
||||
setattr(paddle, 'log_sigmoid', paddle.nn.functional.log_sigmoid)
|
||||
|
||||
if not hasattr(paddle, 'relu'):
|
||||
logger.warn("register user relu to paddle, remove this when fixed!")
|
||||
setattr(paddle, 'relu', paddle.nn.functional.relu)
|
||||
|
||||
|
||||
def cat(xs, dim=0):
|
||||
return paddle.concat(xs, axis=dim)
|
||||
|
||||
|
||||
if not hasattr(paddle, 'cat'):
|
||||
logger.warn(
|
||||
"override cat of paddle if exists or register, remove this when fixed!")
|
||||
paddle.cat = cat
|
||||
|
||||
|
||||
########### hack paddle.Tensor #############
|
||||
def item(x: paddle.Tensor):
|
||||
return x.numpy().item()
|
||||
|
||||
|
||||
if not hasattr(paddle.Tensor, 'item'):
|
||||
logger.warn(
|
||||
"override item of paddle.Tensor if exists or register, remove this when fixed!"
|
||||
)
|
||||
paddle.Tensor.item = item
|
||||
|
||||
|
||||
def func_long(x: paddle.Tensor):
|
||||
return paddle.cast(x, paddle.long)
|
||||
|
||||
|
||||
if not hasattr(paddle.Tensor, 'long'):
|
||||
logger.warn(
|
||||
"override long of paddle.Tensor if exists or register, remove this when fixed!"
|
||||
)
|
||||
paddle.Tensor.long = func_long
|
||||
|
||||
if not hasattr(paddle.Tensor, 'numel'):
|
||||
logger.warn(
|
||||
"override numel of paddle.Tensor if exists or register, remove this when fixed!"
|
||||
)
|
||||
paddle.Tensor.numel = paddle.numel
|
||||
|
||||
|
||||
def new_full(x: paddle.Tensor,
|
||||
size: Union[List[int], Tuple[int], paddle.Tensor],
|
||||
fill_value: Union[float, int, bool, paddle.Tensor],
|
||||
dtype=None):
|
||||
return paddle.full(size, fill_value, dtype=x.dtype)
|
||||
|
||||
|
||||
if not hasattr(paddle.Tensor, 'new_full'):
|
||||
logger.warn(
|
||||
"override new_full of paddle.Tensor if exists or register, remove this when fixed!"
|
||||
)
|
||||
paddle.Tensor.new_full = new_full
|
||||
|
||||
|
||||
def eq(xs: paddle.Tensor, ys: Union[paddle.Tensor, float]) -> paddle.Tensor:
|
||||
if convert_dtype_to_string(xs.dtype) == paddle.bool:
|
||||
xs = xs.astype(paddle.int)
|
||||
return xs.equal(
|
||||
paddle.to_tensor(
|
||||
ys, dtype=convert_dtype_to_string(xs.dtype), place=xs.place))
|
||||
|
||||
|
||||
if not hasattr(paddle.Tensor, 'eq'):
|
||||
logger.warn(
|
||||
"override eq of paddle.Tensor if exists or register, remove this when fixed!"
|
||||
)
|
||||
paddle.Tensor.eq = eq
|
||||
|
||||
if not hasattr(paddle, 'eq'):
|
||||
logger.warn(
|
||||
"override eq of paddle if exists or register, remove this when fixed!")
|
||||
paddle.eq = eq
|
||||
|
||||
|
||||
def contiguous(xs: paddle.Tensor) -> paddle.Tensor:
|
||||
return xs
|
||||
|
||||
|
||||
if not hasattr(paddle.Tensor, 'contiguous'):
|
||||
logger.warn(
|
||||
"override contiguous of paddle.Tensor if exists or register, remove this when fixed!"
|
||||
)
|
||||
paddle.Tensor.contiguous = contiguous
|
||||
|
||||
|
||||
def size(xs: paddle.Tensor, *args: int) -> paddle.Tensor:
|
||||
nargs = len(args)
|
||||
assert (nargs <= 1)
|
||||
s = paddle.shape(xs)
|
||||
if nargs == 1:
|
||||
return s[args[0]]
|
||||
else:
|
||||
return s
|
||||
|
||||
|
||||
#`to_static` do not process `size` property, maybe some `paddle` api dependent on it.
|
||||
logger.warn(
|
||||
"override size of paddle.Tensor "
|
||||
"(`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!"
|
||||
)
|
||||
paddle.Tensor.size = size
|
||||
|
||||
|
||||
def view(xs: paddle.Tensor, *args: int) -> paddle.Tensor:
|
||||
return xs.reshape(args)
|
||||
|
||||
|
||||
if not hasattr(paddle.Tensor, 'view'):
|
||||
logger.warn("register user view to paddle.Tensor, remove this when fixed!")
|
||||
paddle.Tensor.view = view
|
||||
|
||||
|
||||
def view_as(xs: paddle.Tensor, ys: paddle.Tensor) -> paddle.Tensor:
|
||||
return xs.reshape(ys.size())
|
||||
|
||||
|
||||
if not hasattr(paddle.Tensor, 'view_as'):
|
||||
logger.warn(
|
||||
"register user view_as to paddle.Tensor, remove this when fixed!")
|
||||
paddle.Tensor.view_as = view_as
|
||||
|
||||
|
||||
def is_broadcastable(shp1, shp2):
|
||||
for a, b in zip(shp1[::-1], shp2[::-1]):
|
||||
if a == 1 or b == 1 or a == b:
|
||||
pass
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def masked_fill(xs: paddle.Tensor,
|
||||
mask: paddle.Tensor,
|
||||
value: Union[float, int]):
|
||||
assert is_broadcastable(xs.shape, mask.shape) is True
|
||||
bshape = paddle.broadcast_shape(xs.shape, mask.shape)
|
||||
mask = mask.broadcast_to(bshape)
|
||||
trues = paddle.ones_like(xs) * value
|
||||
xs = paddle.where(mask, trues, xs)
|
||||
return xs
|
||||
|
||||
|
||||
if not hasattr(paddle.Tensor, 'masked_fill'):
|
||||
logger.warn(
|
||||
"register user masked_fill to paddle.Tensor, remove this when fixed!")
|
||||
paddle.Tensor.masked_fill = masked_fill
|
||||
|
||||
|
||||
def masked_fill_(xs: paddle.Tensor,
|
||||
mask: paddle.Tensor,
|
||||
value: Union[float, int]) -> paddle.Tensor:
|
||||
assert is_broadcastable(xs.shape, mask.shape) is True
|
||||
bshape = paddle.broadcast_shape(xs.shape, mask.shape)
|
||||
mask = mask.broadcast_to(bshape)
|
||||
trues = paddle.ones_like(xs) * value
|
||||
ret = paddle.where(mask, trues, xs)
|
||||
paddle.assign(ret.detach(), output=xs)
|
||||
return xs
|
||||
|
||||
|
||||
if not hasattr(paddle.Tensor, 'masked_fill_'):
|
||||
logger.warn(
|
||||
"register user masked_fill_ to paddle.Tensor, remove this when fixed!")
|
||||
paddle.Tensor.masked_fill_ = masked_fill_
|
||||
|
||||
|
||||
def fill_(xs: paddle.Tensor, value: Union[float, int]) -> paddle.Tensor:
|
||||
val = paddle.full_like(xs, value)
|
||||
paddle.assign(val.detach(), output=xs)
|
||||
return xs
|
||||
|
||||
|
||||
if not hasattr(paddle.Tensor, 'fill_'):
|
||||
logger.warn("register user fill_ to paddle.Tensor, remove this when fixed!")
|
||||
paddle.Tensor.fill_ = fill_
|
||||
|
||||
|
||||
def repeat(xs: paddle.Tensor, *size: Any) -> paddle.Tensor:
|
||||
return paddle.tile(xs, size)
|
||||
|
||||
|
||||
if not hasattr(paddle.Tensor, 'repeat'):
|
||||
logger.warn(
|
||||
"register user repeat to paddle.Tensor, remove this when fixed!")
|
||||
paddle.Tensor.repeat = repeat
|
||||
|
||||
if not hasattr(paddle.Tensor, 'softmax'):
|
||||
logger.warn(
|
||||
"register user softmax to paddle.Tensor, remove this when fixed!")
|
||||
setattr(paddle.Tensor, 'softmax', paddle.nn.functional.softmax)
|
||||
|
||||
if not hasattr(paddle.Tensor, 'sigmoid'):
|
||||
logger.warn(
|
||||
"register user sigmoid to paddle.Tensor, remove this when fixed!")
|
||||
setattr(paddle.Tensor, 'sigmoid', paddle.nn.functional.sigmoid)
|
||||
|
||||
if not hasattr(paddle.Tensor, 'relu'):
|
||||
logger.warn("register user relu to paddle.Tensor, remove this when fixed!")
|
||||
setattr(paddle.Tensor, 'relu', paddle.nn.functional.relu)
|
||||
|
||||
|
||||
def type_as(x: paddle.Tensor, other: paddle.Tensor) -> paddle.Tensor:
|
||||
return x.astype(other.dtype)
|
||||
|
||||
|
||||
if not hasattr(paddle.Tensor, 'type_as'):
|
||||
logger.warn(
|
||||
"register user type_as to paddle.Tensor, remove this when fixed!")
|
||||
setattr(paddle.Tensor, 'type_as', type_as)
|
||||
|
||||
|
||||
def to(x: paddle.Tensor, *args, **kwargs) -> paddle.Tensor:
|
||||
assert len(args) == 1
|
||||
if isinstance(args[0], str): # dtype
|
||||
return x.astype(args[0])
|
||||
elif isinstance(args[0], paddle.Tensor): #Tensor
|
||||
return x.astype(args[0].dtype)
|
||||
else: # Device
|
||||
return x
|
||||
|
||||
|
||||
if not hasattr(paddle.Tensor, 'to'):
|
||||
logger.warn("register user to to paddle.Tensor, remove this when fixed!")
|
||||
setattr(paddle.Tensor, 'to', to)
|
||||
|
||||
|
||||
def func_float(x: paddle.Tensor) -> paddle.Tensor:
|
||||
return x.astype(paddle.float)
|
||||
|
||||
|
||||
if not hasattr(paddle.Tensor, 'float'):
|
||||
logger.warn("register user float to paddle.Tensor, remove this when fixed!")
|
||||
setattr(paddle.Tensor, 'float', func_float)
|
||||
|
||||
|
||||
def func_int(x: paddle.Tensor) -> paddle.Tensor:
|
||||
return x.astype(paddle.int)
|
||||
|
||||
|
||||
if not hasattr(paddle.Tensor, 'int'):
|
||||
logger.warn("register user int to paddle.Tensor, remove this when fixed!")
|
||||
setattr(paddle.Tensor, 'int', func_int)
|
||||
|
||||
|
||||
def tolist(x: paddle.Tensor) -> List[Any]:
|
||||
return x.numpy().tolist()
|
||||
|
||||
|
||||
if not hasattr(paddle.Tensor, 'tolist'):
|
||||
logger.warn(
|
||||
"register user tolist to paddle.Tensor, remove this when fixed!")
|
||||
setattr(paddle.Tensor, 'tolist', tolist)
|
||||
|
||||
|
||||
########### hack paddle.nn #############
|
||||
class GLU(nn.Layer):
|
||||
"""Gated Linear Units (GLU) Layer"""
|
||||
|
||||
def __init__(self, dim: int=-1):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, xs):
|
||||
return F.glu(xs, axis=self.dim)
|
||||
|
||||
|
||||
if not hasattr(paddle.nn, 'GLU'):
|
||||
logger.warn("register user GLU to paddle.nn, remove this when fixed!")
|
||||
setattr(paddle.nn, 'GLU', GLU)
|
@ -1,59 +0,0 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Evaluation for DeepSpeech2 model."""
|
||||
from src_deepspeech2x.test_model import DeepSpeech2Tester as Tester
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from paddlespeech.s2t.training.cli import default_argument_parser
|
||||
from paddlespeech.s2t.utils.utility import print_arguments
|
||||
|
||||
|
||||
def main_sp(config, args):
|
||||
exp = Tester(config, args)
|
||||
exp.setup()
|
||||
exp.run_test()
|
||||
|
||||
|
||||
def main(config, args):
|
||||
main_sp(config, args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = default_argument_parser()
|
||||
parser.add_argument(
|
||||
"--model_type", type=str, default='offline', help='offline/online')
|
||||
# save asr result to
|
||||
parser.add_argument(
|
||||
"--result_file", type=str, help="path of save the asr result")
|
||||
args = parser.parse_args()
|
||||
print_arguments(args, globals())
|
||||
print("model_type:{}".format(args.model_type))
|
||||
|
||||
# https://yaml.org/type/float.html
|
||||
config = CfgNode(new_allowed=True)
|
||||
if args.config:
|
||||
config.merge_from_file(args.config)
|
||||
if args.decode_cfg:
|
||||
decode_confs = CfgNode(new_allowed=True)
|
||||
decode_confs.merge_from_file(args.decode_cfg)
|
||||
config.decode = decode_confs
|
||||
if args.opts:
|
||||
config.merge_from_list(args.opts)
|
||||
config.freeze()
|
||||
print(config)
|
||||
if args.dump_config:
|
||||
with open(args.dump_config, 'w') as f:
|
||||
print(config, file=f)
|
||||
|
||||
main(config, args)
|
@ -1,13 +0,0 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -1,17 +0,0 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from .deepspeech2 import DeepSpeech2InferModel
|
||||
from .deepspeech2 import DeepSpeech2Model
|
||||
|
||||
__all__ = ['DeepSpeech2Model', 'DeepSpeech2InferModel']
|
@ -1,275 +0,0 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Deepspeech2 ASR Model"""
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from src_deepspeech2x.models.ds2.rnn import RNNStack
|
||||
|
||||
from paddlespeech.s2t.models.ds2.conv import ConvStack
|
||||
from paddlespeech.s2t.modules.ctc import CTCDecoder
|
||||
from paddlespeech.s2t.utils import layer_tools
|
||||
from paddlespeech.s2t.utils.checkpoint import Checkpoint
|
||||
from paddlespeech.s2t.utils.log import Log
|
||||
logger = Log(__name__).getlog()
|
||||
|
||||
__all__ = ['DeepSpeech2Model', 'DeepSpeech2InferModel']
|
||||
|
||||
|
||||
class CRNNEncoder(nn.Layer):
|
||||
def __init__(self,
|
||||
feat_size,
|
||||
dict_size,
|
||||
num_conv_layers=2,
|
||||
num_rnn_layers=3,
|
||||
rnn_size=1024,
|
||||
use_gru=False,
|
||||
share_rnn_weights=True):
|
||||
super().__init__()
|
||||
self.rnn_size = rnn_size
|
||||
self.feat_size = feat_size # 161 for linear
|
||||
self.dict_size = dict_size
|
||||
|
||||
self.conv = ConvStack(feat_size, num_conv_layers)
|
||||
|
||||
i_size = self.conv.output_height # H after conv stack
|
||||
self.rnn = RNNStack(
|
||||
i_size=i_size,
|
||||
h_size=rnn_size,
|
||||
num_stacks=num_rnn_layers,
|
||||
use_gru=use_gru,
|
||||
share_rnn_weights=share_rnn_weights)
|
||||
|
||||
@property
|
||||
def output_size(self):
|
||||
return self.rnn_size * 2
|
||||
|
||||
def forward(self, audio, audio_len):
|
||||
"""Compute Encoder outputs
|
||||
|
||||
Args:
|
||||
audio (Tensor): [B, Tmax, D]
|
||||
text (Tensor): [B, Umax]
|
||||
audio_len (Tensor): [B]
|
||||
text_len (Tensor): [B]
|
||||
Returns:
|
||||
x (Tensor): encoder outputs, [B, T, D]
|
||||
x_lens (Tensor): encoder length, [B]
|
||||
"""
|
||||
# [B, T, D] -> [B, D, T]
|
||||
audio = audio.transpose([0, 2, 1])
|
||||
# [B, D, T] -> [B, C=1, D, T]
|
||||
x = audio.unsqueeze(1)
|
||||
x_lens = audio_len
|
||||
|
||||
# convolution group
|
||||
x, x_lens = self.conv(x, x_lens)
|
||||
x_val = x.numpy()
|
||||
|
||||
# convert data from convolution feature map to sequence of vectors
|
||||
#B, C, D, T = paddle.shape(x) # not work under jit
|
||||
x = x.transpose([0, 3, 1, 2]) #[B, T, C, D]
|
||||
#x = x.reshape([B, T, C * D]) #[B, T, C*D] # not work under jit
|
||||
x = x.reshape([0, 0, -1]) #[B, T, C*D]
|
||||
|
||||
# remove padding part
|
||||
x, x_lens = self.rnn(x, x_lens) #[B, T, D]
|
||||
return x, x_lens
|
||||
|
||||
|
||||
class DeepSpeech2Model(nn.Layer):
|
||||
"""The DeepSpeech2 network structure.
|
||||
|
||||
:param audio_data: Audio spectrogram data layer.
|
||||
:type audio_data: Variable
|
||||
:param text_data: Transcription text data layer.
|
||||
:type text_data: Variable
|
||||
:param audio_len: Valid sequence length data layer.
|
||||
:type audio_len: Variable
|
||||
:param masks: Masks data layer to reset padding.
|
||||
:type masks: Variable
|
||||
:param dict_size: Dictionary size for tokenized transcription.
|
||||
:type dict_size: int
|
||||
:param num_conv_layers: Number of stacking convolution layers.
|
||||
:type num_conv_layers: int
|
||||
:param num_rnn_layers: Number of stacking RNN layers.
|
||||
:type num_rnn_layers: int
|
||||
:param rnn_size: RNN layer size (dimension of RNN cells).
|
||||
:type rnn_size: int
|
||||
:param use_gru: Use gru if set True. Use simple rnn if set False.
|
||||
:type use_gru: bool
|
||||
:param share_rnn_weights: Whether to share input-hidden weights between
|
||||
forward and backward direction RNNs.
|
||||
It is only available when use_gru=False.
|
||||
:type share_weights: bool
|
||||
:return: A tuple of an output unnormalized log probability layer (
|
||||
before softmax) and a ctc cost layer.
|
||||
:rtype: tuple of LayerOutput
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
feat_size,
|
||||
dict_size,
|
||||
num_conv_layers=2,
|
||||
num_rnn_layers=3,
|
||||
rnn_size=1024,
|
||||
use_gru=False,
|
||||
share_rnn_weights=True,
|
||||
blank_id=0):
|
||||
super().__init__()
|
||||
self.encoder = CRNNEncoder(
|
||||
feat_size=feat_size,
|
||||
dict_size=dict_size,
|
||||
num_conv_layers=num_conv_layers,
|
||||
num_rnn_layers=num_rnn_layers,
|
||||
rnn_size=rnn_size,
|
||||
use_gru=use_gru,
|
||||
share_rnn_weights=share_rnn_weights)
|
||||
assert (self.encoder.output_size == rnn_size * 2)
|
||||
|
||||
self.decoder = CTCDecoder(
|
||||
odim=dict_size, # <blank> is in vocab
|
||||
enc_n_units=self.encoder.output_size,
|
||||
blank_id=blank_id, # first token is <blank>
|
||||
dropout_rate=0.0,
|
||||
reduction=True, # sum
|
||||
batch_average=True) # sum / batch_size
|
||||
|
||||
def forward(self, audio, audio_len, text, text_len):
|
||||
"""Compute Model loss
|
||||
|
||||
Args:
|
||||
audio (Tensor): [B, T, D]
|
||||
audio_len (Tensor): [B]
|
||||
text (Tensor): [B, U]
|
||||
text_len (Tensor): [B]
|
||||
|
||||
Returns:
|
||||
loss (Tensor): [1]
|
||||
"""
|
||||
eouts, eouts_len = self.encoder(audio, audio_len)
|
||||
loss = self.decoder(eouts, eouts_len, text, text_len)
|
||||
return loss
|
||||
|
||||
@paddle.no_grad()
|
||||
def decode(self, audio, audio_len):
|
||||
# decoders only accept string encoded in utf-8
|
||||
|
||||
# Make sure the decoder has been initialized
|
||||
eouts, eouts_len = self.encoder(audio, audio_len)
|
||||
probs = self.decoder.softmax(eouts)
|
||||
batch_size = probs.shape[0]
|
||||
self.decoder.reset_decoder(batch_size=batch_size)
|
||||
self.decoder.next(probs, eouts_len)
|
||||
trans_best, trans_beam = self.decoder.decode()
|
||||
return trans_best
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, dataloader, config, checkpoint_path):
|
||||
"""Build a DeepSpeech2Model model from a pretrained model.
|
||||
Parameters
|
||||
----------
|
||||
dataloader: paddle.io.DataLoader
|
||||
|
||||
config: yacs.config.CfgNode
|
||||
model configs
|
||||
|
||||
checkpoint_path: Path or str
|
||||
the path of pretrained model checkpoint, without extension name
|
||||
|
||||
Returns
|
||||
-------
|
||||
DeepSpeech2Model
|
||||
The model built from pretrained result.
|
||||
"""
|
||||
model = cls(feat_size=dataloader.collate_fn.feature_size,
|
||||
dict_size=len(dataloader.collate_fn.vocab_list),
|
||||
num_conv_layers=config.num_conv_layers,
|
||||
num_rnn_layers=config.num_rnn_layers,
|
||||
rnn_size=config.rnn_layer_size,
|
||||
use_gru=config.use_gru,
|
||||
share_rnn_weights=config.share_rnn_weights)
|
||||
infos = Checkpoint().load_parameters(
|
||||
model, checkpoint_path=checkpoint_path)
|
||||
logger.info(f"checkpoint info: {infos}")
|
||||
layer_tools.summary(model)
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
"""Build a DeepSpeec2Model from config
|
||||
Parameters
|
||||
|
||||
config: yacs.config.CfgNode
|
||||
config
|
||||
Returns
|
||||
-------
|
||||
DeepSpeech2Model
|
||||
The model built from config.
|
||||
"""
|
||||
model = cls(feat_size=config.feat_size,
|
||||
dict_size=config.dict_size,
|
||||
num_conv_layers=config.num_conv_layers,
|
||||
num_rnn_layers=config.num_rnn_layers,
|
||||
rnn_size=config.rnn_layer_size,
|
||||
use_gru=config.use_gru,
|
||||
share_rnn_weights=config.share_rnn_weights,
|
||||
blank_id=config.blank_id)
|
||||
return model
|
||||
|
||||
|
||||
class DeepSpeech2InferModel(DeepSpeech2Model):
|
||||
def __init__(self,
|
||||
feat_size,
|
||||
dict_size,
|
||||
num_conv_layers=2,
|
||||
num_rnn_layers=3,
|
||||
rnn_size=1024,
|
||||
use_gru=False,
|
||||
share_rnn_weights=True,
|
||||
blank_id=0):
|
||||
super().__init__(
|
||||
feat_size=feat_size,
|
||||
dict_size=dict_size,
|
||||
num_conv_layers=num_conv_layers,
|
||||
num_rnn_layers=num_rnn_layers,
|
||||
rnn_size=rnn_size,
|
||||
use_gru=use_gru,
|
||||
share_rnn_weights=share_rnn_weights,
|
||||
blank_id=blank_id)
|
||||
|
||||
def forward(self, audio, audio_len):
|
||||
"""export model function
|
||||
|
||||
Args:
|
||||
audio (Tensor): [B, T, D]
|
||||
audio_len (Tensor): [B]
|
||||
|
||||
Returns:
|
||||
probs: probs after softmax
|
||||
"""
|
||||
eouts, eouts_len = self.encoder(audio, audio_len)
|
||||
probs = self.decoder.softmax(eouts)
|
||||
return probs, eouts_len
|
||||
|
||||
def export(self):
|
||||
static_model = paddle.jit.to_static(
|
||||
self,
|
||||
input_spec=[
|
||||
paddle.static.InputSpec(
|
||||
shape=[None, None, self.encoder.feat_size],
|
||||
dtype='float32'), # audio, [B,T,D]
|
||||
paddle.static.InputSpec(shape=[None],
|
||||
dtype='int64'), # audio_length, [B]
|
||||
])
|
||||
return static_model
|
@ -1,334 +0,0 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.nn import functional as F
|
||||
from paddle.nn import initializer as I
|
||||
|
||||
from paddlespeech.s2t.modules.activation import brelu
|
||||
from paddlespeech.s2t.modules.mask import make_non_pad_mask
|
||||
from paddlespeech.s2t.utils.log import Log
|
||||
logger = Log(__name__).getlog()
|
||||
|
||||
__all__ = ['RNNStack']
|
||||
|
||||
|
||||
class RNNCell(nn.RNNCellBase):
|
||||
r"""
|
||||
Elman RNN (SimpleRNN) cell. Given the inputs and previous states, it
|
||||
computes the outputs and updates states.
|
||||
The formula used is as follows:
|
||||
.. math::
|
||||
h_{t} & = act(x_{t} + b_{ih} + W_{hh}h_{t-1} + b_{hh})
|
||||
y_{t} & = h_{t}
|
||||
|
||||
where :math:`act` is for :attr:`activation`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
activation="tanh",
|
||||
weight_ih_attr=None,
|
||||
weight_hh_attr=None,
|
||||
bias_ih_attr=None,
|
||||
bias_hh_attr=None,
|
||||
name=None):
|
||||
super().__init__()
|
||||
std = 1.0 / math.sqrt(hidden_size)
|
||||
self.weight_hh = self.create_parameter(
|
||||
(hidden_size, hidden_size),
|
||||
weight_hh_attr,
|
||||
default_initializer=I.Uniform(-std, std))
|
||||
self.bias_ih = None
|
||||
self.bias_hh = self.create_parameter(
|
||||
(hidden_size, ),
|
||||
bias_hh_attr,
|
||||
is_bias=True,
|
||||
default_initializer=I.Uniform(-std, std))
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
if activation not in ["tanh", "relu", "brelu"]:
|
||||
raise ValueError(
|
||||
"activation for SimpleRNNCell should be tanh or relu, "
|
||||
"but get {}".format(activation))
|
||||
self.activation = activation
|
||||
self._activation_fn = paddle.tanh \
|
||||
if activation == "tanh" \
|
||||
else F.relu
|
||||
if activation == 'brelu':
|
||||
self._activation_fn = brelu
|
||||
|
||||
def forward(self, inputs, states=None):
|
||||
if states is None:
|
||||
states = self.get_initial_states(inputs, self.state_shape)
|
||||
pre_h = states
|
||||
i2h = inputs
|
||||
if self.bias_ih is not None:
|
||||
i2h += self.bias_ih
|
||||
h2h = paddle.matmul(pre_h, self.weight_hh, transpose_y=True)
|
||||
if self.bias_hh is not None:
|
||||
h2h += self.bias_hh
|
||||
h = self._activation_fn(i2h + h2h)
|
||||
return h, h
|
||||
|
||||
@property
|
||||
def state_shape(self):
|
||||
return (self.hidden_size, )
|
||||
|
||||
|
||||
class GRUCell(nn.RNNCellBase):
|
||||
r"""
|
||||
Gated Recurrent Unit (GRU) RNN cell. Given the inputs and previous states,
|
||||
it computes the outputs and updates states.
|
||||
The formula for GRU used is as follows:
|
||||
.. math::
|
||||
r_{t} & = \sigma(W_{ir}x_{t} + b_{ir} + W_{hr}h_{t-1} + b_{hr})
|
||||
z_{t} & = \sigma(W_{iz}x_{t} + b_{iz} + W_{hz}h_{t-1} + b_{hz})
|
||||
\widetilde{h}_{t} & = \tanh(W_{ic}x_{t} + b_{ic} + r_{t} * (W_{hc}h_{t-1} + b_{hc}))
|
||||
h_{t} & = z_{t} * h_{t-1} + (1 - z_{t}) * \widetilde{h}_{t}
|
||||
y_{t} & = h_{t}
|
||||
|
||||
where :math:`\sigma` is the sigmoid fucntion, and * is the elemetwise
|
||||
multiplication operator.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
input_size: int,
|
||||
hidden_size: int,
|
||||
weight_ih_attr=None,
|
||||
weight_hh_attr=None,
|
||||
bias_ih_attr=None,
|
||||
bias_hh_attr=None,
|
||||
name=None):
|
||||
super().__init__()
|
||||
std = 1.0 / math.sqrt(hidden_size)
|
||||
self.weight_hh = self.create_parameter(
|
||||
(3 * hidden_size, hidden_size),
|
||||
weight_hh_attr,
|
||||
default_initializer=I.Uniform(-std, std))
|
||||
self.bias_ih = None
|
||||
self.bias_hh = self.create_parameter(
|
||||
(3 * hidden_size, ),
|
||||
bias_hh_attr,
|
||||
is_bias=True,
|
||||
default_initializer=I.Uniform(-std, std))
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.input_size = input_size
|
||||
self._gate_activation = F.sigmoid
|
||||
self._activation = paddle.relu
|
||||
|
||||
def forward(self, inputs, states=None):
|
||||
if states is None:
|
||||
states = self.get_initial_states(inputs, self.state_shape)
|
||||
|
||||
pre_hidden = states # shape [batch_size, hidden_size]
|
||||
|
||||
x_gates = inputs
|
||||
if self.bias_ih is not None:
|
||||
x_gates = x_gates + self.bias_ih
|
||||
bias_u, bias_r, bias_c = paddle.split(
|
||||
self.bias_hh, num_or_sections=3, axis=0)
|
||||
|
||||
weight_hh = paddle.transpose(
|
||||
self.weight_hh,
|
||||
perm=[1, 0]) #weight_hh:shape[hidden_size, 3 * hidden_size]
|
||||
w_u_r_c = paddle.flatten(weight_hh)
|
||||
size_u_r = self.hidden_size * 2 * self.hidden_size
|
||||
w_u_r = paddle.reshape(w_u_r_c[:size_u_r],
|
||||
(self.hidden_size, self.hidden_size * 2))
|
||||
w_u, w_r = paddle.split(w_u_r, num_or_sections=2, axis=1)
|
||||
w_c = paddle.reshape(w_u_r_c[size_u_r:],
|
||||
(self.hidden_size, self.hidden_size))
|
||||
|
||||
h_u = paddle.matmul(
|
||||
pre_hidden, w_u,
|
||||
transpose_y=False) + bias_u #shape [batch_size, hidden_size]
|
||||
h_r = paddle.matmul(
|
||||
pre_hidden, w_r,
|
||||
transpose_y=False) + bias_r #shape [batch_size, hidden_size]
|
||||
|
||||
x_u, x_r, x_c = paddle.split(
|
||||
x_gates, num_or_sections=3, axis=1) #shape[batch_size, hidden_size]
|
||||
|
||||
u = self._gate_activation(x_u + h_u) #shape [batch_size, hidden_size]
|
||||
r = self._gate_activation(x_r + h_r) #shape [batch_size, hidden_size]
|
||||
c = self._activation(
|
||||
x_c + paddle.matmul(r * pre_hidden, w_c, transpose_y=False) +
|
||||
bias_c) # [batch_size, hidden_size]
|
||||
|
||||
h = (1 - u) * pre_hidden + u * c
|
||||
# https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/fluid/layers/dynamic_gru_cn.html#dynamic-gru
|
||||
return h, h
|
||||
|
||||
@property
|
||||
def state_shape(self):
|
||||
r"""
|
||||
The `state_shape` of GRUCell is a shape `[hidden_size]` (-1 for batch
|
||||
size would be automatically inserted into shape). The shape corresponds
|
||||
to the shape of :math:`h_{t-1}`.
|
||||
"""
|
||||
return (self.hidden_size, )
|
||||
|
||||
|
||||
class BiRNNWithBN(nn.Layer):
|
||||
"""Bidirectonal simple rnn layer with sequence-wise batch normalization.
|
||||
The batch normalization is only performed on input-state weights.
|
||||
|
||||
:param size: Dimension of RNN cells.
|
||||
:type size: int
|
||||
:param share_weights: Whether to share input-hidden weights between
|
||||
forward and backward directional RNNs.
|
||||
:type share_weights: bool
|
||||
:return: Bidirectional simple rnn layer.
|
||||
:rtype: Variable
|
||||
"""
|
||||
|
||||
def __init__(self, i_size: int, h_size: int, share_weights: bool):
|
||||
super().__init__()
|
||||
self.share_weights = share_weights
|
||||
if self.share_weights:
|
||||
#input-hidden weights shared between bi-directional rnn.
|
||||
self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False)
|
||||
# batch norm is only performed on input-state projection
|
||||
self.fw_bn = nn.BatchNorm1D(
|
||||
h_size, bias_attr=None, data_format='NLC')
|
||||
self.bw_fc = self.fw_fc
|
||||
self.bw_bn = self.fw_bn
|
||||
else:
|
||||
self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False)
|
||||
self.fw_bn = nn.BatchNorm1D(
|
||||
h_size, bias_attr=None, data_format='NLC')
|
||||
self.bw_fc = nn.Linear(i_size, h_size, bias_attr=False)
|
||||
self.bw_bn = nn.BatchNorm1D(
|
||||
h_size, bias_attr=None, data_format='NLC')
|
||||
|
||||
self.fw_cell = RNNCell(hidden_size=h_size, activation='brelu')
|
||||
self.bw_cell = RNNCell(hidden_size=h_size, activation='brelu')
|
||||
self.fw_rnn = nn.RNN(
|
||||
self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]
|
||||
self.bw_rnn = nn.RNN(
|
||||
self.bw_cell, is_reverse=True, time_major=False) #[B, T, D]
|
||||
|
||||
def forward(self, x: paddle.Tensor, x_len: paddle.Tensor):
|
||||
# x, shape [B, T, D]
|
||||
fw_x = self.fw_bn(self.fw_fc(x))
|
||||
bw_x = self.bw_bn(self.bw_fc(x))
|
||||
fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len)
|
||||
bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len)
|
||||
x = paddle.concat([fw_x, bw_x], axis=-1)
|
||||
return x, x_len
|
||||
|
||||
|
||||
class BiGRUWithBN(nn.Layer):
|
||||
"""Bidirectonal gru layer with sequence-wise batch normalization.
|
||||
The batch normalization is only performed on input-state weights.
|
||||
|
||||
:param name: Name of the layer.
|
||||
:type name: string
|
||||
:param input: Input layer.
|
||||
:type input: Variable
|
||||
:param size: Dimension of GRU cells.
|
||||
:type size: int
|
||||
:param act: Activation type.
|
||||
:type act: string
|
||||
:return: Bidirectional GRU layer.
|
||||
:rtype: Variable
|
||||
"""
|
||||
|
||||
def __init__(self, i_size: int, h_size: int):
|
||||
super().__init__()
|
||||
hidden_size = h_size * 3
|
||||
|
||||
self.fw_fc = nn.Linear(i_size, hidden_size, bias_attr=False)
|
||||
self.fw_bn = nn.BatchNorm1D(
|
||||
hidden_size, bias_attr=None, data_format='NLC')
|
||||
self.bw_fc = nn.Linear(i_size, hidden_size, bias_attr=False)
|
||||
self.bw_bn = nn.BatchNorm1D(
|
||||
hidden_size, bias_attr=None, data_format='NLC')
|
||||
|
||||
self.fw_cell = GRUCell(input_size=hidden_size, hidden_size=h_size)
|
||||
self.bw_cell = GRUCell(input_size=hidden_size, hidden_size=h_size)
|
||||
self.fw_rnn = nn.RNN(
|
||||
self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]
|
||||
self.bw_rnn = nn.RNN(
|
||||
self.bw_cell, is_reverse=True, time_major=False) #[B, T, D]
|
||||
|
||||
def forward(self, x, x_len):
|
||||
# x, shape [B, T, D]
|
||||
fw_x = self.fw_bn(self.fw_fc(x))
|
||||
|
||||
bw_x = self.bw_bn(self.bw_fc(x))
|
||||
fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len)
|
||||
bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len)
|
||||
x = paddle.concat([fw_x, bw_x], axis=-1)
|
||||
return x, x_len
|
||||
|
||||
|
||||
class RNNStack(nn.Layer):
|
||||
"""RNN group with stacked bidirectional simple RNN or GRU layers.
|
||||
|
||||
:param input: Input layer.
|
||||
:type input: Variable
|
||||
:param size: Dimension of RNN cells in each layer.
|
||||
:type size: int
|
||||
:param num_stacks: Number of stacked rnn layers.
|
||||
:type num_stacks: int
|
||||
:param use_gru: Use gru if set True. Use simple rnn if set False.
|
||||
:type use_gru: bool
|
||||
:param share_rnn_weights: Whether to share input-hidden weights between
|
||||
forward and backward directional RNNs.
|
||||
It is only available when use_gru=False.
|
||||
:type share_weights: bool
|
||||
:return: Output layer of the RNN group.
|
||||
:rtype: Variable
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
i_size: int,
|
||||
h_size: int,
|
||||
num_stacks: int,
|
||||
use_gru: bool,
|
||||
share_rnn_weights: bool):
|
||||
super().__init__()
|
||||
rnn_stacks = []
|
||||
for i in range(num_stacks):
|
||||
if use_gru:
|
||||
#default:GRU using tanh
|
||||
rnn_stacks.append(BiGRUWithBN(i_size=i_size, h_size=h_size))
|
||||
else:
|
||||
rnn_stacks.append(
|
||||
BiRNNWithBN(
|
||||
i_size=i_size,
|
||||
h_size=h_size,
|
||||
share_weights=share_rnn_weights))
|
||||
i_size = h_size * 2
|
||||
|
||||
self.rnn_stacks = nn.LayerList(rnn_stacks)
|
||||
|
||||
def forward(self, x: paddle.Tensor, x_len: paddle.Tensor):
|
||||
"""
|
||||
x: shape [B, T, D]
|
||||
x_len: shpae [B]
|
||||
"""
|
||||
for i, rnn in enumerate(self.rnn_stacks):
|
||||
x, x_len = rnn(x, x_len)
|
||||
masks = make_non_pad_mask(x_len) #[B, T]
|
||||
masks = masks.unsqueeze(-1) # [B, T, 1]
|
||||
# TODO(Hui Zhang): not support bool multiply
|
||||
masks = masks.astype(x.dtype)
|
||||
x = x.multiply(masks)
|
||||
return x, x_len
|
@ -1,357 +0,0 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Contains DeepSpeech2 and DeepSpeech2Online model."""
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import distributed as dist
|
||||
from paddle.io import DataLoader
|
||||
from src_deepspeech2x.models.ds2 import DeepSpeech2InferModel
|
||||
from src_deepspeech2x.models.ds2 import DeepSpeech2Model
|
||||
|
||||
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
|
||||
from paddlespeech.s2t.io.collator import SpeechCollator
|
||||
from paddlespeech.s2t.io.dataset import ManifestDataset
|
||||
from paddlespeech.s2t.io.sampler import SortagradBatchSampler
|
||||
from paddlespeech.s2t.io.sampler import SortagradDistributedBatchSampler
|
||||
from paddlespeech.s2t.models.ds2_online import DeepSpeech2InferModelOnline
|
||||
from paddlespeech.s2t.models.ds2_online import DeepSpeech2ModelOnline
|
||||
from paddlespeech.s2t.training.gradclip import ClipGradByGlobalNormWithLog
|
||||
from paddlespeech.s2t.training.trainer import Trainer
|
||||
from paddlespeech.s2t.utils import error_rate
|
||||
from paddlespeech.s2t.utils import layer_tools
|
||||
from paddlespeech.s2t.utils import mp_tools
|
||||
from paddlespeech.s2t.utils.log import Log
|
||||
|
||||
logger = Log(__name__).getlog()
|
||||
|
||||
|
||||
class DeepSpeech2Trainer(Trainer):
|
||||
def __init__(self, config, args):
|
||||
super().__init__(config, args)
|
||||
|
||||
def train_batch(self, batch_index, batch_data, msg):
|
||||
train_conf = self.config
|
||||
start = time.time()
|
||||
|
||||
# forward
|
||||
utt, audio, audio_len, text, text_len = batch_data
|
||||
loss = self.model(audio, audio_len, text, text_len)
|
||||
losses_np = {
|
||||
'train_loss': float(loss),
|
||||
}
|
||||
|
||||
# loss backward
|
||||
if (batch_index + 1) % train_conf.accum_grad != 0:
|
||||
# Disable gradient synchronizations across DDP processes.
|
||||
# Within this context, gradients will be accumulated on module
|
||||
# variables, which will later be synchronized.
|
||||
context = self.model.no_sync
|
||||
else:
|
||||
# Used for single gpu training and DDP gradient synchronization
|
||||
# processes.
|
||||
context = nullcontext
|
||||
|
||||
with context():
|
||||
loss.backward()
|
||||
layer_tools.print_grads(self.model, print_func=None)
|
||||
|
||||
# optimizer step
|
||||
if (batch_index + 1) % train_conf.accum_grad == 0:
|
||||
self.optimizer.step()
|
||||
self.optimizer.clear_grad()
|
||||
self.iteration += 1
|
||||
|
||||
iteration_time = time.time() - start
|
||||
|
||||
msg += "train time: {:>.3f}s, ".format(iteration_time)
|
||||
msg += "batch size: {}, ".format(self.config.batch_size)
|
||||
msg += "accum: {}, ".format(train_conf.accum_grad)
|
||||
msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
||||
for k, v in losses_np.items())
|
||||
logger.info(msg)
|
||||
|
||||
if dist.get_rank() == 0 and self.visualizer:
|
||||
for k, v in losses_np.items():
|
||||
# `step -1` since we update `step` after optimizer.step().
|
||||
self.visualizer.add_scalar("train/{}".format(k), v,
|
||||
self.iteration - 1)
|
||||
|
||||
@paddle.no_grad()
|
||||
def valid(self):
|
||||
logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}")
|
||||
self.model.eval()
|
||||
valid_losses = defaultdict(list)
|
||||
num_seen_utts = 1
|
||||
total_loss = 0.0
|
||||
for i, batch in enumerate(self.valid_loader):
|
||||
utt, audio, audio_len, text, text_len = batch
|
||||
loss = self.model(audio, audio_len, text, text_len)
|
||||
if paddle.isfinite(loss):
|
||||
num_utts = batch[1].shape[0]
|
||||
num_seen_utts += num_utts
|
||||
total_loss += float(loss) * num_utts
|
||||
valid_losses['val_loss'].append(float(loss))
|
||||
|
||||
if (i + 1) % self.config.log_interval == 0:
|
||||
valid_dump = {k: np.mean(v) for k, v in valid_losses.items()}
|
||||
valid_dump['val_history_loss'] = total_loss / num_seen_utts
|
||||
|
||||
# logging
|
||||
msg = f"Valid: Rank: {dist.get_rank()}, "
|
||||
msg += "epoch: {}, ".format(self.epoch)
|
||||
msg += "step: {}, ".format(self.iteration)
|
||||
msg += "batch : {}/{}, ".format(i + 1, len(self.valid_loader))
|
||||
msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
||||
for k, v in valid_dump.items())
|
||||
logger.info(msg)
|
||||
|
||||
logger.info('Rank {} Val info val_loss {}'.format(
|
||||
dist.get_rank(), total_loss / num_seen_utts))
|
||||
return total_loss, num_seen_utts
|
||||
|
||||
def setup_model(self):
|
||||
config = self.config.clone()
|
||||
config.defrost()
|
||||
config.feat_size = self.train_loader.collate_fn.feature_size
|
||||
#config.dict_size = self.train_loader.collate_fn.vocab_size
|
||||
config.dict_size = len(self.train_loader.collate_fn.vocab_list)
|
||||
config.freeze()
|
||||
|
||||
if self.args.model_type == 'offline':
|
||||
model = DeepSpeech2Model.from_config(config)
|
||||
elif self.args.model_type == 'online':
|
||||
model = DeepSpeech2ModelOnline.from_config(config)
|
||||
else:
|
||||
raise Exception("wrong model type")
|
||||
if self.parallel:
|
||||
model = paddle.DataParallel(model)
|
||||
|
||||
logger.info(f"{model}")
|
||||
layer_tools.print_params(model, logger.info)
|
||||
|
||||
grad_clip = ClipGradByGlobalNormWithLog(config.global_grad_clip)
|
||||
lr_scheduler = paddle.optimizer.lr.ExponentialDecay(
|
||||
learning_rate=config.lr, gamma=config.lr_decay, verbose=True)
|
||||
optimizer = paddle.optimizer.Adam(
|
||||
learning_rate=lr_scheduler,
|
||||
parameters=model.parameters(),
|
||||
weight_decay=paddle.regularizer.L2Decay(config.weight_decay),
|
||||
grad_clip=grad_clip)
|
||||
|
||||
self.model = model
|
||||
self.optimizer = optimizer
|
||||
self.lr_scheduler = lr_scheduler
|
||||
logger.info("Setup model/optimizer/lr_scheduler!")
|
||||
|
||||
def setup_dataloader(self):
|
||||
config = self.config.clone()
|
||||
config.defrost()
|
||||
config.keep_transcription_text = False
|
||||
|
||||
config.manifest = config.train_manifest
|
||||
train_dataset = ManifestDataset.from_config(config)
|
||||
|
||||
config.manifest = config.dev_manifest
|
||||
dev_dataset = ManifestDataset.from_config(config)
|
||||
|
||||
config.manifest = config.test_manifest
|
||||
test_dataset = ManifestDataset.from_config(config)
|
||||
|
||||
if self.parallel:
|
||||
batch_sampler = SortagradDistributedBatchSampler(
|
||||
train_dataset,
|
||||
batch_size=config.batch_size,
|
||||
num_replicas=None,
|
||||
rank=None,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
sortagrad=config.sortagrad,
|
||||
shuffle_method=config.shuffle_method)
|
||||
else:
|
||||
batch_sampler = SortagradBatchSampler(
|
||||
train_dataset,
|
||||
shuffle=True,
|
||||
batch_size=config.batch_size,
|
||||
drop_last=True,
|
||||
sortagrad=config.sortagrad,
|
||||
shuffle_method=config.shuffle_method)
|
||||
|
||||
collate_fn_train = SpeechCollator.from_config(config)
|
||||
|
||||
config.augmentation_config = ""
|
||||
collate_fn_dev = SpeechCollator.from_config(config)
|
||||
|
||||
config.keep_transcription_text = True
|
||||
config.augmentation_config = ""
|
||||
collate_fn_test = SpeechCollator.from_config(config)
|
||||
|
||||
self.train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_sampler=batch_sampler,
|
||||
collate_fn=collate_fn_train,
|
||||
num_workers=config.num_workers)
|
||||
self.valid_loader = DataLoader(
|
||||
dev_dataset,
|
||||
batch_size=config.batch_size,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
collate_fn=collate_fn_dev)
|
||||
self.test_loader = DataLoader(
|
||||
test_dataset,
|
||||
batch_size=config.decode.decode_batch_size,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
collate_fn=collate_fn_test)
|
||||
if "<eos>" in self.test_loader.collate_fn.vocab_list:
|
||||
self.test_loader.collate_fn.vocab_list.remove("<eos>")
|
||||
if "<eos>" in self.valid_loader.collate_fn.vocab_list:
|
||||
self.valid_loader.collate_fn.vocab_list.remove("<eos>")
|
||||
if "<eos>" in self.train_loader.collate_fn.vocab_list:
|
||||
self.train_loader.collate_fn.vocab_list.remove("<eos>")
|
||||
logger.info("Setup train/valid/test Dataloader!")
|
||||
|
||||
|
||||
class DeepSpeech2Tester(DeepSpeech2Trainer):
|
||||
def __init__(self, config, args):
|
||||
|
||||
self._text_featurizer = TextFeaturizer(
|
||||
unit_type=config.unit_type, vocab=None)
|
||||
super().__init__(config, args)
|
||||
|
||||
def ordid2token(self, texts, texts_len):
|
||||
""" ord() id to chr() chr """
|
||||
trans = []
|
||||
for text, n in zip(texts, texts_len):
|
||||
n = n.numpy().item()
|
||||
ids = text[:n]
|
||||
trans.append(''.join([chr(i) for i in ids]))
|
||||
return trans
|
||||
|
||||
def compute_metrics(self,
|
||||
utts,
|
||||
audio,
|
||||
audio_len,
|
||||
texts,
|
||||
texts_len,
|
||||
fout=None):
|
||||
cfg = self.config.decode
|
||||
errors_sum, len_refs, num_ins = 0.0, 0, 0
|
||||
errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors
|
||||
error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer
|
||||
|
||||
target_transcripts = self.ordid2token(texts, texts_len)
|
||||
|
||||
result_transcripts = self.compute_result_transcripts(audio, audio_len)
|
||||
|
||||
for utt, target, result in zip(utts, target_transcripts,
|
||||
result_transcripts):
|
||||
errors, len_ref = errors_func(target, result)
|
||||
errors_sum += errors
|
||||
len_refs += len_ref
|
||||
num_ins += 1
|
||||
if fout:
|
||||
fout.write(utt + " " + result + "\n")
|
||||
logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" %
|
||||
(target, result))
|
||||
logger.info("Current error rate [%s] = %f" %
|
||||
(cfg.error_rate_type, error_rate_func(target, result)))
|
||||
|
||||
return dict(
|
||||
errors_sum=errors_sum,
|
||||
len_refs=len_refs,
|
||||
num_ins=num_ins,
|
||||
error_rate=errors_sum / len_refs,
|
||||
error_rate_type=cfg.error_rate_type)
|
||||
|
||||
def compute_result_transcripts(self, audio, audio_len):
|
||||
result_transcripts = self.model.decode(audio, audio_len)
|
||||
|
||||
result_transcripts = [
|
||||
self._text_featurizer.detokenize(item)
|
||||
for item in result_transcripts
|
||||
]
|
||||
return result_transcripts
|
||||
|
||||
@mp_tools.rank_zero_only
|
||||
@paddle.no_grad()
|
||||
def test(self):
|
||||
logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}")
|
||||
self.model.eval()
|
||||
cfg = self.config
|
||||
error_rate_type = None
|
||||
errors_sum, len_refs, num_ins = 0.0, 0, 0
|
||||
|
||||
# Initialized the decoder in model
|
||||
decode_cfg = self.config.decode
|
||||
vocab_list = self.test_loader.collate_fn.vocab_list
|
||||
decode_batch_size = self.test_loader.batch_size
|
||||
self.model.decoder.init_decoder(
|
||||
decode_batch_size, vocab_list, decode_cfg.decoding_method,
|
||||
decode_cfg.lang_model_path, decode_cfg.alpha, decode_cfg.beta,
|
||||
decode_cfg.beam_size, decode_cfg.cutoff_prob,
|
||||
decode_cfg.cutoff_top_n, decode_cfg.num_proc_bsearch)
|
||||
|
||||
with open(self.args.result_file, 'w') as fout:
|
||||
for i, batch in enumerate(self.test_loader):
|
||||
utts, audio, audio_len, texts, texts_len = batch
|
||||
metrics = self.compute_metrics(utts, audio, audio_len, texts,
|
||||
texts_len, fout)
|
||||
errors_sum += metrics['errors_sum']
|
||||
len_refs += metrics['len_refs']
|
||||
num_ins += metrics['num_ins']
|
||||
error_rate_type = metrics['error_rate_type']
|
||||
logger.info("Error rate [%s] (%d/?) = %f" %
|
||||
(error_rate_type, num_ins, errors_sum / len_refs))
|
||||
|
||||
# logging
|
||||
msg = "Test: "
|
||||
msg += "epoch: {}, ".format(self.epoch)
|
||||
msg += "step: {}, ".format(self.iteration)
|
||||
msg += "Final error rate [%s] (%d/%d) = %f" % (
|
||||
error_rate_type, num_ins, num_ins, errors_sum / len_refs)
|
||||
logger.info(msg)
|
||||
self.model.decoder.del_decoder()
|
||||
|
||||
def run_test(self):
|
||||
self.resume_or_scratch()
|
||||
try:
|
||||
self.test()
|
||||
except KeyboardInterrupt:
|
||||
exit(-1)
|
||||
|
||||
def export(self):
|
||||
if self.args.model_type == 'offline':
|
||||
infer_model = DeepSpeech2InferModel.from_pretrained(
|
||||
self.test_loader, self.config, self.args.checkpoint_path)
|
||||
elif self.args.model_type == 'online':
|
||||
infer_model = DeepSpeech2InferModelOnline.from_pretrained(
|
||||
self.test_loader, self.config, self.args.checkpoint_path)
|
||||
else:
|
||||
raise Exception("wrong model type")
|
||||
|
||||
infer_model.eval()
|
||||
feat_dim = self.test_loader.collate_fn.feature_size
|
||||
static_model = infer_model.export()
|
||||
logger.info(f"Export code: {static_model.forward.code}")
|
||||
paddle.jit.save(static_model, self.args.export_path)
|
||||
|
||||
def run_export(self):
|
||||
try:
|
||||
self.export()
|
||||
except KeyboardInterrupt:
|
||||
exit(-1)
|
Loading…
Reference in new issue