add wav2vev2_zh aishell recipe, and speechbrain dataloader. (#2916)
parent
66a9cf8ebc
commit
047092de8e
@ -0,0 +1,3 @@
|
||||
process:
|
||||
# use raw audio
|
||||
- type: wav_process
|
@ -0,0 +1,101 @@
|
||||
# Copyright (c) 2023 speechbrain Authors. All Rights Reserved.
|
||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Modified from speechbrain 2023 (https://github.com/speechbrain/speechbrain/blob/develop/recipes/AISHELL-1/ASR/CTC/hparams/train_with_wav2vec.yaml)
|
||||
|
||||
# ############################################################################
|
||||
# Model: CTC-wav2vec2
|
||||
# Encoder: wav2vec2
|
||||
# Decoder: -
|
||||
# Tokens: Char
|
||||
# losses: CTC
|
||||
# Training: AISHELL-1
|
||||
# Authors: Yingzhi WANG 2022
|
||||
# ############################################################################
|
||||
|
||||
output_folder: !ref data
|
||||
cer_file: !ref <output_folder>/cer.txt
|
||||
save_folder: !ref <output_folder>/save
|
||||
train_log: !ref <output_folder>/train_log.txt
|
||||
|
||||
# Data files
|
||||
data_folder: data/aishell # e,g./path/to/aishell
|
||||
|
||||
skip_prep: False
|
||||
ckpt_interval_minutes: 15 # save checkpoint every N min
|
||||
train_data: !ref <output_folder>/train.csv
|
||||
valid_data: !ref <output_folder>/dev.csv
|
||||
test_data: !ref <output_folder>/test.csv
|
||||
|
||||
wav2vec2_hub: TencentGameMate/chinese-wav2vec2-large
|
||||
|
||||
# Training parameters
|
||||
number_of_epochs: 80
|
||||
lr: 1.0
|
||||
lr_wav2vec: 0.0001
|
||||
sorting: ascending
|
||||
auto_mix_prec: False
|
||||
sample_rate: 16000
|
||||
|
||||
# With data_parallel batch_size is split into N jobs
|
||||
# With DDP batch_size is multiplied by N jobs
|
||||
# Must be 8 per GPU to fit 32GB of VRAM
|
||||
batch_size: 5
|
||||
test_batch_size: 1 # need set to 1 when decoding
|
||||
|
||||
dynamic_batching: False
|
||||
dynamic_batch_sampler:
|
||||
feats_hop_size: 0.01
|
||||
max_batch_len: 15 # in terms of "duration" in annotations by default, second here
|
||||
left_bucket_len: 200 # old implementation attributs
|
||||
multiplier: 1.1 # old implementation attributs
|
||||
shuffle_ex: False # if true re-creates batches at each epoch shuffling examples.
|
||||
num_buckets: 10 # floor(log(max_batch_len/left_bucket_len, multiplier)) + 1
|
||||
batch_ordering: ascending
|
||||
|
||||
num_workers: 6
|
||||
|
||||
# Dataloader options
|
||||
train_dataloader_opts:
|
||||
batch_size: !ref <batch_size>
|
||||
num_workers: !ref <num_workers>
|
||||
valid_dataloader_opts:
|
||||
batch_size: !ref <test_batch_size>
|
||||
num_workers: !ref <num_workers>
|
||||
test_dataloader_opts:
|
||||
batch_size: !ref <test_batch_size>
|
||||
num_workers: !ref <num_workers>
|
||||
|
||||
wav2vec_output_dim: 1024
|
||||
dnn_neurons: 1024
|
||||
freeze_wav2vec: False
|
||||
dropout: 0.15
|
||||
|
||||
tokenizer: !apply:transformers.BertTokenizer.from_pretrained
|
||||
pretrained_model_name_or_path: bert-base-chinese
|
||||
# bert-base-chinese tokens length
|
||||
output_neurons: 21128
|
||||
|
||||
# Decoding parameters
|
||||
# Be sure that the bos and eos index match with the BPEs ones
|
||||
blank_index: 0
|
||||
|
||||
# AISHELL-1 has spaces between words in the transcripts,
|
||||
# which Chinese writing normally does not do.
|
||||
# If remove_spaces, spaces are removed
|
||||
# from the transcript before computing CER.
|
||||
# (e.g., 祝 可爱 的 你 —> 祝可爱的你)
|
||||
remove_spaces: True
|
||||
split_tokens: !apply:operator.not_ [!ref <remove_spaces>]
|
@ -0,0 +1,4 @@
|
||||
decode_batch_size: 1
|
||||
error_rate_type: cer
|
||||
decoding_method: ctc_greedy_search # 'ctc_greedy_search', 'ctc_prefix_beam_search'
|
||||
beam_size: 10
|
@ -0,0 +1,167 @@
|
||||
############################################
|
||||
# Network Architecture #
|
||||
############################################
|
||||
freeze_wav2vec2: False
|
||||
normalize_wav: True
|
||||
output_norm: True
|
||||
init_type: 'kaiming_uniform' # !Warning: need to convergence
|
||||
enc:
|
||||
input_shape: 1024
|
||||
dnn_blocks: 3
|
||||
dnn_neurons: 1024
|
||||
activation: True
|
||||
normalization: True
|
||||
dropout_rate: [0.15, 0.15, 0.0]
|
||||
ctc:
|
||||
enc_n_units: 1024
|
||||
blank_id: 0
|
||||
dropout_rate: 0.0
|
||||
|
||||
audio_augment:
|
||||
speeds: [90, 100, 110]
|
||||
|
||||
spec_augment:
|
||||
time_warp: True
|
||||
time_warp_window: 5
|
||||
time_warp_mode: bicubic
|
||||
freq_mask: True
|
||||
n_freq_mask: 2
|
||||
time_mask: True
|
||||
n_time_mask: 2
|
||||
replace_with_zero: False
|
||||
freq_mask_width: 30
|
||||
time_mask_width: 40
|
||||
wav2vec2_params_path: exp/wav2vec2/chinese-wav2vec2-large.pdparams
|
||||
|
||||
|
||||
############################################
|
||||
# Wav2Vec2.0 #
|
||||
############################################
|
||||
# vocab_size: 1000000
|
||||
hidden_size: 1024
|
||||
num_hidden_layers: 24
|
||||
num_attention_heads: 16
|
||||
intermediate_size: 4096
|
||||
hidden_act: gelu
|
||||
hidden_dropout: 0.1
|
||||
activation_dropout: 0.0
|
||||
attention_dropout: 0.1
|
||||
feat_proj_dropout: 0.1
|
||||
feat_quantizer_dropout: 0.0
|
||||
final_dropout: 0.0
|
||||
layerdrop: 0.1
|
||||
initializer_range: 0.02
|
||||
layer_norm_eps: 1e-5
|
||||
feat_extract_norm: layer
|
||||
feat_extract_activation: gelu
|
||||
conv_dim: [512, 512, 512, 512, 512, 512, 512]
|
||||
conv_stride: [5, 2, 2, 2, 2, 2, 2]
|
||||
conv_kernel: [10, 3, 3, 3, 3, 2, 2]
|
||||
conv_bias: True
|
||||
num_conv_pos_embeddings: 128
|
||||
num_conv_pos_embedding_groups: 16
|
||||
do_stable_layer_norm: True
|
||||
apply_spec_augment: False
|
||||
mask_channel_length: 10
|
||||
mask_channel_min_space: 1
|
||||
mask_channel_other: 0.0
|
||||
mask_channel_prob: 0.0
|
||||
mask_channel_selection: static
|
||||
mask_feature_length: 10
|
||||
mask_feature_min_masks: 0
|
||||
mask_feature_prob: 0.0
|
||||
mask_time_length: 10
|
||||
mask_time_min_masks: 2
|
||||
mask_time_min_space: 1
|
||||
mask_time_other: 0.0
|
||||
mask_time_prob: 0.075
|
||||
mask_time_selection: static
|
||||
num_codevectors_per_group: 320
|
||||
num_codevector_groups: 2
|
||||
contrastive_logits_temperature: 0.1
|
||||
num_negatives: 100
|
||||
codevector_dim: 256
|
||||
proj_codevector_dim: 256
|
||||
diversity_loss_weight: 0.1
|
||||
use_weighted_layer_sum: False
|
||||
# pad_token_id: 0
|
||||
# bos_token_id: 1
|
||||
# eos_token_id: 2
|
||||
add_adapter: False
|
||||
adapter_kernel_size: 3
|
||||
adapter_stride: 2
|
||||
num_adapter_layers: 3
|
||||
output_hidden_size: None
|
||||
|
||||
###########################################
|
||||
# Data #
|
||||
###########################################
|
||||
|
||||
train_manifest: data/manifest.train
|
||||
dev_manifest: data/manifest.dev
|
||||
test_manifest: data/manifest.test
|
||||
vocab_filepath: data/lang_char/vocab.txt
|
||||
|
||||
###########################################
|
||||
# Dataloader #
|
||||
###########################################
|
||||
|
||||
unit_type: 'char'
|
||||
mean_std_filepath:
|
||||
preprocess_config: conf/preprocess.yaml
|
||||
sortagrad: -1 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
|
||||
batch_size: 5 # Different batch_size may cause large differences in results
|
||||
maxlen_in: 51200000000 # if input length > maxlen-in batchsize is automatically reduced
|
||||
maxlen_out: 1500000 # if output length > maxlen-out batchsize is automatically reduced
|
||||
minibatches: 0 # for debug
|
||||
batch_count: auto
|
||||
batch_bins: 0
|
||||
batch_frames_in: 0
|
||||
batch_frames_out: 0
|
||||
batch_frames_inout: 0
|
||||
num_workers: 6
|
||||
subsampling_factor: 1
|
||||
num_encs: 1
|
||||
dist_sampler: True
|
||||
shortest_first: True
|
||||
return_lens_rate: True
|
||||
|
||||
###########################################
|
||||
# use speechbrain dataloader #
|
||||
###########################################
|
||||
use_sb_pipeline: True # whether use speechbrain pipeline. Default is True.
|
||||
sb_pipeline_conf: conf/train_with_wav2vec.yaml
|
||||
|
||||
###########################################
|
||||
# Training #
|
||||
###########################################
|
||||
n_epoch: 80
|
||||
accum_grad: 1
|
||||
global_grad_clip: 5.0
|
||||
|
||||
model_optim: adadelta
|
||||
model_optim_conf:
|
||||
lr: 1.0
|
||||
weight_decay: 0.0
|
||||
rho: 0.95
|
||||
epsilon: 1.0e-8
|
||||
|
||||
wav2vec2_optim: adam
|
||||
wav2vec2_optim_conf:
|
||||
lr: 0.0001
|
||||
weight_decay: 0.0
|
||||
|
||||
model_scheduler: newbobscheduler
|
||||
model_scheduler_conf:
|
||||
improvement_threshold: 0.0025
|
||||
annealing_factor: 0.8
|
||||
patient: 0
|
||||
wav2vec2_scheduler: newbobscheduler
|
||||
wav2vec2_scheduler_conf:
|
||||
improvement_threshold: 0.0025
|
||||
annealing_factor: 0.9
|
||||
patient: 0
|
||||
log_interval: 1
|
||||
checkpoint:
|
||||
kbest_n: 50
|
||||
latest_n: 5
|
@ -0,0 +1,129 @@
|
||||
# Copyright (c) 2023 speechbrain Authors. All Rights Reserved.
|
||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Modified from speechbrain 2023
|
||||
# (https://github.com/speechbrain/speechbrain/blob/develop/recipes/AISHELL-1/aishell_prepare.py)
|
||||
import argparse
|
||||
import csv
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
|
||||
from paddlespeech.s2t.models.wav2vec2.io.dataio import read_audio
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech')
|
||||
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--data_folder",
|
||||
default=DATA_HOME + "/Aishell",
|
||||
type=str,
|
||||
help="Directory to save the dataset. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--save_folder",
|
||||
default="data/",
|
||||
type=str,
|
||||
help="Filepath prefix for output manifests. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--skip_prep",
|
||||
default=False,
|
||||
type=bool,
|
||||
help="If True, skip data preparation. (default: %(default)s)")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def prepare_aishell(data_folder, save_folder, skip_prep=False):
|
||||
"""
|
||||
This function prepares the AISHELL-1 dataset.
|
||||
If the folder does not exist, the zip file will be extracted. If the zip file does not exist, it will be downloaded.
|
||||
data_folder : path to AISHELL-1 dataset.
|
||||
save_folder: path where to store the manifest csv files.
|
||||
skip_prep: If True, skip data preparation.
|
||||
"""
|
||||
if skip_prep:
|
||||
return
|
||||
|
||||
# Create filename-to-transcript dictionary
|
||||
filename2transcript = {}
|
||||
with open(
|
||||
os.path.join(data_folder,
|
||||
"data_aishell/transcript/aishell_transcript_v0.8.txt"),
|
||||
"r", ) as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
key = line.split()[0]
|
||||
value = " ".join(line.split()[1:])
|
||||
filename2transcript[key] = value
|
||||
|
||||
splits = [
|
||||
"train",
|
||||
"dev",
|
||||
"test",
|
||||
]
|
||||
ID_start = 0 # needed to have a unique ID for each audio
|
||||
for split in splits:
|
||||
new_filename = os.path.join(save_folder, split) + ".csv"
|
||||
if os.path.exists(new_filename):
|
||||
continue
|
||||
logger.info("Preparing %s..." % new_filename)
|
||||
|
||||
csv_output = [["ID", "duration", "wav", "transcript"]]
|
||||
entry = []
|
||||
|
||||
all_wavs = glob.glob(
|
||||
os.path.join(data_folder, "data_aishell/wav") + "/" + split +
|
||||
"/*/*.wav")
|
||||
for i in range(len(all_wavs)):
|
||||
filename = all_wavs[i].split("/")[-1].split(".wav")[0]
|
||||
if filename not in filename2transcript:
|
||||
continue
|
||||
signal = read_audio(all_wavs[i])
|
||||
duration = signal.shape[0] / 16000
|
||||
transcript_ = filename2transcript[filename]
|
||||
csv_line = [
|
||||
ID_start + i,
|
||||
str(duration),
|
||||
all_wavs[i],
|
||||
transcript_,
|
||||
]
|
||||
entry.append(csv_line)
|
||||
|
||||
csv_output = csv_output + entry
|
||||
|
||||
with open(new_filename, mode="w") as csv_f:
|
||||
csv_writer = csv.writer(
|
||||
csv_f, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL)
|
||||
for line in csv_output:
|
||||
csv_writer.writerow(line)
|
||||
|
||||
msg = "\t%s successfully created!" % (new_filename)
|
||||
logger.info(msg)
|
||||
|
||||
ID_start += len(all_wavs)
|
||||
|
||||
|
||||
def main():
|
||||
if args.data_folder.startswith('~'):
|
||||
args.data_folder = os.path.expanduser(args.data_folder)
|
||||
|
||||
prepare_aishell(args.data_folder, args.save_folder, skip_prep=False)
|
||||
|
||||
print("Data csv prepare done!")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -0,0 +1,101 @@
|
||||
#!/bin/bash
|
||||
|
||||
stage=-1
|
||||
stop_stage=-1
|
||||
dict_dir=data/lang_char
|
||||
|
||||
. ${MAIN_ROOT}/utils/parse_options.sh || exit -1;
|
||||
|
||||
mkdir -p data
|
||||
mkdir -p ${dict_dir}
|
||||
TARGET_DIR=${MAIN_ROOT}/dataset
|
||||
mkdir -p ${TARGET_DIR}
|
||||
|
||||
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
|
||||
# download data, generate manifests
|
||||
python3 ${TARGET_DIR}/aishell/aishell.py \
|
||||
--manifest_prefix="data/manifest" \
|
||||
--target_dir="${TARGET_DIR}/aishell"
|
||||
|
||||
#generate csv file for speechbrain dataloader
|
||||
python3 local/aishell_prepare.py \
|
||||
--data_folder="${TARGET_DIR}/aishell" \
|
||||
--save_folder="data/"
|
||||
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Prepare Aishell failed. Terminated."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
for dataset in train dev test; do
|
||||
mv data/manifest.${dataset} data/manifest.${dataset}.raw
|
||||
done
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
# compute mean and stddev for normalizer
|
||||
num_workers=$(nproc)
|
||||
python3 ${MAIN_ROOT}/utils/compute_mean_std.py \
|
||||
--manifest_path="data/manifest.train.raw" \
|
||||
--spectrum_type="fbank" \
|
||||
--feat_dim=80 \
|
||||
--delta_delta=false \
|
||||
--stride_ms=10 \
|
||||
--window_ms=25 \
|
||||
--sample_rate=16000 \
|
||||
--use_dB_normalization=False \
|
||||
--num_samples=-1 \
|
||||
--num_workers=${num_workers} \
|
||||
--output_path="data/mean_std.json"
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Compute mean and stddev failed. Terminated."
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
# download data, generate manifests
|
||||
# build vocabulary
|
||||
python3 ${MAIN_ROOT}/utils/build_vocab.py \
|
||||
--unit_type="char" \
|
||||
--count_threshold=0 \
|
||||
--vocab_path="${dict_dir}/vocab.txt" \
|
||||
--manifest_paths "data/manifest.train.raw"
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Build vocabulary failed. Terminated."
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
# format manifest with tokenids, vocab size
|
||||
for dataset in train dev test; do
|
||||
{
|
||||
python3 ${MAIN_ROOT}/utils/format_data.py \
|
||||
--cmvn_path "data/mean_std.json" \
|
||||
--unit_type "char" \
|
||||
--vocab_path="${dict_dir}/vocab.txt" \
|
||||
--manifest_path="data/manifest.${dataset}.raw" \
|
||||
--output_path="data/manifest.${dataset}"
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Formt mnaifest failed. Terminated."
|
||||
exit 1
|
||||
fi
|
||||
} &
|
||||
done
|
||||
wait
|
||||
fi
|
||||
echo "Aishell data preparation done."
|
||||
|
||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
mkdir -p exp/wav2vec2
|
||||
echo "Pretrained wav2vec2 model download"
|
||||
wget -P exp/wav2vec2 https://paddlespeech.bj.bcebos.com/wav2vec/chinese-wav2vec2-large.pdparams
|
||||
fi
|
||||
|
||||
exit 0
|
||||
|
@ -0,0 +1,84 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
||||
echo "using $ngpu gpus..."
|
||||
|
||||
expdir=exp
|
||||
datadir=data
|
||||
|
||||
train_set=train_960
|
||||
recog_set="test-clean test-other dev-clean dev-other"
|
||||
recog_set="test-clean"
|
||||
|
||||
config_path=$1
|
||||
decode_config_path=$2
|
||||
ckpt_prefix=$3
|
||||
|
||||
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
|
||||
|
||||
# download language model
|
||||
#bash local/download_lm_en.sh
|
||||
#if [ $? -ne 0 ]; then
|
||||
# exit 1
|
||||
#fi
|
||||
|
||||
python3 utils/format_rsl.py \
|
||||
--origin_ref data/manifest.test.raw \
|
||||
--trans_ref data/manifest.test.text
|
||||
|
||||
|
||||
for type in ctc_greedy_search; do
|
||||
echo "decoding ${type}"
|
||||
batch_size=1
|
||||
python3 -u ${BIN_DIR}/test.py \
|
||||
--ngpu ${ngpu} \
|
||||
--config ${config_path} \
|
||||
--decode_cfg ${decode_config_path} \
|
||||
--result_file ${ckpt_prefix}.${type}.rsl \
|
||||
--checkpoint_path ${ckpt_prefix} \
|
||||
--opts decode.decoding_method ${type} \
|
||||
--opts decode.decode_batch_size ${batch_size}
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in evaluation!"
|
||||
exit 1
|
||||
fi
|
||||
python3 utils/format_rsl.py \
|
||||
--origin_hyp ${ckpt_prefix}.${type}.rsl \
|
||||
--trans_hyp ${ckpt_prefix}.${type}.rsl.text
|
||||
|
||||
python3 utils/compute-wer.py --char=1 --v=1 \
|
||||
data/manifest.test.text ${ckpt_prefix}.${type}.rsl.text > ${ckpt_prefix}.${type}.error
|
||||
echo "decoding ${type} done."
|
||||
done
|
||||
|
||||
for type in ctc_prefix_beam_search; do
|
||||
echo "decoding ${type}"
|
||||
batch_size=1
|
||||
python3 -u ${BIN_DIR}/test.py \
|
||||
--ngpu ${ngpu} \
|
||||
--config ${config_path} \
|
||||
--decode_cfg ${decode_config_path} \
|
||||
--result_file ${ckpt_prefix}.${type}.rsl \
|
||||
--checkpoint_path ${ckpt_prefix} \
|
||||
--opts decode.decoding_method ${type} \
|
||||
--opts decode.decode_batch_size ${batch_size}
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in evaluation!"
|
||||
exit 1
|
||||
fi
|
||||
python3 utils/format_rsl.py \
|
||||
--origin_hyp ${ckpt_prefix}.${type}.rsl \
|
||||
--trans_hyp ${ckpt_prefix}.${type}.rsl.text
|
||||
|
||||
python3 utils/compute-wer.py --char=1 --v=1 \
|
||||
data/manifest.test-clean.text ${ckpt_prefix}.${type}.rsl.text > ${ckpt_prefix}.${type}.error
|
||||
echo "decoding ${type} done."
|
||||
done
|
||||
|
||||
echo "Finished"
|
||||
|
||||
exit 0
|
@ -0,0 +1,58 @@
|
||||
#!/bin/bash
|
||||
|
||||
if [ $# != 4 ];then
|
||||
echo "usage: ${0} config_path decode_config_path ckpt_path_prefix audio_file"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
||||
echo "using $ngpu gpus..."
|
||||
|
||||
config_path=$1
|
||||
decode_config_path=$2
|
||||
ckpt_prefix=$3
|
||||
audio_file=$4
|
||||
|
||||
mkdir -p data
|
||||
wget -nc https://paddlespeech.bj.bcebos.com/datasets/single_wav/en/demo_002_en.wav -P data/
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f ${audio_file} ]; then
|
||||
echo "Plase input the right audio_file path"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
chunk_mode=false
|
||||
if [[ ${config_path} =~ ^.*chunk_.*yaml$ ]];then
|
||||
chunk_mode=true
|
||||
fi
|
||||
|
||||
# download language model
|
||||
#bash local/download_lm_ch.sh
|
||||
#if [ $? -ne 0 ]; then
|
||||
# exit 1
|
||||
#fi
|
||||
|
||||
for type in ctc_greedy_search; do
|
||||
echo "decoding ${type}"
|
||||
batch_size=1
|
||||
output_dir=${ckpt_prefix}
|
||||
mkdir -p ${output_dir}
|
||||
python3 -u ${BIN_DIR}/test_wav.py \
|
||||
--ngpu ${ngpu} \
|
||||
--config ${config_path} \
|
||||
--decode_cfg ${decode_config_path} \
|
||||
--result_file ${output_dir}/${type}.rsl \
|
||||
--checkpoint_path ${ckpt_prefix} \
|
||||
--opts decode.decoding_method ${type} \
|
||||
--opts decode.decode_batch_size ${batch_size} \
|
||||
--audio_file ${audio_file}
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in evaluation!"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
exit 0
|
@ -0,0 +1,59 @@
|
||||
#!/bin/bash
|
||||
|
||||
if [ $# -lt 2 ] && [ $# -gt 3 ];then
|
||||
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name ips(optional)"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
||||
echo "using $ngpu gpus..."
|
||||
|
||||
config_path=$1
|
||||
ckpt_name=$2
|
||||
resume=$3
|
||||
ips=$4
|
||||
|
||||
if [ ! $ips ];then
|
||||
ips_config=
|
||||
else
|
||||
ips_config="--ips="${ips}
|
||||
fi
|
||||
|
||||
mkdir -p exp
|
||||
|
||||
# seed may break model convergence
|
||||
seed=2
|
||||
if [ ${seed} != 0 ]; then
|
||||
export FLAGS_cudnn_deterministic=True
|
||||
fi
|
||||
|
||||
# export FLAGS_cudnn_exhaustive_search=true
|
||||
# export FLAGS_conv_workspace_size_limit=4000
|
||||
# export FLAGS_allocator_strategy=naive_best_fit
|
||||
|
||||
if [ ${ngpu} == 0 ]; then
|
||||
python3 -u ${BIN_DIR}/train.py \
|
||||
--ngpu ${ngpu} \
|
||||
--config ${config_path} \
|
||||
--output exp/${ckpt_name} \
|
||||
--seed ${seed} \
|
||||
--resume ${resume}
|
||||
else
|
||||
python3 -m paddle.distributed.launch --log_dir=${ckpt_name} --gpus=${CUDA_VISIBLE_DEVICES} ${ips_config} ${BIN_DIR}/train.py \
|
||||
--ngpu ${ngpu} \
|
||||
--config ${config_path} \
|
||||
--output exp/${ckpt_name} \
|
||||
--seed ${seed} \
|
||||
--resume ${resume}
|
||||
fi
|
||||
|
||||
if [ ${seed} != 0 ]; then
|
||||
unset FLAGS_cudnn_deterministic
|
||||
fi
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in training!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
exit 0
|
@ -0,0 +1,15 @@
|
||||
export MAIN_ROOT=`realpath ${PWD}/../../../`
|
||||
|
||||
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/tools/sctk/bin:${PWD}/utils:${PATH}
|
||||
export LC_ALL=C
|
||||
|
||||
export PYTHONDONTWRITEBYTECODE=1
|
||||
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
|
||||
export PYTHONIOENCODING=UTF-8
|
||||
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
|
||||
|
||||
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
|
||||
|
||||
|
||||
MODEL=wav2vec2
|
||||
export BIN_DIR=${MAIN_ROOT}/paddlespeech/s2t/exps/${MODEL}/bin
|
@ -0,0 +1,48 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
. ./path.sh || exit 1;
|
||||
. ./cmd.sh || exit 1;
|
||||
|
||||
gpus=0,1,2,3
|
||||
stage=0
|
||||
stop_stage=4
|
||||
conf_path=conf/wav2vec2ASR.yaml
|
||||
ips= #xx.xx.xx.xx,xx.xx.xx.xx
|
||||
decode_conf_path=conf/tuning/decode.yaml
|
||||
avg_num=1
|
||||
resume= # xx e.g. 30
|
||||
export FLAGS_cudnn_deterministic=1
|
||||
. ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
|
||||
|
||||
audio_file=data/demo_002_en.wav
|
||||
|
||||
avg_ckpt=avg_${avg_num}
|
||||
ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}')
|
||||
echo "checkpoint name ${ckpt}"git revert -v
|
||||
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
# prepare data
|
||||
bash ./local/data.sh || exit -1
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
# train model, all `ckpt` under `exp` dir
|
||||
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${resume} ${ips}
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
# avg n best model
|
||||
avg.sh last exp/${ckpt}/checkpoints ${avg_num}
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
# greedy search decoder
|
||||
CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
|
||||
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
||||
# test a single .wav file
|
||||
CUDA_VISIBLE_DEVICES=0 ./local/test_wav.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${audio_file} || exit -1
|
||||
fi
|
@ -0,0 +1 @@
|
||||
../../../utils
|
@ -0,0 +1,13 @@
|
||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -0,0 +1,107 @@
|
||||
# Copyright (c) 2023 speechbrain Authors. All Rights Reserved.
|
||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Modified from speechbrain 2023 (https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/dataio/batch.py)
|
||||
"""Batch collation
|
||||
|
||||
Authors
|
||||
* Aku Rouhe 2020
|
||||
"""
|
||||
import collections
|
||||
|
||||
import paddle
|
||||
|
||||
from paddlespeech.s2t.io.speechbrain.data_utils import batch_pad_right
|
||||
from paddlespeech.s2t.io.speechbrain.data_utils import mod_default_collate
|
||||
|
||||
PaddedData = collections.namedtuple("PaddedData", ["data", "lengths"])
|
||||
|
||||
|
||||
class PaddedBatch:
|
||||
"""Collate_fn when examples are dicts and have variable-length sequences.
|
||||
|
||||
Different elements in the examples get matched by key.
|
||||
All numpy tensors get converted to paddle.Tensor
|
||||
Then, by default, all paddle.Tensor valued elements get padded and support
|
||||
collective pin_memory() and to() calls.
|
||||
Regular Python data types are just collected in a list.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
examples : list
|
||||
List of example dicts, as produced by Dataloader.
|
||||
padded_keys : list, None
|
||||
(Optional) List of keys to pad on. If None, pad all paddle.Tensors
|
||||
device_prep_keys : list, None
|
||||
(Optional) Only these keys participate in collective memory pinning and moving with
|
||||
to().
|
||||
If None, defaults to all items with paddle.Tensor values.
|
||||
padding_func : callable, optional
|
||||
Called with a list of tensors to be padded together. Needs to return
|
||||
two tensors: the padded data, and another tensor for the data lengths.
|
||||
padding_kwargs : dict
|
||||
(Optional) Extra kwargs to pass to padding_func. E.G. mode, value
|
||||
nonpadded_stack : bool
|
||||
Whether to apply Tensor stacking on values that didn't get padded.
|
||||
This stacks if it can, but doesn't error out if it cannot.
|
||||
Default:True, usually does the right thing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
examples,
|
||||
padded_keys=None,
|
||||
device_prep_keys=None,
|
||||
padding_func=batch_pad_right,
|
||||
padding_kwargs={},
|
||||
nonpadded_stack=True, ):
|
||||
self.__length = len(examples)
|
||||
self.__keys = list(examples[0].keys())
|
||||
self.__padded_keys = []
|
||||
self.__device_prep_keys = []
|
||||
for key in self.__keys:
|
||||
values = [example[key] for example in examples]
|
||||
# Default convert usually does the right thing (numpy2tensor etc.)
|
||||
values = paddle.to_tensor(values)
|
||||
|
||||
if (padded_keys is not None and key in padded_keys) or (
|
||||
padded_keys is None and
|
||||
isinstance(values[0], paddle.Tensor)):
|
||||
# Padding and PaddedData
|
||||
self.__padded_keys.append(key)
|
||||
padded = PaddedData(*padding_func(values, **padding_kwargs))
|
||||
setattr(self, key, padded)
|
||||
else:
|
||||
if nonpadded_stack:
|
||||
values = mod_default_collate(values)
|
||||
setattr(self, key, values)
|
||||
if (device_prep_keys is not None and key in device_prep_keys) or (
|
||||
device_prep_keys is None and
|
||||
isinstance(values[0], paddle.Tensor)):
|
||||
self.__device_prep_keys.append(key)
|
||||
|
||||
def __len__(self):
|
||||
return self.__length
|
||||
|
||||
def __getitem__(self, key):
|
||||
if key in self.__keys:
|
||||
return getattr(self, key)
|
||||
else:
|
||||
raise KeyError(f"Batch doesn't have key: {key}")
|
||||
|
||||
def __iter__(self):
|
||||
"""Iterates over the different elements of the batch.
|
||||
"""
|
||||
return iter((getattr(self, key) for key in self.__keys))
|
@ -0,0 +1,488 @@
|
||||
# Copyright (c) 2023 speechbrain Authors. All Rights Reserved.
|
||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Modified from speechbrain 2023 (https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/utils/data_pipeline.py)
|
||||
"""A pipeline for data transformations.
|
||||
|
||||
Author:
|
||||
* Aku Rouhe
|
||||
"""
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
|
||||
from paddlespeech.s2t.io.speechbrain.depgraph import DependencyGraph
|
||||
|
||||
|
||||
@dataclass
|
||||
class StaticItem:
|
||||
"""Data class that represents a static item.
|
||||
|
||||
Static items are in-memory items so they don't need to be computed
|
||||
dynamically.
|
||||
"""
|
||||
|
||||
key: str
|
||||
|
||||
|
||||
class DynamicItem:
|
||||
"""Essentially represents a data transformation function.
|
||||
|
||||
A DynamicItem takes some arguments and computes its value dynamically when
|
||||
called. A straight-forward use-case is to load something from disk
|
||||
dynamically; take the path and provide the loaded data.
|
||||
|
||||
Instances of this class are often created implicitly via the
|
||||
@takes and @provides decorators or otherwise from specifying the taken and
|
||||
provided arguments and the function.
|
||||
|
||||
A counterpart is the GeneratorDynamicItem, which should be used for
|
||||
generator functions.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
takes : list
|
||||
The keys of the items that this needs to compute its output.
|
||||
func : callable
|
||||
The function that is used to compute the output.
|
||||
provides : list
|
||||
The keys that this provides.
|
||||
"""
|
||||
|
||||
def __init__(self, takes=[], func=None, provides=[]):
|
||||
self.takes = takes
|
||||
self.func = func
|
||||
self.provides = provides
|
||||
|
||||
def __call__(self, *args):
|
||||
return self.func(*args)
|
||||
|
||||
# The next methods are more about supporting GeneratorDynamicItems
|
||||
def next_takes(self):
|
||||
"""The next argkeys to provide to this, when called."""
|
||||
# Regular function DynamicItems always just need the same set of args
|
||||
return self.takes
|
||||
|
||||
def next_provides(self):
|
||||
"""The next keys that this provides, when called."""
|
||||
# Regular function DynamicItems always just provide the same set of keys
|
||||
return self.provides
|
||||
|
||||
def provided_in_order(self):
|
||||
"""Assuming that this may need to be called multiple times; which keys
|
||||
does it provide at that call. Returns a list, with len equal to the
|
||||
number of times that this may be called."""
|
||||
# Regular function DynamicItems are only called once:
|
||||
return [self.provides]
|
||||
|
||||
def reset(self):
|
||||
"""Signals that this will not be called any more times on this pipeline
|
||||
call."""
|
||||
# Regular function DynamicItems don't need special resets.
|
||||
pass
|
||||
|
||||
|
||||
class GeneratorDynamicItem(DynamicItem):
|
||||
"""Essentially represents a multi-step data transformation.
|
||||
|
||||
This is the generator function counterpart for DynamicItem (which should be
|
||||
used for regular functions).
|
||||
|
||||
A GeneratorDynamicItem first takes some arguments and then uses those in
|
||||
multiple steps to incrementally compute some values when called.
|
||||
|
||||
A typical use-case is a pipeline of transformations on data: e.g. taking in
|
||||
text as a string, and first a tokenized version, and then on the second
|
||||
call providing an integer-encoded version. This can be used even though the
|
||||
integer-encoder needs to be trained on the first outputs.
|
||||
|
||||
The main benefit is to be able to define the pipeline in a clear function,
|
||||
even if parts of the pipeline depend on others for their initialization.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# Doesn't generate electricity, only stores the currently active
|
||||
# generator:
|
||||
self.current_generator = None
|
||||
self.num_provided_items = 0
|
||||
|
||||
def __call__(self, *args):
|
||||
if self.num_provided_items == len(self.provides):
|
||||
raise RuntimeError("DynamicItemPipeline called too many times!")
|
||||
if not self.current_generator:
|
||||
self.current_generator = self.func(*args)
|
||||
# NOTE: Not supporting sending new values to the pipeline.
|
||||
out = next(self.current_generator)
|
||||
self.num_provided_items += 1
|
||||
return out
|
||||
|
||||
def next_takes(self):
|
||||
"""The next argkeys to provide to this, when called."""
|
||||
if not self.current_generator:
|
||||
return self.takes
|
||||
else:
|
||||
return []
|
||||
|
||||
def next_provides(self):
|
||||
"""The next keys that this provides, when called."""
|
||||
keys = self.provides[self.num_provided_items]
|
||||
# Support multiple yielded values like:
|
||||
# @yields("wav_read", ["left_ch", "right_ch"])
|
||||
if isinstance(keys, str):
|
||||
return [keys]
|
||||
else:
|
||||
return keys
|
||||
|
||||
def provided_in_order(self):
|
||||
"""Assuming that this may need to be called multiple times; which keys
|
||||
does it provide at that call. Returns a list, with len equal to the
|
||||
number of times that this may be called."""
|
||||
in_order = []
|
||||
for keys in self.provides:
|
||||
# Support multiple yielded values like:
|
||||
# @provides("wav_read", ["left_ch", "right_ch"])
|
||||
if isinstance(keys, str):
|
||||
in_order.append([keys])
|
||||
else:
|
||||
in_order.append(keys)
|
||||
return in_order
|
||||
|
||||
def reset(self):
|
||||
"""Signals that this will not be called any more times on this pipeline
|
||||
call."""
|
||||
if self.current_generator is not None:
|
||||
self.current_generator.close()
|
||||
self.current_generator = None
|
||||
self.num_provided_items = 0
|
||||
|
||||
|
||||
def takes(*argkeys):
|
||||
"""Decorator which makes a DynamicItem and specifies its argkeys.
|
||||
|
||||
If the wrapped object is a generator function (has a yield statement),
|
||||
Creates a GeneratorDynamicItem. If the object is already a DynamicItem,
|
||||
just specifies the argkeys for that. Otherwise creates a new regular
|
||||
DynamicItem, with argkeys specified.
|
||||
|
||||
The args are always passed to the function at the start. Generators could
|
||||
support sending new arguments, but for such use cases, simply create a new
|
||||
dynamic item. The GeneratorDynamicItem class is meant for pipelines which
|
||||
take in an input and transform it in multiple ways, where the intermediate
|
||||
representations may be needed for e.g. fitting a BPE segmenter.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> @takes("text")
|
||||
... def tokenize(text):
|
||||
... return text.strip().lower().split()
|
||||
>>> tokenize.provides = ["tokenized"]
|
||||
>>> tokenize('\tThis Example gets tokenized')
|
||||
['this', 'example', 'gets', 'tokenized']
|
||||
"""
|
||||
|
||||
def decorator(obj):
|
||||
"""Decorator definition."""
|
||||
if isinstance(obj, DynamicItem):
|
||||
if obj.takes:
|
||||
raise ValueError("Can't overwrite DynamicItem.takes")
|
||||
obj.takes = argkeys
|
||||
return obj
|
||||
elif inspect.isgeneratorfunction(obj):
|
||||
return GeneratorDynamicItem(takes=argkeys, func=obj)
|
||||
else:
|
||||
return DynamicItem(takes=argkeys, func=obj)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
takes_decorator = takes # Just for DataPipeline.add_dynamic_item
|
||||
|
||||
|
||||
def provides(*output_keys):
|
||||
"""Decorator which makes a DynamicItem and specifies what keys it provides.
|
||||
|
||||
If the wrapped object is a generator function (has a yield statement),
|
||||
Creates a GeneratorDynamicItem. If the object is already a DynamicItem,
|
||||
just specifies the provided keys for that. Otherwise creates a new regular
|
||||
DynamicItem, with provided keys specified.
|
||||
|
||||
NOTE
|
||||
----
|
||||
The behavior is slightly different for generators and regular functions, if
|
||||
many output keys are specified, e.g. @provides("signal", "mfcc"). Regular
|
||||
functions should return a tuple with len equal to len(output_keys), while
|
||||
generators should yield the items one by one.
|
||||
|
||||
>>> @provides("signal", "feat")
|
||||
... def read_feat():
|
||||
... wav = [.1,.2,-.1]
|
||||
... feat = [s**2 for s in wav]
|
||||
... return wav, feat
|
||||
>>> @provides("signal", "feat")
|
||||
... def read_feat():
|
||||
... wav = [.1,.2,-.1]
|
||||
... yield wav
|
||||
... feat = [s**2 for s in wav]
|
||||
... yield feat
|
||||
|
||||
If multiple keys are yielded at once, write e.g.,
|
||||
|
||||
>>> @provides("wav_read", ["left_channel", "right_channel"])
|
||||
... def read_multi_channel():
|
||||
... wav = [[.1,.2,-.1],[.2,.1,-.1]]
|
||||
... yield wav
|
||||
... yield wav[0], wav[1]
|
||||
|
||||
"""
|
||||
|
||||
def decorator(obj):
|
||||
"""Decorator definition."""
|
||||
if isinstance(obj, DynamicItem):
|
||||
if obj.provides:
|
||||
raise ValueError("Can't overwrite DynamicItem provides-list.")
|
||||
obj.provides = output_keys
|
||||
return obj
|
||||
elif inspect.isgeneratorfunction(obj):
|
||||
return GeneratorDynamicItem(func=obj, provides=output_keys)
|
||||
else:
|
||||
return DynamicItem(func=obj, provides=output_keys)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
provides_decorator = provides # Just for DataPipeline.add_dynamic_item
|
||||
|
||||
|
||||
class DataPipeline:
|
||||
"""Organises data transformations into a pipeline.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> pipeline = DataPipeline(
|
||||
... static_data_keys=["text"],
|
||||
... dynamic_items=[
|
||||
... {"func": lambda x: x.lower(), "takes": "text", "provides": "foo"},
|
||||
... {"func": lambda x: x[::-1], "takes": "foo", "provides": "bar"},
|
||||
... ],
|
||||
... output_keys=["bar"],
|
||||
... )
|
||||
>>> pipeline({"text": "Test"})
|
||||
{'bar': 'tset'}
|
||||
"""
|
||||
|
||||
def __init__(self, static_data_keys, dynamic_items=[], output_keys=[]):
|
||||
self.dg = DependencyGraph()
|
||||
self._exec_order = None
|
||||
self.key_to_node = {}
|
||||
self.unaccounted_keys = {}
|
||||
self.dynamic_items = []
|
||||
self.output_mapping = {}
|
||||
self.add_static_keys(static_data_keys)
|
||||
self.add_dynamic_items(dynamic_items)
|
||||
self.set_output_keys(output_keys)
|
||||
|
||||
def add_static_keys(self, static_keys):
|
||||
"""Informs the pipeline about static items.
|
||||
|
||||
Static items are the ones provided to __call__ as data.
|
||||
"""
|
||||
for key in static_keys:
|
||||
node_id = self.dg.add_node(data=StaticItem(key=key))
|
||||
self.key_to_node[key] = node_id
|
||||
|
||||
def add_dynamic_items(self, dynamic_items):
|
||||
"""Add multiple dynamic items at once."""
|
||||
for item in dynamic_items:
|
||||
try:
|
||||
self.add_dynamic_item(**item)
|
||||
except TypeError:
|
||||
self.add_dynamic_item(item)
|
||||
|
||||
def add_dynamic_item(self, func, takes=None, provides=None):
|
||||
"""Adds a dynamic item to the Pipeline.
|
||||
|
||||
Two calling conventions. For DynamicItem objects, just use:
|
||||
add_dynamic_item(dynamic_item)
|
||||
But otherwise, should use:
|
||||
add_dynamic_item(func, takes, provides)
|
||||
|
||||
Arguments
|
||||
---------
|
||||
func : callable, DynamicItem
|
||||
If a DynamicItem is given, adds that directly. Otherwise a
|
||||
DynamicItem is created, and this specifies the callable to use. If
|
||||
a generator function is given, then create a GeneratorDynamicItem.
|
||||
Otherwise creates a normal DynamicItem.
|
||||
takes : list, str
|
||||
List of keys. When func is called, each key is resolved to
|
||||
either an entry in the data or the output of another dynamic_item.
|
||||
The func is then called with these as positional arguments,
|
||||
in the same order as specified here.
|
||||
A single key can be given as a bare string.
|
||||
provides : str, list
|
||||
For regular functions, the key or list of keys that it provides.
|
||||
If you give a generator function, key or list of keys that it
|
||||
yields, in order. Also see the provides decorator.
|
||||
A single key can be given as a bare string.
|
||||
"""
|
||||
if isinstance(func, DynamicItem):
|
||||
if takes is not None or provides is not None:
|
||||
raise ValueError("If providing a DynamicItem directly, don't "
|
||||
"specify takes or provides")
|
||||
else:
|
||||
self._add_dynamic_item_object(func)
|
||||
return
|
||||
if isinstance(takes, str):
|
||||
takes = [takes]
|
||||
if isinstance(provides, str):
|
||||
provides = [provides]
|
||||
di = takes_decorator(*takes)(provides_decorator(*provides)(func))
|
||||
self._add_dynamic_item_object(di)
|
||||
|
||||
def _add_dynamic_item_object(self, obj):
|
||||
"""Internally adds the object.
|
||||
|
||||
There is a node in the dependency graph for each call of the
|
||||
DynamicItem. Each call may return multiple keys and depend on multiple
|
||||
keys. An internal dict maps key to the id of the node that produces it.
|
||||
"""
|
||||
if not obj.provides:
|
||||
raise ValueError("Won't add redundant dynamic item which doesn't "
|
||||
"provide anything.")
|
||||
depended = []
|
||||
for key in obj.takes:
|
||||
# Might not be accounted for, yet:
|
||||
if key not in self.key_to_node:
|
||||
dependee_keys = self.unaccounted_keys.setdefault(key, [])
|
||||
dependee_keys.extend(obj.next_provides())
|
||||
else:
|
||||
depended.append(self.key_to_node[key])
|
||||
for provided in obj.provided_in_order():
|
||||
node_id = self.dg.add_node(data=obj)
|
||||
for key in provided:
|
||||
self.key_to_node[key] = node_id
|
||||
# This key may also be unaccounted for, so account for it now:
|
||||
if key in self.unaccounted_keys:
|
||||
for dependee_key in self.unaccounted_keys[key]:
|
||||
dependee_node = self.key_to_node[dependee_key]
|
||||
self.dg.add_edge(dependee_node, node_id)
|
||||
del self.unaccounted_keys[key] # Now accounted for!
|
||||
for dep_id in depended:
|
||||
self.dg.add_edge(node_id, dep_id)
|
||||
# Next call will depend on this call:
|
||||
depended = [node_id]
|
||||
# Keep a reference to the item in this object, as well:
|
||||
self.dynamic_items.append(obj)
|
||||
|
||||
def set_output_keys(self, keys):
|
||||
"""Use this to change the output keys.
|
||||
|
||||
Also re-evaluates execution order.
|
||||
So if you request different outputs, some parts of the
|
||||
data pipeline may be skipped.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
keys : dict, list, None
|
||||
List of keys (str) to produce in output.
|
||||
|
||||
If a dict is given; it is used to map internal keys to output keys.
|
||||
From the output_keys dict key:value pairs the key appears outside,
|
||||
and value is the internal key.
|
||||
"""
|
||||
self.output_mapping = self._output_keys_to_mapping(keys)
|
||||
self._exec_order = None
|
||||
|
||||
@staticmethod
|
||||
def _output_keys_to_mapping(keys):
|
||||
# Ensure a mapping (accept a list for convenience, too)
|
||||
if keys is None:
|
||||
output_mapping = {}
|
||||
elif isinstance(keys, dict):
|
||||
output_mapping = keys
|
||||
else:
|
||||
output_mapping = {key: key for key in keys}
|
||||
return output_mapping
|
||||
|
||||
def compute_outputs(self, data):
|
||||
"""
|
||||
Arguments
|
||||
---------
|
||||
data : dict
|
||||
Dictionary with data entries by key.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
With the keys that were set.
|
||||
"""
|
||||
if self._exec_order is None:
|
||||
self._prepare_run(data)
|
||||
return self._compute(data, self._exec_order, self.output_mapping)
|
||||
|
||||
def compute_specific(self, keys, data):
|
||||
"""Compute output of specific item, without changing output_keys."""
|
||||
output_mapping = self._output_keys_to_mapping(keys)
|
||||
order = self.dg.get_evaluation_order(
|
||||
selected_keys=self.get_selected_node_ids(keys))
|
||||
return self._compute(data, order, output_mapping)
|
||||
|
||||
def _compute(self, data, order, output_mapping):
|
||||
if self.unaccounted_keys:
|
||||
MSG = "These keys are still unaccounted for in the data pipeline: "
|
||||
MSG += ", ".join(self.unaccounted_keys)
|
||||
raise RuntimeError(MSG)
|
||||
intermediate = {}
|
||||
for node_id, edges, item in order:
|
||||
if isinstance(item, StaticItem):
|
||||
# Static item in data.
|
||||
# Just check that key is found.
|
||||
try:
|
||||
data[item.key]
|
||||
continue
|
||||
except KeyError:
|
||||
raise KeyError(f"Expected key {item.key} in data!")
|
||||
# A dynamic item, which we should compute:
|
||||
args = [
|
||||
data[argkey] if argkey in data else intermediate[argkey]
|
||||
for argkey in item.next_takes()
|
||||
]
|
||||
# This needs to be called BEFORE the dynamic item is called.
|
||||
provided_keys = item.next_provides()
|
||||
values = item(*args) # Call the DynamicItem to produce output
|
||||
# If there is just one output value, wrap in a list so that
|
||||
# it can be zipped as well:
|
||||
if len(provided_keys) == 1:
|
||||
values = [values]
|
||||
intermediate.update(zip(provided_keys, values))
|
||||
for dynamic_item in self.dynamic_items:
|
||||
dynamic_item.reset()
|
||||
return {
|
||||
outkey: data[inkey] if inkey in data else intermediate[inkey]
|
||||
for outkey, inkey in output_mapping.items()
|
||||
}
|
||||
|
||||
def get_selected_node_ids(self, selected_keys):
|
||||
"""Translates selected keys to dependency graph keys."""
|
||||
return [self.key_to_node[key] for key in selected_keys]
|
||||
|
||||
def __call__(self, data):
|
||||
return self.compute_outputs(data)
|
||||
|
||||
def _prepare_run(self, data):
|
||||
self._exec_order = list(
|
||||
self.dg.get_evaluation_order(
|
||||
self.get_selected_node_ids(self.output_mapping.values())))
|
@ -0,0 +1,177 @@
|
||||
# Copyright (c) 2023 speechbrain Authors. All Rights Reserved.
|
||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Modified from speechbrain 2023 (https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/utils/data_utils.py)
|
||||
import collections.abc
|
||||
import csv
|
||||
import os
|
||||
import pathlib
|
||||
import re
|
||||
import shutil
|
||||
import urllib.request
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import tqdm
|
||||
|
||||
|
||||
def batch_pad_right(array: list, mode="constant", value=0):
|
||||
"""Given a list of paddle tensors it batches them together by padding to the right
|
||||
on each dimension in order to get same length for all.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
array : list
|
||||
List of tensor we wish to pad together.
|
||||
mode : str
|
||||
Padding mode see numpy.pad documentation.
|
||||
value : float
|
||||
Padding value see numpy.pad documentation.
|
||||
|
||||
Returns
|
||||
-------
|
||||
batched : numpy array
|
||||
Padded numpy array.
|
||||
valid_vals : list
|
||||
List containing proportion for each dimension of original, non-padded values.
|
||||
|
||||
"""
|
||||
|
||||
if not len(array):
|
||||
raise IndexError("Tensors list must not be empty")
|
||||
|
||||
if len(array) == 1:
|
||||
# if there is only one tensor in the batch we simply unsqueeze it.
|
||||
return np.expand_dims(array[0], 0), np.array([1.0], dtype="float32")
|
||||
if not (any(
|
||||
[array[i].ndim == array[0].ndim for i in range(1, len(array))])):
|
||||
raise IndexError("All array must have same number of dimensions")
|
||||
|
||||
# FIXME we limit the support here: we allow padding of only the first dimension
|
||||
# need to remove this when feat extraction is updated to handle multichannel.
|
||||
max_shape = []
|
||||
for dim in range(array[0].ndim):
|
||||
if dim != 0:
|
||||
if not all(
|
||||
[x.shape[dim] == array[0].shape[dim] for x in array[1:]]):
|
||||
raise EnvironmentError(
|
||||
"Tensors should have same dimensions except for the first one"
|
||||
)
|
||||
max_shape.append(max([x.shape[dim] for x in array]))
|
||||
|
||||
batched = []
|
||||
valid = []
|
||||
for t in array:
|
||||
# for each tensor we apply pad_right_to
|
||||
padded, valid_percent = pad_right_to(
|
||||
t, max_shape, mode=mode, value=value)
|
||||
batched.append(padded)
|
||||
valid.append(valid_percent[0])
|
||||
|
||||
batched = np.stack(batched)
|
||||
|
||||
return batched, np.array(valid, dtype="float32")
|
||||
|
||||
|
||||
np_str_obj_array_pattern = re.compile(r"[SaUO]")
|
||||
|
||||
|
||||
def pad_right_to(
|
||||
array: np.ndarray,
|
||||
target_shape: (list, tuple),
|
||||
mode="constant",
|
||||
value=0, ):
|
||||
"""
|
||||
This function takes a numpy of arbitrary shape and pads it to target
|
||||
shape by appending values on the right.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
array : input numpy array
|
||||
Input tensor whose dimension we need to pad.
|
||||
target_shape : (list, tuple)
|
||||
Target shape we want for the target tensor its len must be equal to tensor.ndim
|
||||
mode : str
|
||||
Pad mode, please refer to numpy.pad documentation.
|
||||
value : float
|
||||
Pad value, please refer to numpy.pad documentation.
|
||||
|
||||
Returns
|
||||
-------
|
||||
array : numpy array
|
||||
Padded numpy array.
|
||||
valid_vals : list
|
||||
List containing proportion for each dimension of original, non-padded values.
|
||||
"""
|
||||
assert len(target_shape) == array.ndim
|
||||
pads = [] # this contains the abs length of the padding for each dimension.
|
||||
valid_vals = [] # this contains the relative lengths for each dimension.
|
||||
i = len(target_shape) - 1 # iterating over target_shape ndims
|
||||
j = 0
|
||||
while i >= 0:
|
||||
assert (target_shape[i] >= array.shape[i]
|
||||
), "Target shape must be >= original shape for every dim"
|
||||
pads.extend([0, target_shape[i] - array.shape[i]])
|
||||
valid_vals.append(array.shape[j] / target_shape[j])
|
||||
i -= 1
|
||||
j += 1
|
||||
array = np.pad(array, pads, mode, constant_values=(value, value))
|
||||
|
||||
return array, valid_vals
|
||||
|
||||
|
||||
def mod_default_collate(batch):
|
||||
"""Makes a tensor from list of batch values.
|
||||
|
||||
Note that this doesn't need to zip(*) values together
|
||||
as PaddedBatch connects them already (by key).
|
||||
|
||||
Here the idea is not to error out.
|
||||
"""
|
||||
elem = batch[0]
|
||||
elem_type = type(elem)
|
||||
if isinstance(elem, paddle.Tensor):
|
||||
out = None
|
||||
try:
|
||||
if paddle.io.get_worker_info() is not None:
|
||||
|
||||
# If we're in a background process, concatenate directly into a
|
||||
# shared memory tensor to avoid an extra copy
|
||||
numel = sum([x.numel() for x in batch])
|
||||
storage = elem.storage()._new_shared(numel)
|
||||
out = elem.new(storage)
|
||||
return paddle.stack(batch, 0, name=out)
|
||||
except RuntimeError: # Unequal size:
|
||||
return batch
|
||||
elif (elem_type.__module__ == "numpy" and elem_type.__name__ != "str_" and
|
||||
elem_type.__name__ != "string_"):
|
||||
try:
|
||||
if (elem_type.__name__ == "ndarray" or
|
||||
elem_type.__name__ == "memmap"):
|
||||
# array of string classes and object
|
||||
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
|
||||
return batch
|
||||
return mod_default_collate(
|
||||
[paddle.to_tensor(b, dtype=b.dtype) for b in batch])
|
||||
elif elem.shape == (): # scalars
|
||||
return paddle.to_tensor(batch, dtype=batch.dtype)
|
||||
except RuntimeError: # Unequal size
|
||||
return batch
|
||||
elif isinstance(elem, float):
|
||||
return paddle.to_tensor(batch, dtype=paddle.float64)
|
||||
elif isinstance(elem, int):
|
||||
return paddle.to_tensor(batch, dtype=paddle.int64)
|
||||
else:
|
||||
return batch
|
@ -0,0 +1,845 @@
|
||||
# Copyright (c) 2023 speechbrain Authors. All Rights Reserved.
|
||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Modified from speechbrain 2023 (https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/dataio/dataio.py)
|
||||
"""
|
||||
Data reading and writing.
|
||||
|
||||
Authors
|
||||
* Mirco Ravanelli 2020
|
||||
* Aku Rouhe 2020
|
||||
* Ju-Chieh Chou 2020
|
||||
* Samuele Cornell 2020
|
||||
* Abdel HEBA 2020
|
||||
"""
|
||||
import csv
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import re
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import soundfile
|
||||
logger = logging.getLogger(__name__)
|
||||
import paddle
|
||||
|
||||
|
||||
def load_data_json(json_path, replacements={}):
|
||||
"""Loads JSON and recursively formats string values.
|
||||
|
||||
Arguments
|
||||
----------
|
||||
json_path : str
|
||||
Path to CSV file.
|
||||
replacements : dict
|
||||
(Optional dict), e.g., {"data_folder": "/home/PaddleSpeech/data"}.
|
||||
This is used to recursively format all string values in the data.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
JSON data with replacements applied.
|
||||
|
||||
|
||||
"""
|
||||
with open(json_path, "r") as f:
|
||||
out_json = json.load(f)
|
||||
_recursive_format(out_json, replacements)
|
||||
return out_json
|
||||
|
||||
|
||||
def _recursive_format(data, replacements):
|
||||
# Data: dict or list, replacements : dict
|
||||
# Replaces string keys in replacements by their values
|
||||
# at all levels of data (in str values)
|
||||
# Works in-place.
|
||||
if isinstance(data, dict):
|
||||
for key, item in data.items():
|
||||
if isinstance(item, dict) or isinstance(item, list):
|
||||
_recursive_format(item, replacements)
|
||||
elif isinstance(item, str):
|
||||
data[key] = item.format_map(replacements)
|
||||
# If not dict, list or str, do nothing
|
||||
if isinstance(data, list):
|
||||
for i, item in enumerate(data):
|
||||
if isinstance(item, dict) or isinstance(item, list):
|
||||
_recursive_format(item, replacements)
|
||||
elif isinstance(item, str):
|
||||
data[i] = item.format_map(replacements)
|
||||
# If not dict, list or str, do nothing
|
||||
|
||||
|
||||
def load_data_csv(csv_path, replacements={}):
|
||||
"""Loads CSV and formats string values.
|
||||
|
||||
Uses the legacy CSV data format, where the CSV must have an
|
||||
'ID' field.
|
||||
If there is a field called duration, it is interpreted as a float.
|
||||
The rest of the fields are left as they are (legacy _format and _opts fields
|
||||
are not used to load the data in any special way).
|
||||
|
||||
Bash-like string replacements with $to_replace are supported.
|
||||
|
||||
Arguments
|
||||
----------
|
||||
csv_path : str
|
||||
Path to CSV file.
|
||||
replacements : dict
|
||||
(Optional dict), e.g., {"data_folder": "/home/PaddleSpeech/data"}
|
||||
This is used to recursively format all string values in the data.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
CSV data with replacements applied.
|
||||
"""
|
||||
|
||||
with open(csv_path, newline="") as csvfile:
|
||||
result = {}
|
||||
reader = csv.DictReader(csvfile, skipinitialspace=True)
|
||||
variable_finder = re.compile(r"\$([\w.]+)")
|
||||
for row in reader:
|
||||
# ID:
|
||||
try:
|
||||
data_id = row["ID"]
|
||||
del row["ID"] # This is used as a key in result, instead.
|
||||
except KeyError:
|
||||
raise KeyError("CSV has to have an 'ID' field, with unique ids"
|
||||
" for all data points")
|
||||
if data_id in result:
|
||||
raise ValueError(f"Duplicate id: {data_id}")
|
||||
# Replacements:
|
||||
for key, value in row.items():
|
||||
try:
|
||||
row[key] = variable_finder.sub(
|
||||
lambda match: str(replacements[match[1]]), value)
|
||||
except KeyError:
|
||||
raise KeyError(f"The item {value} requires replacements "
|
||||
"which were not supplied.")
|
||||
# Duration:
|
||||
if "duration" in row:
|
||||
row["duration"] = float(row["duration"])
|
||||
result[data_id] = row
|
||||
return result
|
||||
|
||||
|
||||
def read_audio(waveforms_obj):
|
||||
"""General audio loading, based on a custom notation.
|
||||
|
||||
Expected use case is in conjunction with Datasets
|
||||
specified by JSON.
|
||||
|
||||
The custom notation:
|
||||
|
||||
The annotation can be just a path to a file:
|
||||
"/path/to/wav1.wav"
|
||||
|
||||
Or can specify more options in a dict:
|
||||
{"file": "/path/to/wav2.wav",
|
||||
"start": 8000,
|
||||
"stop": 16000
|
||||
}
|
||||
|
||||
Arguments
|
||||
----------
|
||||
waveforms_obj : str, dict
|
||||
Audio reading annotation, see above for format.
|
||||
|
||||
Returns
|
||||
-------
|
||||
paddle.Tensor
|
||||
Audio tensor with shape: (samples, ).
|
||||
"""
|
||||
if isinstance(waveforms_obj, str):
|
||||
audio, _ = soundfile.read(waveforms_obj, dtype="float32")
|
||||
return audio
|
||||
|
||||
path = waveforms_obj["file"]
|
||||
start = waveforms_obj.get("start", 0)
|
||||
# Default stop to start -> if not specified, num_frames becomes 0
|
||||
stop = waveforms_obj.get("stop", start)
|
||||
num_frames = stop - start
|
||||
audio, fs = soundfile.read(
|
||||
path, start=start, stop=start + num_frames, dtype="float32")
|
||||
return audio
|
||||
|
||||
|
||||
def read_audio_multichannel(waveforms_obj):
|
||||
"""General audio loading, based on a custom notation.
|
||||
|
||||
Expected use case is in conjunction with Datasets
|
||||
specified by JSON.
|
||||
|
||||
The custom notation:
|
||||
|
||||
The annotation can be just a path to a file:
|
||||
"/path/to/wav1.wav"
|
||||
|
||||
Multiple (possibly multi-channel) files can be specified, as long as they
|
||||
have the same length:
|
||||
{"files": [
|
||||
"/path/to/wav1.wav",
|
||||
"/path/to/wav2.wav"
|
||||
]
|
||||
}
|
||||
|
||||
Or you can specify a single file more succinctly:
|
||||
{"files": "/path/to/wav2.wav"}
|
||||
|
||||
Offset number samples and stop number samples also can be specified to read
|
||||
only a segment within the files.
|
||||
{"files": [
|
||||
"/path/to/wav1.wav",
|
||||
"/path/to/wav2.wav"
|
||||
]
|
||||
"start": 8000
|
||||
"stop": 16000
|
||||
}
|
||||
|
||||
Arguments
|
||||
----------
|
||||
waveforms_obj : str, dict
|
||||
Audio reading annotation, see above for format.
|
||||
|
||||
Returns
|
||||
-------
|
||||
paddle.Tensor
|
||||
Audio tensor with shape: (samples, ).
|
||||
"""
|
||||
if isinstance(waveforms_obj, str):
|
||||
audio, _ = soundfile.read(waveforms_obj, dtype="float32")
|
||||
audio = paddle.to_tensor(audio)
|
||||
return audio
|
||||
|
||||
files = waveforms_obj["files"]
|
||||
if not isinstance(files, list):
|
||||
files = [files]
|
||||
|
||||
waveforms = []
|
||||
start = waveforms_obj.get("start", 0)
|
||||
# Default stop to start -> if not specified, num_frames becomes 0
|
||||
stop = waveforms_obj.get("stop", start - 1)
|
||||
num_frames = stop - start
|
||||
for f in files:
|
||||
audio, fs = soundfile.read(
|
||||
path, start=start, stop=start + num_frames, dtype="float32")
|
||||
audio = paddle.to_tensor(audio)
|
||||
waveforms.append(audio)
|
||||
|
||||
out = paddle.concat(waveforms, 0)
|
||||
return out
|
||||
|
||||
|
||||
def write_audio(filepath, audio, samplerate):
|
||||
"""Write audio on disk. It is basically a wrapper to support saving
|
||||
audio signals in format (audio, channels).
|
||||
|
||||
Arguments
|
||||
---------
|
||||
filepath: path
|
||||
Path where to save the audio file.
|
||||
audio : paddle.Tensor
|
||||
Audio file in the expected format (signal, channels).
|
||||
samplerate: int
|
||||
Sample rate (e.g., 16000).
|
||||
|
||||
"""
|
||||
if len(audio.shape) == 2:
|
||||
audio = audio.transpose([1, 0])
|
||||
elif len(audio.shape) == 1:
|
||||
audio = audio.unsqueeze(0)
|
||||
|
||||
soundfile.write(filepath, audio, samplerate)
|
||||
|
||||
|
||||
def load_pickle(pickle_path):
|
||||
"""Utility function for loading .pkl pickle files.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
pickle_path : str
|
||||
Path to pickle file.
|
||||
|
||||
Returns
|
||||
-------
|
||||
out : object
|
||||
Python object loaded from pickle.
|
||||
"""
|
||||
with open(pickle_path, "rb") as f:
|
||||
out = pickle.load(f)
|
||||
return out
|
||||
|
||||
|
||||
def to_floatTensor(x: (list, tuple, np.ndarray)):
|
||||
"""
|
||||
Arguments
|
||||
---------
|
||||
x : (list, tuple, np.ndarray)
|
||||
Input data to be converted to paddle float.
|
||||
|
||||
Returns
|
||||
-------
|
||||
tensor : paddle.tensor
|
||||
Data now in paddle.tensor float datatype.
|
||||
"""
|
||||
return paddle.to_tensor(x, dtype='float32')
|
||||
|
||||
|
||||
def to_doubleTensor(x: (list, tuple, np.ndarray)):
|
||||
"""
|
||||
Arguments
|
||||
---------
|
||||
x : (list, tuple, np.ndarray)
|
||||
Input data to be converted to paddle double.
|
||||
|
||||
Returns
|
||||
-------
|
||||
tensor : paddle.tensor
|
||||
Data now in paddle.tensor double datatype.
|
||||
"""
|
||||
return paddle.to_tensor(x, dtype='float64')
|
||||
|
||||
|
||||
def to_longTensor(x: (list, tuple, np.ndarray)):
|
||||
"""
|
||||
Arguments
|
||||
---------
|
||||
x : (list, tuple, np.ndarray)
|
||||
Input data to be converted to paddle long.
|
||||
|
||||
Returns
|
||||
-------
|
||||
tensor : paddle.tensor
|
||||
Data now in paddle.tensor long datatype.
|
||||
"""
|
||||
return paddle.to_tensor(x, dtype='int64')
|
||||
|
||||
|
||||
def convert_index_to_lab(batch, ind2lab):
|
||||
"""Convert a batch of integer IDs to string labels.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
batch : list
|
||||
List of lists, a batch of sequences.
|
||||
ind2lab : dict
|
||||
Mapping from integer IDs to labels.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
List of lists, same size as batch, with labels from ind2lab.
|
||||
|
||||
"""
|
||||
return [[ind2lab[int(index)] for index in seq] for seq in batch]
|
||||
|
||||
|
||||
def relative_time_to_absolute(batch, relative_lens, rate):
|
||||
"""Converts relative length to the absolute duration.
|
||||
|
||||
Operates on batch level.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
batch : paddle.tensor
|
||||
Sequences to determine the duration for.
|
||||
relative_lens : paddle.tensor
|
||||
The relative length of each sequence in batch. The longest sequence in
|
||||
the batch needs to have relative length 1.0.
|
||||
rate : float
|
||||
The rate at which sequence elements occur in real-world time. Sample
|
||||
rate, if batch is raw wavs (recommended) or 1/frame_shift if batch is
|
||||
features. This has to have 1/s as the unit.
|
||||
|
||||
Returns
|
||||
------:
|
||||
paddle.tensor
|
||||
Duration of each sequence in seconds.
|
||||
|
||||
"""
|
||||
max_len = batch.shape[1]
|
||||
durations = paddle.round(relative_lens * max_len) / rate
|
||||
return durations
|
||||
|
||||
|
||||
class IterativeCSVWriter:
|
||||
"""Write CSV files a line at a time.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
outstream : file-object
|
||||
A writeable stream
|
||||
data_fields : list
|
||||
List of the optional keys to write. Each key will be expanded,
|
||||
producing three fields: key, key_format, key_opts.
|
||||
"""
|
||||
|
||||
def __init__(self, outstream, data_fields, defaults={}):
|
||||
self._outstream = outstream
|
||||
self.fields = ["ID", "duration"] + self._expand_data_fields(data_fields)
|
||||
self.defaults = defaults
|
||||
self._outstream.write(",".join(self.fields))
|
||||
|
||||
def set_default(self, field, value):
|
||||
"""Sets a default value for the given CSV field.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
field : str
|
||||
A field in the CSV.
|
||||
value
|
||||
The default value.
|
||||
"""
|
||||
if field not in self.fields:
|
||||
raise ValueError(f"{field} is not a field in this CSV!")
|
||||
self.defaults[field] = value
|
||||
|
||||
def write(self, *args, **kwargs):
|
||||
"""Writes one data line into the CSV.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
*args
|
||||
Supply every field with a value in positional form OR.
|
||||
**kwargs
|
||||
Supply certain fields by key. The ID field is mandatory for all
|
||||
lines, but others can be left empty.
|
||||
"""
|
||||
if args and kwargs:
|
||||
raise ValueError(
|
||||
"Use either positional fields or named fields, but not both.")
|
||||
if args:
|
||||
if len(args) != len(self.fields):
|
||||
raise ValueError("Need consistent fields")
|
||||
to_write = [str(arg) for arg in args]
|
||||
if kwargs:
|
||||
if "ID" not in kwargs:
|
||||
raise ValueError("I'll need to see some ID")
|
||||
full_vals = self.defaults.copy()
|
||||
full_vals.update(kwargs)
|
||||
to_write = [str(full_vals.get(field, "")) for field in self.fields]
|
||||
self._outstream.write("\n")
|
||||
self._outstream.write(",".join(to_write))
|
||||
|
||||
def write_batch(self, *args, **kwargs):
|
||||
"""Writes a batch of lines into the CSV.
|
||||
|
||||
Here each argument should be a list with the same length.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
*args
|
||||
Supply every field with a value in positional form OR.
|
||||
**kwargs
|
||||
Supply certain fields by key. The ID field is mandatory for all
|
||||
lines, but others can be left empty.
|
||||
"""
|
||||
if args and kwargs:
|
||||
raise ValueError(
|
||||
"Use either positional fields or named fields, but not both.")
|
||||
if args:
|
||||
if len(args) != len(self.fields):
|
||||
raise ValueError("Need consistent fields")
|
||||
for arg_row in zip(*args):
|
||||
self.write(*arg_row)
|
||||
if kwargs:
|
||||
if "ID" not in kwargs:
|
||||
raise ValueError("I'll need to see some ID")
|
||||
keys = kwargs.keys()
|
||||
for value_row in zip(*kwargs.values()):
|
||||
kwarg_row = dict(zip(keys, value_row))
|
||||
self.write(**kwarg_row)
|
||||
|
||||
@staticmethod
|
||||
def _expand_data_fields(data_fields):
|
||||
expanded = []
|
||||
for data_field in data_fields:
|
||||
expanded.append(data_field)
|
||||
expanded.append(data_field + "_format")
|
||||
expanded.append(data_field + "_opts")
|
||||
return expanded
|
||||
|
||||
|
||||
def write_txt_file(data, filename, sampling_rate=None):
|
||||
"""Write data in text format.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
data : str, list, paddle.tensor, numpy.ndarray
|
||||
The data to write in the text file.
|
||||
filename : str
|
||||
Path to file where to write the data.
|
||||
sampling_rate : None
|
||||
Not used, just here for interface compatibility.
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
|
||||
"""
|
||||
del sampling_rate # Not used.
|
||||
# Check if the path of filename exists
|
||||
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
||||
with open(filename, "w") as fout:
|
||||
if isinstance(data, paddle.Tensor):
|
||||
data = data.tolist()
|
||||
if isinstance(data, np.ndarray):
|
||||
data = data.tolist()
|
||||
if isinstance(data, list):
|
||||
for line in data:
|
||||
print(line, file=fout)
|
||||
if isinstance(data, str):
|
||||
print(data, file=fout)
|
||||
|
||||
|
||||
def write_stdout(data, filename=None, sampling_rate=None):
|
||||
"""Write data to standard output.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
data : str, list, paddle.Tensor, numpy.ndarray
|
||||
The data to write in the text file.
|
||||
filename : None
|
||||
Not used, just here for compatibility.
|
||||
sampling_rate : None
|
||||
Not used, just here for compatibility.
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
|
||||
"""
|
||||
# Managing paddle.Tensor
|
||||
if isinstance(data, paddle.Tensor):
|
||||
data = data.tolist()
|
||||
# Managing np.ndarray
|
||||
if isinstance(data, np.ndarray):
|
||||
data = data.tolist()
|
||||
if isinstance(data, list):
|
||||
for line in data:
|
||||
print(line)
|
||||
if isinstance(data, str):
|
||||
print(data)
|
||||
|
||||
|
||||
def length_to_mask(length, max_len=None, dtype=None, device=None):
|
||||
"""Creates a binary mask for each sequence.
|
||||
Arguments
|
||||
---------
|
||||
length : LongTensor
|
||||
Containing the length of each sequence in the batch. Must be 1D.
|
||||
max_len : int
|
||||
Max length for the mask, also the size of the second dimension.
|
||||
dtype : dtype, default: None
|
||||
The dtype of the generated mask.
|
||||
device: device, default: None
|
||||
The device to put the mask variable.
|
||||
|
||||
Returns
|
||||
-------
|
||||
mask : tensor
|
||||
The binary mask.
|
||||
|
||||
"""
|
||||
assert len(length.shape) == 1
|
||||
|
||||
if max_len is None:
|
||||
max_len = length.max().long().item() # using arange to generate mask
|
||||
mask = paddle.arange(
|
||||
max_len, dtype=length.dtype).expand(
|
||||
[len(length), max_len]) < length.unsqueeze(1)
|
||||
|
||||
if dtype is None:
|
||||
dtype = length.dtype
|
||||
|
||||
if device is None:
|
||||
device = length.device
|
||||
|
||||
mask = paddle.to_tensor(mask, dtype=dtype)
|
||||
return mask
|
||||
|
||||
|
||||
def read_kaldi_lab(kaldi_ali, kaldi_lab_opts):
|
||||
"""Read labels in kaldi format.
|
||||
|
||||
Uses kaldi IO.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
kaldi_ali : str
|
||||
Path to directory where kaldi alignments are stored.
|
||||
kaldi_lab_opts : str
|
||||
A string that contains the options for reading the kaldi alignments.
|
||||
|
||||
Returns
|
||||
-------
|
||||
lab : dict
|
||||
A dictionary containing the labels.
|
||||
|
||||
Note
|
||||
----
|
||||
This depends on kaldi-io-for-python. Install it separately.
|
||||
See: https://github.com/vesis84/kaldi-io-for-python
|
||||
```
|
||||
"""
|
||||
# EXTRA TOOLS
|
||||
try:
|
||||
import kaldi_io
|
||||
except ImportError:
|
||||
raise ImportError("Could not import kaldi_io. Install it to use this.")
|
||||
# Reading the Kaldi labels
|
||||
lab = {
|
||||
k: v
|
||||
for k, v in kaldi_io.read_vec_int_ark(
|
||||
"gunzip -c " + kaldi_ali + "/ali*.gz | " + kaldi_lab_opts + " " +
|
||||
kaldi_ali + "/final.mdl ark:- ark:-|")
|
||||
}
|
||||
return lab
|
||||
|
||||
|
||||
def get_md5(file):
|
||||
"""Get the md5 checksum of an input file.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
file : str
|
||||
Path to file for which compute the checksum.
|
||||
|
||||
Returns
|
||||
-------
|
||||
md5
|
||||
Checksum for the given filepath.
|
||||
"""
|
||||
# Lets read stuff in 64kb chunks!
|
||||
BUF_SIZE = 65536
|
||||
md5 = hashlib.md5()
|
||||
# Computing md5
|
||||
with open(file, "rb") as f:
|
||||
while True:
|
||||
data = f.read(BUF_SIZE)
|
||||
if not data:
|
||||
break
|
||||
md5.update(data)
|
||||
return md5.hexdigest()
|
||||
|
||||
|
||||
def save_md5(files, out_file):
|
||||
"""Saves the md5 of a list of input files as a pickled dict into a file.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
files : list
|
||||
List of input files from which we will compute the md5.
|
||||
outfile : str
|
||||
The path where to store the output pkl file.
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
"""
|
||||
# Initialization of the dictionary
|
||||
md5_dict = {}
|
||||
# Computing md5 for all the files in the list
|
||||
for file in files:
|
||||
md5_dict[file] = get_md5(file)
|
||||
# Saving dictionary in pkl format
|
||||
save_pkl(md5_dict, out_file)
|
||||
|
||||
|
||||
def save_pkl(obj, file):
|
||||
"""Save an object in pkl format.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
obj : object
|
||||
Object to save in pkl format
|
||||
file : str
|
||||
Path to the output file
|
||||
sampling_rate : int
|
||||
Sampling rate of the audio file, TODO: this is not used?
|
||||
|
||||
"""
|
||||
with open(file, "wb") as f:
|
||||
pickle.dump(obj, f)
|
||||
|
||||
|
||||
def load_pkl(file):
|
||||
"""Loads a pkl file.
|
||||
|
||||
For an example, see `save_pkl`.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
file : str
|
||||
Path to the input pkl file.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The loaded object.
|
||||
"""
|
||||
|
||||
# Deals with the situation where two processes are trying
|
||||
# to access the same label dictionary by creating a lock
|
||||
count = 100
|
||||
while count > 0:
|
||||
if os.path.isfile(file + ".lock"):
|
||||
time.sleep(1)
|
||||
count -= 1
|
||||
else:
|
||||
break
|
||||
|
||||
try:
|
||||
open(file + ".lock", "w").close()
|
||||
with open(file, "rb") as f:
|
||||
return pickle.load(f)
|
||||
finally:
|
||||
if os.path.isfile(file + ".lock"):
|
||||
os.remove(file + ".lock")
|
||||
|
||||
|
||||
def prepend_bos_token(label, bos_index):
|
||||
"""Create labels with <bos> token at the beginning.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
label : IntTensor
|
||||
Containing the original labels. Must be of size: [batch_size, max_length].
|
||||
bos_index : int
|
||||
The index for <bos> token.
|
||||
|
||||
Returns
|
||||
-------
|
||||
new_label : tensor
|
||||
The new label with <bos> at the beginning.
|
||||
|
||||
"""
|
||||
new_label = label.long().clone()
|
||||
batch_size = label.shape[0]
|
||||
|
||||
bos = new_label.new_zeros(batch_size, 1).fill_(bos_index)
|
||||
new_label = paddle.concat([bos, new_label], axis=1)
|
||||
return new_label
|
||||
|
||||
|
||||
def append_eos_token(label, length, eos_index):
|
||||
"""Create labels with <eos> token appended.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
label : IntTensor
|
||||
Containing the original labels. Must be of size: [batch_size, max_length]
|
||||
length : LongTensor
|
||||
Containing the original length of each label sequences. Must be 1D.
|
||||
eos_index : int
|
||||
The index for <eos> token.
|
||||
|
||||
Returns
|
||||
-------
|
||||
new_label : tensor
|
||||
The new label with <eos> appended.
|
||||
|
||||
"""
|
||||
new_label = paddle.to_tensor(label, dtype="int32").clone()
|
||||
batch_size = label.shape[0]
|
||||
|
||||
pad = paddle.zeros([batch_size, 1], dtype=new_label.dtype)
|
||||
|
||||
new_label = paddle.concat([new_label, pad], dim=1)
|
||||
new_label[paddle.arange(batch_size), paddle.to_tensor(
|
||||
length, dtype="int64")] = eos_index
|
||||
return new_label
|
||||
|
||||
|
||||
def merge_char(sequences, space="_"):
|
||||
"""Merge characters sequences into word sequences.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
sequences : list
|
||||
Each item contains a list, and this list contains a character sequence.
|
||||
space : string
|
||||
The token represents space. Default: _
|
||||
|
||||
Returns
|
||||
-------
|
||||
The list contains word sequences for each sentence.
|
||||
|
||||
"""
|
||||
results = []
|
||||
for seq in sequences:
|
||||
words = "".join(seq).split(space)
|
||||
results.append(words)
|
||||
return results
|
||||
|
||||
|
||||
def merge_csvs(data_folder, csv_lst, merged_csv):
|
||||
"""Merging several csv files into one file.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
data_folder : string
|
||||
The folder to store csv files to be merged and after merging.
|
||||
csv_lst : list
|
||||
Filenames of csv file to be merged.
|
||||
merged_csv : string
|
||||
The filename to write the merged csv file.
|
||||
|
||||
"""
|
||||
write_path = os.path.join(data_folder, merged_csv)
|
||||
if os.path.isfile(write_path):
|
||||
logger.info("Skipping merging. Completed in previous run.")
|
||||
with open(os.path.join(data_folder, csv_lst[0])) as f:
|
||||
header = f.readline()
|
||||
lines = []
|
||||
for csv_file in csv_lst:
|
||||
with open(os.path.join(data_folder, csv_file)) as f:
|
||||
for i, line in enumerate(f):
|
||||
if i == 0:
|
||||
# Checking header
|
||||
if line != header:
|
||||
raise ValueError("Different header for "
|
||||
f"{csv_lst[0]} and {csv}.")
|
||||
continue
|
||||
lines.append(line)
|
||||
with open(write_path, "w") as f:
|
||||
f.write(header)
|
||||
for line in lines:
|
||||
f.write(line)
|
||||
logger.info(f"{write_path} is created.")
|
||||
|
||||
|
||||
def split_word(sequences, space="_"):
|
||||
"""Split word sequences into character sequences.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
sequences : list
|
||||
Each item contains a list, and this list contains a words sequence.
|
||||
space : string
|
||||
The token represents space. Default: _
|
||||
|
||||
Returns
|
||||
-------
|
||||
The list contains word sequences for each sentence.
|
||||
|
||||
"""
|
||||
results = []
|
||||
for seq in sequences:
|
||||
chars = list(space.join(seq))
|
||||
results.append(chars)
|
||||
return results
|
@ -0,0 +1,172 @@
|
||||
# Copyright (c) 2023 speechbrain Authors. All Rights Reserved.
|
||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Modified from speechbrain 2023 (https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/dataio/dataloader.py)
|
||||
"""Paddle compatible DataLoaders
|
||||
|
||||
Essentially we extend Paddle DataLoader by adding the ability to save the
|
||||
data loading state, so that a checkpoint may be saved in the middle of an
|
||||
epoch.
|
||||
|
||||
Authors:
|
||||
* Aku Rouhe 2020
|
||||
"""
|
||||
import collections
|
||||
import functools
|
||||
import logging
|
||||
import warnings
|
||||
|
||||
import paddle
|
||||
from paddle.io import DataLoader
|
||||
|
||||
from paddlespeech.s2t.io.speechbrain.data_utils import batch_pad_right
|
||||
from paddlespeech.s2t.io.speechbrain.data_utils import mod_default_collate
|
||||
from paddlespeech.s2t.io.speechbrain.dataset import DynamicItemDataset
|
||||
from paddlespeech.s2t.io.speechbrain.sampler import ReproducibleRandomSampler
|
||||
PaddedData = collections.namedtuple("PaddedData", ["data", "lengths"])
|
||||
import numpy
|
||||
|
||||
|
||||
class Wav2vec2DataLoader(DataLoader):
|
||||
def __init__(self,
|
||||
dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
sampler=None,
|
||||
batch_sampler=None,
|
||||
num_workers=0,
|
||||
collate_fn=None,
|
||||
pin_memory=False,
|
||||
drop_last=False,
|
||||
timeout=0,
|
||||
worker_init_fn=None,
|
||||
multiprocessing_context=None,
|
||||
generator=None):
|
||||
if isinstance(dataset[0], (tuple, list)):
|
||||
return_list = True
|
||||
else:
|
||||
return_list = False
|
||||
|
||||
super().__init__(
|
||||
dataset,
|
||||
feed_list=None,
|
||||
places=None,
|
||||
return_list=return_list,
|
||||
batch_sampler=batch_sampler,
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle,
|
||||
drop_last=drop_last,
|
||||
collate_fn=collate_fn,
|
||||
num_workers=num_workers,
|
||||
use_buffer_reader=True,
|
||||
use_shared_memory=False,
|
||||
timeout=timeout,
|
||||
worker_init_fn=worker_init_fn)
|
||||
if sampler is not None:
|
||||
self.batch_sampler.sampler = sampler
|
||||
|
||||
|
||||
def PaddedBatch(
|
||||
examples,
|
||||
padded_keys=None,
|
||||
device_prep_keys=None,
|
||||
padding_func=batch_pad_right,
|
||||
padding_kwargs={},
|
||||
nonpadded_stack=True, ):
|
||||
__length = len(examples)
|
||||
__keys = list(examples[0].keys())
|
||||
__padded_keys = []
|
||||
__device_prep_keys = []
|
||||
res = {}
|
||||
for key in __keys:
|
||||
values = [example[key] for example in examples]
|
||||
# Default convert usually does the right thing (numpy2tensor etc.)
|
||||
# values = default_convert(values)
|
||||
if (padded_keys is not None and key in padded_keys) or (
|
||||
padded_keys is None and isinstance(values[0], numpy.ndarray)):
|
||||
# Padding and PaddedData
|
||||
__padded_keys.append(key)
|
||||
|
||||
padded = PaddedData(*padding_func(values, **padding_kwargs))
|
||||
res[key] = padded
|
||||
else:
|
||||
# Default collate usually does the right thing
|
||||
# (convert lists of equal sized tensors to batch tensors, etc.)
|
||||
if nonpadded_stack:
|
||||
values = mod_default_collate(values)
|
||||
res[key] = values
|
||||
if (device_prep_keys is not None and key in device_prep_keys) or (
|
||||
device_prep_keys is None and
|
||||
isinstance(values[0], paddle.Tensor)):
|
||||
__device_prep_keys.append(key)
|
||||
return res
|
||||
|
||||
|
||||
def make_dataloader(dataset, stage, **loader_kwargs):
|
||||
"""Makes a basic DataLoader.
|
||||
|
||||
For DynamicItemDatasets (which return dicts), use
|
||||
PaddedBatch as the default collate_fn.
|
||||
|
||||
Shuffling gets implemented by ReproducibleRandomSampler.
|
||||
|
||||
If the Dataset is not an IterableDataset, the DataLoader
|
||||
is a SaveableDataLoader.
|
||||
|
||||
If the Dataset is a webdataset.dataset.Composable, set default
|
||||
batch_size = None.
|
||||
|
||||
Can also loop over the underlying dataloader continuously,
|
||||
and stop iterations at nominal epoch lengths.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
dataset : Dataset
|
||||
The dataset to make a DataLoader for.
|
||||
looped_nominal_epoch : None, int
|
||||
If an integer is given, loop the underlying DataLoader infinitely and
|
||||
set a nominal epoch length in batches (or whatever the DataLoader
|
||||
yields).
|
||||
**loader_kwargs : dict
|
||||
Keyword args to DataLoader, see Paddle DataLoader for
|
||||
options.
|
||||
|
||||
Returns
|
||||
-------
|
||||
DataLoader
|
||||
If looped_nominal_epoch is None
|
||||
LoopedLoader
|
||||
If looped_nominal_epoch is not None
|
||||
"""
|
||||
# PaddedBatch as default collation for DynamicItemDataset
|
||||
if "collate_fn" not in loader_kwargs and isinstance(dataset,
|
||||
DynamicItemDataset):
|
||||
loader_kwargs["collate_fn"] = PaddedBatch
|
||||
# Reproducible random sampling
|
||||
if loader_kwargs.get("shuffle", False):
|
||||
if loader_kwargs.get("sampler") is not None:
|
||||
raise ValueError("Cannot specify both shuffle=True and a "
|
||||
"sampler in loader_kwargs")
|
||||
sampler = ReproducibleRandomSampler(dataset)
|
||||
loader_kwargs["sampler"] = sampler
|
||||
# Should delete shuffle because you can't set both Sampler and
|
||||
# shuffle
|
||||
# NOTE: the dict of loader options may get used elsewhere!
|
||||
# However, this del doesn't touch those because loader_kwargs comes
|
||||
# from a **kwargs dict.
|
||||
del loader_kwargs["shuffle"]
|
||||
# Create the loader
|
||||
dataloader = Wav2vec2DataLoader(dataset, **loader_kwargs)
|
||||
return dataloader
|
@ -0,0 +1,371 @@
|
||||
# Copyright (c) 2023 speechbrain Authors. All Rights Reserved.
|
||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Modified from speechbrain 2023 (https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/dataio/dataset.py)
|
||||
import contextlib
|
||||
import copy
|
||||
import logging
|
||||
from types import MethodType
|
||||
|
||||
from paddle.io import Dataset
|
||||
|
||||
from paddlespeech.s2t.io.speechbrain.data_pipeline import DataPipeline
|
||||
from paddlespeech.s2t.io.speechbrain.dataio import load_data_csv
|
||||
from paddlespeech.s2t.io.speechbrain.dataio import load_data_json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DynamicItemDataset(Dataset):
|
||||
"""Dataset that reads, wrangles, and produces dicts.
|
||||
|
||||
Each data point dict provides some items (by key), for example, a path to a
|
||||
wavefile with the key "wav_file". When a data point is fetched from this
|
||||
Dataset, more items are produced dynamically, based on pre-existing items
|
||||
and other dynamic created items. For example, a dynamic item could take the
|
||||
wavfile path and load the audio from the disk.
|
||||
|
||||
The dynamic items can depend on other dynamic items: a suitable evaluation
|
||||
order is used automatically, as long as there are no circular dependencies.
|
||||
|
||||
A specified list of keys is collected in the output dict. These can be items
|
||||
in the original data or dynamic items. If some dynamic items are not
|
||||
requested, nor depended on by other requested items, they won't be computed.
|
||||
So for example if a user simply wants to iterate over the text, the
|
||||
time-consuming audio loading can be skipped.
|
||||
|
||||
About the format:
|
||||
Takes a dict of dicts as the collection of data points to read/wrangle.
|
||||
The top level keys are data point IDs.
|
||||
Each data point (example) dict should have the same keys, corresponding to
|
||||
different items in that data point.
|
||||
|
||||
Altogether the data collection could look like this:
|
||||
|
||||
>>> data = {
|
||||
... "spk1utt1": {
|
||||
... "wav_file": "/path/to/spk1utt1.wav",
|
||||
... "text": "hello world",
|
||||
... "speaker": "spk1",
|
||||
... },
|
||||
... "spk1utt2": {
|
||||
... "wav_file": "/path/to/spk1utt2.wav",
|
||||
... "text": "how are you world",
|
||||
... "speaker": "spk1",
|
||||
... }
|
||||
... }
|
||||
|
||||
NOTE
|
||||
----
|
||||
The top-level key, the data point id, is implicitly added as an item
|
||||
in the data point, with the key "id"
|
||||
|
||||
Each dynamic item is configured by three things: a key, a func, and a list
|
||||
of argkeys. The key should be unique among all the items (dynamic or not) in
|
||||
each data point. The func is any callable, and it returns the dynamic item's
|
||||
value. The callable is called with the values of other items as specified
|
||||
by the argkeys list (as positional args, passed in the order specified by
|
||||
argkeys).
|
||||
|
||||
Arguments
|
||||
---------
|
||||
data : dict
|
||||
Dictionary containing single data points (e.g. utterances).
|
||||
dynamic_items : list, optional
|
||||
Configuration for the dynamic items produced when fetching an example.
|
||||
List of DynamicItems or dicts with the format::
|
||||
func: <callable> # To be called
|
||||
takes: <list> # key or list of keys of args this takes
|
||||
provides: key # key or list of keys that this provides
|
||||
output_keys : dict, list, optional
|
||||
List of keys (either directly available in data or dynamic items)
|
||||
to include in the output dict when data points are fetched.
|
||||
|
||||
If a dict is given; it is used to map internal keys to output keys.
|
||||
From the output_keys dict key:value pairs the key appears outside,
|
||||
and value is the internal key.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data,
|
||||
dynamic_items=[],
|
||||
output_keys=[], ):
|
||||
self.data = data
|
||||
self.data_ids = list(self.data.keys())
|
||||
static_keys = list(self.data[self.data_ids[0]].keys())
|
||||
if "id" in static_keys:
|
||||
raise ValueError("The key 'id' is reserved for the data point id.")
|
||||
else:
|
||||
static_keys.append("id")
|
||||
self.pipeline = DataPipeline(static_keys, dynamic_items)
|
||||
self.set_output_keys(output_keys)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_ids)
|
||||
|
||||
def __getitem__(self, index):
|
||||
data_id = self.data_ids[index]
|
||||
data_point = self.data[data_id]
|
||||
return self.pipeline.compute_outputs({"id": data_id, **data_point})
|
||||
|
||||
def add_dynamic_item(self, func, takes=None, provides=None):
|
||||
"""Makes a new dynamic item available on the dataset.
|
||||
|
||||
Two calling conventions. For DynamicItem objects, just use:
|
||||
add_dynamic_item(dynamic_item).
|
||||
But otherwise, should use:
|
||||
add_dynamic_item(func, takes, provides).
|
||||
|
||||
Arguments
|
||||
---------
|
||||
func : callable, DynamicItem
|
||||
If a DynamicItem is given, adds that directly. Otherwise a
|
||||
DynamicItem is created, and this specifies the callable to use. If
|
||||
a generator function is given, then create a GeneratorDynamicItem.
|
||||
Otherwise creates a normal DynamicItem.
|
||||
takes : list, str
|
||||
List of keys. When func is called, each key is resolved to
|
||||
either an entry in the data or the output of another dynamic_item.
|
||||
The func is then called with these as positional arguments,
|
||||
in the same order as specified here.
|
||||
A single arg can be given directly.
|
||||
provides : str
|
||||
Unique key or keys that this provides.
|
||||
"""
|
||||
self.pipeline.add_dynamic_item(func, takes, provides)
|
||||
|
||||
def set_output_keys(self, keys):
|
||||
"""Use this to change the output keys.
|
||||
|
||||
These are the keys that are actually evaluated when a data point
|
||||
is fetched from the dataset.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
keys : dict, list
|
||||
List of keys (str) to produce in output.
|
||||
|
||||
If a dict is given; it is used to map internal keys to output keys.
|
||||
From the output_keys dict key:value pairs the key appears outside,
|
||||
and value is the internal key.
|
||||
"""
|
||||
self.pipeline.set_output_keys(keys)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def output_keys_as(self, keys):
|
||||
"""Context manager to temporarily set output keys.
|
||||
|
||||
NOTE
|
||||
----
|
||||
Not thread-safe. While in this context manager, the output keys
|
||||
are affected for any call.
|
||||
"""
|
||||
saved_output = self.pipeline.output_mapping
|
||||
self.pipeline.set_output_keys(keys)
|
||||
yield self
|
||||
self.pipeline.set_output_keys(saved_output)
|
||||
|
||||
def filtered_sorted(
|
||||
self,
|
||||
key_min_value={},
|
||||
key_max_value={},
|
||||
key_test={},
|
||||
sort_key=None,
|
||||
reverse=False,
|
||||
select_n=None, ):
|
||||
"""Get a filtered and/or sorted version of this, shares static data.
|
||||
|
||||
The reason to implement these operations in the same method is that
|
||||
computing some dynamic items may be expensive, and this way the
|
||||
filtering and sorting steps don't need to compute the dynamic items
|
||||
twice.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
key_min_value : dict
|
||||
Map from key (in data or in dynamic items) to limit, will only keep
|
||||
data_point if data_point[key] >= limit
|
||||
key_max_value : dict
|
||||
Map from key (in data or in dynamic items) to limit, will only keep
|
||||
data_point if data_point[key] <= limit
|
||||
key_test : dict
|
||||
Map from key (in data or in dynamic items) to func, will only keep
|
||||
data_point if bool(func(data_point[key])) == True
|
||||
sort_key : None, str
|
||||
If not None, sort by data_point[sort_key]. Default is ascending
|
||||
order.
|
||||
reverse : bool
|
||||
If True, sort in descending order.
|
||||
select_n : None, int
|
||||
If not None, only keep (at most) the first n filtered data_points.
|
||||
The possible sorting is applied, but only on the first n data
|
||||
points found. Meant for debugging.
|
||||
|
||||
Returns
|
||||
-------
|
||||
FilteredSortedDynamicItemDataset
|
||||
Shares the static data, but has its own output keys and
|
||||
dynamic items (initially deep copied from this, so they have the
|
||||
same dynamic items available)
|
||||
|
||||
NOTE
|
||||
----
|
||||
Temporarily changes the output keys!
|
||||
"""
|
||||
filtered_sorted_ids = self._filtered_sorted_ids(
|
||||
key_min_value,
|
||||
key_max_value,
|
||||
key_test,
|
||||
sort_key,
|
||||
reverse,
|
||||
select_n, )
|
||||
return FilteredSortedDynamicItemDataset(
|
||||
self, filtered_sorted_ids) # NOTE: defined below
|
||||
|
||||
def _filtered_sorted_ids(
|
||||
self,
|
||||
key_min_value={},
|
||||
key_max_value={},
|
||||
key_test={},
|
||||
sort_key=None,
|
||||
reverse=False,
|
||||
select_n=None, ):
|
||||
"""Returns a list of data ids, fulfilling the sorting and filtering."""
|
||||
|
||||
def combined_filter(computed):
|
||||
"""Applies filter."""
|
||||
for key, limit in key_min_value.items():
|
||||
# NOTE: docstring promises >= so using that.
|
||||
# Mathematically could also use < for nicer syntax, but
|
||||
# maybe with some super special weird edge case some one can
|
||||
# depend on the >= operator
|
||||
if computed[key] >= limit:
|
||||
continue
|
||||
return False
|
||||
for key, limit in key_max_value.items():
|
||||
if computed[key] <= limit:
|
||||
continue
|
||||
return False
|
||||
for key, func in key_test.items():
|
||||
if bool(func(computed[key])):
|
||||
continue
|
||||
return False
|
||||
return True
|
||||
|
||||
temp_keys = (set(key_min_value.keys()) | set(key_max_value.keys()) |
|
||||
set(key_test.keys()) |
|
||||
set([] if sort_key is None else [sort_key]))
|
||||
filtered_ids = []
|
||||
with self.output_keys_as(temp_keys):
|
||||
for i, data_id in enumerate(self.data_ids):
|
||||
if select_n is not None and len(filtered_ids) == select_n:
|
||||
break
|
||||
data_point = self.data[data_id]
|
||||
data_point["id"] = data_id
|
||||
computed = self.pipeline.compute_outputs(data_point)
|
||||
if combined_filter(computed):
|
||||
if sort_key is not None:
|
||||
# Add (main sorting index, current index, data_id)
|
||||
# So that we maintain current sorting and don't compare
|
||||
# data_id values ever.
|
||||
filtered_ids.append((computed[sort_key], i, data_id))
|
||||
else:
|
||||
filtered_ids.append(data_id)
|
||||
if sort_key is not None:
|
||||
filtered_sorted_ids = [
|
||||
tup[2] for tup in sorted(filtered_ids, reverse=reverse)
|
||||
]
|
||||
else:
|
||||
filtered_sorted_ids = filtered_ids
|
||||
return filtered_sorted_ids
|
||||
|
||||
@classmethod
|
||||
def from_json(cls,
|
||||
json_path,
|
||||
replacements={},
|
||||
dynamic_items=[],
|
||||
output_keys=[]):
|
||||
"""Load a data prep JSON file and create a Dataset based on it."""
|
||||
data = load_data_json(json_path, replacements)
|
||||
return cls(data, dynamic_items, output_keys)
|
||||
|
||||
@classmethod
|
||||
def from_csv(cls,
|
||||
csv_path,
|
||||
replacements={},
|
||||
dynamic_items=[],
|
||||
output_keys=[]):
|
||||
"""Load a data prep CSV file and create a Dataset based on it."""
|
||||
data = load_data_csv(csv_path, replacements)
|
||||
return cls(data, dynamic_items, output_keys)
|
||||
|
||||
@classmethod
|
||||
def from_arrow_dataset(cls,
|
||||
dataset,
|
||||
replacements={},
|
||||
dynamic_items=[],
|
||||
output_keys=[]):
|
||||
"""Loading a prepared huggingface dataset"""
|
||||
|
||||
# define an unbound method to generate puesdo keys
|
||||
def keys(self):
|
||||
"Returns the keys."
|
||||
return [i for i in range(dataset.__len__())]
|
||||
|
||||
# bind this method to arrow dataset
|
||||
dataset.keys = MethodType(keys, dataset)
|
||||
return cls(dataset, dynamic_items, output_keys)
|
||||
|
||||
|
||||
class FilteredSortedDynamicItemDataset(DynamicItemDataset):
|
||||
"""Possibly filtered, possibly sorted DynamicItemDataset.
|
||||
|
||||
Shares the static data (reference).
|
||||
Has its own dynamic_items and output_keys (deepcopy).
|
||||
"""
|
||||
|
||||
def __init__(self, from_dataset, data_ids):
|
||||
self.data = from_dataset.data
|
||||
self.data_ids = data_ids
|
||||
self.pipeline = copy.deepcopy(from_dataset.pipeline)
|
||||
|
||||
@classmethod
|
||||
def from_json(cls,
|
||||
json_path,
|
||||
replacements={},
|
||||
dynamic_items=None,
|
||||
output_keys=None):
|
||||
raise TypeError("Cannot create SubsetDynamicItemDataset directly!")
|
||||
|
||||
@classmethod
|
||||
def from_csv(cls,
|
||||
csv_path,
|
||||
replacements={},
|
||||
dynamic_items=None,
|
||||
output_keys=None):
|
||||
raise TypeError("Cannot create SubsetDynamicItemDataset directly!")
|
||||
|
||||
|
||||
def add_dynamic_item(datasets, func, takes=None, provides=None):
|
||||
"""Helper for adding the same item to multiple datasets."""
|
||||
for dataset in datasets:
|
||||
dataset.add_dynamic_item(func, takes, provides)
|
||||
|
||||
|
||||
def set_output_keys(datasets, output_keys):
|
||||
"""Helper for setting the same item to multiple datasets."""
|
||||
for dataset in datasets:
|
||||
dataset.set_output_keys(output_keys)
|
@ -0,0 +1,237 @@
|
||||
# Copyright (c) 2023 speechbrain Authors. All Rights Reserved.
|
||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Modified from speechbrain 2023 (https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/utils/depgraph.py)
|
||||
"""A dependency graph for finding evaluation order.
|
||||
|
||||
Authors:
|
||||
* Aku Rouhe 2020
|
||||
"""
|
||||
import collections
|
||||
import uuid
|
||||
|
||||
|
||||
class CircularDependencyError(ValueError):
|
||||
"""
|
||||
An error caused by running into circular dependencies while searching for
|
||||
an evaluation order in a DependencyGraph.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
DGNode = collections.namedtuple("DGNode", ["key", "edges", "data"])
|
||||
|
||||
# A node in DependencyGraph.
|
||||
|
||||
|
||||
class DependencyGraph:
|
||||
"""General-purpose dependency graph.
|
||||
|
||||
Essentially a directed acyclic graph.
|
||||
Usually used to find an evaluation order for e.g. variable substitution
|
||||
The relation that an edge between A and B represents is:
|
||||
"A depends on B, i.e. B should be evaluated before A"
|
||||
|
||||
Nodes can be added explicitly or they can be created implicitly
|
||||
while adding edges.
|
||||
Nodes have keys, which should be some hashable value that identifies
|
||||
the elements the graph represents in your use case. E.G. they can just
|
||||
be the variable name you want to substitute.
|
||||
However, if needed, more generally you can attach any data to a node
|
||||
(e.g. a path in your tree), and if so desired, a unique key can be
|
||||
created for you. You'll only need to know that key while adding edges
|
||||
to/from it.
|
||||
Implicit keys and explicit keys can also be mixed.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.digraph = []
|
||||
self.key2ind = {}
|
||||
# Guard for manual duplicates (but not implicitly added ones)
|
||||
self._manually_added_keys = []
|
||||
|
||||
@staticmethod
|
||||
def get_unique_key():
|
||||
"""Returns a unique hashable identifier."""
|
||||
return uuid.uuid4()
|
||||
|
||||
def add_node(self, key=None, data=None):
|
||||
"""Adds a node explicitly.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
key : hashable, optional
|
||||
If not given, a key is created for you.
|
||||
data : Any, optional
|
||||
Any additional data you wish to attach to this node.
|
||||
|
||||
Returns
|
||||
-------
|
||||
hashable
|
||||
The key that was used (either yours or generated).
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If node with the given key has already been added explicitly
|
||||
(with this method, not "add_edge").
|
||||
"""
|
||||
if key is None:
|
||||
key = self.get_unique_key()
|
||||
elif key in self._manually_added_keys:
|
||||
raise ValueError("Adding duplicate node: {key}".format(key=key))
|
||||
else:
|
||||
self._manually_added_keys.append(key)
|
||||
if key in self.key2ind: # Implicitly added already; don't add again.
|
||||
ind = self.key2ind[key]
|
||||
node = self.digraph[ind]
|
||||
# All that this operation can do is add data:
|
||||
self.digraph[ind] = DGNode(node.key, node.edges, data)
|
||||
return key
|
||||
self.key2ind[key] = len(self.digraph)
|
||||
self.digraph.append(DGNode(key, [], data))
|
||||
return key
|
||||
|
||||
def add_edge(self, from_key, to_key):
|
||||
"""Adds an edge, and implicitly also creates nodes for keys which have
|
||||
not been seen before. This will not let you add data to your nodes.
|
||||
The relation encodes: "from_key depends on to_key"
|
||||
(to_key must be evaluated before from_key).
|
||||
|
||||
Arguments
|
||||
---------
|
||||
from_key : hashable
|
||||
The key which depends on.
|
||||
to_key : hashable
|
||||
The key which is depended on.
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
"""
|
||||
from_ind = self._get_ind_and_add_if_new(from_key)
|
||||
to_ind = self._get_ind_and_add_if_new(to_key)
|
||||
edges_list = self.digraph[from_ind].edges
|
||||
if to_ind not in edges_list:
|
||||
edges_list.append(to_ind)
|
||||
|
||||
def _get_ind_and_add_if_new(self, key):
|
||||
# Used internally to implicitly add nodes for unseen keys
|
||||
if key not in self.key2ind:
|
||||
self.key2ind[key] = len(self.digraph)
|
||||
self.digraph.append(DGNode(key, [], None))
|
||||
return self.key2ind[key]
|
||||
|
||||
def is_valid(self):
|
||||
"""Checks if an evaluation order can be found.
|
||||
|
||||
A dependency graph is evaluatable if there are no circular
|
||||
dependencies, i.e., the graph is acyclic.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
Indicating if the graph is evaluatable.
|
||||
"""
|
||||
return not self._find_first_cycle()
|
||||
|
||||
def get_evaluation_order(self, selected_keys=None):
|
||||
"""Finds one valid evaluation order.
|
||||
|
||||
There can be many different valid
|
||||
orders.
|
||||
NOTE: Generates output one DGNode at a time. May generate DGNodes
|
||||
before it finds a circular dependency. If you really need to know
|
||||
whether an order can be found, check is_valid() first. However,
|
||||
the algorithm for finding cycles is essentially the same as the one
|
||||
used for finding an evaluation order, so for very large graphs...
|
||||
Ah well, but maybe then you should be using some other solution
|
||||
anyway.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
selected_keys : list, None
|
||||
List of keys. If not None, only the selected keys are guaranteed
|
||||
in the evaluation order (along with the keys they depend on).
|
||||
|
||||
Yields
|
||||
------
|
||||
DGNode
|
||||
The added DGNodes in a valid evaluation order.
|
||||
See the DGNode namedtuple above.
|
||||
|
||||
Raises
|
||||
------
|
||||
CircularDependencyError
|
||||
If a circular dependency is found.
|
||||
"""
|
||||
seen_ever = set()
|
||||
|
||||
def toposort(root_ind, visited):
|
||||
"""Implementation of topsort."""
|
||||
nonlocal seen_ever
|
||||
here = visited + [root_ind]
|
||||
if root_ind in visited:
|
||||
raise CircularDependencyError("{cycle}".format(
|
||||
cycle=" -> ".join(str(self.digraph[i].key) for i in here)))
|
||||
if root_ind in seen_ever:
|
||||
return # Yield nothing
|
||||
seen_ever = seen_ever.union(set([root_ind]))
|
||||
for to_ind in self.digraph[root_ind].edges:
|
||||
for ind in toposort(to_ind, visited=here):
|
||||
yield ind
|
||||
yield root_ind
|
||||
|
||||
if selected_keys is None:
|
||||
start_inds = range(len(self.digraph))
|
||||
else:
|
||||
start_inds = [self.key2ind[key] for key in selected_keys]
|
||||
|
||||
for start_ind in start_inds:
|
||||
for ind in toposort(start_ind, []):
|
||||
yield self.digraph[ind]
|
||||
|
||||
def _find_first_cycle(self):
|
||||
"""Depth-first search based algorithm for finding cycles in the graph."""
|
||||
seen_ever = set()
|
||||
|
||||
def cycle_dfs(root_ind, visited):
|
||||
"""Implementation of cycle_dfs."""
|
||||
nonlocal seen_ever
|
||||
print(root_ind, visited)
|
||||
here = visited + [root_ind]
|
||||
if root_ind in visited:
|
||||
return here
|
||||
if root_ind in seen_ever:
|
||||
return []
|
||||
seen_ever = seen_ever.union(set([root_ind]))
|
||||
for to_ind in self.digraph[root_ind].edges:
|
||||
cycle = cycle_dfs(to_ind, here)
|
||||
if cycle:
|
||||
return cycle
|
||||
return []
|
||||
|
||||
for ind in range(len(self.digraph)):
|
||||
if ind not in seen_ever:
|
||||
cycle = cycle_dfs(ind, [])
|
||||
if cycle:
|
||||
return cycle
|
||||
return []
|
||||
|
||||
def __contains__(self, key):
|
||||
# Allows the syntax:
|
||||
# 'key' in dependency_graph
|
||||
return key in self.key2ind
|
@ -0,0 +1,118 @@
|
||||
# Copyright (c) 2023 speechbrain Authors. All Rights Reserved.
|
||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Modified from speechbrain 2023 (https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/core.py)
|
||||
import paddlespeech.s2t.io.speechbrain.dataloader
|
||||
|
||||
|
||||
def _train_loader_specifics(self, dataset, loader_kwargs):
|
||||
sampler = loader_kwargs.get("sampler", None)
|
||||
# Shuffling should really only matter for the train stage. Shuffling
|
||||
# will also lead to more padding in batches if the order was otherwise
|
||||
# sorted by length.
|
||||
shuffle = loader_kwargs.get("shuffle", False)
|
||||
if shuffle and not self.distributed_launch:
|
||||
if sampler is not None:
|
||||
raise ValueError("Cannot specify both shuffle=True"
|
||||
"and a sampler in loader_kwargs")
|
||||
sampler = ReproducibleRandomSampler(dataset)
|
||||
self.train_sampler = sampler
|
||||
loader_kwargs["sampler"] = self.train_sampler
|
||||
# Delete the shuffle flag, since you cannot specify both a sampler and
|
||||
# shuffling:
|
||||
del loader_kwargs["shuffle"]
|
||||
|
||||
# Possibly make a DistributedSampler or a wrapper for some other sampler
|
||||
if self.distributed_launch and not isinstance(dataset, IterableDataset):
|
||||
drop_last = loader_kwargs.get("drop_last", False)
|
||||
# num_replicas arg is equal to world_size
|
||||
# and retrieved automatically within
|
||||
# DistributedSampler obj.
|
||||
if sampler is not None:
|
||||
self.train_sampler = DistributedSamplerWrapper(
|
||||
sampler,
|
||||
rank=self.rank,
|
||||
drop_last=drop_last,
|
||||
shuffle=shuffle, )
|
||||
|
||||
# with DistributedSamplerWrapper, one must disable shuffling for dataloader
|
||||
loader_kwargs["shuffle"] = False
|
||||
loader_kwargs["sampler"] = self.train_sampler
|
||||
elif loader_kwargs.get("batch_sampler") is None:
|
||||
# no sampler and batch-sampler
|
||||
self.train_sampler = DistributedSampler(
|
||||
dataset, rank=self.rank, shuffle=True, drop_last=drop_last)
|
||||
|
||||
# with DistributedSamplerWrapper, one must disable shuffling for dataloader
|
||||
loader_kwargs["shuffle"] = False
|
||||
loader_kwargs["sampler"] = self.train_sampler
|
||||
else: # batch_sampler was specified
|
||||
self.train_sampler = DistributedSamplerWrapper(
|
||||
loader_kwargs.get("batch_sampler", None),
|
||||
rank=self.rank,
|
||||
shuffle=True, )
|
||||
loader_kwargs["batch_sampler"] = self.train_sampler
|
||||
elif self.distributed_launch and isinstance(dataset, IterableDataset):
|
||||
logger.warning("Cannot automatically solve distributed sampling "
|
||||
"for IterableDataset.")
|
||||
return loader_kwargs
|
||||
|
||||
|
||||
def make_dataloader(self, dataset, stage, **loader_kwargs):
|
||||
"""Creates DataLoaders for Datasets.
|
||||
|
||||
This is used by ``fit()`` and ``evaluate()`` if they just receive
|
||||
Datasets.
|
||||
|
||||
Alternatively, this can be called from outside the Brain subclass.
|
||||
In that case, the DataLoader should be passed to ``fit()`` in place
|
||||
of the dataset.
|
||||
|
||||
The Stage.TRAIN DataLoader is handled specially. It has extra args for
|
||||
shuffle and drop_last. In DDP a DistributedSampler is created (unless
|
||||
the dataset is an IterableDataset).
|
||||
|
||||
NOTE
|
||||
----
|
||||
Some important DataLoader arguments are passed via **loader_kwargs,
|
||||
e.g., batch_size, num_workers, pin_memory.
|
||||
|
||||
NOTE
|
||||
----
|
||||
By default, ``evaluate()`` specifies ckpt_prefix=None to stop the test
|
||||
DataLoader being added to the checkpointer. If you need to add a
|
||||
recoverable after saving checkpoints (e.g., at test time, after
|
||||
checkpointing the training), and still be able to recover reasonably,
|
||||
you should probably specify ``allow_partial_load=True``.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
dataset : Dataset
|
||||
A set of data to use to create data loader. If the Dataset is a
|
||||
DynamicItemDataset, PaddedBatch is used as the default collate_fn,
|
||||
unless specified in loader_kwargs.
|
||||
stage : Stage
|
||||
The stage of the experiment: Stage.TRAIN, Stage.VALID, Stage.TEST
|
||||
ckpt_prefix : str, None
|
||||
Prefix to use for SaveableDataLoader Checkpoint name. The Stage
|
||||
name is added to this to create the full key. Set to None to not
|
||||
save the DataLoader.
|
||||
**loader_kwargs : dict
|
||||
Additional keyword arguments to the DataLoader.
|
||||
E.g., batch_size, num_workers, pin_memory.
|
||||
"""
|
||||
|
||||
dataloader_ = dataloader.make_dataloader(dataset, **loader_kwargs)
|
||||
return dataloader_
|
@ -0,0 +1,503 @@
|
||||
# Copyright (c) 2023 speechbrain Authors. All Rights Reserved.
|
||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Modified from speechbrain 2023 (https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/dataio/sampler.py)
|
||||
"""compatible samplers.
|
||||
|
||||
These determine the order of iteration through a dataset.
|
||||
|
||||
Authors:
|
||||
* Aku Rouhe 2020
|
||||
* Samuele Cornell 2020
|
||||
* Ralf Leibold 2020
|
||||
* Artem Ploujnikov 2021
|
||||
* Andreas Nautsch 2021
|
||||
"""
|
||||
import logging
|
||||
from collections import Counter
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle.io import RandomSampler
|
||||
from paddle.io import Sampler
|
||||
from paddle.io import WeightedRandomSampler
|
||||
from scipy.stats import lognorm
|
||||
|
||||
from paddlespeech.s2t.io.speechbrain.dataset import DynamicItemDataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReproducibleRandomSampler(RandomSampler):
|
||||
"""A modification of RandomSampler which always returns the same values.
|
||||
|
||||
Also look at `paddle.io.RandomSampler`. This has mostly
|
||||
the same behaviour and arguments, except for adding 'seed' and 'epoch' and
|
||||
not supporting 'generator'.
|
||||
|
||||
Note
|
||||
----
|
||||
Call `set_epoch` before every epoch. Otherwise, the sampler will produce the
|
||||
same sequence of indices every epoch.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
data_source : Dataset
|
||||
The data source to sample indices for.
|
||||
seed : int
|
||||
The base seed to use for the random number generator. It is recommended
|
||||
to use a value which has a good mix of 0 and 1 bits.
|
||||
epoch : int
|
||||
The epoch to start at.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, data_source, seed=563375142, epoch=0, **kwargs):
|
||||
if "generator" in kwargs:
|
||||
MSG = ("Cannot give a separate generator when using " +
|
||||
"ReproducibleRandomSampler")
|
||||
raise ValueError(MSG)
|
||||
super().__init__(data_source, **kwargs)
|
||||
self.seed = int(seed)
|
||||
self.epoch = epoch
|
||||
self.gen = paddle.seed(1)
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
"""
|
||||
You can also just access self.epoch, but we maintain this interface
|
||||
to mirror paddle.io.DistributedBatchSampler
|
||||
"""
|
||||
self.epoch = epoch
|
||||
|
||||
def __iter__(self):
|
||||
self.gen.manual_seed(self.seed + self.epoch)
|
||||
return super().__iter__()
|
||||
|
||||
|
||||
class ReproducibleWeightedRandomSampler(WeightedRandomSampler):
|
||||
"""A reproducible modification of WeightedRandomSampler.
|
||||
|
||||
Also look at `paddle.io.WeightedRandomSampler`. This has the
|
||||
the same behaviour and arguments, except for adding 'seed' and 'epoch' and
|
||||
not supporting 'generator'.
|
||||
|
||||
Note
|
||||
----
|
||||
Call `set_epoch` before every epoch. Otherwise, the sampler will produce the
|
||||
same sequence of indices every epoch.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
weights : sequence of float
|
||||
Weights for each index. Doesn't need to sum to one.
|
||||
num_samples : int
|
||||
Number of samples to draw
|
||||
replacement : bool
|
||||
To draw with replacement or not (within an epoch of num_samples).
|
||||
seed : int
|
||||
The base seed to use for the random number generator. It is recommended
|
||||
to use a value which has a good mix of 0 and 1 bits.
|
||||
epoch : int
|
||||
The epoch to start at.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weights,
|
||||
num_samples,
|
||||
replacement,
|
||||
seed=129491412,
|
||||
epoch=0,
|
||||
**kwargs, ):
|
||||
if "generator" in kwargs:
|
||||
MSG = ("Cannot give a separate generator when using " +
|
||||
"ReproducibleRandomSampler")
|
||||
raise ValueError(MSG)
|
||||
super().__init__(weights, num_samples, replacement, **kwargs)
|
||||
self.seed = int(seed)
|
||||
self.epoch = epoch
|
||||
self.gen = paddle.seed(1)
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
"""
|
||||
You can also just access self.epoch, but we maintain this interface
|
||||
to mirror paddle.io.DistributedBatchSampler
|
||||
"""
|
||||
self.epoch = epoch
|
||||
|
||||
def __iter__(self):
|
||||
self.gen.manual_seed(self.seed + self.epoch)
|
||||
return super().__iter__()
|
||||
|
||||
|
||||
class DynamicBatchSampler(Sampler):
|
||||
"""This BatchSampler batches examples together by grouping them by their length.
|
||||
|
||||
Every example in the batch have approximately the same length and
|
||||
thus padding is minimized.
|
||||
This enables faster training on datasets
|
||||
where length of examples can vary significantly (e.g Librispeech).
|
||||
Inspired by: https://www.tensorflow.org/api_docs/python/tf/data/experimental/bucket_by_sequence_length
|
||||
|
||||
Dynamic batching is performed by specifying a max_batch_length which is the
|
||||
upper limit for the sum of the length of examples in a batch:
|
||||
e.g., if ex1 has length 4, ex2 length 5 and if max_batch_length is set to 6
|
||||
ex1 and ex2 will be placed, alone, in two distinct batches.
|
||||
|
||||
Length for each example can be obtained in two manners.
|
||||
If the input dataset is a DynamicItemDataset it can be obtained by specifying a
|
||||
length_func. Default assumes a "duration" entry is in the annotation.
|
||||
Length for each example can also be passed to this class upon instantiation
|
||||
by specifying a list containing the length for each example and passing it to
|
||||
lengths_list.
|
||||
|
||||
Examples are grouped together by defining a set of possible discrete intervals
|
||||
(buckets). Examples whose length fall into these intervals can be batched together.
|
||||
|
||||
The number of buckets can be specified by using the arg num_buckets.
|
||||
There is usually an optimal range for the value of this argument.
|
||||
|
||||
If num_buckets == 1, all examples can be batched together. You have maximum randomization
|
||||
but your training speed will be slower due to the fact that a large amount of the values will be padding
|
||||
as long and short examples can be batched together.
|
||||
As the number of buckets grows only examples with similar
|
||||
length can be grouped together.
|
||||
This trades-off speed with randomization.
|
||||
TLDR: Low number -> better randomization, High number -> faster training.
|
||||
NOTE THAT: if set too high the training speed will decrease. If num_buckets -> number of examples in the
|
||||
dataset the batch size will be small impacting training speed and possibly performance.
|
||||
|
||||
The buckets can also be specified by passing a list to the bucket_boundaries
|
||||
argument instead of specifying a left_bucket_length and a bucket_length_multiplier.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset,
|
||||
max_batch_length: int,
|
||||
num_buckets: int=None,
|
||||
length_func=lambda x: x["duration"],
|
||||
shuffle: bool=True,
|
||||
batch_ordering: str="random",
|
||||
max_batch_ex: int=None,
|
||||
bucket_boundaries: List[int]=[],
|
||||
lengths_list: List[int]=None,
|
||||
seed: int=42,
|
||||
epoch: int=0,
|
||||
drop_last: bool=False,
|
||||
verbose: bool=False, ):
|
||||
self._dataset = dataset
|
||||
self._ex_lengths = {}
|
||||
ex_ids = self._dataset.data_ids
|
||||
self.verbose = verbose
|
||||
|
||||
# We do not put a default on num_buckets to encourage users to play with this parameter
|
||||
if num_buckets is None and len(bucket_boundaries) == 0:
|
||||
raise RuntimeError(
|
||||
"Please specify either num_buckets or bucket boundaries."
|
||||
"Check the docs, and/or the tutorial !")
|
||||
|
||||
if lengths_list is not None:
|
||||
# take length of examples from this argument and bypass length_key
|
||||
for indx in range(len(lengths_list)):
|
||||
self._ex_lengths[str(indx)] = lengths_list[indx]
|
||||
else:
|
||||
# use length func
|
||||
if not isinstance(dataset, DynamicItemDataset):
|
||||
raise NotImplementedError(
|
||||
"Dataset should be a DynamicItemDataset when using length function"
|
||||
)
|
||||
for indx in range(len(self._dataset)):
|
||||
self._ex_lengths[str(indx)] = length_func(
|
||||
self._dataset.data[ex_ids[indx]])
|
||||
|
||||
if len(bucket_boundaries) > 0:
|
||||
if not all([x >= 0 for x in bucket_boundaries]):
|
||||
raise ValueError(
|
||||
"All elements in bucket boundaries should be non-negative (>= 0)."
|
||||
)
|
||||
if not len(set(bucket_boundaries)) == len(bucket_boundaries):
|
||||
raise ValueError(
|
||||
"Bucket_boundaries should not contain duplicates.")
|
||||
np.testing.assert_array_equal(
|
||||
np.array(bucket_boundaries),
|
||||
np.array(sorted(bucket_boundaries)),
|
||||
err_msg="The arg bucket_boundaries should be an ascending sorted list of non negative values values!",
|
||||
)
|
||||
self._bucket_boundaries = np.array(sorted(bucket_boundaries))
|
||||
else:
|
||||
# use num_buckets
|
||||
self._bucket_boundaries = np.array(
|
||||
self._get_boundaries_through_warping(
|
||||
max_batch_length=max_batch_length,
|
||||
num_quantiles=num_buckets, ))
|
||||
|
||||
self._max_batch_length = max_batch_length
|
||||
self._shuffle_ex = shuffle
|
||||
self._batch_ordering = batch_ordering
|
||||
self._seed = seed
|
||||
self._drop_last = drop_last
|
||||
if max_batch_ex is None:
|
||||
max_batch_ex = np.inf
|
||||
self._max_batch_ex = max_batch_ex
|
||||
# Calculate bucket lengths - how often does one bucket boundary fit into max_batch_length?
|
||||
self._bucket_lens = [
|
||||
max(1, int(max_batch_length / self._bucket_boundaries[i]))
|
||||
for i in range(len(self._bucket_boundaries))
|
||||
] + [1]
|
||||
self._epoch = epoch
|
||||
self._generate_batches()
|
||||
|
||||
def get_durations(self, batch):
|
||||
"""Gets durations of the elements in the batch."""
|
||||
return [self._ex_lengths[str(idx)] for idx in batch]
|
||||
|
||||
def _get_boundaries_through_warping(
|
||||
self,
|
||||
max_batch_length: int,
|
||||
num_quantiles: int, ) -> List[int]:
|
||||
|
||||
# NOTE: the following lines do not cover that there is only one example in the dataset
|
||||
# warp frames (duration) distribution of train data
|
||||
logger.info("Batch quantisation in latent space")
|
||||
# linspace set-up
|
||||
num_boundaries = num_quantiles + 1
|
||||
# create latent linearly equal spaced buckets
|
||||
latent_boundaries = np.linspace(
|
||||
1 / num_boundaries,
|
||||
num_quantiles / num_boundaries,
|
||||
num_quantiles, )
|
||||
# get quantiles using lognormal distribution
|
||||
quantiles = lognorm.ppf(latent_boundaries, 1)
|
||||
# scale up to to max_batch_length
|
||||
bucket_boundaries = quantiles * max_batch_length / quantiles[-1]
|
||||
# compute resulting bucket length multipliers
|
||||
length_multipliers = [
|
||||
bucket_boundaries[x + 1] / bucket_boundaries[x]
|
||||
for x in range(num_quantiles - 1)
|
||||
]
|
||||
# logging
|
||||
logger.info(
|
||||
"Latent bucket boundary - buckets: {} - length multipliers: {}".
|
||||
format(
|
||||
list(map("{:.2f}".format, bucket_boundaries)),
|
||||
list(map("{:.2f}".format, length_multipliers)), ))
|
||||
return list(sorted(bucket_boundaries))
|
||||
|
||||
def _permute_batches(self):
|
||||
|
||||
if self._batch_ordering == "random":
|
||||
# deterministically shuffle based on epoch and seed
|
||||
gen = paddle.seed(1)
|
||||
gen.manual_seed(self._seed + self._epoch)
|
||||
sampler = paddle.randperm(
|
||||
len(self._batches)).tolist() # type: ignore
|
||||
tmp = []
|
||||
for idx in sampler:
|
||||
tmp.append(self._batches[idx])
|
||||
self._batches = tmp
|
||||
|
||||
elif self._batch_ordering == "ascending":
|
||||
self._batches = sorted(
|
||||
self._batches,
|
||||
key=lambda x: max([self._ex_lengths[str(idx)] for idx in x]), )
|
||||
elif self._batch_ordering == "descending":
|
||||
self._batches = sorted(
|
||||
self._batches,
|
||||
key=lambda x: max([self._ex_lengths[str(idx)] for idx in x]),
|
||||
reverse=True, )
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def _generate_batches(self):
|
||||
logger.info("DynamicBatchSampler: Generating dynamic batches")
|
||||
if self._shuffle_ex:
|
||||
# deterministically shuffle based on epoch and seed
|
||||
gen = paddle.seed(1)
|
||||
gen.manual_seed(self._seed + self._epoch)
|
||||
sampler = paddle.randperm(
|
||||
len(self._dataset)).tolist() # type: ignore
|
||||
else:
|
||||
# take examples as they are: e.g. they have been sorted
|
||||
sampler = range(len(self._dataset)) # type: ignore
|
||||
|
||||
self._batches = []
|
||||
bucket_batches = [[] for i in self._bucket_lens]
|
||||
|
||||
stats_tracker = [{
|
||||
"min": np.inf,
|
||||
"max": -np.inf,
|
||||
"tot": 0,
|
||||
"n_ex": 0
|
||||
} for i in self._bucket_lens]
|
||||
|
||||
for idx in sampler:
|
||||
# length of pre-sampled audio
|
||||
item_len = self._ex_lengths[str(idx)]
|
||||
# bucket to fill up most padding
|
||||
bucket_id = np.searchsorted(self._bucket_boundaries, item_len)
|
||||
# fill audio's duration into that bucket
|
||||
bucket_batches[bucket_id].append(idx)
|
||||
|
||||
stats_tracker[bucket_id]["min"] = min(
|
||||
stats_tracker[bucket_id]["min"], item_len)
|
||||
stats_tracker[bucket_id]["max"] = max(
|
||||
stats_tracker[bucket_id]["max"], item_len)
|
||||
stats_tracker[bucket_id]["tot"] += item_len
|
||||
stats_tracker[bucket_id]["n_ex"] += 1
|
||||
# track #samples - why not duration/#frames; rounded up?
|
||||
# keep track of durations, if necessary
|
||||
|
||||
if (len(bucket_batches[bucket_id]) >= self._bucket_lens[bucket_id]
|
||||
or len(bucket_batches[bucket_id]) >= self._max_batch_ex):
|
||||
self._batches.append(bucket_batches[bucket_id])
|
||||
bucket_batches[bucket_id] = []
|
||||
# keep track of durations
|
||||
|
||||
# Dump remaining batches
|
||||
if not self._drop_last:
|
||||
for batch in bucket_batches:
|
||||
if batch:
|
||||
self._batches.append(batch)
|
||||
|
||||
self._permute_batches() # possibly reorder batches
|
||||
|
||||
if self._epoch == 0: # only log at first epoch
|
||||
# frames per batch & their padding remaining
|
||||
boundaries = [0] + self._bucket_boundaries.tolist()
|
||||
|
||||
for bucket_indx in range(len(self._bucket_boundaries)):
|
||||
try:
|
||||
num_batches = stats_tracker[bucket_indx]["tot"] // (
|
||||
self._max_batch_length)
|
||||
pad_factor = (stats_tracker[bucket_indx]["max"] -
|
||||
stats_tracker[bucket_indx]["min"]) / (
|
||||
stats_tracker[bucket_indx]["tot"] /
|
||||
stats_tracker[bucket_indx]["n_ex"])
|
||||
except ZeroDivisionError:
|
||||
num_batches = 0
|
||||
pad_factor = 0
|
||||
|
||||
logger.info((
|
||||
"DynamicBatchSampler: Bucket {} with boundary {:.1f}-{:.1f} and "
|
||||
+
|
||||
"batch_size {}: Num Examples {:.1f}, Num Full Batches {:.3f}, Pad Factor {:.3f}."
|
||||
).format(
|
||||
bucket_indx,
|
||||
boundaries[bucket_indx],
|
||||
boundaries[bucket_indx + 1],
|
||||
self._bucket_lens[bucket_indx],
|
||||
stats_tracker[bucket_indx]["n_ex"],
|
||||
num_batches,
|
||||
pad_factor * 100, ))
|
||||
|
||||
if self.verbose:
|
||||
batch_stats = {
|
||||
"tot_frames": [],
|
||||
"tot_pad_frames": [],
|
||||
"pad_%": [],
|
||||
}
|
||||
for batch in self._batches:
|
||||
tot_frames = sum(
|
||||
[self._ex_lengths[str(idx)] for idx in batch])
|
||||
batch_stats["tot_frames"].append(tot_frames)
|
||||
max_frames = max(
|
||||
[self._ex_lengths[str(idx)] for idx in batch])
|
||||
tot_pad = sum([
|
||||
max_frames - self._ex_lengths[str(idx)] for idx in batch
|
||||
])
|
||||
batch_stats["tot_pad_frames"].append(tot_pad)
|
||||
batch_stats["pad_%"].append(tot_pad / tot_frames * 100)
|
||||
|
||||
padding_details = "Batch {} with {:.1f} frames with {} files - {:.1f} padding, {:.2f} (%) of total."
|
||||
padding_details = "DynamicBatchSampler: " + padding_details
|
||||
for i in range(len(self._batches)):
|
||||
logger.info(
|
||||
padding_details.format(
|
||||
i,
|
||||
batch_stats["tot_frames"][i],
|
||||
len(self._batches[i]),
|
||||
batch_stats["tot_pad_frames"][i],
|
||||
batch_stats["pad_%"][i], ))
|
||||
|
||||
def __iter__(self):
|
||||
for batch in self._batches:
|
||||
yield batch
|
||||
if self._shuffle_ex: # re-generate examples if ex_ordering == "random"
|
||||
self._generate_batches()
|
||||
if self._batch_ordering == "random":
|
||||
# we randomly permute the batches only --> faster
|
||||
self._permute_batches()
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
"""
|
||||
You can also just access self.epoch, but we maintain this interface
|
||||
to mirror paddle.io.DistributedBatchSampler
|
||||
"""
|
||||
self._epoch = epoch
|
||||
self._generate_batches()
|
||||
|
||||
def __len__(self):
|
||||
return len(self._batches)
|
||||
|
||||
|
||||
class BalancingDataSampler(ReproducibleWeightedRandomSampler):
|
||||
"""A data sampler that takes a single key from the dataset and
|
||||
ensures an approximately equal distribution by that key
|
||||
|
||||
Arguments
|
||||
---------
|
||||
dataset: DynamicItemDataset
|
||||
the dataset form which samples will be drawn
|
||||
key: str
|
||||
the key from which samples will be taken
|
||||
num_samples : int
|
||||
Number of samples to draw
|
||||
replacement : bool
|
||||
To draw with replacement or not (within an epoch of num_samples).
|
||||
seed : int
|
||||
The base seed to use for the random number generator. It is recommended
|
||||
to use a value which has a good mix of 0 and 1 bits.
|
||||
epoch : int
|
||||
The epoch to start at.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset,
|
||||
key,
|
||||
num_samples=None,
|
||||
replacement=True,
|
||||
seed=563375142,
|
||||
epoch=0,
|
||||
**kwargs, ):
|
||||
self.dataset = dataset
|
||||
self.key = key
|
||||
if not num_samples:
|
||||
num_samples = len(dataset)
|
||||
weights = self._compute_weights()
|
||||
super().__init__(weights, num_samples, replacement, seed, epoch,
|
||||
**kwargs)
|
||||
|
||||
def _compute_weights(self):
|
||||
with self.dataset.output_keys_as([self.key]):
|
||||
class_ids = [item[self.key] for item in self.dataset]
|
||||
class_counter = Counter(class_ids)
|
||||
weights = 1 / paddle.to_tensor(
|
||||
[class_counter[class_id] for class_id in class_ids])
|
||||
return weights
|
@ -0,0 +1,156 @@
|
||||
# Copyright (c) 2023 speechbrain Authors. All Rights Reserved.
|
||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Modified from speechbrain 2023 (https://github.com/speechbrain/speechbrain/blob/develop/recipes/AISHELL-1/ASR/CTC/train_with_wav2vec.py)
|
||||
import data_pipeline
|
||||
import dataio
|
||||
import numpy
|
||||
import paddle
|
||||
import tqdm
|
||||
import transformers
|
||||
from dataloader import make_dataloader
|
||||
from hyperpyyaml import load_hyperpyyaml
|
||||
|
||||
import dataset
|
||||
|
||||
|
||||
def dataio_prepare(hparams):
|
||||
"""This function prepares the datasets to be used in the brain class.
|
||||
It also defines the data processing pipeline through user-defined functions."""
|
||||
data_folder = hparams["data_folder"]
|
||||
|
||||
train_data = dataset.DynamicItemDataset.from_csv(
|
||||
csv_path=hparams["train_data"],
|
||||
replacements={"data_root": data_folder}, )
|
||||
|
||||
if hparams["sorting"] == "ascending":
|
||||
# we sort training data to speed up training and get better results.
|
||||
train_data = train_data.filtered_sorted(sort_key="duration")
|
||||
# when sorting do not shuffle in dataloader ! otherwise is pointless
|
||||
hparams["train_dataloader_opts"]["shuffle"] = False
|
||||
|
||||
elif hparams["sorting"] == "descending":
|
||||
train_data = train_data.filtered_sorted(
|
||||
sort_key="duration", reverse=True)
|
||||
# when sorting do not shuffle in dataloader ! otherwise is pointless
|
||||
hparams["train_dataloader_opts"]["shuffle"] = False
|
||||
|
||||
elif hparams["sorting"] == "random":
|
||||
pass
|
||||
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"sorting must be random, ascending or descending")
|
||||
|
||||
valid_data = dataset.DynamicItemDataset.from_csv(
|
||||
csv_path=hparams["valid_data"],
|
||||
replacements={"data_root": data_folder}, )
|
||||
valid_data = valid_data.filtered_sorted(sort_key="duration")
|
||||
|
||||
test_data = dataset.DynamicItemDataset.from_csv(
|
||||
csv_path=hparams["test_data"],
|
||||
replacements={"data_root": data_folder}, )
|
||||
test_data = test_data.filtered_sorted(sort_key="duration")
|
||||
|
||||
datasets = [train_data, valid_data, test_data]
|
||||
|
||||
# Defining tokenizer and loading it
|
||||
tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-chinese')
|
||||
|
||||
# 2. Define audio pipeline:
|
||||
@data_pipeline.takes("wav")
|
||||
@data_pipeline.provides("sig")
|
||||
def audio_pipeline(wav):
|
||||
sig = dataio.read_audio(wav)
|
||||
return sig
|
||||
|
||||
dataset.add_dynamic_item(datasets, audio_pipeline)
|
||||
|
||||
# 3. Define text pipeline:
|
||||
@data_pipeline.takes("transcript")
|
||||
@data_pipeline.provides("wrd", "tokens_list", "tokens")
|
||||
def text_pipeline(wrd):
|
||||
wrd = "".join(wrd.split(" "))
|
||||
yield wrd
|
||||
tokens_list = tokenizer(wrd)["input_ids"]
|
||||
yield tokens_list
|
||||
tokens = numpy.array(tokens_list, dtype="int64")
|
||||
yield tokens
|
||||
|
||||
dataset.add_dynamic_item(datasets, text_pipeline)
|
||||
|
||||
# 4. Set output:
|
||||
dataset.set_output_keys(
|
||||
datasets,
|
||||
["id", "sig", "wrd", "tokens"], )
|
||||
|
||||
# 5. If Dynamic Batching is used, we instantiate the needed samplers.
|
||||
train_batch_sampler = None
|
||||
valid_batch_sampler = None
|
||||
if hparams["dynamic_batching"]:
|
||||
from sampler import DynamicBatchSampler # noqa
|
||||
|
||||
dynamic_hparams = hparams["dynamic_batch_sampler"]
|
||||
num_buckets = dynamic_hparams["num_buckets"]
|
||||
|
||||
train_batch_sampler = DynamicBatchSampler(
|
||||
train_data,
|
||||
dynamic_hparams["max_batch_len"],
|
||||
num_buckets=num_buckets,
|
||||
length_func=lambda x: x["duration"],
|
||||
shuffle=dynamic_hparams["shuffle_ex"],
|
||||
batch_ordering=dynamic_hparams["batch_ordering"], )
|
||||
|
||||
valid_batch_sampler = DynamicBatchSampler(
|
||||
valid_data,
|
||||
dynamic_hparams["max_batch_len"],
|
||||
num_buckets=num_buckets,
|
||||
length_func=lambda x: x["duration"],
|
||||
shuffle=dynamic_hparams["shuffle_ex"],
|
||||
batch_ordering=dynamic_hparams["batch_ordering"], )
|
||||
|
||||
return (train_data, valid_data, test_data, tokenizer, train_batch_sampler,
|
||||
valid_batch_sampler, )
|
||||
|
||||
|
||||
hparams_file = 'train_with_wav2vec.yaml'
|
||||
with open(hparams_file) as fin:
|
||||
hparams = load_hyperpyyaml(fin, None)
|
||||
|
||||
(train_data, valid_data, test_data, tokenizer, train_bsampler,
|
||||
valid_bsampler, ) = dataio_prepare(hparams)
|
||||
|
||||
train_dataloader_opts = hparams["train_dataloader_opts"]
|
||||
valid_dataloader_opts = hparams["valid_dataloader_opts"]
|
||||
|
||||
if train_bsampler is not None:
|
||||
train_dataloader_opts = {
|
||||
"batch_sampler": train_bsampler,
|
||||
"num_workers": hparams["num_workers"],
|
||||
}
|
||||
|
||||
if valid_bsampler is not None:
|
||||
valid_dataloader_opts = {"batch_sampler": valid_bsampler}
|
||||
|
||||
train_set = make_dataloader(train_data, stage='train', **train_dataloader_opts)
|
||||
|
||||
valid_set = make_dataloader(
|
||||
valid_data,
|
||||
stage='train',
|
||||
**valid_dataloader_opts, )
|
||||
|
||||
for batch in valid_set:
|
||||
print(batch)
|
||||
print('done') # exit()
|
Loading…
Reference in new issue