From e5b99965b84acc7bcccc6db5ac9ce4b2e2a47ada Mon Sep 17 00:00:00 2001 From: zhangtianhao Date: Wed, 23 Nov 2022 07:18:26 +0000 Subject: [PATCH] sb_pipeline --- examples/aishell/asr2/conf/tuning/decode.yaml | 11 - examples/aishell/{asr2 => asr3}/cmd.sh | 0 .../{asr2 => asr3}/conf/conformer.yaml | 0 .../{asr2 => asr3}/conf/preprocess.yaml | 1 - .../conf/tuning/chunk_decode.yaml | 0 examples/aishell/asr3/conf/tuning/decode.yaml | 4 + .../{asr2 => asr3}/conf/wav2vec2ASR.yaml | 67 +- examples/aishell/{asr2 => asr3}/local/data.sh | 0 examples/aishell/{asr2 => asr3}/local/test.sh | 8 +- .../aishell/{asr2 => asr3}/local/test_wav.sh | 0 .../aishell/{asr2 => asr3}/local/train.sh | 13 +- examples/aishell/{asr2 => asr3}/path.sh | 0 examples/aishell/{asr2 => asr3}/run.sh | 13 +- examples/aishell/{asr2 => asr3}/utils | 0 paddlespeech/s2t/exps/wav2vec2/model.py | 433 ++++--- paddlespeech/s2t/io/wav2vec2/batch.py | 184 +++ paddlespeech/s2t/io/wav2vec2/data_pipeline.py | 518 +++++++++ paddlespeech/s2t/io/wav2vec2/data_utils.py | 167 +++ paddlespeech/s2t/io/wav2vec2/dataio.py | 1024 +++++++++++++++++ paddlespeech/s2t/io/wav2vec2/dataloader.py | 215 ++++ paddlespeech/s2t/io/wav2vec2/dataset.py | 409 +++++++ paddlespeech/s2t/io/wav2vec2/depgraph.py | 276 +++++ .../s2t/io/wav2vec2/make_dataloader.py | 115 ++ paddlespeech/s2t/io/wav2vec2/sampler.py | 695 +++++++++++ paddlespeech/s2t/io/wav2vec2/sb_pipeline.py | 162 +++ .../s2t/io/wav2vec2/train_with_wav2vec.yaml | 86 ++ .../s2t/models/wav2vec2/wav2vec2_ASR.py | 86 +- 27 files changed, 4263 insertions(+), 224 deletions(-) delete mode 100755 examples/aishell/asr2/conf/tuning/decode.yaml rename examples/aishell/{asr2 => asr3}/cmd.sh (100%) rename examples/aishell/{asr2 => asr3}/conf/conformer.yaml (100%) rename examples/aishell/{asr2 => asr3}/conf/preprocess.yaml (76%) rename examples/aishell/{asr2 => asr3}/conf/tuning/chunk_decode.yaml (100%) create mode 100755 examples/aishell/asr3/conf/tuning/decode.yaml rename examples/aishell/{asr2 => asr3}/conf/wav2vec2ASR.yaml (75%) rename examples/aishell/{asr2 => asr3}/local/data.sh (100%) rename examples/aishell/{asr2 => asr3}/local/test.sh (90%) rename examples/aishell/{asr2 => asr3}/local/test_wav.sh (100%) rename examples/aishell/{asr2 => asr3}/local/train.sh (80%) rename examples/aishell/{asr2 => asr3}/path.sh (100%) rename examples/aishell/{asr2 => asr3}/run.sh (80%) rename examples/aishell/{asr2 => asr3}/utils (100%) create mode 100755 paddlespeech/s2t/io/wav2vec2/batch.py create mode 100755 paddlespeech/s2t/io/wav2vec2/data_pipeline.py create mode 100755 paddlespeech/s2t/io/wav2vec2/data_utils.py create mode 100755 paddlespeech/s2t/io/wav2vec2/dataio.py create mode 100755 paddlespeech/s2t/io/wav2vec2/dataloader.py create mode 100755 paddlespeech/s2t/io/wav2vec2/dataset.py create mode 100755 paddlespeech/s2t/io/wav2vec2/depgraph.py create mode 100755 paddlespeech/s2t/io/wav2vec2/make_dataloader.py create mode 100755 paddlespeech/s2t/io/wav2vec2/sampler.py create mode 100755 paddlespeech/s2t/io/wav2vec2/sb_pipeline.py create mode 100755 paddlespeech/s2t/io/wav2vec2/train_with_wav2vec.yaml diff --git a/examples/aishell/asr2/conf/tuning/decode.yaml b/examples/aishell/asr2/conf/tuning/decode.yaml deleted file mode 100755 index 72ede9272..000000000 --- a/examples/aishell/asr2/conf/tuning/decode.yaml +++ /dev/null @@ -1,11 +0,0 @@ -beam_size: 10 -decode_batch_size: 128 -error_rate_type: cer -decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring' -ctc_weight: 0.5 # ctc weight for attention rescoring decode mode. -decoding_chunk_size: -1 # decoding chunk size. Defaults to -1. - # <0: for decoding, use full chunk. - # >0: for decoding, use fixed chunk size as set. - # 0: used for training, it's prohibited here. -num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1. -simulate_streaming: False # simulate streaming inference. Defaults to False. diff --git a/examples/aishell/asr2/cmd.sh b/examples/aishell/asr3/cmd.sh similarity index 100% rename from examples/aishell/asr2/cmd.sh rename to examples/aishell/asr3/cmd.sh diff --git a/examples/aishell/asr2/conf/conformer.yaml b/examples/aishell/asr3/conf/conformer.yaml similarity index 100% rename from examples/aishell/asr2/conf/conformer.yaml rename to examples/aishell/asr3/conf/conformer.yaml diff --git a/examples/aishell/asr2/conf/preprocess.yaml b/examples/aishell/asr3/conf/preprocess.yaml similarity index 76% rename from examples/aishell/asr2/conf/preprocess.yaml rename to examples/aishell/asr3/conf/preprocess.yaml index 4a908a83b..724782ed6 100755 --- a/examples/aishell/asr2/conf/preprocess.yaml +++ b/examples/aishell/asr3/conf/preprocess.yaml @@ -1,4 +1,3 @@ process: # use raw audio - type: wav_process - dither: 0.0 diff --git a/examples/aishell/asr2/conf/tuning/chunk_decode.yaml b/examples/aishell/asr3/conf/tuning/chunk_decode.yaml similarity index 100% rename from examples/aishell/asr2/conf/tuning/chunk_decode.yaml rename to examples/aishell/asr3/conf/tuning/chunk_decode.yaml diff --git a/examples/aishell/asr3/conf/tuning/decode.yaml b/examples/aishell/asr3/conf/tuning/decode.yaml new file mode 100755 index 000000000..69d0a4551 --- /dev/null +++ b/examples/aishell/asr3/conf/tuning/decode.yaml @@ -0,0 +1,4 @@ +decode_batch_size: 1 +error_rate_type: cer +decoding_method: ctc_greedy_search # 'ctc_greedy_search', 'ctc_prefix_beam_search' +beam_size: 10 diff --git a/examples/aishell/asr2/conf/wav2vec2ASR.yaml b/examples/aishell/asr3/conf/wav2vec2ASR.yaml similarity index 75% rename from examples/aishell/asr2/conf/wav2vec2ASR.yaml rename to examples/aishell/asr3/conf/wav2vec2ASR.yaml index 270b91b42..b5cadffd3 100755 --- a/examples/aishell/asr2/conf/wav2vec2ASR.yaml +++ b/examples/aishell/asr3/conf/wav2vec2ASR.yaml @@ -4,16 +4,39 @@ freeze_wav2vec2: False normalize_wav: True output_norm: True -dnn_blocks: 2 -dnn_neurons: 1024 -blank_id: 0 -ctc_dropout_rate: 0.0 +init_type: 'kaiming_uniform' # !Warning: need to convergence +enc: + input_shape: 1024 + dnn_blocks: 3 + dnn_neurons: 1024 + activation: True + normalization: True + dropout_rate: [0.15, 0.15, 0.0] +ctc: + enc_n_units: 1024 + blank_id: 0 + dropout_rate: 0.0 + +audio_augment: + speeds: [90, 100, 110] + +spec_augment: + time_warp: True + time_warp_window: 5 + time_warp_mode: bicubic + freq_mask: True + n_freq_mask: 2 + time_mask: True + n_time_mask: 2 + replace_with_zero: False + freq_mask_width: 30 + time_mask_width: 40 wav2vec2_params_path: exp/wav2vec2/chinese-wav2vec2-large.pdparams ############################################ # Wav2Vec2.0 # ############################################ -vocab_size: 32 +# vocab_size: 1000000 hidden_size: 1024 num_hidden_layers: 24 num_attention_heads: 16 @@ -60,9 +83,9 @@ codevector_dim: 256 proj_codevector_dim: 256 diversity_loss_weight: 0.1 use_weighted_layer_sum: False -pad_token_id: 0 -bos_token_id: 1 -eos_token_id: 2 +# pad_token_id: 0 +# bos_token_id: 1 +# eos_token_id: 2 add_adapter: False adapter_kernel_size: 3 adapter_stride: 2 @@ -72,20 +95,25 @@ output_hidden_size: None ########################################### # Data # ########################################### +# train_manifest: data/manifest_bert_tokenizer.train +# dev_manifest: data/manifest_bert_tokenizer.dev +# test_manifest: data/manifest_bert_tokenizer.test +# vocab_filepath: vocab.txt + train_manifest: data/manifest.train dev_manifest: data/manifest.dev test_manifest: data/manifest.test - +vocab_filepath: data/lang_char/vocab.txt ########################################### # Dataloader # ########################################### -vocab_filepath: data/lang_char/vocab.txt + unit_type: 'char' mean_std_filepath: preprocess_config: conf/preprocess.yaml -sortagrad: -1 # Feed samples from shortest to longest ; -1: enabled for all epochs 0: disabled other: enabled for 'other' epochs -batch_size: 8 # Different batch_size may cause large differences in results +sortagrad: -1 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs +batch_size: 4 # Different batch_size may cause large differences in results maxlen_in: 51200000000 # if input length > maxlen-in batchsize is automatically reduced maxlen_out: 1500000 # if output length > maxlen-out batchsize is automatically reduced minibatches: 0 # for debug @@ -94,7 +122,7 @@ batch_bins: 0 batch_frames_in: 0 batch_frames_out: 0 batch_frames_inout: 0 -num_workers: 0 +num_workers: 4 subsampling_factor: 1 num_encs: 1 dist_sampler: True @@ -107,12 +135,18 @@ return_lens_rate: True ########################################### n_epoch: 80 accum_grad: 1 -global_grad_clip: 3.0 +global_grad_clip: 5.0 model_optim: adadelta model_optim_conf: - lr: 0.95 + lr: 1.0 epsilon: 1.0e-8 rho: 0.95 + weight_decay: 0.0 +# model_optim: adam +# model_optim_conf: +# lr: 0.01 +# # epsilon: 1.0e-8 +# # rho: 0.95 wav2vec2_optim: adam wav2vec2_optim_conf: lr: 0.0001 @@ -132,6 +166,3 @@ log_interval: 1 checkpoint: kbest_n: 50 latest_n: 5 -augment: True - - diff --git a/examples/aishell/asr2/local/data.sh b/examples/aishell/asr3/local/data.sh similarity index 100% rename from examples/aishell/asr2/local/data.sh rename to examples/aishell/asr3/local/data.sh diff --git a/examples/aishell/asr2/local/test.sh b/examples/aishell/asr3/local/test.sh similarity index 90% rename from examples/aishell/asr2/local/test.sh rename to examples/aishell/asr3/local/test.sh index ccc0d84de..9d4b84291 100755 --- a/examples/aishell/asr2/local/test.sh +++ b/examples/aishell/asr3/local/test.sh @@ -25,13 +25,13 @@ source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; #fi python3 utils/format_rsl.py \ - --origin_ref data/manifest.test-clean.raw \ - --trans_ref data/manifest.test-clean.text + --origin_ref data/manifest.test.raw \ + --trans_ref data/manifest.test.text for type in ctc_greedy_search; do echo "decoding ${type}" - batch_size=16 + batch_size=1 python3 -u ${BIN_DIR}/test.py \ --ngpu ${ngpu} \ --config ${config_path} \ @@ -50,7 +50,7 @@ for type in ctc_greedy_search; do --trans_hyp ${ckpt_prefix}.${type}.rsl.text python3 utils/compute-wer.py --char=1 --v=1 \ - data/manifest.test-clean.text ${ckpt_prefix}.${type}.rsl.text > ${ckpt_prefix}.${type}.error + data/manifest.test.text ${ckpt_prefix}.${type}.rsl.text > ${ckpt_prefix}.${type}.error echo "decoding ${type} done." done diff --git a/examples/aishell/asr2/local/test_wav.sh b/examples/aishell/asr3/local/test_wav.sh similarity index 100% rename from examples/aishell/asr2/local/test_wav.sh rename to examples/aishell/asr3/local/test_wav.sh diff --git a/examples/aishell/asr2/local/train.sh b/examples/aishell/asr3/local/train.sh similarity index 80% rename from examples/aishell/asr2/local/train.sh rename to examples/aishell/asr3/local/train.sh index a8b8ed590..014249c76 100755 --- a/examples/aishell/asr2/local/train.sh +++ b/examples/aishell/asr3/local/train.sh @@ -10,7 +10,8 @@ echo "using $ngpu gpus..." config_path=$1 ckpt_name=$2 -ips=$3 +resume=$3 +ips=$4 if [ ! $ips ];then ips_config= @@ -21,7 +22,7 @@ fi mkdir -p exp # seed may break model convergence -seed=1998 +seed=1988 if [ ${seed} != 0 ]; then export FLAGS_cudnn_deterministic=True fi @@ -35,13 +36,15 @@ python3 -u ${BIN_DIR}/train.py \ --ngpu ${ngpu} \ --config ${config_path} \ --output exp/${ckpt_name} \ ---seed ${seed} +--seed ${seed} \ +--resume ${resume} else -python3 -m paddle.distributed.launch --log_dir=aa --gpus=${CUDA_VISIBLE_DEVICES} ${ips_config} ${BIN_DIR}/train.py \ +python3 -m paddle.distributed.launch --log_dir=${ckpt_name} --gpus=${CUDA_VISIBLE_DEVICES} ${ips_config} ${BIN_DIR}/train.py \ --ngpu ${ngpu} \ --config ${config_path} \ --output exp/${ckpt_name} \ ---seed ${seed} +--seed ${seed} \ +--resume ${resume} fi if [ ${seed} != 0 ]; then diff --git a/examples/aishell/asr2/path.sh b/examples/aishell/asr3/path.sh similarity index 100% rename from examples/aishell/asr2/path.sh rename to examples/aishell/asr3/path.sh diff --git a/examples/aishell/asr2/run.sh b/examples/aishell/asr3/run.sh similarity index 80% rename from examples/aishell/asr2/run.sh rename to examples/aishell/asr3/run.sh index 07b2c4049..d3b721858 100755 --- a/examples/aishell/asr2/run.sh +++ b/examples/aishell/asr3/run.sh @@ -4,13 +4,14 @@ set -e . ./path.sh || exit 1; . ./cmd.sh || exit 1; -gpus=6 -stage=1 -stop_stage=1 +gpus=0 +stage=0 +stop_stage=0 conf_path=conf/wav2vec2ASR.yaml ips= #xx.xx.xx.xx,xx.xx.xx.xx decode_conf_path=conf/tuning/decode.yaml avg_num=1 +resume= # xx e.g. 30 . ${MAIN_ROOT}/utils/parse_options.sh || exit 1; @@ -27,17 +28,17 @@ fi if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then # train model, all `ckpt` under `exp` dir - CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${ips} + CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${resume} ${ips} fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then # avg n best model - avg.sh best exp/${ckpt}/checkpoints ${avg_num} + avg.sh last exp/${ckpt}/checkpoints ${avg_num} fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then # greedy search decoder - CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 + CUDA_VISIBLE_DEVICES=1 ./local/test.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 fi if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then diff --git a/examples/aishell/asr2/utils b/examples/aishell/asr3/utils similarity index 100% rename from examples/aishell/asr2/utils rename to examples/aishell/asr3/utils diff --git a/paddlespeech/s2t/exps/wav2vec2/model.py b/paddlespeech/s2t/exps/wav2vec2/model.py index 1144128fb..127015c84 100755 --- a/paddlespeech/s2t/exps/wav2vec2/model.py +++ b/paddlespeech/s2t/exps/wav2vec2/model.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,7 +20,6 @@ import time from collections import defaultdict from collections import OrderedDict from contextlib import nullcontext -import re import jsonlines import numpy as np @@ -43,6 +42,16 @@ from paddlespeech.s2t.utils import mp_tools from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.utility import UpdateConfig +import transformers +from hyperpyyaml import load_hyperpyyaml +from paddlespeech.s2t.io.wav2vec2 import dataset +from paddlespeech.s2t.io.wav2vec2 import data_pipeline +from paddlespeech.s2t.io.wav2vec2.dataloader import make_dataloader +from paddlespeech.s2t.io.wav2vec2 import dataio +import paddle +import tqdm +import numpy + logger = Log(__name__).getlog() @@ -50,7 +59,8 @@ class Wav2Vec2ASRTrainer(Trainer): def __init__(self, config, args): super().__init__(config, args) self.avg_train_loss = 0.0 - + self.flag = False + self.use_sb = True def update_average(self, batch_index, loss): """Update running average of the loss. Arguments @@ -63,7 +73,10 @@ class Wav2Vec2ASRTrainer(Trainer): if math.isfinite(loss): self.avg_train_loss -= self.avg_train_loss / (batch_index + 1) self.avg_train_loss += loss / (batch_index + 1) - + else: + self.flag = True + # exit() + logger.info('loss:{} in Nan or inf, error'.format(loss)) def before_train(self): from_scratch = self.resume_or_scratch() if from_scratch: @@ -72,7 +85,6 @@ class Wav2Vec2ASRTrainer(Trainer): else: # resume: train next_epoch and next_iteration self.epoch += 1 - self.iteration += 1 logger.info( f"Resume train: epoch {self.epoch }, step {self.iteration}!") @@ -83,10 +95,30 @@ class Wav2Vec2ASRTrainer(Trainer): start = time.time() # forward - utt, wav, wavs_lens, target, target_lens = batch - wavs_lens_rate = wavs_lens / wav.shape[1] + ## sb data pipeline + if self.use_sb: + wav, wavs_lens_rate = batch['sig'] + target, target_lens_rate = batch['tokens'] + target_lens = (target_lens_rate * + target.shape[1]).round().astype(paddle.int64) + else: + utt, wav, wavs_lens, target, target_lens = batch + wavs_lens_rate = wavs_lens / wav.shape[1] + wav = wav[:, :, 0] + + + # 加载输入和gt + # self.model.eval() ## 用来测试,设置为eval + # import numpy as np + # wav = paddle.to_tensor(np.load('/home/zhangtianhao/workspace/PaddleSpeech/examples/aishell/asr2/duiqi/inputs.npz.npy')) + # wavs_lens_rate = paddle.to_tensor(np.load('/home/zhangtianhao/workspace/PaddleSpeech/examples/aishell/asr2/duiqi/inputs_length.npz.npy')) + # target = paddle.to_tensor(np.load('/home/zhangtianhao/workspace/PaddleSpeech/examples/aishell/asr2/duiqi/tokens.npy')) + # target_lens = paddle.to_tensor(np.load('/home/zhangtianhao/workspace/PaddleSpeech/examples/aishell/asr2/duiqi/tokens_length.npz.npy')) + # print(wav, wavs_lens_rate) + # exit() + # target_lens_rate = target_lens / target.shape[1] - wav = wav[:, :, 0] + if hasattr(train_conf, 'audio_augment'): wav = self.speech_augmentation(wav, wavs_lens_rate) @@ -110,6 +142,8 @@ class Wav2Vec2ASRTrainer(Trainer): context = nullcontext with context(): loss.backward() + # print(loss) + layer_tools.print_grads(self.model, print_func=None) # optimizer step old @@ -121,10 +155,15 @@ class Wav2Vec2ASRTrainer(Trainer): self.wav2vec2_optimizer.clear_grad() if self.config.model_scheduler != 'newbobscheduler': self.model_lr_scheduler.step() - self.model_lr_scheduler.clear_grad() + if self.config.wav2vec2_scheduler != 'newbobscheduler': if not train_conf.freeze_wav2vec2: self.wav2vec2_lr_scheduler.step() self.iteration += 1 + # import numpy as np + # xx = self.model.ctc.ctc_lo.weight + # np.save('/home/zhangtianhao/workspace/PaddleSpeech/examples/aishell/asr2/duiqi/paddle_data', xx.cpu().numpy()) + # print(xx) + # exit() losses_np = {'loss': self.avg_train_loss * train_conf.accum_grad} iteration_time = time.time() - start for k, v in losses_np.items(): @@ -154,16 +193,27 @@ class Wav2Vec2ASRTrainer(Trainer): num_seen_utts = 1 total_loss = 0.0 for i, batch in enumerate(self.valid_loader): - utt, wav, wavs_lens, target, target_lens = batch - wavs_lens_rate = wavs_lens / wav.shape[1] - wav = wav[:, :, 0] + if self.use_sb: + wav, wavs_lens_rate = batch['sig'] + target, target_lens_rate = batch['tokens'] + target_lens = (target_lens_rate * + target.shape[1]).round().astype(paddle.int64) + else: + utt, wav, wavs_lens, target, target_lens = batch + wavs_lens_rate = wavs_lens / wav.shape[1] + # target_lens_rate = target_lens / target.shape[1] + wav = wav[:, :, 0] + loss = self.model(wav, wavs_lens_rate, target, target_lens) if math.isfinite(float(loss)): - num_utts = batch[1].shape[0] + # num_utts = batch[1].shape[0] + num_utts = wav.shape[0] num_seen_utts += num_utts total_loss += float(loss) * num_utts valid_losses['val_loss'].append(float(loss)) + else: + logger.info('loss:{} in Nan or inf, error'.format(float(loss))) if (i + 1) % self.config.log_interval == 0: valid_dump = {k: np.mean(v) for k, v in valid_losses.items()} @@ -183,105 +233,6 @@ class Wav2Vec2ASRTrainer(Trainer): logger.info('Rank {} Val info val_loss {}'.format( dist.get_rank(), total_loss / num_seen_utts)) return total_loss, num_seen_utts - - - @mp_tools.rank_zero_only - def save(self, tag=None, infos: dict=None): - """Save checkpoint (model parameters and optimizer states). - - Args: - tag (int or str, optional): None for step, else using tag, e.g epoch. Defaults to None. - infos (dict, optional): meta data to save. Defaults to None. - """ - - infos = infos if infos else dict() - infos.update({ - "step": self.iteration, - "epoch": self.epoch, - "model_lr": self.model_optimizer.get_lr(), - "wav2vec2_lr": self.wav2vec2_optimizer.get_lr() - }) - - checkpoint_path = os.path.join(self.checkpoint_dir, - "{}".format(self.iteration - if tag is None else tag)) - - model_dict = self.model.state_dict() - params_path = checkpoint_path + ".pdparams" - paddle.save(model_dict, params_path) - logger.info("Saved model to {}".format(params_path)) - - model_opt_dict = self.model_optimizer.state_dict() - wav2vec2_opt_dict = self.wav2vec2_optimizer.state_dict() - - opt_dict = { - 'model': model_opt_dict, - 'wav2vec2': wav2vec2_opt_dict} - - optimizer_path = checkpoint_path + ".pdopt" - paddle.save(opt_dict, optimizer_path) - logger.info("Saved optimzier state to {}".format(optimizer_path)) - - scheduler_dict = {} - - if self.config.model_scheduler == 'newbobscheduler': - scheduler_dict['model'] = self.model_lr_scheduler.save() - if self.config.wav2vec2_scheduler =='newbobscheduler': - scheduler_dict['wav2vec2'] = self.wav2vec2_lr_scheduler.save() - if scheduler_dict: - scheduler_path = checkpoint_path + ".pdlrs" - paddle.save(scheduler_dict, scheduler_path) - logger.info("Saved scheduler state to {}".format(scheduler_path)) - info_path = re.sub('.pdparams$', '.json', params_path) - infos = {} if infos is None else infos - with open(info_path, 'w') as fout: - data = json.dumps(infos) - fout.write(data) - - def resume_or_scratch(self): - """Resume from latest checkpoint at checkpoints in the output - directory or load a specified checkpoint. - - If ``args.checkpoint_path`` is not None, load the checkpoint, else - resume training. - """ - scratch = None - infos = self.checkpoint.load_latest_parameters( - self.model, - checkpoint_dir=self.checkpoint_dir, - checkpoint_path=self.args.checkpoint_path) - if infos: - # just restore ckpt - # lr will resotre from optimizer ckpt - self.iteration = infos["step"] - self.epoch = infos["epoch"] - - # resotre optimizer from *.pdopt - optimizer_path = os.path.join(self.checkpoint_dir, - "{}".format(epoch)) + '.pdopt' - optimizer_dict = paddle.load(optimizer_path) - optimizer.set_state_dict(optimizer_dict) - self.model_optimizer.set_state_dict(optimizer_dict['model']) - self.wav2vec2_optimizer.set_state_dict(optimizer_dict['wav2vec2']) - - # resotre lr_scheduler from *.pdlrs - scheduler_path = os.path.join(self.checkpoint_dir, - "{}".format(epoch)) + '.pdlrs' - if os.path.isfile(os.path.join(scheduler_path)): - scheduler_dict = paddle.load(scheduler_path) - if self.config.model_scheduler is 'newbobscheduler': - self.model_lr_scheduler.load(scheduler_dict['model']) - if self.config.wav2vec2_scheduler is 'newbobscheduler': - self.wav2vec2_lr_scheduler.load(scheduler_dict['wav2vec2']) - scratch = False - logger.info( - f"Restore ckpt: epoch {self.epoch }, step {self.iteration}!") - else: - self.iteration = 0 - self.epoch = 0 - scratch = True - logger.info("Init from scratch!") - return scratch @mp_tools.rank_zero_only def save(self, tag=None, infos: dict=None): @@ -462,43 +413,209 @@ class Wav2Vec2ASRTrainer(Trainer): tag='eval/wav2vec2_lr', value=self.wav2vec2_lr_scheduler(), step=self.epoch) + if self.config.model_scheduler == 'newbobscheduler': self.model_lr_scheduler.step(cv_loss) if self.config.wav2vec2_scheduler == 'newbobscheduler': if not self.config.freeze_wav2vec2: self.wav2vec2_lr_scheduler.step(cv_loss) self.save(tag=self.epoch, infos={'val_loss': cv_loss}) + with open(os.path.join(self.checkpoint_dir, 'log'), 'a') as f: + f.write( + 'epoch: {}, lr_model: {}, lr_wav2vec: {} - train loss: {} - valid loss: {}\n'. + format(self.epoch, + self.model_lr_scheduler(), + self.wav2vec2_lr_scheduler(), self.avg_train_loss, + cv_loss)) + self.avg_train_loss = 0.0 + self.step = 0 self.new_epoch() + def dataio_prepare(self, hparams): + """This function prepares the datasets to be used in the brain class. + It also defines the data processing pipeline through user-defined functions.""" + data_folder = hparams["data_folder"] + + train_data = dataset.DynamicItemDataset.from_csv( + csv_path=hparams["train_data"], replacements={"data_root": data_folder}, + ) + + if hparams["sorting"] == "ascending": + # we sort training data to speed up training and get better results. + train_data = train_data.filtered_sorted(sort_key="duration") + # when sorting do not shuffle in dataloader ! otherwise is pointless + hparams["train_dataloader_opts"]["shuffle"] = False + + elif hparams["sorting"] == "descending": + train_data = train_data.filtered_sorted( + sort_key="duration", reverse=True + ) + # when sorting do not shuffle in dataloader ! otherwise is pointless + hparams["train_dataloader_opts"]["shuffle"] = False + + elif hparams["sorting"] == "random": + pass + + else: + raise NotImplementedError( + "sorting must be random, ascending or descending" + ) + + valid_data = dataset.DynamicItemDataset.from_csv( + csv_path=hparams["valid_data"], replacements={"data_root": data_folder}, + ) + valid_data = valid_data.filtered_sorted(sort_key="duration") + + test_data = dataset.DynamicItemDataset.from_csv( + csv_path=hparams["test_data"], replacements={"data_root": data_folder}, + ) + test_data = test_data.filtered_sorted(sort_key="duration") + + datasets = [train_data, valid_data, test_data] + + # Defining tokenizer and loading it + tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-chinese') + self.tokenizer = tokenizer + # 2. Define audio pipeline: + @data_pipeline.takes("wav") + @data_pipeline.provides("sig") + def audio_pipeline(wav): + sig = dataio.read_audio(wav) + return sig + + dataset.add_dynamic_item(datasets, audio_pipeline) + + # 3. Define text pipeline: + @data_pipeline.takes("transcript") + @data_pipeline.provides("wrd", "tokens_list", "tokens") + def text_pipeline(wrd): + wrd = "".join(wrd.split(" ")) + yield wrd + tokens_list = tokenizer(wrd)["input_ids"] + yield tokens_list + tokens = numpy.array(tokens_list, dtype="int64") + # tokens = paddle.to_tensor(tokens_list, dtype="int64") + yield tokens + + dataset.add_dynamic_item(datasets, text_pipeline) + + # 4. Set output: + dataset.set_output_keys( + datasets, ["id", "sig", "wrd", "tokens"], + ) + + # 5. If Dynamic Batching is used, we instantiate the needed samplers. + train_batch_sampler = None + valid_batch_sampler = None + if hparams["dynamic_batching"]: + from sampler import DynamicBatchSampler # noqa + + dynamic_hparams = hparams["dynamic_batch_sampler"] + num_buckets = dynamic_hparams["num_buckets"] + + train_batch_sampler = DynamicBatchSampler( + train_data, + dynamic_hparams["max_batch_len"], + num_buckets=num_buckets, + length_func=lambda x: x["duration"], + shuffle=dynamic_hparams["shuffle_ex"], + batch_ordering=dynamic_hparams["batch_ordering"], + ) + + valid_batch_sampler = DynamicBatchSampler( + valid_data, + dynamic_hparams["max_batch_len"], + num_buckets=num_buckets, + length_func=lambda x: x["duration"], + shuffle=dynamic_hparams["shuffle_ex"], + batch_ordering=dynamic_hparams["batch_ordering"], + ) + + return ( + train_data, + valid_data, + test_data, + tokenizer, + train_batch_sampler, + valid_batch_sampler, + ) + def setup_dataloader(self): config = self.config.clone() self.use_streamdata = config.get("use_stream_data", False) - if self.train: - self.train_loader = DataLoaderFactory.get_dataloader( - 'train', config, self.args) - self.valid_loader = DataLoaderFactory.get_dataloader( - 'valid', config, self.args) - logger.info("Setup train/valid Dataloader!") + if self.use_sb: + hparams_file = '/home/zhangtianhao/workspace/PaddleSpeech/paddlespeech/s2t/io/wav2vec2/train_with_wav2vec.yaml' + with open(hparams_file) as fin: + hparams = load_hyperpyyaml(fin, None) + + ( + train_data, + valid_data, + test_data, + tokenizer, + train_bsampler, + valid_bsampler, + ) = self.dataio_prepare(hparams) + + train_dataloader_opts = hparams["train_dataloader_opts"] + valid_dataloader_opts = hparams["valid_dataloader_opts"] + + if train_bsampler is not None: + train_dataloader_opts = { + "batch_sampler": train_bsampler, + "num_workers": hparams["num_workers"], + } + + if valid_bsampler is not None: + valid_dataloader_opts = {"batch_sampler": valid_bsampler} + + if self.train: + self.train_loader = make_dataloader( + train_data, stage='train', **train_dataloader_opts + ) + self.valid_loader = make_dataloader( + valid_data, + stage='val', + **valid_dataloader_opts, + ) + logger.info("Setup train/valid Dataloader!") + else: + self.test_loader = make_dataloader( + test_data, stage='test', **hparams["test_dataloader_opts"] + ) else: - decode_batch_size = config.get('decode', dict()).get( - 'decode_batch_size', 1) - self.test_loader = DataLoaderFactory.get_dataloader('test', config, - self.args) - self.align_loader = DataLoaderFactory.get_dataloader( - 'align', config, self.args) - logger.info("Setup test/align Dataloader!") + if self.train: + self.train_loader = DataLoaderFactory.get_dataloader( + 'train', config, self.args) + self.valid_loader = DataLoaderFactory.get_dataloader( + 'valid', config, self.args) + logger.info("Setup train/valid Dataloader!") + else: + decode_batch_size = config.get('decode', dict()).get( + 'decode_batch_size', 1) + self.test_loader = DataLoaderFactory.get_dataloader('test', config, + self.args) + self.align_loader = DataLoaderFactory.get_dataloader( + 'align', config, self.args) + logger.info("Setup test/align Dataloader!") + + def setup_model(self): config = self.config model_conf = config with UpdateConfig(model_conf): - if self.train: - model_conf.input_dim = self.train_loader.feat_dim - model_conf.output_dim = self.train_loader.vocab_size + if self.use_sb: + model_conf.output_dim = self.tokenizer.vocab_size else: - model_conf.input_dim = self.test_loader.feat_dim - model_conf.output_dim = self.test_loader.vocab_size + if self.train: + model_conf.input_dim = self.train_loader.feat_dim + model_conf.output_dim = self.train_loader.vocab_size + else: + model_conf.input_dim = self.test_loader.feat_dim + model_conf.output_dim = self.test_loader.vocab_size + model = Wav2vec2ASR.from_config(model_conf) model_dict = paddle.load(config.wav2vec2_params_path) @@ -575,6 +692,8 @@ class Wav2Vec2ASRTrainer(Trainer): 'params': model.ctc.parameters() }], model_lr_scheduler) + # model_optimizer_args = optimizer_args(config, model_optim_type, model_optim_conf, + # [*model._layers.enc.parameters(), *model._layers.ctc.parameters()] if self.parallel else [*model.enc.parameters(), *model.ctc.parameters()], model_lr_scheduler) wav2vec2_optimizer_args = optimizer_args( config, wav2vec2_optim_type, wav2vec2_optim_conf, model._layers.wav2vec2.parameters() if self.parallel else @@ -656,6 +775,55 @@ class Wav2Vec2ASRTester(Wav2Vec2ASRTrainer): num_frames=audio_len.sum().numpy().item(), decode_time=decode_time) + + def sb_compute_metrics(self, + id, + sig, + wrd, + tokens, + fout=None): + decode_cfg = self.config.decode + errors_sum, len_refs, num_ins = 0.0, 0, 0 + errors_func = error_rate.char_errors if decode_cfg.error_rate_type == 'cer' else error_rate.word_errors + error_rate_func = error_rate.cer if decode_cfg.error_rate_type == 'cer' else error_rate.wer + start_time = time.time() + target_transcripts = wrd + result_transcripts, result_tokenids = self.model.decode( + sig[0], + text_feature=self.tokenizer, + decoding_method=decode_cfg.decoding_method, + beam_size=decode_cfg.beam_size, + sb_pipeline=True) + decode_time = time.time() - start_time + + for utt, target, result, rec_tids in zip( + utts, target_transcripts, result_transcripts, result_tokenids): + errors, len_ref = errors_func(target, result) + errors_sum += errors + len_refs += len_ref + num_ins += 1 + if fout: + fout.write({ + "utt": utt, + "refs": [target], + "hyps": [result], + "hyps_tokenid": [rec_tids], + }) + logger.info(f"Utt: {utt}") + logger.info(f"Ref: {target}") + logger.info(f"Hyp: {result}") + logger.info("One example error rate [%s] = %f" % ( + decode_cfg.error_rate_type, error_rate_func(target, result))) + + return dict( + errors_sum=errors_sum, + len_refs=len_refs, + num_ins=num_ins, # num examples + error_rate=errors_sum / len_refs, + error_rate_type=decode_cfg.error_rate_type, + num_frames=audio_len.sum().numpy().item(), + decode_time=decode_time) + @mp_tools.rank_zero_only @paddle.no_grad() def test(self): @@ -673,7 +841,10 @@ class Wav2Vec2ASRTester(Wav2Vec2ASRTrainer): with jsonlines.open(self.args.result_file, 'w') as fout: for i, batch in enumerate(self.test_loader): - metrics = self.compute_metrics(*batch, fout=fout) + if self.use_sb: + metrics = self.sb_compute_metrics(**batch, fout=fout) + else: + metrics = self.compute_metrics(*batch, fout=fout) num_frames += metrics['num_frames'] num_time += metrics["decode_time"] errors_sum += metrics['errors_sum'] diff --git a/paddlespeech/s2t/io/wav2vec2/batch.py b/paddlespeech/s2t/io/wav2vec2/batch.py new file mode 100755 index 000000000..b69e55718 --- /dev/null +++ b/paddlespeech/s2t/io/wav2vec2/batch.py @@ -0,0 +1,184 @@ +"""Batch collation + +Authors + * Aku Rouhe 2020 +""" +import collections +import torch +from paddlespeech.s2t.io.wav2vec2.data_utils import mod_default_collate +# from speechbrain.utils.data_utils import recursive_to +from paddlespeech.s2t.io.wav2vec2.data_utils import batch_pad_right +from torch.utils.data._utils.collate import default_convert +# from torch.utils.data._utils.pin_memory import ( +# pin_memory as recursive_pin_memory, +# ) +import paddle + +PaddedData = collections.namedtuple("PaddedData", ["data", "lengths"]) + + +class PaddedBatch: + """Collate_fn when examples are dicts and have variable-length sequences. + + Different elements in the examples get matched by key. + All numpy tensors get converted to Torch (PyTorch default_convert) + Then, by default, all torch.Tensor valued elements get padded and support + collective pin_memory() and to() calls. + Regular Python data types are just collected in a list. + + Arguments + --------- + examples : list + List of example dicts, as produced by Dataloader. + padded_keys : list, None + (Optional) List of keys to pad on. If None, pad all torch.Tensors + device_prep_keys : list, None + (Optional) Only these keys participate in collective memory pinning and moving with + to(). + If None, defaults to all items with torch.Tensor values. + padding_func : callable, optional + Called with a list of tensors to be padded together. Needs to return + two tensors: the padded data, and another tensor for the data lengths. + padding_kwargs : dict + (Optional) Extra kwargs to pass to padding_func. E.G. mode, value + apply_default_convert : bool + Whether to apply PyTorch default_convert (numpy to torch recursively, + etc.) on all data. Default:True, usually does the right thing. + nonpadded_stack : bool + Whether to apply PyTorch-default_collate-like stacking on values that + didn't get padded. This stacks if it can, but doesn't error out if it + cannot. Default:True, usually does the right thing. + + Example + ------- + >>> batch = PaddedBatch([ + ... {"id": "ex1", "foo": torch.Tensor([1.])}, + ... {"id": "ex2", "foo": torch.Tensor([2., 1.])}]) + >>> # Attribute or key-based access: + >>> batch.id + ['ex1', 'ex2'] + >>> batch["id"] + ['ex1', 'ex2'] + >>> # torch.Tensors get padded + >>> type(batch.foo) + + >>> batch.foo.data + tensor([[1., 0.], + [2., 1.]]) + >>> batch.foo.lengths + tensor([0.5000, 1.0000]) + >>> # Batch supports collective operations: + >>> _ = batch.to(dtype=torch.half) + >>> batch.foo.data + tensor([[1., 0.], + [2., 1.]], dtype=torch.float16) + >>> batch.foo.lengths + tensor([0.5000, 1.0000], dtype=torch.float16) + >>> # Numpy tensors get converted to torch and padded as well: + >>> import numpy as np + >>> batch = PaddedBatch([ + ... {"wav": np.asarray([1,2,3,4])}, + ... {"wav": np.asarray([1,2,3])}]) + >>> batch.wav # +ELLIPSIS + PaddedData(data=tensor([[1, 2,... + >>> # Basic stacking collation deals with non padded data: + >>> batch = PaddedBatch([ + ... {"spk_id": torch.tensor([1]), "wav": torch.tensor([.1,.0,.3])}, + ... {"spk_id": torch.tensor([2]), "wav": torch.tensor([.2,.3,-.1])}], + ... padded_keys=["wav"]) + >>> batch.spk_id + tensor([[1], + [2]]) + >>> # And some data is left alone: + >>> batch = PaddedBatch([ + ... {"text": ["Hello"]}, + ... {"text": ["How", "are", "you?"]}]) + >>> batch.text + [['Hello'], ['How', 'are', 'you?']] + + """ + + def __init__( + self, + examples, + padded_keys=None, + device_prep_keys=None, + padding_func=batch_pad_right, + padding_kwargs={}, + nonpadded_stack=True, + ): + self.__length = len(examples) + self.__keys = list(examples[0].keys()) + self.__padded_keys = [] + self.__device_prep_keys = [] + for key in self.__keys: + values = [example[key] for example in examples] + # Default convert usually does the right thing (numpy2torch etc.) + values = default_convert(values) + + if (padded_keys is not None and key in padded_keys) or ( + padded_keys is None and isinstance(values[0], paddle.Tensor) + ): + # Padding and PaddedData + self.__padded_keys.append(key) + padded = PaddedData(*padding_func(values, **padding_kwargs)) + setattr(self, key, padded) + else: + # Default PyTorch collate usually does the right thing + # (convert lists of equal sized tensors to batch tensors, etc.) + if nonpadded_stack: + values = mod_default_collate(values) + setattr(self, key, values) + if (device_prep_keys is not None and key in device_prep_keys) or ( + device_prep_keys is None and isinstance(values[0], paddle.Tensor) + ): + self.__device_prep_keys.append(key) + + def __len__(self): + return self.__length + + def __getitem__(self, key): + if key in self.__keys: + return getattr(self, key) + else: + raise KeyError(f"Batch doesn't have key: {key}") + + def __iter__(self): + """Iterates over the different elements of the batch. + + Example + ------- + >>> batch = PaddedBatch([ + ... {"id": "ex1", "val": torch.Tensor([1.])}, + ... {"id": "ex2", "val": torch.Tensor([2., 1.])}]) + >>> ids, vals = batch + >>> ids + ['ex1', 'ex2'] + """ + return iter((getattr(self, key) for key in self.__keys)) + + # def pin_memory(self): + # """In-place, moves relevant elements to pinned memory.""" + # for key in self.__device_prep_keys: + # value = getattr(self, key) + # pinned = value + # setattr(self, key, pinned) + # return self + + # def to(self, *args, **kwargs): + # """In-place move/cast relevant elements. + + # Passes all arguments to torch.Tensor.to, see its documentation. + # """ + # for key in self.__device_prep_keys: + # value = getattr(self, key) + # moved = recursive_to(value, *args, **kwargs) + # setattr(self, key, moved) + # return self + + # def at_position(self, pos): + # """Gets the position.""" + # key = self.__keys[pos] + # return getattr(self, key) + + diff --git a/paddlespeech/s2t/io/wav2vec2/data_pipeline.py b/paddlespeech/s2t/io/wav2vec2/data_pipeline.py new file mode 100755 index 000000000..16fcf39ea --- /dev/null +++ b/paddlespeech/s2t/io/wav2vec2/data_pipeline.py @@ -0,0 +1,518 @@ +"""A pipeline for data transformations. + +Example +------- +>>> from hyperpyyaml import load_hyperpyyaml +>>> yamlstring = ''' +... pipeline: !new:speechbrain.utils.data_pipeline.DataPipeline +... static_data_keys: [a, b] +... dynamic_items: +... - func: !name:operator.add +... takes: ["a", "b"] +... provides: foo +... - func: !name:operator.sub +... takes: ["foo", "b"] +... provides: bar +... output_keys: ["foo", "bar"] +... ''' +>>> hparams = load_hyperpyyaml(yamlstring) +>>> hparams["pipeline"]({"a":1, "b":2}) +{'foo': 3, 'bar': 1} + +Author: + * Aku Rouhe +""" + +import inspect +from dataclasses import dataclass +from paddlespeech.s2t.io.wav2vec2.depgraph import DependencyGraph + +@dataclass +class StaticItem: + """Data class that represents a static item. + + Static items are in-memory items so they don't need to be computed + dynamically. + """ + + key: str + + +class DynamicItem: + """Essentially represents a data transformation function. + + A DynamicItem takes some arguments and computes its value dynamically when + called. A straight-forward use-case is to load something from disk + dynamically; take the path and provide the loaded data. + + Instances of this class are often created implicitly via the + @takes and @provides decorators or otherwise from specifying the taken and + provided arguments and the function. + + A counterpart is the GeneratorDynamicItem, which should be used for + generator functions. + + Arguments + --------- + takes : list + The keys of the items that this needs to compute its output. + func : callable + The function that is used to compute the output. + provides : list + The keys that this provides. + """ + + def __init__(self, takes=[], func=None, provides=[]): + self.takes = takes + self.func = func + self.provides = provides + + def __call__(self, *args): + return self.func(*args) + + # The next methods are more about supporting GeneratorDynamicItems + def next_takes(self): + """The next argkeys to provide to this, when called.""" + # Regular function DynamicItems always just need the same set of args + return self.takes + + def next_provides(self): + """The next keys that this provides, when called.""" + # Regular function DynamicItems always just provide the same set of keys + return self.provides + + def provided_in_order(self): + """Assuming that this may need to be called multiple times; which keys + does it provide at that call. Returns a list, with len equal to the + number of times that this may be called.""" + # Regular function DynamicItems are only called once: + return [self.provides] + + def reset(self): + """Signals that this will not be called any more times on this pipeline + call.""" + # Regular function DynamicItems don't need special resets. + pass + + +class GeneratorDynamicItem(DynamicItem): + """Essentially represents a multi-step data transformation. + + This is the generator function counterpart for DynamicItem (which should be + used for regular functions). + + A GeneratorDynamicItem first takes some arguments and then uses those in + multiple steps to incrementally compute some values when called. + + A typical use-case is a pipeline of transformations on data: e.g. taking in + text as a string, and first a tokenized version, and then on the second + call providing an integer-encoded version. This can be used even though the + integer-encoder needs to be trained on the first outputs. + + The main benefit is to be able to define the pipeline in a clear function, + even if parts of the pipeline depend on others for their initialization. + + Example + ------- + >>> lab2ind = {} + >>> def text_pipeline(text): + ... text = text.lower().strip() + ... text = "".join(c for c in text if c.isalpha() or c == " ") + ... words = text.split() + ... yield words + ... encoded = [lab2ind[word] for word in words] + ... yield encoded + >>> item = GeneratorDynamicItem( + ... func=text_pipeline, + ... takes=["text"], + ... provides=["words", "words_encoded"]) + >>> # First create the integer-encoding: + >>> ind = 1 + >>> for token in item("Is this it? - This is it."): + ... if token not in lab2ind: + ... lab2ind[token] = ind + ... ind += 1 + >>> # Now the integers can be encoded! + >>> item() + [1, 2, 3, 2, 1, 3] + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Doesn't generate electricity, only stores the currently active + # generator: + self.current_generator = None + self.num_provided_items = 0 + + def __call__(self, *args): + if self.num_provided_items == len(self.provides): + raise RuntimeError("DynamicItemPipeline called too many times!") + if not self.current_generator: + self.current_generator = self.func(*args) + # NOTE: Not supporting sending new values to the pipeline. + out = next(self.current_generator) + self.num_provided_items += 1 + return out + + def next_takes(self): + """The next argkeys to provide to this, when called.""" + if not self.current_generator: + return self.takes + else: + return [] + + def next_provides(self): + """The next keys that this provides, when called.""" + keys = self.provides[self.num_provided_items] + # Support multiple yielded values like: + # @yields("wav_read", ["left_ch", "right_ch"]) + if isinstance(keys, str): + return [keys] + else: + return keys + + def provided_in_order(self): + """Assuming that this may need to be called multiple times; which keys + does it provide at that call. Returns a list, with len equal to the + number of times that this may be called.""" + in_order = [] + for keys in self.provides: + # Support multiple yielded values like: + # @provides("wav_read", ["left_ch", "right_ch"]) + if isinstance(keys, str): + in_order.append([keys]) + else: + in_order.append(keys) + return in_order + + def reset(self): + """Signals that this will not be called any more times on this pipeline + call.""" + if self.current_generator is not None: + self.current_generator.close() + self.current_generator = None + self.num_provided_items = 0 + + +def takes(*argkeys): + """Decorator which makes a DynamicItem and specifies its argkeys. + + If the wrapped object is a generator function (has a yield statement), + Creates a GeneratorDynamicItem. If the object is already a DynamicItem, + just specifies the argkeys for that. Otherwise creates a new regular + DynamicItem, with argkeys specified. + + The args are always passed to the function at the start. Generators could + support sending new arguments, but for such use cases, simply create a new + dynamic item. The GeneratorDynamicItem class is meant for pipelines which + take in an input and transform it in multiple ways, where the intermediate + representations may be needed for e.g. fitting a BPE segmenter. + + Example + ------- + >>> @takes("text") + ... def tokenize(text): + ... return text.strip().lower().split() + >>> tokenize.provides = ["tokenized"] + >>> tokenize('\tThis Example gets tokenized') + ['this', 'example', 'gets', 'tokenized'] + """ + + def decorator(obj): + """Decorator definition.""" + if isinstance(obj, DynamicItem): + if obj.takes: + raise ValueError("Can't overwrite DynamicItem.takes") + obj.takes = argkeys + return obj + elif inspect.isgeneratorfunction(obj): + return GeneratorDynamicItem(takes=argkeys, func=obj) + else: + return DynamicItem(takes=argkeys, func=obj) + + return decorator + +takes_decorator = takes # Just for DataPipeline.add_dynamic_item + +def provides(*output_keys): + """Decorator which makes a DynamicItem and specifies what keys it provides. + + If the wrapped object is a generator function (has a yield statement), + Creates a GeneratorDynamicItem. If the object is already a DynamicItem, + just specifies the provided keys for that. Otherwise creates a new regular + DynamicItem, with provided keys specified. + + NOTE + ---- + The behavior is slightly different for generators and regular functions, if + many output keys are specified, e.g. @provides("signal", "mfcc"). Regular + functions should return a tuple with len equal to len(output_keys), while + generators should yield the items one by one. + + >>> @provides("signal", "feat") + ... def read_feat(): + ... wav = [.1,.2,-.1] + ... feat = [s**2 for s in wav] + ... return wav, feat + >>> @provides("signal", "feat") + ... def read_feat(): + ... wav = [.1,.2,-.1] + ... yield wav + ... feat = [s**2 for s in wav] + ... yield feat + + If multiple keys are yielded at once, write e.g., + + >>> @provides("wav_read", ["left_channel", "right_channel"]) + ... def read_multi_channel(): + ... wav = [[.1,.2,-.1],[.2,.1,-.1]] + ... yield wav + ... yield wav[0], wav[1] + + """ + + def decorator(obj): + """Decorator definition.""" + if isinstance(obj, DynamicItem): + if obj.provides: + raise ValueError("Can't overwrite DynamicItem provides-list.") + obj.provides = output_keys + return obj + elif inspect.isgeneratorfunction(obj): + return GeneratorDynamicItem(func=obj, provides=output_keys) + else: + return DynamicItem(func=obj, provides=output_keys) + + return decorator + + +provides_decorator = provides # Just for DataPipeline.add_dynamic_item + + +class DataPipeline: + """Organises data transformations into a pipeline. + + Example + ------- + >>> pipeline = DataPipeline( + ... static_data_keys=["text"], + ... dynamic_items=[ + ... {"func": lambda x: x.lower(), "takes": "text", "provides": "foo"}, + ... {"func": lambda x: x[::-1], "takes": "foo", "provides": "bar"}, + ... ], + ... output_keys=["bar"], + ... ) + >>> pipeline({"text": "Test"}) + {'bar': 'tset'} + """ + + def __init__(self, static_data_keys, dynamic_items=[], output_keys=[]): + self.dg = DependencyGraph() + self._exec_order = None + self.key_to_node = {} + self.unaccounted_keys = {} + self.dynamic_items = [] + self.output_mapping = {} + self.add_static_keys(static_data_keys) + self.add_dynamic_items(dynamic_items) + self.set_output_keys(output_keys) + + def add_static_keys(self, static_keys): + """Informs the pipeline about static items. + + Static items are the ones provided to __call__ as data. + """ + for key in static_keys: + node_id = self.dg.add_node(data=StaticItem(key=key)) + self.key_to_node[key] = node_id + + def add_dynamic_items(self, dynamic_items): + """Add multiple dynamic items at once.""" + for item in dynamic_items: + try: + self.add_dynamic_item(**item) + except TypeError: + self.add_dynamic_item(item) + + def add_dynamic_item(self, func, takes=None, provides=None): + """Adds a dynamic item to the Pipeline. + + Two calling conventions. For DynamicItem objects, just use: + add_dynamic_item(dynamic_item) + But otherwise, should use: + add_dynamic_item(func, takes, provides) + + Arguments + --------- + func : callable, DynamicItem + If a DynamicItem is given, adds that directly. Otherwise a + DynamicItem is created, and this specifies the callable to use. If + a generator function is given, then create a GeneratorDynamicItem. + Otherwise creates a normal DynamicItem. + takes : list, str + List of keys. When func is called, each key is resolved to + either an entry in the data or the output of another dynamic_item. + The func is then called with these as positional arguments, + in the same order as specified here. + A single key can be given as a bare string. + provides : str, list + For regular functions, the key or list of keys that it provides. + If you give a generator function, key or list of keys that it + yields, in order. Also see the provides decorator. + A single key can be given as a bare string. + """ + if isinstance(func, DynamicItem): + if takes is not None or provides is not None: + raise ValueError( + "If providing a DynamicItem directly, don't " + "specify takes or provides" + ) + else: + self._add_dynamic_item_object(func) + return + if isinstance(takes, str): + takes = [takes] + if isinstance(provides, str): + provides = [provides] + di = takes_decorator(*takes)(provides_decorator(*provides)(func)) + self._add_dynamic_item_object(di) + + def _add_dynamic_item_object(self, obj): + """Internally adds the object. + + There is a node in the dependency graph for each call of the + DynamicItem. Each call may return multiple keys and depend on multiple + keys. An internal dict maps key to the id of the node that produces it. + """ + if not obj.provides: + raise ValueError( + "Won't add redundant dynamic item which doesn't " + "provide anything." + ) + depended = [] + for key in obj.takes: + # Might not be accounted for, yet: + if key not in self.key_to_node: + dependee_keys = self.unaccounted_keys.setdefault(key, []) + dependee_keys.extend(obj.next_provides()) + else: + depended.append(self.key_to_node[key]) + for provided in obj.provided_in_order(): + node_id = self.dg.add_node(data=obj) + for key in provided: + self.key_to_node[key] = node_id + # This key may also be unaccounted for, so account for it now: + if key in self.unaccounted_keys: + for dependee_key in self.unaccounted_keys[key]: + dependee_node = self.key_to_node[dependee_key] + self.dg.add_edge(dependee_node, node_id) + del self.unaccounted_keys[key] # Now accounted for! + for dep_id in depended: + self.dg.add_edge(node_id, dep_id) + # Next call will depend on this call: + depended = [node_id] + # Keep a reference to the item in this object, as well: + self.dynamic_items.append(obj) + + def set_output_keys(self, keys): + """Use this to change the output keys. + + Also re-evaluates execution order. + So if you request different outputs, some parts of the + data pipeline may be skipped. + + Arguments + --------- + keys : dict, list, None + List of keys (str) to produce in output. + + If a dict is given; it is used to map internal keys to output keys. + From the output_keys dict key:value pairs the key appears outside, + and value is the internal key. + """ + self.output_mapping = self._output_keys_to_mapping(keys) + self._exec_order = None + + @staticmethod + def _output_keys_to_mapping(keys): + # Ensure a mapping (accept a list for convenience, too) + if keys is None: + output_mapping = {} + elif isinstance(keys, dict): + output_mapping = keys + else: + output_mapping = {key: key for key in keys} + return output_mapping + + def compute_outputs(self, data): + """ + Arguments + --------- + data : dict + Dictionary with data entries by key. + + Returns + ------- + dict + With the keys that were set. + """ + if self._exec_order is None: + self._prepare_run(data) + return self._compute(data, self._exec_order, self.output_mapping) + + def compute_specific(self, keys, data): + """Compute output of specific item, without changing output_keys.""" + output_mapping = self._output_keys_to_mapping(keys) + order = self.dg.get_evaluation_order( + selected_keys=self.get_selected_node_ids(keys) + ) + return self._compute(data, order, output_mapping) + + def _compute(self, data, order, output_mapping): + if self.unaccounted_keys: + MSG = "These keys are still unaccounted for in the data pipeline: " + MSG += ", ".join(self.unaccounted_keys) + raise RuntimeError(MSG) + intermediate = {} + for node_id, edges, item in order: + if isinstance(item, StaticItem): + # Static item in data. + # Just check that key is found. + try: + data[item.key] + continue + except KeyError: + raise KeyError(f"Expected key {item.key} in data!") + # A dynamic item, which we should compute: + args = [ + data[argkey] if argkey in data else intermediate[argkey] + for argkey in item.next_takes() + ] + # This needs to be called BEFORE the dynamic item is called. + provided_keys = item.next_provides() + values = item(*args) # Call the DynamicItem to produce output + # If there is just one output value, wrap in a list so that + # it can be zipped as well: + if len(provided_keys) == 1: + values = [values] + intermediate.update(zip(provided_keys, values)) + for dynamic_item in self.dynamic_items: + dynamic_item.reset() + return { + outkey: data[inkey] if inkey in data else intermediate[inkey] + for outkey, inkey in output_mapping.items() + } + + def get_selected_node_ids(self, selected_keys): + """Translates selected keys to dependency graph keys.""" + return [self.key_to_node[key] for key in selected_keys] + + def __call__(self, data): + return self.compute_outputs(data) + + def _prepare_run(self, data): + self._exec_order = list( + self.dg.get_evaluation_order( + self.get_selected_node_ids(self.output_mapping.values()) + ) + ) diff --git a/paddlespeech/s2t/io/wav2vec2/data_utils.py b/paddlespeech/s2t/io/wav2vec2/data_utils.py new file mode 100755 index 000000000..f782d207d --- /dev/null +++ b/paddlespeech/s2t/io/wav2vec2/data_utils.py @@ -0,0 +1,167 @@ +import os +import re +import csv +import shutil +import urllib.request +import collections.abc +import torch +import tqdm +import pathlib +import paddle +import numpy as np +def batch_pad_right(array: list, mode="constant", value=0): + """Given a list of torch tensors it batches them together by padding to the right + on each dimension in order to get same length for all. + + Parameters + ---------- + tensors : list + List of tensor we wish to pad together. + mode : str + Padding mode see torch.nn.functional.pad documentation. + value : float + Padding value see torch.nn.functional.pad documentation. + + Returns + ------- + tensor : torch.Tensor + Padded tensor. + valid_vals : listf + List containing proportion for each dimension of original, non-padded values. + + """ + + if not len(array): + raise IndexError("Tensors list must not be empty") + + if len(array) == 1: + # if there is only one tensor in the batch we simply unsqueeze it. + return np.expand_dims(array[0], 0), np.array([1.0], dtype="float32") + if not ( + any( + [array[i].ndim == array[0].ndim for i in range(1, len(array))] + ) + ): + raise IndexError("All array must have same number of dimensions") + + # FIXME we limit the support here: we allow padding of only the first dimension + # need to remove this when feat extraction is updated to handle multichannel. + max_shape = [] + for dim in range(array[0].ndim): + if dim != 0: + if not all( + [x.shape[dim] == array[0].shape[dim] for x in array[1:]] + ): + raise EnvironmentError( + "Tensors should have same dimensions except for the first one" + ) + max_shape.append(max([x.shape[dim] for x in array])) + + batched = [] + valid = [] + for t in array: + # for each tensor we apply pad_right_to + padded, valid_percent = pad_right_to( + t, max_shape, mode=mode, value=value + ) + batched.append(padded) + valid.append(valid_percent[0]) + + batched = np.stack(batched) + + return batched, np.array(valid, dtype="float32") + +np_str_obj_array_pattern = re.compile(r"[SaUO]") + +def pad_right_to( + array: np.ndarray, target_shape: (list, tuple), mode="constant", value=0, +): + """ + This function takes a torch tensor of arbitrary shape and pads it to target + shape by appending values on the right. + + Parameters + ---------- + tensor : input torch tensor + Input tensor whose dimension we need to pad. + target_shape : (list, tuple) + Target shape we want for the target tensor its len must be equal to tensor.ndim + mode : str + Pad mode, please refer to torch.nn.functional.pad documentation. + value : float + Pad value, please refer to torch.nn.functional.pad documentation. + + Returns + ------- + tensor : torch.Tensor + Padded tensor. + valid_vals : list + List containing proportion for each dimension of original, non-padded values. + """ + assert len(target_shape) == array.ndim + pads = [] # this contains the abs length of the padding for each dimension. + valid_vals = [] # this contains the relative lengths for each dimension. + i = len(target_shape) - 1 # iterating over target_shape ndims + j = 0 + while i >= 0: + assert ( + target_shape[i] >= array.shape[i] + ), "Target shape must be >= original shape for every dim" + pads.extend([0, target_shape[i] - array.shape[i]]) + valid_vals.append(array.shape[j] / target_shape[j]) + i -= 1 + j += 1 + array = np.pad(array, pads, mode, constant_values=(value, value)) + + return array, valid_vals + +def mod_default_collate(batch): + """Makes a tensor from list of batch values. + + Note that this doesn't need to zip(*) values together + as PaddedBatch connects them already (by key). + + Here the idea is not to error out. + + This is modified from: + https://github.com/pytorch/pytorch/blob/c0deb231db76dbea8a9d326401417f7d1ce96ed5/torch/utils/data/_utils/collate.py#L42 + """ + elem = batch[0] + elem_type = type(elem) + if isinstance(elem, paddle.Tensor): + out = None + try: + if torch.io.get_worker_info() is not None: + + # If we're in a background process, concatenate directly into a + # shared memory tensor to avoid an extra copy + numel = sum([x.numel() for x in batch]) + storage = elem.storage()._new_shared(numel) + out = elem.new(storage) + return torch.stack(batch, 0, out=out) + except RuntimeError: # Unequal size: + return batch + elif ( + elem_type.__module__ == "numpy" + and elem_type.__name__ != "str_" + and elem_type.__name__ != "string_" + ): + try: + if ( + elem_type.__name__ == "ndarray" + or elem_type.__name__ == "memmap" + ): + # array of string classes and object + if np_str_obj_array_pattern.search(elem.dtype.str) is not None: + return batch + return mod_default_collate([paddle.to_tensor(b, dtype=b.dtype) for b in batch]) + elif elem.shape == (): # scalars + return paddle.to_tensor(batch, dtype=batch.dtype) + except RuntimeError: # Unequal size + return batch + elif isinstance(elem, float): + return paddle.to_tensor(batch, dtype=paddle.float64) + elif isinstance(elem, int): + return paddle.to_tensor(batch, dtype=paddle.int64) + else: + return batch \ No newline at end of file diff --git a/paddlespeech/s2t/io/wav2vec2/dataio.py b/paddlespeech/s2t/io/wav2vec2/dataio.py new file mode 100755 index 000000000..6dfa81dc4 --- /dev/null +++ b/paddlespeech/s2t/io/wav2vec2/dataio.py @@ -0,0 +1,1024 @@ +""" +Data reading and writing. + +Authors + * Mirco Ravanelli 2020 + * Aku Rouhe 2020 + * Ju-Chieh Chou 2020 + * Samuele Cornell 2020 + * Abdel HEBA 2020 +""" + +import os +import logging +import numpy as np +import pickle +import hashlib +import csv +import time +import json +import re +import soundfile +logger = logging.getLogger(__name__) +import paddle + +def load_data_json(json_path, replacements={}): + """Loads JSON and recursively formats string values. + + Arguments + ---------- + json_path : str + Path to CSV file. + replacements : dict + (Optional dict), e.g., {"data_folder": "/home/speechbrain/data"}. + This is used to recursively format all string values in the data. + + Returns + ------- + dict + JSON data with replacements applied. + + Example + ------- + >>> json_spec = '''{ + ... "ex1": {"files": ["{ROOT}/mic1/ex1.wav", "{ROOT}/mic2/ex1.wav"], "id": 1}, + ... "ex2": {"files": [{"spk1": "{ROOT}/ex2.wav"}, {"spk2": "{ROOT}/ex2.wav"}], "id": 2} + ... } + ... ''' + >>> tmpfile = getfixture('tmpdir') / "test.json" + >>> with open(tmpfile, "w") as fo: + ... _ = fo.write(json_spec) + >>> data = load_data_json(tmpfile, {"ROOT": "/home"}) + >>> data["ex1"]["files"][0] + '/home/mic1/ex1.wav' + >>> data["ex2"]["files"][1]["spk2"] + '/home/ex2.wav' + + """ + with open(json_path, "r") as f: + out_json = json.load(f) + _recursive_format(out_json, replacements) + return out_json + + +def _recursive_format(data, replacements): + # Data: dict or list, replacements : dict + # Replaces string keys in replacements by their values + # at all levels of data (in str values) + # Works in-place. + if isinstance(data, dict): + for key, item in data.items(): + if isinstance(item, dict) or isinstance(item, list): + _recursive_format(item, replacements) + elif isinstance(item, str): + data[key] = item.format_map(replacements) + # If not dict, list or str, do nothing + if isinstance(data, list): + for i, item in enumerate(data): + if isinstance(item, dict) or isinstance(item, list): + _recursive_format(item, replacements) + elif isinstance(item, str): + data[i] = item.format_map(replacements) + # If not dict, list or str, do nothing + + +def load_data_csv(csv_path, replacements={}): + """Loads CSV and formats string values. + + Uses the SpeechBrain legacy CSV data format, where the CSV must have an + 'ID' field. + If there is a field called duration, it is interpreted as a float. + The rest of the fields are left as they are (legacy _format and _opts fields + are not used to load the data in any special way). + + Bash-like string replacements with $to_replace are supported. + + Arguments + ---------- + csv_path : str + Path to CSV file. + replacements : dict + (Optional dict), e.g., {"data_folder": "/home/speechbrain/data"} + This is used to recursively format all string values in the data. + + Returns + ------- + dict + CSV data with replacements applied. + + Example + ------- + >>> csv_spec = '''ID,duration,wav_path + ... utt1,1.45,$data_folder/utt1.wav + ... utt2,2.0,$data_folder/utt2.wav + ... ''' + >>> tmpfile = getfixture("tmpdir") / "test.csv" + >>> with open(tmpfile, "w") as fo: + ... _ = fo.write(csv_spec) + >>> data = load_data_csv(tmpfile, {"data_folder": "/home"}) + >>> data["utt1"]["wav_path"] + '/home/utt1.wav' + """ + + with open(csv_path, newline="") as csvfile: + result = {} + reader = csv.DictReader(csvfile, skipinitialspace=True) + variable_finder = re.compile(r"\$([\w.]+)") + for row in reader: + # ID: + try: + data_id = row["ID"] + del row["ID"] # This is used as a key in result, instead. + except KeyError: + raise KeyError( + "CSV has to have an 'ID' field, with unique ids" + " for all data points" + ) + if data_id in result: + raise ValueError(f"Duplicate id: {data_id}") + # Replacements: + for key, value in row.items(): + try: + row[key] = variable_finder.sub( + lambda match: str(replacements[match[1]]), value + ) + except KeyError: + raise KeyError( + f"The item {value} requires replacements " + "which were not supplied." + ) + # Duration: + if "duration" in row: + row["duration"] = float(row["duration"]) + result[data_id] = row + return result + + +def read_audio(waveforms_obj): + """General audio loading, based on a custom notation. + + Expected use case is in conjunction with Datasets + specified by JSON. + + The custom notation: + + The annotation can be just a path to a file: + "/path/to/wav1.wav" + + Or can specify more options in a dict: + {"file": "/path/to/wav2.wav", + "start": 8000, + "stop": 16000 + } + + Arguments + ---------- + waveforms_obj : str, dict + Audio reading annotation, see above for format. + + Returns + ------- + torch.Tensor + Audio tensor with shape: (samples, ). + + Example + ------- + >>> dummywav = torch.rand(16000) + >>> import os + >>> tmpfile = str(getfixture('tmpdir') / "wave.wav") + >>> write_audio(tmpfile, dummywav, 16000) + >>> asr_example = { "wav": tmpfile, "spk_id": "foo", "words": "foo bar"} + >>> loaded = read_audio(asr_example["wav"]) + >>> loaded.allclose(dummywav.squeeze(0),atol=1e-4) # replace with eq with sox_io backend + True + """ + if isinstance(waveforms_obj, str): + audio, _ = soundfile.read(waveforms_obj, dtype="float32") + # audio = paddle.to_tensor(audio) + return audio + + path = waveforms_obj["file"] + start = waveforms_obj.get("start", 0) + # Default stop to start -> if not specified, num_frames becomes 0, + # which is the torchaudio default + stop = waveforms_obj.get("stop", start) + num_frames = stop - start + audio, fs = soundfile.read(path, start=start, stop=start+num_frame, dtype="float32") + # audio = paddle.to_tensor(audio) + return audio + + +def read_audio_multichannel(waveforms_obj): + """General audio loading, based on a custom notation. + + Expected use case is in conjunction with Datasets + specified by JSON. + + The custom notation: + + The annotation can be just a path to a file: + "/path/to/wav1.wav" + + Multiple (possibly multi-channel) files can be specified, as long as they + have the same length: + {"files": [ + "/path/to/wav1.wav", + "/path/to/wav2.wav" + ] + } + + Or you can specify a single file more succinctly: + {"files": "/path/to/wav2.wav"} + + Offset number samples and stop number samples also can be specified to read + only a segment within the files. + {"files": [ + "/path/to/wav1.wav", + "/path/to/wav2.wav" + ] + "start": 8000 + "stop": 16000 + } + + Arguments + ---------- + waveforms_obj : str, dict + Audio reading annotation, see above for format. + + Returns + ------- + torch.Tensor + Audio tensor with shape: (samples, ). + + Example + ------- + >>> dummywav = torch.rand(16000, 2) + >>> import os + >>> tmpfile = str(getfixture('tmpdir') / "wave.wav") + >>> write_audio(tmpfile, dummywav, 16000) + >>> asr_example = { "wav": tmpfile, "spk_id": "foo", "words": "foo bar"} + >>> loaded = read_audio(asr_example["wav"]) + >>> loaded.allclose(dummywav.squeeze(0),atol=1e-4) # replace with eq with sox_io backend + True + """ + if isinstance(waveforms_obj, str): + audio, _ = soundfile.read(waveforms_obj, dtype="float32") + audio = paddle.to_tensor(audio) + return audio + + files = waveforms_obj["files"] + if not isinstance(files, list): + files = [files] + + waveforms = [] + start = waveforms_obj.get("start", 0) + # Default stop to start -> if not specified, num_frames becomes 0, + # which is the torchaudio default + stop = waveforms_obj.get("stop", start - 1) + num_frames = stop - start + for f in files: + audio, fs = soundfile.read(path, start=start, stop=start+num_frame, dtype="float32") + audio = paddle.to_tensor(audio) + waveforms.append(audio) + + out = paddle.concat(waveforms, 0) + return out + + +def write_audio(filepath, audio, samplerate): + """Write audio on disk. It is basically a wrapper to support saving + audio signals in the speechbrain format (audio, channels). + + Arguments + --------- + filepath: path + Path where to save the audio file. + audio : torch.Tensor + Audio file in the expected speechbrain format (signal, channels). + samplerate: int + Sample rate (e.g., 16000). + + + Example + ------- + >>> import os + >>> tmpfile = str(getfixture('tmpdir') / "wave.wav") + >>> dummywav = torch.rand(16000, 2) + >>> write_audio(tmpfile, dummywav, 16000) + >>> loaded = read_audio(tmpfile) + >>> loaded.allclose(dummywav,atol=1e-4) # replace with eq with sox_io backend + True + """ + if len(audio.shape) == 2: + audio = audio.transpose([1, 0]) + elif len(audio.shape) == 1: + audio = audio.unsqueeze(0) + + soundfile.write(filepath, audio, samplerate) + + +def load_pickle(pickle_path): + """Utility function for loading .pkl pickle files. + + Arguments + --------- + pickle_path : str + Path to pickle file. + + Returns + ------- + out : object + Python object loaded from pickle. + """ + with open(pickle_path, "rb") as f: + out = pickle.load(f) + return out + + +def to_floatTensor(x: (list, tuple, np.ndarray)): + """ + Arguments + --------- + x : (list, tuple, np.ndarray) + Input data to be converted to torch float. + + Returns + ------- + tensor : torch.tensor + Data now in torch.tensor float datatype. + """ + return paddle.to_tensor(x, dtype='float32') + + +def to_doubleTensor(x: (list, tuple, np.ndarray)): + """ + Arguments + --------- + x : (list, tuple, np.ndarray) + Input data to be converted to torch double. + + Returns + ------- + tensor : torch.tensor + Data now in torch.tensor double datatype. + """ + return paddle.to_tensor(x, dtype='float64') + + +def to_longTensor(x: (list, tuple, np.ndarray)): + """ + Arguments + --------- + x : (list, tuple, np.ndarray) + Input data to be converted to torch long. + + Returns + ------- + tensor : torch.tensor + Data now in torch.tensor long datatype. + """ + return paddle.to_tensor(x, dtype='int64') + + +def convert_index_to_lab(batch, ind2lab): + """Convert a batch of integer IDs to string labels. + + Arguments + --------- + batch : list + List of lists, a batch of sequences. + ind2lab : dict + Mapping from integer IDs to labels. + + Returns + ------- + list + List of lists, same size as batch, with labels from ind2lab. + + Example + ------- + >>> ind2lab = {1: "h", 2: "e", 3: "l", 4: "o"} + >>> out = convert_index_to_lab([[4,1], [1,2,3,3,4]], ind2lab) + >>> for seq in out: + ... print("".join(seq)) + oh + hello + """ + return [[ind2lab[int(index)] for index in seq] for seq in batch] + + +def relative_time_to_absolute(batch, relative_lens, rate): + """Converts SpeechBrain style relative length to the absolute duration. + + Operates on batch level. + + Arguments + --------- + batch : torch.tensor + Sequences to determine the duration for. + relative_lens : torch.tensor + The relative length of each sequence in batch. The longest sequence in + the batch needs to have relative length 1.0. + rate : float + The rate at which sequence elements occur in real-world time. Sample + rate, if batch is raw wavs (recommended) or 1/frame_shift if batch is + features. This has to have 1/s as the unit. + + Returns + ------: + torch.tensor + Duration of each sequence in seconds. + + Example + ------- + >>> batch = torch.ones(2, 16000) + >>> relative_lens = torch.tensor([3./4., 1.0]) + >>> rate = 16000 + >>> print(relative_time_to_absolute(batch, relative_lens, rate)) + tensor([0.7500, 1.0000]) + """ + max_len = batch.shape[1] + durations = paddle.round(relative_lens * max_len) / rate + return durations + + +class IterativeCSVWriter: + """Write CSV files a line at a time. + + Arguments + --------- + outstream : file-object + A writeable stream + data_fields : list + List of the optional keys to write. Each key will be expanded to the + SpeechBrain format, producing three fields: key, key_format, key_opts. + + Example + ------- + >>> import io + >>> f = io.StringIO() + >>> writer = IterativeCSVWriter(f, ["phn"]) + >>> print(f.getvalue()) + ID,duration,phn,phn_format,phn_opts + >>> writer.write("UTT1",2.5,"sil hh ee ll ll oo sil","string","") + >>> print(f.getvalue()) + ID,duration,phn,phn_format,phn_opts + UTT1,2.5,sil hh ee ll ll oo sil,string, + >>> writer.write(ID="UTT2",phn="sil ww oo rr ll dd sil",phn_format="string") + >>> print(f.getvalue()) + ID,duration,phn,phn_format,phn_opts + UTT1,2.5,sil hh ee ll ll oo sil,string, + UTT2,,sil ww oo rr ll dd sil,string, + >>> writer.set_default('phn_format', 'string') + >>> writer.write_batch(ID=["UTT3","UTT4"],phn=["ff oo oo", "bb aa rr"]) + >>> print(f.getvalue()) + ID,duration,phn,phn_format,phn_opts + UTT1,2.5,sil hh ee ll ll oo sil,string, + UTT2,,sil ww oo rr ll dd sil,string, + UTT3,,ff oo oo,string, + UTT4,,bb aa rr,string, + """ + + def __init__(self, outstream, data_fields, defaults={}): + self._outstream = outstream + self.fields = ["ID", "duration"] + self._expand_data_fields(data_fields) + self.defaults = defaults + self._outstream.write(",".join(self.fields)) + + def set_default(self, field, value): + """Sets a default value for the given CSV field. + + Arguments + --------- + field : str + A field in the CSV. + value + The default value. + """ + if field not in self.fields: + raise ValueError(f"{field} is not a field in this CSV!") + self.defaults[field] = value + + def write(self, *args, **kwargs): + """Writes one data line into the CSV. + + Arguments + --------- + *args + Supply every field with a value in positional form OR. + **kwargs + Supply certain fields by key. The ID field is mandatory for all + lines, but others can be left empty. + """ + if args and kwargs: + raise ValueError( + "Use either positional fields or named fields, but not both." + ) + if args: + if len(args) != len(self.fields): + raise ValueError("Need consistent fields") + to_write = [str(arg) for arg in args] + if kwargs: + if "ID" not in kwargs: + raise ValueError("I'll need to see some ID") + full_vals = self.defaults.copy() + full_vals.update(kwargs) + to_write = [str(full_vals.get(field, "")) for field in self.fields] + self._outstream.write("\n") + self._outstream.write(",".join(to_write)) + + def write_batch(self, *args, **kwargs): + """Writes a batch of lines into the CSV. + + Here each argument should be a list with the same length. + + Arguments + --------- + *args + Supply every field with a value in positional form OR. + **kwargs + Supply certain fields by key. The ID field is mandatory for all + lines, but others can be left empty. + """ + if args and kwargs: + raise ValueError( + "Use either positional fields or named fields, but not both." + ) + if args: + if len(args) != len(self.fields): + raise ValueError("Need consistent fields") + for arg_row in zip(*args): + self.write(*arg_row) + if kwargs: + if "ID" not in kwargs: + raise ValueError("I'll need to see some ID") + keys = kwargs.keys() + for value_row in zip(*kwargs.values()): + kwarg_row = dict(zip(keys, value_row)) + self.write(**kwarg_row) + + @staticmethod + def _expand_data_fields(data_fields): + expanded = [] + for data_field in data_fields: + expanded.append(data_field) + expanded.append(data_field + "_format") + expanded.append(data_field + "_opts") + return expanded + + +def write_txt_file(data, filename, sampling_rate=None): + """Write data in text format. + + Arguments + --------- + data : str, list, torch.tensor, numpy.ndarray + The data to write in the text file. + filename : str + Path to file where to write the data. + sampling_rate : None + Not used, just here for interface compatibility. + + Returns + ------- + None + + Example + ------- + >>> tmpdir = getfixture('tmpdir') + >>> signal=torch.tensor([1,2,3,4]) + >>> write_txt_file(signal, tmpdir / 'example.txt') + """ + del sampling_rate # Not used. + # Check if the path of filename exists + os.makedirs(os.path.dirname(filename), exist_ok=True) + with open(filename, "w") as fout: + if isinstance(data, torch.Tensor): + data = data.tolist() + if isinstance(data, np.ndarray): + data = data.tolist() + if isinstance(data, list): + for line in data: + print(line, file=fout) + if isinstance(data, str): + print(data, file=fout) + + +def write_stdout(data, filename=None, sampling_rate=None): + """Write data to standard output. + + Arguments + --------- + data : str, list, torch.tensor, numpy.ndarray + The data to write in the text file. + filename : None + Not used, just here for compatibility. + sampling_rate : None + Not used, just here for compatibility. + + Returns + ------- + None + + Example + ------- + >>> tmpdir = getfixture('tmpdir') + >>> signal = torch.tensor([[1,2,3,4]]) + >>> write_stdout(signal, tmpdir / 'example.txt') + [1, 2, 3, 4] + """ + # Managing Torch.Tensor + if isinstance(data, torch.Tensor): + data = data.tolist() + # Managing np.ndarray + if isinstance(data, np.ndarray): + data = data.tolist() + if isinstance(data, list): + for line in data: + print(line) + if isinstance(data, str): + print(data) + + +def length_to_mask(length, max_len=None, dtype=None, device=None): + """Creates a binary mask for each sequence. + + Reference: https://discuss.pytorch.org/t/how-to-generate-variable-length-mask/23397/3 + + Arguments + --------- + length : torch.LongTensor + Containing the length of each sequence in the batch. Must be 1D. + max_len : int + Max length for the mask, also the size of the second dimension. + dtype : torch.dtype, default: None + The dtype of the generated mask. + device: torch.device, default: None + The device to put the mask variable. + + Returns + ------- + mask : tensor + The binary mask. + + Example + ------- + >>> length=torch.Tensor([1,2,3]) + >>> mask=length_to_mask(length) + >>> mask + tensor([[1., 0., 0.], + [1., 1., 0.], + [1., 1., 1.]]) + """ + assert len(length.shape) == 1 + + if max_len is None: + max_len = length.max().long().item() # using arange to generate mask + mask = paddle.arange( + max_len, dtype=length.dtype + ).expand([len(length), max_len]) < length.unsqueeze(1) + + if dtype is None: + dtype = length.dtype + + if device is None: + device = length.device + + mask = paddle.to_tensor(mask, dtype=dtype) + return mask + + +def read_kaldi_lab(kaldi_ali, kaldi_lab_opts): + """Read labels in kaldi format. + + Uses kaldi IO. + + Arguments + --------- + kaldi_ali : str + Path to directory where kaldi alignments are stored. + kaldi_lab_opts : str + A string that contains the options for reading the kaldi alignments. + + Returns + ------- + lab : dict + A dictionary containing the labels. + + Note + ---- + This depends on kaldi-io-for-python. Install it separately. + See: https://github.com/vesis84/kaldi-io-for-python + + Example + ------- + This example requires kaldi files. + ``` + lab_folder = '/home/kaldi/egs/TIMIT/s5/exp/dnn4_pretrain-dbn_dnn_ali' + read_kaldi_lab(lab_folder, 'ali-to-pdf') + ``` + """ + # EXTRA TOOLS + try: + import kaldi_io + except ImportError: + raise ImportError("Could not import kaldi_io. Install it to use this.") + # Reading the Kaldi labels + lab = { + k: v + for k, v in kaldi_io.read_vec_int_ark( + "gunzip -c " + + kaldi_ali + + "/ali*.gz | " + + kaldi_lab_opts + + " " + + kaldi_ali + + "/final.mdl ark:- ark:-|" + ) + } + return lab + + +def get_md5(file): + """Get the md5 checksum of an input file. + + Arguments + --------- + file : str + Path to file for which compute the checksum. + + Returns + ------- + md5 + Checksum for the given filepath. + + Example + ------- + >>> get_md5('tests/samples/single-mic/example1.wav') + 'c482d0081ca35302d30d12f1136c34e5' + """ + # Lets read stuff in 64kb chunks! + BUF_SIZE = 65536 + md5 = hashlib.md5() + # Computing md5 + with open(file, "rb") as f: + while True: + data = f.read(BUF_SIZE) + if not data: + break + md5.update(data) + return md5.hexdigest() + + +def save_md5(files, out_file): + """Saves the md5 of a list of input files as a pickled dict into a file. + + Arguments + --------- + files : list + List of input files from which we will compute the md5. + outfile : str + The path where to store the output pkl file. + + Returns + ------- + None + + Example: + >>> files = ['tests/samples/single-mic/example1.wav'] + >>> tmpdir = getfixture('tmpdir') + >>> save_md5(files, tmpdir / "md5.pkl") + """ + # Initialization of the dictionary + md5_dict = {} + # Computing md5 for all the files in the list + for file in files: + md5_dict[file] = get_md5(file) + # Saving dictionary in pkl format + save_pkl(md5_dict, out_file) + + +def save_pkl(obj, file): + """Save an object in pkl format. + + Arguments + --------- + obj : object + Object to save in pkl format + file : str + Path to the output file + sampling_rate : int + Sampling rate of the audio file, TODO: this is not used? + + Example + ------- + >>> tmpfile = getfixture('tmpdir') / "example.pkl" + >>> save_pkl([1, 2, 3, 4, 5], tmpfile) + >>> load_pkl(tmpfile) + [1, 2, 3, 4, 5] + """ + with open(file, "wb") as f: + pickle.dump(obj, f) + + +def load_pkl(file): + """Loads a pkl file. + + For an example, see `save_pkl`. + + Arguments + --------- + file : str + Path to the input pkl file. + + Returns + ------- + The loaded object. + """ + + # Deals with the situation where two processes are trying + # to access the same label dictionary by creating a lock + count = 100 + while count > 0: + if os.path.isfile(file + ".lock"): + time.sleep(1) + count -= 1 + else: + break + + try: + open(file + ".lock", "w").close() + with open(file, "rb") as f: + return pickle.load(f) + finally: + if os.path.isfile(file + ".lock"): + os.remove(file + ".lock") + + +def prepend_bos_token(label, bos_index): + """Create labels with token at the beginning. + + Arguments + --------- + label : torch.IntTensor + Containing the original labels. Must be of size: [batch_size, max_length]. + bos_index : int + The index for token. + + Returns + ------- + new_label : tensor + The new label with at the beginning. + + Example + ------- + >>> label=torch.LongTensor([[1,0,0], [2,3,0], [4,5,6]]) + >>> new_label=prepend_bos_token(label, bos_index=7) + >>> new_label + tensor([[7, 1, 0, 0], + [7, 2, 3, 0], + [7, 4, 5, 6]]) + """ + new_label = label.long().clone() + batch_size = label.shape[0] + + bos = new_label.new_zeros(batch_size, 1).fill_(bos_index) + new_label = torch.cat([bos, new_label], dim=1) + return new_label + + +def append_eos_token(label, length, eos_index): + """Create labels with token appended. + + Arguments + --------- + label : torch.IntTensor + Containing the original labels. Must be of size: [batch_size, max_length] + length : torch.LongTensor + Containing the original length of each label sequences. Must be 1D. + eos_index : int + The index for token. + + Returns + ------- + new_label : tensor + The new label with appended. + + Example + ------- + >>> label=torch.IntTensor([[1,0,0], [2,3,0], [4,5,6]]) + >>> length=torch.LongTensor([1,2,3]) + >>> new_label=append_eos_token(label, length, eos_index=7) + >>> new_label + tensor([[1, 7, 0, 0], + [2, 3, 7, 0], + [4, 5, 6, 7]], dtype=torch.int32) + """ + new_label = paddle.to_tensor(label, dtype="int32").clone() + batch_size = label.shape[0] + + pad = paddle.zeros([batch_size, 1], dtype=new_label.dtype) + + new_label = paddle.concat([new_label, pad], dim=1) + new_label[paddle.arange(batch_size), paddle.to_tensor(length, dtype="int64")] = eos_index + return new_label + + +def merge_char(sequences, space="_"): + """Merge characters sequences into word sequences. + + Arguments + --------- + sequences : list + Each item contains a list, and this list contains a character sequence. + space : string + The token represents space. Default: _ + + Returns + ------- + The list contains word sequences for each sentence. + + Example + ------- + >>> sequences = [["a", "b", "_", "c", "_", "d", "e"], ["e", "f", "g", "_", "h", "i"]] + >>> results = merge_char(sequences) + >>> results + [['ab', 'c', 'de'], ['efg', 'hi']] + """ + results = [] + for seq in sequences: + words = "".join(seq).split(space) + results.append(words) + return results + + +def merge_csvs(data_folder, csv_lst, merged_csv): + """Merging several csv files into one file. + + Arguments + --------- + data_folder : string + The folder to store csv files to be merged and after merging. + csv_lst : list + Filenames of csv file to be merged. + merged_csv : string + The filename to write the merged csv file. + + Example + ------- + >>> tmpdir = getfixture('tmpdir') + >>> os.symlink(os.path.realpath("tests/samples/annotation/speech.csv"), tmpdir / "speech.csv") + >>> merge_csvs(tmpdir, + ... ["speech.csv", "speech.csv"], + ... "test_csv_merge.csv") + """ + write_path = os.path.join(data_folder, merged_csv) + if os.path.isfile(write_path): + logger.info("Skipping merging. Completed in previous run.") + with open(os.path.join(data_folder, csv_lst[0])) as f: + header = f.readline() + lines = [] + for csv_file in csv_lst: + with open(os.path.join(data_folder, csv_file)) as f: + for i, line in enumerate(f): + if i == 0: + # Checking header + if line != header: + raise ValueError( + "Different header for " f"{csv_lst[0]} and {csv}." + ) + continue + lines.append(line) + with open(write_path, "w") as f: + f.write(header) + for line in lines: + f.write(line) + logger.info(f"{write_path} is created.") + + +def split_word(sequences, space="_"): + """Split word sequences into character sequences. + + Arguments + --------- + sequences : list + Each item contains a list, and this list contains a words sequence. + space : string + The token represents space. Default: _ + + Returns + ------- + The list contains word sequences for each sentence. + + Example + ------- + >>> sequences = [['ab', 'c', 'de'], ['efg', 'hi']] + >>> results = split_word(sequences) + >>> results + [['a', 'b', '_', 'c', '_', 'd', 'e'], ['e', 'f', 'g', '_', 'h', 'i']] + """ + results = [] + for seq in sequences: + chars = list(space.join(seq)) + results.append(chars) + return results diff --git a/paddlespeech/s2t/io/wav2vec2/dataloader.py b/paddlespeech/s2t/io/wav2vec2/dataloader.py new file mode 100755 index 000000000..0a6549cbf --- /dev/null +++ b/paddlespeech/s2t/io/wav2vec2/dataloader.py @@ -0,0 +1,215 @@ +"""PyTorch compatible DataLoaders + +Essentially we extend PyTorch DataLoader by adding the ability to save the +data loading state, so that a checkpoint may be saved in the middle of an +epoch. + +Example +------- +>>> import torch +>>> from speechbrain.utils.checkpoints import Checkpointer +>>> # An example "dataset" and its loader +>>> dataset = torch.randn(10, 1) +>>> dataloader = SaveableDataLoader(dataset, num_workers = 3) +>>> # Setup the checkpointer: +>>> tmpdir = getfixture('tmpdir') +>>> checkpointer = Checkpointer(tmpdir, {"dataloader": dataloader}) +>>> # Iterate: +>>> for i, data_point in enumerate(dataloader): +... # Here you would process the data: +... rainfall_amount_prediction = data_point * 4. +... # Now, imagine the experiment gets killed on the fifth batch: +... if i == 4: +... break +... # Luckily, you had just saved a checkpoint: +... if i == 3: +... _ = checkpointer.save_checkpoint(end_of_epoch = False) +>>> # So when you restart the experiment: +>>> new_dataloader = SaveableDataLoader(dataset, num_workers = 3) +>>> new_checkpointer = Checkpointer(tmpdir, {"dataloader": new_dataloader}) +>>> _ = new_checkpointer.recover_if_possible() +>>> # The dataloader fast-forwards to the position where we left off: +>>> assert next(iter(new_dataloader)) == dataset[4] + +Authors: + * Aku Rouhe 2020 +""" +import collections +import torch +from paddlespeech.s2t.io.wav2vec2.data_utils import mod_default_collate +# from speechbrain.utils.data_utils import recursive_to +from paddlespeech.s2t.io.wav2vec2.data_utils import batch_pad_right +from paddle.io import DataLoader +import logging +import warnings +import functools +# from batch import PaddedBatch +from paddlespeech.s2t.io.wav2vec2.dataset import DynamicItemDataset +from paddlespeech.s2t.io.wav2vec2.sampler import ReproducibleRandomSampler +import paddle +PaddedData = collections.namedtuple("PaddedData", ["data", "lengths"]) +import numpy + +class Wav2vec2DataLoader(DataLoader): + def __init__(self, + dataset, + batch_size=1, + shuffle=False, + sampler=None, + batch_sampler=None, + num_workers=0, + collate_fn=None, + pin_memory=False, + drop_last=False, + timeout=0, + worker_init_fn=None, + multiprocessing_context=None, + generator=None): + if isinstance(dataset[0], (tuple, list)): + return_list = True + else: + return_list = False + + super().__init__( + dataset, + feed_list=None, + places=None, + return_list=return_list, + batch_sampler=batch_sampler, + batch_size=batch_size, + shuffle=shuffle, + drop_last=drop_last, + collate_fn=collate_fn, + num_workers=num_workers, + use_buffer_reader=True, + use_shared_memory=False, + timeout=timeout, + worker_init_fn=worker_init_fn) + if sampler is not None: + self.batch_sampler.sampler = sampler + # self.dataloader = DataLoader( + # dataset=dataset, + # batch_sampler=batch_sampler, + # collate_fn=collate_fn, + # num_workers=num_workers,) + + # def __len__(self): + # return len(self.dataloader) + + # def __iter__(self): + # return self.dataloader.__iter__() + + # def __call__(self): + # return self.__iter__() + + +def PaddedBatch( + examples, + padded_keys=None, + device_prep_keys=None, + padding_func=batch_pad_right, + padding_kwargs={}, + nonpadded_stack=True, + ): + __length = len(examples) + __keys = list(examples[0].keys()) + __padded_keys = [] + __device_prep_keys = [] + res = {} + for key in __keys: + values = [example[key] for example in examples] + # Default convert usually does the right thing (numpy2torch etc.) + # values = default_convert(values) + if (padded_keys is not None and key in padded_keys) or ( + padded_keys is None and isinstance(values[0], numpy.ndarray) + ): + # Padding and PaddedData + __padded_keys.append(key) + + padded = PaddedData(*padding_func(values, **padding_kwargs)) + res[key] = padded + else: + # Default PyTorch collate usually does the right thing + # (convert lists of equal sized tensors to batch tensors, etc.) + if nonpadded_stack: + values = mod_default_collate(values) + res[key] = values + if (device_prep_keys is not None and key in device_prep_keys) or ( + device_prep_keys is None and isinstance(values[0], paddle.Tensor) + ): + __device_prep_keys.append(key) + return res + +def make_dataloader(dataset, stage, **loader_kwargs): + """Makes a basic DataLoader with SpeechBrain defaults. + + For DynamicItemDatasets (which return dicts), use + PaddedBatch as the default collate_fn. + + Shuffling gets implemented by ReproducibleRandomSampler. + + If the Dataset is not an IterableDataset, the DataLoader + is a SaveableDataLoader. + + If the Dataset is a webdataset.dataset.Composable, set default + batch_size = None. + + Can also loop over the underlying dataloader continuously, + and stop iterations at nominal epoch lengths. + + Arguments + --------- + dataset : Dataset + The dataset to make a DataLoader for. + looped_nominal_epoch : None, int + If an integer is given, loop the underlying DataLoader infinitely and + set a nominal epoch length in batches (or whatever the DataLoader + yields). + **loader_kwargs : dict + Keyword args to DataLoader, see PyTorch DataLoader for + options. + + Returns + ------- + DataLoader + If looped_nominal_epoch is None + LoopedLoader + If looped_nominal_epoch is not None + """ + # PaddedBatch as default collation for DynamicItemDataset + if "collate_fn" not in loader_kwargs and isinstance( + dataset, DynamicItemDataset + ): + loader_kwargs["collate_fn"] = PaddedBatch + # Reproducible random sampling + if loader_kwargs.get("shuffle", False): + if loader_kwargs.get("sampler") is not None: + raise ValueError( + "Cannot specify both shuffle=True and a " + "sampler in loader_kwargs" + ) + sampler = ReproducibleRandomSampler(dataset) + loader_kwargs["sampler"] = sampler + # Should delete shuffle because you can't set both Sampler and + # shuffle + # NOTE: the dict of loader options may get used elsewhere! + # However, this del doesn't touch those because loader_kwargs comes + # from a **kwargs dict. + del loader_kwargs["shuffle"] + # Create the loader + dataloader = Wav2vec2DataLoader(dataset, **loader_kwargs) + return dataloader + + +# import collections +# import torch +# from data_utils import mod_default_collate +# # from speechbrain.utils.data_utils import recursive_to +# from data_utils import batch_pad_right +# from torch.utils.data._utils.collate import default_convert +# # from torch.utils.data._utils.pin_memory import ( +# # pin_memory as recursive_pin_memory, +# # ) +# import paddle + +# PaddedData = collections.namedtuple("PaddedData", ["data", "lengths"]) diff --git a/paddlespeech/s2t/io/wav2vec2/dataset.py b/paddlespeech/s2t/io/wav2vec2/dataset.py new file mode 100755 index 000000000..3ed336719 --- /dev/null +++ b/paddlespeech/s2t/io/wav2vec2/dataset.py @@ -0,0 +1,409 @@ +import copy +import contextlib +from types import MethodType +from paddle.io import Dataset +from paddlespeech.s2t.io.wav2vec2.data_pipeline import DataPipeline +from paddlespeech.s2t.io.wav2vec2.dataio import load_data_json, load_data_csv +import logging + +logger = logging.getLogger(__name__) + + +class DynamicItemDataset(Dataset): + """Dataset that reads, wrangles, and produces dicts. + + Each data point dict provides some items (by key), for example, a path to a + wavefile with the key "wav_file". When a data point is fetched from this + Dataset, more items are produced dynamically, based on pre-existing items + and other dynamic created items. For example, a dynamic item could take the + wavfile path and load the audio from the disk. + + The dynamic items can depend on other dynamic items: a suitable evaluation + order is used automatically, as long as there are no circular dependencies. + + A specified list of keys is collected in the output dict. These can be items + in the original data or dynamic items. If some dynamic items are not + requested, nor depended on by other requested items, they won't be computed. + So for example if a user simply wants to iterate over the text, the + time-consuming audio loading can be skipped. + + About the format: + Takes a dict of dicts as the collection of data points to read/wrangle. + The top level keys are data point IDs. + Each data point (example) dict should have the same keys, corresponding to + different items in that data point. + + Altogether the data collection could look like this: + + >>> data = { + ... "spk1utt1": { + ... "wav_file": "/path/to/spk1utt1.wav", + ... "text": "hello world", + ... "speaker": "spk1", + ... }, + ... "spk1utt2": { + ... "wav_file": "/path/to/spk1utt2.wav", + ... "text": "how are you world", + ... "speaker": "spk1", + ... } + ... } + + NOTE + ---- + The top-level key, the data point id, is implicitly added as an item + in the data point, with the key "id" + + Each dynamic item is configured by three things: a key, a func, and a list + of argkeys. The key should be unique among all the items (dynamic or not) in + each data point. The func is any callable, and it returns the dynamic item's + value. The callable is called with the values of other items as specified + by the argkeys list (as positional args, passed in the order specified by + argkeys). + + The dynamic_items configuration could look like this: + + >>> import torch + >>> dynamic_items = [ + ... {"func": lambda l: torch.Tensor(l), + ... "takes": ["wav_loaded"], + ... "provides": "wav"}, + ... {"func": lambda path: [ord(c)/100 for c in path], # Fake "loading" + ... "takes": ["wav_file"], + ... "provides": "wav_loaded"}, + ... {"func": lambda t: t.split(), + ... "takes": ["text"], + ... "provides": "words"}] + + With these, different views of the data can be loaded: + + >>> from speechbrain.dataio.dataloader import SaveableDataLoader + >>> from speechbrain.dataio.batch import PaddedBatch + >>> dataset = DynamicItemDataset(data, dynamic_items) + >>> dataloader = SaveableDataLoader(dataset, collate_fn=PaddedBatch, + ... batch_size=2) + >>> # First, create encoding for words: + >>> dataset.set_output_keys(["words"]) + >>> encoding = {} + >>> next_id = 1 + >>> for batch in dataloader: + ... for sent in batch.words: + ... for word in sent: + ... if word not in encoding: + ... encoding[word] = next_id + ... next_id += 1 + >>> # Next, add an encoded words_tensor dynamic item: + >>> dataset.add_dynamic_item( + ... func = lambda ws: torch.tensor([encoding[w] for w in ws], + ... dtype=torch.long), + ... takes = ["words"], + ... provides = "words_encoded") + >>> # Now we can get word and audio tensors: + >>> dataset.set_output_keys(["id", "wav", "words_encoded"]) + >>> batch = next(iter(dataloader)) + >>> batch.id + ['spk1utt1', 'spk1utt2'] + >>> batch.wav # +ELLIPSIS + PaddedData(data=tensor([[0.4700, 1.1200, ... + >>> batch.words_encoded + PaddedData(data=tensor([[1, 2, 0, 0], + [3, 4, 5, 2]]), lengths=tensor([0.5000, 1.0000])) + + Output keys can also be a map: + + >>> dataset.set_output_keys({"id":"id", "signal": "wav", "words": "words_encoded"}) + >>> batch = next(iter(dataloader)) + >>> batch.words + PaddedData(data=tensor([[1, 2, 0, 0], + [3, 4, 5, 2]]), lengths=tensor([0.5000, 1.0000])) + + + Arguments + --------- + data : dict + Dictionary containing single data points (e.g. utterances). + dynamic_items : list, optional + Configuration for the dynamic items produced when fetching an example. + List of DynamicItems or dicts with the format:: + func: # To be called + takes: # key or list of keys of args this takes + provides: key # key or list of keys that this provides + output_keys : dict, list, optional + List of keys (either directly available in data or dynamic items) + to include in the output dict when data points are fetched. + + If a dict is given; it is used to map internal keys to output keys. + From the output_keys dict key:value pairs the key appears outside, + and value is the internal key. + """ + + def __init__( + self, data, dynamic_items=[], output_keys=[], + ): + self.data = data + self.data_ids = list(self.data.keys()) + static_keys = list(self.data[self.data_ids[0]].keys()) + if "id" in static_keys: + raise ValueError("The key 'id' is reserved for the data point id.") + else: + static_keys.append("id") + self.pipeline = DataPipeline(static_keys, dynamic_items) + self.set_output_keys(output_keys) + + def __len__(self): + return len(self.data_ids) + + def __getitem__(self, index): + data_id = self.data_ids[index] + data_point = self.data[data_id] + return self.pipeline.compute_outputs({"id": data_id, **data_point}) + + def add_dynamic_item(self, func, takes=None, provides=None): + """Makes a new dynamic item available on the dataset. + + Two calling conventions. For DynamicItem objects, just use: + add_dynamic_item(dynamic_item). + But otherwise, should use: + add_dynamic_item(func, takes, provides). + + See `speechbrain.utils.data_pipeline`. + + Arguments + --------- + func : callable, DynamicItem + If a DynamicItem is given, adds that directly. Otherwise a + DynamicItem is created, and this specifies the callable to use. If + a generator function is given, then create a GeneratorDynamicItem. + Otherwise creates a normal DynamicItem. + takes : list, str + List of keys. When func is called, each key is resolved to + either an entry in the data or the output of another dynamic_item. + The func is then called with these as positional arguments, + in the same order as specified here. + A single arg can be given directly. + provides : str + Unique key or keys that this provides. + """ + self.pipeline.add_dynamic_item(func, takes, provides) + + def set_output_keys(self, keys): + """Use this to change the output keys. + + These are the keys that are actually evaluated when a data point + is fetched from the dataset. + + Arguments + --------- + keys : dict, list + List of keys (str) to produce in output. + + If a dict is given; it is used to map internal keys to output keys. + From the output_keys dict key:value pairs the key appears outside, + and value is the internal key. + """ + self.pipeline.set_output_keys(keys) + + @contextlib.contextmanager + def output_keys_as(self, keys): + """Context manager to temporarily set output keys. + + Example + ------- + >>> dataset = DynamicItemDataset({"a":{"x":1,"y":2},"b":{"x":3,"y":4}}, + ... output_keys = ["x"]) + >>> with dataset.output_keys_as(["y"]): + ... print(dataset[0]) + {'y': 2} + >>> print(dataset[0]) + {'x': 1} + + NOTE + ---- + Not thread-safe. While in this context manager, the output keys + are affected for any call. + """ + saved_output = self.pipeline.output_mapping + self.pipeline.set_output_keys(keys) + yield self + self.pipeline.set_output_keys(saved_output) + + def filtered_sorted( + self, + key_min_value={}, + key_max_value={}, + key_test={}, + sort_key=None, + reverse=False, + select_n=None, + ): + """Get a filtered and/or sorted version of this, shares static data. + + The reason to implement these operations in the same method is that + computing some dynamic items may be expensive, and this way the + filtering and sorting steps don't need to compute the dynamic items + twice. + + Arguments + --------- + key_min_value : dict + Map from key (in data or in dynamic items) to limit, will only keep + data_point if data_point[key] >= limit + key_max_value : dict + Map from key (in data or in dynamic items) to limit, will only keep + data_point if data_point[key] <= limit + key_test : dict + Map from key (in data or in dynamic items) to func, will only keep + data_point if bool(func(data_point[key])) == True + sort_key : None, str + If not None, sort by data_point[sort_key]. Default is ascending + order. + reverse : bool + If True, sort in descending order. + select_n : None, int + If not None, only keep (at most) the first n filtered data_points. + The possible sorting is applied, but only on the first n data + points found. Meant for debugging. + + Returns + ------- + FilteredSortedDynamicItemDataset + Shares the static data, but has its own output keys and + dynamic items (initially deep copied from this, so they have the + same dynamic items available) + + NOTE + ---- + Temporarily changes the output keys! + """ + filtered_sorted_ids = self._filtered_sorted_ids( + key_min_value, key_max_value, key_test, sort_key, reverse, select_n, + ) + return FilteredSortedDynamicItemDataset( + self, filtered_sorted_ids + ) # NOTE: defined below + + def _filtered_sorted_ids( + self, + key_min_value={}, + key_max_value={}, + key_test={}, + sort_key=None, + reverse=False, + select_n=None, + ): + """Returns a list of data ids, fulfilling the sorting and filtering.""" + + def combined_filter(computed): + """Applies filter.""" + for key, limit in key_min_value.items(): + # NOTE: docstring promises >= so using that. + # Mathematically could also use < for nicer syntax, but + # maybe with some super special weird edge case some one can + # depend on the >= operator + if computed[key] >= limit: + continue + return False + for key, limit in key_max_value.items(): + if computed[key] <= limit: + continue + return False + for key, func in key_test.items(): + if bool(func(computed[key])): + continue + return False + return True + + temp_keys = ( + set(key_min_value.keys()) + | set(key_max_value.keys()) + | set(key_test.keys()) + | set([] if sort_key is None else [sort_key]) + ) + filtered_ids = [] + with self.output_keys_as(temp_keys): + for i, data_id in enumerate(self.data_ids): + if select_n is not None and len(filtered_ids) == select_n: + break + data_point = self.data[data_id] + data_point["id"] = data_id + computed = self.pipeline.compute_outputs(data_point) + if combined_filter(computed): + if sort_key is not None: + # Add (main sorting index, current index, data_id) + # So that we maintain current sorting and don't compare + # data_id values ever. + filtered_ids.append((computed[sort_key], i, data_id)) + else: + filtered_ids.append(data_id) + if sort_key is not None: + filtered_sorted_ids = [ + tup[2] for tup in sorted(filtered_ids, reverse=reverse) + ] + else: + filtered_sorted_ids = filtered_ids + return filtered_sorted_ids + + @classmethod + def from_json( + cls, json_path, replacements={}, dynamic_items=[], output_keys=[] + ): + """Load a data prep JSON file and create a Dataset based on it.""" + data = load_data_json(json_path, replacements) + return cls(data, dynamic_items, output_keys) + + @classmethod + def from_csv( + cls, csv_path, replacements={}, dynamic_items=[], output_keys=[] + ): + """Load a data prep CSV file and create a Dataset based on it.""" + data = load_data_csv(csv_path, replacements) + return cls(data, dynamic_items, output_keys) + + @classmethod + def from_arrow_dataset( + cls, dataset, replacements={}, dynamic_items=[], output_keys=[] + ): + """Loading a prepared huggingface dataset""" + # define an unbound method to generate puesdo keys + def keys(self): + "Returns the keys." + return [i for i in range(dataset.__len__())] + + # bind this method to arrow dataset + dataset.keys = MethodType(keys, dataset) + return cls(dataset, dynamic_items, output_keys) + +class FilteredSortedDynamicItemDataset(DynamicItemDataset): + """Possibly filtered, possibly sorted DynamicItemDataset. + + Shares the static data (reference). + Has its own dynamic_items and output_keys (deepcopy). + """ + + def __init__(self, from_dataset, data_ids): + self.data = from_dataset.data + self.data_ids = data_ids + self.pipeline = copy.deepcopy(from_dataset.pipeline) + + @classmethod + def from_json( + cls, json_path, replacements={}, dynamic_items=None, output_keys=None + ): + raise TypeError("Cannot create SubsetDynamicItemDataset directly!") + + @classmethod + def from_csv( + cls, csv_path, replacements={}, dynamic_items=None, output_keys=None + ): + raise TypeError("Cannot create SubsetDynamicItemDataset directly!") + + +def add_dynamic_item(datasets, func, takes=None, provides=None): + """Helper for adding the same item to multiple datasets.""" + for dataset in datasets: + dataset.add_dynamic_item(func, takes, provides) + + +def set_output_keys(datasets, output_keys): + """Helper for setting the same item to multiple datasets.""" + for dataset in datasets: + dataset.set_output_keys(output_keys) diff --git a/paddlespeech/s2t/io/wav2vec2/depgraph.py b/paddlespeech/s2t/io/wav2vec2/depgraph.py new file mode 100755 index 000000000..02bbe7a04 --- /dev/null +++ b/paddlespeech/s2t/io/wav2vec2/depgraph.py @@ -0,0 +1,276 @@ +"""A dependency graph for finding evaluation order. + +Example +------- +>>> # The basic use case is that you have a bunch of keys +>>> # and some of them depend on each other: +>>> database = [] +>>> functions = {'read': {'func': lambda: (0,1,2), +... 'needs': []}, +... 'process': {'func': lambda X: [x**2 for x in X], +... 'needs': ['read']}, +... 'save': {'func': lambda x: database.append(x), +... 'needs': ['process']}, +... 'print': {'func': lambda x,y: print(x, "became", y), +... 'needs': ['read', 'process']}, +... 'auxiliary': {'func': lambda: (1,2,3), +... 'needs': []}} +>>> # If this is user supplied info, so you can't just hardcode the order, +>>> # a dependency graph may be needed. +>>> dg = DependencyGraph() +>>> # In simple cases, you can just encode the dependencies directly: +>>> for key, conf in functions.items(): +... for needed in conf["needs"]: +... dg.add_edge(key, needed) +>>> # Now we can evaluate: +>>> outputs = {} +>>> for node in dg.get_evaluation_order(): +... f = functions[node.key]['func'] +... args = [outputs[needed] for needed in functions[node.key]['needs']] +... outputs[node.key] = f(*args) +(0, 1, 2) became [0, 1, 4] +>>> # This added nodes implicitly. +>>> # However, since 'auxiliary' didn't depend on anything, +>>> # it didn't get added! +>>> assert 'auxiliary' not in outputs +>>> # So to be careful, we should also manually add nodes for any thing that +>>> # is not an intermediate step. +>>> _ = dg.add_node('auxiliary') +>>> assert 'auxiliary' in (node.key for node in dg.get_evaluation_order()) +>>> # Arbitrary data can be added to nodes: +>>> dg2 = DependencyGraph() +>>> for key, conf in functions.items(): +... _ = dg2.add_node(key, conf) +... for needed in conf["needs"]: +... dg2.add_edge(key, needed) +>>> # Now we get access to the data in evaluation: +>>> outputs2 = {} +>>> for key, _, conf in dg2.get_evaluation_order(): +... f = conf['func'] +... args = [outputs[needed] for needed in conf['needs']] +... outputs[key] = f(*args) +(0, 1, 2) became [0, 1, 4] + +Authors: + * Aku Rouhe 2020 +""" +import collections +import uuid + + +class CircularDependencyError(ValueError): + """ + An error caused by running into circular dependencies while searching for + an evaluation order in a DependencyGraph. + """ + + pass + + +DGNode = collections.namedtuple("DGNode", ["key", "edges", "data"]) +# A node in DependencyGraph. + + +class DependencyGraph: + """General-purpose dependency graph. + + Essentially a directed acyclic graph. + Usually used to find an evaluation order for e.g. variable substitution + The relation that an edge between A and B represents is: + "A depends on B, i.e. B should be evaluated before A" + + Nodes can be added explicitly or they can be created implicitly + while adding edges. + Nodes have keys, which should be some hashable value that identifies + the elements the graph represents in your use case. E.G. they can just + be the variable name you want to substitute. + However, if needed, more generally you can attach any data to a node + (e.g. a path in your tree), and if so desired, a unique key can be + created for you. You'll only need to know that key while adding edges + to/from it. + Implicit keys and explicit keys can also be mixed. + """ + + def __init__(self): + self.digraph = [] + self.key2ind = {} + # Guard for manual duplicates (but not implicitly added ones) + self._manually_added_keys = [] + + @staticmethod + def get_unique_key(): + """Returns a unique hashable identifier.""" + return uuid.uuid4() + + def add_node(self, key=None, data=None): + """Adds a node explicitly. + + Arguments + --------- + key : hashable, optional + If not given, a key is created for you. + data : Any, optional + Any additional data you wish to attach to this node. + + Returns + ------- + hashable + The key that was used (either yours or generated). + + Raises + ------ + ValueError + If node with the given key has already been added explicitly + (with this method, not "add_edge"). + """ + if key is None: + key = self.get_unique_key() + elif key in self._manually_added_keys: + raise ValueError("Adding duplicate node: {key}".format(key=key)) + else: + self._manually_added_keys.append(key) + if key in self.key2ind: # Implicitly added already; don't add again. + ind = self.key2ind[key] + node = self.digraph[ind] + # All that this operation can do is add data: + self.digraph[ind] = DGNode(node.key, node.edges, data) + return key + self.key2ind[key] = len(self.digraph) + self.digraph.append(DGNode(key, [], data)) + return key + + def add_edge(self, from_key, to_key): + """Adds an edge, and implicitly also creates nodes for keys which have + not been seen before. This will not let you add data to your nodes. + The relation encodes: "from_key depends on to_key" + (to_key must be evaluated before from_key). + + Arguments + --------- + from_key : hashable + The key which depends on. + to_key : hashable + The key which is depended on. + + Returns + ------- + None + """ + from_ind = self._get_ind_and_add_if_new(from_key) + to_ind = self._get_ind_and_add_if_new(to_key) + edges_list = self.digraph[from_ind].edges + if to_ind not in edges_list: + edges_list.append(to_ind) + + def _get_ind_and_add_if_new(self, key): + # Used internally to implicitly add nodes for unseen keys + if key not in self.key2ind: + self.key2ind[key] = len(self.digraph) + self.digraph.append(DGNode(key, [], None)) + return self.key2ind[key] + + def is_valid(self): + """Checks if an evaluation order can be found. + + A dependency graph is evaluatable if there are no circular + dependencies, i.e., the graph is acyclic. + + Returns + ------- + bool + Indicating if the graph is evaluatable. + """ + return not self._find_first_cycle() + + def get_evaluation_order(self, selected_keys=None): + """Finds one valid evaluation order. + + There can be many different valid + orders. + NOTE: Generates output one DGNode at a time. May generate DGNodes + before it finds a circular dependency. If you really need to know + whether an order can be found, check is_valid() first. However, + the algorithm for finding cycles is essentially the same as the one + used for finding an evaluation order, so for very large graphs... + Ah well, but maybe then you should be using some other solution + anyway. + + Arguments + --------- + selected_keys : list, None + List of keys. If not None, only the selected keys are guaranteed + in the evaluation order (along with the keys they depend on). + + Yields + ------ + DGNode + The added DGNodes in a valid evaluation order. + See the DGNode namedtuple above. + + Raises + ------ + CircularDependencyError + If a circular dependency is found. + """ + seen_ever = set() + + def toposort(root_ind, visited): + """Implementation of topsort.""" + nonlocal seen_ever + here = visited + [root_ind] + if root_ind in visited: + raise CircularDependencyError( + "{cycle}".format( + cycle=" -> ".join( + str(self.digraph[i].key) for i in here + ) + ) + ) + if root_ind in seen_ever: + return # Yield nothing + seen_ever = seen_ever.union(set([root_ind])) + for to_ind in self.digraph[root_ind].edges: + for ind in toposort(to_ind, visited=here): + yield ind + yield root_ind + + if selected_keys is None: + start_inds = range(len(self.digraph)) + else: + start_inds = [self.key2ind[key] for key in selected_keys] + + for start_ind in start_inds: + for ind in toposort(start_ind, []): + yield self.digraph[ind] + + def _find_first_cycle(self): + """Depth-first search based algorithm for finding cycles in the graph.""" + seen_ever = set() + + def cycle_dfs(root_ind, visited): + """Implementation of cycle_dfs.""" + nonlocal seen_ever + print(root_ind, visited) + here = visited + [root_ind] + if root_ind in visited: + return here + if root_ind in seen_ever: + return [] + seen_ever = seen_ever.union(set([root_ind])) + for to_ind in self.digraph[root_ind].edges: + cycle = cycle_dfs(to_ind, here) + if cycle: + return cycle + return [] + + for ind in range(len(self.digraph)): + if ind not in seen_ever: + cycle = cycle_dfs(ind, []) + if cycle: + return cycle + return [] + + def __contains__(self, key): + # Allows the syntax: + # 'key' in dependency_graph + return key in self.key2ind diff --git a/paddlespeech/s2t/io/wav2vec2/make_dataloader.py b/paddlespeech/s2t/io/wav2vec2/make_dataloader.py new file mode 100755 index 000000000..1f3157b23 --- /dev/null +++ b/paddlespeech/s2t/io/wav2vec2/make_dataloader.py @@ -0,0 +1,115 @@ +import paddlespeech.s2t.io.wav2vec2.dataloader + + +def _train_loader_specifics(self, dataset, loader_kwargs): + sampler = loader_kwargs.get("sampler", None) + # Shuffling should really only matter for the train stage. Shuffling + # will also lead to more padding in batches if the order was otherwise + # sorted by length. + shuffle = loader_kwargs.get("shuffle", False) + if shuffle and not self.distributed_launch: + if sampler is not None: + raise ValueError( + "Cannot specify both shuffle=True" + "and a sampler in loader_kwargs" + ) + sampler = ReproducibleRandomSampler(dataset) + self.train_sampler = sampler + loader_kwargs["sampler"] = self.train_sampler + # Delete the shuffle flag, since you cannot specify both a sampler and + # shuffling: + del loader_kwargs["shuffle"] + + # Possibly make a DistributedSampler or a wrapper for some other sampler + if self.distributed_launch and not isinstance(dataset, IterableDataset): + drop_last = loader_kwargs.get("drop_last", False) + # num_replicas arg is equal to world_size + # and retrieved automatically within + # DistributedSampler obj. + if sampler is not None: + self.train_sampler = DistributedSamplerWrapper( + sampler, + rank=self.rank, + drop_last=drop_last, + shuffle=shuffle, + ) + + # with DistributedSamplerWrapper, one must disable shuffling for dataloader + loader_kwargs["shuffle"] = False + loader_kwargs["sampler"] = self.train_sampler + elif loader_kwargs.get("batch_sampler") is None: + # no sampler and batch-sampler + self.train_sampler = DistributedSampler( + dataset, rank=self.rank, shuffle=True, drop_last=drop_last + ) + + # with DistributedSamplerWrapper, one must disable shuffling for dataloader + loader_kwargs["shuffle"] = False + loader_kwargs["sampler"] = self.train_sampler + else: # batch_sampler was specified + self.train_sampler = DistributedSamplerWrapper( + loader_kwargs.get("batch_sampler", None), + rank=self.rank, + shuffle=True, + ) + loader_kwargs["batch_sampler"] = self.train_sampler + elif self.distributed_launch and isinstance(dataset, IterableDataset): + logger.warning( + "Cannot automatically solve distributed sampling " + "for IterableDataset." + ) + return loader_kwargs + + +def make_dataloader( + self, dataset, stage, **loader_kwargs + ): + """Creates DataLoaders for Datasets. + + This is used by ``fit()`` and ``evaluate()`` if they just receive + Datasets. + + Alternatively, this can be called from outside the Brain subclass. + In that case, the DataLoader should be passed to ``fit()`` in place + of the dataset. + + The Stage.TRAIN DataLoader is handled specially. It has extra args for + shuffle and drop_last. In DDP a DistributedSampler is created (unless + the dataset is an IterableDataset). + + NOTE + ---- + Some important DataLoader arguments are passed via **loader_kwargs, + e.g., batch_size, num_workers, pin_memory. + + NOTE + ---- + By default, ``evaluate()`` specifies ckpt_prefix=None to stop the test + DataLoader being added to the checkpointer. If you need to add a + recoverable after saving checkpoints (e.g., at test time, after + checkpointing the training), and still be able to recover reasonably, + you should probably specify ``allow_partial_load=True``. + + Arguments + --------- + dataset : Dataset + A set of data to use to create data loader. If the Dataset is a + DynamicItemDataset, PaddedBatch is used as the default collate_fn, + unless specified in loader_kwargs. + stage : Stage + The stage of the experiment: Stage.TRAIN, Stage.VALID, Stage.TEST + ckpt_prefix : str, None + Prefix to use for SaveableDataLoader Checkpoint name. The Stage + name is added to this to create the full key. Set to None to not + save the DataLoader. + **loader_kwargs : dict + Additional keyword arguments to the DataLoader. + E.g., batch_size, num_workers, pin_memory. + """ + # TRAIN stage is handled specially. + # if stage == train: + # loader_kwargs = _train_loader_specifics(dataset, loader_kwargs) + dataloader_ = dataloader.make_dataloader( + dataset, **loader_kwargs + ) + return dataloader_ \ No newline at end of file diff --git a/paddlespeech/s2t/io/wav2vec2/sampler.py b/paddlespeech/s2t/io/wav2vec2/sampler.py new file mode 100755 index 000000000..f4fa38379 --- /dev/null +++ b/paddlespeech/s2t/io/wav2vec2/sampler.py @@ -0,0 +1,695 @@ +"""PyTorch compatible samplers. + +These determine the order of iteration through a dataset. + +Authors: + * Aku Rouhe 2020 + * Samuele Cornell 2020 + * Ralf Leibold 2020 + * Artem Ploujnikov 2021 + * Andreas Nautsch 2021 +""" +import torch +import logging +from operator import itemgetter +from paddle.io import ( + RandomSampler, + WeightedRandomSampler, + Sampler, +) + +import numpy as np +from typing import List +from paddlespeech.s2t.io.wav2vec2.dataset import DynamicItemDataset +from collections import Counter +from scipy.stats import lognorm + +logger = logging.getLogger(__name__) + +class ReproducibleRandomSampler(RandomSampler): + """A modification of RandomSampler which always returns the same values. + + Also look at `torch.utils.data.RandomSampler`. This has mostly + the same behaviour and arguments, except for adding 'seed' and 'epoch' and + not supporting 'generator'. + + Note + ---- + Call `set_epoch` before every epoch. Otherwise, the sampler will produce the + same sequence of indices every epoch. + + Arguments + --------- + data_source : Dataset + The data source to sample indices for. + seed : int + The base seed to use for the random number generator. It is recommended + to use a value which has a good mix of 0 and 1 bits. + epoch : int + The epoch to start at. + + Example + ------- + >>> import torch + >>> from speechbrain.utils.checkpoints import Checkpointer + >>> from speechbrain.dataio.dataloader import SaveableDataLoader + >>> # An example "dataset" + >>> dataset = torch.arange(10).unsqueeze(1) + >>> # Create the random sampler: + >>> sampler = ReproducibleRandomSampler(dataset) + >>> dataloader = SaveableDataLoader(dataset, sampler = sampler, + ... num_workers = 3) + >>> # Setup the checkpointer. + >>> # Note that the sampler doesn't need to be saved itself. + >>> tmpdir = getfixture('tmpdir') + >>> checkpointer = Checkpointer(tmpdir, {"dataloader": dataloader}) + >>> # Iterate: + >>> subset = [] + >>> for i, data_point in enumerate(dataloader): + ... # Say you save a checkpoint on the fourth batch: + ... if i == 3: + ... _ = checkpointer.save_checkpoint(end_of_epoch = False) + ... # So let's save the numbers you would get if you continue + ... if i >= 4: + ... subset.append(data_point.item()) + >>> # What if instead you had to restart the experiment? + >>> new_sampler = ReproducibleRandomSampler(dataset) + >>> new_dataloader = SaveableDataLoader(dataset, sampler = new_sampler, + ... num_workers = 3) + >>> new_checkpointer = Checkpointer(tmpdir, {"dataloader": new_dataloader}) + >>> _ = new_checkpointer.recover_if_possible() + >>> # You'll get the same random order again: + >>> new_subset = [data_point.item() for data_point in new_dataloader] + >>> assert subset == new_subset + + """ + + def __init__(self, data_source, seed=563375142, epoch=0, **kwargs): + if "generator" in kwargs: + MSG = ( + "Cannot give a separate generator when using " + + "ReproducibleRandomSampler" + ) + raise ValueError(MSG) + super().__init__(data_source, **kwargs) + self.seed = int(seed) + self.epoch = epoch + self.gen = paddle.seed(1) + + def set_epoch(self, epoch): + """ + You can also just access self.epoch, but we maintain this interface + to mirror torch.utils.data.distributed.DistributedSampler + """ + self.epoch = epoch + + def __iter__(self): + self.gen.manual_seed(self.seed + self.epoch) + return super().__iter__() + + +class ReproducibleWeightedRandomSampler(WeightedRandomSampler): + """A reproducible modification of WeightedRandomSampler. + + Also look at `torch.utils.data.WeightedRandomSampler`. This has the + the same behaviour and arguments, except for adding 'seed' and 'epoch' and + not supporting 'generator'. + + Note + ---- + Call `set_epoch` before every epoch. Otherwise, the sampler will produce the + same sequence of indices every epoch. + + Arguments + --------- + weights : sequence of float + Weights for each index. Doesn't need to sum to one. + num_samples : int + Number of samples to draw + replacement : bool + To draw with replacement or not (within an epoch of num_samples). + seed : int + The base seed to use for the random number generator. It is recommended + to use a value which has a good mix of 0 and 1 bits. + epoch : int + The epoch to start at. + + Example + ------- + >>> a = ReproducibleWeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True) + >>> b = ReproducibleWeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True) + >>> list(a) + [3, 1, 4, 4, 4] + >>> list(b) + [3, 1, 4, 4, 4] + >>> a.set_epoch(1) + >>> list(a) + [4, 5, 4, 4, 3] + >>> b.set_epoch(1) + >>> list(b) + [4, 5, 4, 4, 3] + + + """ + + def __init__( + self, + weights, + num_samples, + replacement, + seed=129491412, + epoch=0, + **kwargs, + ): + if "generator" in kwargs: + MSG = ( + "Cannot give a separate generator when using " + + "ReproducibleRandomSampler" + ) + raise ValueError(MSG) + super().__init__(weights, num_samples, replacement, **kwargs) + self.seed = int(seed) + self.epoch = epoch + self.gen = paddle.seed(1) + + def set_epoch(self, epoch): + """ + You can also just access self.epoch, but we maintain this interface + to mirror torch.utils.data.distributed.DistributedSampler + """ + self.epoch = epoch + + def __iter__(self): + self.gen.manual_seed(self.seed + self.epoch) + return super().__iter__() + +class DynamicBatchSampler(Sampler): + """This BatchSampler batches examples together by grouping them by their length. + + Every example in the batch have approximately the same length and + thus padding is minimized. + This enables faster training on datasets + where length of examples can vary significantly (e.g Librispeech). + Inspired by: https://www.tensorflow.org/api_docs/python/tf/data/experimental/bucket_by_sequence_length + + Dynamic batching is performed by specifying a max_batch_length which is the + upper limit for the sum of the length of examples in a batch: + e.g., if ex1 has length 4, ex2 length 5 and if max_batch_length is set to 6 + ex1 and ex2 will be placed, alone, in two distinct batches. + + Length for each example can be obtained in two manners. + If the input dataset is a DynamicItemDataset it can be obtained by specifying a + length_func. Default assumes a "duration" entry is in the annotation. + Length for each example can also be passed to this class upon instantiation + by specifying a list containing the length for each example and passing it to + lengths_list. + + Examples are grouped together by defining a set of possible discrete intervals + (buckets). Examples whose length fall into these intervals can be batched together. + + The number of buckets can be specified by using the arg num_buckets. + There is usually an optimal range for the value of this argument. + + If num_buckets == 1, all examples can be batched together. You have maximum randomization + but your training speed will be slower due to the fact that a large amount of the values will be padding + as long and short examples can be batched together. + As the number of buckets grows only examples with similar + length can be grouped together. + This trades-off speed with randomization. + TLDR: Low number -> better randomization, High number -> faster training. + NOTE THAT: if set too high the training speed will decrease. If num_buckets -> number of examples in the dataset the batch size + will be small impacting training speed and possibly performance. + + The buckets can also be specified by passing a list to the bucket_boundaries + argument instead of specifying a left_bucket_length and a bucket_length_multiplier. + + Example + ------- + >>> import torch + >>> import speechbrain as sb + >>> from speechbrain.dataio.sampler import DynamicBatchSampler + >>> from speechbrain.dataio.dataset import DynamicItemDataset + >>> from speechbrain.dataio.dataloader import SaveableDataLoader + >>> from speechbrain.dataio.batch import PaddedBatch + >>> import numpy as np + >>> item_lengths = sorted([np.random.randint(10, 100) for x in range(20)]) + >>> dataset = {"ex_{}".format(x) : {"wav" :torch.randn(x)} for x in item_lengths} + >>> dataset = DynamicItemDataset(dataset) + >>> dataset.set_output_keys(["wav"]) + >>> length_func = lambda x : len(x) # trivial in this example + >>> bsampler = DynamicBatchSampler(dataset, 20, 4, length_func, shuffle=False, batch_ordering='descending') + >>> dataloader = SaveableDataLoader(dataset, batch_sampler=bsampler, collate_fn=PaddedBatch) + >>> for i, b in enumerate(dataloader): + ... data, length = b["wav"] + >>> assert data.shape[-1] == max(item_lengths) + + Arguments + --------- + dataset : torch.utils.data.Dataset + Pytorch Dataset from which elements will be sampled. + max_batch_length : int + Upper limit for the sum of the length of examples in a batch. + Should be chosen based on your GPU memory. + num_buckets : int + Number of discrete buckets used to group examples together. + If num_buckets == 1, all examples can be batched together. As the number of buckets grows only examples with similar + length can be grouped together. This trades-off speed with randomization. + Low number -> better randomization, High number -> faster training. + However if set too high the training speed will decrease. If num_buckets -> number of examples in the dataset the batch size + will be small impacting training speed and possibly performance. + NOTE: you have either to specify manually the bucket_boundaries or the number of buckets. + length_func : callable + Function used to get length of each example from the dataset. + This argument can be used only when the dataset is a Speechbrain DynamicItemDataset object. + Can be anything: e.g. lambda x: x["duration"]*16000 returns number of samples + if duration key in the annotation is in seconds and the file has 16kHz sampling freq. + shuffle : bool + Whether or not shuffle examples between each epoch. + batch_ordering : string + If ``random``, batches are randomly permuted; otherwise ``ascending`` or ``descending`` sorted by length. + max_batch_ex: int + If set, it limits the maximum number of examples that can be in a batch superseeding max_batch_length + in instances where the amount of examples will exceeed the value specified here. + E.g. you have a lot of short examples and the batch size for those will be too high, you can use this argument + to limit the batch size for these short examples. + bucket_boundaries : list + Overrides bucket_length_multiplier and left_bucket_length by specifying manually + the buckets right boundaries. + lengths_list: list + Overrides length_func by passing a list containing the length of each example + in the dataset. This argument must be set when the dataset is a plain + Pytorch Dataset object and not a DynamicItemDataset object as length_func + cannot be used on Pytorch Datasets. + epoch : int + The epoch to start at. + drop_last : bool + If ``True``, the sampler will drop the last examples which + have not been grouped. + verbose: bool + If ``True``, log also the stats for each batch at the first epoch. + """ + + def __init__( + self, + dataset, + max_batch_length: int, + num_buckets: int = None, + length_func=lambda x: x["duration"], + shuffle: bool = True, + batch_ordering: str = "random", + max_batch_ex: int = None, + bucket_boundaries: List[int] = [], + lengths_list: List[int] = None, + seed: int = 42, + epoch: int = 0, + drop_last: bool = False, + verbose: bool = False, + ): + self._dataset = dataset + self._ex_lengths = {} + ex_ids = self._dataset.data_ids + self.verbose = verbose + + # We do not put a default on num_buckets to encourage users to play with this parameter + if num_buckets is None and len(bucket_boundaries) == 0: + raise RuntimeError( + "Please specify either num_buckets or bucket boundaries." + "Check the docs, and/or the tutorial !" + ) + + if lengths_list is not None: + # take length of examples from this argument and bypass length_key + for indx in range(len(lengths_list)): + self._ex_lengths[str(indx)] = lengths_list[indx] + else: + # use length func + if not isinstance(dataset, DynamicItemDataset): + raise NotImplementedError( + "Dataset should be a Speechbrain DynamicItemDataset when using length function" + ) + for indx in range(len(self._dataset)): + self._ex_lengths[str(indx)] = length_func( + self._dataset.data[ex_ids[indx]] + ) + + if len(bucket_boundaries) > 0: + if not all([x >= 0 for x in bucket_boundaries]): + raise ValueError( + "All elements in bucket boundaries should be non-negative (>= 0)." + ) + if not len(set(bucket_boundaries)) == len(bucket_boundaries): + raise ValueError( + "Bucket_boundaries should not contain duplicates." + ) + np.testing.assert_array_equal( + np.array(bucket_boundaries), + np.array(sorted(bucket_boundaries)), + err_msg="The arg bucket_boundaries should be an ascending sorted list of non negative values values!", + ) + self._bucket_boundaries = np.array(sorted(bucket_boundaries)) + else: + # use num_buckets + self._bucket_boundaries = np.array( + self._get_boundaries_through_warping( + max_batch_length=max_batch_length, + num_quantiles=num_buckets, + ) + ) + + self._max_batch_length = max_batch_length + self._shuffle_ex = shuffle + self._batch_ordering = batch_ordering + self._seed = seed + self._drop_last = drop_last + if max_batch_ex is None: + max_batch_ex = np.inf + self._max_batch_ex = max_batch_ex + # Calculate bucket lengths - how often does one bucket boundary fit into max_batch_length? + self._bucket_lens = [ + max(1, int(max_batch_length / self._bucket_boundaries[i])) + for i in range(len(self._bucket_boundaries)) + ] + [1] + self._epoch = epoch + self._generate_batches() + + def get_durations(self, batch): + """Gets durations of the elements in the batch.""" + return [self._ex_lengths[str(idx)] for idx in batch] + + def _get_boundaries_through_warping( + self, max_batch_length: int, num_quantiles: int, + ) -> List[int]: + + # NOTE: the following lines do not cover that there is only one example in the dataset + # warp frames (duration) distribution of train data + logger.info("Batch quantisation in latent space") + # linspace set-up + num_boundaries = num_quantiles + 1 + # create latent linearly equal spaced buckets + latent_boundaries = np.linspace( + 1 / num_boundaries, num_quantiles / num_boundaries, num_quantiles, + ) + # get quantiles using lognormal distribution + quantiles = lognorm.ppf(latent_boundaries, 1) + # scale up to to max_batch_length + bucket_boundaries = quantiles * max_batch_length / quantiles[-1] + # compute resulting bucket length multipliers + length_multipliers = [ + bucket_boundaries[x + 1] / bucket_boundaries[x] + for x in range(num_quantiles - 1) + ] + # logging + logger.info( + "Latent bucket boundary - buckets: {} - length multipliers: {}".format( + list(map("{:.2f}".format, bucket_boundaries)), + list(map("{:.2f}".format, length_multipliers)), + ) + ) + return list(sorted(bucket_boundaries)) + + def _permute_batches(self): + + if self._batch_ordering == "random": + # deterministically shuffle based on epoch and seed + gen = paddle.seed(1) + gen.manual_seed(self._seed + self._epoch) + sampler = torch.randperm( + len(self._batches) + ).tolist() # type: ignore + tmp = [] + for idx in sampler: + tmp.append(self._batches[idx]) + self._batches = tmp + + elif self._batch_ordering == "ascending": + self._batches = sorted( + self._batches, + key=lambda x: max([self._ex_lengths[str(idx)] for idx in x]), + ) + elif self._batch_ordering == "descending": + self._batches = sorted( + self._batches, + key=lambda x: max([self._ex_lengths[str(idx)] for idx in x]), + reverse=True, + ) + else: + raise NotImplementedError + + def _generate_batches(self): + logger.info("DynamicBatchSampler: Generating dynamic batches") + if self._shuffle_ex: + # deterministically shuffle based on epoch and seed + gen = paddle.seed(1) + gen.manual_seed(self._seed + self._epoch) + sampler = paddle.randperm(len(self._dataset)).tolist() # type: ignore + else: + # take examples as they are: e.g. they have been sorted + sampler = range(len(self._dataset)) # type: ignore + + self._batches = [] + bucket_batches = [[] for i in self._bucket_lens] + + stats_tracker = [ + {"min": np.inf, "max": -np.inf, "tot": 0, "n_ex": 0} + for i in self._bucket_lens + ] + + for idx in sampler: + # length of pre-sampled audio + item_len = self._ex_lengths[str(idx)] + # bucket to fill up most padding + bucket_id = np.searchsorted(self._bucket_boundaries, item_len) + # fill audio's duration into that bucket + bucket_batches[bucket_id].append(idx) + + stats_tracker[bucket_id]["min"] = min( + stats_tracker[bucket_id]["min"], item_len + ) + stats_tracker[bucket_id]["max"] = max( + stats_tracker[bucket_id]["max"], item_len + ) + stats_tracker[bucket_id]["tot"] += item_len + stats_tracker[bucket_id]["n_ex"] += 1 + # track #samples - why not duration/#frames; rounded up? + # keep track of durations, if necessary + + if ( + len(bucket_batches[bucket_id]) >= self._bucket_lens[bucket_id] + or len(bucket_batches[bucket_id]) >= self._max_batch_ex + ): + self._batches.append(bucket_batches[bucket_id]) + bucket_batches[bucket_id] = [] + # keep track of durations + + # Dump remaining batches + if not self._drop_last: + for batch in bucket_batches: + if batch: + self._batches.append(batch) + + self._permute_batches() # possibly reorder batches + + if self._epoch == 0: # only log at first epoch + # frames per batch & their padding remaining + boundaries = [0] + self._bucket_boundaries.tolist() + + for bucket_indx in range(len(self._bucket_boundaries)): + try: + num_batches = stats_tracker[bucket_indx]["tot"] // ( + self._max_batch_length + ) + pad_factor = ( + stats_tracker[bucket_indx]["max"] + - stats_tracker[bucket_indx]["min"] + ) / ( + stats_tracker[bucket_indx]["tot"] + / stats_tracker[bucket_indx]["n_ex"] + ) + except ZeroDivisionError: + num_batches = 0 + pad_factor = 0 + + logger.info( + ( + "DynamicBatchSampler: Bucket {} with boundary {:.1f}-{:.1f} and " + + "batch_size {}: Num Examples {:.1f}, Num Full Batches {:.3f}, Pad Factor {:.3f}." + ).format( + bucket_indx, + boundaries[bucket_indx], + boundaries[bucket_indx + 1], + self._bucket_lens[bucket_indx], + stats_tracker[bucket_indx]["n_ex"], + num_batches, + pad_factor * 100, + ) + ) + + if self.verbose: + batch_stats = { + "tot_frames": [], + "tot_pad_frames": [], + "pad_%": [], + } + for batch in self._batches: + tot_frames = sum( + [self._ex_lengths[str(idx)] for idx in batch] + ) + batch_stats["tot_frames"].append(tot_frames) + max_frames = max( + [self._ex_lengths[str(idx)] for idx in batch] + ) + tot_pad = sum( + [ + max_frames - self._ex_lengths[str(idx)] + for idx in batch + ] + ) + batch_stats["tot_pad_frames"].append(tot_pad) + batch_stats["pad_%"].append(tot_pad / tot_frames * 100) + + padding_details = "Batch {} with {:.1f} frames with {} files - {:.1f} padding, {:.2f} (%) of total." + padding_details = "DynamicBatchSampler: " + padding_details + for i in range(len(self._batches)): + logger.info( + padding_details.format( + i, + batch_stats["tot_frames"][i], + len(self._batches[i]), + batch_stats["tot_pad_frames"][i], + batch_stats["pad_%"][i], + ) + ) + + def __iter__(self): + for batch in self._batches: + yield batch + if self._shuffle_ex: # re-generate examples if ex_ordering == "random" + self._generate_batches() + if self._batch_ordering == "random": + # we randomly permute the batches only --> faster + self._permute_batches() + + def set_epoch(self, epoch): + """ + You can also just access self.epoch, but we maintain this interface + to mirror torch.utils.data.distributed.DistributedSampler + """ + self._epoch = epoch + self._generate_batches() + + def __len__(self): + return len(self._batches) + + +# Heavily inspired by Catalyst, which is under Apache 2.0 licence. +# https://github.com/catalyst-team/catalyst/blob/51428d7756e62b9b8ee5379f38e9fd576eeb36e5/catalyst/data/sampler.py#L522 +# class DistributedSamplerWrapper(DistributedSampler): +# """This wrapper allows using any sampler (for example batch) with Distributed Data Parallel (DDP) +# correctly. + +# Passing blindly the sampler to each DDP process will cause to have access +# within each process to all the data in the dataset instead of only a subset +# of it which is unique to each process. This wrapper prevents this and +# allows to use only a subset of the original data for each process. + +# NOTE +# ---- +# This is is automatically applied to any sampler in the Brain class when DDP +# training is used. +# """ + +# def __init__(self, sampler, *args, **kwargs): +# # DistributedSampler only calls len() on dataset +# # so a sampler is fine to pass there, as well. +# super().__init__(dataset=sampler, *args, **kwargs) +# self.sampler = sampler + +# def __iter__(self): +# # It is easiest to use a random access interface to the wrapped +# # sampler's indices, so we just fetch all indices from the wrapped +# # sampler +# sampler_indices = list(self.sampler.__iter__()) +# indices_of_indices = super().__iter__() +# # Itemgetter fetches the wrapped sampler indices from the positions +# # pointed to by DistributedSampler +# return iter(itemgetter(*indices_of_indices)(sampler_indices)) + +# def set_epoch(self, epoch): +# """Pass set_epoch() through to DistributedSampler and the wrapper one""" +# super().set_epoch(epoch) +# if hasattr(self.sampler, "set_epoch"): +# self.sampler.set_epoch(epoch) + + +class BalancingDataSampler(ReproducibleWeightedRandomSampler): + """A data sampler that takes a single key from the dataset and + ensures an approximately equal distribution by that key + + Arguments + --------- + dataset: DynamicItemDataset + the dataset form which samples will be drawn + key: str + the key from which samples will be taken + num_samples : int + Number of samples to draw + replacement : bool + To draw with replacement or not (within an epoch of num_samples). + seed : int + The base seed to use for the random number generator. It is recommended + to use a value which has a good mix of 0 and 1 bits. + epoch : int + The epoch to start at. + + Example + ------- + >>> from speechbrain.dataio.sampler import BalancingDataSampler + >>> from speechbrain.dataio.dataset import DynamicItemDataset + >>> sample_data = { + ... 1: {"category": "A", + ... "text": "This is a test"}, + ... 2: {"category": "A", + ... "text": "This is a second test"}, + ... 3: {"category": "B", + ... "text": "This is a third test"} + ... } + >>> dataset = DynamicItemDataset(data=sample_data) + >>> sampler = BalancingDataSampler( + ... dataset=dataset, + ... key="category", + ... num_samples=10 + ... ) + >>> sampler.weights + tensor([0.5000, 0.5000, 1.0000], dtype=torch.float64) + >>> it = iter(sampler) + >>> [next(it) for _ in range(10)] + [2, 2, 1, 2, 2, 0, 1, 1, 1, 2] + """ + + def __init__( + self, + dataset, + key, + num_samples=None, + replacement=True, + seed=563375142, + epoch=0, + **kwargs, + ): + self.dataset = dataset + self.key = key + if not num_samples: + num_samples = len(dataset) + weights = self._compute_weights() + super().__init__( + weights, num_samples, replacement, seed, epoch, **kwargs + ) + + def _compute_weights(self): + with self.dataset.output_keys_as([self.key]): + class_ids = [item[self.key] for item in self.dataset] + class_counter = Counter(class_ids) + weights = 1 / paddle.to_tensor( + [class_counter[class_id] for class_id in class_ids] + ) + return weights diff --git a/paddlespeech/s2t/io/wav2vec2/sb_pipeline.py b/paddlespeech/s2t/io/wav2vec2/sb_pipeline.py new file mode 100755 index 000000000..56128fa0b --- /dev/null +++ b/paddlespeech/s2t/io/wav2vec2/sb_pipeline.py @@ -0,0 +1,162 @@ +import transformers +from hyperpyyaml import load_hyperpyyaml +import dataset +import data_pipeline +from dataloader import make_dataloader +import dataio +import paddle +import tqdm +import numpy +def dataio_prepare(hparams): + """This function prepares the datasets to be used in the brain class. + It also defines the data processing pipeline through user-defined functions.""" + data_folder = hparams["data_folder"] + + train_data = dataset.DynamicItemDataset.from_csv( + csv_path=hparams["train_data"], replacements={"data_root": data_folder}, + ) + + if hparams["sorting"] == "ascending": + # we sort training data to speed up training and get better results. + train_data = train_data.filtered_sorted(sort_key="duration") + # when sorting do not shuffle in dataloader ! otherwise is pointless + hparams["train_dataloader_opts"]["shuffle"] = False + + elif hparams["sorting"] == "descending": + train_data = train_data.filtered_sorted( + sort_key="duration", reverse=True + ) + # when sorting do not shuffle in dataloader ! otherwise is pointless + hparams["train_dataloader_opts"]["shuffle"] = False + + elif hparams["sorting"] == "random": + pass + + else: + raise NotImplementedError( + "sorting must be random, ascending or descending" + ) + + valid_data = dataset.DynamicItemDataset.from_csv( + csv_path=hparams["valid_data"], replacements={"data_root": data_folder}, + ) + valid_data = valid_data.filtered_sorted(sort_key="duration") + + test_data = dataset.DynamicItemDataset.from_csv( + csv_path=hparams["test_data"], replacements={"data_root": data_folder}, + ) + test_data = test_data.filtered_sorted(sort_key="duration") + + datasets = [train_data, valid_data, test_data] + + # Defining tokenizer and loading it + tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-chinese') + + # 2. Define audio pipeline: + @data_pipeline.takes("wav") + @data_pipeline.provides("sig") + def audio_pipeline(wav): + sig = dataio.read_audio(wav) + return sig + + dataset.add_dynamic_item(datasets, audio_pipeline) + + # 3. Define text pipeline: + @data_pipeline.takes("transcript") + @data_pipeline.provides("wrd", "tokens_list", "tokens") + def text_pipeline(wrd): + wrd = "".join(wrd.split(" ")) + yield wrd + tokens_list = tokenizer(wrd)["input_ids"] + yield tokens_list + tokens = numpy.array(tokens_list, dtype="int64") + # tokens = paddle.to_tensor(tokens_list, dtype="int64") + yield tokens + + dataset.add_dynamic_item(datasets, text_pipeline) + + # 4. Set output: + dataset.set_output_keys( + datasets, ["id", "sig", "wrd", "tokens"], + ) + + # 5. If Dynamic Batching is used, we instantiate the needed samplers. + train_batch_sampler = None + valid_batch_sampler = None + if hparams["dynamic_batching"]: + from sampler import DynamicBatchSampler # noqa + + dynamic_hparams = hparams["dynamic_batch_sampler"] + num_buckets = dynamic_hparams["num_buckets"] + + train_batch_sampler = DynamicBatchSampler( + train_data, + dynamic_hparams["max_batch_len"], + num_buckets=num_buckets, + length_func=lambda x: x["duration"], + shuffle=dynamic_hparams["shuffle_ex"], + batch_ordering=dynamic_hparams["batch_ordering"], + ) + + valid_batch_sampler = DynamicBatchSampler( + valid_data, + dynamic_hparams["max_batch_len"], + num_buckets=num_buckets, + length_func=lambda x: x["duration"], + shuffle=dynamic_hparams["shuffle_ex"], + batch_ordering=dynamic_hparams["batch_ordering"], + ) + + return ( + train_data, + valid_data, + test_data, + tokenizer, + train_batch_sampler, + valid_batch_sampler, + ) + + + + +hparams_file = 'train_with_wav2vec.yaml' +with open(hparams_file) as fin: + hparams = load_hyperpyyaml(fin, None) + +( + train_data, + valid_data, + test_data, + tokenizer, + train_bsampler, + valid_bsampler, +) = dataio_prepare(hparams) + +train_dataloader_opts = hparams["train_dataloader_opts"] +valid_dataloader_opts = hparams["valid_dataloader_opts"] + +if train_bsampler is not None: + train_dataloader_opts = { + "batch_sampler": train_bsampler, + "num_workers": hparams["num_workers"], + } + +if valid_bsampler is not None: + valid_dataloader_opts = {"batch_sampler": valid_bsampler} + + +train_set = make_dataloader( + train_data, stage='train', **train_dataloader_opts +) + +valid_set = make_dataloader( + valid_data, + stage='train', + **valid_dataloader_opts, +) + +# print(len(train_set)) + +for batch in valid_set: + print(batch) +print('done') # exit() \ No newline at end of file diff --git a/paddlespeech/s2t/io/wav2vec2/train_with_wav2vec.yaml b/paddlespeech/s2t/io/wav2vec2/train_with_wav2vec.yaml new file mode 100755 index 000000000..a1f1dcfbe --- /dev/null +++ b/paddlespeech/s2t/io/wav2vec2/train_with_wav2vec.yaml @@ -0,0 +1,86 @@ +# ############################################################################ +# Model: CTC-wav2vec2 +# Encoder: wav2vec2 +# Decoder: - +# Tokens: Char +# losses: CTC +# Training: AISHELL-1 +# Authors: Yingzhi WANG 2022 +# ############################################################################ + +seed: 10 +__set_seed: !apply:torch.manual_seed [!ref ] +output_folder: !ref /home/zhangtianhao/workspace/speechbrain/recipes/AISHELL-1/ASR/CTC/results/ctc_wav2vec/ +cer_file: !ref /cer.txt +save_folder: !ref /save +train_log: !ref /train_log.txt + +# Data files +data_folder: /home/zhangtianhao/workspace/PaddleSpeech/dataset/aishell # e,g./path/to/aishell + +skip_prep: False +ckpt_interval_minutes: 15 # save checkpoint every N min +train_data: !ref /train.csv +valid_data: !ref /dev.csv +test_data: !ref /test.csv + +wav2vec2_hub: TencentGameMate/chinese-wav2vec2-large + +# Training parameters +number_of_epochs: 80 +lr: 1.0 +lr_wav2vec: 0.0001 +sorting: ascending +auto_mix_prec: False +sample_rate: 16000 + +# With data_parallel batch_size is split into N jobs +# With DDP batch_size is multiplied by N jobs +# Must be 8 per GPU to fit 32GB of VRAM +batch_size: 12 +test_batch_size: 8 + +dynamic_batching: False +dynamic_batch_sampler: + feats_hop_size: 0.01 + max_batch_len: 15 # in terms of "duration" in annotations by default, second here + left_bucket_len: 200 # old implementation attributs + multiplier: 1.1 # old implementation attributs + shuffle_ex: False # if true re-creates batches at each epoch shuffling examples. + num_buckets: 10 # floor(log(max_batch_len/left_bucket_len, multiplier)) + 1 + batch_ordering: ascending + +num_workers: 4 + +# Dataloader options +train_dataloader_opts: + batch_size: !ref + num_workers: !ref +valid_dataloader_opts: + batch_size: !ref + num_workers: !ref +test_dataloader_opts: + batch_size: !ref + num_workers: !ref + +wav2vec_output_dim: 1024 +dnn_neurons: 1024 +freeze_wav2vec: False +dropout: 0.15 + +tokenizer: !apply:transformers.BertTokenizer.from_pretrained + pretrained_model_name_or_path: bert-base-chinese +# bert-base-chinese tokens length +output_neurons: 21128 + +# Decoding parameters +# Be sure that the bos and eos index match with the BPEs ones +blank_index: 0 + +# AISHELL-1 has spaces between words in the transcripts, +# which Chinese writing normally does not do. +# If remove_spaces, spaces are removed +# from the transcript before computing CER. +# (e.g., 祝 可爱 的 你 —> 祝可爱的你) +remove_spaces: True +split_tokens: !apply:operator.not_ [!ref ] diff --git a/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py b/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py index eda188da5..ff8f279d8 100755 --- a/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py +++ b/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py @@ -1,16 +1,3 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. from collections import defaultdict from typing import Dict from typing import List @@ -57,18 +44,27 @@ class Wav2vec2ASR(nn.Layer): def forward(self, wav, wavs_lens_rate, target, target_lens): if self.normalize_wav: wav = F.layer_norm(wav, wav.shape) + # Extract wav2vec output out = self.wav2vec2(wav)[0] # We normalize the output if required if self.output_norm: out = F.layer_norm(out, out.shape) - if self.train and hasattr(self.config, 'spec_augment'): + + if self.training and hasattr(self.config, 'spec_augment'): feats = self.spec_augment(out) else: feats = out x = self.enc(feats) + x_lens = (wavs_lens_rate * x.shape[1]).round().astype(paddle.int64) + + # sb target_lens = rate + # target_lens = (target_lens * + # target.shape[1]).round().astype(paddle.int64) + ctc_loss = self.ctc(x, x_lens, target, target_lens) + # print(target_lens_rate) return ctc_loss @paddle.no_grad() @@ -76,7 +72,8 @@ class Wav2vec2ASR(nn.Layer): feats: paddle.Tensor, text_feature: Dict[str, int], decoding_method: str, - beam_size: int): + beam_size: int, + sb_pipeline=False): batch_size = feats.shape[0] if decoding_method == 'ctc_prefix_beam_search' and batch_size > 1: @@ -87,9 +84,32 @@ class Wav2vec2ASR(nn.Layer): sys.exit(1) if decoding_method == 'ctc_greedy_search': - hyps = self.ctc_greedy_search(feats) - res = [text_feature.defeaturize(hyp) for hyp in hyps] - res_tokenids = [hyp for hyp in hyps] + if not sb_pipeline: + hyps = self.ctc_greedy_search(feats) + res = [text_feature.defeaturize(hyp) for hyp in hyps] + res_tokenids = [hyp for hyp in hyps] + else: + hyps = self.ctc_greedy_search(feats.unsqueeze(-1)) + + predicted_words_list = [] + for sequence in hyps: + # Decode token terms to words + predicted_tokens = text_feature.convert_ids_to_tokens( + sequence + ) + + predicted_words = [] + for c in predicted_tokens: + if c == "[CLS]": + continue + elif c == "[SEP]" or c == "[PAD]": + break + else: + predicted_words.append(c) + print(predicted_words) + exit() + predicted_words_list.append(predicted_words) + # ctc_prefix_beam_search and attention_rescoring only return one # result in List[int], change it to List[List[int]] for compatible # with other batch decoding mode @@ -238,33 +258,3 @@ class Wav2vec2ASR(nn.Layer): """ hyps = self._ctc_prefix_beam_search(wav, beam_size) return hyps[0][0] - - -class Wav2vec2Base(nn.Layer): - """Wav2vec2 model""" - - def __init__(self, config: dict): - super().__init__() - wav2vec2_config = Wav2Vec2ConfigPure(config) - wav2vec2 = Wav2Vec2Model(wav2vec2_config) - self.wav2vec2 = wav2vec2 - - @classmethod - def from_config(cls, configs: dict): - """init model. - - Args: - configs (dict): config dict. - - Raises: - ValueError: raise when using not support encoder type. - - Returns: - nn.Layer: Wav2Vec2Base - """ - model = cls(configs) - return model - - def forward(self, wav): - out = self.wav2vec2(wav) - return out