parent
4323526155
commit
3b6651ba7c
@ -0,0 +1,9 @@
|
|||||||
|
# LibriSpeech
|
||||||
|
|
||||||
|
## hubertASR
|
||||||
|
Fintuning on train-clean-100
|
||||||
|
train: Epoch 3, 1*V100-32G, batchsize: 4, accum_grad: 8
|
||||||
|
|
||||||
|
| Model | Params | Config | Augmentation| Test set | Decode method | WER |
|
||||||
|
| --- | --- | --- | --- | --- | --- | --- |
|
||||||
|
| hubertASR | 326.16M | conf/hubertASR.yaml | spec_aug | test-clean | greedy search | 0.05868 |
|
@ -0,0 +1,33 @@
|
|||||||
|
#! /usr/bin/env bash
|
||||||
|
|
||||||
|
if [ $# != 3 ]; then
|
||||||
|
echo "usage: ${0} [best|latest] ckpt_dir avg_num"
|
||||||
|
exit -1
|
||||||
|
fi
|
||||||
|
|
||||||
|
avg_mode=${1} # best,latest
|
||||||
|
ckpt_dir=${2}
|
||||||
|
average_num=${3}
|
||||||
|
decode_checkpoint=${ckpt_dir}/avg_${average_num}.pdparams
|
||||||
|
|
||||||
|
if [ $avg_mode == best ];then
|
||||||
|
# best
|
||||||
|
python avg_model.py \
|
||||||
|
--dst_model ${decode_checkpoint} \
|
||||||
|
--ckpt_dir ${ckpt_dir} \
|
||||||
|
--num ${average_num} \
|
||||||
|
--val_best
|
||||||
|
else
|
||||||
|
# latest
|
||||||
|
python avg_model.py \
|
||||||
|
--dst_model ${decode_checkpoint} \
|
||||||
|
--ckpt_dir ${ckpt_dir} \
|
||||||
|
--num ${average_num}
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Failed in avg ckpt!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
exit 0
|
@ -0,0 +1,18 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from paddlespeech.dataset.s2t import avg_ckpts_main
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
avg_ckpts_main()
|
@ -0,0 +1,3 @@
|
|||||||
|
process:
|
||||||
|
# use raw audio
|
||||||
|
- type: wav_process
|
@ -0,0 +1,9 @@
|
|||||||
|
{
|
||||||
|
"do_normalize": true,
|
||||||
|
"feature_extractor_type": "Wav2Vec2FeatureExtractor",
|
||||||
|
"feature_size": 1,
|
||||||
|
"padding_side": "right",
|
||||||
|
"padding_value": 0,
|
||||||
|
"return_attention_mask": true,
|
||||||
|
"sampling_rate": 16000
|
||||||
|
}
|
@ -0,0 +1,4 @@
|
|||||||
|
decode_batch_size: 1
|
||||||
|
error_rate_type: wer
|
||||||
|
decoding_method: "ctc_greedy_search" # 'ctc_greedy_search', 'ctc_prefix_beam_search'
|
||||||
|
beam_size: 10
|
@ -0,0 +1,137 @@
|
|||||||
|
############################################
|
||||||
|
# Network Architecture #
|
||||||
|
############################################
|
||||||
|
freeze_wavlm: False
|
||||||
|
normalize_wav: True
|
||||||
|
output_norm: True
|
||||||
|
init_type: kaiming_uniform # !Warning: need to convergence
|
||||||
|
enc:
|
||||||
|
input_shape: 768
|
||||||
|
dnn_blocks: 2
|
||||||
|
dnn_neurons: 768
|
||||||
|
activation: True
|
||||||
|
normalization: True
|
||||||
|
dropout_rate: [0.15, 0]
|
||||||
|
ctc:
|
||||||
|
enc_n_units: 768
|
||||||
|
blank_id: 0
|
||||||
|
dropout_rate: 0.0
|
||||||
|
wavlm_params_path: "/home/ubuntu/Documents/Github/wavlm_paddle/wavlm-paddle-ft.pth"
|
||||||
|
|
||||||
|
|
||||||
|
task_cfg:
|
||||||
|
label_rate: 50.0
|
||||||
|
sample_rate: 16000
|
||||||
|
normalize: True
|
||||||
|
enable_padding: False
|
||||||
|
max_keep_size: None
|
||||||
|
max_sample_size: 250000
|
||||||
|
min_sample_size: 32000
|
||||||
|
dropout_input: 0.1
|
||||||
|
final_dropout: 0.0
|
||||||
|
dropout: 0.1
|
||||||
|
attention_dropout: 0.0
|
||||||
|
activation_dropout: 0.1
|
||||||
|
apply_mask: True
|
||||||
|
mask_length: 10
|
||||||
|
mask_prob: 0.5
|
||||||
|
mask_selection: static
|
||||||
|
mask_other: 0.0
|
||||||
|
no_mask_overlap: False
|
||||||
|
mask_channel_length: 10
|
||||||
|
mask_channel_prob: 0.0
|
||||||
|
mask_channel_selection: static
|
||||||
|
mask_channel_other: 0.0
|
||||||
|
no_mask_channel_overlap: False
|
||||||
|
feature_grad_mult: 0.0
|
||||||
|
layerdrop: 0.1
|
||||||
|
fp16: True
|
||||||
|
extractor_mode: layer_norm
|
||||||
|
encoder_layers: 12
|
||||||
|
encoder_embed_dim: 768
|
||||||
|
encoder_ffn_embed_dim: 3072
|
||||||
|
encoder_attention_heads: 12
|
||||||
|
activation_fn: gelu
|
||||||
|
encoder_layerdrop: 0.0
|
||||||
|
dropout_features: 0.0
|
||||||
|
final_dim: 768
|
||||||
|
untie_final_proj: True
|
||||||
|
layer_norm_first: True
|
||||||
|
conv_feature_layers: "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2"
|
||||||
|
conv_bias: False
|
||||||
|
logit_temp: 0.1
|
||||||
|
target_glu: False
|
||||||
|
mask_min_space: 1
|
||||||
|
mask_channel_min_space: 1
|
||||||
|
conv_pos: 128
|
||||||
|
conv_pos_groups: 16
|
||||||
|
latent_temp: [2.0, 0.5, 0.999995]
|
||||||
|
skip_masked: False
|
||||||
|
skip_nomask: True
|
||||||
|
|
||||||
|
###########################################
|
||||||
|
# Data #
|
||||||
|
###########################################
|
||||||
|
train_manifest: data/manifest.train
|
||||||
|
dev_manifest: data/manifest.dev
|
||||||
|
test_manifest: data/manifest.test-clean
|
||||||
|
|
||||||
|
###########################################
|
||||||
|
# Dataloader #
|
||||||
|
###########################################
|
||||||
|
vocab_filepath: data/lang_char/vocab.txt
|
||||||
|
unit_type: char
|
||||||
|
mean_std_filepath: ""
|
||||||
|
preprocess_config: conf/preprocess.yaml
|
||||||
|
sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs 0: disabled other: enabled for other epochs
|
||||||
|
batch_size: 8 # Different batch_size may cause large differences in results
|
||||||
|
maxlen_in: 51200000000 # if input length > maxlen-in batchsize is automatically reduced
|
||||||
|
maxlen_out: 160000
|
||||||
|
minibatches: 0 # for debug
|
||||||
|
batch_count: auto
|
||||||
|
batch_bins: 0
|
||||||
|
batch_frames_in: 0
|
||||||
|
batch_frames_out: 0
|
||||||
|
batch_frames_inout: 0
|
||||||
|
num_workers: 0
|
||||||
|
subsampling_factor: 1
|
||||||
|
num_encs: 1
|
||||||
|
dist_sampler: True
|
||||||
|
shortest_first: False
|
||||||
|
return_lens_rate: True
|
||||||
|
|
||||||
|
############################################
|
||||||
|
# Data Augmentation #
|
||||||
|
############################################
|
||||||
|
audio_augment: # for raw audio
|
||||||
|
sample_rate: 16000
|
||||||
|
speeds: [90, 100, 110]
|
||||||
|
|
||||||
|
###########################################
|
||||||
|
# Training #
|
||||||
|
###########################################
|
||||||
|
n_epoch: 10
|
||||||
|
accum_grad: 8
|
||||||
|
global_grad_clip: 5.0
|
||||||
|
model_scheduler: newbobscheduler
|
||||||
|
model_scheduler_conf:
|
||||||
|
improvement_threshold: 0.0025
|
||||||
|
annealing_factor: 0.8
|
||||||
|
patient: 0
|
||||||
|
model_optim: adam
|
||||||
|
model_optim_conf:
|
||||||
|
lr: 0.0001
|
||||||
|
weight_decay: 0.0
|
||||||
|
# I changed this
|
||||||
|
wavlm_optim: adam
|
||||||
|
wavlm_optim_conf:
|
||||||
|
lr: 0.00005
|
||||||
|
weight_decay: 0.0
|
||||||
|
wavlm_scheduler: constantlr
|
||||||
|
wavlm_scheduler_conf:
|
||||||
|
warmup_steps: 1000
|
||||||
|
lr_decay: 1.0
|
||||||
|
log_interval: 1
|
||||||
|
checkpoint:
|
||||||
|
kbest_n: 50
|
||||||
|
latest_n: 5
|
@ -0,0 +1,143 @@
|
|||||||
|
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
format ref/hyp file for `utt text` format to compute CER/WER/MER.
|
||||||
|
|
||||||
|
norm:
|
||||||
|
BAC009S0764W0196 明确了发展目标和重点任务
|
||||||
|
BAC009S0764W0186 实现我国房地产市场的平稳运行
|
||||||
|
|
||||||
|
|
||||||
|
sclite:
|
||||||
|
加大对结构机械化环境和收集谈控机制力度(BAC009S0906W0240.wav)
|
||||||
|
河南省新乡市丰秋县刘光镇政府东五零左右(BAC009S0770W0441.wav)
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import jsonlines
|
||||||
|
|
||||||
|
from paddlespeech.utils.argparse import print_arguments
|
||||||
|
|
||||||
|
|
||||||
|
def transform_hyp(origin, trans, trans_sclite):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
origin: The input json file which contains the model output
|
||||||
|
trans: The output file for caculate CER/WER
|
||||||
|
trans_sclite: The output file for caculate CER/WER using sclite
|
||||||
|
"""
|
||||||
|
input_dict = {}
|
||||||
|
|
||||||
|
with open(origin, "r+", encoding="utf8") as f:
|
||||||
|
for item in jsonlines.Reader(f):
|
||||||
|
input_dict[item["utt"]] = item["hyps"][0]
|
||||||
|
|
||||||
|
if trans:
|
||||||
|
with open(trans, "w+", encoding="utf8") as f:
|
||||||
|
for key in input_dict.keys():
|
||||||
|
f.write(key + " " + input_dict[key] + "\n")
|
||||||
|
print(f"transform_hyp output: {trans}")
|
||||||
|
|
||||||
|
if trans_sclite:
|
||||||
|
with open(trans_sclite, "w+") as f:
|
||||||
|
for key in input_dict.keys():
|
||||||
|
line = input_dict[key] + "(" + key + ".wav" + ")" + "\n"
|
||||||
|
f.write(line)
|
||||||
|
print(f"transform_hyp output: {trans_sclite}")
|
||||||
|
|
||||||
|
|
||||||
|
def transform_ref(origin, trans, trans_sclite):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
origin: The input json file which contains the model output
|
||||||
|
trans: The output file for caculate CER/WER
|
||||||
|
trans_sclite: The output file for caculate CER/WER using sclite
|
||||||
|
"""
|
||||||
|
input_dict = {}
|
||||||
|
|
||||||
|
with open(origin, "r", encoding="utf8") as f:
|
||||||
|
for item in jsonlines.Reader(f):
|
||||||
|
input_dict[item["utt"]] = item["text"]
|
||||||
|
|
||||||
|
if trans:
|
||||||
|
with open(trans, "w", encoding="utf8") as f:
|
||||||
|
for key in input_dict.keys():
|
||||||
|
f.write(key + " " + input_dict[key] + "\n")
|
||||||
|
print(f"transform_hyp output: {trans}")
|
||||||
|
|
||||||
|
if trans_sclite:
|
||||||
|
with open(trans_sclite, "w") as f:
|
||||||
|
for key in input_dict.keys():
|
||||||
|
line = input_dict[key] + "(" + key + ".wav" + ")" + "\n"
|
||||||
|
f.write(line)
|
||||||
|
print(f"transform_hyp output: {trans_sclite}")
|
||||||
|
|
||||||
|
|
||||||
|
def define_argparse():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
prog='format ref/hyp file for compute CER/WER', add_help=True)
|
||||||
|
parser.add_argument(
|
||||||
|
'--origin_hyp', type=str, default="", help='origin hyp file')
|
||||||
|
parser.add_argument(
|
||||||
|
'--trans_hyp',
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help='hyp file for caculating CER/WER')
|
||||||
|
parser.add_argument(
|
||||||
|
'--trans_hyp_sclite',
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help='hyp file for caculating CER/WER by sclite')
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--origin_ref', type=str, default="", help='origin ref file')
|
||||||
|
parser.add_argument(
|
||||||
|
'--trans_ref',
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help='ref file for caculating CER/WER')
|
||||||
|
parser.add_argument(
|
||||||
|
'--trans_ref_sclite',
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help='ref file for caculating CER/WER by sclite')
|
||||||
|
parser_args = parser.parse_args()
|
||||||
|
return parser_args
|
||||||
|
|
||||||
|
|
||||||
|
def format_result(origin_hyp="",
|
||||||
|
trans_hyp="",
|
||||||
|
trans_hyp_sclite="",
|
||||||
|
origin_ref="",
|
||||||
|
trans_ref="",
|
||||||
|
trans_ref_sclite=""):
|
||||||
|
|
||||||
|
if origin_hyp:
|
||||||
|
transform_hyp(
|
||||||
|
origin=origin_hyp, trans=trans_hyp, trans_sclite=trans_hyp_sclite)
|
||||||
|
|
||||||
|
if origin_ref:
|
||||||
|
transform_ref(
|
||||||
|
origin=origin_ref, trans=trans_ref, trans_sclite=trans_ref_sclite)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = define_argparse()
|
||||||
|
print_arguments(args, globals())
|
||||||
|
|
||||||
|
format_result(**vars(args))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -0,0 +1,110 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
stage=-1
|
||||||
|
stop_stage=100
|
||||||
|
|
||||||
|
unit_type=char
|
||||||
|
dict_dir=data/lang_char
|
||||||
|
|
||||||
|
source ${MAIN_ROOT}/utils/parse_options.sh
|
||||||
|
|
||||||
|
mkdir -p data
|
||||||
|
mkdir -p ${dict_dir}
|
||||||
|
TARGET_DIR=${MAIN_ROOT}/dataset
|
||||||
|
mkdir -p ${TARGET_DIR}
|
||||||
|
|
||||||
|
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
|
||||||
|
# download data, generate manifests
|
||||||
|
python3 ${TARGET_DIR}/librispeech/librispeech.py \
|
||||||
|
--manifest_prefix="data/manifest" \
|
||||||
|
--target_dir="${TARGET_DIR}/librispeech" \
|
||||||
|
--full_download="False"
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Prepare LibriSpeech failed. Terminated."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
for set in train-clean-100 dev-clean test-clean; do
|
||||||
|
mv data/manifest.${set} data/manifest.${set}.raw
|
||||||
|
done
|
||||||
|
|
||||||
|
rm -rf data/manifest.train.raw data/manifest.dev.raw data/manifest.test.raw
|
||||||
|
for set in train-clean-100; do
|
||||||
|
cat data/manifest.${set}.raw >> data/manifest.train.raw
|
||||||
|
done
|
||||||
|
|
||||||
|
for set in dev-clean; do
|
||||||
|
cat data/manifest.${set}.raw >> data/manifest.dev.raw
|
||||||
|
done
|
||||||
|
|
||||||
|
for set in test-clean; do
|
||||||
|
cat data/manifest.${set}.raw >> data/manifest.test.raw
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||||
|
# compute mean and stddev for normalizer
|
||||||
|
num_workers=$(nproc)
|
||||||
|
python ${MAIN_ROOT}/utils/compute_mean_std.py \
|
||||||
|
--manifest_path="data/manifest.train.raw" \
|
||||||
|
--num_samples=2000 \
|
||||||
|
--spectrum_type="fbank" \
|
||||||
|
--feat_dim=161 \
|
||||||
|
--delta_delta=false \
|
||||||
|
--sample_rate=16000 \
|
||||||
|
--stride_ms=10 \
|
||||||
|
--window_ms=25 \
|
||||||
|
--use_dB_normalization=False \
|
||||||
|
--num_workers=${num_workers} \
|
||||||
|
--output_path="data/mean_std.json"
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Compute mean and stddev failed. Terminated."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||||
|
# build vocabulary
|
||||||
|
python3 ${MAIN_ROOT}/utils/build_vocab.py \
|
||||||
|
--unit_type ${unit_type} \
|
||||||
|
--count_threshold=0 \
|
||||||
|
--vocab_path="${dict_dir}/vocab.txt" \
|
||||||
|
--manifest_paths="data/manifest.train.raw"
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Build vocabulary failed. Terminated."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||||
|
# format manifest with tokenids, vocab size
|
||||||
|
for set in train dev test dev-clean test-clean; do
|
||||||
|
{
|
||||||
|
python3 ${MAIN_ROOT}/utils/format_data.py \
|
||||||
|
--cmvn_path "data/mean_std.json" \
|
||||||
|
--unit_type ${unit_type} \
|
||||||
|
--vocab_path="${dict_dir}/vocab.txt" \
|
||||||
|
--manifest_path="data/manifest.${set}.raw" \
|
||||||
|
--output_path="data/manifest.${set}"
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Formt manifest.${set} failed. Terminated."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
}&
|
||||||
|
done
|
||||||
|
wait
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "LibriSpeech Data preparation done."
|
||||||
|
|
||||||
|
# if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||||
|
# mkdir -p exp/hubert
|
||||||
|
# echo "Pretrained hubert model download"
|
||||||
|
# wget -P exp/hubert https://paddlespeech.bj.bcebos.com/hubert/hubert-large-lv60.pdparams
|
||||||
|
# fi
|
||||||
|
|
||||||
|
exit 0
|
@ -0,0 +1,83 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
||||||
|
echo "using $ngpu gpus..."
|
||||||
|
|
||||||
|
expdir=exp
|
||||||
|
datadir=data
|
||||||
|
|
||||||
|
recog_set="test-clean test-other dev-clean dev-other"
|
||||||
|
recog_set="test-clean"
|
||||||
|
|
||||||
|
config_path=$1
|
||||||
|
decode_config_path=$2
|
||||||
|
ckpt_prefix=$3
|
||||||
|
|
||||||
|
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
|
||||||
|
|
||||||
|
# download language model
|
||||||
|
#bash local/download_lm_en.sh
|
||||||
|
#if [ $? -ne 0 ]; then
|
||||||
|
# exit 1
|
||||||
|
#fi
|
||||||
|
|
||||||
|
python3 format_rsl.py \
|
||||||
|
--origin_ref data/manifest.test-clean.raw \
|
||||||
|
--trans_ref data/manifest.test-clean.text
|
||||||
|
|
||||||
|
|
||||||
|
for type in ctc_greedy_search; do
|
||||||
|
echo "decoding ${type}"
|
||||||
|
batch_size=16
|
||||||
|
python3 -u ${BIN_DIR}/test.py \
|
||||||
|
--ngpu ${ngpu} \
|
||||||
|
--config ${config_path} \
|
||||||
|
--decode_cfg ${decode_config_path} \
|
||||||
|
--result_file ${ckpt_prefix}.${type}.rsl \
|
||||||
|
--checkpoint_path ${ckpt_prefix} \
|
||||||
|
--opts decode.decoding_method ${type} \
|
||||||
|
--opts decode.decode_batch_size ${batch_size}
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Failed in evaluation!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
python3 format_rsl.py \
|
||||||
|
--origin_hyp ${ckpt_prefix}.${type}.rsl \
|
||||||
|
--trans_hyp ${ckpt_prefix}.${type}.rsl.text
|
||||||
|
|
||||||
|
python3 compute_wer.py --char=1 --v=1 \
|
||||||
|
data/manifest.test-clean.text ${ckpt_prefix}.${type}.rsl.text > ${ckpt_prefix}.${type}.error
|
||||||
|
echo "decoding ${type} done."
|
||||||
|
done
|
||||||
|
|
||||||
|
for type in ctc_prefix_beam_search; do
|
||||||
|
echo "decoding ${type}"
|
||||||
|
batch_size=1
|
||||||
|
python3 -u ${BIN_DIR}/test.py \
|
||||||
|
--ngpu ${ngpu} \
|
||||||
|
--config ${config_path} \
|
||||||
|
--decode_cfg ${decode_config_path} \
|
||||||
|
--result_file ${ckpt_prefix}.${type}.rsl \
|
||||||
|
--checkpoint_path ${ckpt_prefix} \
|
||||||
|
--opts decode.decoding_method ${type} \
|
||||||
|
--opts decode.decode_batch_size ${batch_size}
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Failed in evaluation!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
python3 format_rsl.py \
|
||||||
|
--origin_hyp ${ckpt_prefix}.${type}.rsl \
|
||||||
|
--trans_hyp ${ckpt_prefix}.${type}.rsl.text
|
||||||
|
|
||||||
|
python3 compute_wer.py --char=1 --v=1 \
|
||||||
|
data/manifest.test-clean.text ${ckpt_prefix}.${type}.rsl.text > ${ckpt_prefix}.${type}.error
|
||||||
|
echo "decoding ${type} done."
|
||||||
|
done
|
||||||
|
|
||||||
|
echo "Finished"
|
||||||
|
|
||||||
|
exit 0
|
@ -0,0 +1,58 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
if [ $# != 4 ];then
|
||||||
|
echo "usage: ${0} config_path decode_config_path ckpt_path_prefix audio_file"
|
||||||
|
exit -1
|
||||||
|
fi
|
||||||
|
|
||||||
|
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
||||||
|
echo "using $ngpu gpus..."
|
||||||
|
|
||||||
|
config_path=$1
|
||||||
|
decode_config_path=$2
|
||||||
|
ckpt_prefix=$3
|
||||||
|
audio_file=$4
|
||||||
|
|
||||||
|
mkdir -p data
|
||||||
|
wget -nc https://paddlespeech.bj.bcebos.com/datasets/single_wav/en/demo_002_en.wav -P data/
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -f ${audio_file} ]; then
|
||||||
|
echo "Plase input the right audio_file path"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
chunk_mode=false
|
||||||
|
if [[ ${config_path} =~ ^.*chunk_.*yaml$ ]];then
|
||||||
|
chunk_mode=true
|
||||||
|
fi
|
||||||
|
|
||||||
|
# download language model
|
||||||
|
#bash local/download_lm_ch.sh
|
||||||
|
#if [ $? -ne 0 ]; then
|
||||||
|
# exit 1
|
||||||
|
#fi
|
||||||
|
|
||||||
|
for type in ctc_greedy_search; do
|
||||||
|
echo "decoding ${type}"
|
||||||
|
batch_size=1
|
||||||
|
output_dir=${ckpt_prefix}
|
||||||
|
mkdir -p ${output_dir}
|
||||||
|
python3 -u ${BIN_DIR}/test_wav.py \
|
||||||
|
--ngpu ${ngpu} \
|
||||||
|
--config ${config_path} \
|
||||||
|
--decode_cfg ${decode_config_path} \
|
||||||
|
--result_file ${output_dir}/${type}.rsl \
|
||||||
|
--checkpoint_path ${ckpt_prefix} \
|
||||||
|
--opts decode.decoding_method ${type} \
|
||||||
|
--opts decode.decode_batch_size ${batch_size} \
|
||||||
|
--audio_file ${audio_file}
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Failed in evaluation!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
exit 0
|
@ -0,0 +1,58 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
if [ $# -lt 2 ] && [ $# -gt 3 ];then
|
||||||
|
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name ips(optional)"
|
||||||
|
exit -1
|
||||||
|
fi
|
||||||
|
|
||||||
|
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
||||||
|
echo "using $ngpu gpus..."
|
||||||
|
|
||||||
|
config_path=$1
|
||||||
|
ckpt_name=$2
|
||||||
|
resume=$3
|
||||||
|
ips=$4
|
||||||
|
|
||||||
|
if [ ! $ips ];then
|
||||||
|
ips_config=
|
||||||
|
else
|
||||||
|
ips_config="--ips="${ips}
|
||||||
|
fi
|
||||||
|
|
||||||
|
mkdir -p exp
|
||||||
|
|
||||||
|
# seed may break model convergence
|
||||||
|
seed=1988
|
||||||
|
if [ ${seed} != 0 ]; then
|
||||||
|
export FLAGS_cudnn_deterministic=True
|
||||||
|
fi
|
||||||
|
|
||||||
|
# export FLAGS_cudnn_exhaustive_search=true
|
||||||
|
# export FLAGS_conv_workspace_size_limit=4000
|
||||||
|
export FLAGS_allocator_strategy=naive_best_fit
|
||||||
|
if [ ${ngpu} == 0 ]; then
|
||||||
|
python3 -u ${BIN_DIR}/train.py \
|
||||||
|
--ngpu ${ngpu} \
|
||||||
|
--config ${config_path} \
|
||||||
|
--output exp/${ckpt_name} \
|
||||||
|
--seed ${seed} \
|
||||||
|
--resume ${resume}
|
||||||
|
else
|
||||||
|
python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${ips_config} ${BIN_DIR}/train.py \
|
||||||
|
--ngpu ${ngpu} \
|
||||||
|
--config ${config_path} \
|
||||||
|
--output exp/${ckpt_name} \
|
||||||
|
--seed ${seed} \
|
||||||
|
--resume ${resume}
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${seed} != 0 ]; then
|
||||||
|
unset FLAGS_cudnn_deterministic
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Failed in training!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
exit 0
|
@ -0,0 +1,13 @@
|
|||||||
|
export MAIN_ROOT=`realpath ${PWD}/../../../`
|
||||||
|
|
||||||
|
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/tools/sctk/bin:${PWD}/utils:${PATH}
|
||||||
|
export LC_ALL=C
|
||||||
|
|
||||||
|
export PYTHONDONTWRITEBYTECODE=1
|
||||||
|
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
|
||||||
|
export PYTHONIOENCODING=UTF-8
|
||||||
|
# export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
|
||||||
|
|
||||||
|
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
|
||||||
|
|
||||||
|
export BIN_DIR=${MAIN_ROOT}/paddlespeech/s2t/exps/wavlm/bin
|
@ -0,0 +1,48 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -e
|
||||||
|
|
||||||
|
. ./path.sh || exit 1;
|
||||||
|
. ./cmd.sh || exit 1;
|
||||||
|
|
||||||
|
gpus=1,2,3
|
||||||
|
stage=0
|
||||||
|
stop_stage=3
|
||||||
|
conf_path=conf/wavlmASR.yaml
|
||||||
|
ips= #xx.xx.xx.xx,xx.xx.xx.xx
|
||||||
|
decode_conf_path=conf/tuning/decode.yaml
|
||||||
|
avg_num=3
|
||||||
|
resume= # xx e.g. 30
|
||||||
|
|
||||||
|
. ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
|
||||||
|
|
||||||
|
audio_file=data/demo_002_en.wav
|
||||||
|
|
||||||
|
# avg_ckpt=avg_${avg_num}
|
||||||
|
avg_ckpt=4
|
||||||
|
ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}')
|
||||||
|
echo "checkpoint name ${ckpt}"
|
||||||
|
|
||||||
|
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||||
|
# prepare data
|
||||||
|
bash ./local/data.sh || exit -1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||||
|
# train model, all `ckpt` under `exp` dir
|
||||||
|
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${resume} ${ips}
|
||||||
|
fi
|
||||||
|
|
||||||
|
# if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||||
|
# # avg n best model
|
||||||
|
# ./avg.sh best exp/${ckpt}/checkpoints ${avg_num}
|
||||||
|
# fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||||
|
# greedy search decoder
|
||||||
|
CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
||||||
|
# test a single .wav file
|
||||||
|
CUDA_VISIBLE_DEVICES=0 ./local/test_wav.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${audio_file} || exit -1
|
||||||
|
fi
|
Binary file not shown.
@ -0,0 +1 @@
|
|||||||
|
../../../utils
|
@ -0,0 +1,108 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Evaluation for WavLM model."""
|
||||||
|
import cProfile
|
||||||
|
|
||||||
|
from yacs.config import CfgNode
|
||||||
|
|
||||||
|
from paddlespeech.s2t.exps.wavlm.model import WavLMASRTester as Tester
|
||||||
|
from paddlespeech.s2t.training.cli import default_argument_parser
|
||||||
|
# from paddlespeech.utils.argparse import print_arguments
|
||||||
|
import distutils.util
|
||||||
|
|
||||||
|
def add_arguments(argname, type, default, help, argparser, **kwargs):
|
||||||
|
"""Add argparse's argument.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
add_argument("name", str, "Jonh", "User name.", parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
"""
|
||||||
|
type = distutils.util.strtobool if type == bool else type
|
||||||
|
argparser.add_argument(
|
||||||
|
"--" + argname,
|
||||||
|
default=default,
|
||||||
|
type=type,
|
||||||
|
help=help + ' Default: %(default)s.',
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
def print_arguments(args, info=None):
|
||||||
|
"""Print argparse's arguments.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("name", default="Jonh", type=str, help="User name.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
print_arguments(args)
|
||||||
|
|
||||||
|
:param args: Input argparse.Namespace for printing.
|
||||||
|
:type args: argparse.Namespace
|
||||||
|
"""
|
||||||
|
filename = ""
|
||||||
|
if info:
|
||||||
|
filename = info["__file__"]
|
||||||
|
filename = os.path.basename(filename)
|
||||||
|
print(f"----------- {filename} Configuration Arguments -----------")
|
||||||
|
for arg, value in sorted(vars(args).items()):
|
||||||
|
print("%s: %s" % (arg, value))
|
||||||
|
print("-----------------------------------------------------------")
|
||||||
|
|
||||||
|
|
||||||
|
def main_sp(config, args):
|
||||||
|
exp = Tester(config, args)
|
||||||
|
with exp.eval():
|
||||||
|
exp.setup()
|
||||||
|
exp.run_test()
|
||||||
|
|
||||||
|
|
||||||
|
def main(config, args):
|
||||||
|
main_sp(config, args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = default_argument_parser()
|
||||||
|
# save asr result to
|
||||||
|
parser.add_argument(
|
||||||
|
'--dict-path', type=str, default=None, help='dict path.')
|
||||||
|
parser.add_argument(
|
||||||
|
"--result_file", type=str, help="path of save the asr result")
|
||||||
|
args = parser.parse_args()
|
||||||
|
print_arguments(args, globals())
|
||||||
|
|
||||||
|
# https://yaml.org/type/float.html
|
||||||
|
config = CfgNode(new_allowed=True)
|
||||||
|
if args.config:
|
||||||
|
config.merge_from_file(args.config)
|
||||||
|
if args.decode_cfg:
|
||||||
|
decode_confs = CfgNode(new_allowed=True)
|
||||||
|
decode_confs.merge_from_file(args.decode_cfg)
|
||||||
|
config.decode = decode_confs
|
||||||
|
if args.opts:
|
||||||
|
config.merge_from_list(args.opts)
|
||||||
|
config.freeze()
|
||||||
|
print(config)
|
||||||
|
if args.dump_config:
|
||||||
|
with open(args.dump_config, 'w') as f:
|
||||||
|
print(config, file=f)
|
||||||
|
|
||||||
|
# Setting for profiling
|
||||||
|
pr = cProfile.Profile()
|
||||||
|
pr.runcall(main, config, args)
|
||||||
|
pr.dump_stats('test.profile')
|
@ -0,0 +1,125 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Evaluation for wavlm model."""
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import soundfile
|
||||||
|
from paddlenlp.transformers import AutoTokenizer
|
||||||
|
from yacs.config import CfgNode
|
||||||
|
|
||||||
|
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
|
||||||
|
from paddlespeech.s2t.models.wavlm.wavlm_asr import WavLMASR
|
||||||
|
from paddlespeech.s2t.training.cli import default_argument_parser
|
||||||
|
from paddlespeech.s2t.utils.log import Log
|
||||||
|
from paddlespeech.s2t.utils.utility import UpdateConfig
|
||||||
|
logger = Log(__name__).getlog()
|
||||||
|
|
||||||
|
|
||||||
|
class Wav2vec2Infer():
|
||||||
|
def __init__(self, config, args):
|
||||||
|
self.args = args
|
||||||
|
self.config = config
|
||||||
|
self.audio_file = args.audio_file
|
||||||
|
self.tokenizer = config.get("tokenizer", None)
|
||||||
|
|
||||||
|
if self.tokenizer:
|
||||||
|
self.text_feature = AutoTokenizer.from_pretrained(
|
||||||
|
self.config.tokenizer)
|
||||||
|
else:
|
||||||
|
self.text_feature = TextFeaturizer(
|
||||||
|
unit_type=config.unit_type, vocab=config.vocab_filepath)
|
||||||
|
|
||||||
|
paddle.set_device('gpu' if self.args.ngpu > 0 else 'cpu')
|
||||||
|
|
||||||
|
# model
|
||||||
|
model_conf = config
|
||||||
|
with UpdateConfig(model_conf):
|
||||||
|
model_conf.output_dim = self.text_feature.vocab_size
|
||||||
|
model = WavLMASR.from_config(model_conf)
|
||||||
|
self.model = model
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
# load model
|
||||||
|
params_path = self.args.checkpoint_path + ".pdparams"
|
||||||
|
model_dict = paddle.load(params_path)
|
||||||
|
self.model.set_state_dict(model_dict)
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
check(args.audio_file)
|
||||||
|
|
||||||
|
with paddle.no_grad():
|
||||||
|
# read
|
||||||
|
audio, _ = soundfile.read(
|
||||||
|
self.audio_file, dtype="int16", always_2d=True)
|
||||||
|
logger.info(f"audio shape: {audio.shape}")
|
||||||
|
xs = paddle.to_tensor(audio, dtype='float32').unsqueeze(axis=0)
|
||||||
|
decode_config = self.config.decode
|
||||||
|
result_transcripts, result_tokenids = self.model.decode(
|
||||||
|
xs,
|
||||||
|
text_feature=self.text_feature,
|
||||||
|
decoding_method=decode_config.decoding_method,
|
||||||
|
beam_size=decode_config.beam_size,
|
||||||
|
tokenizer=self.tokenizer, )
|
||||||
|
rsl = result_transcripts[0]
|
||||||
|
utt = Path(self.audio_file).name
|
||||||
|
logger.info(f"hyp: {utt} {rsl}")
|
||||||
|
return rsl
|
||||||
|
|
||||||
|
|
||||||
|
def check(audio_file):
|
||||||
|
if not os.path.isfile(audio_file):
|
||||||
|
print("Please input the right audio file path")
|
||||||
|
sys.exit(-1)
|
||||||
|
|
||||||
|
logger.info("checking the audio file format......")
|
||||||
|
try:
|
||||||
|
sig, sample_rate = soundfile.read(audio_file)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(str(e))
|
||||||
|
logger.error(
|
||||||
|
"can not open the wav file, please check the audio file format")
|
||||||
|
sys.exit(-1)
|
||||||
|
logger.info("The sample rate is %d" % sample_rate)
|
||||||
|
assert (sample_rate == 16000)
|
||||||
|
logger.info("The audio file format is right")
|
||||||
|
|
||||||
|
|
||||||
|
def main(config, args):
|
||||||
|
Wav2vec2Infer(config, args).run()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = default_argument_parser()
|
||||||
|
# save asr result to
|
||||||
|
parser.add_argument(
|
||||||
|
"--result_file", type=str, help="path of save the asr result")
|
||||||
|
parser.add_argument(
|
||||||
|
"--audio_file", type=str, help="path of the input audio file")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
config = CfgNode(new_allowed=True)
|
||||||
|
|
||||||
|
if args.config:
|
||||||
|
config.merge_from_file(args.config)
|
||||||
|
if args.decode_cfg:
|
||||||
|
decode_confs = CfgNode(new_allowed=True)
|
||||||
|
decode_confs.merge_from_file(args.decode_cfg)
|
||||||
|
config.decode = decode_confs
|
||||||
|
if args.opts:
|
||||||
|
config.merge_from_list(args.opts)
|
||||||
|
config.freeze()
|
||||||
|
main(config, args)
|
@ -0,0 +1,101 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Trainer for wavlm model."""
|
||||||
|
import cProfile
|
||||||
|
import os
|
||||||
|
|
||||||
|
from yacs.config import CfgNode
|
||||||
|
|
||||||
|
from paddlespeech.s2t.exps.wavlm.model import WavLMASRTrainer as Trainer
|
||||||
|
from paddlespeech.s2t.training.cli import default_argument_parser
|
||||||
|
# from paddlespeech.utils.argparse import print_arguments
|
||||||
|
|
||||||
|
import distutils.util
|
||||||
|
|
||||||
|
def add_arguments(argname, type, default, help, argparser, **kwargs):
|
||||||
|
"""Add argparse's argument.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
add_argument("name", str, "Jonh", "User name.", parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
"""
|
||||||
|
type = distutils.util.strtobool if type == bool else type
|
||||||
|
argparser.add_argument(
|
||||||
|
"--" + argname,
|
||||||
|
default=default,
|
||||||
|
type=type,
|
||||||
|
help=help + ' Default: %(default)s.',
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
def print_arguments(args, info=None):
|
||||||
|
"""Print argparse's arguments.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("name", default="Jonh", type=str, help="User name.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
print_arguments(args)
|
||||||
|
|
||||||
|
:param args: Input argparse.Namespace for printing.
|
||||||
|
:type args: argparse.Namespace
|
||||||
|
"""
|
||||||
|
filename = ""
|
||||||
|
if info:
|
||||||
|
filename = info["__file__"]
|
||||||
|
filename = os.path.basename(filename)
|
||||||
|
print(f"----------- {filename} Configuration Arguments -----------")
|
||||||
|
for arg, value in sorted(vars(args).items()):
|
||||||
|
print("%s: %s" % (arg, value))
|
||||||
|
print("-----------------------------------------------------------")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def main_sp(config, args):
|
||||||
|
exp = Trainer(config, args)
|
||||||
|
exp.setup()
|
||||||
|
exp.run()
|
||||||
|
|
||||||
|
|
||||||
|
def main(config, args):
|
||||||
|
main_sp(config, args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = default_argument_parser()
|
||||||
|
parser.add_argument(
|
||||||
|
'--resume', type=str, default="", nargs="?", help='resume ckpt path.')
|
||||||
|
args = parser.parse_args()
|
||||||
|
print_arguments(args, globals())
|
||||||
|
# https://yaml.org/type/float.html
|
||||||
|
config = CfgNode(new_allowed=True)
|
||||||
|
if args.config:
|
||||||
|
config.merge_from_file(args.config)
|
||||||
|
if args.opts:
|
||||||
|
config.merge_from_list(args.opts)
|
||||||
|
config.freeze()
|
||||||
|
if args.dump_config:
|
||||||
|
with open(args.dump_config, 'w') as f:
|
||||||
|
print(config, file=f)
|
||||||
|
|
||||||
|
# Setting for profiling
|
||||||
|
pr = cProfile.Profile()
|
||||||
|
pr.runcall(main, config, args)
|
||||||
|
pr.dump_stats(os.path.join(args.output, 'train.profile'))
|
@ -0,0 +1,912 @@
|
|||||||
|
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Contains wavlm model."""
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
from collections import OrderedDict
|
||||||
|
from contextlib import nullcontext
|
||||||
|
|
||||||
|
import jsonlines
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
from hyperpyyaml import load_hyperpyyaml
|
||||||
|
from paddle import distributed as dist
|
||||||
|
from paddlenlp.transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from paddlespeech.s2t.frontend.featurizer import TextFeaturizer
|
||||||
|
from paddlespeech.s2t.io.dataloader import DataLoaderFactory
|
||||||
|
from paddlespeech.s2t.io.speechbrain import data_pipeline
|
||||||
|
from paddlespeech.s2t.io.speechbrain import dataio
|
||||||
|
from paddlespeech.s2t.io.speechbrain import dataset
|
||||||
|
from paddlespeech.s2t.io.speechbrain.dataloader import make_dataloader
|
||||||
|
from paddlespeech.s2t.models.wavlm.processing.speech_augmentation import TimeDomainSpecAugment
|
||||||
|
from paddlespeech.s2t.models.wavlm.wavlm_asr import WavLMASR
|
||||||
|
from paddlespeech.s2t.training.optimizer import OptimizerFactory
|
||||||
|
from paddlespeech.s2t.training.reporter import ObsScope
|
||||||
|
from paddlespeech.s2t.training.reporter import report
|
||||||
|
from paddlespeech.s2t.training.scheduler import LRSchedulerFactory
|
||||||
|
from paddlespeech.s2t.training.timer import Timer
|
||||||
|
from paddlespeech.s2t.training.trainer import Trainer
|
||||||
|
from paddlespeech.s2t.utils import error_rate
|
||||||
|
from paddlespeech.s2t.utils import layer_tools
|
||||||
|
from paddlespeech.s2t.utils import mp_tools
|
||||||
|
from paddlespeech.s2t.utils.log import Log
|
||||||
|
from paddlespeech.s2t.utils.utility import UpdateConfig
|
||||||
|
|
||||||
|
logger = Log(__name__).getlog()
|
||||||
|
|
||||||
|
|
||||||
|
def clip_grad_norm_(
|
||||||
|
parameters,
|
||||||
|
max_norm,
|
||||||
|
norm_type=2.0,
|
||||||
|
error_if_nonfinite=False, ):
|
||||||
|
r"""Clips gradient norm of the iteratable parameters.
|
||||||
|
|
||||||
|
Norms are calculated together on all gradients, just as they are
|
||||||
|
connected into one vector. The gradient will be modified in place.
|
||||||
|
|
||||||
|
This API can only run in dynamic graph mode, not static graph mode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
parameters (Iterable[paddle.Tensor] or paddle.Tensor): Tensors or a single Tensor
|
||||||
|
that will be normalized gradients
|
||||||
|
max_norm (float or int): max norm of the gradients
|
||||||
|
norm_type (float or int): type of the used p-norm. Can be `inf` for
|
||||||
|
infinity norm.
|
||||||
|
error_if_nonfinite (bool): if True, throw an error if the total
|
||||||
|
norm of the gradients from :attr:`parameters` is `nan`,
|
||||||
|
`inf`, or `-inf`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Total norm of the parameter gradients (treated as a single vector).
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
x = paddle.uniform([10, 10], min=-1.0, max=1.0, dtype='float32')
|
||||||
|
max_norm = float(5.0)
|
||||||
|
linear = paddle.nn.Linear(in_features=10, out_features=10)
|
||||||
|
out = linear(x)
|
||||||
|
loss = paddle.mean(out)
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
paddle.nn.utils.clip_grad_norm_(linear.parameters(), max_norm)
|
||||||
|
|
||||||
|
sdg = paddle.optimizer.SGD(learning_rate=0.1, parameters=linear.parameters())
|
||||||
|
sdg.step()
|
||||||
|
"""
|
||||||
|
if not paddle.in_dynamic_mode():
|
||||||
|
raise RuntimeError('this API can only run in dynamic mode.')
|
||||||
|
|
||||||
|
if isinstance(parameters, paddle.Tensor):
|
||||||
|
parameters = [parameters]
|
||||||
|
|
||||||
|
support_norm_type = [float("inf"), 0, 1, 2]
|
||||||
|
if norm_type not in support_norm_type:
|
||||||
|
raise ValueError(f'norm_type only support {support_norm_type}')
|
||||||
|
|
||||||
|
grads = [p.grad for p in parameters if p.grad is not None]
|
||||||
|
max_norm = float(max_norm)
|
||||||
|
norm_type = float(norm_type)
|
||||||
|
if len(grads) == 0:
|
||||||
|
return paddle.to_tensor(0.0)
|
||||||
|
if norm_type == float("inf"):
|
||||||
|
norms = [g.detach().abs().max() for g in grads]
|
||||||
|
total_norm = (norms[0]
|
||||||
|
if len(norms) == 1 else paddle.max(paddle.stack(norms)))
|
||||||
|
else:
|
||||||
|
total_norm = paddle.linalg.norm(
|
||||||
|
paddle.stack(
|
||||||
|
[paddle.linalg.norm(g.detach(), norm_type) for g in grads]),
|
||||||
|
norm_type, )
|
||||||
|
|
||||||
|
if error_if_nonfinite and paddle.logical_or(total_norm.isnan(),
|
||||||
|
total_norm.isinf()):
|
||||||
|
raise RuntimeError(
|
||||||
|
f'The total norm of {norm_type} order of the gradients from '
|
||||||
|
'`parameters` is non-finite, so it cannot be clipped. In any case, '
|
||||||
|
'disable this error and scale the gradient by non-finite norm, '
|
||||||
|
'set `error_if_nonfinite=False`')
|
||||||
|
clip_coef = max_norm / (total_norm + 1e-6)
|
||||||
|
# Note: when the coef is clamped to 1, it is redundant to multiply the clamped coef, but this
|
||||||
|
# avoids the `if clip_coef < 1:` condition.
|
||||||
|
clip_coef_clamped = paddle.clip(clip_coef, max=1.0)
|
||||||
|
with paddle.no_grad():
|
||||||
|
for _, p in enumerate(parameters):
|
||||||
|
g = p.grad
|
||||||
|
if g is not None:
|
||||||
|
p.grad = paddle.multiply(x=g, y=clip_coef_clamped)
|
||||||
|
return total_norm
|
||||||
|
|
||||||
|
|
||||||
|
class WavLMASRTrainer(Trainer):
|
||||||
|
def __init__(self, config, args):
|
||||||
|
super().__init__(config, args)
|
||||||
|
self.avg_train_loss = 0.0
|
||||||
|
self.loss_isfinite = True # while flag is 'False', loss in Nan or inf, and can not be avg
|
||||||
|
self.use_sb = True # whether use speech brain dataloader
|
||||||
|
|
||||||
|
def update_average(self, batch_index, loss):
|
||||||
|
"""Update running average of the loss.
|
||||||
|
Arguments
|
||||||
|
---------
|
||||||
|
batch_index : int
|
||||||
|
current batch index
|
||||||
|
loss : paddle.tensor
|
||||||
|
detached loss, a single float value.
|
||||||
|
"""
|
||||||
|
if math.isfinite(loss):
|
||||||
|
self.avg_train_loss -= self.avg_train_loss / (batch_index + 1)
|
||||||
|
self.avg_train_loss += loss / (batch_index + 1)
|
||||||
|
else:
|
||||||
|
self.loss_isfinite = False
|
||||||
|
logger.info('loss:{} in Nan or inf, error'.format(loss))
|
||||||
|
|
||||||
|
def before_train(self):
|
||||||
|
from_scratch = self.resume_or_scratch()
|
||||||
|
if from_scratch:
|
||||||
|
# scratch: save init model, i.e. 0 epoch
|
||||||
|
self.save(tag='init', infos=None)
|
||||||
|
else:
|
||||||
|
# resume: train next_epoch and next_iteration
|
||||||
|
self.epoch += 1
|
||||||
|
logger.info(
|
||||||
|
f"Resume train: epoch {self.epoch }, step {self.iteration}!")
|
||||||
|
|
||||||
|
self.maybe_batch_sampler_step()
|
||||||
|
|
||||||
|
def train_batch(self, batch_index, batch, msg):
|
||||||
|
train_conf = self.config
|
||||||
|
start = time.time()
|
||||||
|
|
||||||
|
# forward
|
||||||
|
## sb data pipeline
|
||||||
|
if self.use_sb:
|
||||||
|
wav, wavs_lens_rate = batch['sig']
|
||||||
|
target, target_lens_rate = batch['tokens']
|
||||||
|
target_lens = (target_lens_rate *
|
||||||
|
target.shape[1]).round().astype(paddle.int64)
|
||||||
|
else:
|
||||||
|
utt, wav, wavs_lens, target, target_lens = batch
|
||||||
|
wavs_lens_rate = wavs_lens / wav.shape[1]
|
||||||
|
wav = wav[:, :, 0]
|
||||||
|
|
||||||
|
if hasattr(train_conf, 'audio_augment'):
|
||||||
|
wav = self.speech_augmentation(wav, wavs_lens_rate)
|
||||||
|
loss = self.model(wav, wavs_lens_rate, target, target_lens)
|
||||||
|
|
||||||
|
# loss div by `batch_size * accum_grad`
|
||||||
|
loss /= train_conf.accum_grad
|
||||||
|
# update self.avg_train_loss
|
||||||
|
self.update_average(batch_index, float(loss))
|
||||||
|
|
||||||
|
# loss backward
|
||||||
|
if (batch_index + 1) % train_conf.accum_grad != 0:
|
||||||
|
# Disable gradient synchronizations across DDP processes.
|
||||||
|
# Within this context, gradients will be accumulated on module
|
||||||
|
# variables, which will later be synchronized.
|
||||||
|
# When using cpu w/o DDP, model does not have `no_sync`
|
||||||
|
context = self.model.no_sync if (hasattr(self.model, "no_sync") and
|
||||||
|
self.parallel) else nullcontext
|
||||||
|
else:
|
||||||
|
# Used for single gpu training and DDP gradient synchronization
|
||||||
|
# processes.
|
||||||
|
context = nullcontext
|
||||||
|
with context():
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
layer_tools.print_grads(self.model, print_func=None)
|
||||||
|
|
||||||
|
# NOTE: the code below asserted that the backward() is problematic, and as more steps are accumulated, the output from wavlm alone will be the same for all frames
|
||||||
|
# optimizer step old
|
||||||
|
if (batch_index + 1) % train_conf.accum_grad == 0:
|
||||||
|
#do global grad clip
|
||||||
|
if train_conf.global_grad_clip != 0:
|
||||||
|
clip_grad_norm_(self.model.parameters(),
|
||||||
|
train_conf.global_grad_clip)
|
||||||
|
self.model_optimizer.step()
|
||||||
|
self.model_optimizer.clear_grad()
|
||||||
|
if not train_conf.freeze_wavlm:
|
||||||
|
self.wavlm_optimizer.step()
|
||||||
|
self.wavlm_optimizer.clear_grad()
|
||||||
|
if self.config.model_scheduler != 'newbobscheduler':
|
||||||
|
self.model_lr_scheduler.step()
|
||||||
|
if self.config.wavlm_scheduler != 'newbobscheduler':
|
||||||
|
if not train_conf.freeze_wavlm:
|
||||||
|
self.wavlm_lr_scheduler.step()
|
||||||
|
self.iteration += 1
|
||||||
|
|
||||||
|
losses_np = {'loss': self.avg_train_loss * train_conf.accum_grad}
|
||||||
|
iteration_time = time.time() - start
|
||||||
|
for k, v in losses_np.items():
|
||||||
|
report(k, v)
|
||||||
|
report("loss_whitoutavg", float(loss))
|
||||||
|
report("batch_size", self.config.batch_size)
|
||||||
|
report("accum", train_conf.accum_grad)
|
||||||
|
report("step_cost", iteration_time)
|
||||||
|
|
||||||
|
if (batch_index + 1) % train_conf.accum_grad == 0:
|
||||||
|
if dist.get_rank() == 0 and self.visualizer:
|
||||||
|
losses_np_v = losses_np.copy()
|
||||||
|
losses_np_v.update({
|
||||||
|
"model_lr": self.model_lr_scheduler(),
|
||||||
|
"wavlm_lr": self.wavlm_lr_scheduler()
|
||||||
|
})
|
||||||
|
for key, val in losses_np_v.items():
|
||||||
|
self.visualizer.add_scalar(
|
||||||
|
tag='train/' + key, value=val, step=self.iteration - 1)
|
||||||
|
|
||||||
|
@paddle.no_grad()
|
||||||
|
def valid(self):
|
||||||
|
self.model.eval()
|
||||||
|
if not self.use_streamdata:
|
||||||
|
logger.info(
|
||||||
|
f"Valid Total Examples: {len(self.valid_loader.dataset)}")
|
||||||
|
valid_losses = {}
|
||||||
|
step = 0
|
||||||
|
total_loss = 0.0
|
||||||
|
num_seen_utts = 1 # use update_average and no need for num_seen_utts here
|
||||||
|
for i, batch in enumerate(self.valid_loader):
|
||||||
|
if self.use_sb:
|
||||||
|
wav, wavs_lens_rate = batch['sig']
|
||||||
|
target, target_lens_rate = batch['tokens']
|
||||||
|
target_lens = (target_lens_rate *
|
||||||
|
target.shape[1]).round().astype(paddle.int64)
|
||||||
|
else:
|
||||||
|
utt, wav, wavs_lens, target, target_lens = batch
|
||||||
|
wavs_lens_rate = wavs_lens / wav.shape[1]
|
||||||
|
wav = wav[:, :, 0]
|
||||||
|
|
||||||
|
loss = self.model(wav, wavs_lens_rate, target, target_lens)
|
||||||
|
# use update_average
|
||||||
|
total_loss -= total_loss / (step + 1)
|
||||||
|
total_loss += loss / (step + 1)
|
||||||
|
|
||||||
|
if math.isfinite(float(loss)):
|
||||||
|
step += 1
|
||||||
|
valid_losses['val_loss'] = float(loss)
|
||||||
|
else:
|
||||||
|
logger.info('loss:{} in Nan or inf, error'.format(float(loss)))
|
||||||
|
|
||||||
|
if (i + 1) % self.config.log_interval == 0:
|
||||||
|
valid_losses['val_history_loss'] = float(total_loss)
|
||||||
|
|
||||||
|
# logging
|
||||||
|
msg = f"Valid: Rank: {dist.get_rank()}, "
|
||||||
|
msg += "epoch: {}, ".format(self.epoch)
|
||||||
|
msg += "step: {}, ".format(self.iteration)
|
||||||
|
if not self.use_streamdata:
|
||||||
|
msg += "batch: {}/{}, ".format(i + 1,
|
||||||
|
len(self.valid_loader))
|
||||||
|
msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
||||||
|
for k, v in valid_losses.items())
|
||||||
|
logger.info(msg)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
'Rank {} Val info val_loss {}'.format(dist.get_rank(), total_loss))
|
||||||
|
return total_loss, num_seen_utts
|
||||||
|
|
||||||
|
@mp_tools.rank_zero_only
|
||||||
|
def save(self, tag=None, infos: dict=None):
|
||||||
|
"""Save checkpoint (model parameters and optimizer states).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tag (int or str, optional): None for step, else using tag, e.g epoch. Defaults to None.
|
||||||
|
infos (dict, optional): meta data to save. Defaults to None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
infos = infos if infos else dict()
|
||||||
|
infos.update({
|
||||||
|
"epoch": self.epoch,
|
||||||
|
"model_lr": self.model_optimizer.get_lr(),
|
||||||
|
"wavlm_lr": self.wavlm_optimizer.get_lr()
|
||||||
|
})
|
||||||
|
|
||||||
|
checkpoint_path = os.path.join(
|
||||||
|
self.checkpoint_dir,
|
||||||
|
"{}".format(self.iteration if tag is None else tag))
|
||||||
|
|
||||||
|
model_dict = self.model.state_dict()
|
||||||
|
params_path = checkpoint_path + ".pdparams"
|
||||||
|
paddle.save(model_dict, params_path)
|
||||||
|
logger.info("Saved model to {}".format(params_path))
|
||||||
|
|
||||||
|
model_opt_dict = self.model_optimizer.state_dict()
|
||||||
|
wavlm_opt_dict = self.wavlm_optimizer.state_dict()
|
||||||
|
|
||||||
|
opt_dict = {'model': model_opt_dict, 'wavlm': wavlm_opt_dict}
|
||||||
|
|
||||||
|
optimizer_path = checkpoint_path + ".pdopt"
|
||||||
|
paddle.save(opt_dict, optimizer_path)
|
||||||
|
logger.info("Saved optimzier state to {}".format(optimizer_path))
|
||||||
|
|
||||||
|
scheduler_dict = {}
|
||||||
|
|
||||||
|
if self.config.model_scheduler == 'newbobscheduler':
|
||||||
|
scheduler_dict['model'] = self.model_lr_scheduler.save()
|
||||||
|
if self.config.wavlm_scheduler == 'newbobscheduler':
|
||||||
|
scheduler_dict['wavlm'] = self.wavlm_lr_scheduler.save()
|
||||||
|
if scheduler_dict:
|
||||||
|
scheduler_path = checkpoint_path + ".pdlrs"
|
||||||
|
paddle.save(scheduler_dict, scheduler_path)
|
||||||
|
logger.info("Saved scheduler state to {}".format(scheduler_path))
|
||||||
|
info_path = re.sub('.pdparams$', '.json', params_path)
|
||||||
|
infos = {} if infos is None else infos
|
||||||
|
with open(info_path, 'w', encoding='utf8') as fout:
|
||||||
|
data = json.dumps(infos)
|
||||||
|
fout.write(data)
|
||||||
|
|
||||||
|
def resume_or_scratch(self):
|
||||||
|
"""Resume from latest checkpoint at checkpoints in the output
|
||||||
|
directory or load a specified checkpoint.
|
||||||
|
|
||||||
|
If ``args.checkpoint_path`` is not None, load the checkpoint, else
|
||||||
|
resume training.
|
||||||
|
"""
|
||||||
|
scratch = None
|
||||||
|
if self.args.resume:
|
||||||
|
# just restore ckpt
|
||||||
|
# lr will resotre from optimizer ckpt
|
||||||
|
resume_json_path = os.path.join(self.checkpoint_dir,
|
||||||
|
self.args.resume + '.json')
|
||||||
|
with open(resume_json_path, 'r', encoding='utf8') as f:
|
||||||
|
resume_json = json.load(f)
|
||||||
|
self.iteration = 0
|
||||||
|
self.epoch = resume_json["epoch"]
|
||||||
|
|
||||||
|
# resotre model from *.pdparams
|
||||||
|
params_path = os.path.join(self.checkpoint_dir,
|
||||||
|
"{}".format(self.epoch)) + '.pdparams'
|
||||||
|
model_dict = paddle.load(params_path)
|
||||||
|
self.model.set_state_dict(model_dict)
|
||||||
|
|
||||||
|
# resotre optimizer from *.pdopt
|
||||||
|
optimizer_path = os.path.join(self.checkpoint_dir,
|
||||||
|
"{}".format(self.epoch)) + '.pdopt'
|
||||||
|
optimizer_dict = paddle.load(optimizer_path)
|
||||||
|
self.model_optimizer.set_state_dict(optimizer_dict['model'])
|
||||||
|
self.wavlm_optimizer.set_state_dict(optimizer_dict['wavlm'])
|
||||||
|
|
||||||
|
# resotre lr_scheduler from *.pdlrs
|
||||||
|
scheduler_path = os.path.join(self.checkpoint_dir,
|
||||||
|
"{}".format(self.epoch)) + '.pdlrs'
|
||||||
|
if os.path.isfile(os.path.join(scheduler_path)):
|
||||||
|
scheduler_dict = paddle.load(scheduler_path)
|
||||||
|
if self.config.model_scheduler == 'newbobscheduler':
|
||||||
|
self.model_lr_scheduler.load(scheduler_dict['model'])
|
||||||
|
if self.config.wavlm_scheduler == 'newbobscheduler':
|
||||||
|
self.wavlm_lr_scheduler.load(scheduler_dict['wavlm'])
|
||||||
|
logger.info(
|
||||||
|
f"Restore ckpt: epoch {self.epoch }, step {self.iteration}!")
|
||||||
|
scratch = False
|
||||||
|
else:
|
||||||
|
self.iteration = 0
|
||||||
|
self.epoch = 0
|
||||||
|
scratch = True
|
||||||
|
logger.info("Init from scratch!")
|
||||||
|
return scratch
|
||||||
|
|
||||||
|
def do_train(self):
|
||||||
|
"""The training process control by step."""
|
||||||
|
# !!!IMPORTANT!!!
|
||||||
|
# Try to export the model by script, if fails, we should refine
|
||||||
|
# the code to satisfy the script export requirements
|
||||||
|
# script_model = paddle.jit.to_static(self.model)
|
||||||
|
# script_model_path = str(self.checkpoint_dir / 'init')
|
||||||
|
# paddle.jit.save(script_model, script_model_path)
|
||||||
|
|
||||||
|
self.before_train()
|
||||||
|
if not self.use_streamdata:
|
||||||
|
logger.info(
|
||||||
|
f"Train Total Examples: {len(self.train_loader.dataset)}")
|
||||||
|
while self.epoch < self.config.n_epoch:
|
||||||
|
with Timer("Epoch-Train Time Cost: {}"):
|
||||||
|
self.model.train()
|
||||||
|
try:
|
||||||
|
data_start_time = time.time()
|
||||||
|
for batch_index, batch in enumerate(self.train_loader):
|
||||||
|
dataload_time = time.time() - data_start_time
|
||||||
|
msg = "Train:"
|
||||||
|
observation = OrderedDict()
|
||||||
|
with ObsScope(observation):
|
||||||
|
report("Rank", dist.get_rank())
|
||||||
|
report("epoch", self.epoch)
|
||||||
|
report('step', self.iteration)
|
||||||
|
report("model_lr", self.model_optimizer.get_lr())
|
||||||
|
report("wavlm_lr",
|
||||||
|
self.wavlm_optimizer.get_lr())
|
||||||
|
self.train_batch(batch_index, batch, msg)
|
||||||
|
self.after_train_batch()
|
||||||
|
report('iter', batch_index + 1)
|
||||||
|
if not self.use_streamdata:
|
||||||
|
report('total', len(self.train_loader))
|
||||||
|
report('reader_cost', dataload_time)
|
||||||
|
observation['batch_cost'] = observation[
|
||||||
|
'reader_cost'] + observation['step_cost']
|
||||||
|
observation['samples'] = observation['batch_size']
|
||||||
|
observation['ips,samples/s'] = observation[
|
||||||
|
'batch_size'] / observation['batch_cost']
|
||||||
|
for k, v in observation.items():
|
||||||
|
msg += f" {k.split(',')[0]}: "
|
||||||
|
msg += f"{v:>.8f}" if isinstance(v,
|
||||||
|
float) else f"{v}"
|
||||||
|
msg += f" {k.split(',')[1]}" if len(
|
||||||
|
k.split(',')) == 2 else ""
|
||||||
|
msg += ","
|
||||||
|
msg = msg[:-1] # remove the last ","
|
||||||
|
if (batch_index + 1) % self.config.log_interval == 0:
|
||||||
|
logger.info(msg)
|
||||||
|
data_start_time = time.time()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e)
|
||||||
|
raise e
|
||||||
|
with Timer("Eval Time Cost: {}"):
|
||||||
|
total_loss, num_seen_utts = self.valid()
|
||||||
|
if dist.get_world_size() > 1:
|
||||||
|
num_seen_utts = paddle.to_tensor(num_seen_utts)
|
||||||
|
dist.all_reduce(num_seen_utts)
|
||||||
|
total_loss = paddle.to_tensor(total_loss)
|
||||||
|
dist.all_reduce(total_loss)
|
||||||
|
cv_loss = total_loss / num_seen_utts
|
||||||
|
cv_loss = float(cv_loss)
|
||||||
|
else:
|
||||||
|
cv_loss = float(total_loss)
|
||||||
|
logger.info(
|
||||||
|
'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss))
|
||||||
|
if self.visualizer:
|
||||||
|
self.visualizer.add_scalar(
|
||||||
|
tag='eval/cv_loss', value=cv_loss, step=self.epoch)
|
||||||
|
self.visualizer.add_scalar(
|
||||||
|
tag='eval/model_lr',
|
||||||
|
value=self.model_lr_scheduler(),
|
||||||
|
step=self.epoch)
|
||||||
|
self.visualizer.add_scalar(
|
||||||
|
tag='eval/wavlm_lr',
|
||||||
|
value=self.wavlm_lr_scheduler(),
|
||||||
|
step=self.epoch)
|
||||||
|
|
||||||
|
if self.config.model_scheduler == 'newbobscheduler':
|
||||||
|
self.model_lr_scheduler.step(cv_loss)
|
||||||
|
if self.config.wavlm_scheduler == 'newbobscheduler':
|
||||||
|
if not self.config.freeze_wavlm:
|
||||||
|
self.wavlm_lr_scheduler.step(cv_loss)
|
||||||
|
self.save(tag=self.epoch, infos={'val_loss': cv_loss})
|
||||||
|
self.avg_train_loss = 0.0
|
||||||
|
self.new_epoch()
|
||||||
|
|
||||||
|
def dataio_prepare(self, hparams):
|
||||||
|
"""This function prepares the datasets to be used in the brain class.
|
||||||
|
It also defines the data processing pipeline through user-defined functions."""
|
||||||
|
data_folder = hparams["data_folder"]
|
||||||
|
|
||||||
|
train_data = dataset.DynamicItemDataset.from_csv(
|
||||||
|
csv_path=hparams["train_data"],
|
||||||
|
replacements={"data_root": data_folder}, )
|
||||||
|
|
||||||
|
if hparams["sorting"] == "ascending":
|
||||||
|
# we sort training data to speed up training and get better results.
|
||||||
|
train_data = train_data.filtered_sorted(sort_key="duration")
|
||||||
|
# when sorting do not shuffle in dataloader ! otherwise is pointless
|
||||||
|
hparams["train_dataloader_opts"]["shuffle"] = False
|
||||||
|
|
||||||
|
elif hparams["sorting"] == "descending":
|
||||||
|
train_data = train_data.filtered_sorted(
|
||||||
|
sort_key="duration", reverse=True)
|
||||||
|
# when sorting do not shuffle in dataloader ! otherwise is pointless
|
||||||
|
hparams["train_dataloader_opts"]["shuffle"] = False
|
||||||
|
|
||||||
|
elif hparams["sorting"] == "random":
|
||||||
|
pass
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"sorting must be random, ascending or descending")
|
||||||
|
|
||||||
|
valid_data = dataset.DynamicItemDataset.from_csv(
|
||||||
|
csv_path=hparams["valid_data"],
|
||||||
|
replacements={"data_root": data_folder}, )
|
||||||
|
valid_data = valid_data.filtered_sorted(sort_key="duration")
|
||||||
|
|
||||||
|
test_data = dataset.DynamicItemDataset.from_csv(
|
||||||
|
csv_path=hparams["test_data"],
|
||||||
|
replacements={"data_root": data_folder}, )
|
||||||
|
test_data = test_data.filtered_sorted(sort_key="duration")
|
||||||
|
|
||||||
|
datasets = [train_data, valid_data, test_data]
|
||||||
|
|
||||||
|
# Defining tokenizer and loading it
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained('bert-base-chinese')
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
# 2. Define audio pipeline:
|
||||||
|
@data_pipeline.takes("wav")
|
||||||
|
@data_pipeline.provides("sig")
|
||||||
|
def audio_pipeline(wav):
|
||||||
|
sig = dataio.read_audio(wav)
|
||||||
|
return sig
|
||||||
|
|
||||||
|
dataset.add_dynamic_item(datasets, audio_pipeline)
|
||||||
|
|
||||||
|
# 3. Define text pipeline:
|
||||||
|
@data_pipeline.takes("transcript")
|
||||||
|
@data_pipeline.provides("wrd", "tokens_list", "tokens")
|
||||||
|
def text_pipeline(wrd):
|
||||||
|
wrd = "".join(wrd.split(" "))
|
||||||
|
yield wrd
|
||||||
|
tokens_list = tokenizer(wrd)["input_ids"]
|
||||||
|
yield tokens_list
|
||||||
|
tokens = np.array(tokens_list, dtype="int64")
|
||||||
|
# tokens = paddle.to_tensor(tokens_list, dtype="int64")
|
||||||
|
yield tokens
|
||||||
|
|
||||||
|
dataset.add_dynamic_item(datasets, text_pipeline)
|
||||||
|
|
||||||
|
# 4. Set output:
|
||||||
|
dataset.set_output_keys(
|
||||||
|
datasets,
|
||||||
|
["id", "sig", "wrd", "tokens"], )
|
||||||
|
|
||||||
|
# 5. If Dynamic Batching is used, we instantiate the needed samplers.
|
||||||
|
train_batch_sampler = None
|
||||||
|
valid_batch_sampler = None
|
||||||
|
if hparams["dynamic_batching"]:
|
||||||
|
from sampler import DynamicBatchSampler # noqa
|
||||||
|
|
||||||
|
dynamic_hparams = hparams["dynamic_batch_sampler"]
|
||||||
|
num_buckets = dynamic_hparams["num_buckets"]
|
||||||
|
|
||||||
|
train_batch_sampler = DynamicBatchSampler(
|
||||||
|
train_data,
|
||||||
|
dynamic_hparams["max_batch_len"],
|
||||||
|
num_buckets=num_buckets,
|
||||||
|
length_func=lambda x: x["duration"],
|
||||||
|
shuffle=dynamic_hparams["shuffle_ex"],
|
||||||
|
batch_ordering=dynamic_hparams["batch_ordering"], )
|
||||||
|
|
||||||
|
valid_batch_sampler = DynamicBatchSampler(
|
||||||
|
valid_data,
|
||||||
|
dynamic_hparams["max_batch_len"],
|
||||||
|
num_buckets=num_buckets,
|
||||||
|
length_func=lambda x: x["duration"],
|
||||||
|
shuffle=dynamic_hparams["shuffle_ex"],
|
||||||
|
batch_ordering=dynamic_hparams["batch_ordering"], )
|
||||||
|
|
||||||
|
return (train_data, valid_data, test_data, tokenizer,
|
||||||
|
train_batch_sampler, valid_batch_sampler, )
|
||||||
|
|
||||||
|
def setup_dataloader(self):
|
||||||
|
config = self.config.clone()
|
||||||
|
self.use_streamdata = config.get("use_stream_data", False)
|
||||||
|
self.use_sb = config.get("use_sb_pipeline", False)
|
||||||
|
if self.use_sb:
|
||||||
|
hparams_file = config.sb_pipeline_conf
|
||||||
|
with open(hparams_file, 'r', encoding='utf8') as fin:
|
||||||
|
hparams = load_hyperpyyaml(fin, None)
|
||||||
|
|
||||||
|
(train_data, valid_data, test_data, tokenizer, train_bsampler,
|
||||||
|
valid_bsampler, ) = self.dataio_prepare(hparams)
|
||||||
|
|
||||||
|
train_dataloader_opts = hparams["train_dataloader_opts"]
|
||||||
|
valid_dataloader_opts = hparams["valid_dataloader_opts"]
|
||||||
|
|
||||||
|
if train_bsampler is not None:
|
||||||
|
train_dataloader_opts = {
|
||||||
|
"batch_sampler": train_bsampler,
|
||||||
|
"num_workers": hparams["num_workers"],
|
||||||
|
}
|
||||||
|
|
||||||
|
if valid_bsampler is not None:
|
||||||
|
valid_dataloader_opts = {"batch_sampler": valid_bsampler}
|
||||||
|
|
||||||
|
if self.train:
|
||||||
|
self.train_loader = make_dataloader(
|
||||||
|
train_data, stage='train', **train_dataloader_opts)
|
||||||
|
self.valid_loader = make_dataloader(
|
||||||
|
valid_data,
|
||||||
|
stage='val',
|
||||||
|
**valid_dataloader_opts, )
|
||||||
|
logger.info("Setup train/valid Dataloader!")
|
||||||
|
else:
|
||||||
|
self.test_loader = make_dataloader(
|
||||||
|
test_data, stage='test', **hparams["test_dataloader_opts"])
|
||||||
|
else:
|
||||||
|
if self.train:
|
||||||
|
self.train_loader = DataLoaderFactory.get_dataloader(
|
||||||
|
'train', config, self.args)
|
||||||
|
self.valid_loader = DataLoaderFactory.get_dataloader(
|
||||||
|
'valid', config, self.args)
|
||||||
|
logger.info("Setup train/valid Dataloader!")
|
||||||
|
else:
|
||||||
|
decode_batch_size = config.get('decode', dict()).get(
|
||||||
|
'decode_batch_size', 1)
|
||||||
|
self.test_loader = DataLoaderFactory.get_dataloader(
|
||||||
|
'test', config, self.args)
|
||||||
|
self.align_loader = DataLoaderFactory.get_dataloader(
|
||||||
|
'align', config, self.args)
|
||||||
|
logger.info("Setup test/align Dataloader!")
|
||||||
|
|
||||||
|
def setup_model(self):
|
||||||
|
config = self.config
|
||||||
|
model_conf = config
|
||||||
|
|
||||||
|
with UpdateConfig(model_conf):
|
||||||
|
if self.use_sb:
|
||||||
|
model_conf.output_dim = self.tokenizer.vocab_size
|
||||||
|
else:
|
||||||
|
if self.train:
|
||||||
|
model_conf.input_dim = self.train_loader.feat_dim
|
||||||
|
model_conf.output_dim = self.train_loader.vocab_size
|
||||||
|
else:
|
||||||
|
model_conf.input_dim = self.test_loader.feat_dim
|
||||||
|
model_conf.output_dim = self.test_loader.vocab_size
|
||||||
|
|
||||||
|
model = WavLMASR.from_config(model_conf)
|
||||||
|
|
||||||
|
model_dict = paddle.load(config.wavlm_params_path)
|
||||||
|
model.wavlm.set_state_dict(model_dict)
|
||||||
|
|
||||||
|
if self.parallel:
|
||||||
|
model = paddle.DataParallel(model, find_unused_parameters=True)
|
||||||
|
|
||||||
|
layer_tools.print_params(model, logger.info)
|
||||||
|
self.model = model
|
||||||
|
logger.info("Setup model!")
|
||||||
|
|
||||||
|
# setup speech augmentation for wavlm
|
||||||
|
if hasattr(config, 'audio_augment') and self.train:
|
||||||
|
self.speech_augmentation = TimeDomainSpecAugment(
|
||||||
|
**config.audio_augment)
|
||||||
|
|
||||||
|
if not self.train:
|
||||||
|
return
|
||||||
|
|
||||||
|
train_config = config
|
||||||
|
model_optim_type = train_config.model_optim
|
||||||
|
model_optim_conf = train_config.model_optim_conf
|
||||||
|
logger.info("optim_model:{},{}", model_optim_type, model_optim_conf)
|
||||||
|
wavlm_optim_type = train_config.wavlm_optim
|
||||||
|
wavlm_optim_conf = train_config.wavlm_optim_conf
|
||||||
|
logger.info("optim_model:{},{}", wavlm_optim_type,
|
||||||
|
wavlm_optim_conf)
|
||||||
|
|
||||||
|
model_scheduler_type = train_config.model_scheduler
|
||||||
|
model_scheduler_conf = train_config.model_scheduler_conf
|
||||||
|
wavlm_scheduler_type = train_config.wavlm_scheduler
|
||||||
|
wavlm_scheduler_conf = train_config.wavlm_scheduler_conf
|
||||||
|
|
||||||
|
model_scheduler_args = dict(
|
||||||
|
**{"learning_rate": model_optim_conf.lr,
|
||||||
|
"verbose": False}, **(dict(model_scheduler_conf)))
|
||||||
|
|
||||||
|
wavlm_scheduler_args = dict(
|
||||||
|
**{"learning_rate": wavlm_optim_conf.lr,
|
||||||
|
"verbose": False}, **(dict(wavlm_scheduler_conf)))
|
||||||
|
|
||||||
|
model_lr_scheduler = LRSchedulerFactory.from_args(model_scheduler_type,
|
||||||
|
model_scheduler_args)
|
||||||
|
wavlm_lr_scheduler = LRSchedulerFactory.from_args(
|
||||||
|
wavlm_scheduler_type, wavlm_scheduler_args)
|
||||||
|
|
||||||
|
def optimizer_args(
|
||||||
|
config,
|
||||||
|
optim_type,
|
||||||
|
optim_conf,
|
||||||
|
parameters,
|
||||||
|
lr_scheduler=None, ):
|
||||||
|
optim_arg = dict(optim_conf)
|
||||||
|
optim_arg.update({
|
||||||
|
"learning_rate":
|
||||||
|
lr_scheduler if lr_scheduler else optim_conf.lr,
|
||||||
|
"parameters":
|
||||||
|
parameters
|
||||||
|
})
|
||||||
|
return optim_arg
|
||||||
|
|
||||||
|
model_optimizer_args = optimizer_args(
|
||||||
|
config, model_optim_type,
|
||||||
|
model_optim_conf,
|
||||||
|
[{'params': model._layers.enc.parameters()}, {'params': model._layers.ctc.parameters()}] if self.parallel else [{'params': model.enc.parameters()}, {'params': model.ctc.parameters()}],
|
||||||
|
model_lr_scheduler
|
||||||
|
)
|
||||||
|
# [{'params': model._layers.ctc.parameters()}] if self.parallel else [{'params': model.ctc.parameters()}], model_lr_scheduler)
|
||||||
|
|
||||||
|
|
||||||
|
wavlm_optimizer_args = optimizer_args(
|
||||||
|
config, wavlm_optim_type, wavlm_optim_conf,
|
||||||
|
model._layers.wavlm.parameters() if self.parallel else
|
||||||
|
model.wavlm.parameters(), wavlm_lr_scheduler)
|
||||||
|
|
||||||
|
model_optimizer = OptimizerFactory.from_args(model_optim_type,
|
||||||
|
model_optimizer_args)
|
||||||
|
wavlm_optimizer = OptimizerFactory.from_args(wavlm_optim_type,
|
||||||
|
wavlm_optimizer_args)
|
||||||
|
|
||||||
|
self.model_optimizer = model_optimizer
|
||||||
|
self.wavlm_optimizer = wavlm_optimizer
|
||||||
|
self.model_lr_scheduler = model_lr_scheduler
|
||||||
|
self.wavlm_lr_scheduler = wavlm_lr_scheduler
|
||||||
|
logger.info("Setup optimizer/lr_scheduler!")
|
||||||
|
|
||||||
|
|
||||||
|
class WavLMASRTester(WavLMASRTrainer):
|
||||||
|
def __init__(self, config, args):
|
||||||
|
super().__init__(config, args)
|
||||||
|
self.text_featurizer = TextFeaturizer(
|
||||||
|
unit_type=config.unit_type, vocab=config.vocab_filepath)
|
||||||
|
self.vocab_list = self.text_featurizer.vocab_list
|
||||||
|
|
||||||
|
def id2token(self, texts, texts_len):
|
||||||
|
""" ord() id to chr() chr """
|
||||||
|
trans = []
|
||||||
|
for text, n in zip(texts, texts_len):
|
||||||
|
n = n.numpy().item()
|
||||||
|
ids = text[:n]
|
||||||
|
trans.append(self.text_featurizer.defeaturize(ids.numpy().tolist()))
|
||||||
|
return trans
|
||||||
|
|
||||||
|
def compute_metrics(self, id, audio, audio_len, texts, texts_len,
|
||||||
|
fout=None):
|
||||||
|
decode_cfg = self.config.decode
|
||||||
|
errors_sum, len_refs, num_ins = 0.0, 0, 0
|
||||||
|
errors_func = error_rate.char_errors if decode_cfg.error_rate_type == 'cer' else error_rate.word_errors
|
||||||
|
error_rate_func = error_rate.cer if decode_cfg.error_rate_type == 'cer' else error_rate.wer
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
target_transcripts = self.id2token(texts, texts_len)
|
||||||
|
result_transcripts, result_tokenids = self.model.decode(
|
||||||
|
audio,
|
||||||
|
text_feature=self.text_featurizer,
|
||||||
|
decoding_method=decode_cfg.decoding_method,
|
||||||
|
beam_size=decode_cfg.beam_size)
|
||||||
|
decode_time = time.time() - start_time
|
||||||
|
|
||||||
|
for utt, target, result, rec_tids in zip(
|
||||||
|
id, target_transcripts, result_transcripts, result_tokenids):
|
||||||
|
errors, len_ref = errors_func(target, result)
|
||||||
|
errors_sum += errors
|
||||||
|
len_refs += len_ref
|
||||||
|
num_ins += 1
|
||||||
|
if fout:
|
||||||
|
fout.write({
|
||||||
|
"utt": utt,
|
||||||
|
"refs": [target],
|
||||||
|
"hyps": [result],
|
||||||
|
"hyps_tokenid": [rec_tids],
|
||||||
|
})
|
||||||
|
logger.info(f"Utt: {utt}")
|
||||||
|
logger.info(f"Ref: {target}")
|
||||||
|
logger.info(f"Hyp: {result}")
|
||||||
|
logger.info("One example error rate [%s] = %f" % (
|
||||||
|
decode_cfg.error_rate_type, error_rate_func(target, result)))
|
||||||
|
|
||||||
|
return dict(
|
||||||
|
errors_sum=errors_sum,
|
||||||
|
len_refs=len_refs,
|
||||||
|
num_ins=num_ins, # num examples
|
||||||
|
error_rate=errors_sum / len_refs,
|
||||||
|
error_rate_type=decode_cfg.error_rate_type,
|
||||||
|
num_frames=audio_len.sum().numpy().item(),
|
||||||
|
decode_time=decode_time)
|
||||||
|
|
||||||
|
def sb_compute_metrics(self, id, sig, wrd, tokens, fout=None):
|
||||||
|
decode_cfg = self.config.decode
|
||||||
|
errors_sum, len_refs, num_ins = 0.0, 0, 0
|
||||||
|
errors_func = error_rate.char_errors if decode_cfg.error_rate_type == 'cer' else error_rate.word_errors
|
||||||
|
error_rate_func = error_rate.cer if decode_cfg.error_rate_type == 'cer' else error_rate.wer
|
||||||
|
start_time = time.time()
|
||||||
|
target_transcripts = wrd
|
||||||
|
result_transcripts, result_tokenids = self.model.decode(
|
||||||
|
sig[0],
|
||||||
|
text_feature=self.tokenizer,
|
||||||
|
decoding_method=decode_cfg.decoding_method,
|
||||||
|
beam_size=decode_cfg.beam_size,
|
||||||
|
sb_pipeline=True)
|
||||||
|
decode_time = time.time() - start_time
|
||||||
|
|
||||||
|
for utt, target, result, rec_tids in zip(
|
||||||
|
id, target_transcripts, result_transcripts, result_tokenids):
|
||||||
|
errors, len_ref = errors_func(target, result)
|
||||||
|
errors_sum += errors
|
||||||
|
len_refs += len_ref
|
||||||
|
num_ins += 1
|
||||||
|
if fout:
|
||||||
|
fout.write({
|
||||||
|
"utt": utt,
|
||||||
|
"refs": [target],
|
||||||
|
"hyps": [result],
|
||||||
|
"hyps_tokenid": [rec_tids],
|
||||||
|
})
|
||||||
|
logger.info(f"Utt: {utt}")
|
||||||
|
logger.info(f"Ref: {target}")
|
||||||
|
logger.info(f"Hyp: {result}")
|
||||||
|
logger.info("One example error rate [%s] = %f" % (
|
||||||
|
decode_cfg.error_rate_type, error_rate_func(target, result)))
|
||||||
|
|
||||||
|
return dict(
|
||||||
|
errors_sum=errors_sum,
|
||||||
|
len_refs=len_refs,
|
||||||
|
num_ins=num_ins, # num examples
|
||||||
|
error_rate=errors_sum / len_refs,
|
||||||
|
error_rate_type=decode_cfg.error_rate_type,
|
||||||
|
num_frames=sig[1].sum().numpy().item(),
|
||||||
|
decode_time=decode_time)
|
||||||
|
|
||||||
|
@mp_tools.rank_zero_only
|
||||||
|
@paddle.no_grad()
|
||||||
|
def test(self):
|
||||||
|
logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}")
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
error_rate_type = None
|
||||||
|
errors_sum, len_refs, num_ins = 0.0, 0, 0
|
||||||
|
num_frames = 0.0
|
||||||
|
num_time = 0.0
|
||||||
|
# Initialized the decoder in model
|
||||||
|
decode_cfg = self.config.decode
|
||||||
|
vocab_list = self.vocab_list
|
||||||
|
decode_batch_size = decode_cfg.decode_batch_size
|
||||||
|
|
||||||
|
with jsonlines.open(self.args.result_file, 'w') as fout:
|
||||||
|
for i, batch in enumerate(self.test_loader):
|
||||||
|
if self.use_sb:
|
||||||
|
metrics = self.sb_compute_metrics(**batch, fout=fout)
|
||||||
|
else:
|
||||||
|
metrics = self.compute_metrics(*batch, fout=fout)
|
||||||
|
num_frames += metrics['num_frames']
|
||||||
|
num_time += metrics["decode_time"]
|
||||||
|
errors_sum += metrics['errors_sum']
|
||||||
|
len_refs += metrics['len_refs']
|
||||||
|
num_ins += metrics['num_ins']
|
||||||
|
error_rate_type = metrics['error_rate_type']
|
||||||
|
rtf = num_time / (num_frames)
|
||||||
|
logger.info(
|
||||||
|
"RTF: %f, Error rate [%s] (%d/?) = %f" %
|
||||||
|
(rtf, error_rate_type, num_ins, errors_sum / len_refs))
|
||||||
|
|
||||||
|
# logging
|
||||||
|
msg = "Test: "
|
||||||
|
msg += "epoch: {}, ".format(self.epoch)
|
||||||
|
msg += "step: {}, ".format(self.iteration)
|
||||||
|
msg += "Final error rate [%s] (%d/%d) = %f" % (
|
||||||
|
error_rate_type, num_ins, num_ins, errors_sum / len_refs)
|
||||||
|
logger.info(msg)
|
||||||
|
|
||||||
|
err_meta_path = os.path.splitext(self.args.result_file)[0] + '.err'
|
||||||
|
err_type_str = "{}".format(error_rate_type)
|
||||||
|
with open(err_meta_path, 'w', encoding='utf8') as f:
|
||||||
|
data = json.dumps({
|
||||||
|
"epoch":
|
||||||
|
self.epoch,
|
||||||
|
"step":
|
||||||
|
self.iteration,
|
||||||
|
"rtf":
|
||||||
|
rtf,
|
||||||
|
error_rate_type:
|
||||||
|
errors_sum / len_refs,
|
||||||
|
"dataset_hour": (num_frames) / 1000.0 / 3600.0,
|
||||||
|
"process_hour":
|
||||||
|
num_time / 1000.0 / 3600.0,
|
||||||
|
"num_examples":
|
||||||
|
num_ins,
|
||||||
|
"err_sum":
|
||||||
|
errors_sum,
|
||||||
|
"ref_len":
|
||||||
|
len_refs,
|
||||||
|
"decode_method":
|
||||||
|
self.config.decode.decoding_method,
|
||||||
|
})
|
||||||
|
f.write(data + '\n')
|
@ -0,0 +1,88 @@
|
|||||||
|
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def _gelu_python(x):
|
||||||
|
"""
|
||||||
|
Original Implementation of the GELU activation function in Google BERT repo when initially created. For
|
||||||
|
information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
|
||||||
|
torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in
|
||||||
|
torch.nn.functional Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
|
||||||
|
"""
|
||||||
|
return x * 0.5 * (1.0 + paddle.erf(x / math.sqrt(2.0)))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def gelu_new(x):
|
||||||
|
"""
|
||||||
|
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
|
||||||
|
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
|
||||||
|
"""
|
||||||
|
return 0.5 * x * (1.0 + paddle.tanh(
|
||||||
|
math.sqrt(2.0 / math.pi) * (x + 0.044715 * paddle.pow(x, 3.0))))
|
||||||
|
|
||||||
|
|
||||||
|
def gelu_fast(x):
|
||||||
|
return 0.5 * x * (1.0 + paddle.tanh(x * 0.7978845608 *
|
||||||
|
(1.0 + 0.044715 * x * x)))
|
||||||
|
|
||||||
|
gelu = gelu_fast
|
||||||
|
|
||||||
|
def _silu_python(x):
|
||||||
|
"""
|
||||||
|
See Gaussian Error Linear Units (Hendrycks et al., https://arxiv.org/abs/1606.08415) where the SiLU (Sigmoid Linear
|
||||||
|
Unit) was originally introduced and coined, and see Sigmoid-Weighted Linear Units for Neural Network Function
|
||||||
|
Approximation in Reinforcement Learning (Elfwing et al., https://arxiv.org/abs/1702.03118) and Swish: a Self-Gated
|
||||||
|
Activation Function (Ramachandran et al., https://arxiv.org/abs/1710.05941v1) where the SiLU was experimented with
|
||||||
|
later.
|
||||||
|
"""
|
||||||
|
return x * paddle.nn.functional.sigmoid(x)
|
||||||
|
|
||||||
|
|
||||||
|
def mish(x):
|
||||||
|
return x * paddle.tanh(paddle.nn.functional.softplus(x))
|
||||||
|
|
||||||
|
|
||||||
|
def linear_act(x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
ACT2FN = {
|
||||||
|
"relu": F.relu,
|
||||||
|
"silu": _silu_python,
|
||||||
|
"swish": _silu_python,
|
||||||
|
"gelu": gelu,
|
||||||
|
"tanh": paddle.tanh,
|
||||||
|
"gelu_new": gelu_new,
|
||||||
|
"gelu_fast": gelu_fast,
|
||||||
|
"mish": mish,
|
||||||
|
"linear": linear_act,
|
||||||
|
"sigmoid": paddle.nn.functional.sigmoid,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_activation(activation_string):
|
||||||
|
if activation_string in ACT2FN:
|
||||||
|
return ACT2FN[activation_string]
|
||||||
|
else:
|
||||||
|
raise KeyError(
|
||||||
|
f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}"
|
||||||
|
)
|
@ -0,0 +1,508 @@
|
|||||||
|
import paddle
|
||||||
|
import paddle.nn as nn
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
from typing import Optional, List, Tuple
|
||||||
|
import math
|
||||||
|
|
||||||
|
def _mha_shape_check(query: paddle.Tensor, key: paddle.Tensor, value: paddle.Tensor,
|
||||||
|
key_padding_mask: Optional[paddle.Tensor], attn_mask: Optional[paddle.Tensor], num_heads: int):
|
||||||
|
# Verifies the expected shape for `query, `key`, `value`, `key_padding_mask` and `attn_mask`
|
||||||
|
# and returns if the input is batched or not.
|
||||||
|
# Raises an error if `query` is not 2-D (unbatched) or 3-D (batched) tensor.
|
||||||
|
|
||||||
|
# Shape check.
|
||||||
|
if query.dim() == 3:
|
||||||
|
# Batched Inputs
|
||||||
|
is_batched = True
|
||||||
|
assert key.dim() == 3 and value.dim() == 3, \
|
||||||
|
("For batched (3-D) `query`, expected `key` and `value` to be 3-D"
|
||||||
|
f" but found {key.dim()}-D and {value.dim()}-D tensors respectively")
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
assert key_padding_mask.dim() == 2, \
|
||||||
|
("For batched (3-D) `query`, expected `key_padding_mask` to be `None` or 2-D"
|
||||||
|
f" but found {key_padding_mask.dim()}-D tensor instead")
|
||||||
|
if attn_mask is not None:
|
||||||
|
assert attn_mask.dim() in (2, 3), \
|
||||||
|
("For batched (3-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D"
|
||||||
|
f" but found {attn_mask.dim()}-D tensor instead")
|
||||||
|
elif query.dim() == 2:
|
||||||
|
# Unbatched Inputs
|
||||||
|
is_batched = False
|
||||||
|
assert key.dim() == 2 and value.dim() == 2, \
|
||||||
|
("For unbatched (2-D) `query`, expected `key` and `value` to be 2-D"
|
||||||
|
f" but found {key.dim()}-D and {value.dim()}-D tensors respectively")
|
||||||
|
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
assert key_padding_mask.dim() == 1, \
|
||||||
|
("For unbatched (2-D) `query`, expected `key_padding_mask` to be `None` or 1-D"
|
||||||
|
f" but found {key_padding_mask.dim()}-D tensor instead")
|
||||||
|
|
||||||
|
if attn_mask is not None:
|
||||||
|
assert attn_mask.dim() in (2, 3), \
|
||||||
|
("For unbatched (2-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D"
|
||||||
|
f" but found {attn_mask.dim()}-D tensor instead")
|
||||||
|
if attn_mask.dim() == 3:
|
||||||
|
expected_shape = (num_heads, query.shape[0], key.shape[0])
|
||||||
|
assert attn_mask.shape == expected_shape, \
|
||||||
|
(f"Expected `attn_mask` shape to be {expected_shape} but got {attn_mask.shape}")
|
||||||
|
else:
|
||||||
|
raise AssertionError(
|
||||||
|
f"query should be unbatched 2D or batched 3D tensor but received {query.dim()}-D query tensor")
|
||||||
|
|
||||||
|
def masked_fill(x, mask, value):
|
||||||
|
y = paddle.full(x.shape, value)
|
||||||
|
|
||||||
|
|
||||||
|
def scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal):
|
||||||
|
"""
|
||||||
|
Scaled Dot-Product Attention
|
||||||
|
"""
|
||||||
|
|
||||||
|
d_key = k.shape[-1]
|
||||||
|
scaled_q = paddle.scale(x=q, scale=d_key ** -0.5)
|
||||||
|
product = paddle.matmul(x=scaled_q, y=k, transpose_y=True)
|
||||||
|
weights = paddle.nn.functional.softmax(x=product + attn_mask)
|
||||||
|
if dropout_p:
|
||||||
|
weights = paddle.fluid.layers.nn.dropout(
|
||||||
|
weights,
|
||||||
|
dropout_prob=dropout_p,
|
||||||
|
dropout_implementation="upscale_in_train",
|
||||||
|
is_test=False)
|
||||||
|
out = paddle.matmul(x=weights, y=v)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def addr(input, vec1, vec2, beta=1, alpha=1, out=None):
|
||||||
|
row = vec1.shape[0]
|
||||||
|
column = vec2.shape[0]
|
||||||
|
vec1 = paddle.unsqueeze(vec1, 0)
|
||||||
|
vec1 = paddle.transpose(vec1, [1, 0])
|
||||||
|
vec1 = paddle.expand(vec1, [row, column])
|
||||||
|
new_vec2 = paddle.zeros([column, column], dtype=vec2.dtype)
|
||||||
|
new_vec2[0, :] = vec2
|
||||||
|
out = alpha * paddle.matmul(vec1, new_vec2)
|
||||||
|
out = beta * input + out
|
||||||
|
return out
|
||||||
|
|
||||||
|
def multi_head_attention_forward(
|
||||||
|
x: paddle.Tensor,
|
||||||
|
num_heads: int,
|
||||||
|
q_proj: nn.Linear,
|
||||||
|
k_proj: nn.Linear,
|
||||||
|
v_proj: nn.Linear,
|
||||||
|
c_proj: nn.Linear,
|
||||||
|
attn_mask: Optional[paddle.Tensor] = None,
|
||||||
|
):
|
||||||
|
max_len, batch_size, emb_dim = x.shape
|
||||||
|
head_dim = emb_dim // num_heads
|
||||||
|
scaling = float(head_dim) ** -0.5
|
||||||
|
q = q_proj(x) # L, N, E
|
||||||
|
k = k_proj(x) # L, N, E
|
||||||
|
v = v_proj(x) # L, N, E
|
||||||
|
|
||||||
|
v = v.reshape((-1, batch_size * num_heads, head_dim)).transpose((1, 0, 2))
|
||||||
|
k = k.reshape((-1, batch_size * num_heads, head_dim)).transpose((1, 0, 2))
|
||||||
|
q = q.reshape((-1, batch_size * num_heads, head_dim)).transpose((1, 0, 2))
|
||||||
|
|
||||||
|
q = q * scaling
|
||||||
|
qk = paddle.matmul(q, k, transpose_y=True)
|
||||||
|
if attn_mask is not None:
|
||||||
|
if attn_mask.ndim == 2:
|
||||||
|
attn_mask.unsqueeze_(0)
|
||||||
|
assert attn_mask.shape[0] == 1 and attn_mask.shape[1] == max_len and attn_mask.shape[2] == max_len
|
||||||
|
qk += attn_mask
|
||||||
|
|
||||||
|
qk = F.softmax(qk, axis=-1)
|
||||||
|
atten = paddle.bmm(qk, v)
|
||||||
|
atten = atten.transpose((1, 0, 2))
|
||||||
|
atten = atten.reshape((max_len, batch_size, emb_dim))
|
||||||
|
atten = c_proj(atten)
|
||||||
|
return atten
|
||||||
|
|
||||||
|
def linear(input, weight, bias=None):
|
||||||
|
# compute y = x A^T + b
|
||||||
|
# Input: (N, in_feature) paddle tensor
|
||||||
|
# weight: (out_feature, in_feature) paddle tensor
|
||||||
|
# bias: (out_feature) paddle tensor
|
||||||
|
if input.dim() == 2 and bias is not None:
|
||||||
|
# fused op is marginally faster
|
||||||
|
return paddle.addmm(bias, input, weight)
|
||||||
|
output = paddle.matmul(input, weight)
|
||||||
|
if bias is not None:
|
||||||
|
output += bias
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def _in_projection_packed(
|
||||||
|
q: paddle.Tensor,
|
||||||
|
k: paddle.Tensor,
|
||||||
|
v: paddle.Tensor,
|
||||||
|
w: paddle.Tensor,
|
||||||
|
b: Optional[paddle.Tensor] = None,
|
||||||
|
) -> List[paddle.Tensor]:
|
||||||
|
r"""
|
||||||
|
Performs the in-projection step of the attention operation, using packed weights.
|
||||||
|
Output is a triple containing projection tensors for query, key and value.
|
||||||
|
Args:
|
||||||
|
q, k, v: query, key and value tensors to be projected. For self-attention,
|
||||||
|
these are typically the same tensor; for encoder-decoder attention,
|
||||||
|
k and v are typically the same tensor. (We take advantage of these
|
||||||
|
identities for performance if they are present.) Regardless, q, k and v
|
||||||
|
must share a common embedding dimension; otherwise their shapes may vary.
|
||||||
|
w: projection weights for q, k and v, packed into a single tensor. Weights
|
||||||
|
are packed along dimension 0, in q, k, v order.
|
||||||
|
b: optional projection biases for q, k and v, packed into a single tensor
|
||||||
|
in q, k, v order.
|
||||||
|
Shape:
|
||||||
|
Inputs:
|
||||||
|
- q: :math:`(..., E)` where E is the embedding dimension
|
||||||
|
- k: :math:`(..., E)` where E is the embedding dimension
|
||||||
|
- v: :math:`(..., E)` where E is the embedding dimension
|
||||||
|
- w: :math:`(E * 3, E)` where E is the embedding dimension
|
||||||
|
- b: :math:`E * 3` where E is the embedding dimension
|
||||||
|
Output:
|
||||||
|
- in output list :math:`[q', k', v']`, each output tensor will have the
|
||||||
|
same shape as the corresponding input tensor.
|
||||||
|
"""
|
||||||
|
# E = q.size(-1)
|
||||||
|
E = q.shape[-1]
|
||||||
|
if k is v:
|
||||||
|
if q is k:
|
||||||
|
# self-attention
|
||||||
|
proj = linear(q, w, b)
|
||||||
|
# reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk()
|
||||||
|
proj = proj.unflatten(-1, (3, E)).unsqueeze(0).transpose([2, 1, 0]).squeeze(-2).contiguous()
|
||||||
|
return proj[0], proj[1], proj[2]
|
||||||
|
else:
|
||||||
|
# encoder-decoder attention
|
||||||
|
w_q, w_kv = w.split([E, E * 2])
|
||||||
|
if b is None:
|
||||||
|
b_q = b_kv = None
|
||||||
|
else:
|
||||||
|
b_q, b_kv = b.split([E, E * 2])
|
||||||
|
q_proj = linear(q, w_q, b_q)
|
||||||
|
kv_proj = linear(k, w_kv, b_kv)
|
||||||
|
# reshape to 2, E and not E, 2 is deliberate for better memory coalescing and keeping same order as chunk()
|
||||||
|
kv_proj = kv_proj.unflatten(-1, (2, E)).unsqueeze(0).transpose([2, 1, 0]).squeeze(-2).contiguous()
|
||||||
|
return (q_proj, kv_proj[0], kv_proj[1])
|
||||||
|
else:
|
||||||
|
w_q, w_k, w_v = w.chunk(3)
|
||||||
|
if b is None:
|
||||||
|
b_q = b_k = b_v = None
|
||||||
|
else:
|
||||||
|
b_q, b_k, b_v = b.chunk(3)
|
||||||
|
return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
|
||||||
|
|
||||||
|
def _in_projection(
|
||||||
|
q: paddle.Tensor,
|
||||||
|
k: paddle.Tensor,
|
||||||
|
v: paddle.Tensor,
|
||||||
|
w_q: paddle.Tensor,
|
||||||
|
w_k: paddle.Tensor,
|
||||||
|
w_v: paddle.Tensor,
|
||||||
|
b_q: Optional[paddle.Tensor] = None,
|
||||||
|
b_k: Optional[paddle.Tensor] = None,
|
||||||
|
b_v: Optional[paddle.Tensor] = None,
|
||||||
|
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
|
||||||
|
A, B, C = linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
|
||||||
|
|
||||||
|
return A, B, C
|
||||||
|
# return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
|
||||||
|
|
||||||
|
def multi_head_attention_forward_paddle(
|
||||||
|
query: paddle.Tensor,
|
||||||
|
key: paddle.Tensor,
|
||||||
|
value: paddle.Tensor,
|
||||||
|
embed_dim_to_check: int,
|
||||||
|
num_heads: int,
|
||||||
|
in_proj_weight: Optional[paddle.Tensor],
|
||||||
|
in_proj_bias: Optional[paddle.Tensor],
|
||||||
|
bias_k: Optional[paddle.Tensor],
|
||||||
|
bias_v: Optional[paddle.Tensor],
|
||||||
|
add_zero_attn: bool,
|
||||||
|
dropout_p: float,
|
||||||
|
out_proj_weight: paddle.Tensor,
|
||||||
|
out_proj_bias: Optional[paddle.Tensor],
|
||||||
|
training: bool = True,
|
||||||
|
key_padding_mask: Optional[paddle.Tensor] = None,
|
||||||
|
need_weights: bool = True,
|
||||||
|
attn_mask: Optional[paddle.Tensor] = None,
|
||||||
|
use_separate_proj_weight: bool = False,
|
||||||
|
q_proj_weight: Optional[paddle.Tensor] = None,
|
||||||
|
k_proj_weight: Optional[paddle.Tensor] = None,
|
||||||
|
v_proj_weight: Optional[paddle.Tensor] = None,
|
||||||
|
static_k: Optional[paddle.Tensor] = None,
|
||||||
|
static_v: Optional[paddle.Tensor] = None,
|
||||||
|
average_attn_weights: bool = True,
|
||||||
|
is_causal: bool = False,
|
||||||
|
) -> Tuple[paddle.Tensor, Optional[paddle.Tensor]]:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
query, key, value: map a query and a set of key-value pairs to an output.
|
||||||
|
See "Attention Is All You Need" for more details.
|
||||||
|
embed_dim_to_check: total dimension of the model.
|
||||||
|
num_heads: parallel attention heads.
|
||||||
|
in_proj_weight, in_proj_bias: input projection weight and bias.
|
||||||
|
bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
|
||||||
|
add_zero_attn: add a new batch of zeros to the key and
|
||||||
|
value sequences at dim=1.
|
||||||
|
dropout_p: probability of an element to be zeroed.
|
||||||
|
out_proj_weight, out_proj_bias: the output projection weight and bias.
|
||||||
|
training: apply dropout if is ``True``.
|
||||||
|
key_padding_mask: if provided, specified padding elements in the key will
|
||||||
|
be ignored by the attention. This is an binary mask. When the value is True,
|
||||||
|
the corresponding value on the attention layer will be filled with -inf.
|
||||||
|
need_weights: output attn_output_weights.
|
||||||
|
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
||||||
|
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
||||||
|
is_causal: If specified, applies a causal mask as attention mask, and ignores
|
||||||
|
attn_mask for computing scaled dot product attention.
|
||||||
|
Default: ``False``.
|
||||||
|
use_separate_proj_weight: the function accept the proj. weights for query, key,
|
||||||
|
and value in different forms. If false, in_proj_weight will be used, which is
|
||||||
|
a combination of q_proj_weight, k_proj_weight, v_proj_weight.
|
||||||
|
q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
|
||||||
|
static_k, static_v: static key and value used for attention operators.
|
||||||
|
average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across heads.
|
||||||
|
Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an effect
|
||||||
|
when ``need_weights=True.``. Default: True
|
||||||
|
Shape:
|
||||||
|
Inputs:
|
||||||
|
- query: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
||||||
|
the embedding dimension.
|
||||||
|
- key: :math:`(S, E)` or :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
||||||
|
the embedding dimension.
|
||||||
|
- value: :math:`(S, E)` or :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
||||||
|
the embedding dimension.
|
||||||
|
- key_padding_mask: :math:`(S)` or :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
||||||
|
If a FloatTensor is provided, it will be directly added to the value.
|
||||||
|
If a BoolTensor is provided, the positions with the
|
||||||
|
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
||||||
|
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
||||||
|
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
||||||
|
S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
|
||||||
|
positions. If a BoolTensor is provided, positions with ``True``
|
||||||
|
are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
||||||
|
is provided, it will be added to the attention weight.
|
||||||
|
- static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
|
||||||
|
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
|
||||||
|
- static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
|
||||||
|
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
|
||||||
|
Outputs:
|
||||||
|
- attn_output: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
||||||
|
E is the embedding dimension.
|
||||||
|
- attn_output_weights: Only returned when ``need_weights=True``. If ``average_attn_weights=True``, returns
|
||||||
|
attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
|
||||||
|
:math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
|
||||||
|
:math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
|
||||||
|
head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads)
|
||||||
|
|
||||||
|
# For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
|
||||||
|
# is batched, run the computation and before returning squeeze the
|
||||||
|
# batch dimension so that the output doesn't carry this temporary batch dimension.
|
||||||
|
# if not is_batched:
|
||||||
|
# # unsqueeze if the input is unbatched
|
||||||
|
# query = query.unsqueeze(1)
|
||||||
|
# key = key.unsqueeze(1)
|
||||||
|
# value = value.unsqueeze(1)
|
||||||
|
# if key_padding_mask is not None:
|
||||||
|
# key_padding_mask = key_padding_mask.unsqueeze(0)
|
||||||
|
|
||||||
|
# set up shape vars
|
||||||
|
# import pdb; pdb.set_trace()
|
||||||
|
tgt_len, bsz, embed_dim = query.shape
|
||||||
|
# tgt_len, bsz, embed_dim = query.shape
|
||||||
|
src_len, _, _ = key.shape
|
||||||
|
|
||||||
|
if is_causal:
|
||||||
|
attn_mask = None
|
||||||
|
|
||||||
|
assert embed_dim == embed_dim_to_check, \
|
||||||
|
f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
|
||||||
|
if isinstance(embed_dim, paddle.Tensor):
|
||||||
|
# embed_dim can be a tensor when JIT tracing
|
||||||
|
head_dim = embed_dim.div(num_heads, rounding_mode='trunc')
|
||||||
|
else:
|
||||||
|
head_dim = embed_dim // num_heads
|
||||||
|
assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
|
||||||
|
if use_separate_proj_weight:
|
||||||
|
# allow MHA to have different embedding dimensions when separate projection weights are used
|
||||||
|
assert key.shape[:2] == value.shape[:2], \
|
||||||
|
f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
|
||||||
|
else:
|
||||||
|
assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
|
||||||
|
|
||||||
|
#
|
||||||
|
# compute in-projection
|
||||||
|
#
|
||||||
|
if not use_separate_proj_weight:
|
||||||
|
assert in_proj_weight is not None, "use_separate_proj_weight is False but in_proj_weight is None"
|
||||||
|
q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
|
||||||
|
|
||||||
|
else:
|
||||||
|
assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None"
|
||||||
|
assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None"
|
||||||
|
assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None"
|
||||||
|
if in_proj_bias is None:
|
||||||
|
b_q = b_k = b_v = None
|
||||||
|
else:
|
||||||
|
b_q, b_k, b_v = in_proj_bias.chunk(3)
|
||||||
|
|
||||||
|
q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v)
|
||||||
|
|
||||||
|
# prep attention mask
|
||||||
|
|
||||||
|
if attn_mask is not None:
|
||||||
|
# ensure attn_mask's dim is 3
|
||||||
|
if attn_mask.dim() == 2:
|
||||||
|
correct_2d_size = (tgt_len, src_len)
|
||||||
|
if attn_mask.shape != correct_2d_size:
|
||||||
|
raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
|
||||||
|
attn_mask = attn_mask.unsqueeze(0)
|
||||||
|
elif attn_mask.dim() == 3:
|
||||||
|
correct_3d_size = (bsz * num_heads, tgt_len, src_len)
|
||||||
|
if tuple(attn_mask.shape) != correct_3d_size:
|
||||||
|
raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
|
||||||
|
|
||||||
|
# add bias along batch dimension (currently second)
|
||||||
|
if bias_k is not None and bias_v is not None:
|
||||||
|
assert static_k is None, "bias cannot be added to static key."
|
||||||
|
assert static_v is None, "bias cannot be added to static value."
|
||||||
|
# k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
|
||||||
|
k = paddle.concat([k, bias_k.repeat(1, bsz, 1)], axis=1)
|
||||||
|
# v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
|
||||||
|
v = paddle.concat([v, bias_v.repeat(1, bsz, 1)], axis=1)
|
||||||
|
if attn_mask is not None:
|
||||||
|
# attn_mask = pad(attn_mask, (0, 1))
|
||||||
|
# pad last dim with 0 on one side and 1 on the other
|
||||||
|
attn_mask = paddle.concat([attn_mask, paddle.zeros_like(attn_mask[:, :, -1:])], axis=2)
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
# key_padding_mask = pad(key_padding_mask, (0, 1))
|
||||||
|
# pad last dim with 0 on one side and 1 on the other
|
||||||
|
key_padding_mask = paddle.concat([key_padding_mask, paddle.zeros_like(key_padding_mask[:, -1:])], axis=1)
|
||||||
|
else:
|
||||||
|
assert bias_k is None
|
||||||
|
assert bias_v is None
|
||||||
|
|
||||||
|
#
|
||||||
|
# reshape q, k, v for multihead attention and make em batch first
|
||||||
|
#
|
||||||
|
# q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
|
||||||
|
q = q.reshape([tgt_len, bsz * num_heads, head_dim]).transpose([1, 0, 2])
|
||||||
|
|
||||||
|
|
||||||
|
if static_k is None:
|
||||||
|
# k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
|
||||||
|
k = k.reshape([k.shape[0], bsz * num_heads, head_dim]).transpose([1, 0, 2])
|
||||||
|
else:
|
||||||
|
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
|
||||||
|
assert static_k.size(0) == bsz * num_heads, \
|
||||||
|
f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
|
||||||
|
assert static_k.size(2) == head_dim, \
|
||||||
|
f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
|
||||||
|
k = static_k
|
||||||
|
if static_v is None:
|
||||||
|
# v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
|
||||||
|
v = v.reshape([v.shape[0], bsz * num_heads, head_dim]).transpose([1, 0, 2])
|
||||||
|
else:
|
||||||
|
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
|
||||||
|
assert static_v.size(0) == bsz * num_heads, \
|
||||||
|
f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
|
||||||
|
assert static_v.size(2) == head_dim, \
|
||||||
|
f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
|
||||||
|
v = static_v
|
||||||
|
|
||||||
|
# add zero attention along batch dimension (now first)
|
||||||
|
if add_zero_attn:
|
||||||
|
zero_attn_shape = (bsz * num_heads, 1, head_dim)
|
||||||
|
# k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1)
|
||||||
|
k = paddle.concat([k, paddle.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], axis=1)
|
||||||
|
# v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1)
|
||||||
|
v = paddle.concat([v, paddle.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], axis=1)
|
||||||
|
if attn_mask is not None:
|
||||||
|
# attn_mask = pad(attn_mask, (0, 1))
|
||||||
|
attn_mask = paddle.concat([attn_mask, paddle.zeros_like(attn_mask[:, :, -1:])], axis=2)
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
# key_padding_mask = pad(key_padding_mask, (0, 1))
|
||||||
|
key_padding_mask = paddle.concat([key_padding_mask, paddle.zeros_like(key_padding_mask[:, -1:])], axis=1)
|
||||||
|
|
||||||
|
# update source sequence length after adjustments
|
||||||
|
src_len = k.shape[1]
|
||||||
|
|
||||||
|
# merge key padding and attention masks
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
assert key_padding_mask.shape == (bsz, src_len), \
|
||||||
|
f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
|
||||||
|
# key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len).expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
|
||||||
|
key_padding_mask = key_padding_mask.reshape([bsz, 1, 1, src_len]).expand([-1, num_heads, -1, -1]).reshape([bsz * num_heads, 1, src_len])
|
||||||
|
if attn_mask is None:
|
||||||
|
attn_mask = key_padding_mask
|
||||||
|
else:
|
||||||
|
attn_mask = attn_mask + key_padding_mask
|
||||||
|
|
||||||
|
# adjust dropout probability
|
||||||
|
if not training:
|
||||||
|
dropout_p = 0.0
|
||||||
|
|
||||||
|
#
|
||||||
|
# (deep breath) calculate attention and out projection
|
||||||
|
#
|
||||||
|
if need_weights:
|
||||||
|
B, Nt, E = q.shape
|
||||||
|
q_scaled = q / math.sqrt(E)
|
||||||
|
if attn_mask is not None:
|
||||||
|
# attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1))
|
||||||
|
attn_output_weights = addr(q_scaled, k.transpose(-2, -1))
|
||||||
|
else:
|
||||||
|
# attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
|
||||||
|
attn_output_weights = paddle.bmm(q_scaled, k.transpose(0, 2, 1))
|
||||||
|
# attn_output_weights = softmax(attn_output_weights, dim=-1)
|
||||||
|
attn_output_weights = paddle.nn.functional.softmax(attn_output_weights, axis=-1)
|
||||||
|
if dropout_p > 0.0:
|
||||||
|
# attn_output_weights = dropout(attn_output_weights, p=dropout_p)
|
||||||
|
attn_output_weights = paddle.nn.functional.dropout(attn_output_weights, p=dropout_p)
|
||||||
|
|
||||||
|
# attn_output = torch.bmm(attn_output_weights, v)
|
||||||
|
attn_output = paddle.bmm(attn_output_weights, v)
|
||||||
|
attn_output = attn_output.transpose([1, 0, 2]).reshape([tgt_len * bsz, embed_dim])
|
||||||
|
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
||||||
|
attn_output = attn_output.reshape([tgt_len, bsz, attn_output.shape[1]])
|
||||||
|
|
||||||
|
# optionally average attention weights over heads
|
||||||
|
# attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
||||||
|
attn_output_weights = attn_output_weights.reshape([bsz, num_heads, tgt_len, src_len])
|
||||||
|
if average_attn_weights:
|
||||||
|
attn_output_weights = attn_output_weights.mean(dim=1)
|
||||||
|
|
||||||
|
if not is_batched:
|
||||||
|
# squeeze the output if input was unbatched
|
||||||
|
attn_output = attn_output.squeeze(1)
|
||||||
|
attn_output_weights = attn_output_weights.squeeze(0)
|
||||||
|
return attn_output, attn_output_weights
|
||||||
|
else:
|
||||||
|
# attn_mask can be either (L,S) or (N*num_heads, L, S)
|
||||||
|
# if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S)
|
||||||
|
# in order to match the input for SDPA of (N, num_heads, L, S)
|
||||||
|
if attn_mask is not None:
|
||||||
|
if attn_mask.shape[0] == 1 and attn_mask.dim() == 3:
|
||||||
|
attn_mask = attn_mask.unsqueeze(0)
|
||||||
|
else:
|
||||||
|
# attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)
|
||||||
|
attn_mask = attn_mask.reshape([bsz, num_heads, -1, src_len])
|
||||||
|
|
||||||
|
q = q.reshape([bsz, num_heads, tgt_len, head_dim])
|
||||||
|
k = k.reshape([bsz, num_heads, src_len, head_dim])
|
||||||
|
v = v.reshape([bsz, num_heads, src_len, head_dim])
|
||||||
|
attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
|
||||||
|
attn_output = attn_output.transpose(perm=[2, 0, 1, 3]).reshape([bsz * tgt_len, embed_dim])
|
||||||
|
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
||||||
|
attn_output = attn_output.reshape([tgt_len, bsz, attn_output.shape[1]])
|
||||||
|
# if not is_batched:
|
||||||
|
# # squeeze the output if input was unbatched
|
||||||
|
# attn_output = attn_output.squeeze(1)
|
||||||
|
return attn_output, None
|
@ -0,0 +1,892 @@
|
|||||||
|
# --------------------------------------------------------
|
||||||
|
# paddle: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
|
||||||
|
# Github source: https://github.com/microsoft/unilm/tree/master/paddle
|
||||||
|
# Copyright (c) 2021 Microsoft
|
||||||
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
# Based on fairseq code bases
|
||||||
|
# https://github.com/pytorch/fairseq
|
||||||
|
# --------------------------------------------------------
|
||||||
|
|
||||||
|
import math
|
||||||
|
import warnings
|
||||||
|
from typing import Dict, Optional, Tuple
|
||||||
|
from .functional import multi_head_attention_forward_paddle
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import paddle.nn as nn
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
from paddle import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class TransposeLast(nn.Layer):
|
||||||
|
def __init__(self, deconstruct_idx=None):
|
||||||
|
super().__init__()
|
||||||
|
self.deconstruct_idx = deconstruct_idx
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.deconstruct_idx is not None:
|
||||||
|
x = x[self.deconstruct_idx]
|
||||||
|
return paddle.transpose(x, perm=[0, 2, 1])
|
||||||
|
|
||||||
|
|
||||||
|
class Fp32LayerNorm(nn.LayerNorm):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
output = F.layer_norm(
|
||||||
|
input.float(),
|
||||||
|
self.normalized_shape,
|
||||||
|
self.weight.float() if self.weight is not None else None,
|
||||||
|
self.bias.float() if self.bias is not None else None,
|
||||||
|
self.eps,
|
||||||
|
)
|
||||||
|
return output.type_as(input)
|
||||||
|
|
||||||
|
|
||||||
|
class Fp32GroupNorm(nn.GroupNorm):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
output = F.group_norm(
|
||||||
|
input.float(),
|
||||||
|
self.num_groups,
|
||||||
|
self.weight.float() if self.weight is not None else None,
|
||||||
|
self.bias.float() if self.bias is not None else None,
|
||||||
|
self.eps,
|
||||||
|
)
|
||||||
|
return output.type_as(input)
|
||||||
|
|
||||||
|
|
||||||
|
# class GradMultiply(torch.autograd.Function):
|
||||||
|
# convert into paddle equivalent
|
||||||
|
# class GradMultiply(torch.autograd.Function):
|
||||||
|
# @staticmethod
|
||||||
|
# def forward(ctx, x, scale):
|
||||||
|
# ctx.scale = scale
|
||||||
|
# res = x.new(x)
|
||||||
|
# return res
|
||||||
|
|
||||||
|
# @staticmethod
|
||||||
|
# def backward(ctx, grad):
|
||||||
|
# return grad * ctx.scale, None
|
||||||
|
|
||||||
|
|
||||||
|
class SamePad(nn.Layer):
|
||||||
|
def __init__(self, kernel_size, causal=False):
|
||||||
|
super().__init__()
|
||||||
|
if causal:
|
||||||
|
self.remove = kernel_size - 1
|
||||||
|
else:
|
||||||
|
self.remove = 1 if kernel_size % 2 == 0 else 0
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.remove > 0:
|
||||||
|
x = x[:, :, : -self.remove]
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Swish(nn.Layer):
|
||||||
|
"""Swish function
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Construct an MultiHeadedAttention object."""
|
||||||
|
super(Swish, self).__init__()
|
||||||
|
# self.act = torch.nn.Sigmoid()
|
||||||
|
self.act = nn.Sigmoid()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x * self.act(x)
|
||||||
|
|
||||||
|
|
||||||
|
class GLU_Linear(nn.Layer):
|
||||||
|
def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
|
||||||
|
super(GLU_Linear, self).__init__()
|
||||||
|
|
||||||
|
self.glu_type = glu_type
|
||||||
|
self.output_dim = output_dim
|
||||||
|
|
||||||
|
if glu_type == "sigmoid":
|
||||||
|
self.glu_act = nn.Sigmoid()
|
||||||
|
elif glu_type == "swish":
|
||||||
|
self.glu_act = Swish()
|
||||||
|
elif glu_type == "relu":
|
||||||
|
self.glu_act = nn.ReLU()
|
||||||
|
elif glu_type == "gelu":
|
||||||
|
self.glu_act = nn.GELU()
|
||||||
|
|
||||||
|
if bias_in_glu:
|
||||||
|
self.linear = nn.Linear(input_dim, output_dim * 2, True)
|
||||||
|
else:
|
||||||
|
self.linear = nn.Linear(input_dim, output_dim * 2, False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
|
||||||
|
x = self.linear(x)
|
||||||
|
|
||||||
|
if self.glu_type == "bilinear":
|
||||||
|
x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2])
|
||||||
|
else:
|
||||||
|
x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2]))
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def gelu_accurate(x):
|
||||||
|
if not hasattr(gelu_accurate, "_a"):
|
||||||
|
gelu_accurate._a = math.sqrt(2 / math.pi)
|
||||||
|
return (
|
||||||
|
0.5 * x * (1 + paddle.tanh(gelu_accurate._a * (x + 0.044715 * paddle.pow(x, 3))))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def gelu(x: Tensor) -> Tensor:
|
||||||
|
return nn.functional.gelu(x.astype("float32")).astype(x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def get_activation_fn(activation: str):
|
||||||
|
"""Returns the activation function corresponding to `activation`"""
|
||||||
|
|
||||||
|
if activation == "relu":
|
||||||
|
return F.relu
|
||||||
|
elif activation == "gelu":
|
||||||
|
return gelu
|
||||||
|
elif activation == "gelu_fast":
|
||||||
|
warnings.warn(
|
||||||
|
"--activation-fn=gelu_fast has been renamed to gelu_accurate"
|
||||||
|
)
|
||||||
|
return gelu_accurate
|
||||||
|
elif activation == "gelu_accurate":
|
||||||
|
return gelu_accurate
|
||||||
|
elif activation == "tanh":
|
||||||
|
# return torch.tanh
|
||||||
|
return paddle.tanh
|
||||||
|
elif activation == "linear":
|
||||||
|
return lambda x: x
|
||||||
|
elif activation == "glu":
|
||||||
|
return lambda x: x
|
||||||
|
else:
|
||||||
|
raise RuntimeError("--activation-fn {} not supported".format(activation))
|
||||||
|
|
||||||
|
|
||||||
|
def init_bert_params(module):
|
||||||
|
"""
|
||||||
|
Initialize the weights specific to the BERT Model.
|
||||||
|
This overrides the default initializations depending on the specified arguments.
|
||||||
|
1. If normal_init_linear_weights is set then weights of linear
|
||||||
|
layer will be initialized using the normal distribution and
|
||||||
|
bais will be set to the specified value.
|
||||||
|
2. If normal_init_embed_weights is set then weights of embedding
|
||||||
|
layer will be initialized using the normal distribution.
|
||||||
|
3. If normal_init_proj_weights is set then weights of
|
||||||
|
in_project_weight for MultiHeadAttention initialized using
|
||||||
|
the normal distribution (to be validated).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def normal_(data):
|
||||||
|
# with FSDP, module params will be on CUDA, so we cast them back to CPU
|
||||||
|
# so that the RNG is consistent with and without FSDP
|
||||||
|
data.copy_(
|
||||||
|
data.cpu().normal_(mean=0.0, std=0.02).to(data.device)
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
# normal_(module.weight.data)
|
||||||
|
if module.bias is not None:
|
||||||
|
# module.bias.data.zero_()
|
||||||
|
pass
|
||||||
|
if isinstance(module, nn.Embedding):
|
||||||
|
# normal_(module.weight.data)
|
||||||
|
if module.padding_idx is not None:
|
||||||
|
# module.weight.data[module.padding_idx].zero_()
|
||||||
|
pass
|
||||||
|
if isinstance(module, MultiheadAttention):
|
||||||
|
pass
|
||||||
|
# normal_(module.q_proj.weight.data)
|
||||||
|
# normal_(module.k_proj.weight.data)
|
||||||
|
# normal_(module.v_proj.weight.data)
|
||||||
|
|
||||||
|
|
||||||
|
def quant_noise(module, p, block_size):
|
||||||
|
"""
|
||||||
|
Wraps modules and applies quantization noise to the weights for
|
||||||
|
subsequent quantization with Iterative Product Quantization as
|
||||||
|
described in "Training with Quantization Noise for Extreme Model Compression"
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- module: nn.Layer
|
||||||
|
- p: amount of Quantization Noise
|
||||||
|
- block_size: size of the blocks for subsequent quantization with iPQ
|
||||||
|
|
||||||
|
Remarks:
|
||||||
|
- Module weights must have the right sizes wrt the block size
|
||||||
|
- Only Linear, Embedding and Conv2d modules are supported for the moment
|
||||||
|
- For more detail on how to quantize by blocks with convolutional weights,
|
||||||
|
see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
|
||||||
|
- We implement the simplest form of noise here as stated in the paper
|
||||||
|
which consists in randomly dropping blocks
|
||||||
|
"""
|
||||||
|
|
||||||
|
# if no quantization noise, don't register hook
|
||||||
|
if p <= 0:
|
||||||
|
return module
|
||||||
|
|
||||||
|
# supported modules
|
||||||
|
assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
|
||||||
|
|
||||||
|
# test whether module.weight has the right sizes wrt block_size
|
||||||
|
is_conv = module.weight.ndim == 4
|
||||||
|
|
||||||
|
# 2D matrix
|
||||||
|
if not is_conv:
|
||||||
|
assert (
|
||||||
|
module.weight.size(1) % block_size == 0
|
||||||
|
), "Input features must be a multiple of block sizes"
|
||||||
|
|
||||||
|
# 4D matrix
|
||||||
|
else:
|
||||||
|
# 1x1 convolutions
|
||||||
|
if module.kernel_size == (1, 1):
|
||||||
|
assert (
|
||||||
|
module.in_channels % block_size == 0
|
||||||
|
), "Input channels must be a multiple of block sizes"
|
||||||
|
# regular convolutions
|
||||||
|
else:
|
||||||
|
k = module.kernel_size[0] * module.kernel_size[1]
|
||||||
|
assert k % block_size == 0, "Kernel size must be a multiple of block size"
|
||||||
|
|
||||||
|
def _forward_pre_hook(mod, input):
|
||||||
|
# no noise for evaluation
|
||||||
|
if mod.training:
|
||||||
|
if not is_conv:
|
||||||
|
# gather weight and sizes
|
||||||
|
weight = mod.weight
|
||||||
|
in_features = weight.size(1)
|
||||||
|
out_features = weight.size(0)
|
||||||
|
|
||||||
|
# split weight matrix into blocks and randomly drop selected blocks
|
||||||
|
mask = paddle.zeros(
|
||||||
|
in_features // block_size * out_features, device=weight.device
|
||||||
|
)
|
||||||
|
mask.bernoulli_(p)
|
||||||
|
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# gather weight and sizes
|
||||||
|
weight = mod.weight
|
||||||
|
in_channels = mod.in_channels
|
||||||
|
out_channels = mod.out_channels
|
||||||
|
|
||||||
|
# split weight matrix into blocks and randomly drop selected blocks
|
||||||
|
if mod.kernel_size == (1, 1):
|
||||||
|
mask = paddle.zeros(
|
||||||
|
int(in_channels // block_size * out_channels),
|
||||||
|
device=weight.device,
|
||||||
|
)
|
||||||
|
mask.bernoulli_(p)
|
||||||
|
mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
|
||||||
|
else:
|
||||||
|
mask = paddle.zeros(
|
||||||
|
weight.size(0), weight.size(1), device=weight.device
|
||||||
|
)
|
||||||
|
|
||||||
|
mask.bernoulli_(p)
|
||||||
|
mask = (
|
||||||
|
mask.unsqueeze(2)
|
||||||
|
.unsqueeze(3)
|
||||||
|
.repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
|
||||||
|
)
|
||||||
|
|
||||||
|
# scale weights and apply mask
|
||||||
|
mask = mask.to(
|
||||||
|
# torch.bool
|
||||||
|
paddle.bool
|
||||||
|
) # x.bool() is not currently supported in TorchScript
|
||||||
|
s = 1 / (1 - p)
|
||||||
|
mod.weight.data = s * weight.masked_fill(mask, 0)
|
||||||
|
|
||||||
|
module.register_forward_pre_hook(_forward_pre_hook)
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
class MultiheadAttention(nn.Layer):
|
||||||
|
"""Multi-headed attention.
|
||||||
|
|
||||||
|
See "Attention Is All You Need" for more details.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim,
|
||||||
|
num_heads,
|
||||||
|
kdim=None,
|
||||||
|
vdim=None,
|
||||||
|
dropout=0.0,
|
||||||
|
bias=True,
|
||||||
|
add_bias_kv=False,
|
||||||
|
add_zero_attn=False,
|
||||||
|
self_attention=False,
|
||||||
|
encoder_decoder_attention=False,
|
||||||
|
q_noise=0.0,
|
||||||
|
qn_block_size=8,
|
||||||
|
has_relative_attention_bias=True,
|
||||||
|
num_buckets=32,
|
||||||
|
max_distance=128,
|
||||||
|
gru_rel_pos=True,
|
||||||
|
rescale_init=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.kdim = kdim if kdim is not None else embed_dim
|
||||||
|
self.vdim = vdim if vdim is not None else embed_dim
|
||||||
|
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
||||||
|
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.dropout_module = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
self.has_relative_attention_bias = has_relative_attention_bias
|
||||||
|
self.num_buckets = num_buckets
|
||||||
|
self.max_distance = max_distance
|
||||||
|
if self.has_relative_attention_bias:
|
||||||
|
self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
|
||||||
|
|
||||||
|
self.head_dim = embed_dim // num_heads
|
||||||
|
self.q_head_dim = self.head_dim
|
||||||
|
self.k_head_dim = self.head_dim
|
||||||
|
assert (
|
||||||
|
self.head_dim * num_heads == self.embed_dim
|
||||||
|
), "embed_dim must be divisible by num_heads"
|
||||||
|
self.scaling = self.head_dim ** -0.5
|
||||||
|
|
||||||
|
self.self_attention = self_attention
|
||||||
|
self.encoder_decoder_attention = encoder_decoder_attention
|
||||||
|
|
||||||
|
assert not self.self_attention or self.qkv_same_dim, (
|
||||||
|
"Self-attention requires query, key and " "value to be of the same size"
|
||||||
|
)
|
||||||
|
|
||||||
|
k_bias = True
|
||||||
|
if rescale_init:
|
||||||
|
k_bias = False
|
||||||
|
|
||||||
|
k_embed_dim = embed_dim
|
||||||
|
q_embed_dim = embed_dim
|
||||||
|
|
||||||
|
self.k_proj = quant_noise(
|
||||||
|
nn.Linear(self.kdim, k_embed_dim, bias_attr=k_bias), q_noise, qn_block_size
|
||||||
|
)
|
||||||
|
self.v_proj = quant_noise(
|
||||||
|
nn.Linear(self.vdim, embed_dim, bias_attr=bias), q_noise, qn_block_size
|
||||||
|
)
|
||||||
|
self.q_proj = quant_noise(
|
||||||
|
nn.Linear(embed_dim, q_embed_dim, bias_attr=bias), q_noise, qn_block_size
|
||||||
|
)
|
||||||
|
|
||||||
|
self.out_proj = quant_noise(
|
||||||
|
nn.Linear(embed_dim, embed_dim, bias_attr=bias), q_noise, qn_block_size
|
||||||
|
)
|
||||||
|
|
||||||
|
if add_bias_kv:
|
||||||
|
self.bias_k = self.create_parameter(
|
||||||
|
shape=[1, 1, embed_dim], dtype="float32"
|
||||||
|
)
|
||||||
|
self.bias_v = self.create_parameter(
|
||||||
|
shape=[1, 1, embed_dim], dtype="float32"
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.bias_k = self.bias_v = None
|
||||||
|
|
||||||
|
self.add_zero_attn = add_zero_attn
|
||||||
|
|
||||||
|
self.gru_rel_pos = gru_rel_pos
|
||||||
|
if self.gru_rel_pos:
|
||||||
|
self.grep_linear = nn.Linear(self.q_head_dim, 8)
|
||||||
|
# self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
|
||||||
|
self.grep_a = self.create_parameter(
|
||||||
|
shape=[1, num_heads, 1, 1], dtype="float32"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
pass
|
||||||
|
# if self.qkv_same_dim:
|
||||||
|
# # Empirically observed the convergence to be much better with
|
||||||
|
# # the scaled initialization
|
||||||
|
# # nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
|
||||||
|
# # self.k_proj.weight.set_value(
|
||||||
|
# # paddle.nn.initializer.XavierUniform(1, 1)(self.k_proj.weight.shape)
|
||||||
|
# # )
|
||||||
|
# # self.v_proj.weight.set_value(
|
||||||
|
# # paddle.nn.initializer.XavierUniform(1, 1)(self.v_proj.weight.shape)
|
||||||
|
# # )
|
||||||
|
# # self.q_proj.weight.set_value(
|
||||||
|
# # paddle.nn.initializer.XavierUniform(1, 1)(self.q_proj.weight.shape)
|
||||||
|
# # )
|
||||||
|
# pass
|
||||||
|
# # nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
|
||||||
|
# # nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
|
||||||
|
# else:
|
||||||
|
# # nn.init.xavier_uniform_(self.k_proj.weight)
|
||||||
|
# # nn.init.xavier_uniform_(self.v_proj.weight)
|
||||||
|
# # nn.init.xavier_uniform_(self.q_proj.weight)
|
||||||
|
# # self.k_proj.weight.set_value(
|
||||||
|
# # paddle.nn.initializer.XavierUniform()(self.k_proj.weight.shape)
|
||||||
|
# # )
|
||||||
|
# # self.v_proj.weight.set_value(
|
||||||
|
# # paddle.nn.initializer.XavierUniform()(self.v_proj.weight.shape)
|
||||||
|
# # )
|
||||||
|
# # self.q_proj.weight.set_value(
|
||||||
|
# # paddle.nn.initializer.XavierUniform()(self.q_proj.weight.shape)
|
||||||
|
# # )
|
||||||
|
# pass
|
||||||
|
|
||||||
|
|
||||||
|
# nn.init.xavier_uniform_(self.out_proj.weight)
|
||||||
|
# if self.out_proj.bias is not None:
|
||||||
|
# nn.init.constant_(self.out_proj.bias, 0.0)
|
||||||
|
# if self.bias_k is not None:
|
||||||
|
# nn.init.xavier_normal_(self.bias_k)
|
||||||
|
# if self.bias_v is not None:
|
||||||
|
# nn.init.xavier_normal_(self.bias_v)
|
||||||
|
# if self.has_relative_attention_bias:
|
||||||
|
# nn.init.xavier_normal_(self.relative_attention_bias.weight)
|
||||||
|
|
||||||
|
def _relative_positions_bucket(self, relative_positions, bidirectional=True):
|
||||||
|
num_buckets = self.num_buckets
|
||||||
|
max_distance = self.max_distance
|
||||||
|
relative_buckets = 0
|
||||||
|
|
||||||
|
if bidirectional:
|
||||||
|
num_buckets = num_buckets // 2
|
||||||
|
relative_buckets += (relative_positions > 0).astype("int64") * num_buckets
|
||||||
|
relative_positions = paddle.abs(relative_positions)
|
||||||
|
else:
|
||||||
|
relative_positions = -paddle.minimum(relative_positions, paddle.zeros_like(relative_positions))
|
||||||
|
|
||||||
|
max_exact = num_buckets // 2
|
||||||
|
is_small = relative_positions < max_exact
|
||||||
|
|
||||||
|
relative_postion_if_large = max_exact + (
|
||||||
|
paddle.log(relative_positions.astype("float32") / max_exact)
|
||||||
|
/ math.log(max_distance / max_exact)
|
||||||
|
* (num_buckets - max_exact)
|
||||||
|
).astype("int64")
|
||||||
|
relative_postion_if_large = paddle.minimum(
|
||||||
|
relative_postion_if_large, paddle.full_like(relative_postion_if_large, num_buckets - 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
relative_buckets += paddle.where(is_small, relative_positions, relative_postion_if_large)
|
||||||
|
return relative_buckets
|
||||||
|
|
||||||
|
def compute_bias(self, query_length, key_length):
|
||||||
|
context_position = paddle.arange(query_length, dtype="int64")[:, None]
|
||||||
|
memory_position = paddle.arange(key_length, dtype="int64")[None, :]
|
||||||
|
relative_position = memory_position - context_position
|
||||||
|
relative_position_bucket = self._relative_positions_bucket(
|
||||||
|
relative_position,
|
||||||
|
bidirectional=True
|
||||||
|
)
|
||||||
|
# relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
|
||||||
|
values = self.relative_attention_bias(relative_position_bucket)
|
||||||
|
values = values.transpose([2, 0, 1])
|
||||||
|
return values
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
query,
|
||||||
|
key: Optional[Tensor],
|
||||||
|
value: Optional[Tensor],
|
||||||
|
key_padding_mask: Optional[Tensor] = None,
|
||||||
|
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
||||||
|
need_weights: bool = True,
|
||||||
|
static_kv: bool = False,
|
||||||
|
attn_mask: Optional[Tensor] = None,
|
||||||
|
before_softmax: bool = False,
|
||||||
|
need_head_weights: bool = False,
|
||||||
|
position_bias: Optional[Tensor] = None
|
||||||
|
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
|
||||||
|
"""Input shape: Time x Batch x Channel
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key_padding_mask (ByteTensor, optional): mask to exclude
|
||||||
|
keys that are pads, of shape `(batch, src_len)`, where
|
||||||
|
padding elements are indicated by 1s.
|
||||||
|
need_weights (bool, optional): return the attention weights,
|
||||||
|
averaged over heads (default: False).
|
||||||
|
attn_mask (ByteTensor, optional): typically used to
|
||||||
|
implement causal attention, where the mask prevents the
|
||||||
|
attention from looking forward in time (default: None).
|
||||||
|
before_softmax (bool, optional): return the raw attention
|
||||||
|
weights and values before the attention softmax.
|
||||||
|
need_head_weights (bool, optional): return the attention
|
||||||
|
weights for each head. Implies *need_weights*. Default:
|
||||||
|
return the average attention weights over all heads.
|
||||||
|
"""
|
||||||
|
if need_head_weights:
|
||||||
|
need_weights = True
|
||||||
|
|
||||||
|
tgt_len, bsz, embed_dim = query.shape
|
||||||
|
src_len = tgt_len
|
||||||
|
assert embed_dim == self.embed_dim
|
||||||
|
assert list(query.shape) == [tgt_len, bsz, embed_dim]
|
||||||
|
if key is not None:
|
||||||
|
src_len, key_bsz, _ = key.shape
|
||||||
|
|
||||||
|
if self.has_relative_attention_bias and position_bias is None:
|
||||||
|
position_bias = self.compute_bias(tgt_len, src_len)
|
||||||
|
position_bias_ = position_bias.unsqueeze(0)
|
||||||
|
position_bias = paddle.concat([position_bias_ for _ in range(bsz)], axis=0)
|
||||||
|
position_bias = position_bias.reshape([bsz * self.num_heads, tgt_len, src_len])
|
||||||
|
if (
|
||||||
|
# not is_tpu # don't use PyTorch version on TPUs
|
||||||
|
incremental_state is None
|
||||||
|
and not static_kv
|
||||||
|
and self.q_head_dim == self.head_dim
|
||||||
|
):
|
||||||
|
assert key is not None and value is not None
|
||||||
|
assert attn_mask is None
|
||||||
|
|
||||||
|
attn_mask_rel_pos = None
|
||||||
|
if position_bias is not None:
|
||||||
|
attn_mask_rel_pos = position_bias
|
||||||
|
if self.gru_rel_pos:
|
||||||
|
query_layer = query.transpose([1, 0, 2])
|
||||||
|
new_x_shape = query_layer.shape[:-1] + [self.num_heads, -1]
|
||||||
|
query_layer = query_layer.reshape(new_x_shape)
|
||||||
|
query_layer = query_layer.transpose([0, 2, 1, 3])
|
||||||
|
_B, _H, _L, __ = query_layer.shape
|
||||||
|
|
||||||
|
gate_a, gate_b = paddle.nn.functional.sigmoid(self.grep_linear(query_layer).reshape([_B, _H, _L, 2, 4]).sum(-1, keepdim=False)).chunk(2, axis=-1)
|
||||||
|
|
||||||
|
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
|
||||||
|
attn_mask_rel_pos = gate_a_1.reshape([bsz * self.num_heads, -1, 1]) * position_bias
|
||||||
|
|
||||||
|
attn_mask_rel_pos = attn_mask_rel_pos.reshape((-1, tgt_len, tgt_len))
|
||||||
|
k_proj_bias = self.k_proj.bias
|
||||||
|
if k_proj_bias is None:
|
||||||
|
k_proj_bias = paddle.zeros_like(self.q_proj.bias)
|
||||||
|
|
||||||
|
|
||||||
|
x, attn = multi_head_attention_forward_paddle(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
self.embed_dim,
|
||||||
|
self.num_heads,
|
||||||
|
paddle.empty([0]),
|
||||||
|
paddle.concat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias), axis=0),
|
||||||
|
self.bias_k,
|
||||||
|
self.bias_v,
|
||||||
|
self.add_zero_attn,
|
||||||
|
self.dropout_module.p,
|
||||||
|
self.out_proj.weight,
|
||||||
|
self.out_proj.bias,
|
||||||
|
self.training,
|
||||||
|
key_padding_mask,
|
||||||
|
need_weights,
|
||||||
|
attn_mask_rel_pos,
|
||||||
|
use_separate_proj_weight=True,
|
||||||
|
q_proj_weight=self.q_proj.weight,
|
||||||
|
k_proj_weight=self.k_proj.weight,
|
||||||
|
v_proj_weight=self.v_proj.weight,
|
||||||
|
)
|
||||||
|
|
||||||
|
return x, attn, position_bias
|
||||||
|
|
||||||
|
if incremental_state is not None:
|
||||||
|
saved_state = self._get_input_buffer(incremental_state)
|
||||||
|
if saved_state is not None and "prev_key" in saved_state:
|
||||||
|
# previous time steps are cached - no need to recompute
|
||||||
|
# key and value if they are static
|
||||||
|
if static_kv:
|
||||||
|
assert self.encoder_decoder_attention and not self.self_attention
|
||||||
|
key = value = None
|
||||||
|
else:
|
||||||
|
saved_state = None
|
||||||
|
|
||||||
|
if self.self_attention:
|
||||||
|
q = self.q_proj(query)
|
||||||
|
k = self.k_proj(query)
|
||||||
|
v = self.v_proj(query)
|
||||||
|
elif self.encoder_decoder_attention:
|
||||||
|
# encoder-decoder attention
|
||||||
|
q = self.q_proj(query)
|
||||||
|
if key is None:
|
||||||
|
assert value is None
|
||||||
|
k = v = None
|
||||||
|
else:
|
||||||
|
k = self.k_proj(key)
|
||||||
|
v = self.v_proj(key)
|
||||||
|
|
||||||
|
else:
|
||||||
|
assert key is not None and value is not None
|
||||||
|
q = self.q_proj(query)
|
||||||
|
k = self.k_proj(key)
|
||||||
|
v = self.v_proj(value)
|
||||||
|
q *= self.scaling
|
||||||
|
|
||||||
|
if self.bias_k is not None:
|
||||||
|
assert self.bias_v is not None
|
||||||
|
k = paddle.concat([k, self.bias_k.repeat(1, bsz, 1)], axis=0)
|
||||||
|
v = paddle.concat([v, self.bias_v.repeat(1, bsz, 1)], axis=0)
|
||||||
|
if attn_mask is not None:
|
||||||
|
attn_mask = paddle.concat(
|
||||||
|
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], axis=1
|
||||||
|
)
|
||||||
|
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
key_padding_mask = paddle.concat(
|
||||||
|
[
|
||||||
|
key_padding_mask,
|
||||||
|
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
|
||||||
|
],
|
||||||
|
axis=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
q = (
|
||||||
|
q.contiguous()
|
||||||
|
.view(tgt_len, bsz * self.num_heads, self.q_head_dim)
|
||||||
|
.transpose(0, 1)
|
||||||
|
)
|
||||||
|
if k is not None:
|
||||||
|
k = (
|
||||||
|
k.contiguous()
|
||||||
|
.view(-1, bsz * self.num_heads, self.k_head_dim)
|
||||||
|
.transpose(0, 1)
|
||||||
|
)
|
||||||
|
if v is not None:
|
||||||
|
v = (
|
||||||
|
v.contiguous()
|
||||||
|
.view(-1, bsz * self.num_heads, self.head_dim)
|
||||||
|
.transpose(0, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
if saved_state is not None:
|
||||||
|
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
||||||
|
if "prev_key" in saved_state:
|
||||||
|
_prev_key = saved_state["prev_key"]
|
||||||
|
assert _prev_key is not None
|
||||||
|
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
|
||||||
|
if static_kv:
|
||||||
|
k = prev_key
|
||||||
|
else:
|
||||||
|
assert k is not None
|
||||||
|
k = paddle.concat([prev_key, k], axis=1)
|
||||||
|
src_len = k.size(1)
|
||||||
|
if "prev_value" in saved_state:
|
||||||
|
_prev_value = saved_state["prev_value"]
|
||||||
|
assert _prev_value is not None
|
||||||
|
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
|
||||||
|
if static_kv:
|
||||||
|
v = prev_value
|
||||||
|
else:
|
||||||
|
assert v is not None
|
||||||
|
v = paddle.concat([prev_value, v], axis=1)
|
||||||
|
prev_key_padding_mask: Optional[Tensor] = None
|
||||||
|
if "prev_key_padding_mask" in saved_state:
|
||||||
|
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
|
||||||
|
assert k is not None and v is not None
|
||||||
|
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
|
||||||
|
key_padding_mask=key_padding_mask,
|
||||||
|
prev_key_padding_mask=prev_key_padding_mask,
|
||||||
|
batch_size=bsz,
|
||||||
|
src_len=k.size(1),
|
||||||
|
static_kv=static_kv,
|
||||||
|
)
|
||||||
|
|
||||||
|
saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
|
||||||
|
saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
|
||||||
|
saved_state["prev_key_padding_mask"] = key_padding_mask
|
||||||
|
# In this branch incremental_state is never None
|
||||||
|
assert incremental_state is not None
|
||||||
|
incremental_state = self._set_input_buffer(incremental_state, saved_state)
|
||||||
|
assert k is not None
|
||||||
|
assert k.size(1) == src_len
|
||||||
|
|
||||||
|
# This is part of a workaround to get around fork/join parallelism
|
||||||
|
# not supporting Optional types.
|
||||||
|
if key_padding_mask is not None and key_padding_mask.dim() == 0:
|
||||||
|
key_padding_mask = None
|
||||||
|
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
assert key_padding_mask.size(0) == bsz
|
||||||
|
assert key_padding_mask.size(1) == src_len
|
||||||
|
|
||||||
|
if self.add_zero_attn:
|
||||||
|
assert v is not None
|
||||||
|
src_len += 1
|
||||||
|
k = paddle.concat([k, k.new_zeros((k.size(0), 1) + k.shape[2:])], axis=1)
|
||||||
|
v = paddle.concat([v, v.new_zeros((v.size(0), 1) + v.shape[2:])], axis=1)
|
||||||
|
if attn_mask is not None:
|
||||||
|
attn_mask = paddle.concat(
|
||||||
|
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], axis=1
|
||||||
|
)
|
||||||
|
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
key_padding_mask = paddle.concat(
|
||||||
|
[
|
||||||
|
key_padding_mask,
|
||||||
|
paddle.zeros(key_padding_mask.size(0), 1).type_as(
|
||||||
|
key_padding_mask
|
||||||
|
),
|
||||||
|
],
|
||||||
|
axis=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# attn_weights = torch.bmm(q, k.transpose(1, 2))
|
||||||
|
attn_weights = paddle.bmm(q, k.transpose(1, 2))
|
||||||
|
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
|
||||||
|
|
||||||
|
assert list(attn_weights.shape) == [bsz * self.num_heads, tgt_len, src_len]
|
||||||
|
|
||||||
|
if attn_mask is not None:
|
||||||
|
attn_mask = attn_mask.unsqueeze(0)
|
||||||
|
attn_weights += attn_mask
|
||||||
|
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
# don't attend to padding symbols
|
||||||
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||||
|
if not is_tpu:
|
||||||
|
attn_weights = attn_weights.masked_fill(
|
||||||
|
# key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
|
||||||
|
key_padding_mask.unsqueeze(1).unsqueeze(2).to(paddle.bool),
|
||||||
|
float("-inf"),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attn_weights = attn_weights.transpose(0, 2)
|
||||||
|
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
|
||||||
|
attn_weights = attn_weights.transpose(0, 2)
|
||||||
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
|
if before_softmax:
|
||||||
|
return attn_weights, v, position_bias
|
||||||
|
|
||||||
|
if position_bias is not None:
|
||||||
|
if self.gru_rel_pos == 1:
|
||||||
|
query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim)
|
||||||
|
_B, _H, _L, __ = query_layer.shape
|
||||||
|
# gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
|
||||||
|
# _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
|
||||||
|
gate_a, gate_b = paddle.sigmoid(self.grep_linear(query_layer).view(
|
||||||
|
_B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, axis=-1)
|
||||||
|
|
||||||
|
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
|
||||||
|
position_bias = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
|
||||||
|
|
||||||
|
position_bias = position_bias.view(attn_weights.shape)
|
||||||
|
|
||||||
|
attn_weights = attn_weights + position_bias
|
||||||
|
|
||||||
|
attn_weights_float = F.softmax(
|
||||||
|
attn_weights, dim=-1
|
||||||
|
)
|
||||||
|
attn_weights = attn_weights_float.type_as(attn_weights)
|
||||||
|
attn_probs = self.dropout_module(attn_weights)
|
||||||
|
|
||||||
|
assert v is not None
|
||||||
|
# attn = torch.bmm(attn_probs, v)
|
||||||
|
attn = paddle.bmm(attn_probs, v)
|
||||||
|
assert list(attn.shape) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
||||||
|
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
||||||
|
attn = self.out_proj(attn)
|
||||||
|
attn_weights: Optional[Tensor] = None
|
||||||
|
if need_weights:
|
||||||
|
attn_weights = attn_weights_float.view(
|
||||||
|
bsz, self.num_heads, tgt_len, src_len
|
||||||
|
).transpose(1, 0)
|
||||||
|
if not need_head_weights:
|
||||||
|
# average attention weights over heads
|
||||||
|
attn_weights = attn_weights.mean(dim=0)
|
||||||
|
|
||||||
|
return attn, attn_weights, position_bias
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _append_prev_key_padding_mask(
|
||||||
|
key_padding_mask: Optional[Tensor],
|
||||||
|
prev_key_padding_mask: Optional[Tensor],
|
||||||
|
batch_size: int,
|
||||||
|
src_len: int,
|
||||||
|
static_kv: bool,
|
||||||
|
) -> Optional[Tensor]:
|
||||||
|
# saved key padding masks have shape (bsz, seq_len)
|
||||||
|
if prev_key_padding_mask is not None and static_kv:
|
||||||
|
new_key_padding_mask = prev_key_padding_mask
|
||||||
|
elif prev_key_padding_mask is not None and key_padding_mask is not None:
|
||||||
|
# new_key_padding_mask = torch.cat(
|
||||||
|
# [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
|
||||||
|
# )
|
||||||
|
new_key_padding_mask = paddle.concat(
|
||||||
|
[prev_key_padding_mask.float(), key_padding_mask.float()], axis=1
|
||||||
|
)
|
||||||
|
# During incremental decoding, as the padding token enters and
|
||||||
|
# leaves the frame, there will be a time when prev or current
|
||||||
|
# is None
|
||||||
|
elif prev_key_padding_mask is not None:
|
||||||
|
if src_len > prev_key_padding_mask.size(1):
|
||||||
|
# filler = torch.zeros(
|
||||||
|
# (batch_size, src_len - prev_key_padding_mask.size(1)),
|
||||||
|
# device=prev_key_padding_mask.device,
|
||||||
|
# )
|
||||||
|
filler = paddle.zeros(
|
||||||
|
(batch_size, src_len - prev_key_padding_mask.size(1)),
|
||||||
|
device=prev_key_padding_mask.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# new_key_padding_mask = torch.cat(
|
||||||
|
# [prev_key_padding_mask.float(), filler.float()], dim=1
|
||||||
|
# )
|
||||||
|
new_key_padding_mask = paddle.concat(
|
||||||
|
[prev_key_padding_mask.float(), filler.float()], axis=1
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
new_key_padding_mask = prev_key_padding_mask.float()
|
||||||
|
elif key_padding_mask is not None:
|
||||||
|
if src_len > key_padding_mask.size(1):
|
||||||
|
# filler = torch.zeros(
|
||||||
|
# (batch_size, src_len - key_padding_mask.size(1)),
|
||||||
|
# device=key_padding_mask.device,
|
||||||
|
# )
|
||||||
|
filler = paddle.zeros(
|
||||||
|
(batch_size, src_len - key_padding_mask.size(1)),
|
||||||
|
device=key_padding_mask.device,
|
||||||
|
)
|
||||||
|
# new_key_padding_mask = torch.cat(
|
||||||
|
# [filler.float(), key_padding_mask.float()], dim=1
|
||||||
|
# )
|
||||||
|
new_key_padding_mask = paddle.concat(
|
||||||
|
[filler.float(), key_padding_mask.float()], axis=1
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
new_key_padding_mask = key_padding_mask.float()
|
||||||
|
else:
|
||||||
|
new_key_padding_mask = prev_key_padding_mask
|
||||||
|
return new_key_padding_mask
|
||||||
|
|
||||||
|
def _get_input_buffer(
|
||||||
|
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
|
||||||
|
) -> Dict[str, Optional[Tensor]]:
|
||||||
|
result = self.get_incremental_state(incremental_state, "attn_state")
|
||||||
|
if result is not None:
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
empty_result: Dict[str, Optional[Tensor]] = {}
|
||||||
|
return empty_result
|
||||||
|
|
||||||
|
def _set_input_buffer(
|
||||||
|
self,
|
||||||
|
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
|
||||||
|
buffer: Dict[str, Optional[Tensor]],
|
||||||
|
):
|
||||||
|
return self.set_incremental_state(incremental_state, "attn_state", buffer)
|
||||||
|
|
||||||
|
def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
|
||||||
|
return attn_weights
|
@ -0,0 +1,251 @@
|
|||||||
|
# Copyright (c) 2023 speechbrain Authors. All Rights Reserved.
|
||||||
|
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Modified from speechbrain 2023 (https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/processing/signal_processing.py)
|
||||||
|
"""
|
||||||
|
Low level signal processing utilities
|
||||||
|
Authors
|
||||||
|
* Peter Plantinga 2020
|
||||||
|
* Francois Grondin 2020
|
||||||
|
* William Aris 2020
|
||||||
|
* Samuele Cornell 2020
|
||||||
|
* Sarthak Yadav 2022
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
|
||||||
|
def blackman_window(window_length, periodic=True):
|
||||||
|
"""Blackman window function.
|
||||||
|
Arguments
|
||||||
|
---------
|
||||||
|
window_length : int
|
||||||
|
Controlling the returned window size.
|
||||||
|
periodic : bool
|
||||||
|
Determines whether the returned window trims off the
|
||||||
|
last duplicate value from the symmetric window
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
A 1-D tensor of size (window_length) containing the window
|
||||||
|
"""
|
||||||
|
if window_length == 0:
|
||||||
|
return []
|
||||||
|
if window_length == 1:
|
||||||
|
return paddle.ones([1])
|
||||||
|
if periodic:
|
||||||
|
window_length += 1
|
||||||
|
window = paddle.arange(window_length) * (np.pi / (window_length - 1))
|
||||||
|
window = 0.08 * paddle.cos(window * 4) - 0.5 * paddle.cos(window * 2) + 0.42
|
||||||
|
return window[:-1] if periodic else window
|
||||||
|
|
||||||
|
|
||||||
|
def compute_amplitude(waveforms, lengths=None, amp_type="avg", scale="linear"):
|
||||||
|
"""Compute amplitude of a batch of waveforms.
|
||||||
|
Arguments
|
||||||
|
---------
|
||||||
|
waveform : tensor
|
||||||
|
The waveforms used for computing amplitude.
|
||||||
|
Shape should be `[time]` or `[batch, time]` or
|
||||||
|
`[batch, time, channels]`.
|
||||||
|
lengths : tensor
|
||||||
|
The lengths of the waveforms excluding the padding.
|
||||||
|
Shape should be a single dimension, `[batch]`.
|
||||||
|
amp_type : str
|
||||||
|
Whether to compute "avg" average or "peak" amplitude.
|
||||||
|
Choose between ["avg", "peak"].
|
||||||
|
scale : str
|
||||||
|
Whether to compute amplitude in "dB" or "linear" scale.
|
||||||
|
Choose between ["linear", "dB"].
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
The average amplitude of the waveforms.
|
||||||
|
Example
|
||||||
|
-------
|
||||||
|
>>> signal = paddle.sin(paddle.arange(16000.0)).unsqueeze(0)
|
||||||
|
>>> compute_amplitude(signal, signal.size(1))
|
||||||
|
tensor([[0.6366]])
|
||||||
|
"""
|
||||||
|
if len(waveforms.shape) == 1:
|
||||||
|
waveforms = waveforms.unsqueeze(0)
|
||||||
|
|
||||||
|
assert amp_type in ["avg", "peak"]
|
||||||
|
assert scale in ["linear", "dB"]
|
||||||
|
|
||||||
|
if amp_type == "avg":
|
||||||
|
if lengths is None:
|
||||||
|
out = paddle.mean(paddle.abs(waveforms), axis=1, keepdim=True)
|
||||||
|
else:
|
||||||
|
wav_sum = paddle.sum(paddle.abs(waveforms), axis=1, keepdim=True)
|
||||||
|
out = wav_sum / lengths
|
||||||
|
elif amp_type == "peak":
|
||||||
|
out = paddle.max(paddle.abs(waveforms), axis=1, keepdim=True)[0]
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
if scale == "linear":
|
||||||
|
return out
|
||||||
|
elif scale == "dB":
|
||||||
|
return paddle.clip(20 * paddle.log10(out), min=-80) # clamp zeros
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def convolve1d(
|
||||||
|
waveform,
|
||||||
|
kernel,
|
||||||
|
padding=0,
|
||||||
|
pad_type="constant",
|
||||||
|
stride=1,
|
||||||
|
groups=1,
|
||||||
|
use_fft=False,
|
||||||
|
rotation_index=0, ):
|
||||||
|
"""Use paddle.nn.functional to perform 1d padding and conv.
|
||||||
|
Arguments
|
||||||
|
---------
|
||||||
|
waveform : tensor
|
||||||
|
The tensor to perform operations on.
|
||||||
|
kernel : tensor
|
||||||
|
The filter to apply during convolution.
|
||||||
|
padding : int or tuple
|
||||||
|
The padding (pad_left, pad_right) to apply.
|
||||||
|
If an integer is passed instead, this is passed
|
||||||
|
to the conv1d function and pad_type is ignored.
|
||||||
|
pad_type : str
|
||||||
|
The type of padding to use. Passed directly to
|
||||||
|
`paddle.nn.functional.pad`, see Paddle documentation
|
||||||
|
for available options.
|
||||||
|
stride : int
|
||||||
|
The number of units to move each time convolution is applied.
|
||||||
|
Passed to conv1d. Has no effect if `use_fft` is True.
|
||||||
|
groups : int
|
||||||
|
This option is passed to `conv1d` to split the input into groups for
|
||||||
|
convolution. Input channels should be divisible by the number of groups.
|
||||||
|
use_fft : bool
|
||||||
|
When `use_fft` is passed `True`, then compute the convolution in the
|
||||||
|
spectral domain using complex multiply. This is more efficient on CPU
|
||||||
|
when the size of the kernel is large (e.g. reverberation). WARNING:
|
||||||
|
Without padding, circular convolution occurs. This makes little
|
||||||
|
difference in the case of reverberation, but may make more difference
|
||||||
|
with different kernels.
|
||||||
|
rotation_index : int
|
||||||
|
This option only applies if `use_fft` is true. If so, the kernel is
|
||||||
|
rolled by this amount before convolution to shift the output location.
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
The convolved waveform.
|
||||||
|
Example
|
||||||
|
-------
|
||||||
|
>>> from speechbrain.dataio.dataio import read_audio
|
||||||
|
>>> signal = read_audio('tests/samples/single-mic/example1.wav')
|
||||||
|
>>> signal = signal.unsqueeze(0).unsqueeze(2)
|
||||||
|
>>> kernel = paddle.rand([1, 10, 1])
|
||||||
|
>>> signal = convolve1d(signal, kernel, padding=(9, 0))
|
||||||
|
"""
|
||||||
|
if len(waveform.shape) != 3:
|
||||||
|
raise ValueError("Convolve1D expects a 3-dimensional tensor")
|
||||||
|
|
||||||
|
# Move time dimension last, which pad and fft and conv expect.
|
||||||
|
waveform = waveform.transpose([0, 2, 1])
|
||||||
|
kernel = kernel.transpose([0, 2, 1])
|
||||||
|
# Padding can be a tuple (left_pad, right_pad) or an int
|
||||||
|
if isinstance(padding, tuple):
|
||||||
|
waveform = paddle.nn.functional.pad(
|
||||||
|
x=waveform, pad=padding, mode=pad_type, data_format='NCL')
|
||||||
|
|
||||||
|
# This approach uses FFT, which is more efficient if the kernel is large
|
||||||
|
if use_fft:
|
||||||
|
# Pad kernel to same length as signal, ensuring correct alignment
|
||||||
|
zero_length = waveform.shape[-1] - kernel.shape[-1]
|
||||||
|
|
||||||
|
# Handle case where signal is shorter
|
||||||
|
if zero_length < 0:
|
||||||
|
kernel = kernel[..., :zero_length]
|
||||||
|
zero_length = 0
|
||||||
|
|
||||||
|
# Perform rotation to ensure alignment
|
||||||
|
zeros = paddle.zeros(
|
||||||
|
[kernel.shape[0], kernel.shape[1], zero_length], dtype=kernel.dtype)
|
||||||
|
after_index = kernel[..., rotation_index:]
|
||||||
|
before_index = kernel[..., :rotation_index]
|
||||||
|
kernel = paddle.concat((after_index, zeros, before_index), axis=-1)
|
||||||
|
|
||||||
|
# Multiply in frequency domain to convolve in time domain
|
||||||
|
import paddle.fft as fft
|
||||||
|
|
||||||
|
result = fft.rfft(waveform) * fft.rfft(kernel)
|
||||||
|
convolved = fft.irfft(result, n=waveform.shape[-1])
|
||||||
|
|
||||||
|
# Use the implementation given by paddle, which should be efficient on GPU
|
||||||
|
else:
|
||||||
|
convolved = paddle.nn.functional.conv1d(
|
||||||
|
x=waveform,
|
||||||
|
weight=kernel,
|
||||||
|
stride=stride,
|
||||||
|
groups=groups,
|
||||||
|
padding=padding if not isinstance(padding, tuple) else 0, )
|
||||||
|
|
||||||
|
# Return time dimension to the second dimension.
|
||||||
|
return convolved.transpose([0, 2, 1])
|
||||||
|
|
||||||
|
|
||||||
|
def notch_filter(notch_freq, filter_width=101, notch_width=0.05):
|
||||||
|
"""Returns a notch filter constructed from a high-pass and low-pass filter.
|
||||||
|
(from https://tomroelandts.com/articles/
|
||||||
|
how-to-create-simple-band-pass-and-band-reject-filters)
|
||||||
|
Arguments
|
||||||
|
---------
|
||||||
|
notch_freq : float
|
||||||
|
frequency to put notch as a fraction of the
|
||||||
|
sampling rate / 2. The range of possible inputs is 0 to 1.
|
||||||
|
filter_width : int
|
||||||
|
Filter width in samples. Longer filters have
|
||||||
|
smaller transition bands, but are more inefficient.
|
||||||
|
notch_width : float
|
||||||
|
Width of the notch, as a fraction of the sampling_rate / 2.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Check inputs
|
||||||
|
assert 0 < notch_freq <= 1
|
||||||
|
assert filter_width % 2 != 0
|
||||||
|
pad = filter_width // 2
|
||||||
|
inputs = paddle.arange(filter_width) - pad
|
||||||
|
|
||||||
|
# Avoid frequencies that are too low
|
||||||
|
notch_freq += notch_width
|
||||||
|
|
||||||
|
# Define sinc function, avoiding division by zero
|
||||||
|
def sinc(x):
|
||||||
|
"Computes the sinc function."
|
||||||
|
|
||||||
|
def _sinc(x):
|
||||||
|
return paddle.sin(x) / x
|
||||||
|
|
||||||
|
# The zero is at the middle index
|
||||||
|
return paddle.concat(
|
||||||
|
[_sinc(x[:pad]), paddle.ones([1]), _sinc(x[pad + 1:])])
|
||||||
|
|
||||||
|
# Compute a low-pass filter with cutoff frequency notch_freq.
|
||||||
|
hlpf = sinc(3 * (notch_freq - notch_width) * inputs)
|
||||||
|
hlpf *= blackman_window(filter_width)
|
||||||
|
hlpf /= paddle.sum(hlpf)
|
||||||
|
|
||||||
|
# Compute a high-pass filter with cutoff frequency notch_freq.
|
||||||
|
hhpf = sinc(3 * (notch_freq + notch_width) * inputs)
|
||||||
|
hhpf *= blackman_window(filter_width)
|
||||||
|
hhpf /= -paddle.sum(hhpf)
|
||||||
|
hhpf[pad] += 1
|
||||||
|
|
||||||
|
# Adding filters creates notch filter
|
||||||
|
return (hlpf + hhpf).view(1, -1, 1)
|
@ -0,0 +1,901 @@
|
|||||||
|
# Copyright (c) 2023 speechbrain Authors. All Rights Reserved.
|
||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Modified from speechbrain(https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/processing/speech_augmentation.py)
|
||||||
|
"""Classes for mutating speech data for data augmentation.
|
||||||
|
This module provides classes that produce realistic distortions of speech
|
||||||
|
data for the purpose of training speech processing models. The list of
|
||||||
|
distortions includes adding noise, adding reverberation, changing speed,
|
||||||
|
and more. All the classes are of type `torch.nn.Module`. This gives the
|
||||||
|
possibility to have end-to-end differentiability and
|
||||||
|
backpropagate the gradient through them. In addition, all operations
|
||||||
|
are expected to be performed on the GPU (where available) for efficiency.
|
||||||
|
|
||||||
|
Authors
|
||||||
|
* Peter Plantinga 2020
|
||||||
|
"""
|
||||||
|
import math
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import paddle.nn as nn
|
||||||
|
|
||||||
|
from .signal_processing import compute_amplitude
|
||||||
|
from .signal_processing import convolve1d
|
||||||
|
from .signal_processing import notch_filter
|
||||||
|
|
||||||
|
|
||||||
|
class SpeedPerturb(nn.Layer):
|
||||||
|
"""Slightly speed up or slow down an audio signal.
|
||||||
|
Resample the audio signal at a rate that is similar to the original rate,
|
||||||
|
to achieve a slightly slower or slightly faster signal. This technique is
|
||||||
|
outlined in the paper: "Audio Augmentation for Speech Recognition"
|
||||||
|
Arguments
|
||||||
|
---------
|
||||||
|
orig_freq : int
|
||||||
|
The frequency of the original signal.
|
||||||
|
speeds : list
|
||||||
|
The speeds that the signal should be changed to, as a percentage of the
|
||||||
|
original signal (i.e. `speeds` is divided by 100 to get a ratio).
|
||||||
|
perturb_prob : float
|
||||||
|
The chance that the batch will be speed-
|
||||||
|
perturbed. By default, every batch is perturbed.
|
||||||
|
Example
|
||||||
|
-------
|
||||||
|
>>> from speechbrain.dataio.dataio import read_audio
|
||||||
|
>>> signal = read_audio('tests/samples/single-mic/example1.wav')
|
||||||
|
>>> perturbator = SpeedPerturb(orig_freq=16000, speeds=[90])
|
||||||
|
>>> clean = signal.unsqueeze(0)
|
||||||
|
>>> perturbed = perturbator(clean)
|
||||||
|
>>> clean.shape
|
||||||
|
paddle.shape([1, 52173])
|
||||||
|
>>> perturbed.shape
|
||||||
|
paddle.shape([1, 46956])
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
orig_freq,
|
||||||
|
speeds=[90, 100, 110],
|
||||||
|
perturb_prob=1.0, ):
|
||||||
|
super().__init__()
|
||||||
|
self.orig_freq = orig_freq
|
||||||
|
self.speeds = speeds
|
||||||
|
self.perturb_prob = perturb_prob
|
||||||
|
|
||||||
|
# Initialize index of perturbation
|
||||||
|
self.samp_index = 0
|
||||||
|
# Initialize resamplers
|
||||||
|
self.resamplers = []
|
||||||
|
for speed in self.speeds:
|
||||||
|
config = {
|
||||||
|
"orig_freq": self.orig_freq,
|
||||||
|
"new_freq": self.orig_freq * speed // 100,
|
||||||
|
}
|
||||||
|
self.resamplers.append(Resample(**config))
|
||||||
|
|
||||||
|
def forward(self, waveform):
|
||||||
|
"""
|
||||||
|
Arguments
|
||||||
|
---------
|
||||||
|
waveforms : tensor
|
||||||
|
Shape should be `[batch, time]` or `[batch, time, channels]`.
|
||||||
|
lengths : tensor
|
||||||
|
Shape should be a single dimension, `[batch]`.
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tensor of shape `[batch, time]` or `[batch, time, channels]`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Don't perturb (return early) 1-`perturb_prob` portion of the batches
|
||||||
|
if paddle.rand([1]) > self.perturb_prob:
|
||||||
|
return waveform.clone()
|
||||||
|
# Perform a random perturbation
|
||||||
|
self.samp_index = paddle.randint(len(self.speeds), shape=(1, ))[0]
|
||||||
|
perturbed_waveform = self.resamplers[self.samp_index](waveform)
|
||||||
|
|
||||||
|
return perturbed_waveform
|
||||||
|
|
||||||
|
|
||||||
|
class Resample(nn.Layer):
|
||||||
|
"""This class resamples an audio signal using sinc-based interpolation.
|
||||||
|
|
||||||
|
It is a modification of the `resample` function from torchaudio
|
||||||
|
(https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html)
|
||||||
|
|
||||||
|
Arguments
|
||||||
|
---------
|
||||||
|
orig_freq : int
|
||||||
|
the sampling frequency of the input signal.
|
||||||
|
new_freq : int
|
||||||
|
the new sampling frequency after this operation is performed.
|
||||||
|
lowpass_filter_width : int
|
||||||
|
Controls the sharpness of the filter, larger numbers result in a
|
||||||
|
sharper filter, but they are less efficient. Values from 4 to 10 are allowed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
orig_freq=16000,
|
||||||
|
new_freq=16000,
|
||||||
|
lowpass_filter_width=6, ):
|
||||||
|
super().__init__()
|
||||||
|
self.orig_freq = orig_freq
|
||||||
|
self.new_freq = new_freq
|
||||||
|
self.lowpass_filter_width = lowpass_filter_width
|
||||||
|
|
||||||
|
# Compute rate for striding
|
||||||
|
self._compute_strides()
|
||||||
|
assert self.orig_freq % self.conv_stride == 0
|
||||||
|
assert self.new_freq % self.conv_transpose_stride == 0
|
||||||
|
|
||||||
|
def _compute_strides(self):
|
||||||
|
"""Compute the phases in polyphase filter.
|
||||||
|
|
||||||
|
(almost directly from torchaudio.compliance.kaldi)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Compute new unit based on ratio of in/out frequencies
|
||||||
|
base_freq = math.gcd(self.orig_freq, self.new_freq)
|
||||||
|
input_samples_in_unit = self.orig_freq // base_freq
|
||||||
|
self.output_samples = self.new_freq // base_freq
|
||||||
|
|
||||||
|
# Store the appropriate stride based on the new units
|
||||||
|
self.conv_stride = input_samples_in_unit
|
||||||
|
self.conv_transpose_stride = self.output_samples
|
||||||
|
|
||||||
|
def forward(self, waveforms):
|
||||||
|
"""
|
||||||
|
Arguments
|
||||||
|
---------
|
||||||
|
waveforms : tensor
|
||||||
|
Shape should be `[batch, time]` or `[batch, time, channels]`.
|
||||||
|
lengths : tensor
|
||||||
|
Shape should be a single dimension, `[batch]`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tensor of shape `[batch, time]` or `[batch, time, channels]`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not hasattr(self, "first_indices"):
|
||||||
|
self._indices_and_weights(waveforms)
|
||||||
|
|
||||||
|
# Don't do anything if the frequencies are the same
|
||||||
|
if self.orig_freq == self.new_freq:
|
||||||
|
return waveforms
|
||||||
|
unsqueezed = False
|
||||||
|
if len(waveforms.shape) == 2:
|
||||||
|
waveforms = waveforms.unsqueeze(1)
|
||||||
|
unsqueezed = True
|
||||||
|
elif len(waveforms.shape) == 3:
|
||||||
|
waveforms = waveforms.transpose([0, 2, 1])
|
||||||
|
else:
|
||||||
|
raise ValueError("Input must be 2 or 3 dimensions")
|
||||||
|
|
||||||
|
# Do resampling
|
||||||
|
resampled_waveform = self._perform_resample(waveforms)
|
||||||
|
|
||||||
|
if unsqueezed:
|
||||||
|
resampled_waveform = resampled_waveform.squeeze(1)
|
||||||
|
else:
|
||||||
|
resampled_waveform = resampled_waveform.transpose([0, 2, 1])
|
||||||
|
|
||||||
|
return resampled_waveform
|
||||||
|
|
||||||
|
def _perform_resample(self, waveforms):
|
||||||
|
"""Resamples the waveform at the new frequency.
|
||||||
|
|
||||||
|
This matches Kaldi's OfflineFeatureTpl ResampleWaveform which uses a
|
||||||
|
LinearResample (resample a signal at linearly spaced intervals to
|
||||||
|
up/downsample a signal). LinearResample (LR) means that the output
|
||||||
|
signal is at linearly spaced intervals (i.e the output signal has a
|
||||||
|
frequency of `new_freq`). It uses sinc/bandlimited interpolation to
|
||||||
|
upsample/downsample the signal.
|
||||||
|
|
||||||
|
(almost directly from torchaudio.compliance.kaldi)
|
||||||
|
|
||||||
|
https://ccrma.stanford.edu/~jos/resample/
|
||||||
|
Theory_Ideal_Bandlimited_Interpolation.html
|
||||||
|
|
||||||
|
https://github.com/kaldi-asr/kaldi/blob/master/src/feat/resample.h#L56
|
||||||
|
|
||||||
|
Arguments
|
||||||
|
---------
|
||||||
|
waveforms : tensor
|
||||||
|
The batch of audio signals to resample.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
The waveforms at the new frequency.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Compute output size and initialize
|
||||||
|
batch_size, num_channels, wave_len = waveforms.shape
|
||||||
|
window_size = self.weights.shape[1]
|
||||||
|
tot_output_samp = self._output_samples(wave_len)
|
||||||
|
resampled_waveform = paddle.zeros(
|
||||||
|
(batch_size, num_channels, tot_output_samp))
|
||||||
|
# self.weights = self.weights.to(waveforms.device)
|
||||||
|
|
||||||
|
# Check weights are on correct device
|
||||||
|
# if waveforms.device != self.weights.device:
|
||||||
|
# self.weights = self.weights.to(waveforms.device)
|
||||||
|
|
||||||
|
# eye size: (num_channels, num_channels, 1)
|
||||||
|
eye = paddle.eye(num_channels).unsqueeze(2)
|
||||||
|
|
||||||
|
# Iterate over the phases in the polyphase filter
|
||||||
|
for i in range(self.first_indices.shape[0]):
|
||||||
|
wave_to_conv = waveforms
|
||||||
|
first_index = int(self.first_indices[i].item())
|
||||||
|
if first_index >= 0:
|
||||||
|
# trim the signal as the filter will not be applied
|
||||||
|
# before the first_index
|
||||||
|
wave_to_conv = wave_to_conv[..., first_index:]
|
||||||
|
|
||||||
|
# pad the right of the signal to allow partial convolutions
|
||||||
|
# meaning compute values for partial windows (e.g. end of the
|
||||||
|
# window is outside the signal length)
|
||||||
|
max_index = (tot_output_samp - 1) // self.output_samples
|
||||||
|
end_index = max_index * self.conv_stride + window_size
|
||||||
|
current_wave_len = wave_len - first_index
|
||||||
|
right_padding = max(0, end_index + 1 - current_wave_len)
|
||||||
|
left_padding = max(0, -first_index)
|
||||||
|
wave_to_conv = paddle.nn.functional.pad(
|
||||||
|
wave_to_conv, (left_padding, right_padding), data_format='NCL')
|
||||||
|
conv_wave = paddle.nn.functional.conv1d(
|
||||||
|
x=wave_to_conv,
|
||||||
|
weight=self.weights[i].repeat(num_channels, 1, 1),
|
||||||
|
stride=self.conv_stride,
|
||||||
|
groups=num_channels, )
|
||||||
|
|
||||||
|
# we want conv_wave[:, i] to be at
|
||||||
|
# output[:, i + n*conv_transpose_stride]
|
||||||
|
dilated_conv_wave = paddle.nn.functional.conv1d_transpose(
|
||||||
|
conv_wave, eye, stride=self.conv_transpose_stride)
|
||||||
|
|
||||||
|
# pad dilated_conv_wave so it reaches the output length if needed.
|
||||||
|
left_padding = i
|
||||||
|
previous_padding = left_padding + dilated_conv_wave.shape[-1]
|
||||||
|
right_padding = max(0, tot_output_samp - previous_padding)
|
||||||
|
dilated_conv_wave = paddle.nn.functional.pad(
|
||||||
|
dilated_conv_wave, (left_padding, right_padding),
|
||||||
|
data_format='NCL')
|
||||||
|
dilated_conv_wave = dilated_conv_wave[..., :tot_output_samp]
|
||||||
|
|
||||||
|
resampled_waveform += dilated_conv_wave
|
||||||
|
|
||||||
|
return resampled_waveform
|
||||||
|
|
||||||
|
def _output_samples(self, input_num_samp):
|
||||||
|
"""Based on LinearResample::GetNumOutputSamples.
|
||||||
|
|
||||||
|
LinearResample (LR) means that the output signal is at
|
||||||
|
linearly spaced intervals (i.e the output signal has a
|
||||||
|
frequency of ``new_freq``). It uses sinc/bandlimited
|
||||||
|
interpolation to upsample/downsample the signal.
|
||||||
|
|
||||||
|
(almost directly from torchaudio.compliance.kaldi)
|
||||||
|
|
||||||
|
Arguments
|
||||||
|
---------
|
||||||
|
input_num_samp : int
|
||||||
|
The number of samples in each example in the batch.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Number of samples in the output waveform.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# For exact computation, we measure time in "ticks" of 1.0 / tick_freq,
|
||||||
|
# where tick_freq is the least common multiple of samp_in and
|
||||||
|
# samp_out.
|
||||||
|
samp_in = int(self.orig_freq)
|
||||||
|
samp_out = int(self.new_freq)
|
||||||
|
|
||||||
|
tick_freq = abs(samp_in * samp_out) // math.gcd(samp_in, samp_out)
|
||||||
|
ticks_per_input_period = tick_freq // samp_in
|
||||||
|
|
||||||
|
# work out the number of ticks in the time interval
|
||||||
|
# [ 0, input_num_samp/samp_in ).
|
||||||
|
interval_length = input_num_samp * ticks_per_input_period
|
||||||
|
if interval_length <= 0:
|
||||||
|
return 0
|
||||||
|
ticks_per_output_period = tick_freq // samp_out
|
||||||
|
|
||||||
|
# Get the last output-sample in the closed interval,
|
||||||
|
# i.e. replacing [ ) with [ ]. Note: integer division rounds down.
|
||||||
|
# See http://en.wikipedia.org/wiki/Interval_(mathematics) for an
|
||||||
|
# explanation of the notation.
|
||||||
|
last_output_samp = interval_length // ticks_per_output_period
|
||||||
|
|
||||||
|
# We need the last output-sample in the open interval, so if it
|
||||||
|
# takes us to the end of the interval exactly, subtract one.
|
||||||
|
if last_output_samp * ticks_per_output_period == interval_length:
|
||||||
|
last_output_samp -= 1
|
||||||
|
|
||||||
|
# First output-sample index is zero, so the number of output samples
|
||||||
|
# is the last output-sample plus one.
|
||||||
|
num_output_samp = last_output_samp + 1
|
||||||
|
|
||||||
|
return num_output_samp
|
||||||
|
|
||||||
|
def _indices_and_weights(self, waveforms):
|
||||||
|
"""Based on LinearResample::SetIndexesAndWeights
|
||||||
|
|
||||||
|
Retrieves the weights for resampling as well as the indices in which
|
||||||
|
they are valid. LinearResample (LR) means that the output signal is at
|
||||||
|
linearly spaced intervals (i.e the output signal has a frequency
|
||||||
|
of ``new_freq``). It uses sinc/bandlimited interpolation to
|
||||||
|
upsample/downsample the signal.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
- the place where each filter should start being applied
|
||||||
|
- the filters to be applied to the signal for resampling
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Lowpass filter frequency depends on smaller of two frequencies
|
||||||
|
min_freq = min(self.orig_freq, self.new_freq)
|
||||||
|
lowpass_cutoff = 0.99 * 0.5 * min_freq
|
||||||
|
|
||||||
|
assert lowpass_cutoff * 2 <= min_freq
|
||||||
|
window_width = self.lowpass_filter_width / (2.0 * lowpass_cutoff)
|
||||||
|
|
||||||
|
assert lowpass_cutoff < min(self.orig_freq, self.new_freq) / 2
|
||||||
|
output_t = paddle.arange(start=0.0, end=self.output_samples)
|
||||||
|
output_t /= self.new_freq
|
||||||
|
min_t = output_t - window_width
|
||||||
|
max_t = output_t + window_width
|
||||||
|
|
||||||
|
min_input_index = paddle.ceil(min_t * self.orig_freq)
|
||||||
|
max_input_index = paddle.floor(max_t * self.orig_freq)
|
||||||
|
num_indices = max_input_index - min_input_index + 1
|
||||||
|
|
||||||
|
max_weight_width = num_indices.max()
|
||||||
|
j = paddle.arange(max_weight_width)
|
||||||
|
input_index = min_input_index.unsqueeze(1) + j.unsqueeze(0)
|
||||||
|
delta_t = (input_index / self.orig_freq) - output_t.unsqueeze(1)
|
||||||
|
|
||||||
|
weights = paddle.zeros_like(delta_t)
|
||||||
|
|
||||||
|
inside_window_indices = delta_t.abs() < (window_width)
|
||||||
|
# raised-cosine (Hanning) window with width `window_width`
|
||||||
|
weights[inside_window_indices] = 0.5 * (1 + paddle.cos(
|
||||||
|
2 * math.pi * lowpass_cutoff / self.lowpass_filter_width *
|
||||||
|
delta_t[inside_window_indices]))
|
||||||
|
t_eq_zero_indices = delta_t == 0.0
|
||||||
|
t_not_eq_zero_indices = ~t_eq_zero_indices
|
||||||
|
|
||||||
|
# sinc filter function
|
||||||
|
weights[t_not_eq_zero_indices] *= paddle.sin(
|
||||||
|
2 * math.pi * lowpass_cutoff * delta_t[t_not_eq_zero_indices]) / (
|
||||||
|
math.pi * delta_t[t_not_eq_zero_indices])
|
||||||
|
|
||||||
|
# limit of the function at t = 0
|
||||||
|
weights[t_eq_zero_indices] *= 2 * lowpass_cutoff
|
||||||
|
|
||||||
|
# size (output_samples, max_weight_width)
|
||||||
|
weights /= self.orig_freq
|
||||||
|
|
||||||
|
self.first_indices = min_input_index
|
||||||
|
self.weights = weights
|
||||||
|
|
||||||
|
|
||||||
|
class DropFreq(nn.Layer):
|
||||||
|
"""This class drops a random frequency from the signal.
|
||||||
|
The purpose of this class is to teach models to learn to rely on all parts
|
||||||
|
of the signal, not just a few frequency bands.
|
||||||
|
Arguments
|
||||||
|
---------
|
||||||
|
drop_freq_low : float
|
||||||
|
The low end of frequencies that can be dropped,
|
||||||
|
as a fraction of the sampling rate / 2.
|
||||||
|
drop_freq_high : float
|
||||||
|
The high end of frequencies that can be
|
||||||
|
dropped, as a fraction of the sampling rate / 2.
|
||||||
|
drop_count_low : int
|
||||||
|
The low end of number of frequencies that could be dropped.
|
||||||
|
drop_count_high : int
|
||||||
|
The high end of number of frequencies that could be dropped.
|
||||||
|
drop_width : float
|
||||||
|
The width of the frequency band to drop, as
|
||||||
|
a fraction of the sampling_rate / 2.
|
||||||
|
drop_prob : float
|
||||||
|
The probability that the batch of signals will have a frequency
|
||||||
|
dropped. By default, every batch has frequencies dropped.
|
||||||
|
Example
|
||||||
|
-------
|
||||||
|
>>> from speechbrain.dataio.dataio import read_audio
|
||||||
|
>>> dropper = DropFreq()
|
||||||
|
>>> signal = read_audio('tests/samples/single-mic/example1.wav')
|
||||||
|
>>> dropped_signal = dropper(signal.unsqueeze(0))
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
drop_freq_low=1e-14,
|
||||||
|
drop_freq_high=1,
|
||||||
|
drop_count_low=1,
|
||||||
|
drop_count_high=2,
|
||||||
|
drop_width=0.05,
|
||||||
|
drop_prob=1, ):
|
||||||
|
super().__init__()
|
||||||
|
self.drop_freq_low = drop_freq_low
|
||||||
|
self.drop_freq_high = drop_freq_high
|
||||||
|
self.drop_count_low = drop_count_low
|
||||||
|
self.drop_count_high = drop_count_high
|
||||||
|
self.drop_width = drop_width
|
||||||
|
self.drop_prob = drop_prob
|
||||||
|
|
||||||
|
def forward(self, waveforms):
|
||||||
|
"""
|
||||||
|
Arguments
|
||||||
|
---------
|
||||||
|
waveforms : tensor
|
||||||
|
Shape should be `[batch, time]` or `[batch, time, channels]`.
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tensor of shape `[batch, time]` or `[batch, time, channels]`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Don't drop (return early) 1-`drop_prob` portion of the batches
|
||||||
|
dropped_waveform = waveforms.clone()
|
||||||
|
if paddle.rand([1]) > self.drop_prob:
|
||||||
|
return dropped_waveform
|
||||||
|
|
||||||
|
# Add channels dimension
|
||||||
|
if len(waveforms.shape) == 2:
|
||||||
|
dropped_waveform = dropped_waveform.unsqueeze(-1)
|
||||||
|
|
||||||
|
# Pick number of frequencies to drop
|
||||||
|
drop_count = paddle.randint(
|
||||||
|
low=self.drop_count_low,
|
||||||
|
high=self.drop_count_high + 1,
|
||||||
|
shape=(1, ), )
|
||||||
|
|
||||||
|
# Filter parameters
|
||||||
|
filter_length = 101
|
||||||
|
pad = filter_length // 2
|
||||||
|
|
||||||
|
# Start with delta function
|
||||||
|
drop_filter = paddle.zeros([1, filter_length, 1])
|
||||||
|
drop_filter[0, pad, 0] = 1
|
||||||
|
|
||||||
|
if drop_count.shape == 0:
|
||||||
|
# Pick a frequency to drop
|
||||||
|
drop_range = self.drop_freq_high - self.drop_freq_low
|
||||||
|
drop_frequency = (
|
||||||
|
paddle.rand(drop_count) * drop_range + self.drop_freq_low)
|
||||||
|
# Subtract each frequency
|
||||||
|
for frequency in drop_frequency:
|
||||||
|
notch_kernel = notch_filter(
|
||||||
|
frequency,
|
||||||
|
filter_length,
|
||||||
|
self.drop_width, )
|
||||||
|
drop_filter = convolve1d(drop_filter, notch_kernel, pad)
|
||||||
|
|
||||||
|
# Apply filter
|
||||||
|
dropped_waveform = convolve1d(dropped_waveform, drop_filter, pad)
|
||||||
|
|
||||||
|
# Remove channels dimension if added
|
||||||
|
return dropped_waveform.squeeze(-1)
|
||||||
|
|
||||||
|
|
||||||
|
class DropChunk(nn.Layer):
|
||||||
|
"""This class drops portions of the input signal.
|
||||||
|
Using `DropChunk` as an augmentation strategy helps a models learn to rely
|
||||||
|
on all parts of the signal, since it can't expect a given part to be
|
||||||
|
present.
|
||||||
|
Arguments
|
||||||
|
---------
|
||||||
|
drop_length_low : int
|
||||||
|
The low end of lengths for which to set the
|
||||||
|
signal to zero, in samples.
|
||||||
|
drop_length_high : int
|
||||||
|
The high end of lengths for which to set the
|
||||||
|
signal to zero, in samples.
|
||||||
|
drop_count_low : int
|
||||||
|
The low end of number of times that the signal
|
||||||
|
can be dropped to zero.
|
||||||
|
drop_count_high : int
|
||||||
|
The high end of number of times that the signal
|
||||||
|
can be dropped to zero.
|
||||||
|
drop_start : int
|
||||||
|
The first index for which dropping will be allowed.
|
||||||
|
drop_end : int
|
||||||
|
The last index for which dropping will be allowed.
|
||||||
|
drop_prob : float
|
||||||
|
The probability that the batch of signals will
|
||||||
|
have a portion dropped. By default, every batch
|
||||||
|
has portions dropped.
|
||||||
|
noise_factor : float
|
||||||
|
The factor relative to average amplitude of an utterance
|
||||||
|
to use for scaling the white noise inserted. 1 keeps
|
||||||
|
the average amplitude the same, while 0 inserts all 0's.
|
||||||
|
Example
|
||||||
|
-------
|
||||||
|
>>> from speechbrain.dataio.dataio import read_audio
|
||||||
|
>>> dropper = DropChunk(drop_start=100, drop_end=200, noise_factor=0.)
|
||||||
|
>>> signal = read_audio('tests/samples/single-mic/example1.wav')
|
||||||
|
>>> signal = signal.unsqueeze(0) # [batch, time, channels]
|
||||||
|
>>> length = paddle.ones([1])
|
||||||
|
>>> dropped_signal = dropper(signal, length)
|
||||||
|
>>> float(dropped_signal[:, 150])
|
||||||
|
0.0
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
drop_length_low=100,
|
||||||
|
drop_length_high=1000,
|
||||||
|
drop_count_low=1,
|
||||||
|
drop_count_high=10,
|
||||||
|
drop_start=0,
|
||||||
|
drop_end=None,
|
||||||
|
drop_prob=1,
|
||||||
|
noise_factor=0.0, ):
|
||||||
|
super().__init__()
|
||||||
|
self.drop_length_low = drop_length_low
|
||||||
|
self.drop_length_high = drop_length_high
|
||||||
|
self.drop_count_low = drop_count_low
|
||||||
|
self.drop_count_high = drop_count_high
|
||||||
|
self.drop_start = drop_start
|
||||||
|
self.drop_end = drop_end
|
||||||
|
self.drop_prob = drop_prob
|
||||||
|
self.noise_factor = noise_factor
|
||||||
|
|
||||||
|
# Validate low < high
|
||||||
|
if drop_length_low > drop_length_high:
|
||||||
|
raise ValueError("Low limit must not be more than high limit")
|
||||||
|
if drop_count_low > drop_count_high:
|
||||||
|
raise ValueError("Low limit must not be more than high limit")
|
||||||
|
|
||||||
|
# Make sure the length doesn't exceed end - start
|
||||||
|
if drop_end is not None and drop_end >= 0:
|
||||||
|
if drop_start > drop_end:
|
||||||
|
raise ValueError("Low limit must not be more than high limit")
|
||||||
|
|
||||||
|
drop_range = drop_end - drop_start
|
||||||
|
self.drop_length_low = min(drop_length_low, drop_range)
|
||||||
|
self.drop_length_high = min(drop_length_high, drop_range)
|
||||||
|
|
||||||
|
def forward(self, waveforms, lengths):
|
||||||
|
"""
|
||||||
|
Arguments
|
||||||
|
---------
|
||||||
|
waveforms : tensor
|
||||||
|
Shape should be `[batch, time]` or `[batch, time, channels]`.
|
||||||
|
lengths : tensor
|
||||||
|
Shape should be a single dimension, `[batch]`.
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tensor of shape `[batch, time]` or
|
||||||
|
`[batch, time, channels]`
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Reading input list
|
||||||
|
lengths = (lengths * waveforms.shape[1]).long()
|
||||||
|
batch_size = waveforms.shape[0]
|
||||||
|
dropped_waveform = waveforms.clone()
|
||||||
|
|
||||||
|
# Don't drop (return early) 1-`drop_prob` portion of the batches
|
||||||
|
if paddle.rand([1]) > self.drop_prob:
|
||||||
|
return dropped_waveform
|
||||||
|
|
||||||
|
# Store original amplitude for computing white noise amplitude
|
||||||
|
clean_amplitude = compute_amplitude(waveforms, lengths.unsqueeze(1))
|
||||||
|
|
||||||
|
# Pick a number of times to drop
|
||||||
|
drop_times = paddle.randint(
|
||||||
|
low=self.drop_count_low,
|
||||||
|
high=self.drop_count_high + 1,
|
||||||
|
shape=(batch_size, ), )
|
||||||
|
|
||||||
|
# Iterate batch to set mask
|
||||||
|
for i in range(batch_size):
|
||||||
|
if drop_times[i] == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Pick lengths
|
||||||
|
length = paddle.randint(
|
||||||
|
low=self.drop_length_low,
|
||||||
|
high=self.drop_length_high + 1,
|
||||||
|
shape=(drop_times[i], ), )
|
||||||
|
|
||||||
|
# Compute range of starting locations
|
||||||
|
start_min = self.drop_start
|
||||||
|
if start_min < 0:
|
||||||
|
start_min += lengths[i]
|
||||||
|
start_max = self.drop_end
|
||||||
|
if start_max is None:
|
||||||
|
start_max = lengths[i]
|
||||||
|
if start_max < 0:
|
||||||
|
start_max += lengths[i]
|
||||||
|
start_max = max(0, start_max - length.max())
|
||||||
|
|
||||||
|
# Pick starting locations
|
||||||
|
start = paddle.randint(
|
||||||
|
low=start_min,
|
||||||
|
high=start_max + 1,
|
||||||
|
shape=(drop_times[i], ), )
|
||||||
|
|
||||||
|
end = start + length
|
||||||
|
|
||||||
|
# Update waveform
|
||||||
|
if not self.noise_factor:
|
||||||
|
for j in range(drop_times[i]):
|
||||||
|
dropped_waveform[i, start[j]:end[j]] = 0.0
|
||||||
|
else:
|
||||||
|
# Uniform distribution of -2 to +2 * avg amplitude should
|
||||||
|
# preserve the average for normalization
|
||||||
|
noise_max = 2 * clean_amplitude[i] * self.noise_factor
|
||||||
|
for j in range(drop_times[i]):
|
||||||
|
# zero-center the noise distribution
|
||||||
|
noise_vec = paddle.rand([length[j]])
|
||||||
|
noise_vec = 2 * noise_max * noise_vec - noise_max
|
||||||
|
dropped_waveform[i, start[j]:end[j]] = noise_vec
|
||||||
|
|
||||||
|
return dropped_waveform
|
||||||
|
|
||||||
|
|
||||||
|
class SpecAugment(paddle.nn.Layer):
|
||||||
|
"""An implementation of the SpecAugment algorithm.
|
||||||
|
Reference:
|
||||||
|
https://arxiv.org/abs/1904.08779
|
||||||
|
Arguments
|
||||||
|
---------
|
||||||
|
time_warp : bool
|
||||||
|
Whether applying time warping.
|
||||||
|
time_warp_window : int
|
||||||
|
Time warp window.
|
||||||
|
time_warp_mode : str
|
||||||
|
Interpolation mode for time warping (default "bicubic").
|
||||||
|
freq_mask : bool
|
||||||
|
Whether applying freq mask.
|
||||||
|
freq_mask_width : int or tuple
|
||||||
|
Freq mask width range.
|
||||||
|
n_freq_mask : int
|
||||||
|
Number of freq mask.
|
||||||
|
time_mask : bool
|
||||||
|
Whether applying time mask.
|
||||||
|
time_mask_width : int or tuple
|
||||||
|
Time mask width range.
|
||||||
|
n_time_mask : int
|
||||||
|
Number of time mask.
|
||||||
|
replace_with_zero : bool
|
||||||
|
If True, replace masked value with 0, else replace masked value with mean of the input tensor.
|
||||||
|
Example
|
||||||
|
-------
|
||||||
|
>>> aug = SpecAugment()
|
||||||
|
>>> a = paddle.rand([8, 120, 80])
|
||||||
|
>>> a = aug(a)
|
||||||
|
>>> print(a.shape)
|
||||||
|
paddle.Size([8, 120, 80])
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
time_warp=True,
|
||||||
|
time_warp_window=5,
|
||||||
|
time_warp_mode="bicubic",
|
||||||
|
freq_mask=True,
|
||||||
|
freq_mask_width=(0, 20),
|
||||||
|
n_freq_mask=2,
|
||||||
|
time_mask=True,
|
||||||
|
time_mask_width=(0, 100),
|
||||||
|
n_time_mask=2,
|
||||||
|
replace_with_zero=True, ):
|
||||||
|
super().__init__()
|
||||||
|
assert (
|
||||||
|
time_warp or freq_mask or time_mask
|
||||||
|
), "at least one of time_warp, time_mask, or freq_mask should be applied"
|
||||||
|
|
||||||
|
self.apply_time_warp = time_warp
|
||||||
|
self.time_warp_window = time_warp_window
|
||||||
|
self.time_warp_mode = time_warp_mode
|
||||||
|
|
||||||
|
self.freq_mask = freq_mask
|
||||||
|
if isinstance(freq_mask_width, int):
|
||||||
|
freq_mask_width = (0, freq_mask_width)
|
||||||
|
self.freq_mask_width = freq_mask_width
|
||||||
|
self.n_freq_mask = n_freq_mask
|
||||||
|
|
||||||
|
self.time_mask = time_mask
|
||||||
|
if isinstance(time_mask_width, int):
|
||||||
|
time_mask_width = (0, time_mask_width)
|
||||||
|
self.time_mask_width = time_mask_width
|
||||||
|
self.n_time_mask = n_time_mask
|
||||||
|
|
||||||
|
self.replace_with_zero = replace_with_zero
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Takes in input a tensors and returns an augmented one."""
|
||||||
|
if self.apply_time_warp:
|
||||||
|
x = self.time_warp(x)
|
||||||
|
if self.freq_mask:
|
||||||
|
x = self.mask_along_axis(x, dim=2)
|
||||||
|
if self.time_mask:
|
||||||
|
x = self.mask_along_axis(x, dim=1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def time_warp(self, x):
|
||||||
|
"""Time warping with paddle.nn.functional.interpolate"""
|
||||||
|
original_size = x.shape
|
||||||
|
window = self.time_warp_window
|
||||||
|
|
||||||
|
# 2d interpolation requires 4D or higher dimension tensors
|
||||||
|
# x: (Batch, Time, Freq) -> (Batch, 1, Time, Freq)
|
||||||
|
if x.dim() == 3:
|
||||||
|
x = x.unsqueeze(1)
|
||||||
|
|
||||||
|
time = x.shape[2]
|
||||||
|
if time - window <= window:
|
||||||
|
return x.view(*original_size)
|
||||||
|
|
||||||
|
# compute center and corresponding window
|
||||||
|
c = paddle.randint(window, time - window, (1, ))[0]
|
||||||
|
w = paddle.randint(c - window, c + window, (1, ))[0] + 1
|
||||||
|
|
||||||
|
left = paddle.nn.functional.interpolate(
|
||||||
|
x[:, :, :c],
|
||||||
|
(w, x.shape[3]),
|
||||||
|
mode=self.time_warp_mode,
|
||||||
|
align_corners=True, )
|
||||||
|
right = paddle.nn.functional.interpolate(
|
||||||
|
x[:, :, c:],
|
||||||
|
(time - w, x.shape[3]),
|
||||||
|
mode=self.time_warp_mode,
|
||||||
|
align_corners=True, )
|
||||||
|
|
||||||
|
x[:, :, :w] = left
|
||||||
|
x[:, :, w:] = right
|
||||||
|
return x.view(*original_size)
|
||||||
|
|
||||||
|
def mask_along_axis(self, x, dim):
|
||||||
|
"""Mask along time or frequency axis.
|
||||||
|
Arguments
|
||||||
|
---------
|
||||||
|
x : tensor
|
||||||
|
Input tensor.
|
||||||
|
dim : int
|
||||||
|
Corresponding dimension to mask.
|
||||||
|
"""
|
||||||
|
original_size = x.shape
|
||||||
|
if x.dim() == 4:
|
||||||
|
x = x.view(-1, x.shape[2], x.shape[3])
|
||||||
|
|
||||||
|
batch, time, fea = x.shape
|
||||||
|
|
||||||
|
if dim == 1:
|
||||||
|
D = time
|
||||||
|
n_mask = self.n_time_mask
|
||||||
|
width_range = self.time_mask_width
|
||||||
|
else:
|
||||||
|
D = fea
|
||||||
|
n_mask = self.n_freq_mask
|
||||||
|
width_range = self.freq_mask_width
|
||||||
|
|
||||||
|
mask_len = paddle.randint(width_range[0], width_range[1],
|
||||||
|
(batch, n_mask)).unsqueeze(2)
|
||||||
|
|
||||||
|
mask_pos = paddle.randint(0, max(1, D - mask_len.max()),
|
||||||
|
(batch, n_mask)).unsqueeze(2)
|
||||||
|
|
||||||
|
# compute masks
|
||||||
|
arange = paddle.arange(end=D).view(1, 1, -1)
|
||||||
|
mask = (mask_pos <= arange) * (arange < (mask_pos + mask_len))
|
||||||
|
mask = mask.any(axis=1)
|
||||||
|
|
||||||
|
if dim == 1:
|
||||||
|
mask = mask.unsqueeze(2)
|
||||||
|
else:
|
||||||
|
mask = mask.unsqueeze(1)
|
||||||
|
|
||||||
|
if self.replace_with_zero:
|
||||||
|
val = 0.0
|
||||||
|
else:
|
||||||
|
val = x.mean()
|
||||||
|
# same to x.masked_fill_(mask, val)
|
||||||
|
y = paddle.full(x.shape, val, x.dtype)
|
||||||
|
x = paddle.where(mask, y, x)
|
||||||
|
return x.view(*original_size)
|
||||||
|
|
||||||
|
|
||||||
|
class TimeDomainSpecAugment(nn.Layer):
|
||||||
|
"""A time-domain approximation of the SpecAugment algorithm.
|
||||||
|
This augmentation module implements three augmentations in
|
||||||
|
the time-domain.
|
||||||
|
1. Drop chunks of the audio (zero amplitude or white noise)
|
||||||
|
2. Drop frequency bands (with band-drop filters)
|
||||||
|
3. Speed peturbation (via resampling to slightly different rate)
|
||||||
|
Arguments
|
||||||
|
---------
|
||||||
|
perturb_prob : float from 0 to 1
|
||||||
|
The probability that a batch will have speed perturbation applied.
|
||||||
|
drop_freq_prob : float from 0 to 1
|
||||||
|
The probability that a batch will have frequencies dropped.
|
||||||
|
drop_chunk_prob : float from 0 to 1
|
||||||
|
The probability that a batch will have chunks dropped.
|
||||||
|
speeds : list of ints
|
||||||
|
A set of different speeds to use to perturb each batch.
|
||||||
|
See ``speechbrain.processing.speech_augmentation.SpeedPerturb``
|
||||||
|
sample_rate : int
|
||||||
|
Sampling rate of the input waveforms.
|
||||||
|
drop_freq_count_low : int
|
||||||
|
Lowest number of frequencies that could be dropped.
|
||||||
|
drop_freq_count_high : int
|
||||||
|
Highest number of frequencies that could be dropped.
|
||||||
|
drop_chunk_count_low : int
|
||||||
|
Lowest number of chunks that could be dropped.
|
||||||
|
drop_chunk_count_high : int
|
||||||
|
Highest number of chunks that could be dropped.
|
||||||
|
drop_chunk_length_low : int
|
||||||
|
Lowest length of chunks that could be dropped.
|
||||||
|
drop_chunk_length_high : int
|
||||||
|
Highest length of chunks that could be dropped.
|
||||||
|
drop_chunk_noise_factor : float
|
||||||
|
The noise factor used to scale the white noise inserted, relative to
|
||||||
|
the average amplitude of the utterance. Default 0 (no noise inserted).
|
||||||
|
Example
|
||||||
|
-------
|
||||||
|
>>> inputs = paddle.randn([10, 16000])
|
||||||
|
>>> feature_maker = TimeDomainSpecAugment(speeds=[80])
|
||||||
|
>>> feats = feature_maker(inputs, paddle.ones(10))
|
||||||
|
>>> feats.shape
|
||||||
|
paddle.shape([10, 12800])
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
perturb_prob=1.0,
|
||||||
|
drop_freq_prob=1.0,
|
||||||
|
drop_chunk_prob=1.0,
|
||||||
|
speeds=[95, 100, 105],
|
||||||
|
sample_rate=16000,
|
||||||
|
drop_freq_count_low=0,
|
||||||
|
drop_freq_count_high=3,
|
||||||
|
drop_chunk_count_low=0,
|
||||||
|
drop_chunk_count_high=5,
|
||||||
|
drop_chunk_length_low=1000,
|
||||||
|
drop_chunk_length_high=2000,
|
||||||
|
drop_chunk_noise_factor=0, ):
|
||||||
|
super().__init__()
|
||||||
|
self.speed_perturb = SpeedPerturb(
|
||||||
|
perturb_prob=perturb_prob, orig_freq=sample_rate, speeds=speeds)
|
||||||
|
self.drop_freq = DropFreq(
|
||||||
|
drop_prob=drop_freq_prob,
|
||||||
|
drop_count_low=drop_freq_count_low,
|
||||||
|
drop_count_high=drop_freq_count_high, )
|
||||||
|
self.drop_chunk = DropChunk(
|
||||||
|
drop_prob=drop_chunk_prob,
|
||||||
|
drop_count_low=drop_chunk_count_low,
|
||||||
|
drop_count_high=drop_chunk_count_high,
|
||||||
|
drop_length_low=drop_chunk_length_low,
|
||||||
|
drop_length_high=drop_chunk_length_high,
|
||||||
|
noise_factor=drop_chunk_noise_factor, )
|
||||||
|
|
||||||
|
def forward(self, waveforms, lengths):
|
||||||
|
"""Returns the distorted waveforms.
|
||||||
|
Arguments
|
||||||
|
---------
|
||||||
|
waveforms : tensor
|
||||||
|
The waveforms to distort
|
||||||
|
"""
|
||||||
|
# Augmentation
|
||||||
|
with paddle.no_grad():
|
||||||
|
waveforms = self.speed_perturb(waveforms)
|
||||||
|
waveforms = self.drop_freq(waveforms)
|
||||||
|
waveforms = self.drop_chunk(waveforms, lengths)
|
||||||
|
return waveforms
|
@ -0,0 +1,323 @@
|
|||||||
|
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Dict
|
||||||
|
from typing import List
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import paddle.nn as nn
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
from paddlespeech.s2t.models.wav2vec2.modules.VanillaNN import VanillaNN
|
||||||
|
from paddlespeech.s2t.models.wav2vec2.processing.speech_augmentation import SpecAugment
|
||||||
|
from paddlespeech.s2t.modules.ctc import CTCDecoderBase as CTC
|
||||||
|
from paddlespeech.s2t.modules.initializer import DefaultInitializerContext
|
||||||
|
from paddlespeech.s2t.utils.ctc_utils import remove_duplicates_and_blank
|
||||||
|
from paddlespeech.s2t.utils.utility import log_add
|
||||||
|
|
||||||
|
from .wavlm_paddle import WavLM, WavLMConfig
|
||||||
|
|
||||||
|
|
||||||
|
class WavLMASR(nn.Layer):
|
||||||
|
def __init__(self, config: dict):
|
||||||
|
super().__init__()
|
||||||
|
init_type = config.get("init_type", None)
|
||||||
|
with DefaultInitializerContext(init_type):
|
||||||
|
self.config = config
|
||||||
|
wavlm_config = WavLMConfig(config)
|
||||||
|
wavlm = WavLM(wavlm_config)
|
||||||
|
|
||||||
|
self.normalize_wav = config.normalize_wav
|
||||||
|
self.output_norm = config.output_norm
|
||||||
|
if hasattr(config, 'spec_augment'):
|
||||||
|
self.spec_augment = SpecAugment(**config.spec_augment)
|
||||||
|
|
||||||
|
if config.freeze_wavlm:
|
||||||
|
wavlm.eval()
|
||||||
|
for parm in wavlm.parameters():
|
||||||
|
parm.trainable = False
|
||||||
|
self.wavlm = wavlm
|
||||||
|
self.enc = VanillaNN(**config.enc)
|
||||||
|
self.ctc = CTC(**config.ctc,
|
||||||
|
odim=config.output_dim,
|
||||||
|
batch_average=False,
|
||||||
|
reduction='mean')
|
||||||
|
|
||||||
|
def forward(self, wav, wavs_lens_rate, target, target_lens):
|
||||||
|
if self.normalize_wav:
|
||||||
|
wav = F.layer_norm(wav, wav.shape)
|
||||||
|
|
||||||
|
# Extract wav2vec output
|
||||||
|
out = self.wavlm(wav)
|
||||||
|
# We normalize the output if required
|
||||||
|
if self.output_norm:
|
||||||
|
out = F.layer_norm(out, out.shape)
|
||||||
|
|
||||||
|
if self.training and hasattr(self.config, 'spec_augment'):
|
||||||
|
feats = self.spec_augment(out)
|
||||||
|
else:
|
||||||
|
feats = out
|
||||||
|
|
||||||
|
x = self.enc(feats)
|
||||||
|
# x = feats
|
||||||
|
|
||||||
|
x_lens = (wavs_lens_rate * x.shape[1]).round().astype(paddle.int64)
|
||||||
|
target_lens = target_lens.astype(paddle.int64)
|
||||||
|
# target = target.astype(paddle.int32)
|
||||||
|
ctc_loss = self.ctc(x, x_lens, target, target_lens)
|
||||||
|
|
||||||
|
return ctc_loss
|
||||||
|
|
||||||
|
@paddle.no_grad()
|
||||||
|
def decode(self,
|
||||||
|
feats: paddle.Tensor,
|
||||||
|
text_feature: Dict[str, int],
|
||||||
|
decoding_method: str,
|
||||||
|
beam_size: int,
|
||||||
|
tokenizer: str=None,
|
||||||
|
sb_pipeline=False):
|
||||||
|
batch_size = feats.shape[0]
|
||||||
|
|
||||||
|
if decoding_method == 'ctc_prefix_beam_search' and batch_size > 1:
|
||||||
|
print(
|
||||||
|
f"decoding mode {decoding_method} must be running with batch_size == 1"
|
||||||
|
)
|
||||||
|
print(f"current batch_size is {batch_size}")
|
||||||
|
|
||||||
|
if decoding_method == 'ctc_greedy_search':
|
||||||
|
if tokenizer is None and sb_pipeline is False:
|
||||||
|
hyps = self.ctc_greedy_search(feats)
|
||||||
|
res = [text_feature.defeaturize(hyp) for hyp in hyps]
|
||||||
|
res_tokenids = [hyp for hyp in hyps]
|
||||||
|
else:
|
||||||
|
if sb_pipeline is True:
|
||||||
|
hyps = self.ctc_greedy_search(feats.unsqueeze(-1))
|
||||||
|
else:
|
||||||
|
hyps = self.ctc_greedy_search(feats)
|
||||||
|
res = []
|
||||||
|
res_tokenids = []
|
||||||
|
for sequence in hyps:
|
||||||
|
# Decode token terms to words
|
||||||
|
predicted_tokens = text_feature.convert_ids_to_tokens(
|
||||||
|
sequence)
|
||||||
|
tmp_res = []
|
||||||
|
tmp_res_tokenids = []
|
||||||
|
for c in predicted_tokens:
|
||||||
|
if c == "[CLS]":
|
||||||
|
continue
|
||||||
|
elif c == "[SEP]" or c == "[PAD]":
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
tmp_res.append(c)
|
||||||
|
tmp_res_tokenids.append(text_feature.vocab[c])
|
||||||
|
res.append(''.join(tmp_res))
|
||||||
|
res_tokenids.append(tmp_res_tokenids)
|
||||||
|
|
||||||
|
# ctc_prefix_beam_search and attention_rescoring only return one
|
||||||
|
# result in List[int], change it to List[List[int]] for compatible
|
||||||
|
# with other batch decoding mode
|
||||||
|
elif decoding_method == 'ctc_prefix_beam_search':
|
||||||
|
assert feats.shape[0] == 1
|
||||||
|
if tokenizer is None and sb_pipeline is False:
|
||||||
|
hyp = self.ctc_prefix_beam_search(feats, beam_size)
|
||||||
|
res = [text_feature.defeaturize(hyp)]
|
||||||
|
res_tokenids = [hyp]
|
||||||
|
else:
|
||||||
|
if sb_pipeline is True:
|
||||||
|
hyp = self.ctc_prefix_beam_search(
|
||||||
|
feats.unsqueeze(-1), beam_size)
|
||||||
|
else:
|
||||||
|
hyp = self.ctc_prefix_beam_search(feats, beam_size)
|
||||||
|
res = []
|
||||||
|
res_tokenids = []
|
||||||
|
predicted_tokens = text_feature.convert_ids_to_tokens(hyp)
|
||||||
|
tmp_res = []
|
||||||
|
tmp_res_tokenids = []
|
||||||
|
for c in predicted_tokens:
|
||||||
|
if c == "[CLS]":
|
||||||
|
continue
|
||||||
|
elif c == "[SEP]" or c == "[PAD]":
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
tmp_res.append(c)
|
||||||
|
tmp_res_tokenids.append(text_feature.vocab[c])
|
||||||
|
res.append(''.join(tmp_res))
|
||||||
|
res_tokenids.append(tmp_res_tokenids)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"WavLM not support decoding method: {decoding_method}")
|
||||||
|
|
||||||
|
return res, res_tokenids
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config):
|
||||||
|
model = cls(config)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def ctc_greedy_search(self, wav) -> List[List[int]]:
|
||||||
|
""" Apply CTC greedy search
|
||||||
|
Args:
|
||||||
|
speech (paddle.Tensor): (batch, max_len)
|
||||||
|
speech_length (paddle.Tensor): (batch, )
|
||||||
|
Returns:
|
||||||
|
List[List[int]]: best path result
|
||||||
|
"""
|
||||||
|
batch_size = wav.shape[0]
|
||||||
|
wav = wav[:, :, 0]
|
||||||
|
if self.normalize_wav:
|
||||||
|
wav = F.layer_norm(wav, wav.shape[1:])
|
||||||
|
# Extract wavlm output
|
||||||
|
out = self.wavlm(wav)
|
||||||
|
# We normalize the output if required
|
||||||
|
if self.output_norm:
|
||||||
|
out = F.layer_norm(out, out.shape[1:])
|
||||||
|
feats = out
|
||||||
|
x = self.enc(feats)
|
||||||
|
x_lens = x.shape[1]
|
||||||
|
ctc_probs = self.ctc.log_softmax(x) # (B, maxlen, vocab_size)
|
||||||
|
topk_prob, topk_index = ctc_probs.topk(1, axis=2) # (B, maxlen, 1)
|
||||||
|
topk_index = topk_index.view(batch_size, x_lens) # (B, maxlen)
|
||||||
|
|
||||||
|
hyps = [hyp.tolist() for hyp in topk_index]
|
||||||
|
hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps]
|
||||||
|
return hyps
|
||||||
|
|
||||||
|
def _ctc_prefix_beam_search(
|
||||||
|
self,
|
||||||
|
wav,
|
||||||
|
beam_size,
|
||||||
|
blank_id: int=0, ) -> Tuple[List[Tuple[int, float]], paddle.Tensor]:
|
||||||
|
""" CTC prefix beam search inner implementation
|
||||||
|
Args:
|
||||||
|
speech (paddle.Tensor): (batch, max_len, feat_dim)
|
||||||
|
speech_length (paddle.Tensor): (batch, )
|
||||||
|
beam_size (int): beam size for beam search
|
||||||
|
decoding_chunk_size (int): decoding chunk for dynamic chunk
|
||||||
|
trained model.
|
||||||
|
<0: for decoding, use full chunk.
|
||||||
|
>0: for decoding, use fixed chunk size as set.
|
||||||
|
0: used for training, it's prohibited here
|
||||||
|
simulate_streaming (bool): whether do encoder forward in a
|
||||||
|
streaming fashion
|
||||||
|
Returns:
|
||||||
|
List[Tuple[int, float]]: nbest results, (N,1), (text, likelihood)
|
||||||
|
paddle.Tensor: encoder output, (1, max_len, encoder_dim),
|
||||||
|
it will be used for rescoring in attention rescoring mode
|
||||||
|
"""
|
||||||
|
wav = wav[:, :, 0]
|
||||||
|
|
||||||
|
if self.normalize_wav:
|
||||||
|
wav = F.layer_norm(wav, wav.shape[1:])
|
||||||
|
# Extract wavlm output
|
||||||
|
out = self.wavlm(wav)
|
||||||
|
# We normalize the output if required
|
||||||
|
if self.output_norm:
|
||||||
|
out = F.layer_norm(out, out.shape[1:])
|
||||||
|
feats = out
|
||||||
|
|
||||||
|
x = self.enc(feats)
|
||||||
|
maxlen = x.shape[1]
|
||||||
|
ctc_probs = self.ctc.log_softmax(x) # (1, maxlen, vocab_size)
|
||||||
|
ctc_probs = ctc_probs.squeeze(0)
|
||||||
|
|
||||||
|
# cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score))
|
||||||
|
# blank_ending_score and none_blank_ending_score in ln domain
|
||||||
|
cur_hyps = [(tuple(), (0.0, -float('inf')))]
|
||||||
|
# 2. CTC beam search step by step
|
||||||
|
for t in range(0, maxlen):
|
||||||
|
logp = ctc_probs[t] # (vocab_size,)
|
||||||
|
# key: prefix, value (pb, pnb), default value(-inf, -inf)
|
||||||
|
next_hyps = defaultdict(lambda: (-float('inf'), -float('inf')))
|
||||||
|
# 2.1 First beam prune: select topk best
|
||||||
|
top_k_logp, top_k_index = logp.topk(beam_size) # (beam_size,)
|
||||||
|
for s in top_k_index:
|
||||||
|
s = s.item()
|
||||||
|
ps = logp[s].item()
|
||||||
|
for prefix, (pb, pnb) in cur_hyps:
|
||||||
|
last = prefix[-1] if len(prefix) > 0 else None
|
||||||
|
if s == blank_id: # blank
|
||||||
|
n_pb, n_pnb = next_hyps[prefix]
|
||||||
|
n_pb = log_add([n_pb, pb + ps, pnb + ps])
|
||||||
|
next_hyps[prefix] = (n_pb, n_pnb)
|
||||||
|
elif s == last:
|
||||||
|
# Update *ss -> *s;
|
||||||
|
n_pb, n_pnb = next_hyps[prefix]
|
||||||
|
n_pnb = log_add([n_pnb, pnb + ps])
|
||||||
|
next_hyps[prefix] = (n_pb, n_pnb)
|
||||||
|
# Update *s-s -> *ss, - is for blank
|
||||||
|
n_prefix = prefix + (s, )
|
||||||
|
n_pb, n_pnb = next_hyps[n_prefix]
|
||||||
|
n_pnb = log_add([n_pnb, pb + ps])
|
||||||
|
next_hyps[n_prefix] = (n_pb, n_pnb)
|
||||||
|
else:
|
||||||
|
n_prefix = prefix + (s, )
|
||||||
|
n_pb, n_pnb = next_hyps[n_prefix]
|
||||||
|
n_pnb = log_add([n_pnb, pb + ps, pnb + ps])
|
||||||
|
next_hyps[n_prefix] = (n_pb, n_pnb)
|
||||||
|
|
||||||
|
# 2.2 Second beam prune
|
||||||
|
next_hyps = sorted(
|
||||||
|
next_hyps.items(),
|
||||||
|
key=lambda x: log_add(list(x[1])),
|
||||||
|
reverse=True)
|
||||||
|
cur_hyps = next_hyps[:beam_size]
|
||||||
|
|
||||||
|
hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in cur_hyps]
|
||||||
|
return hyps
|
||||||
|
|
||||||
|
def ctc_prefix_beam_search(self, wav, beam_size) -> List[int]:
|
||||||
|
""" Apply CTC prefix beam search
|
||||||
|
Args:
|
||||||
|
speech (paddle.Tensor): (batch, max_len, feat_dim)
|
||||||
|
speech_length (paddle.Tensor): (batch, )
|
||||||
|
beam_size (int): beam size for beam search
|
||||||
|
decoding_chunk_size (int): decoding chunk for dynamic chunk
|
||||||
|
trained model.
|
||||||
|
<0: for decoding, use full chunk.
|
||||||
|
>0: for decoding, use fixed chunk size as set.
|
||||||
|
0: used for training, it's prohibited here
|
||||||
|
simulate_streaming (bool): whether do encoder forward in a
|
||||||
|
streaming fashion
|
||||||
|
Returns:
|
||||||
|
List[int]: CTC prefix beam search nbest results
|
||||||
|
"""
|
||||||
|
hyps = self._ctc_prefix_beam_search(wav, beam_size)
|
||||||
|
return hyps[0][0]
|
||||||
|
|
||||||
|
|
||||||
|
class WavLMBase(nn.Layer):
|
||||||
|
"""WavLM model"""
|
||||||
|
|
||||||
|
def __init__(self, config: dict):
|
||||||
|
super().__init__()
|
||||||
|
wavlm_config = WavLMConfig(config)
|
||||||
|
wavlm = WavLM(wavlm_config)
|
||||||
|
self.wavlm = wavlm
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, configs: dict):
|
||||||
|
"""init model.
|
||||||
|
Args:
|
||||||
|
configs (dict): config dict.
|
||||||
|
Raises:
|
||||||
|
ValueError: raise when using not support encoder type.
|
||||||
|
Returns:
|
||||||
|
nn.Layer: WavLMBase
|
||||||
|
"""
|
||||||
|
model = cls(configs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def forward(self, wav):
|
||||||
|
out = self.wavlm(wav)
|
||||||
|
return out
|
@ -0,0 +1,757 @@
|
|||||||
|
# --------------------------------------------------------
|
||||||
|
# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
|
||||||
|
# Github source: https://github.com/microsoft/unilm/tree/master/wavlm
|
||||||
|
# Copyright (c) 2021 Microsoft
|
||||||
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
# Based on fairseq code bases
|
||||||
|
# https://github.com/pytorch/fairseq
|
||||||
|
# --------------------------------------------------------
|
||||||
|
|
||||||
|
import math
|
||||||
|
import logging
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import paddle.nn as nn
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
from paddle.nn import LayerNorm
|
||||||
|
from paddle import Tensor
|
||||||
|
from .modules.modules import (
|
||||||
|
MultiheadAttention,
|
||||||
|
SamePad,
|
||||||
|
init_bert_params,
|
||||||
|
get_activation_fn,
|
||||||
|
TransposeLast,
|
||||||
|
GLU_Linear,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_mask_indices(
|
||||||
|
shape: Tuple[int, int],
|
||||||
|
padding_mask: Optional[Tensor],
|
||||||
|
mask_prob: float,
|
||||||
|
mask_length: int,
|
||||||
|
mask_type: str = "static",
|
||||||
|
mask_other: float = 0.0,
|
||||||
|
min_masks: int = 0,
|
||||||
|
no_overlap: bool = False,
|
||||||
|
min_space: int = 0,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Computes random mask spans for a given shape
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shape: the the shape for which to compute masks.
|
||||||
|
should be of size 2 where first element is batch size and 2nd is timesteps
|
||||||
|
padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
|
||||||
|
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
|
||||||
|
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
|
||||||
|
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
|
||||||
|
mask_type: how to compute mask lengths
|
||||||
|
static = fixed size
|
||||||
|
uniform = sample from uniform distribution [mask_other, mask_length*2]
|
||||||
|
normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
|
||||||
|
poisson = sample from possion distribution with lambda = mask length
|
||||||
|
min_masks: minimum number of masked spans
|
||||||
|
no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
|
||||||
|
min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
|
||||||
|
"""
|
||||||
|
|
||||||
|
bsz, all_sz = shape
|
||||||
|
mask = np.full((bsz, all_sz), False)
|
||||||
|
|
||||||
|
all_num_mask = int(
|
||||||
|
# add a random number for probabilistic rounding
|
||||||
|
mask_prob * all_sz / float(mask_length)
|
||||||
|
+ np.random.rand()
|
||||||
|
)
|
||||||
|
|
||||||
|
all_num_mask = max(min_masks, all_num_mask)
|
||||||
|
|
||||||
|
mask_idcs = []
|
||||||
|
for i in range(bsz):
|
||||||
|
if padding_mask is not None:
|
||||||
|
sz = all_sz - padding_mask[i].long().sum().item()
|
||||||
|
num_mask = int(
|
||||||
|
# add a random number for probabilistic rounding
|
||||||
|
mask_prob * sz / float(mask_length)
|
||||||
|
+ np.random.rand()
|
||||||
|
)
|
||||||
|
num_mask = max(min_masks, num_mask)
|
||||||
|
else:
|
||||||
|
sz = all_sz
|
||||||
|
num_mask = all_num_mask
|
||||||
|
|
||||||
|
if mask_type == "static":
|
||||||
|
lengths = np.full(num_mask, mask_length)
|
||||||
|
elif mask_type == "uniform":
|
||||||
|
lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
|
||||||
|
elif mask_type == "normal":
|
||||||
|
lengths = np.random.normal(mask_length, mask_other, size=num_mask)
|
||||||
|
lengths = [max(1, int(round(x))) for x in lengths]
|
||||||
|
elif mask_type == "poisson":
|
||||||
|
lengths = np.random.poisson(mask_length, size=num_mask)
|
||||||
|
lengths = [int(round(x)) for x in lengths]
|
||||||
|
else:
|
||||||
|
raise Exception("unknown mask selection " + mask_type)
|
||||||
|
|
||||||
|
if sum(lengths) == 0:
|
||||||
|
lengths[0] = min(mask_length, sz - 1)
|
||||||
|
|
||||||
|
if no_overlap:
|
||||||
|
mask_idc = []
|
||||||
|
|
||||||
|
def arrange(s, e, length, keep_length):
|
||||||
|
span_start = np.random.randint(s, e - length)
|
||||||
|
mask_idc.extend(span_start + i for i in range(length))
|
||||||
|
|
||||||
|
new_parts = []
|
||||||
|
if span_start - s - min_space >= keep_length:
|
||||||
|
new_parts.append((s, span_start - min_space + 1))
|
||||||
|
if e - span_start - keep_length - min_space > keep_length:
|
||||||
|
new_parts.append((span_start + length + min_space, e))
|
||||||
|
return new_parts
|
||||||
|
|
||||||
|
parts = [(0, sz)]
|
||||||
|
min_length = min(lengths)
|
||||||
|
for length in sorted(lengths, reverse=True):
|
||||||
|
lens = np.fromiter(
|
||||||
|
(e - s if e - s >= length + min_space else 0 for s, e in parts),
|
||||||
|
np.int,
|
||||||
|
)
|
||||||
|
l_sum = np.sum(lens)
|
||||||
|
if l_sum == 0:
|
||||||
|
break
|
||||||
|
probs = lens / np.sum(lens)
|
||||||
|
c = np.random.choice(len(parts), p=probs)
|
||||||
|
s, e = parts.pop(c)
|
||||||
|
parts.extend(arrange(s, e, length, min_length))
|
||||||
|
mask_idc = np.asarray(mask_idc)
|
||||||
|
else:
|
||||||
|
min_len = min(lengths)
|
||||||
|
if sz - min_len <= num_mask:
|
||||||
|
min_len = sz - num_mask - 1
|
||||||
|
|
||||||
|
mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
|
||||||
|
|
||||||
|
mask_idc = np.asarray(
|
||||||
|
[
|
||||||
|
mask_idc[j] + offset
|
||||||
|
for j in range(len(mask_idc))
|
||||||
|
for offset in range(lengths[j])
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
|
||||||
|
|
||||||
|
min_len = min([len(m) for m in mask_idcs])
|
||||||
|
for i, mask_idc in enumerate(mask_idcs):
|
||||||
|
if len(mask_idc) > min_len:
|
||||||
|
mask_idc = np.random.choice(mask_idc, min_len, replace=False)
|
||||||
|
mask[i, mask_idc] = True
|
||||||
|
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
class WavLMConfig:
|
||||||
|
def __init__(self, cfg=None):
|
||||||
|
self.extractor_mode: str = "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True)
|
||||||
|
self.encoder_layers: int = 12 # num encoder layers in the transformer
|
||||||
|
|
||||||
|
self.encoder_embed_dim: int = 768 # encoder embedding dimension
|
||||||
|
self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
|
||||||
|
self.encoder_attention_heads: int = 12 # num encoder attention heads
|
||||||
|
self.activation_fn: str = "gelu" # activation function to use
|
||||||
|
|
||||||
|
self.layer_norm_first: bool = False # apply layernorm first in the transformer
|
||||||
|
self.conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]
|
||||||
|
self.conv_bias: bool = False # include bias in conv encoder
|
||||||
|
self.feature_grad_mult: float = 1.0 # multiply feature extractor var grads by this
|
||||||
|
|
||||||
|
self.normalize: bool = False # normalize input to have 0 mean and unit variance during training
|
||||||
|
|
||||||
|
# dropouts
|
||||||
|
self.dropout: float = 0.1 # dropout probability for the transformer
|
||||||
|
self.attention_dropout: float = 0.1 # dropout probability for attention weights
|
||||||
|
self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
|
||||||
|
self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
|
||||||
|
self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
|
||||||
|
self.dropout_features: float = 0.0 # dropout to apply to the features (after feat extr)
|
||||||
|
|
||||||
|
# masking
|
||||||
|
self.mask_length: int = 10 # mask length
|
||||||
|
self.mask_prob: float = 0.65 # probability of replacing a token with mask
|
||||||
|
self.mask_selection: str = "static" # how to choose mask length
|
||||||
|
self.mask_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh
|
||||||
|
self.no_mask_overlap: bool = False # whether to allow masks to overlap
|
||||||
|
self.mask_min_space: int = 1 # min space between spans (if no overlap is enabled)
|
||||||
|
|
||||||
|
# channel masking
|
||||||
|
self.mask_channel_length: int = 10 # length of the mask for features (channels)
|
||||||
|
self.mask_channel_prob: float = 0.0 # probability of replacing a feature with 0
|
||||||
|
self.mask_channel_selection: str = "static" # how to choose mask length for channel masking
|
||||||
|
self.mask_channel_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices
|
||||||
|
self.no_mask_channel_overlap: bool = False # whether to allow channel masks to overlap
|
||||||
|
self.mask_channel_min_space: int = 1 # min space between spans (if no overlap is enabled)
|
||||||
|
|
||||||
|
# positional embeddings
|
||||||
|
self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
|
||||||
|
self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
|
||||||
|
|
||||||
|
# relative position embedding
|
||||||
|
self.relative_position_embedding: bool = True # apply relative position embedding
|
||||||
|
self.num_buckets: int = 320 # number of buckets for relative position embedding
|
||||||
|
self.max_distance: int = 1280 # maximum distance for relative position embedding
|
||||||
|
self.gru_rel_pos: bool = True # apply gated relative position embedding
|
||||||
|
|
||||||
|
if cfg is not None:
|
||||||
|
self.update(cfg)
|
||||||
|
|
||||||
|
def update(self, cfg: dict):
|
||||||
|
self.__dict__.update(cfg)
|
||||||
|
|
||||||
|
|
||||||
|
class WavLM(nn.Layer):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cfg: WavLMConfig,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
logger.info(f"WavLM Config: {cfg.__dict__}")
|
||||||
|
|
||||||
|
self.cfg = cfg
|
||||||
|
feature_enc_layers = eval(cfg.conv_feature_layers)
|
||||||
|
self.embed = feature_enc_layers[-1][0]
|
||||||
|
|
||||||
|
self.feature_extractor = ConvFeatureExtractionModel(
|
||||||
|
conv_layers=feature_enc_layers,
|
||||||
|
dropout=0.0,
|
||||||
|
mode=cfg.extractor_mode,
|
||||||
|
conv_bias=cfg.conv_bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.post_extract_proj = (
|
||||||
|
nn.Linear(self.embed, cfg.encoder_embed_dim)
|
||||||
|
if self.embed != cfg.encoder_embed_dim
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
self.mask_prob = cfg.mask_prob
|
||||||
|
self.mask_selection = cfg.mask_selection
|
||||||
|
self.mask_other = cfg.mask_other
|
||||||
|
self.mask_length = cfg.mask_length
|
||||||
|
self.no_mask_overlap = cfg.no_mask_overlap
|
||||||
|
self.mask_min_space = cfg.mask_min_space
|
||||||
|
|
||||||
|
self.mask_channel_prob = cfg.mask_channel_prob
|
||||||
|
self.mask_channel_selection = cfg.mask_channel_selection
|
||||||
|
self.mask_channel_other = cfg.mask_channel_other
|
||||||
|
self.mask_channel_length = cfg.mask_channel_length
|
||||||
|
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
|
||||||
|
self.mask_channel_min_space = cfg.mask_channel_min_space
|
||||||
|
|
||||||
|
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
||||||
|
self.dropout_features = nn.Dropout(cfg.dropout_features)
|
||||||
|
|
||||||
|
self.feature_grad_mult = cfg.feature_grad_mult
|
||||||
|
|
||||||
|
self.mask_emb = self.create_parameter(
|
||||||
|
shape=[cfg.encoder_embed_dim],
|
||||||
|
default_initializer=nn.initializer.Uniform(),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.encoder = TransformerEncoder(cfg)
|
||||||
|
self.layer_norm = LayerNorm(self.embed)
|
||||||
|
|
||||||
|
def apply_mask(self, x, padding_mask):
|
||||||
|
B, T, C = x.shape
|
||||||
|
if self.mask_prob > 0:
|
||||||
|
mask_indices = compute_mask_indices(
|
||||||
|
(B, T),
|
||||||
|
padding_mask,
|
||||||
|
self.mask_prob,
|
||||||
|
self.mask_length,
|
||||||
|
self.mask_selection,
|
||||||
|
self.mask_other,
|
||||||
|
min_masks=2,
|
||||||
|
no_overlap=self.no_mask_overlap,
|
||||||
|
min_space=self.mask_min_space,
|
||||||
|
)
|
||||||
|
# mask_indices = torch.from_numpy(mask_indices).to(x.device)
|
||||||
|
mask_indices = paddle.to_tensor(mask_indices, dtype='int64')
|
||||||
|
x[mask_indices] = self.mask_emb
|
||||||
|
else:
|
||||||
|
mask_indices = None
|
||||||
|
|
||||||
|
if self.mask_channel_prob > 0:
|
||||||
|
mask_channel_indices = compute_mask_indices(
|
||||||
|
(B, C),
|
||||||
|
None,
|
||||||
|
self.mask_channel_prob,
|
||||||
|
self.mask_channel_length,
|
||||||
|
self.mask_channel_selection,
|
||||||
|
self.mask_channel_other,
|
||||||
|
no_overlap=self.no_mask_channel_overlap,
|
||||||
|
min_space=self.mask_channel_min_space,
|
||||||
|
)
|
||||||
|
mask_channel_indices = (
|
||||||
|
# torch.from_numpy(mask_channel_indices)
|
||||||
|
paddle.to_tensor(mask_channel_indices, dtype='int64')
|
||||||
|
.to(x.device)
|
||||||
|
.unsqueeze(1)
|
||||||
|
.expand(-1, T, -1)
|
||||||
|
)
|
||||||
|
x[mask_channel_indices] = 0
|
||||||
|
|
||||||
|
return x, mask_indices
|
||||||
|
|
||||||
|
def forward_padding_mask(
|
||||||
|
self, features: Tensor, padding_mask: Tensor,
|
||||||
|
) -> Tensor:
|
||||||
|
extra = padding_mask.size(1) % features.size(1)
|
||||||
|
if extra > 0:
|
||||||
|
padding_mask = padding_mask[:, :-extra]
|
||||||
|
padding_mask = padding_mask.view(
|
||||||
|
padding_mask.size(0), features.size(1), -1
|
||||||
|
)
|
||||||
|
padding_mask = padding_mask.all(-1)
|
||||||
|
return padding_mask
|
||||||
|
|
||||||
|
def extract_features(
|
||||||
|
self,
|
||||||
|
source: Tensor,
|
||||||
|
padding_mask: Optional[Tensor] = None,
|
||||||
|
mask: bool = False,
|
||||||
|
ret_conv: bool = False,
|
||||||
|
output_layer: Optional[int] = None,
|
||||||
|
ret_layer_results: bool = False,
|
||||||
|
):
|
||||||
|
|
||||||
|
if self.feature_grad_mult > 0:
|
||||||
|
features = self.feature_extractor(source)
|
||||||
|
# if self.feature_grad_mult != 1.0:
|
||||||
|
# features = GradMultiply.apply(features, self.feature_grad_mult)
|
||||||
|
else:
|
||||||
|
# with torch.no_grad():
|
||||||
|
with paddle.no_grad():
|
||||||
|
features = self.feature_extractor(source)
|
||||||
|
|
||||||
|
features = features.transpose([0, 2, 1]) # [1, 49, 512]
|
||||||
|
features = self.layer_norm(features)
|
||||||
|
|
||||||
|
if padding_mask is not None:
|
||||||
|
padding_mask = self.forward_padding_mask(features, padding_mask)
|
||||||
|
|
||||||
|
if self.post_extract_proj is not None:
|
||||||
|
features = self.post_extract_proj(features)
|
||||||
|
# [1, 49, 768]
|
||||||
|
features = self.dropout_input(features)
|
||||||
|
|
||||||
|
if mask:
|
||||||
|
x, mask_indices = self.apply_mask(
|
||||||
|
features, padding_mask
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
x = features
|
||||||
|
|
||||||
|
# feature: (B, T, D), float
|
||||||
|
# target: (B, T), long
|
||||||
|
# x: (B, T, D), float
|
||||||
|
# padding_mask: (B, T), bool
|
||||||
|
# mask_indices: (B, T), bool
|
||||||
|
|
||||||
|
x, layer_results = self.encoder(
|
||||||
|
x,
|
||||||
|
padding_mask=padding_mask,
|
||||||
|
layer=None if output_layer is None else output_layer - 1
|
||||||
|
)
|
||||||
|
# print(f"Debugging: x.shape: {x.shape}, x.mean(): {x.mean()}, x.std(): {x.std()}")
|
||||||
|
res = {"x": x, "padding_mask": padding_mask, "features": features, "layer_results": layer_results}
|
||||||
|
|
||||||
|
feature = res["features"] if ret_conv else res["x"]
|
||||||
|
if ret_layer_results:
|
||||||
|
feature = (feature, res["layer_results"])
|
||||||
|
return feature, res["padding_mask"]
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.extract_features(x)[0]
|
||||||
|
|
||||||
|
|
||||||
|
class ConvFeatureExtractionModel(nn.Layer):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
conv_layers: List[Tuple[int, int, int]],
|
||||||
|
dropout: float = 0.0,
|
||||||
|
mode: str = "default",
|
||||||
|
conv_bias: bool = False,
|
||||||
|
conv_type: str = "default"
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
assert mode in {"default", "layer_norm"}
|
||||||
|
|
||||||
|
def block(
|
||||||
|
n_in,
|
||||||
|
n_out,
|
||||||
|
k,
|
||||||
|
stride,
|
||||||
|
is_layer_norm=False,
|
||||||
|
is_group_norm=False,
|
||||||
|
conv_bias=False,
|
||||||
|
):
|
||||||
|
def make_conv():
|
||||||
|
conv = nn.Conv1D(n_in, n_out, k, stride=stride, bias_attr=conv_bias,
|
||||||
|
weight_attr=nn.initializer.KaimingNormal())
|
||||||
|
# nn.init.kaiming_normal_(conv.weight)
|
||||||
|
return conv
|
||||||
|
|
||||||
|
assert (
|
||||||
|
is_layer_norm and is_group_norm
|
||||||
|
) == False, "layer norm and group norm are exclusive"
|
||||||
|
|
||||||
|
if is_layer_norm:
|
||||||
|
return nn.Sequential(
|
||||||
|
make_conv(),
|
||||||
|
nn.Dropout(p=dropout),
|
||||||
|
nn.Sequential(
|
||||||
|
TransposeLast(),
|
||||||
|
nn.LayerNorm(normalized_shape=dim, epsilon=1e-5),
|
||||||
|
TransposeLast(),
|
||||||
|
),
|
||||||
|
nn.GELU(),
|
||||||
|
)
|
||||||
|
elif is_group_norm:
|
||||||
|
return nn.Sequential(
|
||||||
|
make_conv(),
|
||||||
|
nn.Dropout(p=dropout),
|
||||||
|
nn.GroupNorm(num_groups=dim, num_channels=dim, epsilon=1e-5),
|
||||||
|
nn.GELU(),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
|
||||||
|
|
||||||
|
self.conv_type = conv_type
|
||||||
|
if self.conv_type == "default":
|
||||||
|
in_d = 1
|
||||||
|
self.conv_layers = nn.LayerList()
|
||||||
|
for i, cl in enumerate(conv_layers):
|
||||||
|
assert len(cl) == 3, "invalid conv definition: " + str(cl)
|
||||||
|
(dim, k, stride) = cl
|
||||||
|
|
||||||
|
self.conv_layers.append(
|
||||||
|
block(
|
||||||
|
in_d,
|
||||||
|
dim,
|
||||||
|
k,
|
||||||
|
stride,
|
||||||
|
is_layer_norm=mode == "layer_norm",
|
||||||
|
is_group_norm=mode == "default" and i == 0,
|
||||||
|
conv_bias=conv_bias,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
in_d = dim
|
||||||
|
elif self.conv_type == "conv2d":
|
||||||
|
in_d = 1
|
||||||
|
self.conv_layers = nn.LayerList()
|
||||||
|
for i, cl in enumerate(conv_layers):
|
||||||
|
assert len(cl) == 3
|
||||||
|
(dim, k, stride) = cl
|
||||||
|
|
||||||
|
self.conv_layers.append(
|
||||||
|
paddle.nn.Conv2D(in_d, dim, k, stride)
|
||||||
|
)
|
||||||
|
self.conv_layers.append(paddle.nn.ReLU())
|
||||||
|
in_d = dim
|
||||||
|
elif self.conv_type == "custom":
|
||||||
|
in_d = 1
|
||||||
|
idim = 80
|
||||||
|
self.conv_layers = nn.LayerList()
|
||||||
|
for i, cl in enumerate(conv_layers):
|
||||||
|
assert len(cl) == 3
|
||||||
|
(dim, k, stride) = cl
|
||||||
|
self.conv_layers.append(
|
||||||
|
paddle.nn.Conv2D(in_d, dim, k, stride, padding=1)
|
||||||
|
)
|
||||||
|
self.conv_layers.append(
|
||||||
|
paddle.nn.LayerNorm([dim, idim])
|
||||||
|
)
|
||||||
|
self.conv_layers.append(paddle.nn.ReLU())
|
||||||
|
in_d = dim
|
||||||
|
if (i + 1) % 2 == 0:
|
||||||
|
self.conv_layers.append(
|
||||||
|
paddle.nn.MaxPool2D(2, stride=2, ceil_mode=True)
|
||||||
|
)
|
||||||
|
idim = int(math.ceil(idim / 2))
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def forward(self, x, mask=None):
|
||||||
|
|
||||||
|
# BxT -> BxCxT
|
||||||
|
x = x.unsqueeze(1)
|
||||||
|
if self.conv_type == "custom":
|
||||||
|
for conv in self.conv_layers:
|
||||||
|
if isinstance(conv, nn.LayerNorm):
|
||||||
|
x = x.transpose([0, 2, 1])
|
||||||
|
x = conv(x).transpose([0, 2, 1])
|
||||||
|
else:
|
||||||
|
x = conv(x)
|
||||||
|
x = x.transpose([0, 1, 3, 2]).contiguous()
|
||||||
|
x = x.view(x.size(0), -1, x.size(-1))
|
||||||
|
else:
|
||||||
|
for conv in self.conv_layers:
|
||||||
|
x = conv(x)
|
||||||
|
if self.conv_type == "conv2d":
|
||||||
|
b, c, t, f = x.size()
|
||||||
|
# x = x.transpose(2, 3).contiguous().view(b, c * f, t)
|
||||||
|
x = x.transpose([0, 1, 3, 2]).contiguous().view(b, c * f, t)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEncoder(nn.Layer):
|
||||||
|
def __init__(self, args):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dropout = args.dropout
|
||||||
|
self.embedding_dim = args.encoder_embed_dim
|
||||||
|
dropout = 0
|
||||||
|
std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
|
||||||
|
|
||||||
|
|
||||||
|
self.pos_conv = nn.Conv1D(
|
||||||
|
self.embedding_dim,
|
||||||
|
self.embedding_dim,
|
||||||
|
kernel_size=args.conv_pos,
|
||||||
|
padding=args.conv_pos // 2,
|
||||||
|
groups=args.conv_pos_groups,
|
||||||
|
weight_attr=nn.initializer.Normal(mean=0, std=std),
|
||||||
|
bias_attr=True
|
||||||
|
)
|
||||||
|
# nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
|
||||||
|
# nn.init.constant_(self.pos_conv.bias, 0)
|
||||||
|
|
||||||
|
# self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
|
||||||
|
# self.pos_conv.weight_g = self.pos_conv.weight_g.unsqueeze(0).unsqueeze(0)
|
||||||
|
self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
|
||||||
|
self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
|
||||||
|
|
||||||
|
if hasattr(args, "relative_position_embedding"):
|
||||||
|
self.relative_position_embedding = args.relative_position_embedding
|
||||||
|
self.num_buckets = args.num_buckets
|
||||||
|
self.max_distance = args.max_distance
|
||||||
|
else:
|
||||||
|
self.relative_position_embedding = False
|
||||||
|
self.num_buckets = 0
|
||||||
|
self.max_distance = 0
|
||||||
|
|
||||||
|
self.layers = nn.LayerList(
|
||||||
|
[
|
||||||
|
TransformerSentenceEncoderLayer(
|
||||||
|
embedding_dim=self.embedding_dim,
|
||||||
|
ffn_embedding_dim=args.encoder_ffn_embed_dim,
|
||||||
|
num_attention_heads=args.encoder_attention_heads,
|
||||||
|
dropout=self.dropout,
|
||||||
|
attention_dropout=args.attention_dropout,
|
||||||
|
activation_dropout=args.activation_dropout,
|
||||||
|
activation_fn=args.activation_fn,
|
||||||
|
layer_norm_first=args.layer_norm_first,
|
||||||
|
has_relative_attention_bias=(self.relative_position_embedding and i == 0),
|
||||||
|
num_buckets=self.num_buckets,
|
||||||
|
max_distance=self.max_distance,
|
||||||
|
gru_rel_pos=args.gru_rel_pos,
|
||||||
|
)
|
||||||
|
for i in range(args.encoder_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.layer_norm_first = args.layer_norm_first
|
||||||
|
self.layer_norm = LayerNorm(self.embedding_dim)
|
||||||
|
self.layerdrop = args.encoder_layerdrop
|
||||||
|
|
||||||
|
# self.apply(init_bert_params)
|
||||||
|
|
||||||
|
def forward(self, x, padding_mask=None, streaming_mask=None, layer=None):
|
||||||
|
x, layer_results = self.extract_features(x, padding_mask, streaming_mask, layer)
|
||||||
|
# print("x.shape", x.shape)
|
||||||
|
if self.layer_norm_first and layer is None:
|
||||||
|
x = self.layer_norm(x)
|
||||||
|
|
||||||
|
return x, layer_results
|
||||||
|
|
||||||
|
def extract_features(self, x, padding_mask=None, streaming_mask=None, tgt_layer=None):
|
||||||
|
|
||||||
|
if padding_mask is not None:
|
||||||
|
x[padding_mask] = 0
|
||||||
|
|
||||||
|
x_conv = self.pos_conv(x.transpose([0, 2, 1]))
|
||||||
|
x_conv = x_conv.transpose([0, 2, 1])
|
||||||
|
x += x_conv
|
||||||
|
if not self.layer_norm_first:
|
||||||
|
x = self.layer_norm(x)
|
||||||
|
|
||||||
|
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||||
|
|
||||||
|
# B x T x C -> T x B x C
|
||||||
|
# x = x.transpose(0, 1)
|
||||||
|
x = x.transpose([1, 0, 2])
|
||||||
|
|
||||||
|
|
||||||
|
layer_results = []
|
||||||
|
z = None
|
||||||
|
if tgt_layer is not None:
|
||||||
|
layer_results.append((x, z))
|
||||||
|
r = None
|
||||||
|
pos_bias = None
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
dropout_probability = np.random.random()
|
||||||
|
if not self.training or (dropout_probability > self.layerdrop):
|
||||||
|
x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False,self_attn_mask=streaming_mask, pos_bias=pos_bias)
|
||||||
|
if tgt_layer is not None:
|
||||||
|
layer_results.append((x, z))
|
||||||
|
if i == tgt_layer:
|
||||||
|
r = x
|
||||||
|
break
|
||||||
|
|
||||||
|
if r is not None:
|
||||||
|
x = r
|
||||||
|
|
||||||
|
# T x B x C -> B x T x C
|
||||||
|
# x = x.transpose(0, 1)
|
||||||
|
x = x.transpose([1, 0, 2])
|
||||||
|
|
||||||
|
return x, layer_results
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerSentenceEncoderLayer(nn.Layer):
|
||||||
|
"""
|
||||||
|
Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
|
||||||
|
models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: float = 768,
|
||||||
|
ffn_embedding_dim: float = 3072,
|
||||||
|
num_attention_heads: float = 8,
|
||||||
|
dropout: float = 0.1,
|
||||||
|
attention_dropout: float = 0.1,
|
||||||
|
activation_dropout: float = 0.1,
|
||||||
|
activation_fn: str = "relu",
|
||||||
|
layer_norm_first: bool = False,
|
||||||
|
has_relative_attention_bias: bool = True,
|
||||||
|
num_buckets: int = 0,
|
||||||
|
max_distance: int = 0,
|
||||||
|
rescale_init: bool = False,
|
||||||
|
gru_rel_pos: bool = True,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
# Initialize parameters
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.dropout = dropout
|
||||||
|
self.activation_dropout = activation_dropout
|
||||||
|
|
||||||
|
# Initialize blocks
|
||||||
|
self.activation_name = activation_fn
|
||||||
|
self.activation_fn = get_activation_fn(activation_fn)
|
||||||
|
self.self_attn = MultiheadAttention(
|
||||||
|
self.embedding_dim,
|
||||||
|
num_attention_heads,
|
||||||
|
dropout=attention_dropout,
|
||||||
|
self_attention=True,
|
||||||
|
has_relative_attention_bias=has_relative_attention_bias,
|
||||||
|
num_buckets=num_buckets,
|
||||||
|
max_distance=max_distance,
|
||||||
|
rescale_init=rescale_init,
|
||||||
|
gru_rel_pos=gru_rel_pos,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.dropout1 = nn.Dropout(dropout)
|
||||||
|
self.dropout2 = nn.Dropout(self.activation_dropout)
|
||||||
|
self.dropout3 = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
self.layer_norm_first = layer_norm_first
|
||||||
|
|
||||||
|
# layer norm associated with the self attention layer
|
||||||
|
self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
|
||||||
|
|
||||||
|
if self.activation_name == "glu":
|
||||||
|
self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
|
||||||
|
else:
|
||||||
|
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
|
||||||
|
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
|
||||||
|
|
||||||
|
# layer norm associated with the position wise feed-forward NN
|
||||||
|
self.final_layer_norm = LayerNorm(self.embedding_dim)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
self_attn_mask: Tensor = None,
|
||||||
|
self_attn_padding_mask: Tensor = None,
|
||||||
|
need_weights: bool = False,
|
||||||
|
pos_bias=None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
LayerNorm is applied either before or after the self-attention/ffn
|
||||||
|
modules similar to the original Transformer imlementation.
|
||||||
|
"""
|
||||||
|
residual = x
|
||||||
|
if self.layer_norm_first:
|
||||||
|
|
||||||
|
x = self.self_attn_layer_norm(x)
|
||||||
|
x, attn, pos_bias = self.self_attn(
|
||||||
|
query=x,
|
||||||
|
key=x,
|
||||||
|
value=x,
|
||||||
|
key_padding_mask=self_attn_padding_mask,
|
||||||
|
need_weights=False,
|
||||||
|
attn_mask=self_attn_mask,
|
||||||
|
position_bias=pos_bias
|
||||||
|
)
|
||||||
|
# import pdb; pdb.set_trace()
|
||||||
|
x = self.dropout1(x)
|
||||||
|
x = residual + x
|
||||||
|
|
||||||
|
residual = x
|
||||||
|
x = self.final_layer_norm(x)
|
||||||
|
if self.activation_name == "glu":
|
||||||
|
x = self.fc1(x)
|
||||||
|
else:
|
||||||
|
x = self.activation_fn(self.fc1(x))
|
||||||
|
x = self.dropout2(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
x = self.dropout3(x)
|
||||||
|
x = residual + x
|
||||||
|
else:
|
||||||
|
x, attn, pos_bias = self.self_attn(
|
||||||
|
query=x,
|
||||||
|
key=x,
|
||||||
|
value=x,
|
||||||
|
key_padding_mask=self_attn_padding_mask,
|
||||||
|
need_weights=need_weights,
|
||||||
|
attn_mask=self_attn_mask,
|
||||||
|
position_bias=pos_bias
|
||||||
|
)
|
||||||
|
|
||||||
|
x = self.dropout1(x)
|
||||||
|
x = residual + x
|
||||||
|
|
||||||
|
x = self.self_attn_layer_norm(x)
|
||||||
|
|
||||||
|
residual = x
|
||||||
|
if self.activation_name == "glu":
|
||||||
|
x = self.fc1(x)
|
||||||
|
else:
|
||||||
|
x = self.activation_fn(self.fc1(x))
|
||||||
|
x = self.dropout2(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
x = self.dropout3(x)
|
||||||
|
x = residual + x
|
||||||
|
x = self.final_layer_norm(x)
|
||||||
|
|
||||||
|
return x, attn, pos_bias
|
Loading…
Reference in new issue