diff --git a/dataset/librispeech/librispeech.py b/dataset/librispeech/librispeech.py index 2d6f1763d..32b8746c2 100644 --- a/dataset/librispeech/librispeech.py +++ b/dataset/librispeech/librispeech.py @@ -133,7 +133,7 @@ def create_manifest(data_dir, manifest_path): def prepare_dataset(url, md5sum, target_dir, manifest_path): """Download, unpack and create summmary manifest file. """ - if not os.path.exists(os.path.join(target_dir, "LibriSpeech")): + if not os.path.exists(os.path.join(target_dir)): # download filepath = download(url, md5sum, target_dir) # unpack diff --git a/examples/librispeech/asr3/conf/hubertASR.yaml b/examples/librispeech/asr3/conf/hubertASR.yaml new file mode 100644 index 000000000..efb549591 --- /dev/null +++ b/examples/librispeech/asr3/conf/hubertASR.yaml @@ -0,0 +1,133 @@ +############################################ +# Network Architecture # +############################################ +freeze_hubert: True +normalize_wav: True +output_norm: True +init_type: kaiming_uniform # !Warning: need to convergence +enc: + input_shape: 1024 + dnn_blocks: 2 + dnn_neurons: 1024 + activation: True +ctc: + enc_n_units: 1024 + blank_id: 0 + dropout_rate: 0.0 +hubert_params_path: "exp/hubert/pd_hubert.pdparams" + + +task_cfg: + sample_rate: 16000 + +model_cfg: + dropout_input: 0.0 + final_dropout: 0.0 + dropout: 0.0 + attention_dropout: 0.0 + activation_dropout: 0.1 + apply_mask: True + mask_length: 10 + mask_prob: 0.5 + mask_selection: static + mask_other: 0.0 + no_mask_overlap: False + mask_channel_length: 64 + mask_channel_prob: 0.25 + mask_channel_selection: static + mask_channel_other: 0.0 + no_mask_channel_overlap: False + freeze_finetune_updates: 10000 + feature_grad_mult: 0.0 + layerdrop: 0.1 + normalize: True + fp16: True + label_rate: 50 + extractor_mode: layer_norm + encoder_layers: 24 + encoder_embed_dim: 1024 + encoder_ffn_embed_dim: 4096 + encoder_attention_heads: 16 + activation_fn: gelu + encoder_layerdrop: 0.1 + dropout_features: 0.0 + final_dim: 768 + untie_final_proj: True + layer_norm_first: True + conv_feature_layers: "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" + conv_bias: False + logit_temp: 0.1 + target_glu: False + mask_min_space: 1 + mask_channel_min_space: 1 + conv_pos: 128 + conv_pos_groups: 16 + latent_temp: [2.0, 0.5, 0.999995] + skip_masked: False + skip_nomask: True + +########################################### +# Data # +########################################### +train_manifest: data/manifest.train +dev_manifest: data/manifest.dev +test_manifest: data/manifest.test-clean + +########################################### +# Dataloader # +########################################### +vocab_filepath: data/lang_char/vocab.txt +unit_type: char +mean_std_filepath: "" +preprocess_config: conf/preprocess.yaml +sortagrad: -1 # Feed samples from shortest to longest ; -1: enabled for all epochs 0: disabled other: enabled for other epochs +batch_size: 8 # Different batch_size may cause large differences in results +maxlen_in: 51200000000 # if input length > maxlen-in batchsize is automatically reduced +maxlen_out: 1500000 # if output length > maxlen-out batchsize is automatically reduced +minibatches: 0 # for debug +batch_count: auto +batch_bins: 0 +batch_frames_in: 0 +batch_frames_out: 0 +batch_frames_inout: 0 +num_workers: 0 +subsampling_factor: 1 +num_encs: 1 +dist_sampler: True +shortest_first: True +return_lens_rate: True + +############################################ +# Data Augmentation # +############################################ +audio_augment: # for raw audio + sample_rate: 16000 + +########################################### +# Training # +########################################### +n_epoch: 1 +accum_grad: 1 +global_grad_clip: 5.0 +model_optim: adadelta +model_optim_conf: + lr: 1.0 + epsilon: 1.0e-6 + rho: 0.95 +model_scheduler: constantlr +model_scheduler_conf: + warmup_steps: 25000 + lr_decay: 1.0 +hubert_optim: adadelta +hubert_optim_conf: + lr: 0.9 + epsilon: 1.0e-6 + rho: 0.95 +hubert_scheduler: constantlr +hubert_scheduler_conf: + warmup_steps: 25000 + lr_decay: 1.0 +log_interval: 1 +checkpoint: + kbest_n: 50 + latest_n: 5 \ No newline at end of file diff --git a/examples/librispeech/asr3/local/data.sh b/examples/librispeech/asr3/local/data.sh old mode 100644 new mode 100755 index 8495a4ab6..edea3e19b --- a/examples/librispeech/asr3/local/data.sh +++ b/examples/librispeech/asr3/local/data.sh @@ -1,6 +1,6 @@ #!/bin/bash -stage=-1 +stage=0 stop_stage=100 unit_type=char diff --git a/examples/librispeech/asr3/local/test.sh b/examples/librispeech/asr3/local/test.sh old mode 100644 new mode 100755 diff --git a/examples/librispeech/asr3/local/test_wav.sh b/examples/librispeech/asr3/local/test_wav.sh old mode 100644 new mode 100755 diff --git a/examples/librispeech/asr3/local/train.sh b/examples/librispeech/asr3/local/train.sh old mode 100644 new mode 100755 index 24776fd17..10d254f0b --- a/examples/librispeech/asr3/local/train.sh +++ b/examples/librispeech/asr3/local/train.sh @@ -38,7 +38,7 @@ python3 -u ${BIN_DIR}/train.py \ --seed ${seed} \ --resume ${resume} else -python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${ips_config} ${BIN_DIR}/train.py \ +python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} --log_dir=exp/log/${ckpt_name} ${ips_config} ${BIN_DIR}/train.py \ --ngpu ${ngpu} \ --config ${config_path} \ --output exp/${ckpt_name} \ diff --git a/examples/librispeech/asr3/path.sh b/examples/librispeech/asr3/path.sh index f47178382..20df660da 100644 --- a/examples/librispeech/asr3/path.sh +++ b/examples/librispeech/asr3/path.sh @@ -10,6 +10,5 @@ export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/ - -MODEL=wav2vec2 +MODEL=$1 export BIN_DIR=${MAIN_ROOT}/paddlespeech/s2t/exps/${MODEL}/bin diff --git a/examples/librispeech/asr3/run.sh b/examples/librispeech/asr3/run.sh old mode 100644 new mode 100755 index 05ad505c7..53b885e6d --- a/examples/librispeech/asr3/run.sh +++ b/examples/librispeech/asr3/run.sh @@ -1,13 +1,14 @@ #!/bin/bash set -e -. ./path.sh || exit 1; +MODEL=hubert +. ./path.sh ${MODEL} || exit 1; . ./cmd.sh || exit 1; -gpus=0 -stage=0 -stop_stage=0 -conf_path=conf/wav2vec2ASR.yaml +gpus=2 +stage=1 +stop_stage=1 +conf_path=conf/${MODEL}ASR.yaml ips= #xx.xx.xx.xx,xx.xx.xx.xx decode_conf_path=conf/tuning/decode.yaml avg_num=1 @@ -19,6 +20,7 @@ audio_file=data/demo_002_en.wav avg_ckpt=avg_${avg_num} ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}') +ckpt=test3 echo "checkpoint name ${ckpt}" if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then diff --git a/paddlespeech/__init__.py b/paddlespeech/__init__.py index 6c7e75c1f..124649987 100644 --- a/paddlespeech/__init__.py +++ b/paddlespeech/__init__.py @@ -13,3 +13,19 @@ # limitations under the License. import _locale _locale._getdefaultlocale = (lambda *args: ['en_US', 'utf8']) + + + + + + + + + + + + + + + + diff --git a/paddlespeech/s2t/exps/hubert/__init__.py b/paddlespeech/s2t/exps/hubert/__init__.py new file mode 100644 index 000000000..97043fd7b --- /dev/null +++ b/paddlespeech/s2t/exps/hubert/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/s2t/exps/hubert/bin/__init__.py b/paddlespeech/s2t/exps/hubert/bin/__init__.py new file mode 100644 index 000000000..97043fd7b --- /dev/null +++ b/paddlespeech/s2t/exps/hubert/bin/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/s2t/exps/hubert/bin/test.py b/paddlespeech/s2t/exps/hubert/bin/test.py new file mode 100644 index 000000000..7a9509afe --- /dev/null +++ b/paddlespeech/s2t/exps/hubert/bin/test.py @@ -0,0 +1,64 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Evaluation for hubert model.""" +import cProfile + +from yacs.config import CfgNode + +from paddlespeech.s2t.exps.hubert.model import HubertASRTester as Tester +from paddlespeech.s2t.training.cli import default_argument_parser +from paddlespeech.s2t.utils.utility import print_arguments + + +def main_sp(config, args): + exp = Tester(config, args) + with exp.eval(): + exp.setup() + exp.run_test() + + +def main(config, args): + main_sp(config, args) + + +if __name__ == "__main__": + parser = default_argument_parser() + # save asr result to + parser.add_argument( + '--dict-path', type=str, default=None, help='dict path.') + parser.add_argument( + "--result_file", type=str, help="path of save the asr result") + args = parser.parse_args() + print_arguments(args, globals()) + + # https://yaml.org/type/float.html + config = CfgNode(new_allowed=True) + if args.config: + config.merge_from_file(args.config) + if args.decode_cfg: + decode_confs = CfgNode(new_allowed=True) + decode_confs.merge_from_file(args.decode_cfg) + config.decode = decode_confs + if args.opts: + config.merge_from_list(args.opts) + config.freeze() + print(config) + if args.dump_config: + with open(args.dump_config, 'w') as f: + print(config, file=f) + + # Setting for profiling + pr = cProfile.Profile() + pr.runcall(main, config, args) + pr.dump_stats('test.profile') diff --git a/paddlespeech/s2t/exps/hubert/bin/test_wav.py b/paddlespeech/s2t/exps/hubert/bin/test_wav.py new file mode 100644 index 000000000..4a2b0158d --- /dev/null +++ b/paddlespeech/s2t/exps/hubert/bin/test_wav.py @@ -0,0 +1,118 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Evaluation for hubert model.""" +import os +import sys +from pathlib import Path + +import paddle +import soundfile +from yacs.config import CfgNode + +from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer +from paddlespeech.s2t.models.hubert.hubert_ASR import HubertASR +from paddlespeech.s2t.training.cli import default_argument_parser +from paddlespeech.s2t.utils.log import Log +from paddlespeech.s2t.utils.utility import UpdateConfig +logger = Log(__name__).getlog() + + +class HubertInfer(): + def __init__(self, config, args): + self.args = args + self.config = config + self.audio_file = args.audio_file + + self.text_feature = TextFeaturizer( + unit_type=config.unit_type, vocab=config.vocab_filepath) + paddle.set_device('gpu' if self.args.ngpu > 0 else 'cpu') + + # model + model_conf = config + with UpdateConfig(model_conf): + model_conf.output_dim = self.text_feature.vocab_size + model = HubertASR.from_config(model_conf) + self.model = model + self.model.eval() + + # load model + params_path = self.args.checkpoint_path + ".pdparams" + model_dict = paddle.load(params_path) + self.model.set_state_dict(model_dict) + + def run(self): + check(args.audio_file) + + with paddle.no_grad(): + # read + audio, _ = soundfile.read( + self.audio_file, dtype="int16", always_2d=True) + logger.info(f"audio shape: {audio.shape}") + + xs = paddle.to_tensor(audio, dtype='float32').unsqueeze(axis=0) + decode_config = self.config.decode + result_transcripts, result_tokenids = self.model.decode( + xs, + text_feature=self.text_feature, + decoding_method=decode_config.decoding_method, + beam_size=decode_config.beam_size) + rsl = result_transcripts[0] + utt = Path(self.audio_file).name + logger.info(f"hyp: {utt} {rsl}") + return rsl + + +def check(audio_file): + if not os.path.isfile(audio_file): + print("Please input the right audio file path") + sys.exit(-1) + + logger.info("checking the audio file format......") + try: + sig, sample_rate = soundfile.read(audio_file) + except Exception as e: + logger.error(str(e)) + logger.error( + "can not open the wav file, please check the audio file format") + sys.exit(-1) + logger.info("The sample rate is %d" % sample_rate) + assert (sample_rate == 16000) + logger.info("The audio file format is right") + + +def main(config, args): + HubertInfer(config, args).run() + + +if __name__ == "__main__": + parser = default_argument_parser() + # save asr result to + parser.add_argument( + "--result_file", type=str, help="path of save the asr result") + parser.add_argument( + "--audio_file", type=str, help="path of the input audio file") + args = parser.parse_args() + + config = CfgNode(new_allowed=True) + + if args.config: + config.merge_from_file(args.config) + if args.decode_cfg: + decode_confs = CfgNode(new_allowed=True) + decode_confs.merge_from_file(args.decode_cfg) + config.decode = decode_confs + if args.opts: + config.merge_from_list(args.opts) + config.freeze() + main(config, args) diff --git a/paddlespeech/s2t/exps/hubert/bin/train.py b/paddlespeech/s2t/exps/hubert/bin/train.py new file mode 100644 index 000000000..fd66992f0 --- /dev/null +++ b/paddlespeech/s2t/exps/hubert/bin/train.py @@ -0,0 +1,55 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Trainer for hubert model.""" +import cProfile +import os + +from yacs.config import CfgNode + +from paddlespeech.s2t.exps.hubert.model import HubertASRTrainer as Trainer +from paddlespeech.s2t.training.cli import default_argument_parser +from paddlespeech.s2t.utils.utility import print_arguments + + +def main_sp(config, args): + exp = Trainer(config, args) + exp.setup() + exp.run() + + +def main(config, args): + main_sp(config, args) + + +if __name__ == "__main__": + parser = default_argument_parser() + parser.add_argument( + '--resume', type=str, default="", nargs="?", help='resume ckpt path.') + args = parser.parse_args() + print_arguments(args, globals()) + # https://yaml.org/type/float.html + config = CfgNode(new_allowed=True) + if args.config: + config.merge_from_file(args.config) + if args.opts: + config.merge_from_list(args.opts) + config.freeze() + if args.dump_config: + with open(args.dump_config, 'w') as f: + print(config, file=f) + + # Setting for profiling + pr = cProfile.Profile() + pr.runcall(main, config, args) + pr.dump_stats(os.path.join(args.output, 'train.profile')) diff --git a/paddlespeech/s2t/exps/hubert/model.py b/paddlespeech/s2t/exps/hubert/model.py new file mode 100644 index 000000000..1e5496f41 --- /dev/null +++ b/paddlespeech/s2t/exps/hubert/model.py @@ -0,0 +1,918 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains hubert model.""" +import json +import math +import os +import re +import time +from collections import OrderedDict +from contextlib import nullcontext + +import jsonlines +import numpy as np +import paddle +from hyperpyyaml import load_hyperpyyaml +from paddle import distributed as dist +from paddlenlp.transformers import AutoTokenizer + +from paddlespeech.s2t.frontend.featurizer import TextFeaturizer +from paddlespeech.s2t.io.dataloader import DataLoaderFactory +from paddlespeech.s2t.io.speechbrain import data_pipeline +from paddlespeech.s2t.io.speechbrain import dataio +from paddlespeech.s2t.io.speechbrain import dataset +from paddlespeech.s2t.io.speechbrain.dataloader import make_dataloader +from paddlespeech.s2t.models.wav2vec2.processing.speech_augmentation import TimeDomainSpecAugment +from paddlespeech.s2t.models.hubert.hubert_ASR import HubertASR +from paddlespeech.s2t.training.optimizer import OptimizerFactory +from paddlespeech.s2t.training.reporter import ObsScope +from paddlespeech.s2t.training.reporter import report +from paddlespeech.s2t.training.scheduler import LRSchedulerFactory +from paddlespeech.s2t.training.timer import Timer +from paddlespeech.s2t.training.trainer import Trainer +from paddlespeech.s2t.utils import error_rate +from paddlespeech.s2t.utils import layer_tools +from paddlespeech.s2t.utils import mp_tools +from paddlespeech.s2t.utils.log import Log +from paddlespeech.s2t.utils.utility import UpdateConfig + +logger = Log(__name__).getlog() + + +def clip_grad_norm_( + parameters, + max_norm, + norm_type=2.0, + error_if_nonfinite=False, ): + r"""Clips gradient norm of the iteratable parameters. + + Norms are calculated together on all gradients, just as they are + connected into one vector. The gradient will be modified in place. + + This API can only run in dynamic graph mode, not static graph mode. + + Args: + parameters (Iterable[paddle.Tensor] or paddle.Tensor): Tensors or a single Tensor + that will be normalized gradients + max_norm (float or int): max norm of the gradients + norm_type (float or int): type of the used p-norm. Can be `inf` for + infinity norm. + error_if_nonfinite (bool): if True, throw an error if the total + norm of the gradients from :attr:`parameters` is `nan`, + `inf`, or `-inf`. + + Returns: + Total norm of the parameter gradients (treated as a single vector). + Example: + .. code-block:: python + import paddle + + x = paddle.uniform([10, 10], min=-1.0, max=1.0, dtype='float32') + max_norm = float(5.0) + linear = paddle.nn.Linear(in_features=10, out_features=10) + out = linear(x) + loss = paddle.mean(out) + loss.backward() + + paddle.nn.utils.clip_grad_norm_(linear.parameters(), max_norm) + + sdg = paddle.optimizer.SGD(learning_rate=0.1, parameters=linear.parameters()) + sdg.step() + """ + if not paddle.in_dynamic_mode(): + raise RuntimeError('this API can only run in dynamic mode.') + + if isinstance(parameters, paddle.Tensor): + parameters = [parameters] + + support_norm_type = [float("inf"), 0, 1, 2] + if norm_type not in support_norm_type: + raise ValueError(f'norm_type only support {support_norm_type}') + + grads = [p.grad for p in parameters if p.grad is not None] + max_norm = float(max_norm) + norm_type = float(norm_type) + if len(grads) == 0: + return paddle.to_tensor(0.0) + if norm_type == float("inf"): + norms = [g.detach().abs().max() for g in grads] + total_norm = (norms[0] + if len(norms) == 1 else paddle.max(paddle.stack(norms))) + else: + total_norm = paddle.linalg.norm( + paddle.stack( + [paddle.linalg.norm(g.detach(), norm_type) for g in grads]), + norm_type, ) + + if error_if_nonfinite and paddle.logical_or(total_norm.isnan(), + total_norm.isinf()): + raise RuntimeError( + f'The total norm of {norm_type} order of the gradients from ' + '`parameters` is non-finite, so it cannot be clipped. In any case, ' + 'disable this error and scale the gradient by non-finite norm, ' + 'set `error_if_nonfinite=False`') + clip_coef = max_norm / (total_norm + 1e-6) + # Note: when the coef is clamped to 1, it is redundant to multiply the clamped coef, but this + # avoids the `if clip_coef < 1:` condition. + clip_coef_clamped = paddle.clip(clip_coef, max=1.0) + with paddle.no_grad(): + for _, p in enumerate(parameters): + g = p.grad + if g is not None: + p.grad = paddle.multiply(x=g, y=clip_coef_clamped) + return total_norm + + +class HubertASRTrainer(Trainer): + def __init__(self, config, args): + super().__init__(config, args) + self.avg_train_loss = 0.0 + self.loss_isfinite = True # while flag is 'False', loss in Nan or inf, and can not be avg + self.use_sb = True # whether use speech brain dataloader + + def update_average(self, batch_index, loss): + """Update running average of the loss. + Arguments + --------- + batch_index : int + current batch index + loss : paddle.tensor + detached loss, a single float value. + """ + if math.isfinite(loss): + self.avg_train_loss -= self.avg_train_loss / (batch_index + 1) + self.avg_train_loss += loss / (batch_index + 1) + else: + self.loss_isfinite = False + logger.info('loss:{} in Nan or inf, error'.format(loss)) + + def before_train(self): + from_scratch = self.resume_or_scratch() + if from_scratch: + # scratch: save init model, i.e. 0 epoch + self.save(tag='init', infos=None) + else: + # resume: train next_epoch and next_iteration + self.epoch += 1 + logger.info( + f"Resume train: epoch {self.epoch }, step {self.iteration}!") + + self.maybe_batch_sampler_step() + + def train_batch(self, batch_index, batch, msg): + train_conf = self.config + start = time.time() + + # forward + ## sb data pipeline + if self.use_sb: + wav, wavs_lens_rate = batch['sig'] + target, target_lens_rate = batch['tokens'] + target_lens = (target_lens_rate * + target.shape[1]).round().astype(paddle.int64) + else: + utt, wav, wavs_lens, target, target_lens = batch + wavs_lens_rate = wavs_lens / wav.shape[1] + wav = wav[:, :, 0] + + # if hasattr(train_conf, 'audio_augment'): + # wav = self.speech_augmentation(wav, wavs_lens_rate) + + loss = self.model(wav, wavs_lens_rate, target, target_lens) + + # loss div by `batch_size * accum_grad` + loss /= train_conf.accum_grad + # update self.avg_train_loss + self.update_average(batch_index, float(loss)) + + # loss backward + if (batch_index + 1) % train_conf.accum_grad != 0: + # Disable gradient synchronizations across DDP processes. + # Within this context, gradients will be accumulated on module + # variables, which will later be synchronized. + # When using cpu w/o DDP, model does not have `no_sync` + context = self.model.no_sync if (hasattr(self.model, "no_sync") and + self.parallel) else nullcontext + else: + # Used for single gpu training and DDP gradient synchronization + # processes. + context = nullcontext + with context(): + loss.backward() + + layer_tools.print_grads(self.model, print_func=None) + + # optimizer step old + if (batch_index + 1) % train_conf.accum_grad == 0: + #do global grad clip + if train_conf.global_grad_clip != 0: + clip_grad_norm_(self.model.parameters(), + train_conf.global_grad_clip) + self.model_optimizer.step() + self.model_optimizer.clear_grad() + if not train_conf.freeze_hubert: + self.hubert_optimizer.step() + self.hubert_optimizer.clear_grad() + if self.config.model_scheduler != 'newbobscheduler': + self.model_lr_scheduler.step() + if self.config.hubert_scheduler != 'newbobscheduler': + if not train_conf.freeze_hubert: + self.hubert_lr_scheduler.step() + self.iteration += 1 + + losses_np = {'loss': self.avg_train_loss * train_conf.accum_grad} + iteration_time = time.time() - start + for k, v in losses_np.items(): + report(k, v) + report("loss_whitoutavg", float(loss)) + report("batch_size", self.config.batch_size) + report("accum", train_conf.accum_grad) + report("step_cost", iteration_time) + + if (batch_index + 1) % train_conf.accum_grad == 0: + if dist.get_rank() == 0 and self.visualizer: + losses_np_v = losses_np.copy() + losses_np_v.update({ + "model_lr": self.model_lr_scheduler(), + "hubert_lr": self.hubert_lr_scheduler() + }) + for key, val in losses_np_v.items(): + self.visualizer.add_scalar( + tag='train/' + key, value=val, step=self.iteration - 1) + + @paddle.no_grad() + def valid(self): + self.model.eval() + if not self.use_streamdata: + logger.info( + f"Valid Total Examples: {len(self.valid_loader.dataset)}") + valid_losses = {} + step = 0 + total_loss = 0.0 + num_seen_utts = 1 # use update_average and no need for num_seen_utts here + for i, batch in enumerate(self.valid_loader): + if self.use_sb: + wav, wavs_lens_rate = batch['sig'] + target, target_lens_rate = batch['tokens'] + target_lens = (target_lens_rate * + target.shape[1]).round().astype(paddle.int64) + else: + utt, wav, wavs_lens, target, target_lens = batch + wavs_lens_rate = wavs_lens / wav.shape[1] + wav = wav[:, :, 0] + + loss = self.model(wav, wavs_lens_rate, target, target_lens) + # use update_average + total_loss -= total_loss / (step + 1) + total_loss += loss / (step + 1) + + if math.isfinite(float(loss)): + step += 1 + valid_losses['val_loss'] = float(loss) + else: + logger.info('loss:{} in Nan or inf, error'.format(float(loss))) + + if (i + 1) % self.config.log_interval == 0: + valid_losses['val_history_loss'] = float(total_loss) + + # logging + msg = f"Valid: Rank: {dist.get_rank()}, " + msg += "epoch: {}, ".format(self.epoch) + msg += "step: {}, ".format(self.iteration) + if not self.use_streamdata: + msg += "batch: {}/{}, ".format(i + 1, + len(self.valid_loader)) + msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in valid_losses.items()) + logger.info(msg) + + logger.info( + 'Rank {} Val info val_loss {}'.format(dist.get_rank(), total_loss)) + return total_loss, num_seen_utts + + @mp_tools.rank_zero_only + def save(self, tag=None, infos: dict=None): + """Save checkpoint (model parameters and optimizer states). + + Args: + tag (int or str, optional): None for step, else using tag, e.g epoch. Defaults to None. + infos (dict, optional): meta data to save. Defaults to None. + """ + + infos = infos if infos else dict() + infos.update({ + "epoch": self.epoch, + "model_lr": self.model_optimizer.get_lr(), + "hubert_lr": self.hubert_optimizer.get_lr() + }) + + checkpoint_path = os.path.join( + self.checkpoint_dir, + "{}".format(self.iteration if tag is None else tag)) + + model_dict = self.model.state_dict() + params_path = checkpoint_path + ".pdparams" + paddle.save(model_dict, params_path) + logger.info("Saved model to {}".format(params_path)) + + model_opt_dict = self.model_optimizer.state_dict() + hubert_opt_dict = self.hubert_optimizer.state_dict() + + opt_dict = {'model': model_opt_dict, 'hubert': hubert_opt_dict} + + optimizer_path = checkpoint_path + ".pdopt" + paddle.save(opt_dict, optimizer_path) + logger.info("Saved optimzier state to {}".format(optimizer_path)) + + scheduler_dict = {} + + if self.config.model_scheduler == 'newbobscheduler': + scheduler_dict['model'] = self.model_lr_scheduler.save() + if self.config.hubert_scheduler == 'newbobscheduler': + scheduler_dict['hubert'] = self.hubert_lr_scheduler.save() + if scheduler_dict: + scheduler_path = checkpoint_path + ".pdlrs" + paddle.save(scheduler_dict, scheduler_path) + logger.info("Saved scheduler state to {}".format(scheduler_path)) + info_path = re.sub('.pdparams$', '.json', params_path) + infos = {} if infos is None else infos + with open(info_path, 'w', encoding='utf8') as fout: + data = json.dumps(infos) + fout.write(data) + + def resume_or_scratch(self): + """Resume from latest checkpoint at checkpoints in the output + directory or load a specified checkpoint. + + If ``args.checkpoint_path`` is not None, load the checkpoint, else + resume training. + """ + scratch = None + if self.args.resume: + # just restore ckpt + # lr will resotre from optimizer ckpt + resume_json_path = os.path.join(self.checkpoint_dir, + self.args.resume + '.json') + with open(resume_json_path, 'r', encoding='utf8') as f: + resume_json = json.load(f) + self.iteration = 0 + self.epoch = resume_json["epoch"] + + # resotre model from *.pdparams + params_path = os.path.join(self.checkpoint_dir, + "{}".format(self.epoch)) + '.pdparams' + model_dict = paddle.load(params_path) + self.model.set_state_dict(model_dict) + + # resotre optimizer from *.pdopt + optimizer_path = os.path.join(self.checkpoint_dir, + "{}".format(self.epoch)) + '.pdopt' + optimizer_dict = paddle.load(optimizer_path) + self.model_optimizer.set_state_dict(optimizer_dict['model']) + self.hubert_optimizer.set_state_dict(optimizer_dict['hubert']) + + # resotre lr_scheduler from *.pdlrs + scheduler_path = os.path.join(self.checkpoint_dir, + "{}".format(self.epoch)) + '.pdlrs' + if os.path.isfile(os.path.join(scheduler_path)): + scheduler_dict = paddle.load(scheduler_path) + if self.config.model_scheduler == 'newbobscheduler': + self.model_lr_scheduler.load(scheduler_dict['model']) + if self.config.hubert_scheduler == 'newbobscheduler': + self.hubert_lr_scheduler.load(scheduler_dict['hubert']) + logger.info( + f"Restore ckpt: epoch {self.epoch }, step {self.iteration}!") + scratch = False + else: + self.iteration = 0 + self.epoch = 0 + scratch = True + logger.info("Init from scratch!") + return scratch + + def do_train(self): + """The training process control by step.""" + # !!!IMPORTANT!!! + # Try to export the model by script, if fails, we should refine + # the code to satisfy the script export requirements + # script_model = paddle.jit.to_static(self.model) + # script_model_path = str(self.checkpoint_dir / 'init') + # paddle.jit.save(script_model, script_model_path) + + self.before_train() + if not self.use_streamdata: + logger.info( + f"Train Total Examples: {len(self.train_loader.dataset)}") + while self.epoch < self.config.n_epoch: + with Timer("Epoch-Train Time Cost: {}"): + self.model.train() + try: + data_start_time = time.time() + for batch_index, batch in enumerate(self.train_loader): + dataload_time = time.time() - data_start_time + msg = "Train:" + observation = OrderedDict() + with ObsScope(observation): + report("Rank", dist.get_rank()) + report("epoch", self.epoch) + report('step', self.iteration) + report("model_lr", self.model_optimizer.get_lr()) + report("hubert_lr", + self.hubert_optimizer.get_lr()) + self.train_batch(batch_index, batch, msg) + self.after_train_batch() + report('iter', batch_index + 1) + if not self.use_streamdata: + report('total', len(self.train_loader)) + report('reader_cost', dataload_time) + observation['batch_cost'] = observation[ + 'reader_cost'] + observation['step_cost'] + observation['samples'] = observation['batch_size'] + observation['ips,samples/s'] = observation[ + 'batch_size'] / observation['batch_cost'] + for k, v in observation.items(): + msg += f" {k.split(',')[0]}: " + msg += f"{v:>.8f}" if isinstance(v, + float) else f"{v}" + msg += f" {k.split(',')[1]}" if len( + k.split(',')) == 2 else "" + msg += "," + msg = msg[:-1] # remove the last "," + if (batch_index + 1) % self.config.log_interval == 0: + logger.info(msg) + data_start_time = time.time() + except Exception as e: + logger.error(e) + raise e + with Timer("Eval Time Cost: {}"): + total_loss, num_seen_utts = self.valid() + if dist.get_world_size() > 1: + num_seen_utts = paddle.to_tensor(num_seen_utts) + dist.all_reduce(num_seen_utts) + total_loss = paddle.to_tensor(total_loss) + dist.all_reduce(total_loss) + cv_loss = total_loss / num_seen_utts + cv_loss = float(cv_loss) + else: + cv_loss = float(total_loss) + logger.info( + 'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss)) + if self.visualizer: + self.visualizer.add_scalar( + tag='eval/cv_loss', value=cv_loss, step=self.epoch) + self.visualizer.add_scalar( + tag='eval/model_lr', + value=self.model_lr_scheduler(), + step=self.epoch) + self.visualizer.add_scalar( + tag='eval/hubert_lr', + value=self.hubert_lr_scheduler(), + step=self.epoch) + + if self.config.model_scheduler == 'newbobscheduler': + self.model_lr_scheduler.step(cv_loss) + if self.config.hubert_scheduler == 'newbobscheduler': + if not self.config.freeze_hubert: + self.hubert_lr_scheduler.step(cv_loss) + self.save(tag=self.epoch, infos={'val_loss': cv_loss}) + self.avg_train_loss = 0.0 + self.new_epoch() + + def dataio_prepare(self, hparams): + """This function prepares the datasets to be used in the brain class. + It also defines the data processing pipeline through user-defined functions.""" + data_folder = hparams["data_folder"] + + train_data = dataset.DynamicItemDataset.from_csv( + csv_path=hparams["train_data"], + replacements={"data_root": data_folder}, ) + + if hparams["sorting"] == "ascending": + # we sort training data to speed up training and get better results. + train_data = train_data.filtered_sorted(sort_key="duration") + # when sorting do not shuffle in dataloader ! otherwise is pointless + hparams["train_dataloader_opts"]["shuffle"] = False + + elif hparams["sorting"] == "descending": + train_data = train_data.filtered_sorted( + sort_key="duration", reverse=True) + # when sorting do not shuffle in dataloader ! otherwise is pointless + hparams["train_dataloader_opts"]["shuffle"] = False + + elif hparams["sorting"] == "random": + pass + + else: + raise NotImplementedError( + "sorting must be random, ascending or descending") + + valid_data = dataset.DynamicItemDataset.from_csv( + csv_path=hparams["valid_data"], + replacements={"data_root": data_folder}, ) + valid_data = valid_data.filtered_sorted(sort_key="duration") + + test_data = dataset.DynamicItemDataset.from_csv( + csv_path=hparams["test_data"], + replacements={"data_root": data_folder}, ) + test_data = test_data.filtered_sorted(sort_key="duration") + + datasets = [train_data, valid_data, test_data] + + # Defining tokenizer and loading it + tokenizer = AutoTokenizer.from_pretrained('bert-base-chinese') + self.tokenizer = tokenizer + # 2. Define audio pipeline: + @data_pipeline.takes("wav") + @data_pipeline.provides("sig") + def audio_pipeline(wav): + sig = dataio.read_audio(wav) + return sig + + dataset.add_dynamic_item(datasets, audio_pipeline) + + # 3. Define text pipeline: + @data_pipeline.takes("transcript") + @data_pipeline.provides("wrd", "tokens_list", "tokens") + def text_pipeline(wrd): + wrd = "".join(wrd.split(" ")) + yield wrd + tokens_list = tokenizer(wrd)["input_ids"] + yield tokens_list + tokens = np.array(tokens_list, dtype="int64") + # tokens = paddle.to_tensor(tokens_list, dtype="int64") + yield tokens + + dataset.add_dynamic_item(datasets, text_pipeline) + + # 4. Set output: + dataset.set_output_keys( + datasets, + ["id", "sig", "wrd", "tokens"], ) + + # 5. If Dynamic Batching is used, we instantiate the needed samplers. + train_batch_sampler = None + valid_batch_sampler = None + if hparams["dynamic_batching"]: + from sampler import DynamicBatchSampler # noqa + + dynamic_hparams = hparams["dynamic_batch_sampler"] + num_buckets = dynamic_hparams["num_buckets"] + + train_batch_sampler = DynamicBatchSampler( + train_data, + dynamic_hparams["max_batch_len"], + num_buckets=num_buckets, + length_func=lambda x: x["duration"], + shuffle=dynamic_hparams["shuffle_ex"], + batch_ordering=dynamic_hparams["batch_ordering"], ) + + valid_batch_sampler = DynamicBatchSampler( + valid_data, + dynamic_hparams["max_batch_len"], + num_buckets=num_buckets, + length_func=lambda x: x["duration"], + shuffle=dynamic_hparams["shuffle_ex"], + batch_ordering=dynamic_hparams["batch_ordering"], ) + + return (train_data, valid_data, test_data, tokenizer, + train_batch_sampler, valid_batch_sampler, ) + + def setup_dataloader(self): + config = self.config.clone() + self.use_streamdata = config.get("use_stream_data", False) + self.use_sb = config.get("use_sb_pipeline", False) + if self.use_sb: + hparams_file = config.sb_pipeline_conf + with open(hparams_file, 'r', encoding='utf8') as fin: + hparams = load_hyperpyyaml(fin, None) + + (train_data, valid_data, test_data, tokenizer, train_bsampler, + valid_bsampler, ) = self.dataio_prepare(hparams) + + train_dataloader_opts = hparams["train_dataloader_opts"] + valid_dataloader_opts = hparams["valid_dataloader_opts"] + + if train_bsampler is not None: + train_dataloader_opts = { + "batch_sampler": train_bsampler, + "num_workers": hparams["num_workers"], + } + + if valid_bsampler is not None: + valid_dataloader_opts = {"batch_sampler": valid_bsampler} + + if self.train: + self.train_loader = make_dataloader( + train_data, stage='train', **train_dataloader_opts) + self.valid_loader = make_dataloader( + valid_data, + stage='val', + **valid_dataloader_opts, ) + logger.info("Setup train/valid Dataloader!") + else: + self.test_loader = make_dataloader( + test_data, stage='test', **hparams["test_dataloader_opts"]) + else: + if self.train: + self.train_loader = DataLoaderFactory.get_dataloader( + 'train', config, self.args) + self.valid_loader = DataLoaderFactory.get_dataloader( + 'valid', config, self.args) + logger.info("Setup train/valid Dataloader!") + else: + decode_batch_size = config.get('decode', dict()).get( + 'decode_batch_size', 1) + self.test_loader = DataLoaderFactory.get_dataloader( + 'test', config, self.args) + self.align_loader = DataLoaderFactory.get_dataloader( + 'align', config, self.args) + logger.info("Setup test/align Dataloader!") + + def setup_model(self): + config = self.config + model_conf = config + + with UpdateConfig(model_conf): + if self.use_sb: + model_conf.output_dim = self.tokenizer.vocab_size + else: + if self.train: + model_conf.input_dim = self.train_loader.feat_dim + model_conf.output_dim = self.train_loader.vocab_size + else: + model_conf.input_dim = self.test_loader.feat_dim + model_conf.output_dim = self.test_loader.vocab_size + + model = HubertASR.from_config(model_conf) + + model_dict = paddle.load(config.hubert_params_path) + model.set_state_dict(model_dict) + + if self.parallel: + model = paddle.DataParallel(model, find_unused_parameters=True) + + layer_tools.print_params(model, logger.info) + self.model = model + logger.info("Setup model!") + + # setup speech augmentation for hubert + if hasattr(config, 'audio_augment') and self.train: + self.speech_augmentation = TimeDomainSpecAugment( + **config.audio_augment) + + if not self.train: + return + + train_config = config + model_optim_type = train_config.model_optim + model_optim_conf = train_config.model_optim_conf + logger.info("optim_model:{},{}", model_optim_type, model_optim_conf) + hubert_optim_type = train_config.hubert_optim + hubert_optim_conf = train_config.hubert_optim_conf + logger.info("optim_model:{},{}", hubert_optim_type, + hubert_optim_conf) + + model_scheduler_type = train_config.model_scheduler + model_scheduler_conf = train_config.model_scheduler_conf + hubert_scheduler_type = train_config.hubert_scheduler + hubert_scheduler_conf = train_config.hubert_scheduler_conf + + model_scheduler_args = dict( + **{"learning_rate": model_optim_conf.lr, + "verbose": False}, **(dict(model_scheduler_conf))) + + hubert_scheduler_args = dict( + **{"learning_rate": hubert_optim_conf.lr, + "verbose": False}, **(dict(hubert_scheduler_conf))) + + model_lr_scheduler = LRSchedulerFactory.from_args(model_scheduler_type, + model_scheduler_args) + hubert_lr_scheduler = LRSchedulerFactory.from_args( + hubert_scheduler_type, hubert_scheduler_args) + + def optimizer_args( + config, + optim_type, + optim_conf, + parameters, + lr_scheduler=None, ): + optim_arg = dict(optim_conf) + optim_arg.update({ + "learning_rate": + lr_scheduler if lr_scheduler else optim_conf.lr, + "parameters": + parameters + }) + return optim_arg + + model_optimizer_args = optimizer_args(config, model_optim_type, + model_optim_conf, [{ + 'params': + model._layers.enc.parameters() + }, { + 'params': + model._layers.ctc.parameters() + }] if self.parallel else [{ + 'params': + model.enc.parameters() + }, { + 'params': + model.ctc.parameters() + }], model_lr_scheduler) + + hubert_optimizer_args = optimizer_args( + config, hubert_optim_type, hubert_optim_conf, + model._layers.hubert.parameters() if self.parallel else + model.hubert.parameters(), hubert_lr_scheduler) + + model_optimizer = OptimizerFactory.from_args(model_optim_type, + model_optimizer_args) + hubert_optimizer = OptimizerFactory.from_args(hubert_optim_type, + hubert_optimizer_args) + + self.model_optimizer = model_optimizer + self.hubert_optimizer = hubert_optimizer + self.model_lr_scheduler = model_lr_scheduler + self.hubert_lr_scheduler = hubert_lr_scheduler + logger.info("Setup optimizer/lr_scheduler!") + + +class HubertASRTester(HubertASRTrainer): + def __init__(self, config, args): + super().__init__(config, args) + self.text_featurizer = TextFeaturizer( + unit_type=config.unit_type, vocab=config.vocab_filepath) + self.vocab_list = self.text_featurizer.vocab_list + + def id2token(self, texts, texts_len): + """ ord() id to chr() chr """ + trans = [] + for text, n in zip(texts, texts_len): + n = n.numpy().item() + ids = text[:n] + trans.append(self.text_featurizer.defeaturize(ids.numpy().tolist())) + return trans + + def compute_metrics(self, id, audio, audio_len, texts, texts_len, + fout=None): + decode_cfg = self.config.decode + errors_sum, len_refs, num_ins = 0.0, 0, 0 + errors_func = error_rate.char_errors if decode_cfg.error_rate_type == 'cer' else error_rate.word_errors + error_rate_func = error_rate.cer if decode_cfg.error_rate_type == 'cer' else error_rate.wer + + start_time = time.time() + target_transcripts = self.id2token(texts, texts_len) + result_transcripts, result_tokenids = self.model.decode( + audio, + text_feature=self.text_featurizer, + decoding_method=decode_cfg.decoding_method, + beam_size=decode_cfg.beam_size) + decode_time = time.time() - start_time + + for utt, target, result, rec_tids in zip( + id, target_transcripts, result_transcripts, result_tokenids): + errors, len_ref = errors_func(target, result) + errors_sum += errors + len_refs += len_ref + num_ins += 1 + if fout: + fout.write({ + "utt": utt, + "refs": [target], + "hyps": [result], + "hyps_tokenid": [rec_tids], + }) + logger.info(f"Utt: {utt}") + logger.info(f"Ref: {target}") + logger.info(f"Hyp: {result}") + logger.info("One example error rate [%s] = %f" % ( + decode_cfg.error_rate_type, error_rate_func(target, result))) + + return dict( + errors_sum=errors_sum, + len_refs=len_refs, + num_ins=num_ins, # num examples + error_rate=errors_sum / len_refs, + error_rate_type=decode_cfg.error_rate_type, + num_frames=audio_len.sum().numpy().item(), + decode_time=decode_time) + + def sb_compute_metrics(self, id, sig, wrd, tokens, fout=None): + decode_cfg = self.config.decode + errors_sum, len_refs, num_ins = 0.0, 0, 0 + errors_func = error_rate.char_errors if decode_cfg.error_rate_type == 'cer' else error_rate.word_errors + error_rate_func = error_rate.cer if decode_cfg.error_rate_type == 'cer' else error_rate.wer + start_time = time.time() + target_transcripts = wrd + result_transcripts, result_tokenids = self.model.decode( + sig[0], + text_feature=self.tokenizer, + decoding_method=decode_cfg.decoding_method, + beam_size=decode_cfg.beam_size, + sb_pipeline=True) + decode_time = time.time() - start_time + + for utt, target, result, rec_tids in zip( + id, target_transcripts, result_transcripts, result_tokenids): + errors, len_ref = errors_func(target, result) + errors_sum += errors + len_refs += len_ref + num_ins += 1 + if fout: + fout.write({ + "utt": utt, + "refs": [target], + "hyps": [result], + "hyps_tokenid": [rec_tids], + }) + logger.info(f"Utt: {utt}") + logger.info(f"Ref: {target}") + logger.info(f"Hyp: {result}") + logger.info("One example error rate [%s] = %f" % ( + decode_cfg.error_rate_type, error_rate_func(target, result))) + + return dict( + errors_sum=errors_sum, + len_refs=len_refs, + num_ins=num_ins, # num examples + error_rate=errors_sum / len_refs, + error_rate_type=decode_cfg.error_rate_type, + num_frames=sig[1].sum().numpy().item(), + decode_time=decode_time) + + @mp_tools.rank_zero_only + @paddle.no_grad() + def test(self): + logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}") + self.model.eval() + + error_rate_type = None + errors_sum, len_refs, num_ins = 0.0, 0, 0 + num_frames = 0.0 + num_time = 0.0 + # Initialized the decoder in model + decode_cfg = self.config.decode + vocab_list = self.vocab_list + decode_batch_size = decode_cfg.decode_batch_size + + with jsonlines.open(self.args.result_file, 'w') as fout: + for i, batch in enumerate(self.test_loader): + if self.use_sb: + metrics = self.sb_compute_metrics(**batch, fout=fout) + else: + metrics = self.compute_metrics(*batch, fout=fout) + num_frames += metrics['num_frames'] + num_time += metrics["decode_time"] + errors_sum += metrics['errors_sum'] + len_refs += metrics['len_refs'] + num_ins += metrics['num_ins'] + error_rate_type = metrics['error_rate_type'] + rtf = num_time / (num_frames) + logger.info( + "RTF: %f, Error rate [%s] (%d/?) = %f" % + (rtf, error_rate_type, num_ins, errors_sum / len_refs)) + + # logging + msg = "Test: " + msg += "epoch: {}, ".format(self.epoch) + msg += "step: {}, ".format(self.iteration) + msg += "Final error rate [%s] (%d/%d) = %f" % ( + error_rate_type, num_ins, num_ins, errors_sum / len_refs) + logger.info(msg) + + err_meta_path = os.path.splitext(self.args.result_file)[0] + '.err' + err_type_str = "{}".format(error_rate_type) + with open(err_meta_path, 'w', encoding='utf8') as f: + data = json.dumps({ + "epoch": + self.epoch, + "step": + self.iteration, + "rtf": + rtf, + error_rate_type: + errors_sum / len_refs, + "dataset_hour": (num_frames) / 1000.0 / 3600.0, + "process_hour": + num_time / 1000.0 / 3600.0, + "num_examples": + num_ins, + "err_sum": + errors_sum, + "ref_len": + len_refs, + "decode_method": + self.config.decode.decoding_method, + }) + f.write(data + '\n') diff --git a/paddlespeech/s2t/models/hubert/__init__.py b/paddlespeech/s2t/models/hubert/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/paddlespeech/s2t/models/hubert/hubert_ASR.py b/paddlespeech/s2t/models/hubert/hubert_ASR.py new file mode 100644 index 000000000..b31d10c81 --- /dev/null +++ b/paddlespeech/s2t/models/hubert/hubert_ASR.py @@ -0,0 +1,350 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import defaultdict +from typing import Dict +from typing import List +from typing import Tuple +from typing import Any, Optional +from dataclasses import dataclass, field, is_dataclass +from copy import deepcopy + +from omegaconf import II, MISSING, open_dict + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddlespeech.s2t.models.hubert.modules.hubert_model import HubertConfig, HubertModel, HubertPretrainingConfig + +from paddlespeech.s2t.models.wav2vec2.modules.modeling_wav2vec2 import Wav2Vec2ConfigPure +from paddlespeech.s2t.models.wav2vec2.modules.modeling_wav2vec2 import Wav2Vec2Model +from paddlespeech.s2t.models.wav2vec2.modules.VanillaNN import VanillaNN +from paddlespeech.s2t.models.wav2vec2.processing.speech_augmentation import SpecAugment +from paddlespeech.s2t.modules.ctc import CTCDecoderBase as CTC +from paddlespeech.s2t.modules.initializer import DefaultInitializerContext +from paddlespeech.s2t.utils.ctc_utils import remove_duplicates_and_blank +from paddlespeech.s2t.utils.utility import log_add +from paddlespeech.s2t.utils.log import Log + +logger = Log(__name__).getlog() + +class HubertASR(nn.Layer): + def __init__(self, config: dict): + super().__init__() + init_type = config.get("init_type", None) + with DefaultInitializerContext(init_type): + self.config = config + with open(config.vocab_filepath) as f: + dicts = [symbol.strip() for symbol in f.readlines()] + task_cfg = self.merge_with_parent(HubertPretrainingConfig, dict(self.config.task_cfg)) + model_cfg = self.merge_with_parent(HubertConfig, dict(self.config.model_cfg)) + hubert = HubertModel(model_cfg, task_cfg, dicts) + + self.normalize_wav = config.normalize_wav + self.output_norm = config.output_norm + if hasattr(config, 'spec_augment'): + self.spec_augment = SpecAugment(**config.spec_augment) + + if config.freeze_hubert: + hubert.eval() + for parm in hubert.parameters(): + parm.trainable = False + self.hubert = hubert + self.enc = VanillaNN(**config.enc) + self.ctc = CTC(**config.ctc, + odim=config.output_dim, + batch_average=False, + reduction='mean') + + def merge_with_parent(self, dc: dataclass, cfg: dict): + assert is_dataclass(dc) + assert type(cfg) == dict + cfg = deepcopy(cfg) + + def fix_cfg(cfg): + target_keys = set(dc.__dataclass_fields__.keys()) + for k in list(cfg.keys()): + if k not in target_keys: + del cfg[k] + + fix_cfg(cfg) + assert len(cfg) > 0 + return dc(**cfg) + + def forward(self, wav, wavs_lens_rate, target, target_lens): + + if self.normalize_wav: + wav = F.layer_norm(wav, wav.shape) + + self.hubert.eval() + # Extract wav2vec output + out = self.hubert.extract_features(wav)[0] + # We normalize the output if required + if self.output_norm: + out = F.layer_norm(out, out.shape) + + if self.training and hasattr(self.config, 'spec_augment'): + feats = self.spec_augment(out) + else: + feats = out + + x = self.enc(feats) + + x_lens = (wavs_lens_rate * x.shape[1]).round().astype(paddle.int64) + + ctc_loss = self.ctc(x, x_lens, target, target_lens) + + return ctc_loss + + @paddle.no_grad() + def decode(self, + feats: paddle.Tensor, + text_feature: Dict[str, int], + decoding_method: str, + beam_size: int, + tokenizer: str=None, + sb_pipeline=False): + batch_size = feats.shape[0] + + if decoding_method == 'ctc_prefix_beam_search' and batch_size > 1: + logger.error( + f"decoding mode {decoding_method} must be running with batch_size == 1" + ) + logger.error(f"current batch_size is {batch_size}") + + if decoding_method == 'ctc_greedy_search': + if tokenizer is None and sb_pipeline is False: + hyps = self.ctc_greedy_search(feats) + res = [text_feature.defeaturize(hyp) for hyp in hyps] + res_tokenids = [hyp for hyp in hyps] + else: + if sb_pipeline is True: + hyps = self.ctc_greedy_search(feats.unsqueeze(-1)) + else: + hyps = self.ctc_greedy_search(feats) + res = [] + res_tokenids = [] + for sequence in hyps: + # Decode token terms to words + predicted_tokens = text_feature.convert_ids_to_tokens( + sequence) + tmp_res = [] + tmp_res_tokenids = [] + for c in predicted_tokens: + if c == "[CLS]": + continue + elif c == "[SEP]" or c == "[PAD]": + break + else: + tmp_res.append(c) + tmp_res_tokenids.append(text_feature.vocab[c]) + res.append(''.join(tmp_res)) + res_tokenids.append(tmp_res_tokenids) + + # ctc_prefix_beam_search and attention_rescoring only return one + # result in List[int], change it to List[List[int]] for compatible + # with other batch decoding mode + elif decoding_method == 'ctc_prefix_beam_search': + assert feats.shape[0] == 1 + if tokenizer is None and sb_pipeline is False: + hyp = self.ctc_prefix_beam_search(feats, beam_size) + res = [text_feature.defeaturize(hyp)] + res_tokenids = [hyp] + else: + if sb_pipeline is True: + hyp = self.ctc_prefix_beam_search( + feats.unsqueeze(-1), beam_size) + else: + hyp = self.ctc_prefix_beam_search(feats, beam_size) + res = [] + res_tokenids = [] + predicted_tokens = text_feature.convert_ids_to_tokens(hyp) + tmp_res = [] + tmp_res_tokenids = [] + for c in predicted_tokens: + if c == "[CLS]": + continue + elif c == "[SEP]" or c == "[PAD]": + break + else: + tmp_res.append(c) + tmp_res_tokenids.append(text_feature.vocab[c]) + res.append(''.join(tmp_res)) + res_tokenids.append(tmp_res_tokenids) + else: + raise ValueError( + f"wav2vec2 not support decoding method: {decoding_method}") + + return res, res_tokenids + + @classmethod + def from_config(cls, config): + model = cls(config) + return model + + def ctc_greedy_search(self, wav) -> List[List[int]]: + """ Apply CTC greedy search + Args: + speech (paddle.Tensor): (batch, max_len) + speech_length (paddle.Tensor): (batch, ) + Returns: + List[List[int]]: best path result + """ + batch_size = wav.shape[0] + wav = wav[:, :, 0] + if self.normalize_wav: + wav = F.layer_norm(wav, wav.shape[1:]) + # Extract wav2vec output + out = self.wav2vec2(wav)[0] + # We normalize the output if required + if self.output_norm: + out = F.layer_norm(out, out.shape[1:]) + feats = out + x = self.enc(feats) + x_lens = x.shape[1] + ctc_probs = self.ctc.log_softmax(x) # (B, maxlen, vocab_size) + topk_prob, topk_index = ctc_probs.topk(1, axis=2) # (B, maxlen, 1) + topk_index = topk_index.view(batch_size, x_lens) # (B, maxlen) + + hyps = [hyp.tolist() for hyp in topk_index] + hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps] + return hyps + + def _ctc_prefix_beam_search( + self, + wav, + beam_size, + blank_id: int=0, ) -> Tuple[List[Tuple[int, float]], paddle.Tensor]: + """ CTC prefix beam search inner implementation + Args: + speech (paddle.Tensor): (batch, max_len, feat_dim) + speech_length (paddle.Tensor): (batch, ) + beam_size (int): beam size for beam search + decoding_chunk_size (int): decoding chunk for dynamic chunk + trained model. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + 0: used for training, it's prohibited here + simulate_streaming (bool): whether do encoder forward in a + streaming fashion + Returns: + List[Tuple[int, float]]: nbest results, (N,1), (text, likelihood) + paddle.Tensor: encoder output, (1, max_len, encoder_dim), + it will be used for rescoring in attention rescoring mode + """ + wav = wav[:, :, 0] + + if self.normalize_wav: + wav = F.layer_norm(wav, wav.shape[1:]) + # Extract wav2vec output + out = self.wav2vec2(wav)[0] + # We normalize the output if required + if self.output_norm: + out = F.layer_norm(out, out.shape[1:]) + feats = out + + x = self.enc(feats) + maxlen = x.shape[1] + ctc_probs = self.ctc.log_softmax(x) # (1, maxlen, vocab_size) + ctc_probs = ctc_probs.squeeze(0) + + # cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score)) + # blank_ending_score and none_blank_ending_score in ln domain + cur_hyps = [(tuple(), (0.0, -float('inf')))] + # 2. CTC beam search step by step + for t in range(0, maxlen): + logp = ctc_probs[t] # (vocab_size,) + # key: prefix, value (pb, pnb), default value(-inf, -inf) + next_hyps = defaultdict(lambda: (-float('inf'), -float('inf'))) + # 2.1 First beam prune: select topk best + top_k_logp, top_k_index = logp.topk(beam_size) # (beam_size,) + for s in top_k_index: + s = s.item() + ps = logp[s].item() + for prefix, (pb, pnb) in cur_hyps: + last = prefix[-1] if len(prefix) > 0 else None + if s == blank_id: # blank + n_pb, n_pnb = next_hyps[prefix] + n_pb = log_add([n_pb, pb + ps, pnb + ps]) + next_hyps[prefix] = (n_pb, n_pnb) + elif s == last: + # Update *ss -> *s; + n_pb, n_pnb = next_hyps[prefix] + n_pnb = log_add([n_pnb, pnb + ps]) + next_hyps[prefix] = (n_pb, n_pnb) + # Update *s-s -> *ss, - is for blank + n_prefix = prefix + (s, ) + n_pb, n_pnb = next_hyps[n_prefix] + n_pnb = log_add([n_pnb, pb + ps]) + next_hyps[n_prefix] = (n_pb, n_pnb) + else: + n_prefix = prefix + (s, ) + n_pb, n_pnb = next_hyps[n_prefix] + n_pnb = log_add([n_pnb, pb + ps, pnb + ps]) + next_hyps[n_prefix] = (n_pb, n_pnb) + + # 2.2 Second beam prune + next_hyps = sorted( + next_hyps.items(), + key=lambda x: log_add(list(x[1])), + reverse=True) + cur_hyps = next_hyps[:beam_size] + + hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in cur_hyps] + return hyps + + def ctc_prefix_beam_search(self, wav, beam_size) -> List[int]: + """ Apply CTC prefix beam search + Args: + speech (paddle.Tensor): (batch, max_len, feat_dim) + speech_length (paddle.Tensor): (batch, ) + beam_size (int): beam size for beam search + decoding_chunk_size (int): decoding chunk for dynamic chunk + trained model. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + 0: used for training, it's prohibited here + simulate_streaming (bool): whether do encoder forward in a + streaming fashion + Returns: + List[int]: CTC prefix beam search nbest results + """ + hyps = self._ctc_prefix_beam_search(wav, beam_size) + return hyps[0][0] + + +class Wav2vec2Base(nn.Layer): + """Wav2vec2 model""" + + def __init__(self, config: dict): + super().__init__() + wav2vec2_config = Wav2Vec2ConfigPure(config) + wav2vec2 = Wav2Vec2Model(wav2vec2_config) + self.wav2vec2 = wav2vec2 + + @classmethod + def from_config(cls, configs: dict): + """init model. + Args: + configs (dict): config dict. + Raises: + ValueError: raise when using not support encoder type. + Returns: + nn.Layer: Wav2Vec2Base + """ + model = cls(configs) + return model + + def forward(self, wav): + out = self.wav2vec2(wav) + return out diff --git a/paddlespeech/s2t/models/hubert/modules/hubert_model.py b/paddlespeech/s2t/models/hubert/modules/hubert_model.py new file mode 100644 index 000000000..8b7822699 --- /dev/null +++ b/paddlespeech/s2t/models/hubert/modules/hubert_model.py @@ -0,0 +1,612 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# S3PRL Team has no contribution to this file +# The file was copied from fairseq to remove the dependency on the entire fairseq package + +import logging +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import paddle +import paddle.nn as nn + +from paddlespeech.s2t.models.wav2vec2.modules.wav2vec2_model import ( + EXTRACTOR_MODE_CHOICES, + LAYER_TYPE_CHOICES, + MASKING_DISTRIBUTION_CHOICES, + ChoiceEnum, + ConvFeatureExtractionModel, + GradMultiply, + LayerNorm, + TransformerEncoder, + compute_mask_indices, + get_available_activation_fns, + GLU, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class HubertPretrainingConfig: + label_rate: float = field( + default=-1.0, + metadata={"help": "label frame rate. -1.0 for sequence label"}, + ) + sample_rate: int = field( + default=16_000, + metadata={ + "help": "target sample rate. audio files will be up/down " + "sampled to this rate" + }, + ) + normalize: bool = field( + default=False, + metadata={"help": "if set, normalizes input to have 0 mean and unit variance"}, + ) + enable_padding: bool = field( + default=False, + metadata={"help": "pad shorter samples instead of cropping"}, + ) + max_keep_size: Optional[int] = field( + default=None, + metadata={"help": "exclude sample longer than this"}, + ) + max_sample_size: Optional[int] = field( + default=None, + metadata={"help": "max sample size to crop to for batching"}, + ) + min_sample_size: Optional[int] = field( + default=None, + metadata={"help": "min sample size to crop to for batching"}, + ) + random_crop: Optional[bool] = field( + default=True, + metadata={"help": "always crop from the beginning if false"}, + ) + pad_audio: Optional[bool] = field( + default=False, + metadata={"help": "pad audio to the longest one in the batch if true"}, + ) + + +@dataclass +class HubertConfig: + label_rate: float + + extractor_mode: EXTRACTOR_MODE_CHOICES = field( + default="default", + metadata={ + "help": "mode for feature extractor. default has a single group " + "norm with d groups in the first conv block, whereas layer_norm " + "has layer norms in every block (meant to use with normalize=True)" + }, + ) + encoder_layers: int = field( + default=12, metadata={"help": "num encoder layers in the transformer"} + ) + encoder_embed_dim: int = field( + default=768, metadata={"help": "encoder embedding dimension"} + ) + encoder_ffn_embed_dim: int = field( + default=3072, metadata={"help": "encoder embedding dimension for FFN"} + ) + encoder_attention_heads: int = field( + default=12, metadata={"help": "num encoder attention heads"} + ) + activation_fn: ChoiceEnum(get_available_activation_fns()) = field( + default="gelu", metadata={"help": "activation function to use"} + ) + layer_type: LAYER_TYPE_CHOICES = field( + default="transformer", metadata={"help": "layer type in encoder"} + ) + + # dropouts + dropout: float = field( + default=0.1, + metadata={"help": "dropout probability for the transformer"}, + ) + attention_dropout: float = field( + default=0.1, + metadata={"help": "dropout probability for attention weights"}, + ) + activation_dropout: float = field( + default=0.0, + metadata={"help": "dropout probability after activation in FFN"}, + ) + encoder_layerdrop: float = field( + default=0.0, + metadata={"help": "probability of dropping a tarnsformer layer"}, + ) + dropout_input: float = field( + default=0.0, + metadata={"help": "dropout to apply to the input (after feat extr)"}, + ) + dropout_features: float = field( + default=0.0, + metadata={"help": "dropout to apply to the features (after feat extr)"}, + ) + + final_dim: int = field( + default=0, + metadata={ + "help": "project final representations and targets to this many " + "dimensions. set to encoder_embed_dim is <= 0" + }, + ) + untie_final_proj: bool = field( + default=False, + metadata={"help": "use separate projection for each target"}, + ) + layer_norm_first: bool = field( + default=False, + metadata={"help": "apply layernorm first in the transformer"}, + ) + conv_feature_layers: str = field( + default="[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2", + metadata={ + "help": "string describing convolutional feature extraction " + "layers in form of a python list that contains " + "[(dim, kernel_size, stride), ...]" + }, + ) + conv_bias: bool = field( + default=False, metadata={"help": "include bias in conv encoder"} + ) + logit_temp: float = field( + default=0.1, metadata={"help": "temperature to divide logits by"} + ) + target_glu: bool = field( + default=False, metadata={"help": "adds projection + glu to targets"} + ) + feature_grad_mult: float = field( + default=1.0, + metadata={"help": "multiply feature extractor var grads by this"}, + ) + + # masking + mask_length: int = field(default=10, metadata={"help": "mask length"}) + mask_prob: float = field( + default=0.65, + metadata={"help": "probability of replacing a token with mask"}, + ) + mask_selection: MASKING_DISTRIBUTION_CHOICES = field( + default="static", metadata={"help": "how to choose mask length"} + ) + mask_other: float = field( + default=0, + metadata={ + "help": "secondary mask argument " + "(used for more complex distributions), " + "see help in compute_mask_indicesh" + }, + ) + no_mask_overlap: bool = field( + default=False, metadata={"help": "whether to allow masks to overlap"} + ) + mask_min_space: int = field( + default=1, + metadata={"help": "min space between spans (if no overlap is enabled)"}, + ) + + # channel masking + mask_channel_length: int = field( + default=10, + metadata={"help": "length of the mask for features (channels)"}, + ) + mask_channel_prob: float = field( + default=0.0, + metadata={"help": "probability of replacing a feature with 0"}, + ) + mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field( + default="static", + metadata={"help": "how to choose mask length for channel masking"}, + ) + mask_channel_other: float = field( + default=0, + metadata={ + "help": "secondary mask argument " + "(used for more complex distributions), " + "see help in compute_mask_indicesh" + }, + ) + no_mask_channel_overlap: bool = field( + default=False, + metadata={"help": "whether to allow channel masks to overlap"}, + ) + mask_channel_min_space: int = field( + default=1, + metadata={"help": "min space between spans (if no overlap is enabled)"}, + ) + + # positional embeddings + conv_pos: int = field( + default=128, + metadata={"help": "number of filters for convolutional positional embeddings"}, + ) + conv_pos_groups: int = field( + default=16, + metadata={"help": "number of groups for convolutional positional embedding"}, + ) + + latent_temp: Tuple[float, float, float] = field( + default=(2, 0.5, 0.999995), + metadata={"help": "legacy (to be removed)"}, + ) + + # loss computation + skip_masked: bool = field( + default=False, + metadata={"help": "skip computing losses over masked frames"}, + ) + skip_nomask: bool = field( + default=False, + metadata={"help": "skip computing losses over unmasked frames"}, + ) + + checkpoint_activations: bool = field( + default=False, + metadata={"help": "recompute activations and save memory for extra compute"}, + ) + + # FP16 optimization + required_seq_len_multiple: int = field( + default=2, + metadata={ + "help": "pad the input to encoder such that the sequence length is divisible by multiple" + }, + ) + + # Conformer + depthwise_conv_kernel_size: int = field( + default=31, + metadata={ + "help": "depthwise-conv-kernel-size for convolution in conformer layer" + }, + ) + attn_type: str = field( + default="", + metadata={"help": "if espnet use ESPNET MHA"}, + ) + pos_enc_type: str = field( + default="abs", + metadata={"help": "Positional encoding type to use in conformer"}, + ) + fp16: bool = field(default=False, metadata={"help": "If fp16 is being used"}) + + +class HubertModel(nn.Layer): + def __init__( + self, + cfg: HubertConfig, + task_cfg: HubertPretrainingConfig, + dictionaries: List[Any], + ) -> None: + super().__init__() + logger.info(f"HubertModel Config: {cfg}") + + feature_enc_layers = eval(cfg.conv_feature_layers) # noqa + self.embed = feature_enc_layers[-1][0] + + self.feature_extractor = ConvFeatureExtractionModel( + conv_layers=feature_enc_layers, + dropout=0.0, + mode=cfg.extractor_mode, + conv_bias=cfg.conv_bias, + ) + feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers]) + self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / task_cfg.sample_rate + + self.post_extract_proj = ( + nn.Linear(self.embed, cfg.encoder_embed_dim) + if self.embed != cfg.encoder_embed_dim + else None + ) + + self.mask_prob = cfg.mask_prob + self.mask_selection = cfg.mask_selection + self.mask_other = cfg.mask_other + self.mask_length = cfg.mask_length + self.no_mask_overlap = cfg.no_mask_overlap + self.mask_min_space = cfg.mask_min_space + + self.mask_channel_prob = cfg.mask_channel_prob + self.mask_channel_selection = cfg.mask_channel_selection + self.mask_channel_other = cfg.mask_channel_other + self.mask_channel_length = cfg.mask_channel_length + self.no_mask_channel_overlap = cfg.no_mask_channel_overlap + self.mask_channel_min_space = cfg.mask_channel_min_space + + self.dropout_input = nn.Dropout(cfg.dropout_input) + self.dropout_features = nn.Dropout(cfg.dropout_features) + + self.feature_grad_mult = cfg.feature_grad_mult + self.logit_temp = cfg.logit_temp + self.skip_masked = cfg.skip_masked + self.skip_nomask = cfg.skip_nomask + + final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim + + self.mask_emb = paddle.create_parameter( + shape=[cfg.encoder_embed_dim], + dtype='float32', + default_initializer=paddle.nn.initializer.Uniform(), + ) + + self.encoder = TransformerEncoder(cfg) + self.layer_norm = LayerNorm(self.embed) + + self.target_glu = None + if cfg.target_glu: + self.target_glu = nn.Sequential( + nn.Linear(final_dim, final_dim * 2), GLU() + ) + + self.untie_final_proj = cfg.untie_final_proj + if self.untie_final_proj: + self.final_proj = nn.Linear( + cfg.encoder_embed_dim, final_dim * len(dictionaries) + ) + else: + self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim) + + # modules below are not needed during fine-tuning + if any([d is None for d in dictionaries]): + logger.info("cannot find dictionary. assume will be used for fine-tuning") + else: + self.num_classes = [len(d) for d in dictionaries] + self.label_embs_concat = paddle.create_parameter( + shape=[sum(self.num_classes), final_dim], + dtype='float32', + default_initializer=paddle.nn.initializer.Uniform(), + ) + def upgrade_state_dict_named(self, state_dict, name): + """Upgrade a (possibly old) state dict for new versions of fairseq.""" + + super().upgrade_state_dict_named(state_dict, name) + return state_dict + + @classmethod + def build_model(cls, cfg: HubertConfig, task): + """Build a new model instance.""" + + model = HubertModel(cfg, task.cfg, task.dictionaries) + return model + + def apply_mask(self, x, padding_mask, target_list): + B, T, C = x.shape + if self.mask_prob > 0: + mask_indices = compute_mask_indices( + (B, T), + padding_mask, + self.mask_prob, + self.mask_length, + self.mask_selection, + self.mask_other, + min_masks=2, + no_overlap=self.no_mask_overlap, + min_space=self.mask_min_space, + ) + + mask_indices = paddle.to_tensor(mask_indices, dtype='int64', place=x.place) + x[mask_indices] = self.mask_emb + else: + mask_indices = None + + if self.mask_channel_prob > 0: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.mask_channel_prob, + self.mask_channel_length, + self.mask_channel_selection, + self.mask_channel_other, + no_overlap=self.no_mask_channel_overlap, + min_space=self.mask_channel_min_space, + ) + mask_channel_indices = ( + paddle.to_tensor(mask_channel_indices, dtype='int64', place=x.place) + .unsqueeze(1) + .expand(-1, T, -1) + ) + x[mask_channel_indices] = 0 + + return x, mask_indices + + def compute_nce(x, pos, negs): + neg_is_pos = (pos == negs).all(-1) + pos = pos.unsqueeze(0) + targets = paddle.concat([pos, negs], axis=0) + + logits = paddle.nn.functional.cosine_similarity(x.astype('float32'), targets.astype('float32'), axis=-1) + logits /= self.logit_temp + if paddle.any(neg_is_pos): + logits[1:][neg_is_pos] = float("-inf") + logits = logits.transpose([1, 0]) # (num_x, num_cls+1) + return logits + + def forward_features(self, source: paddle.Tensor) -> paddle.Tensor: + if self.feature_grad_mult > 0: + features = self.feature_extractor(source) + if self.feature_grad_mult != 1.0: + features = GradMultiply.apply(features, self.feature_grad_mult) + else: + with paddle.no_grad(): + features = self.feature_extractor(source) + return features + + def forward_targets( + self, + features: paddle.Tensor, + target_list: List[paddle.Tensor], + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + # Trim features to ensure labels exist and then get aligned labels + feat_tsz = features.shape[2] + targ_tsz = min([t.shape[1] for t in target_list]) + if self.feat2tar_ratio * feat_tsz > targ_tsz: + feat_tsz = int(targ_tsz / self.feat2tar_ratio) + features = features[:, :, :feat_tsz] + target_inds = paddle.arange(feat_tsz).astype('float32') * self.feat2tar_ratio + target_list = [t[:, target_inds.astype('int64')] for t in target_list] + return features, target_list + + def forward_padding_mask( + self, + features: paddle.Tensor, + padding_mask: paddle.Tensor, + ) -> paddle.Tensor: + extra = padding_mask.shape[1] % features.shape[1] + if extra > 0: + padding_mask = padding_mask[:, :-extra] + padding_mask = paddle.reshape(padding_mask, [padding_mask.shape[0], features.shape[1], -1]) + padding_mask = paddle.all(padding_mask, axis=-1) + return padding_mask + + def forward( + self, + source: paddle.Tensor, + target_list: Optional[List[paddle.Tensor]] = None, + padding_mask: Optional[paddle.Tensor] = None, + mask: bool = True, + features_only: bool = False, + output_layer: Optional[int] = None, + ) -> Dict[str, paddle.Tensor]: + """output layer is 1-based""" + features = self.forward_features(source) + if target_list is not None: + features, target_list = self.forward_targets(features, target_list) + + features_pen = features.pow(2).mean() + + features = features.transpose([0, 2, 1]) + features = self.layer_norm(features) + unmasked_features = features.clone() + + if padding_mask is not None: + padding_mask = self.forward_padding_mask(features, padding_mask) + + if self.post_extract_proj is not None: + features = self.post_extract_proj(features) + + features = self.dropout_input(features) + unmasked_features = self.dropout_features(unmasked_features) + + if mask: + x, mask_indices = self.apply_mask(features, padding_mask, target_list) + else: + x = features + mask_indices = None + + # feature: (B, T, D), float + # target: (B, T), long + # x: (B, T, D), float + # padding_mask: (B, T), bool + # mask_indices: (B, T), bool + x, _ = self.encoder( + x, + padding_mask=padding_mask, + layer=None if output_layer is None else output_layer - 1, + ) + + if features_only: + return {"x": x, "padding_mask": padding_mask, "features": features} + + def compute_pred(self, proj_x, target, label_embs): + # compute logits for the i-th label set + y = paddle.index_select(label_embs, index=target.astype('int64'), axis=0) + negs = paddle.expand(label_embs.unsqueeze(1), [label_embs.shape[0], proj_x.shape[0], label_embs.shape[-1]]) + if self.target_glu: + y = self.target_glu(y) + negs = self.target_glu(negs) + # proj_x: (S, D) + # y: (S, D) + # negs: (Neg, S, D) + return self.compute_nce(proj_x, y, negs) + + label_embs_list = self.label_embs_concat.split(self.num_classes, 0) + + if not self.skip_masked: + masked_indices = paddle.logical_and(~padding_mask, mask_indices) + proj_x_m = self.final_proj(x[masked_indices]) + if self.untie_final_proj: + proj_x_m_list = proj_x_m.chunk(len(target_list), dim=-1) + else: + proj_x_m_list = [proj_x_m for _ in range(len(target_list))] + logit_m_list = [ + compute_pred(proj_x_m, t[masked_indices], label_embs_list[i]) + for i, (proj_x_m, t) in enumerate(zip(proj_x_m_list, target_list)) + ] + else: + logit_m_list = [None for _ in target_list] + + if not self.skip_nomask: + nomask_indices = paddle.logical_and(~padding_mask, ~mask_indices) + proj_x_u = self.final_proj(x[nomask_indices]) + if self.untie_final_proj: + proj_x_u_list = proj_x_u.chunk(len(target_list), dim=-1) + else: + proj_x_u_list = [proj_x_u for _ in range(len(target_list))] + + logit_u_list = [ + compute_pred(proj_x_u, t[nomask_indices], label_embs_list[i]) + for i, (proj_x_u, t) in enumerate(zip(proj_x_u_list, target_list)) + ] + else: + logit_u_list = [None for _ in target_list] + + result = { + "logit_m_list": logit_m_list, + "logit_u_list": logit_u_list, + "padding_mask": padding_mask, + "features_pen": features_pen, + } + return result + + def extract_features( + self, + source: paddle.Tensor, + padding_mask: Optional[paddle.Tensor] = None, + mask: bool = False, + ret_conv: bool = False, + output_layer: Optional[int] = None, + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + res = self.forward( + source, + padding_mask=padding_mask, + mask=mask, + features_only=True, + output_layer=output_layer, + ) + feature = res["features"] if ret_conv else res["x"] + return feature, res["padding_mask"] + + def get_logits(self, net_output, is_masked=True): + if is_masked: + logits_list = net_output["logit_m_list"] + else: + logits_list = net_output["logit_u_list"] + logits_list = [paddle.cast(x, 'float32') for x in logits_list if x is not None] + return logits_list + + def get_targets(self, net_output, is_masked=True): + logits_list = self.get_logits(net_output, is_masked) + targets_list = [paddle.zeros_like(x, dtype='int64') for x in logits_list] + return targets_list + + def get_extra_losses(self, net_output): + extra_losses = [] + names = [] + + if "features_pen" in net_output: + extra_losses.append(net_output["features_pen"]) + names.append("features_pen") + + return extra_losses, names + + def remove_pretraining_modules(self): + self.target_glu = None + self.final_proj = None diff --git a/paddlespeech/s2t/models/wav2vec2/modules/modeling_wav2vec2.py b/paddlespeech/s2t/models/wav2vec2/modules/modeling_wav2vec2.py index 688bf5f84..3457f51a8 100644 --- a/paddlespeech/s2t/models/wav2vec2/modules/modeling_wav2vec2.py +++ b/paddlespeech/s2t/models/wav2vec2/modules/modeling_wav2vec2.py @@ -1114,7 +1114,6 @@ class Wav2Vec2Model(nn.Layer): class Wav2Vec2ConfigPure(): model_type = "wav2vec2" - def __init__(self, config): self.output_attentions = False self.output_hidden_states = False diff --git a/paddlespeech/s2t/models/wav2vec2/modules/wav2vec2_model.py b/paddlespeech/s2t/models/wav2vec2/modules/wav2vec2_model.py new file mode 100644 index 000000000..6b9d6cb30 --- /dev/null +++ b/paddlespeech/s2t/models/wav2vec2/modules/wav2vec2_model.py @@ -0,0 +1,2613 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# S3PRL has no contribution to this file +# The file was copied from fairseq to remove the dependency on the entire fairseq package + +import logging +import math +import uuid +from dataclasses import dataclass, field +from enum import Enum, EnumMeta +from typing import Callable, Dict, List, Optional, Tuple + +import numpy as np +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle import Tensor + +logger = logging.getLogger(__name__) + + + +class GLU(nn.Layer): + r"""Applies the gated linear unit function + :math:`{GLU}(a, b)= a \otimes \sigma(b)` where :math:`a` is the first half + of the input matrices and :math:`b` is the second half. + + Args: + axis (int): the dimension on which to split the input. Default: -1 + + Shape: + - Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional + dimensions + - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2` + + Examples:: + + >>> m = nn.GLU() + >>> input = paddle.randn([4, 2]) + >>> output = m(input) + """ + def __init__(self, axis: int = -1) -> None: + super().__init__() + self.axis = axis + + def forward(self, input: Tensor) -> Tensor: + return F.glu(input, self.axis) + +class FairseqIncrementalState(object): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.init_incremental_state() + + def init_incremental_state(self): + self._incremental_state_id = str(uuid.uuid4()) + + def _get_full_incremental_state_key(self, key: str) -> str: + return "{}.{}".format(self._incremental_state_id, key) + + def get_incremental_state( + self, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], + key: str, + ) -> Optional[Dict[str, Optional[Tensor]]]: + """Helper for getting incremental state for an nn.Layer.""" + full_key = self._get_full_incremental_state_key(key) + if incremental_state is None or full_key not in incremental_state: + return None + return incremental_state[full_key] + + def set_incremental_state( + self, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], + key: str, + value: Dict[str, Optional[Tensor]], + ) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]: + """Helper for setting incremental state for an nn.Layer.""" + if incremental_state is not None: + full_key = self._get_full_incremental_state_key(key) + incremental_state[full_key] = value + return incremental_state + + +def with_incremental_state(cls): + cls.__bases__ = (FairseqIncrementalState,) + tuple( + b for b in cls.__bases__ if b != FairseqIncrementalState + ) + return cls + + +class FairseqDropout(paddle.nn.Layer): + def __init__(self, p, module_name=None): + super().__init__() + self.p = p + self.module_name = module_name + self.apply_during_inference = False + + def forward(self, x): + if self.p > 0 and (self.training or self.apply_during_inference): + return F.dropout(x, p=self.p, training=True) + else: + return x + + def make_generation_fast_( + self, + name: str, + retain_dropout: bool = False, + retain_dropout_modules: Optional[List[str]] = None, + **kwargs, + ): + if retain_dropout: + if retain_dropout_modules is not None and self.module_name is None: + logger.warning( + "Cannot enable dropout during inference for module {} " + "because module_name was not set".format(name) + ) + elif ( + retain_dropout_modules is None # if None, apply to all modules + or self.module_name in retain_dropout_modules + ): + logger.info( + "Enabling dropout during inference for module: {}".format(name) + ) + self.apply_during_inference = True + else: + logger.info("Disabling dropout for module: {}".format(name)) + + +def quant_noise(module, p, block_size): + """ + Wraps modules and applies quantization noise to the weights for + subsequent quantization with Iterative Product Quantization as + described in "Training with Quantization Noise for Extreme Model Compression" + + Args: + - module: nn.Layer + - p: amount of Quantization Noise + - block_size: size of the blocks for subsequent quantization with iPQ + + Remarks: + - Layer weights must have the right sizes wrt the block size + - Only Linear, Embedding and Conv2d modules are supported for the moment + - For more detail on how to quantize by blocks with convolutional weights, + see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks" + - We implement the simplest form of noise here as stated in the paper + which consists in randomly dropping blocks + """ + + # if no quantization noise, don't register hook + if p <= 0: + return module + + # supported modules + assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2D)) + + # test whether module.weight has the right sizes wrt block_size + is_conv = len(module.weight.shape) == 4 + + # 2D matrix + if not is_conv: + assert ( + module.weight.shape[1] % block_size == 0 + ), "Input features must be a multiple of block sizes" + + # 4D matrix + else: + # 1x1 convolutions + if module.weight.shape[2:] == (1, 1): + assert ( + module.weight.shape[1] % block_size == 0 + ), "Input channels must be a multiple of block sizes" + # regular convolutions + else: + k = module.weight.shape[2] * module.weight.shape[3] + assert k % block_size == 0, "Kernel size must be a multiple of block size" + + def _forward_pre_hook(mod, input): + # no noise for evaluation + if mod.training: + if not is_conv: + # gather weight and sizes + weight = mod.weight + in_features = weight.shape[1] + out_features = weight.shape[0] + + # split weight matrix into blocks and randomly drop selected blocks + mask = paddle.zeros( + [in_features // block_size * out_features], dtype=paddle.bool) + mask.bernoulli_(p) + mask = mask.unsqueeze(1).tile([1, block_size]).reshape([-1, in_features]) + + else: + # gather weight and sizes + weight = mod.weight + in_channels = mod.weight.shape[1] + out_channels = mod.weight.shape[0] + + # split weight matrix into blocks and randomly drop selected blocks + if module.weight.shape[2:] == (1, 1): + mask = paddle.zeros( + [in_channels // block_size * out_channels], dtype=paddle.bool + ) + mask.bernoulli_(p) + mask = mask.unsqueeze(1).tile([1, block_size]).reshape([-1, in_channels]) + else: + mask = paddle.zeros( + weight.shape + ) + mask.bernoulli_(p) + mask = mask.unsqueeze(1).tile([1, in_channels, 1, 1]) + + # scale weights and apply mask + s = 1 / (1 - p) + mod.weight.set_value(s * weight.masked_fill(mask, 0)) + + module.register_forward_pre_hook(_forward_pre_hook) + return module + + +@with_incremental_state +class MultiheadAttention(nn.Layer): + """Multi-headed attention. + + See "Attention Is All You Need" for more details. + """ + + def __init__( + self, + embed_dim, + num_heads, + kdim=None, + vdim=None, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + self_attention=False, + encoder_decoder_attention=False, + q_noise=0.0, + qn_block_size=8, + # TODO: pass in config rather than string. + # config defined in xformers.components.attention.AttentionConfig + xformers_att_config: Optional[str] = None, + xformers_blocksparse_layout: Optional[ + paddle.Tensor + ] = None, # This should be part of the config + xformers_blocksparse_blocksize: Optional[ + int + ] = 16, # This should be part of the config + ): + super().__init__() + + def eval_str_dict(x, type=dict): + if x is None: + return None + if isinstance(x, str): + x = eval(x) + return x + + xformers_att_config = eval_str_dict(xformers_att_config) + self.use_xformers = xformers_att_config is not None + assert not self.use_xformers, "Do not use xformers in PaddleSpeech" + + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout_module = FairseqDropout( + dropout, module_name=self.__class__.__name__ + ) + + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim**-0.5 + + self.self_attention = self_attention + self.encoder_decoder_attention = encoder_decoder_attention + + assert not self.self_attention or self.qkv_same_dim, ( + "Self-attention requires query, key and " "value to be of the same size" + ) + + weight_attr = paddle.ParamAttr(initializer=nn.initializer.XavierUniform) + bias_attr = nn.initializer.Constant(0) + # self.k_proj = quant_noise( + # nn.Linear(self.kdim, embed_dim, weight_attr=weight_attr, bias_attr=bias), q_noise, qn_block_size + # ) + # self.v_proj = quant_noise( + # nn.Linear(self.vdim, embed_dim, weight_attr=weight_attr, bias_attr=bias), q_noise, qn_block_size + # ) + # self.q_proj = quant_noise( + # nn.Linear(embed_dim, embed_dim, weight_attr=weight_attr, bias_attr=bias), q_noise, qn_block_size + # ) + + # self.out_proj = quant_noise( + # nn.Linear(embed_dim, embed_dim, weight_attr=weight_attr, bias_attr=bias if not bias else bias_attr), q_noise, qn_block_size + # ) + self.k_proj = nn.Linear(self.kdim, embed_dim) + + self.v_proj = nn.Linear(self.vdim, embed_dim) + + self.q_proj = nn.Linear(embed_dim, embed_dim) + + self.out_proj = nn.Linear(embed_dim, embed_dim) + + if add_bias_kv: + self.bias_k = paddle.create_parameter(shape=[1, 1, embed_dim], dtype='float32', initializer=nn.initializer.XavierUniform) + self.bias_v = paddle.create_parameter(shape=[1, 1, embed_dim], dtype='float32', initializer=nn.initializer.XavierUniform) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + self.beam_size = 1 + # self.reset_parameters() + + self.onnx_trace = False + self.skip_embed_dim_check = False + + def prepare_for_onnx_export_(self): + self.onnx_trace = True + + # def reset_parameters(self): + # if self.qkv_same_dim: + # # Empirically observed the convergence to be much better with + # # the scaled initialization + # nn.initializer.XavierUniform(self.k_proj.weight, gain=1 / math.sqrt(2)) + # nn.initializer.XavierUniform(self.v_proj.weight, gain=1 / math.sqrt(2)) + # nn.initializer.XavierUniform(self.q_proj.weight, gain=1 / math.sqrt(2)) + # else: + # self.k_proj.weight = paddle.ParamAttr() + # nn.initializer.XavierUniform(self.k_proj.weight) + # nn.initializer.XavierUniform(self.v_proj.weight) + # nn.initializer.XavierUniform(self.q_proj.weight) + + # nn.initializer.XavierUniform(self.out_proj.weight) + # if self.out_proj.bias is not None: + # nn.initializer.Constant(self.out_proj.bias) + # if self.bias_k is not None: + # nn.initializer.XavierNormal(self.bias_k) + # if self.bias_v is not None: + # nn.initializer.XavierNormal(self.bias_v) + + def _get_reserve_head_index(self, num_heads_to_keep: int): + k_proj_heads_norm = [] + q_proj_heads_norm = [] + v_proj_heads_norm = [] + + for i in range(self.num_heads): + start_idx = i * self.head_dim + end_idx = (i + 1) * self.head_dim + k_proj_heads_norm.append( + paddle.sum( + paddle.abs( + self.k_proj.weight[ + start_idx:end_idx, + ] + ) + ).tolist() + + paddle.sum(paddle.abs(self.k_proj.bias[start_idx:end_idx])).tolist() + ) + q_proj_heads_norm.append( + paddle.sum( + paddle.abs( + self.q_proj.weight[ + start_idx:end_idx, + ] + ) + ).tolist() + + paddle.sum(paddle.abs(self.q_proj.bias[start_idx:end_idx])).tolist() + ) + v_proj_heads_norm.append( + paddle.sum( + paddle.abs( + self.v_proj.weight[ + start_idx:end_idx, + ] + ) + ).tolist() + + paddle.sum(paddle.abs(self.v_proj.bias[start_idx:end_idx])).tolist() + ) + + heads_norm = [] + for i in range(self.num_heads): + heads_norm.append( + k_proj_heads_norm[i] + q_proj_heads_norm[i] + v_proj_heads_norm[i] + ) + + sorted_head_index = sorted( + range(self.num_heads), key=lambda k: heads_norm[k], reverse=True + ) + reserve_head_index = [] + for i in range(num_heads_to_keep): + start = sorted_head_index[i] * self.head_dim + end = (sorted_head_index[i] + 1) * self.head_dim + reserve_head_index.append((start, end)) + + return reserve_head_index + + def _adaptive_prune_heads(self, reserve_head_index: List[Tuple[int, int]]): + new_q_weight = [] + new_q_bias = [] + new_k_weight = [] + new_k_bias = [] + new_v_weight = [] + new_v_bias = [] + new_out_proj_weight = [] + + for ele in reserve_head_index: + start_idx, end_idx = ele + new_q_weight.append( + self.q_proj.weight[ + start_idx:end_idx, + ] + ) + new_q_bias.append(self.q_proj.bias[start_idx:end_idx]) + + new_k_weight.append( + self.k_proj.weight[ + start_idx:end_idx, + ] + ) + + new_k_bias.append(self.k_proj.bias[start_idx:end_idx]) + + new_v_weight.append( + self.v_proj.weight[ + start_idx:end_idx, + ] + ) + new_v_bias.append(self.v_proj.bias[start_idx:end_idx]) + + new_out_proj_weight.append(self.out_proj.weight[:, start_idx:end_idx]) + + new_q_weight = paddle.concat(new_q_weight).detach() + new_k_weight = paddle.concat(new_k_weight).detach() + new_v_weight = paddle.concat(new_v_weight).detach() + new_out_proj_weight = paddle.concat(new_out_proj_weight, axis=-1).detach() + new_q_weight.stop_gradient = False + new_k_weight.stop_gradient = False + new_v_weight.stop_gradient = False + new_out_proj_weight.stop_gradient = False + + + new_q_bias = paddle.concat(new_q_bias).detach() + new_q_bias.stop_gradient = False + + new_k_bias = paddle.concat(new_k_bias).detach() + new_k_bias.stop_gradient = False + + new_v_bias = paddle.concat(new_v_bias).detach() + new_v_bias.stop_gradient = False + + self.q_proj.weight = paddle.create_parameter(shape=new_q_weight.shape, dtype=new_q_weight.dtype, default_initializer=paddle.nn.initializer.Assign(new_q_weight)) + self.q_proj.bias = paddle.create_parameter(shape=new_q_bias.shape, dtype=new_q_bias.dtype, default_initializer=paddle.nn.initializer.Assign(new_q_bias)) + + self.k_proj.weight = paddle.create_parameter(shape=new_k_weight.shape, dtype=new_k_weight.dtype, default_initializer=paddle.nn.initializer.Assign(new_k_weight)) + self.k_proj.bias = paddle.create_parameter(shape=new_k_bias.shape, dtype=new_k_bias.dtype, default_initializer=paddle.nn.initializer.Assign(new_k_bias)) + + self.v_proj.weight = paddle.create_parameter(shape=new_v_weight.shape, dtype=new_v_weight.dtype, default_initializer=paddle.nn.initializer.Assign(new_v_weight)) + self.v_proj.bias = paddle.create_parameter(shape=new_v_bias.shape, dtype=new_v_bias.dtype, default_initializer=paddle.nn.initializer.Assign(new_v_bias)) + + self.out_proj.weight = paddle.create_parameter(shape=new_out_proj_weight.shape, dtype=new_out_proj_weight.dtype, default_initializer=paddle.nn.initializer.Assign(new_out_proj_weight)) + + self.num_heads = len(reserve_head_index) + self.embed_dim = self.head_dim * self.num_heads + self.q_proj.out_features = self.embed_dim + self.k_proj.out_features = self.embed_dim + self.v_proj.out_features = self.embed_dim + + def _set_skip_embed_dim_check(self): + self.skip_embed_dim_check = True + + def _pad_masks( + self, + key_padding_mask: Optional[Tensor], + attn_mask: Optional[Tensor], + ) -> Tuple[Optional[Tensor], Optional[Tensor]]: + if attn_mask is not None: + shape = attn_mask.shape[:-1] + [1,] + attn_mask = paddle.concat([attn_mask, paddle.zeros(shape, dtype=attn_mask.dtype)], axis=-1) + if key_padding_mask is not None: + shape = key_padding_mask.shape[:-1] + [1,] + key_padding_mask = paddle.concat([key_padding_mask, paddle.zeros(shape, dtype=key_padding_mask.dtype)], axis=-1) + return key_padding_mask, attn_mask + + def _add_bias( + self, + k: Tensor, + v: Tensor, + key_padding_mask: Optional[Tensor], + attn_mask: Optional[Tensor], + bsz: int, + ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: + assert self.bias_k is not None + assert self.bias_v is not None + k = paddle.concat([k, self.bias_k.tile([1, bsz, 1])], axis=-1) + v = paddle.concat([v, self.bias_v.tile([1, bsz, 1])], axis=-1) + key_padding_mask, attn_mask = self._pad_masks( + key_padding_mask=key_padding_mask, attn_mask=attn_mask + ) + return k, v, key_padding_mask, attn_mask + + def _append_zero_attn( + self, + k: Tensor, + v: Tensor, + key_padding_mask: Optional[Tensor], + attn_mask: Optional[Tensor], + ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: + zero_attn_shape = k.shape[:-2] + [1] + k.shape[-1:] + k = paddle.concat( + [k, paddle.zeros(zero_attn_shape, dtype=k.dtype)], axis=-2 + ) + v = paddle.concat( + [v, paddle.zeros(zero_attn_shape, dtype=v.dtype)], axis=-2 + ) + key_padding_mask, attn_mask = self._pad_masks( + key_padding_mask=key_padding_mask, attn_mask=attn_mask + ) + return k, v, key_padding_mask, attn_mask + + def forward( + self, + query, + key: Optional[Tensor], + value: Optional[Tensor], + key_padding_mask: Optional[Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + need_weights: bool = True, + static_kv: bool = False, + attn_mask: Optional[Tensor] = None, + before_softmax: bool = False, + need_head_weights: bool = False, + ) -> Tuple[Tensor, Optional[Tensor]]: + """Input shape: Time x Batch x Channel + + Args: + key_padding_mask (ByteTensor, optional): mask to exclude + keys that are pads, of shape `(batch, src_len)`, where + padding elements are indicated by 1s. + need_weights (bool, optional): return the attention weights, + averaged over heads (default: False). + attn_mask (ByteTensor, optional): typically used to + implement causal attention, where the mask prevents the + attention from looking forward in time (default: None). + before_softmax (bool, optional): return the raw attention + weights and values before the attention softmax. + need_head_weights (bool, optional): return the attention + weights for each head. Implies *need_weights*. Default: + return the average attention weights over all heads. + """ + if need_head_weights: + need_weights = True + + is_tpu = query.place == "xla" + + tgt_len, bsz, embed_dim = query.shape + src_len = tgt_len + if not self.skip_embed_dim_check: + assert ( + embed_dim == self.embed_dim + ), f"query dim {embed_dim} != {self.embed_dim}" + assert list(query.shape) == [tgt_len, bsz, embed_dim] + # if key is not None: + # src_len, key_bsz, _ = key.size() + # if not torch.jit.is_scripting(): + # assert value is not None + # assert src_len, key_bsz == value.shape[:2] + + # if ( + # not self.onnx_trace + # and not is_tpu # don't use PyTorch version on TPUs + # and incremental_state is None + # and not static_kv + # # A workaround for quantization to work. Otherwise JIT compilation + # # treats bias in linear module as method. + # and not torch.jit.is_scripting() + # # The Multihead attention implemented in pytorch forces strong dimension check + # # for input embedding dimention and K,Q,V projection dimension. + # # Since pruning will break the dimension check and it is not easy to modify the pytorch API, + # # it is preferred to bypass the pytorch MHA when we need to skip embed_dim_check + # and not self.skip_embed_dim_check + # ): + # assert key is not None and value is not None + + # if self.use_xformers: + # return self._xformers_attn_forward( + # query, key, value, key_padding_mask, need_weights, attn_mask + # ) + + # else: + # return F.multi_head_attention_forward( + # query, + # key, + # value, + # self.embed_dim, + # self.num_heads, + # torch.empty([0]), + # torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), + # self.bias_k, + # self.bias_v, + # self.add_zero_attn, + # self.dropout_module.p, + # self.out_proj.weight, + # self.out_proj.bias, + # self.training or self.dropout_module.apply_during_inference, + # key_padding_mask, + # need_weights, + # attn_mask, + # use_separate_proj_weight=True, + # q_proj_weight=self.q_proj.weight, + # k_proj_weight=self.k_proj.weight, + # v_proj_weight=self.v_proj.weight, + # ) + + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if saved_state is not None and "prev_key" in saved_state: + # previous time steps are cached - no need to recompute + # key and value if they are static + if static_kv: + assert self.encoder_decoder_attention and not self.self_attention + key = value = None + else: + saved_state = None + + if self.self_attention: + q = self.q_proj(query) + k = self.k_proj(query) + v = self.v_proj(query) + elif self.encoder_decoder_attention: + # encoder-decoder attention + q = self.q_proj(query) + if key is None: + assert value is None + k = v = None + else: + if self.beam_size > 1 and bsz == key.size(1): + # key is [T, bsz*beam_size, C], reduce to [T, bsz, C] + key = key.view(key.size(0), -1, self.beam_size, key.size(2))[ + :, :, 0, : + ] + if key_padding_mask is not None: + key_padding_mask = key_padding_mask.view( + -1, self.beam_size, key_padding_mask.size(1) + )[:, 0, :] + k = self.k_proj(key) + v = self.v_proj(key) + + else: + assert key is not None and value is not None + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + q *= self.scaling + + if self.bias_k is not None: + assert self.bias_v is not None + k, v, attn_mask, key_padding_mask = self._add_bias( + k, v, attn_mask, key_padding_mask, bsz + ) + + q = paddle.reshape(q, [tgt_len, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2]) + kv_bsz = bsz # need default value for scripting + if k is not None: + kv_bsz = k.shape[1] + k = paddle.reshape(k, [-1, kv_bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2]) + if v is not None: + v = paddle.reshape(v, [-1, kv_bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2]) + + if saved_state is not None: + # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) + if "prev_key" in saved_state: + _prev_key = saved_state["prev_key"] + assert _prev_key is not None + kv_bsz = _prev_key.shape[0] + prev_key = _prev_key.reshape([kv_bsz * self.num_heads, -1, self.head_dim]) + if static_kv: + k = prev_key + else: + assert k is not None + k = paddle.concat([prev_key, k], axis=1) + src_len = k.shape[1] + if "prev_value" in saved_state: + _prev_value = saved_state["prev_value"] + assert _prev_value is not None + assert kv_bsz == _prev_value.size(0) + prev_value = _prev_value.reshape( + [kv_bsz * self.num_heads, -1, self.head_dim] + ) + if static_kv: + v = prev_value + else: + assert v is not None + v = paddle.concat([prev_value, v], axis=1) + prev_key_padding_mask: Optional[Tensor] = None + if "prev_key_padding_mask" in saved_state: + prev_key_padding_mask = saved_state["prev_key_padding_mask"] + assert k is not None and v is not None + key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( + key_padding_mask=key_padding_mask, + prev_key_padding_mask=prev_key_padding_mask, + batch_size=kv_bsz, + src_len=k.shape[1], + static_kv=static_kv, + ) + + saved_state["prev_key"] = k.reshape([kv_bsz, self.num_heads, -1, self.head_dim]) + saved_state["prev_value"] = v.reshape( + [kv_bsz, self.num_heads, -1, self.head_dim] + ) + saved_state["prev_key_padding_mask"] = key_padding_mask + # In this branch incremental_state is never None + assert incremental_state is not None + incremental_state = self._set_input_buffer(incremental_state, saved_state) + assert k is not None + assert k.shape[1] == src_len + + # This is part of a workaround to get around fork/join parallelism + # not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.dim() == 0: + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.shape[0] == kv_bsz + assert key_padding_mask.shape[1] == src_len + + if self.add_zero_attn: + assert v is not None + src_len += 1 + k, v, key_padding_mask, attn_mask = self._append_zero_attn( + k=k, v=v, key_padding_mask=key_padding_mask, attn_mask=attn_mask + ) + + if self.encoder_decoder_attention and bsz != kv_bsz: + attn_weights = paddle.einsum( + "bxhtd,bhsd->bxhts", + q.reshape([kv_bsz, -1, self.num_heads] + q.shape[1:]), + k.reshape([kv_bsz, self.num_heads] + k.shape[1:]), + ) + attn_weights = attn_weights.reshape([-1,] + attn_weights.shape[-2:]) + else: + attn_weights = paddle.bmm(q, k.transpose([0, 2, 1])) + attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) + + assert list(attn_weights.shape) == [bsz * self.num_heads, tgt_len, src_len] + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + if self.onnx_trace: + attn_mask = attn_mask.tile([attn_weights.shape[0], 1, 1]) + attn_weights += attn_mask + + if key_padding_mask is not None: + # don't attend to padding symbols + attn_weights = attn_weights.reshape([bsz, self.num_heads, tgt_len, src_len]) + if not is_tpu: + attn_weights = attn_weights.reshape( + [kv_bsz, -1, self.num_heads, tgt_len, src_len] + ) + attn_weights = paddle.where( + key_padding_mask.unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .astype('bool'), + float('-inf') * paddle.ones_like(attn_weights), + attn_weights + ) + else: + attn_weights = attn_weights.transpose([2, 1, 0]) + attn_weights = paddle.where(key_padding_mask, float('-inf') * paddle.ones_like(attn_weights), attn_weights) + attn_weights = attn_weights.transpose([2, 1, 0]) + attn_weights = attn_weights.reshape([bsz * self.num_heads, tgt_len, src_len]) + + if before_softmax: + return attn_weights, v + + def softmax_supporting_onnx_trace(x, dim: int, onnx_trace: bool = False): + if onnx_trace: + return F.softmax(x, axis=dim) + else: + return F.softmax(x, axis=dim, dtype='float32') + + attn_weights_float = softmax_supporting_onnx_trace( + attn_weights, dim=-1, onnx_trace=self.onnx_trace + ) + attn_weights = paddle.cast(attn_weights_float, attn_weights.dtype) + attn_probs = self.dropout_module(attn_weights) + + assert v is not None + if self.encoder_decoder_attention and bsz != kv_bsz: + attn = paddle.einsum( + "bxhts,bhsd->bxhtd", + attn_probs.reshape( + [kv_bsz, + -1, + self.num_heads] + + attn_probs.shape[1:] + ), + v.reshape( + [kv_bsz, + self.num_heads] + + v.shape[1:] + ), + ) + attn = attn.reshape([-1,] + attn.shape[-2:]) + else: + attn = paddle.bmm(attn_probs, v) + assert list(attn.shape) == [bsz * self.num_heads, tgt_len, self.head_dim] + if self.onnx_trace and attn.shape[1] == 1: + # when ONNX tracing a single decoder step (sequence length == 1) + # the transpose is a no-op copy before view, thus unnecessary + attn = attn.reshape([tgt_len, bsz, self.embed_dim]) + else: + attn = attn.transpose([1, 0, 2]).reshape([tgt_len, bsz, self.embed_dim]) + attn = self.out_proj(attn) + attn_weights: Optional[Tensor] = None + if need_weights: + attn_weights = attn_weights_float.reshape( + [bsz, self.num_heads, tgt_len, src_len] + ).transpose([1, 0, 2, 3]) + if not need_head_weights: + # average attention weights over heads + attn_weights = attn_weights.mean(axis=0) + + return attn, attn_weights + + @staticmethod + def _append_prev_key_padding_mask( + key_padding_mask: Optional[Tensor], + prev_key_padding_mask: Optional[Tensor], + batch_size: int, + src_len: int, + static_kv: bool, + ) -> Optional[Tensor]: + # saved key padding masks have shape (bsz, seq_len) + if prev_key_padding_mask is not None and static_kv: + new_key_padding_mask = prev_key_padding_mask + elif prev_key_padding_mask is not None and key_padding_mask is not None: + new_key_padding_mask = paddle.concat( + [paddle.cast(prev_key_padding_mask, 'float32'), paddle.cast(key_padding_mask, 'float32')], axis==1 + ) + # During incremental decoding, as the padding token enters and + # leaves the frame, there will be a time when prev or current + # is None + elif prev_key_padding_mask is not None: + if src_len > prev_key_padding_mask.shape[1]: + filler = paddle.zeros( + [batch_size, src_len - prev_key_padding_mask.shape[1]], + ) + new_key_padding_mask = paddle.concat( + [paddle.cast(prev_key_padding_mask, 'float32'), paddle.cast(filler, 'float32')], axis==1 + ) + else: + new_key_padding_mask = prev_key_padding_mask + elif key_padding_mask is not None: + if src_len > key_padding_mask.shape[1]: + filler = paddle.zeros( + [batch_size, src_len - key_padding_mask.shape[1]], + ) + new_key_padding_mask = paddle.concat( + [paddle.cast(filler,'float32'), paddle.cast(key_padding_mask,'float32')], axis==1 + ) + else: + new_key_padding_mask = paddle.cast(key_padding_mask,'float32') + else: + new_key_padding_mask = prev_key_padding_mask + return new_key_padding_mask + + @paddle.jit.to_static + def reorder_incremental_state( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + new_order: Tensor, + ): + """Reorder buffered internal state (for incremental generation).""" + input_buffer = self._get_input_buffer(incremental_state) + if input_buffer is not None: + for k in input_buffer.keys(): + input_buffer_k = input_buffer[k] + if input_buffer_k is not None: + if self.encoder_decoder_attention: + if input_buffer_k.shape[0] * self.beam_size == new_order.shape[0]: + return incremental_state + elif self.beam_size > 1: + input_buffer[k] = paddle.index_select( + input_buffer_k, + index=new_order.reshape([-1, self.beam_size])[:, 0] // self.beam_size, + axis=0, + ) + else: + input_buffer[k] = paddle.index_select(input_buffer_k, index=new_order, axis=0) + else: + input_buffer[k] = paddle.index_select(input_buffer_k, index=new_order, axis=0) + incremental_state = self._set_input_buffer(incremental_state, input_buffer) + return incremental_state + + def set_beam_size(self, beam_size): + """Used for effiecient beamable enc-dec attention""" + self.beam_size = beam_size + + def _get_input_buffer( + self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] + ) -> Dict[str, Optional[Tensor]]: + result = self.get_incremental_state(incremental_state, "attn_state") + if result is not None: + return result + else: + empty_result: Dict[str, Optional[Tensor]] = {} + return empty_result + + def _set_input_buffer( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + buffer: Dict[str, Optional[Tensor]], + ): + return self.set_incremental_state(incremental_state, "attn_state", buffer) + + def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int): + return attn_weights + + def upgrade_state_dict_named(self, state_dict, name): + prefix = name + "." if name != "" else "" + items_to_add = {} + keys_to_remove = [] + for k in state_dict.keys(): + if k.endswith(prefix + "in_proj_weight"): + # in_proj_weight used to be q + k + v with same dimensions + dim = int(state_dict[k].shape[0] / 3) + items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim] + items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim] + items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :] + + keys_to_remove.append(k) + + k_bias = prefix + "in_proj_bias" + if k_bias in state_dict.keys(): + dim = int(state_dict[k].shape[0] / 3) + items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim] + items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][ + dim : 2 * dim + ] + items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :] + + keys_to_remove.append(prefix + "in_proj_bias") + + for k in keys_to_remove: + del state_dict[k] + + for key, value in items_to_add.items(): + state_dict[key] = value + +class GumbelVectorQuantizer(nn.Layer): + def __init__( + self, + dim, + num_vars, + temp, + groups, + combine_groups, + vq_dim, + time_first, + activation=nn.GELU(), + weight_proj_depth=1, + weight_proj_factor=1, + ): + """Vector quantization using gumbel softmax + + Args: + dim: input dimension (channels) + num_vars: number of quantized vectors per group + temp: temperature for training. this should be a tuple of 3 elements: (start, stop, decay factor) + groups: number of groups for vector quantization + combine_groups: whether to use the vectors for all groups + vq_dim: dimensionality of the resulting quantized vector + time_first: if true, expect input in BxTxC format, otherwise in BxCxT + activation: what activation to use (should be a module). this is only used if weight_proj_depth is > 1 + weight_proj_depth: number of layers (with activation in between) to project input before computing logits + weight_proj_factor: this is used only if weight_proj_depth is > 1. scales the inner dimensionality of + projections by this factor + """ + super().__init__() + + self.groups = groups + self.combine_groups = combine_groups + self.input_dim = dim + self.num_vars = num_vars + self.time_first = time_first + + assert ( + vq_dim % groups == 0 + ), f"dim {vq_dim} must be divisible by groups {groups} for concatenation" + + var_dim = vq_dim // groups + num_groups = groups if not combine_groups else 1 + + self.vars = self.create_parameter((1, num_groups * num_vars, var_dim), default_initializer=nn.initializer.Uniform()) + + + if weight_proj_depth > 1: + + def block(input_dim, output_dim): + return nn.Sequential(nn.Linear(input_dim, output_dim), activation) + + inner_dim = self.input_dim * weight_proj_factor + self.weight_proj = nn.Sequential( + *[ + block(self.input_dim if i == 0 else inner_dim, inner_dim) + for i in range(weight_proj_depth - 1) + ], + nn.Linear(inner_dim, groups * num_vars), + ) + else: + self.weight_proj = nn.Linear(self.input_dim, groups * num_vars) + nn.initializer.Normal(mean=0, std=1)(self.weight_proj.weight) + nn.initializer.Zero()(self.weight_proj.bias) + + if isinstance(temp, str): + import ast + + temp = ast.literal_eval(temp) + assert len(temp) == 3, f"{temp}, {len(temp)}" + + self.max_temp, self.min_temp, self.temp_decay = temp + self.curr_temp = self.max_temp + self.codebook_indices = None + + def set_num_updates(self, num_updates): + self.curr_temp = max( + self.max_temp * self.temp_decay**num_updates, self.min_temp + ) + + def get_codebook_indices(self): + if self.codebook_indices is None: + from itertools import product + + p = [range(self.num_vars)] * self.groups + inds = list(product(*p)) + self.codebook_indices = paddle.to_tensor( + inds, dtype='int64', place=self.vars.place + ).flatten() + + if not self.combine_groups: + self.codebook_indices = self.codebook_indices.reshape( + self.num_vars**self.groups, -1 + ) + for b in range(1, self.groups): + self.codebook_indices[:, b] += self.num_vars * b + self.codebook_indices = self.codebook_indices.flatten() + return self.codebook_indices + + def codebook(self): + indices = self.get_codebook_indices() + return ( + self.vars.squeeze(0) + .index_select(0, indices) + .reshape(self.num_vars**self.groups, -1) + ) + + def sample_from_codebook(self, b, n): + indices = self.get_codebook_indices() + indices = indices.reshape(-1, self.groups) + cb_size = indices.shape[0] + assert ( + n < cb_size + ), f"sample size {n} is greater than size of codebook {cb_size}" + sample_idx = paddle.randint(low=0, high=cb_size, shape=(b * n,)) + indices = indices[sample_idx] + + z = self.vars.squeeze(0).index_select(0, indices.flatten()).reshape(b, n, -1) + return z + + def to_codebook_index(self, indices): + res = paddle.full(indices.shape[:-1], 0, dtype=indices.dtype) + for i in range(self.groups): + exponent = self.groups - i - 1 + res += indices[..., i] * (self.num_vars**exponent) + return res + + def forward_idx(self, x): + res = self.forward(x, produce_targets=True) + return res["x"], res["targets"] + + def forward(self, x, produce_targets=False): + result = {"num_vars": self.num_vars * self.groups} + + if not self.time_first: + x = x.transpose([0, 2, 1]) + + bsz, tsz, fsz = x.shape + x = x.reshape([-1, fsz]) + x = self.weight_proj(x) + x = x.reshape([bsz * tsz * self.groups, -1]) + + _, k = x.max(-1) + hard_x = paddle.zeros_like(x) + hard_x.scatter_(-1, k.reshape([-1, 1]), 1.0) + hard_x = hard_x.reshape([bsz * tsz, self.groups, -1]) + hard_probs = paddle.mean(hard_x.astype('float32'), axis=0) + result["code_perplexity"] = paddle.exp( + -paddle.sum(hard_probs * paddle.log(hard_probs + 1e-7), axis=-1) + ).sum() + + avg_probs = F.softmax(x.reshape([bsz * tsz, self.groups, -1]).astype('float32'), axis=-1).mean(axis=0) + result["prob_perplexity"] = paddle.exp( + -paddle.sum(avg_probs * paddle.log(avg_probs + 1e-7), axis=-1) + ).sum() + + + result["temp"] = self.curr_temp + + if self.training: + x = F.gumbel_softmax(x.astype('float32'), tau=self.curr_temp, hard=True).astype(x.dtype) + else: + x = hard_x + + + x = x.reshape([bsz * tsz, -1]) + + vars = self.vars + if self.combine_groups: + vars = vars.tile([1, self.groups, 1]) + + if produce_targets: + result["targets"] = ( + x.reshape([bsz * tsz * self.groups, -1]) + .argmax(axis=-1) + .reshape([bsz, tsz, self.groups]) + .detach() + ) + + x = x.unsqueeze(-1) * vars + x = x.reshape([bsz * tsz, self.groups, self.num_vars, -1]) + x = x.sum(axis=-2) + x = x.reshape([bsz, tsz, -1]) + + if not self.time_first: + x = x.transpose([0, 2, 1]) + + result["x"] = x + + return result + +class GradMultiply(paddle.autograd.PyLayer): + @staticmethod + def forward(ctx, x, scale): + ctx.scale = scale + res = x.numpy().copy() + return paddle.to_tensor(res, dtype=x.dtype) + + @staticmethod + def backward(ctx, grad): + return grad * ctx.scale, None + + +class SamePad(nn.Layer): + def __init__(self, kernel_size, causal=False): + super().__init__() + if causal: + self.remove = kernel_size - 1 + else: + self.remove = 1 if kernel_size % 2 == 0 else 0 + + def forward(self, x): + if self.remove > 0: + x = x[:, :, : -self.remove] + return x + + +class TransposeLast(nn.Layer): + def __init__(self, deconstruct_idx=None): + super().__init__() + self.deconstruct_idx = deconstruct_idx + + def forward(self, x): + if self.deconstruct_idx is not None: + x = x[self.deconstruct_idx] + trans_dim = paddle.arange(x.dim()) + trans_dim[-1], trans_dim[-2] = trans_dim[-2], trans_dim[-1] + return x.transpose(trans_dim) + + +def LayerNorm(normalized_shape, eps=1e-5): + return nn.LayerNorm(normalized_shape, epsilon=eps, weight_attr=paddle.ParamAttr(), bias_attr=paddle.ParamAttr()) + + +class Fp32LayerNorm(nn.LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + # import pdb + # pdb.set_trace() + output = F.layer_norm( + input.astype('float32'), + self._normalized_shape, + self.weight.astype('float32') if self.weight is not None else None, + self.bias.astype('float32') if self.bias is not None else None, + self._epsilon, + ) + return output.astype(input.dtype) + +class Fp32GroupNorm(nn.GroupNorm): + def __init__(self, *args, **kwargs): + super().__init__( *args, **kwargs) + def forward(self, input): + # import pdb + # pdb.set_trace() + output = F.group_norm( + input.astype('float32'), + self._num_groups, + self.weight.astype('float32') if self.weight is not None else None, + self.bias.astype('float32') if self.bias is not None else None, + self._epsilon, + ) + return output.astype(input.dtype) + + +class StrEnumMeta(EnumMeta): + # this is workaround for submitit pickling leading to instance checks failing in hydra for StrEnum, see + # https://github.com/facebookresearch/hydra/issues/1156 + @classmethod + def __instancecheck__(cls, other): + return "enum" in str(type(other)) + + +class StrEnum(Enum, metaclass=StrEnumMeta): + def __str__(self): + return self.value + + def __eq__(self, other: str): + return self.value == other + + def __repr__(self): + return self.value + + def __hash__(self): + return hash(str(self)) + + +def ChoiceEnum(choices: List[str]): + """return the Enum class used to enforce list of choices""" + return StrEnum("Choices", {k: k for k in choices}) + + +def relu_squared(x: paddle.Tensor): + return F.relu(x).pow(2) + + +def get_activation_fn(activation: str) -> Callable: + """Returns the activation function corresponding to `activation`""" + + def gelu_accurate(x): + if not hasattr(gelu_accurate, "_a"): + gelu_accurate._a = math.sqrt(2 / math.pi) + return ( + 0.5 + * x + * (1 + paddle.tanh(gelu_accurate._a * (x + 0.044715 * paddle.pow(x, 3)))) + ) + + def gelu(x: paddle.Tensor) -> paddle.Tensor: + return paddle.nn.functional.gelu(x.astype('float32')).astype(x.dtype) + + if activation == "relu": + return F.relu + elif activation == "relu_squared": + return relu_squared + elif activation == "gelu": + return gelu + elif activation == "gelu_fast": + return gelu_accurate + elif activation == "gelu_accurate": + return gelu_accurate + elif activation == "tanh": + return paddle.tanh + elif activation == "linear": + return lambda x: x + elif activation == "swish": + return paddle.nn.Swish + else: + raise RuntimeError("--activation-fn {} not supported".format(activation)) + + +def get_available_activation_fns() -> List: + return [ + "relu", + "gelu", + "gelu_fast", # deprecated + "gelu_accurate", + "tanh", + "linear", + ] + + +def compute_mask_indices( + shape: Tuple[int, int], + padding_mask: Optional[paddle.Tensor], + mask_prob: float, + mask_length: int, + mask_type: str = "static", + mask_other: float = 0.0, + min_masks: int = 0, + no_overlap: bool = False, + min_space: int = 0, + require_same_masks: bool = True, + mask_dropout: float = 0.0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape + + Args: + shape: the the shape for which to compute masks. + should be of size 2 where first element is batch size and 2nd is timesteps + padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements + mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by + number of timesteps divided by length of mask span to mask approximately this percentage of all elements. + however due to overlaps, the actual number will be smaller (unless no_overlap is True) + mask_type: how to compute mask lengths + static = fixed size + uniform = sample from uniform distribution [mask_other, mask_length*2] + normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element + poisson = sample from possion distribution with lambda = mask length + min_masks: minimum number of masked spans + no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping + min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans + require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample + mask_dropout: randomly dropout this percentage of masks in each example + """ + + bsz, all_sz = shape + mask = np.full((bsz, all_sz), False) + + all_num_mask = int( + # add a random number for probabilistic rounding + mask_prob * all_sz / float(mask_length) + + np.random.rand() + ) + + all_num_mask = max(min_masks, all_num_mask) + + mask_idcs = [] + for i in range(bsz): + if padding_mask is not None: + sz = all_sz - padding_mask[i].long().sum().item() + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + np.random.rand() + ) + num_mask = max(min_masks, num_mask) + else: + sz = all_sz + num_mask = all_num_mask + + if mask_type == "static": + lengths = np.full(num_mask, mask_length) + elif mask_type == "uniform": + lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask) + elif mask_type == "normal": + lengths = np.random.normal(mask_length, mask_other, size=num_mask) + lengths = [max(1, int(round(x))) for x in lengths] + elif mask_type == "poisson": + lengths = np.random.poisson(mask_length, size=num_mask) + lengths = [int(round(x)) for x in lengths] + else: + raise Exception("unknown mask selection " + mask_type) + + if sum(lengths) == 0: + lengths[0] = min(mask_length, sz - 1) + + if no_overlap: + mask_idc = [] + + def arrange(s, e, length, keep_length): + span_start = np.random.randint(s, e - length) + mask_idc.extend(span_start + i for i in range(length)) + + new_parts = [] + if span_start - s - min_space >= keep_length: + new_parts.append((s, span_start - min_space + 1)) + if e - span_start - length - min_space > keep_length: + new_parts.append((span_start + length + min_space, e)) + return new_parts + + parts = [(0, sz)] + min_length = min(lengths) + for length in sorted(lengths, reverse=True): + lens = np.fromiter( + (e - s if e - s >= length + min_space else 0 for s, e in parts), + np.int, + ) + l_sum = np.sum(lens) + if l_sum == 0: + break + probs = lens / np.sum(lens) + c = np.random.choice(len(parts), p=probs) + s, e = parts.pop(c) + parts.extend(arrange(s, e, length, min_length)) + mask_idc = np.asarray(mask_idc) + else: + min_len = min(lengths) + if sz - min_len <= num_mask: + min_len = sz - num_mask - 1 + + mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) + + mask_idc = np.asarray( + [ + mask_idc[j] + offset + for j in range(len(mask_idc)) + for offset in range(lengths[j]) + ] + ) + + mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) + + min_len = min([len(m) for m in mask_idcs]) + for i, mask_idc in enumerate(mask_idcs): + if len(mask_idc) > min_len and require_same_masks: + mask_idc = np.random.choice(mask_idc, min_len, replace=False) + if mask_dropout > 0: + num_holes = np.rint(len(mask_idc) * mask_dropout).astype(int) + mask_idc = np.random.choice( + mask_idc, len(mask_idc) - num_holes, replace=False + ) + + mask[i, mask_idc] = True + + return mask + + +def index_put(tensor, indices, value): + tensor[indices] = value + return tensor + + +# ToDo if faster? +def buffered_arange(max): + if not hasattr(buffered_arange, "buf"): + buffered_arange.buf = paddle.empty([max], dtype='int64') + if max > buffered_arange.buf.numel(): + buffered_arange.buf = paddle.arange(max) + return buffered_arange.buf[:max] + + +def pad_to_multiple(x, multiple, dim=-1, value=0): + # Inspired from https://github.com/lucidrains/local-attention/blob/master/local_attention/local_attention.py#L41 + if x is None: + return None, 0 + tsz = x.shape[dim] + m = tsz / multiple + remainder = math.ceil(m) * multiple - tsz + if m.is_integer(): + return x, 0 + pad_offset = (0,) * (-1 - dim) * 2 + return F.pad(x, pad=[*pad_offset, 0, remainder, *pad_offset], value=value, data_format='NLC'), remainder + + +EXTRACTOR_MODE_CHOICES = ChoiceEnum(["default", "layer_norm"]) +MASKING_DISTRIBUTION_CHOICES = ChoiceEnum(["static", "uniform", "normal", "poisson"]) +LAYER_TYPE_CHOICES = ChoiceEnum(["transformer"]) # ToDo: conformer + + +@dataclass +class Wav2Vec2Config: + extractor_mode: EXTRACTOR_MODE_CHOICES = field( + default="default", + metadata={ + "help": "mode for feature extractor. default has a single group norm with d " + "groups in the first conv block, whereas layer_norm has layer norms in " + "every block (meant to use with normalize=True)" + }, + ) + encoder_layers: int = field( + default=12, metadata={"help": "num encoder layers in the transformer"} + ) + encoder_embed_dim: int = field( + default=768, metadata={"help": "encoder embedding dimension"} + ) + encoder_ffn_embed_dim: int = field( + default=3072, metadata={"help": "encoder embedding dimension for FFN"} + ) + encoder_attention_heads: int = field( + default=12, metadata={"help": "num encoder attention heads"} + ) + activation_fn: ChoiceEnum(get_available_activation_fns()) = field( + default="gelu", metadata={"help": "activation function to use"} + ) + layer_type: LAYER_TYPE_CHOICES = field( + default="transformer", metadata={"help": "layer type in encoder"} + ) + # dropouts + dropout: float = field( + default=0.1, metadata={"help": "dropout probability for the transformer"} + ) + attention_dropout: float = field( + default=0.1, metadata={"help": "dropout probability for attention weights"} + ) + activation_dropout: float = field( + default=0.0, metadata={"help": "dropout probability after activation in FFN"} + ) + encoder_layerdrop: float = field( + default=0.0, metadata={"help": "probability of dropping a tarnsformer layer"} + ) + dropout_input: float = field( + default=0.0, + metadata={"help": "dropout to apply to the input (after feat extr)"}, + ) + dropout_features: float = field( + default=0.0, + metadata={"help": "dropout to apply to the features (after feat extr)"}, + ) + + final_dim: int = field( + default=0, + metadata={ + "help": "project final representations and targets to this many dimensions." + "set to encoder_embed_dim is <= 0" + }, + ) + layer_norm_first: bool = field( + default=False, metadata={"help": "apply layernorm first in the transformer"} + ) + conv_feature_layers: str = field( + default="[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]", + metadata={ + "help": "string describing convolutional feature extraction layers in form of a python list that contains " + "[(dim, kernel_size, stride), ...]" + }, + ) + conv_bias: bool = field( + default=False, metadata={"help": "include bias in conv encoder"} + ) + logit_temp: float = field( + default=0.1, metadata={"help": "temperature to divide logits by"} + ) + quantize_targets: bool = field( + default=False, metadata={"help": "use quantized targets"} + ) + quantize_input: bool = field( + default=False, metadata={"help": "use quantized inputs"} + ) + same_quantizer: bool = field( + default=False, metadata={"help": "use same quantizer for inputs and targets"} + ) + target_glu: bool = field( + default=False, metadata={"help": "adds projection + glu to targets"} + ) + feature_grad_mult: float = field( + default=1.0, metadata={"help": "multiply feature extractor var grads by this"} + ) + quantizer_depth: int = field( + default=1, + metadata={"help": "number of quantizer layers"}, + ) + quantizer_factor: int = field( + default=3, + metadata={ + "help": "dimensionality increase for inner quantizer layers (if depth > 1)" + }, + ) + latent_vars: int = field( + default=320, + metadata={"help": "number of latent variables V in each group of the codebook"}, + ) + latent_groups: int = field( + default=2, + metadata={"help": "number of groups G of latent variables in the codebook"}, + ) + latent_dim: int = field( + default=0, + metadata={ + "help": "if > 0, uses this dimensionality for latent variables. " + "otherwise uses final_dim / latent_groups" + }, + ) + + # masking + mask_length: int = field(default=10, metadata={"help": "mask length"}) + mask_prob: float = field( + default=0.65, metadata={"help": "probability of replacing a token with mask"} + ) + mask_selection: MASKING_DISTRIBUTION_CHOICES = field( + default="static", metadata={"help": "how to choose mask length"} + ) + mask_other: float = field( + default=0, + metadata={ + "help": "secondary mask argument (used for more complex distributions), " + "see help in compute_mask_indices" + }, + ) + no_mask_overlap: bool = field( + default=False, metadata={"help": "whether to allow masks to overlap"} + ) + mask_min_space: int = field( + default=1, + metadata={"help": "min space between spans (if no overlap is enabled)"}, + ) + require_same_masks: bool = field( + default=True, + metadata={ + "help": "whether to number of masked timesteps must be the same across all " + "examples in a batch" + }, + ) + mask_dropout: float = field( + default=0.0, + metadata={"help": "percent of masks to unmask for each sample"}, + ) + + # channel masking + mask_channel_length: int = field( + default=10, metadata={"help": "length of the mask for features (channels)"} + ) + mask_channel_prob: float = field( + default=0.0, metadata={"help": "probability of replacing a feature with 0"} + ) + mask_channel_before: bool = False + mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field( + default="static", + metadata={"help": "how to choose mask length for channel masking"}, + ) + mask_channel_other: float = field( + default=0, + metadata={ + "help": "secondary mask argument (used for more complex distributions), " + "see help in compute_mask_indicesh" + }, + ) + no_mask_channel_overlap: bool = field( + default=False, metadata={"help": "whether to allow channel masks to overlap"} + ) + mask_channel_min_space: int = field( + default=1, + metadata={"help": "min space between spans (if no overlap is enabled)"}, + ) + + # negative selection + num_negatives: int = field( + default=100, + metadata={"help": "number of negative examples from the same sample"}, + ) + negatives_from_everywhere: bool = field( + default=False, + metadata={"help": "sample negatives from everywhere, not just masked states"}, + ) + cross_sample_negatives: int = field( + default=0, metadata={"help": "number of negative examples from the any sample"} + ) + codebook_negatives: int = field( + default=0, metadata={"help": "number of negative examples codebook"} + ) + + # positional embeddings + conv_pos: int = field( + default=128, + metadata={"help": "number of filters for convolutional positional embeddings"}, + ) + conv_pos_groups: int = field( + default=16, + metadata={"help": "number of groups for convolutional positional embedding"}, + ) + pos_conv_depth: int = field( + default=1, + metadata={"help": "depth of positional encoder network"}, + ) + + latent_temp: Tuple[float, float, float] = field( + default=(2, 0.5, 0.999995), + metadata={ + "help": "temperature for latent variable sampling. " + "can be tuple of 3 values (start, end, decay)" + }, + ) + max_positions: int = field(default=100000, metadata={"help": "Max positions"}) + checkpoint_activations: bool = field( + default=False, + metadata={"help": "recompute activations and save memory for extra compute"}, + ) + + # FP16 optimization + required_seq_len_multiple: int = field( + default=2, + metadata={ + "help": "pad the input to encoder such that the sequence length is divisible by multiple" + }, + ) + crop_seq_to_multiple: int = field( + default=1, + metadata={ + "help": "crop convolutional feature extractor output such that the sequence length is divisible by multiple" + }, + ) + + # Conformer + depthwise_conv_kernel_size: int = field( + default=31, + metadata={ + "help": "depthwise-conv-kernel-size for convolution in conformer layer" + }, + ) + attn_type: str = field( + default="", + metadata={"help": "if espnet use ESPNET MHA"}, + ) + pos_enc_type: str = field( + default="abs", + metadata={"help": "Positional encoding type to use in conformer"}, + ) + fp16: bool = field(default=False, metadata={"help": "If fp16 is being used"}) + + +class Wav2Vec2Model(nn.Layer): + def __init__(self, cfg: Wav2Vec2Config): + super().__init__() + self.cfg = cfg + + feature_enc_layers = eval(cfg.conv_feature_layers) + self.embed = feature_enc_layers[-1][0] + + self.feature_extractor = ConvFeatureExtractionModel( + conv_layers=feature_enc_layers, + dropout=0.0, + mode=cfg.extractor_mode, + conv_bias=cfg.conv_bias, + ) + + self.post_extract_proj = ( + nn.Linear(self.embed, cfg.encoder_embed_dim) + if self.embed != cfg.encoder_embed_dim and not cfg.quantize_input + else None + ) + + self.crop_seq_to_multiple = cfg.crop_seq_to_multiple + + self.mask_prob = cfg.mask_prob + self.mask_selection = cfg.mask_selection + self.mask_other = cfg.mask_other + self.mask_length = cfg.mask_length + self.no_mask_overlap = cfg.no_mask_overlap + self.mask_min_space = cfg.mask_min_space + + self.mask_channel_prob = cfg.mask_channel_prob + self.mask_channel_before = cfg.mask_channel_before + self.mask_channel_selection = cfg.mask_channel_selection + self.mask_channel_other = cfg.mask_channel_other + self.mask_channel_length = cfg.mask_channel_length + self.no_mask_channel_overlap = cfg.no_mask_channel_overlap + self.mask_channel_min_space = cfg.mask_channel_min_space + + self.dropout_input = nn.Dropout(cfg.dropout_input) + self.dropout_features = nn.Dropout(cfg.dropout_features) + + self.feature_grad_mult = cfg.feature_grad_mult + + self.quantizer = None + self.input_quantizer = None + + self.n_negatives = cfg.num_negatives + self.cross_sample_negatives = cfg.cross_sample_negatives + self.codebook_negatives = cfg.codebook_negatives + self.negatives_from_everywhere = cfg.negatives_from_everywhere + + self.logit_temp = cfg.logit_temp + + final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim + + if cfg.quantize_targets: + vq_dim = cfg.latent_dim if cfg.latent_dim > 0 else final_dim + self.quantizer = GumbelVectorQuantizer( + dim=self.embed, + num_vars=cfg.latent_vars, + temp=cfg.latent_temp, + groups=cfg.latent_groups, + combine_groups=False, + vq_dim=vq_dim, + time_first=True, + weight_proj_depth=cfg.quantizer_depth, + weight_proj_factor=cfg.quantizer_factor, + ) + self.project_q = nn.Linear(vq_dim, final_dim) + else: + self.project_q = nn.Linear(self.embed, final_dim) + + if cfg.quantize_input: + if cfg.same_quantizer and self.quantizer is not None: + vq_dim = final_dim + self.input_quantizer = self.quantizer + else: + vq_dim = cfg.latent_dim if cfg.latent_dim > 0 else cfg.encoder_embed_dim + self.input_quantizer = GumbelVectorQuantizer( + dim=self.embed, + num_vars=cfg.latent_vars, + temp=cfg.latent_temp, + groups=cfg.latent_groups, + combine_groups=False, + vq_dim=vq_dim, + time_first=True, + weight_proj_depth=cfg.quantizer_depth, + weight_proj_factor=cfg.quantizer_factor, + ) + self.project_inp = nn.Linear(vq_dim, cfg.encoder_embed_dim) + + self.mask_emb = self.create_parameter( + shape=[cfg.encoder_embed_dim], + default_initializer=paddle.nn.initializer.Uniform(), + dtype='float32', + ) + + encoder_cls = TransformerEncoder + + self.encoder = encoder_cls(cfg) + self.layer_norm = LayerNorm(self.embed) + + self.target_glu = None + if cfg.target_glu: + self.target_glu = nn.Sequential( + nn.Linear(final_dim, final_dim * 2), GLU() + ) + + self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim) + + def upgrade_state_dict_named(self, state_dict, name): + super().upgrade_state_dict_named(state_dict, name) + """Upgrade a (possibly old) state dict for new versions of fairseq.""" + return state_dict + + @classmethod + def build_model(cls, cfg: Wav2Vec2Config, task=None): + """Build a new model instance.""" + return cls(cfg) + + def apply_mask( + self, + x, + padding_mask, + mask_indices=None, + mask_channel_indices=None, + ): + B, T, C = x.shape + + if self.mask_channel_prob > 0 and self.mask_channel_before: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.mask_channel_prob, + self.mask_channel_length, + self.mask_channel_selection, + self.mask_channel_other, + no_overlap=self.no_mask_channel_overlap, + min_space=self.mask_channel_min_space, + ) + mask_channel_indices = ( + paddle.to_tensor(mask_channel_indices, plcae=x.plcae) + .unsqueeze(1) + .expand([-1, T, -1]) + ) + x[mask_channel_indices] = 0 + + if self.mask_prob > 0: + if mask_indices is None: + mask_indices = compute_mask_indices( + (B, T), + padding_mask, + self.mask_prob, + self.mask_length, + self.mask_selection, + self.mask_other, + min_masks=2, + no_overlap=self.no_mask_overlap, + min_space=self.mask_min_space, + require_same_masks=self.cfg.require_same_masks, + mask_dropout=self.cfg.mask_dropout, + ) + mask_indices = paddle.to_tensor(mask_indices, place=x.place) + x = index_put(x, mask_indices, self.mask_emb) + else: + mask_indices = None + + if self.mask_channel_prob > 0 and not self.mask_channel_before: + if mask_channel_indices is None: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.mask_channel_prob, + self.mask_channel_length, + self.mask_channel_selection, + self.mask_channel_other, + no_overlap=self.no_mask_channel_overlap, + min_space=self.mask_channel_min_space, + ) + mask_channel_indices = ( + paddle.to_tensor(mask_channel_indices, place=x.place) + .unsqueeze(1) + .expand([-1, T, -1]) + ) + x = index_put(x, mask_channel_indices, 0) + + return x, mask_indices + + def sample_negatives(self, y, num, padding_count=None): + + if self.n_negatives == 0 and self.cross_sample_negatives == 0: + return paddle.empty([0], dtype=y.dtype) + + bsz, tsz, fsz = y.shape + y = y.reshape([-1, fsz]) # BTC => (BxT)C + + # FIXME: what happens if padding_count is specified? + cross_high = tsz * bsz + high = tsz - (padding_count or 0) + with paddle.no_grad(): + assert high > 1, f"{bsz,tsz,fsz}" + + if self.n_negatives > 0: + tszs = ( + buffered_arange(num) + .unsqueeze(-1) + .expand([-1, self.n_negatives]) + .flatten() + ) + + neg_idxs = paddle.randint( + low=0, high=high - 1, shape=[bsz, self.n_negatives * num] + ) + neg_idxs[neg_idxs >= tszs] += 1 + + if self.cross_sample_negatives > 0: + tszs = ( + buffered_arange(num) + .unsqueeze(-1) + .expand([-1, self.cross_sample_negatives]) + .flatten() + ) + + cross_neg_idxs = paddle.randint( + low=0, + high=cross_high - 1, + shape=[bsz, self.cross_sample_negatives * num], + ) + cross_neg_idxs[cross_neg_idxs >= tszs] += 1 + + if self.n_negatives > 0: + neg_idxs = neg_idxs + (paddle.arange(bsz).unsqueeze(1) * high) + else: + neg_idxs = cross_neg_idxs + + if self.cross_sample_negatives > 0 and self.n_negatives > 0: + neg_idxs = paddle.concat([neg_idxs, cross_neg_idxs], axis=1) + + negs = y[neg_idxs.reshape([-1])] + negs = negs.reshape( + [bsz, num, self.n_negatives + self.cross_sample_negatives, fsz] + ).transpose( + [2, 0, 1, 3] + ) # to NxBxTxC + return negs, neg_idxs + + def compute_preds(self, x, y, negatives): + neg_is_pos = (y == negatives).all(-1) + y = y.unsqueeze(0) + targets = paddle.concat([y, negatives], axis=0) + + logits = paddle.nn.functional.cosine_similarity(x, targets, axis=-1) + logits = logits / self.logit_temp + logits = logits.astype(x.dtype) + + return logits + + def _get_feat_extract_output_lengths(self, input_lengths: paddle.Tensor): + """ + Computes the output length of the convolutional layers + """ + + def _conv_out_length(input_length, kernel_size, stride): + return paddle.floor((input_length - kernel_size) / stride + 1) + + conv_cfg_list = eval(self.cfg.conv_feature_layers) + + for i in range(len(conv_cfg_list)): + input_lengths = _conv_out_length( + input_lengths, conv_cfg_list[i][1], conv_cfg_list[i][2] + ) + + return paddle.cast(input_lengths, 'int64') + + def forward( + self, + source, + padding_mask=None, + mask=True, + features_only=False, + layer=None, + mask_indices=None, + mask_channel_indices=None, + padding_count=None, + ): + + if self.feature_grad_mult > 0: + features = self.feature_extractor(source) + if self.feature_grad_mult != 1.0: + features = GradMultiply.apply(features, self.feature_grad_mult) + else: + with paddle.no_grad(): + features = self.feature_extractor(source) + + features_pen = features.pow(2).mean() + + features = features.transpose([0, 2, 1]) + features = self.layer_norm(features) + unmasked_features = features.clone() + + if padding_mask is not None and padding_mask.any(): + input_lengths = (1 - paddle.cast(padding_mask, 'int64')).sum(-1) + # apply conv formula to get real output_lengths + output_lengths = self._get_feat_extract_output_lengths(input_lengths) + + padding_mask = paddle.zeros( + features.shape[:2], dtype=features.dtype + ) + + # these two operations makes sure that all values + # before the output lengths indices are attended to + padding_mask[ + ( + paddle.arange(padding_mask.shape[0]), + output_lengths - 1, + ) + ] = 1 + padding_mask = paddle.cast((1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])), 'bool') + else: + padding_mask = None + + time_steps_to_drop = features.shape[1] % self.crop_seq_to_multiple + if time_steps_to_drop != 0: + features = features[:, :-time_steps_to_drop] + unmasked_features = unmasked_features[:, :-time_steps_to_drop] + if padding_mask is not None: + padding_mask = padding_mask[:, :-time_steps_to_drop] + + if self.post_extract_proj is not None: + features = self.post_extract_proj(features) + + features = self.dropout_input(features) + unmasked_features = self.dropout_features(unmasked_features) + + num_vars = None + code_ppl = None + prob_ppl = None + curr_temp = None + + if self.input_quantizer: + q = self.input_quantizer(features, produce_targets=False) + features = q["x"] + num_vars = q["num_vars"] + code_ppl = q["code_perplexity"] + prob_ppl = q["prob_perplexity"] + curr_temp = q["temp"] + features = self.project_inp(features) + + if mask: + x, mask_indices = self.apply_mask( + features, + padding_mask, + mask_indices=mask_indices, + mask_channel_indices=mask_channel_indices, + ) + if mask_indices is not None: + y = unmasked_features[mask_indices].reshape( + [unmasked_features.shape[0], -1, unmasked_features.shape[-1]] + ) + else: + x = features + y = unmasked_features + mask_indices = None + + x, layer_results = self.encoder(x, padding_mask=padding_mask, layer=layer) + + if features_only: + return { + "x": x, + "padding_mask": padding_mask, + "features": unmasked_features, + "layer_results": layer_results, + } + + if self.quantizer: + if self.negatives_from_everywhere: + q = self.quantizer(unmasked_features, produce_targets=False) + y = q["x"] + num_vars = q["num_vars"] + code_ppl = q["code_perplexity"] + prob_ppl = q["prob_perplexity"] + curr_temp = q["temp"] + y = self.project_q(y) + + negs, _ = self.sample_negatives( + y, + mask_indices[0].sum(), + padding_count=padding_count, + ) + y = y[mask_indices].reshape([y.shape[0], -1, y.shape[-1]]) + + else: + q = self.quantizer(y, produce_targets=False) + y = q["x"] + num_vars = q["num_vars"] + code_ppl = q["code_perplexity"] + prob_ppl = q["prob_perplexity"] + curr_temp = q["temp"] + + y = self.project_q(y) + + negs, _ = self.sample_negatives( + y, + y.shape[1], + padding_count=padding_count, + ) + + if self.codebook_negatives > 0: + cb_negs = self.quantizer.sample_from_codebook( + y.shape[0] * y.shape[1], self.codebook_negatives + ) + cb_negs = cb_negs.reshape( + [self.codebook_negatives, y.shape[0], y.shape[1], -1] + ) # order doesnt matter + cb_negs = self.project_q(cb_negs) + negs = paddle.concat([negs, cb_negs], axis=0) + else: + y = self.project_q(y) + + if self.negatives_from_everywhere: + negs, _ = self.sample_negatives( + unmasked_features, + y.shape[1], + padding_count=padding_count, + ) + negs = self.project_q(negs) + else: + negs, _ = self.sample_negatives( + y, + y.shape[1], + padding_count=padding_count, + ) + + x = x[mask_indices].reshape([x.shape[0], -1, x.shape[-1]]) + + if self.target_glu: + y = self.target_glu(y) + negs = self.target_glu(negs) + + x = self.final_proj(x) + x = self.compute_preds(x, y, negs) + + result = { + "x": x, + "padding_mask": padding_mask, + "features_pen": features_pen, + } + + if prob_ppl is not None: + result["prob_perplexity"] = prob_ppl + result["code_perplexity"] = code_ppl + result["num_vars"] = num_vars + result["temp"] = curr_temp + + return result + + def quantize(self, x): + assert self.quantizer is not None + x = self.feature_extractor(x) + x = x.transpose([0, 2, 1]) + x = self.layer_norm(x) + return self.quantizer.forward_idx(x) + + def extract_features(self, source, padding_mask, mask=False, layer=None): + res = self.forward( + source, padding_mask, mask=mask, features_only=True, layer=layer + ) + return res + + def get_logits(self, net_output): + logits = net_output["x"] + logits = logits.transpose([2, 1, 0]) + logits = logits.reshape([-1, logits.shape[-1]]) + return logits + + def get_targets(self, sample, net_output, expand_steps=True): + x = net_output["x"] + return paddle.zeros(x.shape[1] * x.shape[2], dtype='int64') + + def get_extra_losses(self, net_output): + pen = [] + + if "prob_perplexity" in net_output: + pen.append( + (net_output["num_vars"] - net_output["prob_perplexity"]) + / net_output["num_vars"] + ) + + if "features_pen" in net_output: + pen.append(net_output["features_pen"]) + + return pen + + def remove_pretraining_modules(self, last_layer=None): + self.quantizer = None + self.project_q = None + self.target_glu = None + self.final_proj = None + + if last_layer is not None: + self.encoder.layers = nn.LayerList( + l for i, l in enumerate(self.encoder.layers) if i <= last_layer + ) + + +class ConvFeatureExtractionModel(nn.Layer): + def __init__( + self, + conv_layers: List[Tuple[int, int, int]], + dropout: float = 0.0, + mode: str = "default", + conv_bias: bool = False, + ): + super().__init__() + + assert mode in {"default", "layer_norm"} + + def block( + n_in, + n_out, + k, + stride, + is_layer_norm=False, + is_group_norm=False, + conv_bias=False, + ): + def make_conv(): + conv = nn.Conv1D(n_in, n_out, k, stride=stride, bias_attr=conv_bias if not conv_bias else paddle.ParamAttr()) + # nn.initializer.KaimingNormal()(conv.weight) + return conv + + assert ( + is_layer_norm and is_group_norm + ) == False, "layer norm and group norm are exclusive" + + if is_layer_norm: + return nn.Sequential( + make_conv(), + nn.Dropout(p=dropout), + nn.Sequential( + TransposeLast(), + Fp32LayerNorm(dim), + TransposeLast(), + ), + nn.GELU(), + ) + elif is_group_norm: + return nn.Sequential( + make_conv(), + nn.Dropout(p=dropout), + Fp32GroupNorm(dim, dim), + nn.GELU(), + ) + else: + return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU()) + + in_d = 1 + self.conv_layers = nn.LayerList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3, "invalid conv definition: " + str(cl) + (dim, k, stride) = cl + + self.conv_layers.append( + block( + in_d, + dim, + k, + stride, + is_layer_norm=mode == "layer_norm", + is_group_norm=mode == "default" and i == 0, + conv_bias=conv_bias, + ) + ) + in_d = dim + + def forward(self, x): + + # BxT -> BxCxT + x = x.unsqueeze(1) + # import pdb + # pdb.set_trace() + for conv in self.conv_layers: + x = conv(x) + + return x + + +def make_conv_pos(e, k, g): + pos_conv = nn.Conv1D( + e, + e, + kernel_size=k, + padding=k // 2, + groups=g, + ) + dropout = 0 + std = math.sqrt((4 * (1.0 - dropout)) / (k * e)) + nn.initializer.Normal(mean=0, std=std)(pos_conv.weight) + nn.initializer.Constant(0)(pos_conv.bias) + + pos_conv = nn.utils.weight_norm(pos_conv, name="weight", dim=2) + pos_conv = nn.Sequential(pos_conv, SamePad(k), nn.GELU()) + + return pos_conv + + +class TransformerEncoder(nn.Layer): + def build_encoder_layer(self, args: Wav2Vec2Config): + layer = TransformerSentenceEncoderLayer( + embedding_dim=self.embedding_dim, + ffn_embedding_dim=args.encoder_ffn_embed_dim, + num_attention_heads=args.encoder_attention_heads, + dropout=self.dropout, + attention_dropout=args.attention_dropout, + activation_dropout=args.activation_dropout, + activation_fn=args.activation_fn, + layer_norm_first=args.layer_norm_first, + ) + return layer + + def __init__(self, args: Wav2Vec2Config): + super().__init__() + + self.dropout = args.dropout + self.embedding_dim = args.encoder_embed_dim + self.required_seq_len_multiple = args.required_seq_len_multiple + + pos_conv_depth = getattr(args, "pos_conv_depth", 1) + if pos_conv_depth > 1: + num_layers = args.pos_conv_depth + k = max(3, args.conv_pos // num_layers) + + def make_conv_block(e, k, g, l): + return nn.Sequential( + *[ + nn.Sequential( + nn.Conv1D( + e, + e, + kernel_size=k, + padding=k // 2, + groups=g, + ), + SamePad(k), + TransposeLast(), + LayerNorm(e, elementwise_affine=False), + TransposeLast(), + nn.GELU(), + ) + for _ in range(l) + ] + ) + + self.pos_conv = make_conv_block( + self.embedding_dim, k, args.conv_pos_groups, num_layers + ) + + else: + self.pos_conv = make_conv_pos( + self.embedding_dim, + args.conv_pos, + args.conv_pos_groups, + ) + + self.layers = nn.LayerList( + [self.build_encoder_layer(args) for _ in range(args.encoder_layers)] + ) + self.layer_norm_first = args.layer_norm_first + self.layer_norm = LayerNorm(self.embedding_dim) + self.layerdrop = args.encoder_layerdrop + + def forward(self, x, padding_mask=None, layer=None): + x, layer_results = self.extract_features(x, padding_mask, layer) + # import pdb + # pdb.set_trace() + if self.layer_norm_first and layer is None: + x = self.layer_norm(x) + + return x, layer_results + + def extract_features( + self, + x, + padding_mask=None, + tgt_layer=None, + min_layer=0, + ): + + # import pdb + # pdb.set_trace() + if padding_mask is not None: + x = index_put(x, padding_mask, 0) + + x_conv = self.pos_conv(x.transpose([0, 2, 1])) + x_conv = x_conv.transpose([0, 2, 1]) + x = x + x_conv + + if not self.layer_norm_first: + x = self.layer_norm(x) + + # pad to the sequence length dimension + x, pad_length = pad_to_multiple( + x, self.required_seq_len_multiple, dim=-2, value=0 + ) + if pad_length > 0 and padding_mask is None: + padding_mask = paddle.zeros([x.shape[0], x.shape[1]], dtype='bool') + padding_mask[:, -pad_length:] = True + else: + padding_mask, _ = pad_to_multiple( + padding_mask, self.required_seq_len_multiple, dim=-1, value=True + ) + x = F.dropout(x, p=self.dropout, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose([1, 0, 2]) + + layer_results = [] + r = None + for i, layer in enumerate(self.layers): + dropout_probability = np.random.random() if self.layerdrop > 0 else 1 + if not self.training or (dropout_probability > self.layerdrop): + x, (z, lr) = layer( + x, self_attn_padding_mask=padding_mask, need_weights=False + ) + if i >= min_layer: + layer_results.append((x, z, lr)) + if i == tgt_layer: + r = x + break + + if r is not None: + x = r + + # T x B x C -> B x T x C + x = x.transpose([1, 0, 2]) + + # undo paddding + if pad_length > 0: + x = x[:, :-pad_length] + + def undo_pad(a, b, c): + return ( + a[:-pad_length], + b[:-pad_length] if b is not None else b, + c[:-pad_length], + ) + + layer_results = [undo_pad(*u) for u in layer_results] + + return x, layer_results + + def max_positions(self): + """Maximum output length supported by the encoder.""" + return self.args.max_positions + + def upgrade_state_dict_named(self, state_dict, name): + """Upgrade a (possibly old) state dict for new versions of fairseq.""" + return state_dict + +class TransformerSentenceEncoderLayer(nn.Layer): + """ + Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained + models. + """ + + def __init__( + self, + embedding_dim: float = 768, + ffn_embedding_dim: float = 3072, + num_attention_heads: int = 8, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + activation_fn: str = "relu", + layer_norm_first: bool = False, + ) -> None: + + super().__init__() + # Initialize parameters + self.embedding_dim = embedding_dim + self.dropout = dropout + self.activation_dropout = activation_dropout + + # Initialize blocks + self.activation_fn = get_activation_fn(activation_fn) + self.self_attn = MultiheadAttention( + self.embedding_dim, + num_attention_heads, + dropout=attention_dropout, + self_attention=True, + ) + + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(self.activation_dropout) + self.dropout3 = nn.Dropout(dropout) + + self.layer_norm_first = layer_norm_first + + # layer norm associated with the self attention layer + self.self_attn_layer_norm = LayerNorm(self.embedding_dim) + self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) + self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) + + # layer norm associated with the position wise feed-forward NN + self.final_layer_norm = LayerNorm(self.embedding_dim) + + def forward( + self, + x: paddle.Tensor, + self_attn_mask: paddle.Tensor = None, + self_attn_padding_mask: paddle.Tensor = None, + need_weights: bool = False, + att_args=None, + ): + """ + LayerNorm is applied either before or after the self-attention/ffn + modules similar to the original Transformer imlementation. + """ + residual = x + + if self.layer_norm_first: + x = self.self_attn_layer_norm(x) + x, attn = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + attn_mask=self_attn_mask, + need_weights=False, + ) + x = self.dropout1(x) + x = residual + x + + residual = x + x = self.final_layer_norm(x) + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + + layer_result = x + + x = self.dropout3(x) + x = residual + x + else: + x, attn = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=False, + ) + + x = self.dropout1(x) + x = residual + x + + x = self.self_attn_layer_norm(x) + + residual = x + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + + layer_result = x + + x = self.dropout3(x) + x = residual + x + x = self.final_layer_norm(x) + + return x, (attn, layer_result) + + +@dataclass +class AudioPretrainingConfig: + sample_rate: int = field( + default=16_000, + metadata={ + "help": "target sample rate. audio files will be up/down sampled to this rate" + }, + ) + normalize: bool = field( + default=False, + metadata={"help": "if set, normalizes input to have 0 mean and unit variance"}, + ) + enable_padding: bool = field( + default=False, metadata={"help": "pad shorter samples instead of cropping"} + ) + max_sample_size: Optional[int] = field( + default=None, metadata={"help": "max sample size to crop to for batching"} + ) + min_sample_size: Optional[int] = field( + default=None, metadata={"help": "min sample size to skip small examples"} + ) diff --git a/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py b/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py index 7468fdce0..36d7f744d 100755 --- a/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py +++ b/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py @@ -28,6 +28,9 @@ from paddlespeech.s2t.modules.ctc import CTCDecoderBase as CTC from paddlespeech.s2t.modules.initializer import DefaultInitializerContext from paddlespeech.s2t.utils.ctc_utils import remove_duplicates_and_blank from paddlespeech.s2t.utils.utility import log_add +from paddlespeech.s2t.utils.log import Log + +logger = Log(__name__).getlog() class Wav2vec2ASR(nn.Layer): @@ -55,6 +58,8 @@ class Wav2vec2ASR(nn.Layer): reduction='mean') def forward(self, wav, wavs_lens_rate, target, target_lens): + # import pdb + # pdb.set_trace() if self.normalize_wav: wav = F.layer_norm(wav, wav.shape)