[ASR] Support Hubert, fintuned on the librispeech dataset (#3088)
* librispeech hubert, test=asr * librispeech hubert, test=asr * hubert decode * review * copyright, notes, example related * hubert cli * pre-commit format * fix conflicts * fix conflicts * doc related * doc and train config * librispeech.py * support hubert clipull/3221/head
parent
8205343c65
commit
12e3e76092
@ -0,0 +1,9 @@
|
||||
# LibriSpeech
|
||||
|
||||
## hubertASR
|
||||
Fintuning on train-clean-100
|
||||
train: Epoch 3, 1*V100-32G, batchsize: 4, accum_grad: 8
|
||||
|
||||
| Model | Params | Config | Augmentation| Test set | Decode method | WER |
|
||||
| --- | --- | --- | --- | --- | --- | --- |
|
||||
| hubertASR | 326.16M | conf/hubertASR.yaml | spec_aug | test-clean | greedy search | 0.05868 |
|
@ -0,0 +1,77 @@
|
||||
{
|
||||
"_name_or_path": "facebook/hubert-large-ll60k",
|
||||
"activation_dropout": 0.0,
|
||||
"apply_spec_augment": true,
|
||||
"architectures": [
|
||||
"HubertModel"
|
||||
],
|
||||
"attention_dropout": 0.1,
|
||||
"bos_token_id": 1,
|
||||
"conv_bias": true,
|
||||
"conv_dim": [
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
512
|
||||
],
|
||||
"conv_kernel": [
|
||||
10,
|
||||
3,
|
||||
3,
|
||||
3,
|
||||
3,
|
||||
2,
|
||||
2
|
||||
],
|
||||
"conv_stride": [
|
||||
5,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2
|
||||
],
|
||||
"ctc_loss_reduction": "sum",
|
||||
"ctc_zero_infinity": false,
|
||||
"do_stable_layer_norm": true,
|
||||
"eos_token_id": 2,
|
||||
"feat_extract_activation": "gelu",
|
||||
"feat_extract_dropout": 0.0,
|
||||
"feat_extract_norm": "layer",
|
||||
"feat_proj_dropout": 0.1,
|
||||
"final_dropout": 0.0,
|
||||
"gradient_checkpointing": false,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_dropout": 0.1,
|
||||
"hidden_size": 1024,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 4096,
|
||||
"layer_norm_eps": 1e-05,
|
||||
"layerdrop": 0.1,
|
||||
"mask_channel_length": 10,
|
||||
"mask_channel_min_space": 1,
|
||||
"mask_channel_other": 0.0,
|
||||
"mask_channel_prob": 0.0,
|
||||
"mask_channel_selection": "static",
|
||||
"mask_feature_length": 10,
|
||||
"mask_feature_prob": 0.0,
|
||||
"mask_time_length": 10,
|
||||
"mask_time_min_space": 1,
|
||||
"mask_time_other": 0.0,
|
||||
"mask_time_prob": 0.075,
|
||||
"mask_time_selection": "static",
|
||||
"model_type": "hubert",
|
||||
"num_attention_heads": 16,
|
||||
"num_conv_pos_embedding_groups": 16,
|
||||
"num_conv_pos_embeddings": 128,
|
||||
"num_feat_extract_layers": 7,
|
||||
"num_hidden_layers": 24,
|
||||
"pad_token_id": 0,
|
||||
"transformers_version": "4.10.0.dev0",
|
||||
"vocab_size": 32,
|
||||
"tokenizer_class": "Wav2Vec2CTCTokenizer"
|
||||
}
|
@ -0,0 +1,142 @@
|
||||
############################################
|
||||
# Network Architecture #
|
||||
############################################
|
||||
freeze_hubert: False
|
||||
normalize_wav: True
|
||||
output_norm: True
|
||||
init_type: kaiming_uniform # !Warning: need to convergence
|
||||
enc:
|
||||
input_shape: 1024
|
||||
dnn_blocks: 2
|
||||
dnn_neurons: 1024
|
||||
activation: True
|
||||
ctc:
|
||||
enc_n_units: 1024
|
||||
blank_id: 0
|
||||
dropout_rate: 0.0
|
||||
hubert_params_path: "exp/hubert/hubert-large-lv60.pdparams"
|
||||
|
||||
|
||||
task_cfg:
|
||||
label_rate: 50.0
|
||||
sample_rate: 16000
|
||||
normalize: True
|
||||
enable_padding: False
|
||||
max_keep_size: None
|
||||
max_sample_size: 250000
|
||||
min_sample_size: 32000
|
||||
single_target: False
|
||||
random_crop: True
|
||||
pad_audio: False
|
||||
|
||||
model_cfg:
|
||||
dropout_input: 0.0
|
||||
final_dropout: 0.0
|
||||
dropout: 0.0
|
||||
attention_dropout: 0.0
|
||||
activation_dropout: 0.1
|
||||
apply_mask: True
|
||||
mask_length: 10
|
||||
mask_prob: 0.5
|
||||
mask_selection: static
|
||||
mask_other: 0.0
|
||||
no_mask_overlap: False
|
||||
mask_channel_length: 64
|
||||
mask_channel_prob: 0.25
|
||||
mask_channel_selection: static
|
||||
mask_channel_other: 0.0
|
||||
no_mask_channel_overlap: False
|
||||
feature_grad_mult: 0.0
|
||||
layerdrop: 0.1
|
||||
normalize: True
|
||||
fp16: True
|
||||
label_rate: 50
|
||||
extractor_mode: layer_norm
|
||||
encoder_layers: 24
|
||||
encoder_embed_dim: 1024
|
||||
encoder_ffn_embed_dim: 4096
|
||||
encoder_attention_heads: 16
|
||||
activation_fn: gelu
|
||||
encoder_layerdrop: 0.1
|
||||
dropout_features: 0.0
|
||||
final_dim: 768
|
||||
untie_final_proj: True
|
||||
layer_norm_first: True
|
||||
conv_feature_layers: "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2"
|
||||
conv_bias: False
|
||||
logit_temp: 0.1
|
||||
target_glu: False
|
||||
mask_min_space: 1
|
||||
mask_channel_min_space: 1
|
||||
conv_pos: 128
|
||||
conv_pos_groups: 16
|
||||
latent_temp: [2.0, 0.5, 0.999995]
|
||||
skip_masked: False
|
||||
skip_nomask: True
|
||||
|
||||
###########################################
|
||||
# Data #
|
||||
###########################################
|
||||
train_manifest: data/manifest.train-clean-100
|
||||
dev_manifest: data/manifest.dev
|
||||
test_manifest: data/manifest.test-clean
|
||||
|
||||
###########################################
|
||||
# Dataloader #
|
||||
###########################################
|
||||
vocab_filepath: data/lang_char/vocab.txt
|
||||
unit_type: char
|
||||
mean_std_filepath: ""
|
||||
preprocess_config: conf/preprocess.yaml
|
||||
sortagrad: -1 # Feed samples from shortest to longest ; -1: enabled for all epochs 0: disabled other: enabled for other epochs
|
||||
batch_size: 4 # Different batch_size may cause large differences in results
|
||||
maxlen_in: 1500 # if input length > maxlen-in batchsize is automatically reduced
|
||||
maxlen_out: 150 # if output length > maxlen-out batchsize is automatically reduced
|
||||
minibatches: 0 # for debug
|
||||
batch_count: auto
|
||||
batch_bins: 0
|
||||
batch_frames_in: 0
|
||||
batch_frames_out: 0
|
||||
batch_frames_inout: 0
|
||||
num_workers: 0
|
||||
subsampling_factor: 1
|
||||
num_encs: 1
|
||||
dist_sampler: True
|
||||
shortest_first: True
|
||||
return_lens_rate: True
|
||||
|
||||
############################################
|
||||
# Data Augmentation #
|
||||
############################################
|
||||
audio_augment: # for raw audio
|
||||
sample_rate: 16000
|
||||
speeds: [95, 100, 105]
|
||||
|
||||
###########################################
|
||||
# Training #
|
||||
###########################################
|
||||
n_epoch: 3
|
||||
accum_grad: 8
|
||||
global_grad_clip: 5.0
|
||||
model_optim: adadelta
|
||||
model_optim_conf:
|
||||
lr: 1.0
|
||||
epsilon: 1.0e-6
|
||||
rho: 0.95
|
||||
model_scheduler: constantlr
|
||||
model_scheduler_conf:
|
||||
warmup_steps: 25000
|
||||
lr_decay: 1.0
|
||||
hubert_optim: adadelta
|
||||
hubert_optim_conf:
|
||||
lr: 0.95
|
||||
epsilon: 1.0e-6
|
||||
rho: 0.95
|
||||
hubert_scheduler: constantlr
|
||||
hubert_scheduler_conf:
|
||||
warmup_steps: 25000
|
||||
lr_decay: 1.0
|
||||
log_interval: 1
|
||||
checkpoint:
|
||||
kbest_n: 50
|
||||
latest_n: 5
|
@ -0,0 +1,3 @@
|
||||
process:
|
||||
# use raw audio
|
||||
- type: wav_process
|
@ -0,0 +1,9 @@
|
||||
{
|
||||
"do_normalize": true,
|
||||
"feature_extractor_type": "Wav2Vec2FeatureExtractor",
|
||||
"feature_size": 1,
|
||||
"padding_side": "right",
|
||||
"padding_value": 0,
|
||||
"return_attention_mask": true,
|
||||
"sampling_rate": 16000
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
decode_batch_size: 1
|
||||
error_rate_type: wer
|
||||
decoding_method: ctc_greedy_search # 'ctc_greedy_search', 'ctc_prefix_beam_search'
|
||||
beam_size: 10
|
@ -0,0 +1,110 @@
|
||||
#!/bin/bash
|
||||
|
||||
stage=-1
|
||||
stop_stage=100
|
||||
|
||||
unit_type=char
|
||||
dict_dir=data/lang_char
|
||||
|
||||
source ${MAIN_ROOT}/utils/parse_options.sh
|
||||
|
||||
mkdir -p data
|
||||
mkdir -p ${dict_dir}
|
||||
TARGET_DIR=${MAIN_ROOT}/dataset
|
||||
mkdir -p ${TARGET_DIR}
|
||||
|
||||
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
|
||||
# download data, generate manifests
|
||||
python3 ${TARGET_DIR}/librispeech/librispeech.py \
|
||||
--manifest_prefix="data/manifest" \
|
||||
--target_dir="${TARGET_DIR}/librispeech" \
|
||||
--full_download="True"
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Prepare LibriSpeech failed. Terminated."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
for set in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
|
||||
mv data/manifest.${set} data/manifest.${set}.raw
|
||||
done
|
||||
|
||||
rm -rf data/manifest.train.raw data/manifest.dev.raw data/manifest.test.raw
|
||||
for set in train-clean-100 train-clean-360 train-other-500; do
|
||||
cat data/manifest.${set}.raw >> data/manifest.train.raw
|
||||
done
|
||||
|
||||
for set in dev-clean dev-other; do
|
||||
cat data/manifest.${set}.raw >> data/manifest.dev.raw
|
||||
done
|
||||
|
||||
for set in test-clean test-other; do
|
||||
cat data/manifest.${set}.raw >> data/manifest.test.raw
|
||||
done
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
# compute mean and stddev for normalizer
|
||||
num_workers=$(nproc)
|
||||
python3 ${MAIN_ROOT}/utils/compute_mean_std.py \
|
||||
--manifest_path="data/manifest.train.raw" \
|
||||
--num_samples=2000 \
|
||||
--spectrum_type="fbank" \
|
||||
--feat_dim=161 \
|
||||
--delta_delta=false \
|
||||
--sample_rate=16000 \
|
||||
--stride_ms=10 \
|
||||
--window_ms=25 \
|
||||
--use_dB_normalization=False \
|
||||
--num_workers=${num_workers} \
|
||||
--output_path="data/mean_std.json"
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Compute mean and stddev failed. Terminated."
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
# build vocabulary
|
||||
python3 ${MAIN_ROOT}/utils/build_vocab.py \
|
||||
--unit_type ${unit_type} \
|
||||
--count_threshold=0 \
|
||||
--vocab_path="${dict_dir}/vocab.txt" \
|
||||
--manifest_paths="data/manifest.train.raw"
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Build vocabulary failed. Terminated."
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
# format manifest with tokenids, vocab size
|
||||
for set in train dev test dev-clean dev-other test-clean test-other; do
|
||||
{
|
||||
python3 ${MAIN_ROOT}/utils/format_data.py \
|
||||
--cmvn_path "data/mean_std.json" \
|
||||
--unit_type ${unit_type} \
|
||||
--vocab_path="${dict_dir}/vocab.txt" \
|
||||
--manifest_path="data/manifest.${set}.raw" \
|
||||
--output_path="data/manifest.${set}"
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Formt mnaifest.${set} failed. Terminated."
|
||||
exit 1
|
||||
fi
|
||||
}&
|
||||
done
|
||||
wait
|
||||
fi
|
||||
|
||||
echo "LibriSpeech Data preparation done."
|
||||
|
||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
mkdir -p exp/hubert
|
||||
echo "Pretrained hubert model download"
|
||||
wget -P exp/hubert https://paddlespeech.bj.bcebos.com/hubert/hubert-large-lv60.pdparams
|
||||
fi
|
||||
|
||||
exit 0
|
@ -0,0 +1,83 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
||||
echo "using $ngpu gpus..."
|
||||
|
||||
expdir=exp
|
||||
datadir=data
|
||||
|
||||
recog_set="test-clean test-other dev-clean dev-other"
|
||||
recog_set="test-clean"
|
||||
|
||||
config_path=$1
|
||||
decode_config_path=$2
|
||||
ckpt_prefix=$3
|
||||
|
||||
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
|
||||
|
||||
# download language model
|
||||
#bash local/download_lm_en.sh
|
||||
#if [ $? -ne 0 ]; then
|
||||
# exit 1
|
||||
#fi
|
||||
|
||||
python3 utils/format_rsl.py \
|
||||
--origin_ref data/manifest.test-clean.raw \
|
||||
--trans_ref data/manifest.test-clean.text
|
||||
|
||||
|
||||
for type in ctc_greedy_search; do
|
||||
echo "decoding ${type}"
|
||||
batch_size=16
|
||||
python3 -u ${BIN_DIR}/test.py \
|
||||
--ngpu ${ngpu} \
|
||||
--config ${config_path} \
|
||||
--decode_cfg ${decode_config_path} \
|
||||
--result_file ${ckpt_prefix}.${type}.rsl \
|
||||
--checkpoint_path ${ckpt_prefix} \
|
||||
--opts decode.decoding_method ${type} \
|
||||
--opts decode.decode_batch_size ${batch_size}
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in evaluation!"
|
||||
exit 1
|
||||
fi
|
||||
python3 utils/format_rsl.py \
|
||||
--origin_hyp ${ckpt_prefix}.${type}.rsl \
|
||||
--trans_hyp ${ckpt_prefix}.${type}.rsl.text
|
||||
|
||||
python3 utils/compute-wer.py --char=1 --v=1 \
|
||||
data/manifest.test-clean.text ${ckpt_prefix}.${type}.rsl.text > ${ckpt_prefix}.${type}.error
|
||||
echo "decoding ${type} done."
|
||||
done
|
||||
|
||||
for type in ctc_prefix_beam_search; do
|
||||
echo "decoding ${type}"
|
||||
batch_size=1
|
||||
python3 -u ${BIN_DIR}/test.py \
|
||||
--ngpu ${ngpu} \
|
||||
--config ${config_path} \
|
||||
--decode_cfg ${decode_config_path} \
|
||||
--result_file ${ckpt_prefix}.${type}.rsl \
|
||||
--checkpoint_path ${ckpt_prefix} \
|
||||
--opts decode.decoding_method ${type} \
|
||||
--opts decode.decode_batch_size ${batch_size}
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in evaluation!"
|
||||
exit 1
|
||||
fi
|
||||
python3 utils/format_rsl.py \
|
||||
--origin_hyp ${ckpt_prefix}.${type}.rsl \
|
||||
--trans_hyp ${ckpt_prefix}.${type}.rsl.text
|
||||
|
||||
python3 utils/compute-wer.py --char=1 --v=1 \
|
||||
data/manifest.test-clean.text ${ckpt_prefix}.${type}.rsl.text > ${ckpt_prefix}.${type}.error
|
||||
echo "decoding ${type} done."
|
||||
done
|
||||
|
||||
echo "Finished"
|
||||
|
||||
exit 0
|
@ -0,0 +1,58 @@
|
||||
#!/bin/bash
|
||||
|
||||
if [ $# != 4 ];then
|
||||
echo "usage: ${0} config_path decode_config_path ckpt_path_prefix audio_file"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
||||
echo "using $ngpu gpus..."
|
||||
|
||||
config_path=$1
|
||||
decode_config_path=$2
|
||||
ckpt_prefix=$3
|
||||
audio_file=$4
|
||||
|
||||
mkdir -p data
|
||||
wget -nc https://paddlespeech.bj.bcebos.com/datasets/single_wav/en/demo_002_en.wav -P data/
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f ${audio_file} ]; then
|
||||
echo "Plase input the right audio_file path"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
chunk_mode=false
|
||||
if [[ ${config_path} =~ ^.*chunk_.*yaml$ ]];then
|
||||
chunk_mode=true
|
||||
fi
|
||||
|
||||
# download language model
|
||||
#bash local/download_lm_ch.sh
|
||||
#if [ $? -ne 0 ]; then
|
||||
# exit 1
|
||||
#fi
|
||||
|
||||
for type in ctc_greedy_search; do
|
||||
echo "decoding ${type}"
|
||||
batch_size=1
|
||||
output_dir=${ckpt_prefix}
|
||||
mkdir -p ${output_dir}
|
||||
python3 -u ${BIN_DIR}/test_wav.py \
|
||||
--ngpu ${ngpu} \
|
||||
--config ${config_path} \
|
||||
--decode_cfg ${decode_config_path} \
|
||||
--result_file ${output_dir}/${type}.rsl \
|
||||
--checkpoint_path ${ckpt_prefix} \
|
||||
--opts decode.decoding_method ${type} \
|
||||
--opts decode.decode_batch_size ${batch_size} \
|
||||
--audio_file ${audio_file}
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in evaluation!"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
exit 0
|
@ -0,0 +1,58 @@
|
||||
#!/bin/bash
|
||||
|
||||
if [ $# -lt 2 ] && [ $# -gt 3 ];then
|
||||
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name ips(optional)"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
||||
echo "using $ngpu gpus..."
|
||||
|
||||
config_path=$1
|
||||
ckpt_name=$2
|
||||
resume=$3
|
||||
ips=$4
|
||||
|
||||
if [ ! $ips ];then
|
||||
ips_config=
|
||||
else
|
||||
ips_config="--ips="${ips}
|
||||
fi
|
||||
|
||||
mkdir -p exp
|
||||
|
||||
# seed may break model convergence
|
||||
seed=1988
|
||||
if [ ${seed} != 0 ]; then
|
||||
export FLAGS_cudnn_deterministic=True
|
||||
fi
|
||||
|
||||
# export FLAGS_cudnn_exhaustive_search=true
|
||||
# export FLAGS_conv_workspace_size_limit=4000
|
||||
export FLAGS_allocator_strategy=naive_best_fit
|
||||
if [ ${ngpu} == 0 ]; then
|
||||
python3 -u ${BIN_DIR}/train.py \
|
||||
--ngpu ${ngpu} \
|
||||
--config ${config_path} \
|
||||
--output exp/${ckpt_name} \
|
||||
--seed ${seed} \
|
||||
--resume ${resume}
|
||||
else
|
||||
python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${ips_config} ${BIN_DIR}/train.py \
|
||||
--ngpu ${ngpu} \
|
||||
--config ${config_path} \
|
||||
--output exp/${ckpt_name} \
|
||||
--seed ${seed} \
|
||||
--resume ${resume}
|
||||
fi
|
||||
|
||||
if [ ${seed} != 0 ]; then
|
||||
unset FLAGS_cudnn_deterministic
|
||||
fi
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in training!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
exit 0
|
@ -0,0 +1,13 @@
|
||||
export MAIN_ROOT=`realpath ${PWD}/../../../`
|
||||
|
||||
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/tools/sctk/bin:${PWD}/utils:${PATH}
|
||||
export LC_ALL=C
|
||||
|
||||
export PYTHONDONTWRITEBYTECODE=1
|
||||
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
|
||||
export PYTHONIOENCODING=UTF-8
|
||||
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
|
||||
|
||||
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
|
||||
|
||||
export BIN_DIR=${MAIN_ROOT}/paddlespeech/s2t/exps/hubert/bin
|
@ -0,0 +1,47 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
. ./path.sh || exit 1;
|
||||
. ./cmd.sh || exit 1;
|
||||
|
||||
gpus=0
|
||||
stage=0
|
||||
stop_stage=0
|
||||
conf_path=conf/hubertASR.yaml
|
||||
ips= #xx.xx.xx.xx,xx.xx.xx.xx
|
||||
decode_conf_path=conf/tuning/decode.yaml
|
||||
avg_num=1
|
||||
resume= # xx e.g. 30
|
||||
|
||||
. ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
|
||||
|
||||
audio_file=data/demo_002_en.wav
|
||||
|
||||
avg_ckpt=avg_${avg_num}
|
||||
ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}')
|
||||
echo "checkpoint name ${ckpt}"
|
||||
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
# prepare data
|
||||
bash ./local/data.sh || exit -1
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
# train model, all `ckpt` under `exp` dir
|
||||
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${resume} ${ips}
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
# avg n best model
|
||||
avg.sh best exp/${ckpt}/checkpoints ${avg_num}
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
# greedy search decoder
|
||||
CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
||||
# test a single .wav file
|
||||
CUDA_VISIBLE_DEVICES=0 ./local/test_wav.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${audio_file} || exit -1
|
||||
fi
|
@ -0,0 +1 @@
|
||||
../../../utils
|
@ -0,0 +1,13 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -0,0 +1,13 @@
|
||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -0,0 +1,64 @@
|
||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Evaluation for hubert model."""
|
||||
import cProfile
|
||||
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from paddlespeech.s2t.exps.hubert.model import HubertASRTester as Tester
|
||||
from paddlespeech.s2t.training.cli import default_argument_parser
|
||||
from paddlespeech.s2t.utils.utility import print_arguments
|
||||
|
||||
|
||||
def main_sp(config, args):
|
||||
exp = Tester(config, args)
|
||||
with exp.eval():
|
||||
exp.setup()
|
||||
exp.run_test()
|
||||
|
||||
|
||||
def main(config, args):
|
||||
main_sp(config, args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = default_argument_parser()
|
||||
# save asr result to
|
||||
parser.add_argument(
|
||||
'--dict-path', type=str, default=None, help='dict path.')
|
||||
parser.add_argument(
|
||||
"--result_file", type=str, help="path of save the asr result")
|
||||
args = parser.parse_args()
|
||||
print_arguments(args, globals())
|
||||
|
||||
# https://yaml.org/type/float.html
|
||||
config = CfgNode(new_allowed=True)
|
||||
if args.config:
|
||||
config.merge_from_file(args.config)
|
||||
if args.decode_cfg:
|
||||
decode_confs = CfgNode(new_allowed=True)
|
||||
decode_confs.merge_from_file(args.decode_cfg)
|
||||
config.decode = decode_confs
|
||||
if args.opts:
|
||||
config.merge_from_list(args.opts)
|
||||
config.freeze()
|
||||
print(config)
|
||||
if args.dump_config:
|
||||
with open(args.dump_config, 'w') as f:
|
||||
print(config, file=f)
|
||||
|
||||
# Setting for profiling
|
||||
pr = cProfile.Profile()
|
||||
pr.runcall(main, config, args)
|
||||
pr.dump_stats('test.profile')
|
@ -0,0 +1,118 @@
|
||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Evaluation for hubert model."""
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import paddle
|
||||
import soundfile
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
|
||||
from paddlespeech.s2t.models.hubert.hubert_ASR import HubertASR
|
||||
from paddlespeech.s2t.training.cli import default_argument_parser
|
||||
from paddlespeech.s2t.utils.log import Log
|
||||
from paddlespeech.s2t.utils.utility import UpdateConfig
|
||||
logger = Log(__name__).getlog()
|
||||
|
||||
|
||||
class HubertInfer():
|
||||
def __init__(self, config, args):
|
||||
self.args = args
|
||||
self.config = config
|
||||
self.audio_file = args.audio_file
|
||||
|
||||
self.text_feature = TextFeaturizer(
|
||||
unit_type=config.unit_type, vocab=config.vocab_filepath)
|
||||
paddle.set_device('gpu' if self.args.ngpu > 0 else 'cpu')
|
||||
|
||||
# model
|
||||
model_conf = config
|
||||
with UpdateConfig(model_conf):
|
||||
model_conf.output_dim = self.text_feature.vocab_size
|
||||
model = HubertASR.from_config(model_conf)
|
||||
self.model = model
|
||||
self.model.eval()
|
||||
|
||||
# load model
|
||||
params_path = self.args.checkpoint_path + ".pdparams"
|
||||
model_dict = paddle.load(params_path)
|
||||
self.model.set_state_dict(model_dict)
|
||||
|
||||
def run(self):
|
||||
check(args.audio_file)
|
||||
|
||||
with paddle.no_grad():
|
||||
# read
|
||||
audio, _ = soundfile.read(
|
||||
self.audio_file, dtype="int16", always_2d=True)
|
||||
logger.info(f"audio shape: {audio.shape}")
|
||||
|
||||
xs = paddle.to_tensor(audio, dtype='float32').unsqueeze(axis=0)
|
||||
decode_config = self.config.decode
|
||||
result_transcripts, result_tokenids = self.model.decode(
|
||||
xs,
|
||||
text_feature=self.text_feature,
|
||||
decoding_method=decode_config.decoding_method,
|
||||
beam_size=decode_config.beam_size)
|
||||
rsl = result_transcripts[0]
|
||||
utt = Path(self.audio_file).name
|
||||
logger.info(f"hyp: {utt} {rsl}")
|
||||
return rsl
|
||||
|
||||
|
||||
def check(audio_file):
|
||||
if not os.path.isfile(audio_file):
|
||||
print("Please input the right audio file path")
|
||||
sys.exit(-1)
|
||||
|
||||
logger.info("checking the audio file format......")
|
||||
try:
|
||||
sig, sample_rate = soundfile.read(audio_file)
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
logger.error(
|
||||
"can not open the wav file, please check the audio file format")
|
||||
sys.exit(-1)
|
||||
logger.info("The sample rate is %d" % sample_rate)
|
||||
assert (sample_rate == 16000)
|
||||
logger.info("The audio file format is right")
|
||||
|
||||
|
||||
def main(config, args):
|
||||
HubertInfer(config, args).run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = default_argument_parser()
|
||||
# save asr result to
|
||||
parser.add_argument(
|
||||
"--result_file", type=str, help="path of save the asr result")
|
||||
parser.add_argument(
|
||||
"--audio_file", type=str, help="path of the input audio file")
|
||||
args = parser.parse_args()
|
||||
|
||||
config = CfgNode(new_allowed=True)
|
||||
|
||||
if args.config:
|
||||
config.merge_from_file(args.config)
|
||||
if args.decode_cfg:
|
||||
decode_confs = CfgNode(new_allowed=True)
|
||||
decode_confs.merge_from_file(args.decode_cfg)
|
||||
config.decode = decode_confs
|
||||
if args.opts:
|
||||
config.merge_from_list(args.opts)
|
||||
config.freeze()
|
||||
main(config, args)
|
@ -0,0 +1,55 @@
|
||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Trainer for hubert model."""
|
||||
import cProfile
|
||||
import os
|
||||
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from paddlespeech.s2t.exps.hubert.model import HubertASRTrainer as Trainer
|
||||
from paddlespeech.s2t.training.cli import default_argument_parser
|
||||
from paddlespeech.s2t.utils.utility import print_arguments
|
||||
|
||||
|
||||
def main_sp(config, args):
|
||||
exp = Trainer(config, args)
|
||||
exp.setup()
|
||||
exp.run()
|
||||
|
||||
|
||||
def main(config, args):
|
||||
main_sp(config, args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = default_argument_parser()
|
||||
parser.add_argument(
|
||||
'--resume', type=str, default="", nargs="?", help='resume ckpt path.')
|
||||
args = parser.parse_args()
|
||||
print_arguments(args, globals())
|
||||
# https://yaml.org/type/float.html
|
||||
config = CfgNode(new_allowed=True)
|
||||
if args.config:
|
||||
config.merge_from_file(args.config)
|
||||
if args.opts:
|
||||
config.merge_from_list(args.opts)
|
||||
config.freeze()
|
||||
if args.dump_config:
|
||||
with open(args.dump_config, 'w') as f:
|
||||
print(config, file=f)
|
||||
|
||||
# Setting for profiling
|
||||
pr = cProfile.Profile()
|
||||
pr.runcall(main, config, args)
|
||||
pr.dump_stats(os.path.join(args.output, 'train.profile'))
|
@ -0,0 +1,918 @@
|
||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Contains hubert model."""
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from contextlib import nullcontext
|
||||
|
||||
import jsonlines
|
||||
import numpy as np
|
||||
import paddle
|
||||
from hyperpyyaml import load_hyperpyyaml
|
||||
from paddle import distributed as dist
|
||||
from paddlenlp.transformers import AutoTokenizer
|
||||
|
||||
from paddlespeech.s2t.frontend.featurizer import TextFeaturizer
|
||||
from paddlespeech.s2t.io.dataloader import DataLoaderFactory
|
||||
from paddlespeech.s2t.io.speechbrain import data_pipeline
|
||||
from paddlespeech.s2t.io.speechbrain import dataio
|
||||
from paddlespeech.s2t.io.speechbrain import dataset
|
||||
from paddlespeech.s2t.io.speechbrain.dataloader import make_dataloader
|
||||
from paddlespeech.s2t.models.hubert.hubert_ASR import HubertASR
|
||||
from paddlespeech.s2t.models.wav2vec2.processing.speech_augmentation import TimeDomainSpecAugment
|
||||
from paddlespeech.s2t.training.optimizer import OptimizerFactory
|
||||
from paddlespeech.s2t.training.reporter import ObsScope
|
||||
from paddlespeech.s2t.training.reporter import report
|
||||
from paddlespeech.s2t.training.scheduler import LRSchedulerFactory
|
||||
from paddlespeech.s2t.training.timer import Timer
|
||||
from paddlespeech.s2t.training.trainer import Trainer
|
||||
from paddlespeech.s2t.utils import error_rate
|
||||
from paddlespeech.s2t.utils import layer_tools
|
||||
from paddlespeech.s2t.utils import mp_tools
|
||||
from paddlespeech.s2t.utils.log import Log
|
||||
from paddlespeech.s2t.utils.utility import UpdateConfig
|
||||
|
||||
logger = Log(__name__).getlog()
|
||||
|
||||
|
||||
# Todo: change this when paddle supports this api
|
||||
def clip_grad_norm_(
|
||||
parameters,
|
||||
max_norm,
|
||||
norm_type=2.0,
|
||||
error_if_nonfinite=False, ):
|
||||
r"""Clips gradient norm of the iteratable parameters.
|
||||
|
||||
Norms are calculated together on all gradients, just as they are
|
||||
connected into one vector. The gradient will be modified in place.
|
||||
|
||||
This API can only run in dynamic graph mode, not static graph mode.
|
||||
|
||||
Args:
|
||||
parameters (Iterable[paddle.Tensor] or paddle.Tensor): Tensors or a single Tensor
|
||||
that will be normalized gradients
|
||||
max_norm (float or int): max norm of the gradients
|
||||
norm_type (float or int): type of the used p-norm. Can be `inf` for
|
||||
infinity norm.
|
||||
error_if_nonfinite (bool): if True, throw an error if the total
|
||||
norm of the gradients from :attr:`parameters` is `nan`,
|
||||
`inf`, or `-inf`.
|
||||
|
||||
Returns:
|
||||
Total norm of the parameter gradients (treated as a single vector).
|
||||
Example:
|
||||
.. code-block:: python
|
||||
import paddle
|
||||
|
||||
x = paddle.uniform([10, 10], min=-1.0, max=1.0, dtype='float32')
|
||||
max_norm = float(5.0)
|
||||
linear = paddle.nn.Linear(in_features=10, out_features=10)
|
||||
out = linear(x)
|
||||
loss = paddle.mean(out)
|
||||
loss.backward()
|
||||
|
||||
paddle.nn.utils.clip_grad_norm_(linear.parameters(), max_norm)
|
||||
|
||||
sdg = paddle.optimizer.SGD(learning_rate=0.1, parameters=linear.parameters())
|
||||
sdg.step()
|
||||
"""
|
||||
if not paddle.in_dynamic_mode():
|
||||
raise RuntimeError('this API can only run in dynamic mode.')
|
||||
|
||||
if isinstance(parameters, paddle.Tensor):
|
||||
parameters = [parameters]
|
||||
|
||||
support_norm_type = [float("inf"), 0, 1, 2]
|
||||
if norm_type not in support_norm_type:
|
||||
raise ValueError(f'norm_type only support {support_norm_type}')
|
||||
|
||||
grads = [p.grad for p in parameters if p.grad is not None]
|
||||
max_norm = float(max_norm)
|
||||
norm_type = float(norm_type)
|
||||
if len(grads) == 0:
|
||||
return paddle.to_tensor(0.0)
|
||||
if norm_type == float("inf"):
|
||||
norms = [g.detach().abs().max() for g in grads]
|
||||
total_norm = (norms[0]
|
||||
if len(norms) == 1 else paddle.max(paddle.stack(norms)))
|
||||
else:
|
||||
total_norm = paddle.linalg.norm(
|
||||
paddle.stack(
|
||||
[paddle.linalg.norm(g.detach(), norm_type) for g in grads]),
|
||||
norm_type, )
|
||||
|
||||
if error_if_nonfinite and paddle.logical_or(total_norm.isnan(),
|
||||
total_norm.isinf()):
|
||||
raise RuntimeError(
|
||||
f'The total norm of {norm_type} order of the gradients from '
|
||||
'`parameters` is non-finite, so it cannot be clipped. In any case, '
|
||||
'disable this error and scale the gradient by non-finite norm, '
|
||||
'set `error_if_nonfinite=False`')
|
||||
clip_coef = max_norm / (total_norm + 1e-6)
|
||||
# Note: when the coef is clamped to 1, it is redundant to multiply the clamped coef, but this
|
||||
# avoids the `if clip_coef < 1:` condition.
|
||||
clip_coef_clamped = paddle.clip(clip_coef, max=1.0)
|
||||
with paddle.no_grad():
|
||||
for _, p in enumerate(parameters):
|
||||
g = p.grad
|
||||
if g is not None:
|
||||
p.grad = paddle.multiply(x=g, y=clip_coef_clamped)
|
||||
return total_norm
|
||||
|
||||
|
||||
class HubertASRTrainer(Trainer):
|
||||
def __init__(self, config, args):
|
||||
super().__init__(config, args)
|
||||
self.avg_train_loss = 0.0
|
||||
self.loss_isfinite = True # while flag is 'False', loss in Nan or inf, and can not be avg
|
||||
self.use_sb = True # whether use speech brain dataloader
|
||||
|
||||
def update_average(self, batch_index, loss):
|
||||
"""Update running average of the loss.
|
||||
Arguments
|
||||
---------
|
||||
batch_index : int
|
||||
current batch index
|
||||
loss : paddle.tensor
|
||||
detached loss, a single float value.
|
||||
"""
|
||||
if math.isfinite(loss):
|
||||
self.avg_train_loss -= self.avg_train_loss / (batch_index + 1)
|
||||
self.avg_train_loss += loss / (batch_index + 1)
|
||||
else:
|
||||
self.loss_isfinite = False
|
||||
logger.info('loss:{} in Nan or inf, error'.format(loss))
|
||||
|
||||
def before_train(self):
|
||||
from_scratch = self.resume_or_scratch()
|
||||
if from_scratch:
|
||||
# scratch: save init model, i.e. 0 epoch
|
||||
self.save(tag='init', infos=None)
|
||||
else:
|
||||
# resume: train next_epoch and next_iteration
|
||||
self.epoch += 1
|
||||
logger.info(
|
||||
f"Resume train: epoch {self.epoch }, step {self.iteration}!")
|
||||
|
||||
self.maybe_batch_sampler_step()
|
||||
|
||||
def train_batch(self, batch_index, batch, msg):
|
||||
train_conf = self.config
|
||||
start = time.time()
|
||||
|
||||
# forward
|
||||
## sb data pipeline
|
||||
if self.use_sb:
|
||||
wav, wavs_lens_rate = batch['sig']
|
||||
target, target_lens_rate = batch['tokens']
|
||||
target_lens = (target_lens_rate *
|
||||
target.shape[1]).round().astype(paddle.int64)
|
||||
else:
|
||||
utt, wav, wavs_lens, target, target_lens = batch
|
||||
wavs_lens_rate = wavs_lens / wav.shape[1]
|
||||
wav = wav[:, :, 0]
|
||||
logger.info('training utt ids: {}'.format(utt))
|
||||
if hasattr(train_conf, 'audio_augment'):
|
||||
wav = self.speech_augmentation(wav, wavs_lens_rate)
|
||||
|
||||
loss = self.model(wav, wavs_lens_rate, target, target_lens)
|
||||
|
||||
# loss div by `batch_size * accum_grad`
|
||||
loss /= train_conf.accum_grad
|
||||
# update self.avg_train_loss
|
||||
self.update_average(batch_index, float(loss))
|
||||
|
||||
# loss backward
|
||||
if (batch_index + 1) % train_conf.accum_grad != 0:
|
||||
# Disable gradient synchronizations across DDP processes.
|
||||
# Within this context, gradients will be accumulated on module
|
||||
# variables, which will later be synchronized.
|
||||
# When using cpu w/o DDP, model does not have `no_sync`
|
||||
context = self.model.no_sync if (hasattr(self.model, "no_sync") and
|
||||
self.parallel) else nullcontext
|
||||
else:
|
||||
# Used for single gpu training and DDP gradient synchronization
|
||||
# processes.
|
||||
context = nullcontext
|
||||
with context():
|
||||
loss.backward()
|
||||
|
||||
layer_tools.print_grads(self.model, print_func=None)
|
||||
|
||||
# optimizer step old
|
||||
if (batch_index + 1) % train_conf.accum_grad == 0:
|
||||
#do global grad clip
|
||||
if train_conf.global_grad_clip != 0:
|
||||
clip_grad_norm_(self.model.parameters(),
|
||||
train_conf.global_grad_clip)
|
||||
self.model_optimizer.step()
|
||||
self.model_optimizer.clear_grad()
|
||||
if not train_conf.freeze_hubert:
|
||||
self.hubert_optimizer.step()
|
||||
self.hubert_optimizer.clear_grad()
|
||||
if self.config.model_scheduler != 'newbobscheduler':
|
||||
self.model_lr_scheduler.step()
|
||||
if self.config.hubert_scheduler != 'newbobscheduler':
|
||||
if not train_conf.freeze_hubert:
|
||||
self.hubert_lr_scheduler.step()
|
||||
self.iteration += 1
|
||||
|
||||
losses_np = {'loss': self.avg_train_loss * train_conf.accum_grad}
|
||||
iteration_time = time.time() - start
|
||||
for k, v in losses_np.items():
|
||||
report(k, v)
|
||||
report("loss_whitoutavg", float(loss))
|
||||
report("batch_size", self.config.batch_size)
|
||||
report("accum", train_conf.accum_grad)
|
||||
report("step_cost", iteration_time)
|
||||
|
||||
if (batch_index + 1) % train_conf.accum_grad == 0:
|
||||
if dist.get_rank() == 0 and self.visualizer:
|
||||
losses_np_v = losses_np.copy()
|
||||
losses_np_v.update({
|
||||
"model_lr": self.model_lr_scheduler(),
|
||||
"hubert_lr": self.hubert_lr_scheduler()
|
||||
})
|
||||
for key, val in losses_np_v.items():
|
||||
self.visualizer.add_scalar(
|
||||
tag='train/' + key, value=val, step=self.iteration - 1)
|
||||
|
||||
@paddle.no_grad()
|
||||
def valid(self):
|
||||
self.model.eval()
|
||||
if not self.use_streamdata:
|
||||
logger.info(
|
||||
f"Valid Total Examples: {len(self.valid_loader.dataset)}")
|
||||
valid_losses = {}
|
||||
step = 0
|
||||
total_loss = 0.0
|
||||
num_seen_utts = 1 # use update_average and no need for num_seen_utts here
|
||||
for i, batch in enumerate(self.valid_loader):
|
||||
if self.use_sb:
|
||||
wav, wavs_lens_rate = batch['sig']
|
||||
target, target_lens_rate = batch['tokens']
|
||||
target_lens = (target_lens_rate *
|
||||
target.shape[1]).round().astype(paddle.int64)
|
||||
else:
|
||||
utt, wav, wavs_lens, target, target_lens = batch
|
||||
wavs_lens_rate = wavs_lens / wav.shape[1]
|
||||
wav = wav[:, :, 0]
|
||||
|
||||
loss = self.model(wav, wavs_lens_rate, target, target_lens)
|
||||
# use update_average
|
||||
total_loss -= total_loss / (step + 1)
|
||||
total_loss += loss / (step + 1)
|
||||
|
||||
if math.isfinite(float(loss)):
|
||||
step += 1
|
||||
valid_losses['val_loss'] = float(loss)
|
||||
else:
|
||||
logger.info('loss:{} in Nan or inf, error'.format(float(loss)))
|
||||
|
||||
if (i + 1) % self.config.log_interval == 0:
|
||||
valid_losses['val_history_loss'] = float(total_loss)
|
||||
|
||||
# logging
|
||||
msg = f"Valid: Rank: {dist.get_rank()}, "
|
||||
msg += "epoch: {}, ".format(self.epoch)
|
||||
msg += "step: {}, ".format(self.iteration)
|
||||
if not self.use_streamdata:
|
||||
msg += "batch: {}/{}, ".format(i + 1,
|
||||
len(self.valid_loader))
|
||||
msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
||||
for k, v in valid_losses.items())
|
||||
logger.info(msg)
|
||||
|
||||
logger.info(
|
||||
'Rank {} Val info val_loss {}'.format(dist.get_rank(), total_loss))
|
||||
return total_loss, num_seen_utts
|
||||
|
||||
@mp_tools.rank_zero_only
|
||||
def save(self, tag=None, infos: dict=None):
|
||||
"""Save checkpoint (model parameters and optimizer states).
|
||||
|
||||
Args:
|
||||
tag (int or str, optional): None for step, else using tag, e.g epoch. Defaults to None.
|
||||
infos (dict, optional): meta data to save. Defaults to None.
|
||||
"""
|
||||
|
||||
infos = infos if infos else dict()
|
||||
infos.update({
|
||||
"epoch": self.epoch,
|
||||
"model_lr": self.model_optimizer.get_lr(),
|
||||
"hubert_lr": self.hubert_optimizer.get_lr()
|
||||
})
|
||||
|
||||
checkpoint_path = os.path.join(
|
||||
self.checkpoint_dir,
|
||||
"{}".format(self.iteration if tag is None else tag))
|
||||
|
||||
model_dict = self.model.state_dict()
|
||||
params_path = checkpoint_path + ".pdparams"
|
||||
paddle.save(model_dict, params_path)
|
||||
logger.info("Saved model to {}".format(params_path))
|
||||
|
||||
model_opt_dict = self.model_optimizer.state_dict()
|
||||
hubert_opt_dict = self.hubert_optimizer.state_dict()
|
||||
|
||||
opt_dict = {'model': model_opt_dict, 'hubert': hubert_opt_dict}
|
||||
|
||||
optimizer_path = checkpoint_path + ".pdopt"
|
||||
paddle.save(opt_dict, optimizer_path)
|
||||
logger.info("Saved optimzier state to {}".format(optimizer_path))
|
||||
|
||||
scheduler_dict = {}
|
||||
|
||||
if self.config.model_scheduler == 'newbobscheduler':
|
||||
scheduler_dict['model'] = self.model_lr_scheduler.save()
|
||||
if self.config.hubert_scheduler == 'newbobscheduler':
|
||||
scheduler_dict['hubert'] = self.hubert_lr_scheduler.save()
|
||||
if scheduler_dict:
|
||||
scheduler_path = checkpoint_path + ".pdlrs"
|
||||
paddle.save(scheduler_dict, scheduler_path)
|
||||
logger.info("Saved scheduler state to {}".format(scheduler_path))
|
||||
info_path = re.sub('.pdparams$', '.json', params_path)
|
||||
infos = {} if infos is None else infos
|
||||
with open(info_path, 'w', encoding='utf8') as fout:
|
||||
data = json.dumps(infos)
|
||||
fout.write(data)
|
||||
|
||||
def resume_or_scratch(self):
|
||||
"""Resume from latest checkpoint at checkpoints in the output
|
||||
directory or load a specified checkpoint.
|
||||
|
||||
If ``args.checkpoint_path`` is not None, load the checkpoint, else
|
||||
resume training.
|
||||
"""
|
||||
scratch = None
|
||||
if self.args.resume:
|
||||
# just restore ckpt
|
||||
# lr will resotre from optimizer ckpt
|
||||
resume_json_path = os.path.join(self.checkpoint_dir,
|
||||
self.args.resume + '.json')
|
||||
with open(resume_json_path, 'r', encoding='utf8') as f:
|
||||
resume_json = json.load(f)
|
||||
self.iteration = 0
|
||||
self.epoch = resume_json["epoch"]
|
||||
|
||||
# resotre model from *.pdparams
|
||||
params_path = os.path.join(self.checkpoint_dir,
|
||||
"{}".format(self.epoch)) + '.pdparams'
|
||||
model_dict = paddle.load(params_path)
|
||||
self.model.set_state_dict(model_dict)
|
||||
|
||||
# resotre optimizer from *.pdopt
|
||||
optimizer_path = os.path.join(self.checkpoint_dir,
|
||||
"{}".format(self.epoch)) + '.pdopt'
|
||||
optimizer_dict = paddle.load(optimizer_path)
|
||||
self.model_optimizer.set_state_dict(optimizer_dict['model'])
|
||||
self.hubert_optimizer.set_state_dict(optimizer_dict['hubert'])
|
||||
|
||||
# resotre lr_scheduler from *.pdlrs
|
||||
scheduler_path = os.path.join(self.checkpoint_dir,
|
||||
"{}".format(self.epoch)) + '.pdlrs'
|
||||
if os.path.isfile(os.path.join(scheduler_path)):
|
||||
scheduler_dict = paddle.load(scheduler_path)
|
||||
if self.config.model_scheduler == 'newbobscheduler':
|
||||
self.model_lr_scheduler.load(scheduler_dict['model'])
|
||||
if self.config.hubert_scheduler == 'newbobscheduler':
|
||||
self.hubert_lr_scheduler.load(scheduler_dict['hubert'])
|
||||
logger.info(
|
||||
f"Restore ckpt: epoch {self.epoch }, step {self.iteration}!")
|
||||
scratch = False
|
||||
else:
|
||||
self.iteration = 0
|
||||
self.epoch = 0
|
||||
scratch = True
|
||||
logger.info("Init from scratch!")
|
||||
return scratch
|
||||
|
||||
def do_train(self):
|
||||
"""The training process control by step."""
|
||||
# !!!IMPORTANT!!!
|
||||
# Try to export the model by script, if fails, we should refine
|
||||
# the code to satisfy the script export requirements
|
||||
# script_model = paddle.jit.to_static(self.model)
|
||||
# script_model_path = str(self.checkpoint_dir / 'init')
|
||||
# paddle.jit.save(script_model, script_model_path)
|
||||
|
||||
self.before_train()
|
||||
if not self.use_streamdata:
|
||||
logger.info(
|
||||
f"Train Total Examples: {len(self.train_loader.dataset)}")
|
||||
while self.epoch < self.config.n_epoch:
|
||||
with Timer("Epoch-Train Time Cost: {}"):
|
||||
self.model.train()
|
||||
try:
|
||||
data_start_time = time.time()
|
||||
for batch_index, batch in enumerate(self.train_loader):
|
||||
dataload_time = time.time() - data_start_time
|
||||
msg = "Train:"
|
||||
observation = OrderedDict()
|
||||
with ObsScope(observation):
|
||||
report("Rank", dist.get_rank())
|
||||
report("epoch", self.epoch)
|
||||
report('step', self.iteration)
|
||||
report("model_lr", self.model_optimizer.get_lr())
|
||||
report("hubert_lr", self.hubert_optimizer.get_lr())
|
||||
self.train_batch(batch_index, batch, msg)
|
||||
self.after_train_batch()
|
||||
report('iter', batch_index + 1)
|
||||
if not self.use_streamdata:
|
||||
report('total', len(self.train_loader))
|
||||
report('reader_cost', dataload_time)
|
||||
observation['batch_cost'] = observation[
|
||||
'reader_cost'] + observation['step_cost']
|
||||
observation['samples'] = observation['batch_size']
|
||||
observation['ips,samples/s'] = observation[
|
||||
'batch_size'] / observation['batch_cost']
|
||||
for k, v in observation.items():
|
||||
msg += f" {k.split(',')[0]}: "
|
||||
msg += f"{v:>.8f}" if isinstance(v,
|
||||
float) else f"{v}"
|
||||
msg += f" {k.split(',')[1]}" if len(
|
||||
k.split(',')) == 2 else ""
|
||||
msg += ","
|
||||
msg = msg[:-1] # remove the last ","
|
||||
if (batch_index + 1) % self.config.log_interval == 0:
|
||||
logger.info(msg)
|
||||
data_start_time = time.time()
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
raise e
|
||||
with Timer("Eval Time Cost: {}"):
|
||||
total_loss, num_seen_utts = self.valid()
|
||||
if dist.get_world_size() > 1:
|
||||
num_seen_utts = paddle.to_tensor(num_seen_utts)
|
||||
dist.all_reduce(num_seen_utts)
|
||||
total_loss = paddle.to_tensor(total_loss)
|
||||
dist.all_reduce(total_loss)
|
||||
cv_loss = total_loss / num_seen_utts
|
||||
cv_loss = float(cv_loss)
|
||||
else:
|
||||
cv_loss = float(total_loss)
|
||||
logger.info(
|
||||
'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss))
|
||||
if self.visualizer:
|
||||
self.visualizer.add_scalar(
|
||||
tag='eval/cv_loss', value=cv_loss, step=self.epoch)
|
||||
self.visualizer.add_scalar(
|
||||
tag='eval/model_lr',
|
||||
value=self.model_lr_scheduler(),
|
||||
step=self.epoch)
|
||||
self.visualizer.add_scalar(
|
||||
tag='eval/hubert_lr',
|
||||
value=self.hubert_lr_scheduler(),
|
||||
step=self.epoch)
|
||||
|
||||
if self.config.model_scheduler == 'newbobscheduler':
|
||||
self.model_lr_scheduler.step(cv_loss)
|
||||
if self.config.hubert_scheduler == 'newbobscheduler':
|
||||
if not self.config.freeze_hubert:
|
||||
self.hubert_lr_scheduler.step(cv_loss)
|
||||
self.save(tag=self.epoch, infos={'val_loss': cv_loss})
|
||||
self.avg_train_loss = 0.0
|
||||
self.new_epoch()
|
||||
|
||||
def dataio_prepare(self, hparams):
|
||||
"""This function prepares the datasets to be used in the brain class.
|
||||
It also defines the data processing pipeline through user-defined functions."""
|
||||
data_folder = hparams["data_folder"]
|
||||
|
||||
train_data = dataset.DynamicItemDataset.from_csv(
|
||||
csv_path=hparams["train_data"],
|
||||
replacements={"data_root": data_folder}, )
|
||||
|
||||
if hparams["sorting"] == "ascending":
|
||||
# we sort training data to speed up training and get better results.
|
||||
train_data = train_data.filtered_sorted(sort_key="duration")
|
||||
# when sorting do not shuffle in dataloader ! otherwise is pointless
|
||||
hparams["train_dataloader_opts"]["shuffle"] = False
|
||||
|
||||
elif hparams["sorting"] == "descending":
|
||||
train_data = train_data.filtered_sorted(
|
||||
sort_key="duration", reverse=True)
|
||||
# when sorting do not shuffle in dataloader ! otherwise is pointless
|
||||
hparams["train_dataloader_opts"]["shuffle"] = False
|
||||
|
||||
elif hparams["sorting"] == "random":
|
||||
pass
|
||||
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"sorting must be random, ascending or descending")
|
||||
|
||||
valid_data = dataset.DynamicItemDataset.from_csv(
|
||||
csv_path=hparams["valid_data"],
|
||||
replacements={"data_root": data_folder}, )
|
||||
valid_data = valid_data.filtered_sorted(sort_key="duration")
|
||||
|
||||
test_data = dataset.DynamicItemDataset.from_csv(
|
||||
csv_path=hparams["test_data"],
|
||||
replacements={"data_root": data_folder}, )
|
||||
test_data = test_data.filtered_sorted(sort_key="duration")
|
||||
|
||||
datasets = [train_data, valid_data, test_data]
|
||||
|
||||
# Defining tokenizer and loading it
|
||||
tokenizer = AutoTokenizer.from_pretrained('bert-base-chinese')
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
# 2. Define audio pipeline:
|
||||
@data_pipeline.takes("wav")
|
||||
@data_pipeline.provides("sig")
|
||||
def audio_pipeline(wav):
|
||||
sig = dataio.read_audio(wav)
|
||||
return sig
|
||||
|
||||
dataset.add_dynamic_item(datasets, audio_pipeline)
|
||||
|
||||
# 3. Define text pipeline:
|
||||
@data_pipeline.takes("transcript")
|
||||
@data_pipeline.provides("wrd", "tokens_list", "tokens")
|
||||
def text_pipeline(wrd):
|
||||
wrd = "".join(wrd.split(" "))
|
||||
yield wrd
|
||||
tokens_list = tokenizer(wrd)["input_ids"]
|
||||
yield tokens_list
|
||||
tokens = np.array(tokens_list, dtype="int64")
|
||||
# tokens = paddle.to_tensor(tokens_list, dtype="int64")
|
||||
yield tokens
|
||||
|
||||
dataset.add_dynamic_item(datasets, text_pipeline)
|
||||
|
||||
# 4. Set output:
|
||||
dataset.set_output_keys(
|
||||
datasets,
|
||||
["id", "sig", "wrd", "tokens"], )
|
||||
|
||||
# 5. If Dynamic Batching is used, we instantiate the needed samplers.
|
||||
train_batch_sampler = None
|
||||
valid_batch_sampler = None
|
||||
if hparams["dynamic_batching"]:
|
||||
from sampler import DynamicBatchSampler # noqa
|
||||
|
||||
dynamic_hparams = hparams["dynamic_batch_sampler"]
|
||||
num_buckets = dynamic_hparams["num_buckets"]
|
||||
|
||||
train_batch_sampler = DynamicBatchSampler(
|
||||
train_data,
|
||||
dynamic_hparams["max_batch_len"],
|
||||
num_buckets=num_buckets,
|
||||
length_func=lambda x: x["duration"],
|
||||
shuffle=dynamic_hparams["shuffle_ex"],
|
||||
batch_ordering=dynamic_hparams["batch_ordering"], )
|
||||
|
||||
valid_batch_sampler = DynamicBatchSampler(
|
||||
valid_data,
|
||||
dynamic_hparams["max_batch_len"],
|
||||
num_buckets=num_buckets,
|
||||
length_func=lambda x: x["duration"],
|
||||
shuffle=dynamic_hparams["shuffle_ex"],
|
||||
batch_ordering=dynamic_hparams["batch_ordering"], )
|
||||
|
||||
return (train_data, valid_data, test_data, tokenizer,
|
||||
train_batch_sampler, valid_batch_sampler, )
|
||||
|
||||
def setup_dataloader(self):
|
||||
config = self.config.clone()
|
||||
self.use_streamdata = config.get("use_stream_data", False)
|
||||
self.use_sb = config.get("use_sb_pipeline", False)
|
||||
if self.use_sb:
|
||||
hparams_file = config.sb_pipeline_conf
|
||||
with open(hparams_file, 'r', encoding='utf8') as fin:
|
||||
hparams = load_hyperpyyaml(fin, None)
|
||||
|
||||
(train_data, valid_data, test_data, tokenizer, train_bsampler,
|
||||
valid_bsampler, ) = self.dataio_prepare(hparams)
|
||||
|
||||
train_dataloader_opts = hparams["train_dataloader_opts"]
|
||||
valid_dataloader_opts = hparams["valid_dataloader_opts"]
|
||||
|
||||
if train_bsampler is not None:
|
||||
train_dataloader_opts = {
|
||||
"batch_sampler": train_bsampler,
|
||||
"num_workers": hparams["num_workers"],
|
||||
}
|
||||
|
||||
if valid_bsampler is not None:
|
||||
valid_dataloader_opts = {"batch_sampler": valid_bsampler}
|
||||
|
||||
if self.train:
|
||||
self.train_loader = make_dataloader(
|
||||
train_data, stage='train', **train_dataloader_opts)
|
||||
self.valid_loader = make_dataloader(
|
||||
valid_data,
|
||||
stage='val',
|
||||
**valid_dataloader_opts, )
|
||||
logger.info("Setup train/valid Dataloader!")
|
||||
else:
|
||||
self.test_loader = make_dataloader(
|
||||
test_data, stage='test', **hparams["test_dataloader_opts"])
|
||||
else:
|
||||
if self.train:
|
||||
self.train_loader = DataLoaderFactory.get_dataloader(
|
||||
'train', config, self.args)
|
||||
self.valid_loader = DataLoaderFactory.get_dataloader(
|
||||
'valid', config, self.args)
|
||||
logger.info("Setup train/valid Dataloader!")
|
||||
else:
|
||||
decode_batch_size = config.get('decode', dict()).get(
|
||||
'decode_batch_size', 1)
|
||||
self.test_loader = DataLoaderFactory.get_dataloader(
|
||||
'test', config, self.args)
|
||||
self.align_loader = DataLoaderFactory.get_dataloader(
|
||||
'align', config, self.args)
|
||||
logger.info("Setup test/align Dataloader!")
|
||||
|
||||
def setup_model(self):
|
||||
config = self.config
|
||||
model_conf = config
|
||||
|
||||
with UpdateConfig(model_conf):
|
||||
if self.use_sb:
|
||||
model_conf.output_dim = self.tokenizer.vocab_size
|
||||
else:
|
||||
if self.train:
|
||||
model_conf.input_dim = self.train_loader.feat_dim
|
||||
model_conf.output_dim = self.train_loader.vocab_size
|
||||
else:
|
||||
model_conf.input_dim = self.test_loader.feat_dim
|
||||
model_conf.output_dim = self.test_loader.vocab_size
|
||||
|
||||
model = HubertASR.from_config(model_conf)
|
||||
|
||||
model_dict = paddle.load(config.hubert_params_path)
|
||||
model.set_state_dict(model_dict)
|
||||
|
||||
if self.parallel:
|
||||
model = paddle.DataParallel(model, find_unused_parameters=True)
|
||||
|
||||
layer_tools.print_params(model, logger.info)
|
||||
self.model = model
|
||||
logger.info("Setup model!")
|
||||
|
||||
# setup speech augmentation for hubert
|
||||
if hasattr(config, 'audio_augment') and self.train:
|
||||
self.speech_augmentation = TimeDomainSpecAugment(
|
||||
**config.audio_augment)
|
||||
|
||||
if not self.train:
|
||||
return
|
||||
|
||||
train_config = config
|
||||
model_optim_type = train_config.model_optim
|
||||
model_optim_conf = train_config.model_optim_conf
|
||||
logger.info("optim_model:{},{}", model_optim_type, model_optim_conf)
|
||||
hubert_optim_type = train_config.hubert_optim
|
||||
hubert_optim_conf = train_config.hubert_optim_conf
|
||||
logger.info("optim_model:{},{}", hubert_optim_type, hubert_optim_conf)
|
||||
|
||||
model_scheduler_type = train_config.model_scheduler
|
||||
model_scheduler_conf = train_config.model_scheduler_conf
|
||||
hubert_scheduler_type = train_config.hubert_scheduler
|
||||
hubert_scheduler_conf = train_config.hubert_scheduler_conf
|
||||
|
||||
model_scheduler_args = dict(
|
||||
**{"learning_rate": model_optim_conf.lr,
|
||||
"verbose": False}, **(dict(model_scheduler_conf)))
|
||||
|
||||
hubert_scheduler_args = dict(
|
||||
**{"learning_rate": hubert_optim_conf.lr,
|
||||
"verbose": False}, **(dict(hubert_scheduler_conf)))
|
||||
|
||||
model_lr_scheduler = LRSchedulerFactory.from_args(model_scheduler_type,
|
||||
model_scheduler_args)
|
||||
hubert_lr_scheduler = LRSchedulerFactory.from_args(
|
||||
hubert_scheduler_type, hubert_scheduler_args)
|
||||
|
||||
def optimizer_args(
|
||||
config,
|
||||
optim_type,
|
||||
optim_conf,
|
||||
parameters,
|
||||
lr_scheduler=None, ):
|
||||
optim_arg = dict(optim_conf)
|
||||
optim_arg.update({
|
||||
"learning_rate":
|
||||
lr_scheduler if lr_scheduler else optim_conf.lr,
|
||||
"parameters":
|
||||
parameters
|
||||
})
|
||||
return optim_arg
|
||||
|
||||
model_optimizer_args = optimizer_args(config, model_optim_type,
|
||||
model_optim_conf, [{
|
||||
'params':
|
||||
model._layers.enc.parameters()
|
||||
}, {
|
||||
'params':
|
||||
model._layers.ctc.parameters()
|
||||
}] if self.parallel else [{
|
||||
'params':
|
||||
model.enc.parameters()
|
||||
}, {
|
||||
'params':
|
||||
model.ctc.parameters()
|
||||
}], model_lr_scheduler)
|
||||
|
||||
hubert_optimizer_args = optimizer_args(
|
||||
config, hubert_optim_type, hubert_optim_conf,
|
||||
model._layers.hubert.parameters() if self.parallel else
|
||||
model.hubert.parameters(), hubert_lr_scheduler)
|
||||
|
||||
model_optimizer = OptimizerFactory.from_args(model_optim_type,
|
||||
model_optimizer_args)
|
||||
hubert_optimizer = OptimizerFactory.from_args(hubert_optim_type,
|
||||
hubert_optimizer_args)
|
||||
|
||||
self.model_optimizer = model_optimizer
|
||||
self.hubert_optimizer = hubert_optimizer
|
||||
self.model_lr_scheduler = model_lr_scheduler
|
||||
self.hubert_lr_scheduler = hubert_lr_scheduler
|
||||
logger.info("Setup optimizer/lr_scheduler!")
|
||||
|
||||
|
||||
class HubertASRTester(HubertASRTrainer):
|
||||
def __init__(self, config, args):
|
||||
super().__init__(config, args)
|
||||
self.text_featurizer = TextFeaturizer(
|
||||
unit_type=config.unit_type, vocab=config.vocab_filepath)
|
||||
self.vocab_list = self.text_featurizer.vocab_list
|
||||
|
||||
def id2token(self, texts, texts_len):
|
||||
""" ord() id to chr() chr """
|
||||
trans = []
|
||||
for text, n in zip(texts, texts_len):
|
||||
n = n.numpy().item()
|
||||
ids = text[:n]
|
||||
trans.append(self.text_featurizer.defeaturize(ids.numpy().tolist()))
|
||||
return trans
|
||||
|
||||
def compute_metrics(self, id, audio, audio_len, texts, texts_len,
|
||||
fout=None):
|
||||
decode_cfg = self.config.decode
|
||||
errors_sum, len_refs, num_ins = 0.0, 0, 0
|
||||
errors_func = error_rate.char_errors if decode_cfg.error_rate_type == 'cer' else error_rate.word_errors
|
||||
error_rate_func = error_rate.cer if decode_cfg.error_rate_type == 'cer' else error_rate.wer
|
||||
|
||||
start_time = time.time()
|
||||
target_transcripts = self.id2token(texts, texts_len)
|
||||
result_transcripts, result_tokenids = self.model.decode(
|
||||
audio,
|
||||
text_feature=self.text_featurizer,
|
||||
decoding_method=decode_cfg.decoding_method,
|
||||
beam_size=decode_cfg.beam_size)
|
||||
decode_time = time.time() - start_time
|
||||
|
||||
for utt, target, result, rec_tids in zip(
|
||||
id, target_transcripts, result_transcripts, result_tokenids):
|
||||
errors, len_ref = errors_func(target, result)
|
||||
errors_sum += errors
|
||||
len_refs += len_ref
|
||||
num_ins += 1
|
||||
if fout:
|
||||
fout.write({
|
||||
"utt": utt,
|
||||
"refs": [target],
|
||||
"hyps": [result],
|
||||
"hyps_tokenid": [rec_tids],
|
||||
})
|
||||
logger.info(f"Utt: {utt}")
|
||||
logger.info(f"Ref: {target}")
|
||||
logger.info(f"Hyp: {result}")
|
||||
logger.info("One example error rate [%s] = %f" % (
|
||||
decode_cfg.error_rate_type, error_rate_func(target, result)))
|
||||
|
||||
return dict(
|
||||
errors_sum=errors_sum,
|
||||
len_refs=len_refs,
|
||||
num_ins=num_ins, # num examples
|
||||
error_rate=errors_sum / len_refs,
|
||||
error_rate_type=decode_cfg.error_rate_type,
|
||||
num_frames=audio_len.sum().numpy().item(),
|
||||
decode_time=decode_time)
|
||||
|
||||
def sb_compute_metrics(self, id, sig, wrd, tokens, fout=None):
|
||||
decode_cfg = self.config.decode
|
||||
errors_sum, len_refs, num_ins = 0.0, 0, 0
|
||||
errors_func = error_rate.char_errors if decode_cfg.error_rate_type == 'cer' else error_rate.word_errors
|
||||
error_rate_func = error_rate.cer if decode_cfg.error_rate_type == 'cer' else error_rate.wer
|
||||
start_time = time.time()
|
||||
target_transcripts = wrd
|
||||
result_transcripts, result_tokenids = self.model.decode(
|
||||
sig[0],
|
||||
text_feature=self.tokenizer,
|
||||
decoding_method=decode_cfg.decoding_method,
|
||||
beam_size=decode_cfg.beam_size,
|
||||
sb_pipeline=True)
|
||||
decode_time = time.time() - start_time
|
||||
|
||||
for utt, target, result, rec_tids in zip(
|
||||
id, target_transcripts, result_transcripts, result_tokenids):
|
||||
errors, len_ref = errors_func(target, result)
|
||||
errors_sum += errors
|
||||
len_refs += len_ref
|
||||
num_ins += 1
|
||||
if fout:
|
||||
fout.write({
|
||||
"utt": utt,
|
||||
"refs": [target],
|
||||
"hyps": [result],
|
||||
"hyps_tokenid": [rec_tids],
|
||||
})
|
||||
logger.info(f"Utt: {utt}")
|
||||
logger.info(f"Ref: {target}")
|
||||
logger.info(f"Hyp: {result}")
|
||||
logger.info("One example error rate [%s] = %f" % (
|
||||
decode_cfg.error_rate_type, error_rate_func(target, result)))
|
||||
|
||||
return dict(
|
||||
errors_sum=errors_sum,
|
||||
len_refs=len_refs,
|
||||
num_ins=num_ins, # num examples
|
||||
error_rate=errors_sum / len_refs,
|
||||
error_rate_type=decode_cfg.error_rate_type,
|
||||
num_frames=sig[1].sum().numpy().item(),
|
||||
decode_time=decode_time)
|
||||
|
||||
@mp_tools.rank_zero_only
|
||||
@paddle.no_grad()
|
||||
def test(self):
|
||||
logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}")
|
||||
self.model.eval()
|
||||
|
||||
error_rate_type = None
|
||||
errors_sum, len_refs, num_ins = 0.0, 0, 0
|
||||
num_frames = 0.0
|
||||
num_time = 0.0
|
||||
# Initialized the decoder in model
|
||||
decode_cfg = self.config.decode
|
||||
vocab_list = self.vocab_list
|
||||
decode_batch_size = decode_cfg.decode_batch_size
|
||||
|
||||
with jsonlines.open(self.args.result_file, 'w') as fout:
|
||||
for i, batch in enumerate(self.test_loader):
|
||||
if self.use_sb:
|
||||
metrics = self.sb_compute_metrics(**batch, fout=fout)
|
||||
else:
|
||||
metrics = self.compute_metrics(*batch, fout=fout)
|
||||
num_frames += metrics['num_frames']
|
||||
num_time += metrics["decode_time"]
|
||||
errors_sum += metrics['errors_sum']
|
||||
len_refs += metrics['len_refs']
|
||||
num_ins += metrics['num_ins']
|
||||
error_rate_type = metrics['error_rate_type']
|
||||
rtf = num_time / (num_frames)
|
||||
logger.info(
|
||||
"RTF: %f, Error rate [%s] (%d/?) = %f" %
|
||||
(rtf, error_rate_type, num_ins, errors_sum / len_refs))
|
||||
|
||||
# logging
|
||||
msg = "Test: "
|
||||
msg += "epoch: {}, ".format(self.epoch)
|
||||
msg += "step: {}, ".format(self.iteration)
|
||||
msg += "Final error rate [%s] (%d/%d) = %f" % (
|
||||
error_rate_type, num_ins, num_ins, errors_sum / len_refs)
|
||||
logger.info(msg)
|
||||
|
||||
err_meta_path = os.path.splitext(self.args.result_file)[0] + '.err'
|
||||
err_type_str = "{}".format(error_rate_type)
|
||||
with open(err_meta_path, 'w', encoding='utf8') as f:
|
||||
data = json.dumps({
|
||||
"epoch":
|
||||
self.epoch,
|
||||
"step":
|
||||
self.iteration,
|
||||
"rtf":
|
||||
rtf,
|
||||
error_rate_type:
|
||||
errors_sum / len_refs,
|
||||
"dataset_hour": (num_frames) / 1000.0 / 3600.0,
|
||||
"process_hour":
|
||||
num_time / 1000.0 / 3600.0,
|
||||
"num_examples":
|
||||
num_ins,
|
||||
"err_sum":
|
||||
errors_sum,
|
||||
"ref_len":
|
||||
len_refs,
|
||||
"decode_method":
|
||||
self.config.decode.decoding_method,
|
||||
})
|
||||
f.write(data + '\n')
|
@ -0,0 +1,17 @@
|
||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from .hubert_ASR import HubertASR
|
||||
from .hubert_ASR import HubertBase
|
||||
|
||||
__all__ = ["HubertASR", "HubertBase"]
|
@ -0,0 +1,368 @@
|
||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""HubertASR model."""
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import is_dataclass
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
from paddlespeech.s2t.models.hubert.modules.hubert_model import HubertConfig
|
||||
from paddlespeech.s2t.models.hubert.modules.hubert_model import HubertModel
|
||||
from paddlespeech.s2t.models.hubert.modules.hubert_model import HubertPretrainingConfig
|
||||
from paddlespeech.s2t.models.wav2vec2.modules.VanillaNN import VanillaNN
|
||||
from paddlespeech.s2t.models.wav2vec2.processing.speech_augmentation import SpecAugment
|
||||
from paddlespeech.s2t.modules.ctc import CTCDecoderBase as CTC
|
||||
from paddlespeech.s2t.modules.initializer import DefaultInitializerContext
|
||||
from paddlespeech.s2t.utils.ctc_utils import remove_duplicates_and_blank
|
||||
from paddlespeech.s2t.utils.log import Log
|
||||
from paddlespeech.s2t.utils.utility import log_add
|
||||
|
||||
logger = Log(__name__).getlog()
|
||||
|
||||
|
||||
class HubertASR(nn.Layer):
|
||||
def __init__(self, config: dict):
|
||||
super().__init__()
|
||||
init_type = config.get("init_type", None)
|
||||
with DefaultInitializerContext(init_type):
|
||||
self.config = config
|
||||
task_cfg = self.merge_with_parent(HubertPretrainingConfig,
|
||||
dict(self.config.task_cfg))
|
||||
model_cfg = self.merge_with_parent(HubertConfig,
|
||||
dict(self.config.model_cfg))
|
||||
hubert = HubertModel(model_cfg, task_cfg, [None])
|
||||
|
||||
self.normalize_wav = config.normalize_wav
|
||||
self.output_norm = config.output_norm
|
||||
if hasattr(config, 'spec_augment'):
|
||||
self.spec_augment = SpecAugment(**config.spec_augment)
|
||||
|
||||
if config.freeze_hubert:
|
||||
hubert.eval()
|
||||
for parm in hubert.parameters():
|
||||
parm.trainable = False
|
||||
self.hubert = hubert
|
||||
self.enc = VanillaNN(**config.enc)
|
||||
self.ctc = CTC(**config.ctc,
|
||||
odim=config.output_dim,
|
||||
batch_average=False,
|
||||
reduction='mean')
|
||||
|
||||
def merge_with_parent(self, dc: dataclass, cfg: dict):
|
||||
assert is_dataclass(dc)
|
||||
assert type(cfg) == dict
|
||||
cfg = deepcopy(cfg)
|
||||
|
||||
def fix_cfg(cfg):
|
||||
target_keys = set(dc.__dataclass_fields__.keys())
|
||||
for k in list(cfg.keys()):
|
||||
if k not in target_keys:
|
||||
del cfg[k]
|
||||
|
||||
fix_cfg(cfg)
|
||||
assert len(cfg) > 0
|
||||
return dc(**cfg)
|
||||
|
||||
def forward(self, wav, wavs_lens_rate, target, target_lens):
|
||||
|
||||
if self.normalize_wav:
|
||||
wav = F.layer_norm(wav, wav.shape)
|
||||
|
||||
# Extract wav2vec output
|
||||
out = self.hubert.extract_features(wav)[0]
|
||||
# We normalize the output if required
|
||||
if self.output_norm:
|
||||
out = F.layer_norm(out, out.shape)
|
||||
|
||||
if self.training and hasattr(self.config, 'spec_augment'):
|
||||
feats = self.spec_augment(out)
|
||||
else:
|
||||
feats = out
|
||||
|
||||
x = self.enc(feats)
|
||||
|
||||
x_lens = (wavs_lens_rate * x.shape[1]).round().astype(paddle.int64)
|
||||
|
||||
ctc_loss = self.ctc(x, x_lens, target, target_lens)
|
||||
|
||||
return ctc_loss
|
||||
|
||||
@paddle.no_grad()
|
||||
def decode(self,
|
||||
feats: paddle.Tensor,
|
||||
text_feature: Dict[str, int],
|
||||
decoding_method: str,
|
||||
beam_size: int,
|
||||
tokenizer: str=None,
|
||||
sb_pipeline=False):
|
||||
batch_size = feats.shape[0]
|
||||
|
||||
if decoding_method == 'ctc_prefix_beam_search' and batch_size > 1:
|
||||
logger.error(
|
||||
f"decoding mode {decoding_method} must be running with batch_size == 1"
|
||||
)
|
||||
logger.error(f"current batch_size is {batch_size}")
|
||||
|
||||
if decoding_method == 'ctc_greedy_search':
|
||||
if tokenizer is None and sb_pipeline is False:
|
||||
hyps = self.ctc_greedy_search(feats)
|
||||
res = [text_feature.defeaturize(hyp) for hyp in hyps]
|
||||
res_tokenids = [hyp for hyp in hyps]
|
||||
else:
|
||||
if sb_pipeline is True:
|
||||
hyps = self.ctc_greedy_search(feats.unsqueeze(-1))
|
||||
else:
|
||||
hyps = self.ctc_greedy_search(feats)
|
||||
res = []
|
||||
res_tokenids = []
|
||||
for sequence in hyps:
|
||||
# Decode token terms to words
|
||||
predicted_tokens = text_feature.convert_ids_to_tokens(
|
||||
sequence)
|
||||
tmp_res = []
|
||||
tmp_res_tokenids = []
|
||||
for c in predicted_tokens:
|
||||
if c == "[CLS]":
|
||||
continue
|
||||
elif c == "[SEP]" or c == "[PAD]":
|
||||
break
|
||||
else:
|
||||
tmp_res.append(c)
|
||||
tmp_res_tokenids.append(text_feature.vocab[c])
|
||||
res.append(''.join(tmp_res))
|
||||
res_tokenids.append(tmp_res_tokenids)
|
||||
|
||||
# ctc_prefix_beam_search and attention_rescoring only return one
|
||||
# result in List[int], change it to List[List[int]] for compatible
|
||||
# with other batch decoding mode
|
||||
elif decoding_method == 'ctc_prefix_beam_search':
|
||||
assert feats.shape[0] == 1
|
||||
if tokenizer is None and sb_pipeline is False:
|
||||
hyp = self.ctc_prefix_beam_search(feats, beam_size)
|
||||
res = [text_feature.defeaturize(hyp)]
|
||||
res_tokenids = [hyp]
|
||||
else:
|
||||
if sb_pipeline is True:
|
||||
hyp = self.ctc_prefix_beam_search(
|
||||
feats.unsqueeze(-1), beam_size)
|
||||
else:
|
||||
hyp = self.ctc_prefix_beam_search(feats, beam_size)
|
||||
res = []
|
||||
res_tokenids = []
|
||||
predicted_tokens = text_feature.convert_ids_to_tokens(hyp)
|
||||
tmp_res = []
|
||||
tmp_res_tokenids = []
|
||||
for c in predicted_tokens:
|
||||
if c == "[CLS]":
|
||||
continue
|
||||
elif c == "[SEP]" or c == "[PAD]":
|
||||
break
|
||||
else:
|
||||
tmp_res.append(c)
|
||||
tmp_res_tokenids.append(text_feature.vocab[c])
|
||||
res.append(''.join(tmp_res))
|
||||
res_tokenids.append(tmp_res_tokenids)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"wav2vec2 not support decoding method: {decoding_method}")
|
||||
|
||||
return res, res_tokenids
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
model = cls(config)
|
||||
return model
|
||||
|
||||
def ctc_greedy_search(self, wav) -> List[List[int]]:
|
||||
""" Apply CTC greedy search
|
||||
Args:
|
||||
speech (paddle.Tensor): (batch, max_len)
|
||||
speech_length (paddle.Tensor): (batch, )
|
||||
Returns:
|
||||
List[List[int]]: best path result
|
||||
"""
|
||||
batch_size = wav.shape[0]
|
||||
wav = wav[:, :, 0]
|
||||
if self.normalize_wav:
|
||||
wav = F.layer_norm(wav, wav.shape[1:])
|
||||
# Extract wav2vec output
|
||||
out = self.hubert.extract_features(wav)[0]
|
||||
# We normalize the output if required
|
||||
if self.output_norm:
|
||||
out = F.layer_norm(out, out.shape[1:])
|
||||
feats = out
|
||||
x = self.enc(feats)
|
||||
x_lens = x.shape[1]
|
||||
ctc_probs = self.ctc.log_softmax(x) # (B, maxlen, vocab_size)
|
||||
topk_prob, topk_index = ctc_probs.topk(1, axis=2) # (B, maxlen, 1)
|
||||
topk_index = topk_index.view(batch_size, x_lens) # (B, maxlen)
|
||||
|
||||
hyps = [hyp.tolist() for hyp in topk_index]
|
||||
hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps]
|
||||
return hyps
|
||||
|
||||
def _ctc_prefix_beam_search(
|
||||
self,
|
||||
wav,
|
||||
beam_size,
|
||||
blank_id: int=0, ) -> Tuple[List[Tuple[int, float]], paddle.Tensor]:
|
||||
""" CTC prefix beam search inner implementation
|
||||
Args:
|
||||
speech (paddle.Tensor): (batch, max_len, feat_dim)
|
||||
speech_length (paddle.Tensor): (batch, )
|
||||
beam_size (int): beam size for beam search
|
||||
decoding_chunk_size (int): decoding chunk for dynamic chunk
|
||||
trained model.
|
||||
<0: for decoding, use full chunk.
|
||||
>0: for decoding, use fixed chunk size as set.
|
||||
0: used for training, it's prohibited here
|
||||
simulate_streaming (bool): whether do encoder forward in a
|
||||
streaming fashion
|
||||
Returns:
|
||||
List[Tuple[int, float]]: nbest results, (N,1), (text, likelihood)
|
||||
paddle.Tensor: encoder output, (1, max_len, encoder_dim),
|
||||
it will be used for rescoring in attention rescoring mode
|
||||
"""
|
||||
wav = wav[:, :, 0]
|
||||
|
||||
if self.normalize_wav:
|
||||
wav = F.layer_norm(wav, wav.shape[1:])
|
||||
# Extract wav2vec output
|
||||
out = self.hubert.extract_features(wav)[0]
|
||||
# We normalize the output if required
|
||||
if self.output_norm:
|
||||
out = F.layer_norm(out, out.shape[1:])
|
||||
feats = out
|
||||
|
||||
x = self.enc(feats)
|
||||
maxlen = x.shape[1]
|
||||
ctc_probs = self.ctc.log_softmax(x) # (1, maxlen, vocab_size)
|
||||
ctc_probs = ctc_probs.squeeze(0)
|
||||
|
||||
# cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score))
|
||||
# blank_ending_score and none_blank_ending_score in ln domain
|
||||
cur_hyps = [(tuple(), (0.0, -float('inf')))]
|
||||
# 2. CTC beam search step by step
|
||||
for t in range(0, maxlen):
|
||||
logp = ctc_probs[t] # (vocab_size,)
|
||||
# key: prefix, value (pb, pnb), default value(-inf, -inf)
|
||||
next_hyps = defaultdict(lambda: (-float('inf'), -float('inf')))
|
||||
# 2.1 First beam prune: select topk best
|
||||
top_k_logp, top_k_index = logp.topk(beam_size) # (beam_size,)
|
||||
for s in top_k_index:
|
||||
s = s.item()
|
||||
ps = logp[s].item()
|
||||
for prefix, (pb, pnb) in cur_hyps:
|
||||
last = prefix[-1] if len(prefix) > 0 else None
|
||||
if s == blank_id: # blank
|
||||
n_pb, n_pnb = next_hyps[prefix]
|
||||
n_pb = log_add([n_pb, pb + ps, pnb + ps])
|
||||
next_hyps[prefix] = (n_pb, n_pnb)
|
||||
elif s == last:
|
||||
# Update *ss -> *s;
|
||||
n_pb, n_pnb = next_hyps[prefix]
|
||||
n_pnb = log_add([n_pnb, pnb + ps])
|
||||
next_hyps[prefix] = (n_pb, n_pnb)
|
||||
# Update *s-s -> *ss, - is for blank
|
||||
n_prefix = prefix + (s, )
|
||||
n_pb, n_pnb = next_hyps[n_prefix]
|
||||
n_pnb = log_add([n_pnb, pb + ps])
|
||||
next_hyps[n_prefix] = (n_pb, n_pnb)
|
||||
else:
|
||||
n_prefix = prefix + (s, )
|
||||
n_pb, n_pnb = next_hyps[n_prefix]
|
||||
n_pnb = log_add([n_pnb, pb + ps, pnb + ps])
|
||||
next_hyps[n_prefix] = (n_pb, n_pnb)
|
||||
|
||||
# 2.2 Second beam prune
|
||||
next_hyps = sorted(
|
||||
next_hyps.items(),
|
||||
key=lambda x: log_add(list(x[1])),
|
||||
reverse=True)
|
||||
cur_hyps = next_hyps[:beam_size]
|
||||
|
||||
hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in cur_hyps]
|
||||
return hyps
|
||||
|
||||
def ctc_prefix_beam_search(self, wav, beam_size) -> List[int]:
|
||||
""" Apply CTC prefix beam search
|
||||
Args:
|
||||
speech (paddle.Tensor): (batch, max_len, feat_dim)
|
||||
speech_length (paddle.Tensor): (batch, )
|
||||
beam_size (int): beam size for beam search
|
||||
decoding_chunk_size (int): decoding chunk for dynamic chunk
|
||||
trained model.
|
||||
<0: for decoding, use full chunk.
|
||||
>0: for decoding, use fixed chunk size as set.
|
||||
0: used for training, it's prohibited here
|
||||
simulate_streaming (bool): whether do encoder forward in a
|
||||
streaming fashion
|
||||
Returns:
|
||||
List[int]: CTC prefix beam search nbest results
|
||||
"""
|
||||
hyps = self._ctc_prefix_beam_search(wav, beam_size)
|
||||
return hyps[0][0]
|
||||
|
||||
|
||||
class HubertBase(nn.Layer):
|
||||
"""Hubert model"""
|
||||
|
||||
def __init__(self, config: dict):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
task_cfg = self.merge_with_parent(HubertPretrainingConfig,
|
||||
dict(self.config.task_cfg))
|
||||
model_cfg = self.merge_with_parent(HubertConfig,
|
||||
dict(self.config.model_cfg))
|
||||
hubert = HubertModel(model_cfg, task_cfg, [None])
|
||||
self.hubert = hubert
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, configs: dict):
|
||||
"""init model.
|
||||
Args:
|
||||
configs (dict): config dict.
|
||||
Raises:
|
||||
ValueError: raise when using not support encoder type.
|
||||
Returns:
|
||||
nn.Layer: HubertBase
|
||||
"""
|
||||
model = cls(configs)
|
||||
return model
|
||||
|
||||
def merge_with_parent(self, dc: dataclass, cfg: dict):
|
||||
assert is_dataclass(dc)
|
||||
assert type(cfg) == dict
|
||||
cfg = deepcopy(cfg)
|
||||
|
||||
def fix_cfg(cfg):
|
||||
target_keys = set(dc.__dataclass_fields__.keys())
|
||||
for k in list(cfg.keys()):
|
||||
if k not in target_keys:
|
||||
del cfg[k]
|
||||
|
||||
fix_cfg(cfg)
|
||||
assert len(cfg) > 0
|
||||
return dc(**cfg)
|
||||
|
||||
def forward(self, wav):
|
||||
out = self.hubert.extract_features(wav)
|
||||
return out
|
@ -0,0 +1,13 @@
|
||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -0,0 +1,586 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Paddle Hubert model."""
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import field
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
|
||||
from paddlespeech.s2t.models.wav2vec2.modules.wav2vec2_model import ChoiceEnum
|
||||
from paddlespeech.s2t.models.wav2vec2.modules.wav2vec2_model import compute_mask_indices
|
||||
from paddlespeech.s2t.models.wav2vec2.modules.wav2vec2_model import ConvFeatureExtractionModel
|
||||
from paddlespeech.s2t.models.wav2vec2.modules.wav2vec2_model import EXTRACTOR_MODE_CHOICES
|
||||
from paddlespeech.s2t.models.wav2vec2.modules.wav2vec2_model import get_available_activation_fns
|
||||
from paddlespeech.s2t.models.wav2vec2.modules.wav2vec2_model import GLU
|
||||
from paddlespeech.s2t.models.wav2vec2.modules.wav2vec2_model import GradMultiply
|
||||
from paddlespeech.s2t.models.wav2vec2.modules.wav2vec2_model import LAYER_TYPE_CHOICES
|
||||
from paddlespeech.s2t.models.wav2vec2.modules.wav2vec2_model import MASKING_DISTRIBUTION_CHOICES
|
||||
from paddlespeech.s2t.models.wav2vec2.modules.wav2vec2_model import TransformerEncoder
|
||||
from paddlespeech.s2t.modules.align import LayerNorm
|
||||
from paddlespeech.s2t.modules.align import Linear
|
||||
from paddlespeech.s2t.utils.log import Log
|
||||
|
||||
logger = Log(__name__).getlog()
|
||||
|
||||
|
||||
@dataclass
|
||||
class HubertPretrainingConfig:
|
||||
label_rate: float = field(
|
||||
default=-1.0,
|
||||
metadata={"help": "label frame rate. -1.0 for sequence label"}, )
|
||||
sample_rate: int = field(
|
||||
default=16_000,
|
||||
metadata={
|
||||
"help":
|
||||
"target sample rate. audio files will be up/down "
|
||||
"sampled to this rate"
|
||||
}, )
|
||||
normalize: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "if set, normalizes input to have 0 mean and unit variance"
|
||||
}, )
|
||||
enable_padding: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "pad shorter samples instead of cropping"}, )
|
||||
max_keep_size: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "exclude sample longer than this"}, )
|
||||
max_sample_size: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "max sample size to crop to for batching"}, )
|
||||
min_sample_size: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "min sample size to crop to for batching"}, )
|
||||
random_crop: Optional[bool] = field(
|
||||
default=True,
|
||||
metadata={"help": "always crop from the beginning if false"}, )
|
||||
pad_audio: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "pad audio to the longest one in the batch if true"},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HubertConfig:
|
||||
label_rate: float
|
||||
|
||||
extractor_mode: EXTRACTOR_MODE_CHOICES = field(
|
||||
default="default",
|
||||
metadata={
|
||||
"help":
|
||||
"mode for feature extractor. default has a single group "
|
||||
"norm with d groups in the first conv block, whereas layer_norm "
|
||||
"has layer norms in every block (meant to use with normalize=True)"
|
||||
}, )
|
||||
encoder_layers: int = field(
|
||||
default=12, metadata={"help": "num encoder layers in the transformer"})
|
||||
encoder_embed_dim: int = field(
|
||||
default=768, metadata={"help": "encoder embedding dimension"})
|
||||
encoder_ffn_embed_dim: int = field(
|
||||
default=3072, metadata={"help": "encoder embedding dimension for FFN"})
|
||||
encoder_attention_heads: int = field(
|
||||
default=12, metadata={"help": "num encoder attention heads"})
|
||||
activation_fn: ChoiceEnum(get_available_activation_fns()) = field(
|
||||
default="gelu", metadata={"help": "activation function to use"})
|
||||
layer_type: LAYER_TYPE_CHOICES = field(
|
||||
default="transformer", metadata={"help": "layer type in encoder"})
|
||||
|
||||
# dropouts
|
||||
dropout: float = field(
|
||||
default=0.1,
|
||||
metadata={"help": "dropout probability for the transformer"}, )
|
||||
attention_dropout: float = field(
|
||||
default=0.1,
|
||||
metadata={"help": "dropout probability for attention weights"}, )
|
||||
activation_dropout: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "dropout probability after activation in FFN"}, )
|
||||
encoder_layerdrop: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "probability of dropping a tarnsformer layer"}, )
|
||||
dropout_input: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "dropout to apply to the input (after feat extr)"}, )
|
||||
dropout_features: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "dropout to apply to the features (after feat extr)"},
|
||||
)
|
||||
|
||||
final_dim: int = field(
|
||||
default=0,
|
||||
metadata={
|
||||
"help":
|
||||
"project final representations and targets to this many "
|
||||
"dimensions. set to encoder_embed_dim is <= 0"
|
||||
}, )
|
||||
untie_final_proj: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "use separate projection for each target"}, )
|
||||
layer_norm_first: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "apply layernorm first in the transformer"}, )
|
||||
conv_feature_layers: str = field(
|
||||
default="[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2",
|
||||
metadata={
|
||||
"help":
|
||||
"string describing convolutional feature extraction "
|
||||
"layers in form of a python list that contains "
|
||||
"[(dim, kernel_size, stride), ...]"
|
||||
}, )
|
||||
conv_bias: bool = field(
|
||||
default=False, metadata={"help": "include bias in conv encoder"})
|
||||
logit_temp: float = field(
|
||||
default=0.1, metadata={"help": "temperature to divide logits by"})
|
||||
target_glu: bool = field(
|
||||
default=False, metadata={"help": "adds projection + glu to targets"})
|
||||
feature_grad_mult: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "multiply feature extractor var grads by this"}, )
|
||||
|
||||
# masking
|
||||
mask_length: int = field(default=10, metadata={"help": "mask length"})
|
||||
mask_prob: float = field(
|
||||
default=0.65,
|
||||
metadata={"help": "probability of replacing a token with mask"}, )
|
||||
mask_selection: MASKING_DISTRIBUTION_CHOICES = field(
|
||||
default="static", metadata={"help": "how to choose mask length"})
|
||||
mask_other: float = field(
|
||||
default=0,
|
||||
metadata={
|
||||
"help":
|
||||
"secondary mask argument "
|
||||
"(used for more complex distributions), "
|
||||
"see help in compute_mask_indicesh"
|
||||
}, )
|
||||
no_mask_overlap: bool = field(
|
||||
default=False, metadata={"help": "whether to allow masks to overlap"})
|
||||
mask_min_space: int = field(
|
||||
default=1,
|
||||
metadata={"help": "min space between spans (if no overlap is enabled)"},
|
||||
)
|
||||
|
||||
# channel masking
|
||||
mask_channel_length: int = field(
|
||||
default=10,
|
||||
metadata={"help": "length of the mask for features (channels)"}, )
|
||||
mask_channel_prob: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "probability of replacing a feature with 0"}, )
|
||||
mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field(
|
||||
default="static",
|
||||
metadata={"help": "how to choose mask length for channel masking"}, )
|
||||
mask_channel_other: float = field(
|
||||
default=0,
|
||||
metadata={
|
||||
"help":
|
||||
"secondary mask argument "
|
||||
"(used for more complex distributions), "
|
||||
"see help in compute_mask_indicesh"
|
||||
}, )
|
||||
no_mask_channel_overlap: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "whether to allow channel masks to overlap"}, )
|
||||
mask_channel_min_space: int = field(
|
||||
default=1,
|
||||
metadata={"help": "min space between spans (if no overlap is enabled)"},
|
||||
)
|
||||
|
||||
# positional embeddings
|
||||
conv_pos: int = field(
|
||||
default=128,
|
||||
metadata={
|
||||
"help": "number of filters for convolutional positional embeddings"
|
||||
}, )
|
||||
conv_pos_groups: int = field(
|
||||
default=16,
|
||||
metadata={
|
||||
"help": "number of groups for convolutional positional embedding"
|
||||
}, )
|
||||
|
||||
latent_temp: Tuple[float, float, float] = field(
|
||||
default=(2, 0.5, 0.999995),
|
||||
metadata={"help": "legacy (to be removed)"}, )
|
||||
|
||||
# loss computation
|
||||
skip_masked: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "skip computing losses over masked frames"}, )
|
||||
skip_nomask: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "skip computing losses over unmasked frames"}, )
|
||||
|
||||
checkpoint_activations: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "recompute activations and save memory for extra compute"
|
||||
}, )
|
||||
|
||||
# FP16 optimization
|
||||
required_seq_len_multiple: int = field(
|
||||
default=2,
|
||||
metadata={
|
||||
"help":
|
||||
"pad the input to encoder such that the sequence length is divisible by multiple"
|
||||
}, )
|
||||
|
||||
# Conformer
|
||||
depthwise_conv_kernel_size: int = field(
|
||||
default=31,
|
||||
metadata={
|
||||
"help":
|
||||
"depthwise-conv-kernel-size for convolution in conformer layer"
|
||||
}, )
|
||||
attn_type: str = field(
|
||||
default="",
|
||||
metadata={"help": "if espnet use ESPNET MHA"}, )
|
||||
pos_enc_type: str = field(
|
||||
default="abs",
|
||||
metadata={"help": "Positional encoding type to use in conformer"}, )
|
||||
fp16: bool = field(
|
||||
default=False, metadata={"help": "If fp16 is being used"})
|
||||
|
||||
|
||||
class HubertModel(nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
cfg: HubertConfig,
|
||||
task_cfg: HubertPretrainingConfig,
|
||||
dictionaries: List[Any], ) -> None:
|
||||
super().__init__()
|
||||
logger.info(f"HubertModel Config: {cfg}")
|
||||
|
||||
feature_enc_layers = eval(cfg.conv_feature_layers) # noqa
|
||||
self.embed = feature_enc_layers[-1][0]
|
||||
|
||||
self.feature_extractor = ConvFeatureExtractionModel(
|
||||
conv_layers=feature_enc_layers,
|
||||
dropout=0.0,
|
||||
mode=cfg.extractor_mode,
|
||||
conv_bias=cfg.conv_bias, )
|
||||
feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers])
|
||||
self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / task_cfg.sample_rate
|
||||
|
||||
self.post_extract_proj = (Linear(self.embed, cfg.encoder_embed_dim) if
|
||||
self.embed != cfg.encoder_embed_dim else None)
|
||||
|
||||
self.mask_prob = cfg.mask_prob
|
||||
self.mask_selection = cfg.mask_selection
|
||||
self.mask_other = cfg.mask_other
|
||||
self.mask_length = cfg.mask_length
|
||||
self.no_mask_overlap = cfg.no_mask_overlap
|
||||
self.mask_min_space = cfg.mask_min_space
|
||||
|
||||
self.mask_channel_prob = cfg.mask_channel_prob
|
||||
self.mask_channel_selection = cfg.mask_channel_selection
|
||||
self.mask_channel_other = cfg.mask_channel_other
|
||||
self.mask_channel_length = cfg.mask_channel_length
|
||||
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
|
||||
self.mask_channel_min_space = cfg.mask_channel_min_space
|
||||
|
||||
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
||||
self.dropout_features = nn.Dropout(cfg.dropout_features)
|
||||
|
||||
self.feature_grad_mult = cfg.feature_grad_mult
|
||||
self.logit_temp = cfg.logit_temp
|
||||
self.skip_masked = cfg.skip_masked
|
||||
self.skip_nomask = cfg.skip_nomask
|
||||
|
||||
final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim
|
||||
|
||||
self.mask_emb = paddle.create_parameter(
|
||||
shape=[cfg.encoder_embed_dim],
|
||||
dtype='float32',
|
||||
default_initializer=paddle.nn.initializer.Uniform(low=0), )
|
||||
|
||||
self.encoder = TransformerEncoder(cfg)
|
||||
self.layer_norm = LayerNorm(self.embed)
|
||||
|
||||
self.target_glu = None
|
||||
if cfg.target_glu:
|
||||
self.target_glu = nn.Sequential(
|
||||
Linear(final_dim, final_dim * 2), GLU())
|
||||
|
||||
self.untie_final_proj = cfg.untie_final_proj
|
||||
if self.untie_final_proj:
|
||||
self.final_proj = Linear(cfg.encoder_embed_dim,
|
||||
final_dim * len(dictionaries))
|
||||
else:
|
||||
self.final_proj = Linear(cfg.encoder_embed_dim, final_dim)
|
||||
|
||||
# modules below are not needed during fine-tuning
|
||||
if any([d is None for d in dictionaries]):
|
||||
logger.info(
|
||||
"cannot find dictionary. assume will be used for fine-tuning")
|
||||
else:
|
||||
self.num_classes = [len(d) for d in dictionaries]
|
||||
self.label_embs_concat = paddle.create_parameter(
|
||||
shape=[sum(self.num_classes), final_dim],
|
||||
dtype='float32',
|
||||
default_initializer=paddle.nn.initializer.Uniform(low=0), )
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, cfg: HubertConfig, task):
|
||||
"""Build a new model instance."""
|
||||
|
||||
model = HubertModel(cfg, task.cfg, task.dictionaries)
|
||||
return model
|
||||
|
||||
def apply_mask(self, x, padding_mask, target_list):
|
||||
B, T, C = x.shape
|
||||
if self.mask_prob > 0:
|
||||
mask_indices = compute_mask_indices(
|
||||
(B, T),
|
||||
padding_mask,
|
||||
self.mask_prob,
|
||||
self.mask_length,
|
||||
self.mask_selection,
|
||||
self.mask_other,
|
||||
min_masks=2,
|
||||
no_overlap=self.no_mask_overlap,
|
||||
min_space=self.mask_min_space, )
|
||||
|
||||
mask_indices = paddle.to_tensor(
|
||||
mask_indices, dtype='int64', place=x.place)
|
||||
x[mask_indices] = self.mask_emb
|
||||
else:
|
||||
mask_indices = None
|
||||
|
||||
if self.mask_channel_prob > 0:
|
||||
mask_channel_indices = compute_mask_indices(
|
||||
(B, C),
|
||||
None,
|
||||
self.mask_channel_prob,
|
||||
self.mask_channel_length,
|
||||
self.mask_channel_selection,
|
||||
self.mask_channel_other,
|
||||
no_overlap=self.no_mask_channel_overlap,
|
||||
min_space=self.mask_channel_min_space, )
|
||||
mask_channel_indices = (paddle.to_tensor(
|
||||
mask_channel_indices, dtype='int64', place=x.place).unsqueeze(1)
|
||||
.expand([-1, T, -1]))
|
||||
x[mask_channel_indices] = 0
|
||||
|
||||
return x, mask_indices
|
||||
|
||||
def compute_nce(self, x, pos, negs):
|
||||
neg_is_pos = (pos == negs).all(-1)
|
||||
pos = pos.unsqueeze(0)
|
||||
targets = paddle.concat([pos, negs], axis=0)
|
||||
|
||||
logits = paddle.nn.functional.cosine_similarity(
|
||||
x.astype('float32'), targets.astype('float32'), axis=-1)
|
||||
logits /= self.logit_temp
|
||||
if paddle.any(neg_is_pos):
|
||||
logits[1:][neg_is_pos] = float("-inf")
|
||||
logits = logits.transpose([1, 0]) # (num_x, num_cls+1)
|
||||
return logits
|
||||
|
||||
def forward_features(self, source: paddle.Tensor) -> paddle.Tensor:
|
||||
if self.feature_grad_mult > 0:
|
||||
features = self.feature_extractor(source)
|
||||
if self.feature_grad_mult != 1.0:
|
||||
features = GradMultiply.apply(features, self.feature_grad_mult)
|
||||
else:
|
||||
with paddle.no_grad():
|
||||
features = self.feature_extractor(source)
|
||||
return features
|
||||
|
||||
def forward_targets(
|
||||
self,
|
||||
features: paddle.Tensor,
|
||||
target_list: List[paddle.Tensor],
|
||||
) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
||||
# Trim features to ensure labels exist and then get aligned labels
|
||||
feat_tsz = features.shape[2]
|
||||
targ_tsz = min([t.shape[1] for t in target_list])
|
||||
if self.feat2tar_ratio * feat_tsz > targ_tsz:
|
||||
feat_tsz = int(targ_tsz / self.feat2tar_ratio)
|
||||
features = features[:, :, :feat_tsz]
|
||||
target_inds = paddle.arange(feat_tsz).astype(
|
||||
'float32') * self.feat2tar_ratio
|
||||
target_list = [t[:, target_inds.astype('int64')] for t in target_list]
|
||||
return features, target_list
|
||||
|
||||
def forward_padding_mask(
|
||||
self,
|
||||
features: paddle.Tensor,
|
||||
padding_mask: paddle.Tensor, ) -> paddle.Tensor:
|
||||
extra = padding_mask.shape[1] % features.shape[1]
|
||||
if extra > 0:
|
||||
padding_mask = padding_mask[:, :-extra]
|
||||
padding_mask = paddle.reshape(
|
||||
padding_mask, [padding_mask.shape[0], features.shape[1], -1])
|
||||
padding_mask = paddle.all(padding_mask, axis=-1)
|
||||
return padding_mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
source: paddle.Tensor,
|
||||
target_list: Optional[List[paddle.Tensor]]=None,
|
||||
padding_mask: Optional[paddle.Tensor]=None,
|
||||
mask: bool=True,
|
||||
features_only: bool=False,
|
||||
output_layer: Optional[int]=None, ) -> Dict[str, paddle.Tensor]:
|
||||
"""output layer is 1-based"""
|
||||
features = self.forward_features(source)
|
||||
if target_list is not None:
|
||||
features, target_list = self.forward_targets(features, target_list)
|
||||
|
||||
features_pen = features.pow(2).mean()
|
||||
|
||||
features = features.transpose([0, 2, 1])
|
||||
features = self.layer_norm(features)
|
||||
unmasked_features = features.clone()
|
||||
|
||||
if padding_mask is not None:
|
||||
padding_mask = self.forward_padding_mask(features, padding_mask)
|
||||
|
||||
if self.post_extract_proj is not None:
|
||||
features = self.post_extract_proj(features)
|
||||
|
||||
features = self.dropout_input(features)
|
||||
unmasked_features = self.dropout_features(unmasked_features)
|
||||
|
||||
if mask:
|
||||
x, mask_indices = self.apply_mask(features, padding_mask,
|
||||
target_list)
|
||||
else:
|
||||
x = features
|
||||
mask_indices = None
|
||||
|
||||
# feature: (B, T, D), float
|
||||
# target: (B, T), long
|
||||
# x: (B, T, D), float
|
||||
# padding_mask: (B, T), bool
|
||||
# mask_indices: (B, T), bool
|
||||
x, _ = self.encoder(
|
||||
x,
|
||||
padding_mask=padding_mask,
|
||||
layer=None if output_layer is None else output_layer - 1, )
|
||||
|
||||
if features_only:
|
||||
return {"x": x, "padding_mask": padding_mask, "features": features}
|
||||
|
||||
def compute_pred(self, proj_x, target, label_embs):
|
||||
# compute logits for the i-th label set
|
||||
y = paddle.index_select(
|
||||
label_embs, index=target.astype('int64'), axis=0)
|
||||
negs = paddle.expand(
|
||||
label_embs.unsqueeze(1),
|
||||
[label_embs.shape[0], proj_x.shape[0], label_embs.shape[-1]])
|
||||
if self.target_glu:
|
||||
y = self.target_glu(y)
|
||||
negs = self.target_glu(negs)
|
||||
# proj_x: (S, D)
|
||||
# y: (S, D)
|
||||
# negs: (Neg, S, D)
|
||||
return self.compute_nce(proj_x, y, negs)
|
||||
|
||||
label_embs_list = self.label_embs_concat.split(self.num_classes, 0)
|
||||
|
||||
if not self.skip_masked:
|
||||
masked_indices = paddle.logical_and(~padding_mask, mask_indices)
|
||||
proj_x_m = self.final_proj(x[masked_indices])
|
||||
if self.untie_final_proj:
|
||||
proj_x_m_list = proj_x_m.chunk(len(target_list), dim=-1)
|
||||
else:
|
||||
proj_x_m_list = [proj_x_m for _ in range(len(target_list))]
|
||||
logit_m_list = [
|
||||
compute_pred(proj_x_m, t[masked_indices], label_embs_list[i])
|
||||
for i, (proj_x_m, t
|
||||
) in enumerate(zip(proj_x_m_list, target_list))
|
||||
]
|
||||
else:
|
||||
logit_m_list = [None for _ in target_list]
|
||||
|
||||
if not self.skip_nomask:
|
||||
nomask_indices = paddle.logical_and(~padding_mask, ~mask_indices)
|
||||
proj_x_u = self.final_proj(x[nomask_indices])
|
||||
if self.untie_final_proj:
|
||||
proj_x_u_list = proj_x_u.chunk(len(target_list), dim=-1)
|
||||
else:
|
||||
proj_x_u_list = [proj_x_u for _ in range(len(target_list))]
|
||||
|
||||
logit_u_list = [
|
||||
compute_pred(proj_x_u, t[nomask_indices], label_embs_list[i])
|
||||
for i, (proj_x_u, t
|
||||
) in enumerate(zip(proj_x_u_list, target_list))
|
||||
]
|
||||
else:
|
||||
logit_u_list = [None for _ in target_list]
|
||||
|
||||
result = {
|
||||
"logit_m_list": logit_m_list,
|
||||
"logit_u_list": logit_u_list,
|
||||
"padding_mask": padding_mask,
|
||||
"features_pen": features_pen,
|
||||
}
|
||||
return result
|
||||
|
||||
def extract_features(
|
||||
self,
|
||||
source: paddle.Tensor,
|
||||
padding_mask: Optional[paddle.Tensor]=None,
|
||||
mask: bool=False,
|
||||
ret_conv: bool=False,
|
||||
output_layer: Optional[int]=None,
|
||||
) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
||||
res = self.forward(
|
||||
source,
|
||||
padding_mask=padding_mask,
|
||||
mask=mask,
|
||||
features_only=True,
|
||||
output_layer=output_layer, )
|
||||
feature = res["features"] if ret_conv else res["x"]
|
||||
return feature, res["padding_mask"]
|
||||
|
||||
def get_logits(self, net_output, is_masked=True):
|
||||
if is_masked:
|
||||
logits_list = net_output["logit_m_list"]
|
||||
else:
|
||||
logits_list = net_output["logit_u_list"]
|
||||
logits_list = [
|
||||
paddle.cast(x, 'float32') for x in logits_list if x is not None
|
||||
]
|
||||
return logits_list
|
||||
|
||||
def get_targets(self, net_output, is_masked=True):
|
||||
logits_list = self.get_logits(net_output, is_masked)
|
||||
targets_list = [
|
||||
paddle.zeros_like(x, dtype='int64') for x in logits_list
|
||||
]
|
||||
return targets_list
|
||||
|
||||
def get_extra_losses(self, net_output):
|
||||
extra_losses = []
|
||||
names = []
|
||||
|
||||
if "features_pen" in net_output:
|
||||
extra_losses.append(net_output["features_pen"])
|
||||
names.append("features_pen")
|
||||
|
||||
return extra_losses, names
|
||||
|
||||
def remove_pretraining_modules(self):
|
||||
self.target_glu = None
|
||||
self.final_proj = None
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in new issue