You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
2615 lines
96 KiB
2615 lines
96 KiB
2 years ago
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||
|
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
""" Paddle Wav2Vec2 model."""
|
||
|
import math
|
||
|
import uuid
|
||
|
from dataclasses import dataclass
|
||
|
from dataclasses import field
|
||
|
from enum import Enum
|
||
|
from enum import EnumMeta
|
||
|
from typing import Callable
|
||
|
from typing import Dict
|
||
|
from typing import List
|
||
|
from typing import Optional
|
||
|
from typing import Tuple
|
||
|
|
||
|
import numpy as np
|
||
|
import paddle
|
||
|
import paddle.nn as nn
|
||
|
import paddle.nn.functional as F
|
||
|
from paddle import Tensor
|
||
|
|
||
|
from paddlespeech.s2t.modules.align import Conv1D
|
||
|
from paddlespeech.s2t.modules.align import Conv2D
|
||
|
from paddlespeech.s2t.modules.align import Embedding
|
||
|
from paddlespeech.s2t.modules.align import LayerNorm
|
||
|
from paddlespeech.s2t.modules.align import Linear
|
||
|
from paddlespeech.s2t.utils.log import Log
|
||
|
|
||
|
logger = Log(__name__).getlog()
|
||
|
|
||
|
|
||
|
class GLU(nn.Layer):
|
||
|
r"""Applies the gated linear unit function
|
||
|
:math:`{GLU}(a, b)= a \otimes \sigma(b)` where :math:`a` is the first half
|
||
|
of the input matrices and :math:`b` is the second half.
|
||
|
|
||
|
Args:
|
||
|
axis (int): the dimension on which to split the input. Default: -1
|
||
|
|
||
|
Shape:
|
||
|
- Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional
|
||
|
dimensions
|
||
|
- Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`
|
||
|
|
||
|
Examples::
|
||
|
|
||
|
>>> m = nn.GLU()
|
||
|
>>> input = paddle.randn([4, 2])
|
||
|
>>> output = m(input)
|
||
|
"""
|
||
|
|
||
|
def __init__(self, axis: int=-1) -> None:
|
||
|
super().__init__()
|
||
|
self.axis = axis
|
||
|
|
||
|
def forward(self, input: Tensor) -> Tensor:
|
||
|
return F.glu(input, self.axis)
|
||
|
|
||
|
|
||
|
class FairseqIncrementalState(object):
|
||
|
def __init__(self, *args, **kwargs):
|
||
|
super().__init__(*args, **kwargs)
|
||
|
self.init_incremental_state()
|
||
|
|
||
|
def init_incremental_state(self):
|
||
|
self._incremental_state_id = str(uuid.uuid4())
|
||
|
|
||
|
def _get_full_incremental_state_key(self, key: str) -> str:
|
||
|
return "{}.{}".format(self._incremental_state_id, key)
|
||
|
|
||
|
def get_incremental_state(
|
||
|
self,
|
||
|
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
|
||
|
key: str, ) -> Optional[Dict[str, Optional[Tensor]]]:
|
||
|
"""Helper for getting incremental state for an nn.Layer."""
|
||
|
full_key = self._get_full_incremental_state_key(key)
|
||
|
if incremental_state is None or full_key not in incremental_state:
|
||
|
return None
|
||
|
return incremental_state[full_key]
|
||
|
|
||
|
def set_incremental_state(
|
||
|
self,
|
||
|
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
|
||
|
key: str,
|
||
|
value: Dict[str, Optional[Tensor]],
|
||
|
) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]:
|
||
|
"""Helper for setting incremental state for an nn.Layer."""
|
||
|
if incremental_state is not None:
|
||
|
full_key = self._get_full_incremental_state_key(key)
|
||
|
incremental_state[full_key] = value
|
||
|
return incremental_state
|
||
|
|
||
|
|
||
|
def with_incremental_state(cls):
|
||
|
cls.__bases__ = (FairseqIncrementalState, ) + tuple(
|
||
|
b for b in cls.__bases__ if b != FairseqIncrementalState)
|
||
|
return cls
|
||
|
|
||
|
|
||
|
class FairseqDropout(paddle.nn.Layer):
|
||
|
def __init__(self, p, module_name=None):
|
||
|
super().__init__()
|
||
|
self.p = p
|
||
|
self.module_name = module_name
|
||
|
self.apply_during_inference = False
|
||
|
|
||
|
def forward(self, x):
|
||
|
if self.p > 0 and (self.training or self.apply_during_inference):
|
||
|
return F.dropout(x, p=self.p, training=True)
|
||
|
else:
|
||
|
return x
|
||
|
|
||
|
def make_generation_fast_(
|
||
|
self,
|
||
|
name: str,
|
||
|
retain_dropout: bool=False,
|
||
|
retain_dropout_modules: Optional[List[str]]=None,
|
||
|
**kwargs, ):
|
||
|
if retain_dropout:
|
||
|
if retain_dropout_modules is not None and self.module_name is None:
|
||
|
logger.warning(
|
||
|
"Cannot enable dropout during inference for module {} "
|
||
|
"because module_name was not set".format(name))
|
||
|
elif (retain_dropout_modules is
|
||
|
None # if None, apply to all modules
|
||
|
or self.module_name in retain_dropout_modules):
|
||
|
logger.info("Enabling dropout during inference for module: {}".
|
||
|
format(name))
|
||
|
self.apply_during_inference = True
|
||
|
else:
|
||
|
logger.info("Disabling dropout for module: {}".format(name))
|
||
|
|
||
|
|
||
|
def quant_noise(module, p, block_size):
|
||
|
"""
|
||
|
Wraps modules and applies quantization noise to the weights for
|
||
|
subsequent quantization with Iterative Product Quantization as
|
||
|
described in "Training with Quantization Noise for Extreme Model Compression"
|
||
|
|
||
|
Args:
|
||
|
- module: nn.Layer
|
||
|
- p: amount of Quantization Noise
|
||
|
- block_size: size of the blocks for subsequent quantization with iPQ
|
||
|
|
||
|
Remarks:
|
||
|
- Layer weights must have the right sizes wrt the block size
|
||
|
- Only Linear, Embedding and Conv2d modules are supported for the moment
|
||
|
- For more detail on how to quantize by blocks with convolutional weights,
|
||
|
see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
|
||
|
- We implement the simplest form of noise here as stated in the paper
|
||
|
which consists in randomly dropping blocks
|
||
|
"""
|
||
|
|
||
|
# if no quantization noise, don't register hook
|
||
|
if p <= 0:
|
||
|
return module
|
||
|
|
||
|
# supported modules
|
||
|
assert isinstance(module, (Linear, Embedding, Conv2D))
|
||
|
|
||
|
# test whether module.weight has the right sizes wrt block_size
|
||
|
is_conv = len(module.weight.shape) == 4
|
||
|
|
||
|
# 2D matrix
|
||
|
if not is_conv:
|
||
|
if isinstance(module, Linear):
|
||
|
features_weight = module.weight.shape[0]
|
||
|
else:
|
||
|
features_weight = module.weight.shape[1]
|
||
|
assert (
|
||
|
features_weight %
|
||
|
block_size == 0), "Input features must be a multiple of block sizes"
|
||
|
|
||
|
# 4D matrix
|
||
|
else:
|
||
|
# 1x1 convolutions
|
||
|
if module.weight.shape[2:] == (1, 1):
|
||
|
assert (module.weight.shape[1] % block_size == 0
|
||
|
), "Input channels must be a multiple of block sizes"
|
||
|
# regular convolutions
|
||
|
else:
|
||
|
k = module.weight.shape[2] * module.weight.shape[3]
|
||
|
assert k % block_size == 0, "Kernel size must be a multiple of block size"
|
||
|
|
||
|
def _forward_pre_hook(mod, input):
|
||
|
# no noise for evaluation
|
||
|
if mod.training:
|
||
|
if not is_conv:
|
||
|
# gather weight and sizes
|
||
|
weight = mod.weight
|
||
|
if isinstance(module, Linear):
|
||
|
in_features = weight.shape[0]
|
||
|
out_features = weight.shape[1]
|
||
|
else:
|
||
|
in_features = weight.shape[1]
|
||
|
out_features = weight.shape[0]
|
||
|
|
||
|
# split weight matrix into blocks and randomly drop selected blocks
|
||
|
mask = paddle.zeros(
|
||
|
[in_features // block_size * out_features],
|
||
|
dtype=paddle.bool)
|
||
|
# the implementation of bernoulli_, p=0.5
|
||
|
mask = paddle.ones_like(mask) * 0.5
|
||
|
mask = paddle.bernoulli(mask)
|
||
|
mask = mask.unsqueeze(1).tile([1, block_size]).reshape(
|
||
|
[-1, in_features])
|
||
|
|
||
|
else:
|
||
|
# gather weight and sizes
|
||
|
weight = mod.weight
|
||
|
in_channels = mod.weight.shape[1]
|
||
|
out_channels = mod.weight.shape[0]
|
||
|
|
||
|
# split weight matrix into blocks and randomly drop selected blocks
|
||
|
if module.weight.shape[2:] == (1, 1):
|
||
|
mask = paddle.zeros(
|
||
|
[in_channels // block_size * out_channels],
|
||
|
dtype=paddle.bool)
|
||
|
|
||
|
# the implementation of bernoulli_, p=0.5
|
||
|
mask = paddle.ones_like(mask) * 0.5
|
||
|
mask = paddle.bernoulli(mask)
|
||
|
mask = mask.unsqueeze(1).tile([1, block_size]).reshape(
|
||
|
[-1, in_channels])
|
||
|
else:
|
||
|
mask = paddle.zeros(weight.shape)
|
||
|
|
||
|
# the implementation of bernoulli_, p=0.5
|
||
|
mask = paddle.ones_like(mask) * 0.5
|
||
|
mask = paddle.bernoulli(mask)
|
||
|
mask = mask.unsqueeze(1).tile([1, in_channels, 1, 1])
|
||
|
|
||
|
# scale weights and apply mask
|
||
|
s = 1 / (1 - p)
|
||
|
mod.weight.set_value(s * weight.masked_fill(mask, 0))
|
||
|
|
||
|
module.register_forward_pre_hook(_forward_pre_hook)
|
||
|
return module
|
||
|
|
||
|
|
||
|
@with_incremental_state
|
||
|
class MultiheadAttention(nn.Layer):
|
||
|
"""Multi-headed attention.
|
||
|
|
||
|
See "Attention Is All You Need" for more details.
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
embed_dim,
|
||
|
num_heads,
|
||
|
kdim=None,
|
||
|
vdim=None,
|
||
|
dropout=0.0,
|
||
|
bias=True,
|
||
|
add_bias_kv=False,
|
||
|
add_zero_attn=False,
|
||
|
self_attention=False,
|
||
|
encoder_decoder_attention=False,
|
||
|
q_noise=0.0,
|
||
|
qn_block_size=8,
|
||
|
# TODO: pass in config rather than string.
|
||
|
# config defined in xformers.components.attention.AttentionConfig
|
||
|
xformers_att_config: Optional[str]=None,
|
||
|
xformers_blocksparse_layout: Optional[
|
||
|
paddle.Tensor]=None, # This should be part of the config
|
||
|
xformers_blocksparse_blocksize: Optional[
|
||
|
int]=16, # This should be part of the config
|
||
|
):
|
||
|
super().__init__()
|
||
|
|
||
|
def eval_str_dict(x, type=dict):
|
||
|
if x is None:
|
||
|
return None
|
||
|
if isinstance(x, str):
|
||
|
x = eval(x)
|
||
|
return x
|
||
|
|
||
|
xformers_att_config = eval_str_dict(xformers_att_config)
|
||
|
self.use_xformers = xformers_att_config is not None
|
||
|
assert not self.use_xformers, "Do not use xformers in PaddleSpeech"
|
||
|
|
||
|
self.embed_dim = embed_dim
|
||
|
self.kdim = kdim if kdim is not None else embed_dim
|
||
|
self.vdim = vdim if vdim is not None else embed_dim
|
||
|
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
||
|
|
||
|
self.num_heads = num_heads
|
||
|
self.dropout_module = FairseqDropout(
|
||
|
dropout, module_name=self.__class__.__name__)
|
||
|
|
||
|
self.head_dim = embed_dim // num_heads
|
||
|
assert (self.head_dim * num_heads == self.embed_dim
|
||
|
), "embed_dim must be divisible by num_heads"
|
||
|
self.scaling = self.head_dim**-0.5
|
||
|
|
||
|
self.self_attention = self_attention
|
||
|
self.encoder_decoder_attention = encoder_decoder_attention
|
||
|
|
||
|
assert not self.self_attention or self.qkv_same_dim, (
|
||
|
"Self-attention requires query, key and "
|
||
|
"value to be of the same size")
|
||
|
|
||
|
# Todo scaled initialization
|
||
|
# Empirically observed the convergence to be much better with
|
||
|
# the scaled initialization
|
||
|
weight_attr = nn.initializer.XavierUniform()
|
||
|
kv_proj_bias_attr = nn.initializer.XavierUniform()
|
||
|
out_proj_bias_attr = nn.initializer.Constant(0)
|
||
|
|
||
|
self.k_proj = quant_noise(
|
||
|
nn.Linear(
|
||
|
self.kdim,
|
||
|
embed_dim,
|
||
|
weight_attr=weight_attr,
|
||
|
bias_attr=bias
|
||
|
if not bias else kv_proj_bias_attr), q_noise, qn_block_size)
|
||
|
self.v_proj = quant_noise(
|
||
|
nn.Linear(
|
||
|
self.vdim,
|
||
|
embed_dim,
|
||
|
weight_attr=weight_attr,
|
||
|
bias_attr=bias
|
||
|
if not bias else kv_proj_bias_attr), q_noise, qn_block_size)
|
||
|
self.q_proj = quant_noise(
|
||
|
nn.Linear(
|
||
|
embed_dim, embed_dim, weight_attr=weight_attr, bias_attr=bias),
|
||
|
q_noise, qn_block_size)
|
||
|
|
||
|
self.out_proj = quant_noise(
|
||
|
nn.Linear(
|
||
|
embed_dim,
|
||
|
embed_dim,
|
||
|
weight_attr=weight_attr,
|
||
|
bias_attr=bias
|
||
|
if not bias else out_proj_bias_attr), q_noise, qn_block_size)
|
||
|
|
||
|
# nn.initializer.XavierUniform(self.k_proj.weight, gain=1 / math.sqrt(2))
|
||
|
# nn.initializer.XavierUniform(self.v_proj.weight, gain=1 / math.sqrt(2))
|
||
|
# nn.initializer.XavierUniform(self.q_proj.weight, gain=1 / math.sqrt(2))
|
||
|
# else:
|
||
|
# self.k_proj.weight = paddle.ParamAttr()
|
||
|
# nn.initializer.XavierUniform(self.k_proj.weight)
|
||
|
# nn.initializer.XavierUniform(self.v_proj.weight)
|
||
|
# nn.initializer.XavierUniform(self.q_proj.weight)
|
||
|
|
||
|
# nn.initializer.XavierUniform(self.out_proj.weight)
|
||
|
# if self.out_proj.bias is not None:
|
||
|
# nn.initializer.Constant(self.out_proj.bias)
|
||
|
# if self.bias_k is not None:
|
||
|
# nn.initializer.XavierNormal(self.bias_k)
|
||
|
# if self.bias_v is not None:
|
||
|
# nn.initializer.XavierNormal(self.bias_v)
|
||
|
|
||
|
# self.k_proj = Linear(self.kdim, embed_dim)
|
||
|
|
||
|
# self.v_proj = Linear(self.vdim, embed_dim)
|
||
|
|
||
|
# self.q_proj = Linear(embed_dim, embed_dim)
|
||
|
|
||
|
# self.out_proj = Linear(embed_dim, embed_dim)
|
||
|
|
||
|
if add_bias_kv:
|
||
|
self.bias_k = paddle.create_parameter(
|
||
|
shape=[1, 1, embed_dim],
|
||
|
dtype='float32',
|
||
|
initializer=nn.initializer.XavierUniform)
|
||
|
self.bias_v = paddle.create_parameter(
|
||
|
shape=[1, 1, embed_dim],
|
||
|
dtype='float32',
|
||
|
initializer=nn.initializer.XavierUniform)
|
||
|
else:
|
||
|
self.bias_k = self.bias_v = None
|
||
|
|
||
|
self.add_zero_attn = add_zero_attn
|
||
|
self.beam_size = 1
|
||
|
# self.reset_parameters()
|
||
|
|
||
|
self.onnx_trace = False
|
||
|
self.skip_embed_dim_check = False
|
||
|
|
||
|
def prepare_for_onnx_export_(self):
|
||
|
self.onnx_trace = True
|
||
|
|
||
|
def reset_parameters(self):
|
||
|
if self.qkv_same_dim:
|
||
|
# Empirically observed the convergence to be much better with
|
||
|
# the scaled initialization
|
||
|
nn.initializer.XavierUniform(
|
||
|
self.k_proj.weight, gain=1 / math.sqrt(2))
|
||
|
nn.initializer.XavierUniform(
|
||
|
self.v_proj.weight, gain=1 / math.sqrt(2))
|
||
|
nn.initializer.XavierUniform(
|
||
|
self.q_proj.weight, gain=1 / math.sqrt(2))
|
||
|
else:
|
||
|
self.k_proj.weight = paddle.ParamAttr()
|
||
|
nn.initializer.XavierUniform(self.k_proj.weight)
|
||
|
nn.initializer.XavierUniform(self.v_proj.weight)
|
||
|
nn.initializer.XavierUniform(self.q_proj.weight)
|
||
|
|
||
|
nn.initializer.XavierUniform(self.out_proj.weight)
|
||
|
if self.out_proj.bias is not None:
|
||
|
nn.initializer.Constant(self.out_proj.bias)
|
||
|
if self.bias_k is not None:
|
||
|
nn.initializer.XavierNormal(self.bias_k)
|
||
|
if self.bias_v is not None:
|
||
|
nn.initializer.XavierNormal(self.bias_v)
|
||
|
|
||
|
def _get_reserve_head_index(self, num_heads_to_keep: int):
|
||
|
k_proj_heads_norm = []
|
||
|
q_proj_heads_norm = []
|
||
|
v_proj_heads_norm = []
|
||
|
|
||
|
for i in range(self.num_heads):
|
||
|
start_idx = i * self.head_dim
|
||
|
end_idx = (i + 1) * self.head_dim
|
||
|
k_proj_heads_norm.append(
|
||
|
paddle.sum(
|
||
|
paddle.abs(self.k_proj.weight[:, start_idx:end_idx]))
|
||
|
.tolist() + paddle.sum(
|
||
|
paddle.abs(self.k_proj.bias[start_idx:end_idx])).tolist())
|
||
|
q_proj_heads_norm.append(
|
||
|
paddle.sum(
|
||
|
paddle.abs(self.q_proj.weight[:, start_idx:end_idx]))
|
||
|
.tolist() + paddle.sum(
|
||
|
paddle.abs(self.q_proj.bias[start_idx:end_idx])).tolist())
|
||
|
v_proj_heads_norm.append(
|
||
|
paddle.sum(
|
||
|
paddle.abs(self.v_proj.weight[:, start_idx:end_idx]))
|
||
|
.tolist() + paddle.sum(
|
||
|
paddle.abs(self.v_proj.bias[start_idx:end_idx])).tolist())
|
||
|
|
||
|
heads_norm = []
|
||
|
for i in range(self.num_heads):
|
||
|
heads_norm.append(k_proj_heads_norm[i] + q_proj_heads_norm[i] +
|
||
|
v_proj_heads_norm[i])
|
||
|
|
||
|
sorted_head_index = sorted(
|
||
|
range(self.num_heads), key=lambda k: heads_norm[k], reverse=True)
|
||
|
reserve_head_index = []
|
||
|
for i in range(num_heads_to_keep):
|
||
|
start = sorted_head_index[i] * self.head_dim
|
||
|
end = (sorted_head_index[i] + 1) * self.head_dim
|
||
|
reserve_head_index.append((start, end))
|
||
|
|
||
|
return reserve_head_index
|
||
|
|
||
|
def _adaptive_prune_heads(self, reserve_head_index: List[Tuple[int, int]]):
|
||
|
new_q_weight = []
|
||
|
new_q_bias = []
|
||
|
new_k_weight = []
|
||
|
new_k_bias = []
|
||
|
new_v_weight = []
|
||
|
new_v_bias = []
|
||
|
new_out_proj_weight = []
|
||
|
|
||
|
for ele in reserve_head_index:
|
||
|
start_idx, end_idx = ele
|
||
|
new_q_weight.append(self.q_proj.weight[:, start_idx:end_idx])
|
||
|
new_q_bias.append(self.q_proj.bias[start_idx:end_idx])
|
||
|
|
||
|
new_k_weight.append(self.k_proj.weight[:, start_idx:end_idx])
|
||
|
|
||
|
new_k_bias.append(self.k_proj.bias[start_idx:end_idx])
|
||
|
|
||
|
new_v_weight.append(self.v_proj.weight[:, start_idx:end_idx])
|
||
|
new_v_bias.append(self.v_proj.bias[start_idx:end_idx])
|
||
|
|
||
|
new_out_proj_weight.append(
|
||
|
self.out_proj.weight[start_idx:end_idx, ])
|
||
|
|
||
|
new_q_weight = paddle.concat(new_q_weight, axis=-1).detach()
|
||
|
new_k_weight = paddle.concat(new_k_weight, axis=-1).detach()
|
||
|
new_v_weight = paddle.concat(new_v_weight, axis=-1).detach()
|
||
|
new_out_proj_weight = paddle.concat(new_out_proj_weight).detach()
|
||
|
new_q_weight.stop_gradient = False
|
||
|
new_k_weight.stop_gradient = False
|
||
|
new_v_weight.stop_gradient = False
|
||
|
new_out_proj_weight.stop_gradient = False
|
||
|
|
||
|
new_q_bias = paddle.concat(new_q_bias).detach()
|
||
|
new_q_bias.stop_gradient = False
|
||
|
|
||
|
new_k_bias = paddle.concat(new_k_bias).detach()
|
||
|
new_k_bias.stop_gradient = False
|
||
|
|
||
|
new_v_bias = paddle.concat(new_v_bias).detach()
|
||
|
new_v_bias.stop_gradient = False
|
||
|
|
||
|
self.q_proj.weight = paddle.create_parameter(
|
||
|
shape=new_q_weight.shape,
|
||
|
dtype=new_q_weight.dtype,
|
||
|
default_initializer=paddle.nn.initializer.Assign(new_q_weight))
|
||
|
self.q_proj.bias = paddle.create_parameter(
|
||
|
shape=new_q_bias.shape,
|
||
|
dtype=new_q_bias.dtype,
|
||
|
default_initializer=paddle.nn.initializer.Assign(new_q_bias))
|
||
|
|
||
|
self.k_proj.weight = paddle.create_parameter(
|
||
|
shape=new_k_weight.shape,
|
||
|
dtype=new_k_weight.dtype,
|
||
|
default_initializer=paddle.nn.initializer.Assign(new_k_weight))
|
||
|
self.k_proj.bias = paddle.create_parameter(
|
||
|
shape=new_k_bias.shape,
|
||
|
dtype=new_k_bias.dtype,
|
||
|
default_initializer=paddle.nn.initializer.Assign(new_k_bias))
|
||
|
|
||
|
self.v_proj.weight = paddle.create_parameter(
|
||
|
shape=new_v_weight.shape,
|
||
|
dtype=new_v_weight.dtype,
|
||
|
default_initializer=paddle.nn.initializer.Assign(new_v_weight))
|
||
|
self.v_proj.bias = paddle.create_parameter(
|
||
|
shape=new_v_bias.shape,
|
||
|
dtype=new_v_bias.dtype,
|
||
|
default_initializer=paddle.nn.initializer.Assign(new_v_bias))
|
||
|
|
||
|
self.out_proj.weight = paddle.create_parameter(
|
||
|
shape=new_out_proj_weight.shape,
|
||
|
dtype=new_out_proj_weight.dtype,
|
||
|
default_initializer=paddle.nn.initializer.Assign(
|
||
|
new_out_proj_weight))
|
||
|
|
||
|
self.num_heads = len(reserve_head_index)
|
||
|
self.embed_dim = self.head_dim * self.num_heads
|
||
|
self.q_proj.out_features = self.embed_dim
|
||
|
self.k_proj.out_features = self.embed_dim
|
||
|
self.v_proj.out_features = self.embed_dim
|
||
|
|
||
|
def _set_skip_embed_dim_check(self):
|
||
|
self.skip_embed_dim_check = True
|
||
|
|
||
|
def _pad_masks(
|
||
|
self,
|
||
|
key_padding_mask: Optional[Tensor],
|
||
|
attn_mask: Optional[Tensor],
|
||
|
) -> Tuple[Optional[Tensor], Optional[Tensor]]:
|
||
|
if attn_mask is not None:
|
||
|
shape = attn_mask.shape[:-1] + [
|
||
|
1,
|
||
|
]
|
||
|
attn_mask = paddle.concat(
|
||
|
[attn_mask, paddle.zeros(shape, dtype=attn_mask.dtype)],
|
||
|
axis=-1)
|
||
|
if key_padding_mask is not None:
|
||
|
shape = key_padding_mask.shape[:-1] + [
|
||
|
1,
|
||
|
]
|
||
|
key_padding_mask = paddle.concat(
|
||
|
[
|
||
|
key_padding_mask, paddle.zeros(
|
||
|
shape, dtype=key_padding_mask.dtype)
|
||
|
],
|
||
|
axis=-1)
|
||
|
return key_padding_mask, attn_mask
|
||
|
|
||
|
def _add_bias(
|
||
|
self,
|
||
|
k: Tensor,
|
||
|
v: Tensor,
|
||
|
key_padding_mask: Optional[Tensor],
|
||
|
attn_mask: Optional[Tensor],
|
||
|
bsz: int,
|
||
|
) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
|
||
|
assert self.bias_k is not None
|
||
|
assert self.bias_v is not None
|
||
|
k = paddle.concat([k, self.bias_k.tile([1, bsz, 1])], axis=-1)
|
||
|
v = paddle.concat([v, self.bias_v.tile([1, bsz, 1])], axis=-1)
|
||
|
key_padding_mask, attn_mask = self._pad_masks(
|
||
|
key_padding_mask=key_padding_mask, attn_mask=attn_mask)
|
||
|
return k, v, key_padding_mask, attn_mask
|
||
|
|
||
|
def _append_zero_attn(
|
||
|
self,
|
||
|
k: Tensor,
|
||
|
v: Tensor,
|
||
|
key_padding_mask: Optional[Tensor],
|
||
|
attn_mask: Optional[Tensor],
|
||
|
) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
|
||
|
zero_attn_shape = k.shape[:-2] + [1] + k.shape[-1:]
|
||
|
k = paddle.concat(
|
||
|
[k, paddle.zeros(zero_attn_shape, dtype=k.dtype)], axis=-2)
|
||
|
v = paddle.concat(
|
||
|
[v, paddle.zeros(zero_attn_shape, dtype=v.dtype)], axis=-2)
|
||
|
key_padding_mask, attn_mask = self._pad_masks(
|
||
|
key_padding_mask=key_padding_mask, attn_mask=attn_mask)
|
||
|
return k, v, key_padding_mask, attn_mask
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
query,
|
||
|
key: Optional[Tensor],
|
||
|
value: Optional[Tensor],
|
||
|
key_padding_mask: Optional[Tensor]=None,
|
||
|
incremental_state: Optional[Dict[str, Dict[str, Optional[
|
||
|
Tensor]]]]=None,
|
||
|
need_weights: bool=True,
|
||
|
static_kv: bool=False,
|
||
|
attn_mask: Optional[Tensor]=None,
|
||
|
before_softmax: bool=False,
|
||
|
need_head_weights: bool=False, ) -> Tuple[Tensor, Optional[Tensor]]:
|
||
|
"""Input shape: Time x Batch x Channel
|
||
|
|
||
|
Args:
|
||
|
key_padding_mask (ByteTensor, optional): mask to exclude
|
||
|
keys that are pads, of shape `(batch, src_len)`, where
|
||
|
padding elements are indicated by 1s.
|
||
|
need_weights (bool, optional): return the attention weights,
|
||
|
averaged over heads (default: False).
|
||
|
attn_mask (ByteTensor, optional): typically used to
|
||
|
implement causal attention, where the mask prevents the
|
||
|
attention from looking forward in time (default: None).
|
||
|
before_softmax (bool, optional): return the raw attention
|
||
|
weights and values before the attention softmax.
|
||
|
need_head_weights (bool, optional): return the attention
|
||
|
weights for each head. Implies *need_weights*. Default:
|
||
|
return the average attention weights over all heads.
|
||
|
"""
|
||
|
if need_head_weights:
|
||
|
need_weights = True
|
||
|
|
||
|
is_tpu = query.place == "xla"
|
||
|
|
||
|
tgt_len, bsz, embed_dim = query.shape
|
||
|
src_len = tgt_len
|
||
|
if not self.skip_embed_dim_check:
|
||
|
assert (embed_dim == self.embed_dim
|
||
|
), f"query dim {embed_dim} != {self.embed_dim}"
|
||
|
assert list(query.shape) == [tgt_len, bsz, embed_dim]
|
||
|
if key is not None:
|
||
|
src_len, key_bsz, _ = key.shape
|
||
|
# if not torch.jit.is_scripting():
|
||
|
# assert value is not None
|
||
|
# assert src_len, key_bsz == value.shape[:2]
|
||
|
|
||
|
# if (
|
||
|
# not self.onnx_trace
|
||
|
# and not is_tpu # don't use PyTorch version on TPUs
|
||
|
# and incremental_state is None
|
||
|
# and not static_kv
|
||
|
# # A workaround for quantization to work. Otherwise JIT compilation
|
||
|
# # treats bias in linear module as method.
|
||
|
# and not torch.jit.is_scripting()
|
||
|
# # The Multihead attention implemented in pytorch forces strong dimension check
|
||
|
# # for input embedding dimention and K,Q,V projection dimension.
|
||
|
# # Since pruning will break the dimension check and it is not easy to modify the pytorch API,
|
||
|
# # it is preferred to bypass the pytorch MHA when we need to skip embed_dim_check
|
||
|
# and not self.skip_embed_dim_check
|
||
|
# ):
|
||
|
# assert key is not None and value is not None
|
||
|
|
||
|
# if self.use_xformers:
|
||
|
# return self._xformers_attn_forward(
|
||
|
# query, key, value, key_padding_mask, need_weights, attn_mask
|
||
|
# )
|
||
|
|
||
|
# else:
|
||
|
# return F.multi_head_attention_forward(
|
||
|
# query,
|
||
|
# key,
|
||
|
# value,
|
||
|
# self.embed_dim,
|
||
|
# self.num_heads,
|
||
|
# torch.empty([0]),
|
||
|
# torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
|
||
|
# self.bias_k,
|
||
|
# self.bias_v,
|
||
|
# self.add_zero_attn,
|
||
|
# self.dropout_module.p,
|
||
|
# self.out_proj.weight,
|
||
|
# self.out_proj.bias,
|
||
|
# self.training or self.dropout_module.apply_during_inference,
|
||
|
# key_padding_mask,
|
||
|
# need_weights,
|
||
|
# attn_mask,
|
||
|
# use_separate_proj_weight=True,
|
||
|
# q_proj_weight=self.q_proj.weight,
|
||
|
# k_proj_weight=self.k_proj.weight,
|
||
|
# v_proj_weight=self.v_proj.weight,
|
||
|
# )
|
||
|
|
||
|
if incremental_state is not None:
|
||
|
saved_state = self._get_input_buffer(incremental_state)
|
||
|
if saved_state is not None and "prev_key" in saved_state:
|
||
|
# previous time steps are cached - no need to recompute
|
||
|
# key and value if they are static
|
||
|
if static_kv:
|
||
|
assert self.encoder_decoder_attention and not self.self_attention
|
||
|
key = value = None
|
||
|
else:
|
||
|
saved_state = None
|
||
|
|
||
|
if self.self_attention:
|
||
|
q = self.q_proj(query)
|
||
|
k = self.k_proj(query)
|
||
|
v = self.v_proj(query)
|
||
|
elif self.encoder_decoder_attention:
|
||
|
# encoder-decoder attention
|
||
|
q = self.q_proj(query)
|
||
|
if key is None:
|
||
|
assert value is None
|
||
|
k = v = None
|
||
|
else:
|
||
|
if self.beam_size > 1 and bsz == key.size(1):
|
||
|
# key is [T, bsz*beam_size, C], reduce to [T, bsz, C]
|
||
|
key = key.view(
|
||
|
key.size(0), -1, self.beam_size,
|
||
|
key.size(2))[:, :, 0, :]
|
||
|
if key_padding_mask is not None:
|
||
|
key_padding_mask = key_padding_mask.view(
|
||
|
-1, self.beam_size,
|
||
|
key_padding_mask.size(1))[:, 0, :]
|
||
|
k = self.k_proj(key)
|
||
|
v = self.v_proj(key)
|
||
|
|
||
|
else:
|
||
|
assert key is not None and value is not None
|
||
|
q = self.q_proj(query)
|
||
|
k = self.k_proj(key)
|
||
|
v = self.v_proj(value)
|
||
|
q *= self.scaling
|
||
|
|
||
|
if self.bias_k is not None:
|
||
|
assert self.bias_v is not None
|
||
|
k, v, attn_mask, key_padding_mask = self._add_bias(
|
||
|
k, v, attn_mask, key_padding_mask, bsz)
|
||
|
|
||
|
q = paddle.reshape(
|
||
|
q, [tgt_len, bsz * self.num_heads, self.head_dim]).transpose(
|
||
|
[1, 0, 2])
|
||
|
kv_bsz = bsz # need default value for scripting
|
||
|
if k is not None:
|
||
|
kv_bsz = k.shape[1]
|
||
|
k = paddle.reshape(
|
||
|
k, [-1, kv_bsz * self.num_heads, self.head_dim]).transpose(
|
||
|
[1, 0, 2])
|
||
|
if v is not None:
|
||
|
v = paddle.reshape(
|
||
|
v, [-1, kv_bsz * self.num_heads, self.head_dim]).transpose(
|
||
|
[1, 0, 2])
|
||
|
|
||
|
if saved_state is not None:
|
||
|
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
||
|
if "prev_key" in saved_state:
|
||
|
_prev_key = saved_state["prev_key"]
|
||
|
assert _prev_key is not None
|
||
|
kv_bsz = _prev_key.shape[0]
|
||
|
prev_key = _prev_key.reshape(
|
||
|
[kv_bsz * self.num_heads, -1, self.head_dim])
|
||
|
if static_kv:
|
||
|
k = prev_key
|
||
|
else:
|
||
|
assert k is not None
|
||
|
k = paddle.concat([prev_key, k], axis=1)
|
||
|
src_len = k.shape[1]
|
||
|
if "prev_value" in saved_state:
|
||
|
_prev_value = saved_state["prev_value"]
|
||
|
assert _prev_value is not None
|
||
|
assert kv_bsz == _prev_value.size(0)
|
||
|
prev_value = _prev_value.reshape(
|
||
|
[kv_bsz * self.num_heads, -1, self.head_dim])
|
||
|
if static_kv:
|
||
|
v = prev_value
|
||
|
else:
|
||
|
assert v is not None
|
||
|
v = paddle.concat([prev_value, v], axis=1)
|
||
|
prev_key_padding_mask: Optional[Tensor] = None
|
||
|
if "prev_key_padding_mask" in saved_state:
|
||
|
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
|
||
|
assert k is not None and v is not None
|
||
|
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
|
||
|
key_padding_mask=key_padding_mask,
|
||
|
prev_key_padding_mask=prev_key_padding_mask,
|
||
|
batch_size=kv_bsz,
|
||
|
src_len=k.shape[1],
|
||
|
static_kv=static_kv, )
|
||
|
|
||
|
saved_state["prev_key"] = k.reshape(
|
||
|
[kv_bsz, self.num_heads, -1, self.head_dim])
|
||
|
saved_state["prev_value"] = v.reshape(
|
||
|
[kv_bsz, self.num_heads, -1, self.head_dim])
|
||
|
saved_state["prev_key_padding_mask"] = key_padding_mask
|
||
|
# In this branch incremental_state is never None
|
||
|
assert incremental_state is not None
|
||
|
incremental_state = self._set_input_buffer(incremental_state,
|
||
|
saved_state)
|
||
|
assert k is not None
|
||
|
assert k.shape[1] == src_len
|
||
|
|
||
|
# This is part of a workaround to get around fork/join parallelism
|
||
|
# not supporting Optional types.
|
||
|
if key_padding_mask is not None and key_padding_mask.dim() == 0:
|
||
|
key_padding_mask = None
|
||
|
|
||
|
if key_padding_mask is not None:
|
||
|
assert key_padding_mask.shape[0] == kv_bsz
|
||
|
assert key_padding_mask.shape[1] == src_len
|
||
|
|
||
|
if self.add_zero_attn:
|
||
|
assert v is not None
|
||
|
src_len += 1
|
||
|
k, v, key_padding_mask, attn_mask = self._append_zero_attn(
|
||
|
k=k,
|
||
|
v=v,
|
||
|
key_padding_mask=key_padding_mask,
|
||
|
attn_mask=attn_mask)
|
||
|
|
||
|
if self.encoder_decoder_attention and bsz != kv_bsz:
|
||
|
attn_weights = paddle.einsum(
|
||
|
"bxhtd,bhsd->bxhts",
|
||
|
q.reshape([kv_bsz, -1, self.num_heads] + q.shape[1:]),
|
||
|
k.reshape([kv_bsz, self.num_heads] + k.shape[1:]), )
|
||
|
attn_weights = attn_weights.reshape([
|
||
|
-1,
|
||
|
] + attn_weights.shape[-2:])
|
||
|
else:
|
||
|
attn_weights = paddle.bmm(q, k.transpose([0, 2, 1]))
|
||
|
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len,
|
||
|
bsz)
|
||
|
|
||
|
assert list(
|
||
|
attn_weights.shape) == [bsz * self.num_heads, tgt_len, src_len]
|
||
|
|
||
|
if attn_mask is not None:
|
||
|
attn_mask = attn_mask.unsqueeze(0)
|
||
|
if self.onnx_trace:
|
||
|
attn_mask = attn_mask.tile([attn_weights.shape[0], 1, 1])
|
||
|
attn_weights += attn_mask
|
||
|
|
||
|
if key_padding_mask is not None:
|
||
|
# don't attend to padding symbols
|
||
|
attn_weights = attn_weights.reshape(
|
||
|
[bsz, self.num_heads, tgt_len, src_len])
|
||
|
if not is_tpu:
|
||
|
attn_weights = attn_weights.reshape(
|
||
|
[kv_bsz, -1, self.num_heads, tgt_len, src_len])
|
||
|
attn_weights = paddle.where(
|
||
|
key_padding_mask.unsqueeze(1).unsqueeze(2).unsqueeze(3)
|
||
|
.astype('bool'),
|
||
|
float('-inf') * paddle.ones_like(attn_weights),
|
||
|
attn_weights)
|
||
|
else:
|
||
|
attn_weights = attn_weights.transpose([2, 1, 0])
|
||
|
attn_weights = paddle.where(key_padding_mask,
|
||
|
float('-inf') *
|
||
|
paddle.ones_like(attn_weights),
|
||
|
attn_weights)
|
||
|
attn_weights = attn_weights.transpose([2, 1, 0])
|
||
|
attn_weights = attn_weights.reshape(
|
||
|
[bsz * self.num_heads, tgt_len, src_len])
|
||
|
|
||
|
if before_softmax:
|
||
|
return attn_weights, v
|
||
|
|
||
|
def softmax_supporting_onnx_trace(x, dim: int, onnx_trace: bool=False):
|
||
|
if onnx_trace:
|
||
|
return F.softmax(x, axis=dim)
|
||
|
else:
|
||
|
return F.softmax(x, axis=dim, dtype='float32')
|
||
|
|
||
|
attn_weights_float = softmax_supporting_onnx_trace(
|
||
|
attn_weights, dim=-1, onnx_trace=self.onnx_trace)
|
||
|
attn_weights = paddle.cast(attn_weights_float, attn_weights.dtype)
|
||
|
attn_probs = self.dropout_module(attn_weights)
|
||
|
|
||
|
assert v is not None
|
||
|
if self.encoder_decoder_attention and bsz != kv_bsz:
|
||
|
attn = paddle.einsum(
|
||
|
"bxhts,bhsd->bxhtd",
|
||
|
attn_probs.reshape([kv_bsz, -1, self.num_heads] +
|
||
|
attn_probs.shape[1:]),
|
||
|
v.reshape([kv_bsz, self.num_heads] + v.shape[1:]), )
|
||
|
attn = attn.reshape([
|
||
|
-1,
|
||
|
] + attn.shape[-2:])
|
||
|
else:
|
||
|
attn = paddle.bmm(attn_probs, v)
|
||
|
assert list(
|
||
|
attn.shape) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
||
|
if self.onnx_trace and attn.shape[1] == 1:
|
||
|
# when ONNX tracing a single decoder step (sequence length == 1)
|
||
|
# the transpose is a no-op copy before view, thus unnecessary
|
||
|
attn = attn.reshape([tgt_len, bsz, self.embed_dim])
|
||
|
else:
|
||
|
attn = attn.transpose([1, 0, 2]).reshape(
|
||
|
[tgt_len, bsz, self.embed_dim])
|
||
|
attn = self.out_proj(attn)
|
||
|
attn_weights: Optional[Tensor] = None
|
||
|
if need_weights:
|
||
|
attn_weights = attn_weights_float.reshape(
|
||
|
[bsz, self.num_heads, tgt_len, src_len]).transpose([1, 0, 2, 3])
|
||
|
if not need_head_weights:
|
||
|
# average attention weights over heads
|
||
|
attn_weights = attn_weights.mean(axis=0)
|
||
|
|
||
|
return attn, attn_weights
|
||
|
|
||
|
@staticmethod
|
||
|
def _append_prev_key_padding_mask(
|
||
|
key_padding_mask: Optional[Tensor],
|
||
|
prev_key_padding_mask: Optional[Tensor],
|
||
|
batch_size: int,
|
||
|
src_len: int,
|
||
|
static_kv: bool, ) -> Optional[Tensor]:
|
||
|
# saved key padding masks have shape (bsz, seq_len)
|
||
|
if prev_key_padding_mask is not None and static_kv:
|
||
|
new_key_padding_mask = prev_key_padding_mask
|
||
|
elif prev_key_padding_mask is not None and key_padding_mask is not None:
|
||
|
new_key_padding_mask = paddle.concat(
|
||
|
[
|
||
|
paddle.cast(prev_key_padding_mask, 'float32'),
|
||
|
paddle.cast(key_padding_mask, 'float32')
|
||
|
],
|
||
|
axis=1)
|
||
|
# During incremental decoding, as the padding token enters and
|
||
|
# leaves the frame, there will be a time when prev or current
|
||
|
# is None
|
||
|
elif prev_key_padding_mask is not None:
|
||
|
if src_len > prev_key_padding_mask.shape[1]:
|
||
|
filler = paddle.zeros(
|
||
|
[batch_size, src_len - prev_key_padding_mask.shape[1]], )
|
||
|
new_key_padding_mask = paddle.concat(
|
||
|
[
|
||
|
paddle.cast(prev_key_padding_mask, 'float32'),
|
||
|
paddle.cast(filler, 'float32')
|
||
|
],
|
||
|
axis=1)
|
||
|
else:
|
||
|
new_key_padding_mask = prev_key_padding_mask
|
||
|
elif key_padding_mask is not None:
|
||
|
if src_len > key_padding_mask.shape[1]:
|
||
|
filler = paddle.zeros(
|
||
|
[batch_size, src_len - key_padding_mask.shape[1]], )
|
||
|
new_key_padding_mask = paddle.concat(
|
||
|
[
|
||
|
paddle.cast(filler, 'float32'),
|
||
|
paddle.cast(key_padding_mask, 'float32')
|
||
|
],
|
||
|
axis=1)
|
||
|
else:
|
||
|
new_key_padding_mask = paddle.cast(key_padding_mask, 'float32')
|
||
|
else:
|
||
|
new_key_padding_mask = prev_key_padding_mask
|
||
|
return new_key_padding_mask
|
||
|
|
||
|
@paddle.jit.to_static
|
||
|
def reorder_incremental_state(
|
||
|
self,
|
||
|
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
|
||
|
new_order: Tensor, ):
|
||
|
"""Reorder buffered internal state (for incremental generation)."""
|
||
|
input_buffer = self._get_input_buffer(incremental_state)
|
||
|
if input_buffer is not None:
|
||
|
for k in input_buffer.keys():
|
||
|
input_buffer_k = input_buffer[k]
|
||
|
if input_buffer_k is not None:
|
||
|
if self.encoder_decoder_attention:
|
||
|
if input_buffer_k.shape[
|
||
|
0] * self.beam_size == new_order.shape[0]:
|
||
|
return incremental_state
|
||
|
elif self.beam_size > 1:
|
||
|
input_buffer[k] = paddle.index_select(
|
||
|
input_buffer_k,
|
||
|
index=new_order.reshape(
|
||
|
[-1, self.beam_size])[:, 0] //
|
||
|
self.beam_size,
|
||
|
axis=0, )
|
||
|
else:
|
||
|
input_buffer[k] = paddle.index_select(
|
||
|
input_buffer_k, index=new_order, axis=0)
|
||
|
else:
|
||
|
input_buffer[k] = paddle.index_select(
|
||
|
input_buffer_k, index=new_order, axis=0)
|
||
|
incremental_state = self._set_input_buffer(incremental_state,
|
||
|
input_buffer)
|
||
|
return incremental_state
|
||
|
|
||
|
def set_beam_size(self, beam_size):
|
||
|
"""Used for effiecient beamable enc-dec attention"""
|
||
|
self.beam_size = beam_size
|
||
|
|
||
|
def _get_input_buffer(
|
||
|
self,
|
||
|
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
|
||
|
) -> Dict[str, Optional[Tensor]]:
|
||
|
result = self.get_incremental_state(incremental_state, "attn_state")
|
||
|
if result is not None:
|
||
|
return result
|
||
|
else:
|
||
|
empty_result: Dict[str, Optional[Tensor]] = {}
|
||
|
return empty_result
|
||
|
|
||
|
def _set_input_buffer(
|
||
|
self,
|
||
|
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
|
||
|
buffer: Dict[str, Optional[Tensor]], ):
|
||
|
return self.set_incremental_state(incremental_state, "attn_state",
|
||
|
buffer)
|
||
|
|
||
|
def apply_sparse_mask(self,
|
||
|
attn_weights,
|
||
|
tgt_len: int,
|
||
|
src_len: int,
|
||
|
bsz: int):
|
||
|
return attn_weights
|
||
|
|
||
|
def upgrade_state_dict_named(self, state_dict, name):
|
||
|
prefix = name + "." if name != "" else ""
|
||
|
items_to_add = {}
|
||
|
keys_to_remove = []
|
||
|
for k in state_dict.keys():
|
||
|
if k.endswith(prefix + "in_proj_weight"):
|
||
|
# in_proj_weight used to be q + k + v with same dimensions
|
||
|
dim = int(state_dict[k].shape[0] / 3)
|
||
|
items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
|
||
|
items_to_add[prefix +
|
||
|
"k_proj.weight"] = state_dict[k][dim:2 * dim]
|
||
|
items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim:]
|
||
|
|
||
|
keys_to_remove.append(k)
|
||
|
|
||
|
k_bias = prefix + "in_proj_bias"
|
||
|
if k_bias in state_dict.keys():
|
||
|
dim = int(state_dict[k].shape[0] / 3)
|
||
|
items_to_add[prefix +
|
||
|
"q_proj.bias"] = state_dict[k_bias][:dim]
|
||
|
items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][
|
||
|
dim:2 * dim]
|
||
|
items_to_add[prefix +
|
||
|
"v_proj.bias"] = state_dict[k_bias][2 * dim:]
|
||
|
|
||
|
keys_to_remove.append(prefix + "in_proj_bias")
|
||
|
|
||
|
for k in keys_to_remove:
|
||
|
del state_dict[k]
|
||
|
|
||
|
for key, value in items_to_add.items():
|
||
|
state_dict[key] = value
|
||
|
|
||
|
|
||
|
class GumbelVectorQuantizer(nn.Layer):
|
||
|
def __init__(
|
||
|
self,
|
||
|
dim,
|
||
|
num_vars,
|
||
|
temp,
|
||
|
groups,
|
||
|
combine_groups,
|
||
|
vq_dim,
|
||
|
time_first,
|
||
|
activation=nn.GELU(),
|
||
|
weight_proj_depth=1,
|
||
|
weight_proj_factor=1, ):
|
||
|
"""Vector quantization using gumbel softmax
|
||
|
|
||
|
Args:
|
||
|
dim: input dimension (channels)
|
||
|
num_vars: number of quantized vectors per group
|
||
|
temp: temperature for training. this should be a tuple of 3 elements: (start, stop, decay factor)
|
||
|
groups: number of groups for vector quantization
|
||
|
combine_groups: whether to use the vectors for all groups
|
||
|
vq_dim: dimensionality of the resulting quantized vector
|
||
|
time_first: if true, expect input in BxTxC format, otherwise in BxCxT
|
||
|
activation: what activation to use (should be a module). this is only used if weight_proj_depth is > 1
|
||
|
weight_proj_depth: number of layers (with activation in between) to project input before computing logits
|
||
|
weight_proj_factor: this is used only if weight_proj_depth is > 1. scales the inner dimensionality of
|
||
|
projections by this factor
|
||
|
"""
|
||
|
super().__init__()
|
||
|
|
||
|
self.groups = groups
|
||
|
self.combine_groups = combine_groups
|
||
|
self.input_dim = dim
|
||
|
self.num_vars = num_vars
|
||
|
self.time_first = time_first
|
||
|
|
||
|
assert (
|
||
|
vq_dim % groups == 0
|
||
|
), f"dim {vq_dim} must be divisible by groups {groups} for concatenation"
|
||
|
|
||
|
var_dim = vq_dim // groups
|
||
|
num_groups = groups if not combine_groups else 1
|
||
|
|
||
|
self.vars = self.create_parameter(
|
||
|
(1, num_groups * num_vars, var_dim),
|
||
|
default_initializer=nn.initializer.Uniform())
|
||
|
|
||
|
if weight_proj_depth > 1:
|
||
|
|
||
|
def block(input_dim, output_dim):
|
||
|
return nn.Sequential(Linear(input_dim, output_dim), activation)
|
||
|
|
||
|
inner_dim = self.input_dim * weight_proj_factor
|
||
|
self.weight_proj = nn.Sequential(
|
||
|
*[
|
||
|
block(self.input_dim if i == 0 else inner_dim, inner_dim)
|
||
|
for i in range(weight_proj_depth - 1)
|
||
|
],
|
||
|
Linear(inner_dim, groups * num_vars), )
|
||
|
else:
|
||
|
self.weight_proj = Linear(
|
||
|
self.input_dim,
|
||
|
groups * num_vars,
|
||
|
weight_attr=nn.initializer.Normal(mean=0, std=1),
|
||
|
bias_attr=nn.initializer.Zero())
|
||
|
|
||
|
if isinstance(temp, str):
|
||
|
import ast
|
||
|
|
||
|
temp = ast.literal_eval(temp)
|
||
|
assert len(temp) == 3, f"{temp}, {len(temp)}"
|
||
|
|
||
|
self.max_temp, self.min_temp, self.temp_decay = temp
|
||
|
self.curr_temp = self.max_temp
|
||
|
self.codebook_indices = None
|
||
|
|
||
|
def set_num_updates(self, num_updates):
|
||
|
self.curr_temp = max(self.max_temp * self.temp_decay**num_updates,
|
||
|
self.min_temp)
|
||
|
|
||
|
def get_codebook_indices(self):
|
||
|
if self.codebook_indices is None:
|
||
|
from itertools import product
|
||
|
|
||
|
p = [range(self.num_vars)] * self.groups
|
||
|
inds = list(product(*p))
|
||
|
self.codebook_indices = paddle.to_tensor(
|
||
|
inds, dtype='int64', place=self.vars.place).flatten()
|
||
|
|
||
|
if not self.combine_groups:
|
||
|
self.codebook_indices = self.codebook_indices.reshape(
|
||
|
self.num_vars**self.groups, -1)
|
||
|
for b in range(1, self.groups):
|
||
|
self.codebook_indices[:, b] += self.num_vars * b
|
||
|
self.codebook_indices = self.codebook_indices.flatten()
|
||
|
return self.codebook_indices
|
||
|
|
||
|
def codebook(self):
|
||
|
indices = self.get_codebook_indices()
|
||
|
return (self.vars.squeeze(0).index_select(0, indices)
|
||
|
.reshape(self.num_vars**self.groups, -1))
|
||
|
|
||
|
def sample_from_codebook(self, b, n):
|
||
|
indices = self.get_codebook_indices()
|
||
|
indices = indices.reshape(-1, self.groups)
|
||
|
cb_size = indices.shape[0]
|
||
|
assert (n < cb_size
|
||
|
), f"sample size {n} is greater than size of codebook {cb_size}"
|
||
|
sample_idx = paddle.randint(low=0, high=cb_size, shape=(b * n, ))
|
||
|
indices = indices[sample_idx]
|
||
|
|
||
|
z = self.vars.squeeze(0).index_select(0, indices.flatten()).reshape(
|
||
|
b, n, -1)
|
||
|
return z
|
||
|
|
||
|
def to_codebook_index(self, indices):
|
||
|
res = paddle.full(indices.shape[:-1], 0, dtype=indices.dtype)
|
||
|
for i in range(self.groups):
|
||
|
exponent = self.groups - i - 1
|
||
|
res += indices[..., i] * (self.num_vars**exponent)
|
||
|
return res
|
||
|
|
||
|
def forward_idx(self, x):
|
||
|
res = self.forward(x, produce_targets=True)
|
||
|
return res["x"], res["targets"]
|
||
|
|
||
|
def forward(self, x, produce_targets=False):
|
||
|
result = {"num_vars": self.num_vars * self.groups}
|
||
|
|
||
|
if not self.time_first:
|
||
|
x = x.transpose([0, 2, 1])
|
||
|
|
||
|
bsz, tsz, fsz = x.shape
|
||
|
x = x.reshape([-1, fsz])
|
||
|
x = self.weight_proj(x)
|
||
|
x = x.reshape([bsz * tsz * self.groups, -1])
|
||
|
|
||
|
_, k = x.max(-1)
|
||
|
hard_x = paddle.zeros_like(x)
|
||
|
hard_x.scatter_(-1, k.reshape([-1, 1]), 1.0)
|
||
|
hard_x = hard_x.reshape([bsz * tsz, self.groups, -1])
|
||
|
hard_probs = paddle.mean(hard_x.astype('float32'), axis=0)
|
||
|
result["code_perplexity"] = paddle.exp(-paddle.sum(
|
||
|
hard_probs * paddle.log(hard_probs + 1e-7), axis=-1)).sum()
|
||
|
|
||
|
avg_probs = F.softmax(
|
||
|
x.reshape([bsz * tsz, self.groups, -1]).astype('float32'),
|
||
|
axis=-1).mean(axis=0)
|
||
|
result["prob_perplexity"] = paddle.exp(-paddle.sum(
|
||
|
avg_probs * paddle.log(avg_probs + 1e-7), axis=-1)).sum()
|
||
|
|
||
|
result["temp"] = self.curr_temp
|
||
|
|
||
|
if self.training:
|
||
|
x = F.gumbel_softmax(
|
||
|
x.astype('float32'), temperature=self.curr_temp,
|
||
|
hard=True).astype(x.dtype)
|
||
|
else:
|
||
|
x = hard_x
|
||
|
|
||
|
x = x.reshape([bsz * tsz, -1])
|
||
|
|
||
|
vars = self.vars
|
||
|
if self.combine_groups:
|
||
|
vars = vars.tile([1, self.groups, 1])
|
||
|
|
||
|
if produce_targets:
|
||
|
result["targets"] = (x.reshape([bsz * tsz * self.groups, -1])
|
||
|
.argmax(axis=-1)
|
||
|
.reshape([bsz, tsz, self.groups]).detach())
|
||
|
|
||
|
x = x.unsqueeze(-1) * vars
|
||
|
x = x.reshape([bsz * tsz, self.groups, self.num_vars, -1])
|
||
|
x = x.sum(axis=-2)
|
||
|
x = x.reshape([bsz, tsz, -1])
|
||
|
|
||
|
if not self.time_first:
|
||
|
x = x.transpose([0, 2, 1])
|
||
|
|
||
|
result["x"] = x
|
||
|
|
||
|
return result
|
||
|
|
||
|
|
||
|
class GradMultiply(paddle.autograd.PyLayer):
|
||
|
@staticmethod
|
||
|
def forward(ctx, x, scale):
|
||
|
ctx.scale = scale
|
||
|
res = x.numpy().copy()
|
||
|
return paddle.to_tensor(res, dtype=x.dtype)
|
||
|
|
||
|
@staticmethod
|
||
|
def backward(ctx, grad):
|
||
|
return grad * ctx.scale, None
|
||
|
|
||
|
|
||
|
class SamePad(nn.Layer):
|
||
|
def __init__(self, kernel_size, causal=False):
|
||
|
super().__init__()
|
||
|
if causal:
|
||
|
self.remove = kernel_size - 1
|
||
|
else:
|
||
|
self.remove = 1 if kernel_size % 2 == 0 else 0
|
||
|
|
||
|
def forward(self, x):
|
||
|
if self.remove > 0:
|
||
|
x = x[:, :, :-self.remove]
|
||
|
return x
|
||
|
|
||
|
|
||
|
class TransposeLast(nn.Layer):
|
||
|
def __init__(self, deconstruct_idx=None):
|
||
|
super().__init__()
|
||
|
self.deconstruct_idx = deconstruct_idx
|
||
|
|
||
|
def forward(self, x):
|
||
|
if self.deconstruct_idx is not None:
|
||
|
x = x[self.deconstruct_idx]
|
||
|
trans_dim = paddle.arange(x.dim())
|
||
|
trans_dim[-1], trans_dim[-2] = trans_dim[-2], trans_dim[-1]
|
||
|
return x.transpose(trans_dim)
|
||
|
|
||
|
|
||
|
class Fp32LayerNorm(LayerNorm):
|
||
|
def __init__(self, *args, **kwargs):
|
||
|
super().__init__(*args, **kwargs)
|
||
|
|
||
|
def forward(self, input):
|
||
|
output = F.layer_norm(
|
||
|
input.astype('float32'),
|
||
|
self._normalized_shape,
|
||
|
self.weight.astype('float32') if self.weight is not None else None,
|
||
|
self.bias.astype('float32') if self.bias is not None else None,
|
||
|
self._epsilon, )
|
||
|
return output.astype(input.dtype)
|
||
|
|
||
|
|
||
|
# Todo: change this when paddle supports F.group_norm
|
||
|
class Fp32GroupNorm(nn.Layer):
|
||
|
def __init__(self, *args, **kwargs):
|
||
|
super().__init__()
|
||
|
self.group_norm = paddle.nn.GroupNorm(*args, **kwargs)
|
||
|
fp32_weight = paddle.create_parameter(
|
||
|
shape=self.group_norm.weight.shape,
|
||
|
dtype='float32',
|
||
|
default_initializer=paddle.nn.initializer.Assign(
|
||
|
self.group_norm.weight))
|
||
|
fp32_bias = paddle.create_parameter(
|
||
|
shape=self.group_norm.bias.shape,
|
||
|
dtype='float32',
|
||
|
default_initializer=paddle.nn.initializer.Assign(
|
||
|
self.group_norm.bias))
|
||
|
self.group_norm.weight = fp32_weight
|
||
|
self.group_norm.bias = fp32_bias
|
||
|
|
||
|
def forward(self, input):
|
||
|
output = self.group_norm(input.astype('float32'))
|
||
|
return output.astype(input.dtype)
|
||
|
|
||
|
|
||
|
class StrEnumMeta(EnumMeta):
|
||
|
# this is workaround for submitit pickling leading to instance checks failing in hydra for StrEnum, see
|
||
|
# https://github.com/facebookresearch/hydra/issues/1156
|
||
|
@classmethod
|
||
|
def __instancecheck__(cls, other):
|
||
|
return "enum" in str(type(other))
|
||
|
|
||
|
|
||
|
class StrEnum(Enum, metaclass=StrEnumMeta):
|
||
|
def __str__(self):
|
||
|
return self.value
|
||
|
|
||
|
def __eq__(self, other: str):
|
||
|
return self.value == other
|
||
|
|
||
|
def __repr__(self):
|
||
|
return self.value
|
||
|
|
||
|
def __hash__(self):
|
||
|
return hash(str(self))
|
||
|
|
||
|
|
||
|
def ChoiceEnum(choices: List[str]):
|
||
|
"""return the Enum class used to enforce list of choices"""
|
||
|
return StrEnum("Choices", {k: k for k in choices})
|
||
|
|
||
|
|
||
|
def relu_squared(x: paddle.Tensor):
|
||
|
return F.relu(x).pow(2)
|
||
|
|
||
|
|
||
|
def get_activation_fn(activation: str) -> Callable:
|
||
|
"""Returns the activation function corresponding to `activation`"""
|
||
|
|
||
|
def gelu_accurate(x):
|
||
|
if not hasattr(gelu_accurate, "_a"):
|
||
|
gelu_accurate._a = math.sqrt(2 / math.pi)
|
||
|
return (0.5 * x * (1 + paddle.tanh(gelu_accurate._a *
|
||
|
(x + 0.044715 * paddle.pow(x, 3)))))
|
||
|
|
||
|
def gelu(x: paddle.Tensor) -> paddle.Tensor:
|
||
|
return paddle.nn.functional.gelu(x.astype('float32')).astype(x.dtype)
|
||
|
|
||
|
if activation == "relu":
|
||
|
return F.relu
|
||
|
elif activation == "relu_squared":
|
||
|
return relu_squared
|
||
|
elif activation == "gelu":
|
||
|
return gelu
|
||
|
elif activation == "gelu_fast":
|
||
|
return gelu_accurate
|
||
|
elif activation == "gelu_accurate":
|
||
|
return gelu_accurate
|
||
|
elif activation == "tanh":
|
||
|
return paddle.tanh
|
||
|
elif activation == "linear":
|
||
|
return lambda x: x
|
||
|
elif activation == "swish":
|
||
|
return paddle.nn.Swish
|
||
|
else:
|
||
|
raise RuntimeError(
|
||
|
"--activation-fn {} not supported".format(activation))
|
||
|
|
||
|
|
||
|
def get_available_activation_fns() -> List:
|
||
|
return [
|
||
|
"relu",
|
||
|
"gelu",
|
||
|
"gelu_fast", # deprecated
|
||
|
"gelu_accurate",
|
||
|
"tanh",
|
||
|
"linear",
|
||
|
]
|
||
|
|
||
|
|
||
|
def compute_mask_indices(
|
||
|
shape: Tuple[int, int],
|
||
|
padding_mask: Optional[paddle.Tensor],
|
||
|
mask_prob: float,
|
||
|
mask_length: int,
|
||
|
mask_type: str="static",
|
||
|
mask_other: float=0.0,
|
||
|
min_masks: int=0,
|
||
|
no_overlap: bool=False,
|
||
|
min_space: int=0,
|
||
|
require_same_masks: bool=True,
|
||
|
mask_dropout: float=0.0, ) -> np.ndarray:
|
||
|
"""
|
||
|
Computes random mask spans for a given shape
|
||
|
|
||
|
Args:
|
||
|
shape: the the shape for which to compute masks.
|
||
|
should be of size 2 where first element is batch size and 2nd is timesteps
|
||
|
padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
|
||
|
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
|
||
|
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
|
||
|
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
|
||
|
mask_type: how to compute mask lengths
|
||
|
static = fixed size
|
||
|
uniform = sample from uniform distribution [mask_other, mask_length*2]
|
||
|
normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
|
||
|
poisson = sample from possion distribution with lambda = mask length
|
||
|
min_masks: minimum number of masked spans
|
||
|
no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
|
||
|
min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
|
||
|
require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample
|
||
|
mask_dropout: randomly dropout this percentage of masks in each example
|
||
|
"""
|
||
|
|
||
|
bsz, all_sz = shape
|
||
|
mask = np.full((bsz, all_sz), False)
|
||
|
|
||
|
all_num_mask = int(
|
||
|
# add a random number for probabilistic rounding
|
||
|
mask_prob * all_sz / float(mask_length) + np.random.rand())
|
||
|
|
||
|
all_num_mask = max(min_masks, all_num_mask)
|
||
|
|
||
|
mask_idcs = []
|
||
|
for i in range(bsz):
|
||
|
if padding_mask is not None:
|
||
|
sz = all_sz - padding_mask[i].long().sum().item()
|
||
|
num_mask = int(
|
||
|
# add a random number for probabilistic rounding
|
||
|
mask_prob * sz / float(mask_length) + np.random.rand())
|
||
|
num_mask = max(min_masks, num_mask)
|
||
|
else:
|
||
|
sz = all_sz
|
||
|
num_mask = all_num_mask
|
||
|
|
||
|
if mask_type == "static":
|
||
|
lengths = np.full(num_mask, mask_length)
|
||
|
elif mask_type == "uniform":
|
||
|
lengths = np.random.randint(
|
||
|
mask_other, mask_length * 2 + 1, size=num_mask)
|
||
|
elif mask_type == "normal":
|
||
|
lengths = np.random.normal(mask_length, mask_other, size=num_mask)
|
||
|
lengths = [max(1, int(round(x))) for x in lengths]
|
||
|
elif mask_type == "poisson":
|
||
|
lengths = np.random.poisson(mask_length, size=num_mask)
|
||
|
lengths = [int(round(x)) for x in lengths]
|
||
|
else:
|
||
|
raise Exception("unknown mask selection " + mask_type)
|
||
|
|
||
|
if sum(lengths) == 0:
|
||
|
lengths[0] = min(mask_length, sz - 1)
|
||
|
|
||
|
if no_overlap:
|
||
|
mask_idc = []
|
||
|
|
||
|
def arrange(s, e, length, keep_length):
|
||
|
span_start = np.random.randint(s, e - length)
|
||
|
mask_idc.extend(span_start + i for i in range(length))
|
||
|
|
||
|
new_parts = []
|
||
|
if span_start - s - min_space >= keep_length:
|
||
|
new_parts.append((s, span_start - min_space + 1))
|
||
|
if e - span_start - length - min_space > keep_length:
|
||
|
new_parts.append((span_start + length + min_space, e))
|
||
|
return new_parts
|
||
|
|
||
|
parts = [(0, sz)]
|
||
|
min_length = min(lengths)
|
||
|
for length in sorted(lengths, reverse=True):
|
||
|
lens = np.fromiter(
|
||
|
(e - s if e - s >= length + min_space else 0
|
||
|
for s, e in parts),
|
||
|
np.int, )
|
||
|
l_sum = np.sum(lens)
|
||
|
if l_sum == 0:
|
||
|
break
|
||
|
probs = lens / np.sum(lens)
|
||
|
c = np.random.choice(len(parts), p=probs)
|
||
|
s, e = parts.pop(c)
|
||
|
parts.extend(arrange(s, e, length, min_length))
|
||
|
mask_idc = np.asarray(mask_idc)
|
||
|
else:
|
||
|
min_len = min(lengths)
|
||
|
if sz - min_len <= num_mask:
|
||
|
min_len = sz - num_mask - 1
|
||
|
|
||
|
mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
|
||
|
|
||
|
mask_idc = np.asarray([
|
||
|
mask_idc[j] + offset
|
||
|
for j in range(len(mask_idc)) for offset in range(lengths[j])
|
||
|
])
|
||
|
|
||
|
mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
|
||
|
|
||
|
min_len = min([len(m) for m in mask_idcs])
|
||
|
for i, mask_idc in enumerate(mask_idcs):
|
||
|
if len(mask_idc) > min_len and require_same_masks:
|
||
|
mask_idc = np.random.choice(mask_idc, min_len, replace=False)
|
||
|
if mask_dropout > 0:
|
||
|
num_holes = np.rint(len(mask_idc) * mask_dropout).astype(int)
|
||
|
mask_idc = np.random.choice(
|
||
|
mask_idc, len(mask_idc) - num_holes, replace=False)
|
||
|
|
||
|
mask[i, mask_idc] = True
|
||
|
|
||
|
return mask
|
||
|
|
||
|
|
||
|
def index_put(tensor, indices, value):
|
||
|
tensor[indices] = value
|
||
|
return tensor
|
||
|
|
||
|
|
||
|
# ToDo if faster?
|
||
|
def buffered_arange(max):
|
||
|
if not hasattr(buffered_arange, "buf"):
|
||
|
buffered_arange.buf = paddle.empty([max], dtype='int64')
|
||
|
if max > buffered_arange.buf.numel():
|
||
|
buffered_arange.buf = paddle.arange(max)
|
||
|
return buffered_arange.buf[:max]
|
||
|
|
||
|
|
||
|
def pad_to_multiple(x, multiple, dim=-1, value=0):
|
||
|
# Inspired from https://github.com/lucidrains/local-attention/blob/master/local_attention/local_attention.py#L41
|
||
|
if x is None:
|
||
|
return None, 0
|
||
|
tsz = x.shape[dim]
|
||
|
m = tsz / multiple
|
||
|
remainder = math.ceil(m) * multiple - tsz
|
||
|
if m.is_integer():
|
||
|
return x, 0
|
||
|
pad_offset = (0, ) * (-1 - dim) * 2
|
||
|
return F.pad(
|
||
|
x,
|
||
|
pad=[*pad_offset, 0, remainder, *pad_offset],
|
||
|
value=value,
|
||
|
data_format='NLC'), remainder
|
||
|
|
||
|
|
||
|
EXTRACTOR_MODE_CHOICES = ChoiceEnum(["default", "layer_norm"])
|
||
|
MASKING_DISTRIBUTION_CHOICES = ChoiceEnum(
|
||
|
["static", "uniform", "normal", "poisson"])
|
||
|
LAYER_TYPE_CHOICES = ChoiceEnum(["transformer"]) # ToDo: conformer
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class Wav2Vec2Config:
|
||
|
extractor_mode: EXTRACTOR_MODE_CHOICES = field(
|
||
|
default="default",
|
||
|
metadata={
|
||
|
"help":
|
||
|
"mode for feature extractor. default has a single group norm with d "
|
||
|
"groups in the first conv block, whereas layer_norm has layer norms in "
|
||
|
"every block (meant to use with normalize=True)"
|
||
|
}, )
|
||
|
encoder_layers: int = field(
|
||
|
default=12, metadata={"help": "num encoder layers in the transformer"})
|
||
|
encoder_embed_dim: int = field(
|
||
|
default=768, metadata={"help": "encoder embedding dimension"})
|
||
|
encoder_ffn_embed_dim: int = field(
|
||
|
default=3072, metadata={"help": "encoder embedding dimension for FFN"})
|
||
|
encoder_attention_heads: int = field(
|
||
|
default=12, metadata={"help": "num encoder attention heads"})
|
||
|
activation_fn: ChoiceEnum(get_available_activation_fns()) = field(
|
||
|
default="gelu", metadata={"help": "activation function to use"})
|
||
|
layer_type: LAYER_TYPE_CHOICES = field(
|
||
|
default="transformer", metadata={"help": "layer type in encoder"})
|
||
|
# dropouts
|
||
|
dropout: float = field(
|
||
|
default=0.1,
|
||
|
metadata={"help": "dropout probability for the transformer"})
|
||
|
attention_dropout: float = field(
|
||
|
default=0.1,
|
||
|
metadata={"help": "dropout probability for attention weights"})
|
||
|
activation_dropout: float = field(
|
||
|
default=0.0,
|
||
|
metadata={"help": "dropout probability after activation in FFN"})
|
||
|
encoder_layerdrop: float = field(
|
||
|
default=0.0,
|
||
|
metadata={"help": "probability of dropping a tarnsformer layer"})
|
||
|
dropout_input: float = field(
|
||
|
default=0.0,
|
||
|
metadata={"help": "dropout to apply to the input (after feat extr)"}, )
|
||
|
dropout_features: float = field(
|
||
|
default=0.0,
|
||
|
metadata={"help": "dropout to apply to the features (after feat extr)"},
|
||
|
)
|
||
|
|
||
|
final_dim: int = field(
|
||
|
default=0,
|
||
|
metadata={
|
||
|
"help":
|
||
|
"project final representations and targets to this many dimensions."
|
||
|
"set to encoder_embed_dim is <= 0"
|
||
|
}, )
|
||
|
layer_norm_first: bool = field(
|
||
|
default=False,
|
||
|
metadata={"help": "apply layernorm first in the transformer"})
|
||
|
conv_feature_layers: str = field(
|
||
|
default="[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]",
|
||
|
metadata={
|
||
|
"help":
|
||
|
"string describing convolutional feature extraction layers in form of a python list that contains "
|
||
|
"[(dim, kernel_size, stride), ...]"
|
||
|
}, )
|
||
|
conv_bias: bool = field(
|
||
|
default=False, metadata={"help": "include bias in conv encoder"})
|
||
|
logit_temp: float = field(
|
||
|
default=0.1, metadata={"help": "temperature to divide logits by"})
|
||
|
quantize_targets: bool = field(
|
||
|
default=False, metadata={"help": "use quantized targets"})
|
||
|
quantize_input: bool = field(
|
||
|
default=False, metadata={"help": "use quantized inputs"})
|
||
|
same_quantizer: bool = field(
|
||
|
default=False,
|
||
|
metadata={"help": "use same quantizer for inputs and targets"})
|
||
|
target_glu: bool = field(
|
||
|
default=False, metadata={"help": "adds projection + glu to targets"})
|
||
|
feature_grad_mult: float = field(
|
||
|
default=1.0,
|
||
|
metadata={"help": "multiply feature extractor var grads by this"})
|
||
|
quantizer_depth: int = field(
|
||
|
default=1,
|
||
|
metadata={"help": "number of quantizer layers"}, )
|
||
|
quantizer_factor: int = field(
|
||
|
default=3,
|
||
|
metadata={
|
||
|
"help":
|
||
|
"dimensionality increase for inner quantizer layers (if depth > 1)"
|
||
|
}, )
|
||
|
latent_vars: int = field(
|
||
|
default=320,
|
||
|
metadata={
|
||
|
"help": "number of latent variables V in each group of the codebook"
|
||
|
}, )
|
||
|
latent_groups: int = field(
|
||
|
default=2,
|
||
|
metadata={
|
||
|
"help": "number of groups G of latent variables in the codebook"
|
||
|
}, )
|
||
|
latent_dim: int = field(
|
||
|
default=0,
|
||
|
metadata={
|
||
|
"help":
|
||
|
"if > 0, uses this dimensionality for latent variables. "
|
||
|
"otherwise uses final_dim / latent_groups"
|
||
|
}, )
|
||
|
|
||
|
# masking
|
||
|
mask_length: int = field(default=10, metadata={"help": "mask length"})
|
||
|
mask_prob: float = field(
|
||
|
default=0.65,
|
||
|
metadata={"help": "probability of replacing a token with mask"})
|
||
|
mask_selection: MASKING_DISTRIBUTION_CHOICES = field(
|
||
|
default="static", metadata={"help": "how to choose mask length"})
|
||
|
mask_other: float = field(
|
||
|
default=0,
|
||
|
metadata={
|
||
|
"help":
|
||
|
"secondary mask argument (used for more complex distributions), "
|
||
|
"see help in compute_mask_indices"
|
||
|
}, )
|
||
|
no_mask_overlap: bool = field(
|
||
|
default=False, metadata={"help": "whether to allow masks to overlap"})
|
||
|
mask_min_space: int = field(
|
||
|
default=1,
|
||
|
metadata={"help": "min space between spans (if no overlap is enabled)"},
|
||
|
)
|
||
|
require_same_masks: bool = field(
|
||
|
default=True,
|
||
|
metadata={
|
||
|
"help":
|
||
|
"whether to number of masked timesteps must be the same across all "
|
||
|
"examples in a batch"
|
||
|
}, )
|
||
|
mask_dropout: float = field(
|
||
|
default=0.0,
|
||
|
metadata={"help": "percent of masks to unmask for each sample"}, )
|
||
|
|
||
|
# channel masking
|
||
|
mask_channel_length: int = field(
|
||
|
default=10,
|
||
|
metadata={"help": "length of the mask for features (channels)"})
|
||
|
mask_channel_prob: float = field(
|
||
|
default=0.0,
|
||
|
metadata={"help": "probability of replacing a feature with 0"})
|
||
|
mask_channel_before: bool = False
|
||
|
mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field(
|
||
|
default="static",
|
||
|
metadata={"help": "how to choose mask length for channel masking"}, )
|
||
|
mask_channel_other: float = field(
|
||
|
default=0,
|
||
|
metadata={
|
||
|
"help":
|
||
|
"secondary mask argument (used for more complex distributions), "
|
||
|
"see help in compute_mask_indicesh"
|
||
|
}, )
|
||
|
no_mask_channel_overlap: bool = field(
|
||
|
default=False,
|
||
|
metadata={"help": "whether to allow channel masks to overlap"})
|
||
|
mask_channel_min_space: int = field(
|
||
|
default=1,
|
||
|
metadata={"help": "min space between spans (if no overlap is enabled)"},
|
||
|
)
|
||
|
|
||
|
# negative selection
|
||
|
num_negatives: int = field(
|
||
|
default=100,
|
||
|
metadata={"help": "number of negative examples from the same sample"}, )
|
||
|
negatives_from_everywhere: bool = field(
|
||
|
default=False,
|
||
|
metadata={
|
||
|
"help": "sample negatives from everywhere, not just masked states"
|
||
|
}, )
|
||
|
cross_sample_negatives: int = field(
|
||
|
default=0,
|
||
|
metadata={"help": "number of negative examples from the any sample"})
|
||
|
codebook_negatives: int = field(
|
||
|
default=0, metadata={"help": "number of negative examples codebook"})
|
||
|
|
||
|
# positional embeddings
|
||
|
conv_pos: int = field(
|
||
|
default=128,
|
||
|
metadata={
|
||
|
"help": "number of filters for convolutional positional embeddings"
|
||
|
}, )
|
||
|
conv_pos_groups: int = field(
|
||
|
default=16,
|
||
|
metadata={
|
||
|
"help": "number of groups for convolutional positional embedding"
|
||
|
}, )
|
||
|
pos_conv_depth: int = field(
|
||
|
default=1,
|
||
|
metadata={"help": "depth of positional encoder network"}, )
|
||
|
|
||
|
latent_temp: Tuple[float, float, float] = field(
|
||
|
default=(2, 0.5, 0.999995),
|
||
|
metadata={
|
||
|
"help":
|
||
|
"temperature for latent variable sampling. "
|
||
|
"can be tuple of 3 values (start, end, decay)"
|
||
|
}, )
|
||
|
max_positions: int = field(
|
||
|
default=100000, metadata={"help": "Max positions"})
|
||
|
checkpoint_activations: bool = field(
|
||
|
default=False,
|
||
|
metadata={
|
||
|
"help": "recompute activations and save memory for extra compute"
|
||
|
}, )
|
||
|
|
||
|
# FP16 optimization
|
||
|
required_seq_len_multiple: int = field(
|
||
|
default=2,
|
||
|
metadata={
|
||
|
"help":
|
||
|
"pad the input to encoder such that the sequence length is divisible by multiple"
|
||
|
}, )
|
||
|
crop_seq_to_multiple: int = field(
|
||
|
default=1,
|
||
|
metadata={
|
||
|
"help":
|
||
|
"crop convolutional feature extractor output such that the sequence length is divisible by multiple"
|
||
|
}, )
|
||
|
|
||
|
# Conformer
|
||
|
depthwise_conv_kernel_size: int = field(
|
||
|
default=31,
|
||
|
metadata={
|
||
|
"help":
|
||
|
"depthwise-conv-kernel-size for convolution in conformer layer"
|
||
|
}, )
|
||
|
attn_type: str = field(
|
||
|
default="",
|
||
|
metadata={"help": "if espnet use ESPNET MHA"}, )
|
||
|
pos_enc_type: str = field(
|
||
|
default="abs",
|
||
|
metadata={"help": "Positional encoding type to use in conformer"}, )
|
||
|
fp16: bool = field(
|
||
|
default=False, metadata={"help": "If fp16 is being used"})
|
||
|
|
||
|
|
||
|
class Wav2Vec2Model(nn.Layer):
|
||
|
def __init__(self, cfg: Wav2Vec2Config):
|
||
|
super().__init__()
|
||
|
self.cfg = cfg
|
||
|
|
||
|
feature_enc_layers = eval(cfg.conv_feature_layers)
|
||
|
self.embed = feature_enc_layers[-1][0]
|
||
|
|
||
|
self.feature_extractor = ConvFeatureExtractionModel(
|
||
|
conv_layers=feature_enc_layers,
|
||
|
dropout=0.0,
|
||
|
mode=cfg.extractor_mode,
|
||
|
conv_bias=cfg.conv_bias, )
|
||
|
|
||
|
self.post_extract_proj = (Linear(self.embed, cfg.encoder_embed_dim)
|
||
|
if self.embed != cfg.encoder_embed_dim and
|
||
|
not cfg.quantize_input else None)
|
||
|
|
||
|
self.crop_seq_to_multiple = cfg.crop_seq_to_multiple
|
||
|
|
||
|
self.mask_prob = cfg.mask_prob
|
||
|
self.mask_selection = cfg.mask_selection
|
||
|
self.mask_other = cfg.mask_other
|
||
|
self.mask_length = cfg.mask_length
|
||
|
self.no_mask_overlap = cfg.no_mask_overlap
|
||
|
self.mask_min_space = cfg.mask_min_space
|
||
|
|
||
|
self.mask_channel_prob = cfg.mask_channel_prob
|
||
|
self.mask_channel_before = cfg.mask_channel_before
|
||
|
self.mask_channel_selection = cfg.mask_channel_selection
|
||
|
self.mask_channel_other = cfg.mask_channel_other
|
||
|
self.mask_channel_length = cfg.mask_channel_length
|
||
|
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
|
||
|
self.mask_channel_min_space = cfg.mask_channel_min_space
|
||
|
|
||
|
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
||
|
self.dropout_features = nn.Dropout(cfg.dropout_features)
|
||
|
|
||
|
self.feature_grad_mult = cfg.feature_grad_mult
|
||
|
|
||
|
self.quantizer = None
|
||
|
self.input_quantizer = None
|
||
|
|
||
|
self.n_negatives = cfg.num_negatives
|
||
|
self.cross_sample_negatives = cfg.cross_sample_negatives
|
||
|
self.codebook_negatives = cfg.codebook_negatives
|
||
|
self.negatives_from_everywhere = cfg.negatives_from_everywhere
|
||
|
|
||
|
self.logit_temp = cfg.logit_temp
|
||
|
|
||
|
final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim
|
||
|
|
||
|
if cfg.quantize_targets:
|
||
|
vq_dim = cfg.latent_dim if cfg.latent_dim > 0 else final_dim
|
||
|
self.quantizer = GumbelVectorQuantizer(
|
||
|
dim=self.embed,
|
||
|
num_vars=cfg.latent_vars,
|
||
|
temp=cfg.latent_temp,
|
||
|
groups=cfg.latent_groups,
|
||
|
combine_groups=False,
|
||
|
vq_dim=vq_dim,
|
||
|
time_first=True,
|
||
|
weight_proj_depth=cfg.quantizer_depth,
|
||
|
weight_proj_factor=cfg.quantizer_factor, )
|
||
|
self.project_q = Linear(vq_dim, final_dim)
|
||
|
else:
|
||
|
self.project_q = Linear(self.embed, final_dim)
|
||
|
|
||
|
if cfg.quantize_input:
|
||
|
if cfg.same_quantizer and self.quantizer is not None:
|
||
|
vq_dim = final_dim
|
||
|
self.input_quantizer = self.quantizer
|
||
|
else:
|
||
|
vq_dim = cfg.latent_dim if cfg.latent_dim > 0 else cfg.encoder_embed_dim
|
||
|
self.input_quantizer = GumbelVectorQuantizer(
|
||
|
dim=self.embed,
|
||
|
num_vars=cfg.latent_vars,
|
||
|
temp=cfg.latent_temp,
|
||
|
groups=cfg.latent_groups,
|
||
|
combine_groups=False,
|
||
|
vq_dim=vq_dim,
|
||
|
time_first=True,
|
||
|
weight_proj_depth=cfg.quantizer_depth,
|
||
|
weight_proj_factor=cfg.quantizer_factor, )
|
||
|
self.project_inp = Linear(vq_dim, cfg.encoder_embed_dim)
|
||
|
|
||
|
self.mask_emb = self.create_parameter(
|
||
|
shape=[cfg.encoder_embed_dim],
|
||
|
default_initializer=paddle.nn.initializer.Uniform(),
|
||
|
dtype='float32', )
|
||
|
|
||
|
encoder_cls = TransformerEncoder
|
||
|
|
||
|
self.encoder = encoder_cls(cfg)
|
||
|
self.layer_norm = LayerNorm(self.embed)
|
||
|
|
||
|
self.target_glu = None
|
||
|
if cfg.target_glu:
|
||
|
self.target_glu = nn.Sequential(
|
||
|
Linear(final_dim, final_dim * 2), GLU())
|
||
|
|
||
|
self.final_proj = Linear(cfg.encoder_embed_dim, final_dim)
|
||
|
|
||
|
def upgrade_state_dict_named(self, state_dict, name):
|
||
|
super().upgrade_state_dict_named(state_dict, name)
|
||
|
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
|
||
|
return state_dict
|
||
|
|
||
|
@classmethod
|
||
|
def build_model(cls, cfg: Wav2Vec2Config, task=None):
|
||
|
"""Build a new model instance."""
|
||
|
return cls(cfg)
|
||
|
|
||
|
def apply_mask(
|
||
|
self,
|
||
|
x,
|
||
|
padding_mask,
|
||
|
mask_indices=None,
|
||
|
mask_channel_indices=None, ):
|
||
|
B, T, C = x.shape
|
||
|
|
||
|
if self.mask_channel_prob > 0 and self.mask_channel_before:
|
||
|
mask_channel_indices = compute_mask_indices(
|
||
|
(B, C),
|
||
|
None,
|
||
|
self.mask_channel_prob,
|
||
|
self.mask_channel_length,
|
||
|
self.mask_channel_selection,
|
||
|
self.mask_channel_other,
|
||
|
no_overlap=self.no_mask_channel_overlap,
|
||
|
min_space=self.mask_channel_min_space, )
|
||
|
mask_channel_indices = (
|
||
|
paddle.to_tensor(mask_channel_indices, plcae=x.plcae)
|
||
|
.unsqueeze(1).expand([-1, T, -1]))
|
||
|
x[mask_channel_indices] = 0
|
||
|
|
||
|
if self.mask_prob > 0:
|
||
|
if mask_indices is None:
|
||
|
mask_indices = compute_mask_indices(
|
||
|
(B, T),
|
||
|
padding_mask,
|
||
|
self.mask_prob,
|
||
|
self.mask_length,
|
||
|
self.mask_selection,
|
||
|
self.mask_other,
|
||
|
min_masks=2,
|
||
|
no_overlap=self.no_mask_overlap,
|
||
|
min_space=self.mask_min_space,
|
||
|
require_same_masks=self.cfg.require_same_masks,
|
||
|
mask_dropout=self.cfg.mask_dropout, )
|
||
|
mask_indices = paddle.to_tensor(mask_indices, place=x.place)
|
||
|
x = index_put(x, mask_indices, self.mask_emb)
|
||
|
else:
|
||
|
mask_indices = None
|
||
|
|
||
|
if self.mask_channel_prob > 0 and not self.mask_channel_before:
|
||
|
if mask_channel_indices is None:
|
||
|
mask_channel_indices = compute_mask_indices(
|
||
|
(B, C),
|
||
|
None,
|
||
|
self.mask_channel_prob,
|
||
|
self.mask_channel_length,
|
||
|
self.mask_channel_selection,
|
||
|
self.mask_channel_other,
|
||
|
no_overlap=self.no_mask_channel_overlap,
|
||
|
min_space=self.mask_channel_min_space, )
|
||
|
mask_channel_indices = (
|
||
|
paddle.to_tensor(mask_channel_indices, place=x.place)
|
||
|
.unsqueeze(1).expand([-1, T, -1]))
|
||
|
x = index_put(x, mask_channel_indices, 0)
|
||
|
|
||
|
return x, mask_indices
|
||
|
|
||
|
def sample_negatives(self, y, num, padding_count=None):
|
||
|
|
||
|
if self.n_negatives == 0 and self.cross_sample_negatives == 0:
|
||
|
return paddle.empty([0], dtype=y.dtype)
|
||
|
|
||
|
bsz, tsz, fsz = y.shape
|
||
|
y = y.reshape([-1, fsz]) # BTC => (BxT)C
|
||
|
|
||
|
# FIXME: what happens if padding_count is specified?
|
||
|
cross_high = tsz * bsz
|
||
|
high = tsz - (padding_count or 0)
|
||
|
with paddle.no_grad():
|
||
|
assert high > 1, f"{bsz,tsz,fsz}"
|
||
|
|
||
|
if self.n_negatives > 0:
|
||
|
tszs = (buffered_arange(num).unsqueeze(-1)
|
||
|
.expand([-1, self.n_negatives]).flatten())
|
||
|
|
||
|
neg_idxs = paddle.randint(
|
||
|
low=0, high=high - 1, shape=[bsz, self.n_negatives * num])
|
||
|
neg_idxs[neg_idxs >= tszs] += 1
|
||
|
|
||
|
if self.cross_sample_negatives > 0:
|
||
|
tszs = (buffered_arange(num).unsqueeze(-1)
|
||
|
.expand([-1, self.cross_sample_negatives]).flatten())
|
||
|
|
||
|
cross_neg_idxs = paddle.randint(
|
||
|
low=0,
|
||
|
high=cross_high - 1,
|
||
|
shape=[bsz, self.cross_sample_negatives * num], )
|
||
|
cross_neg_idxs[cross_neg_idxs >= tszs] += 1
|
||
|
|
||
|
if self.n_negatives > 0:
|
||
|
neg_idxs = neg_idxs + (paddle.arange(bsz).unsqueeze(1) * high)
|
||
|
else:
|
||
|
neg_idxs = cross_neg_idxs
|
||
|
|
||
|
if self.cross_sample_negatives > 0 and self.n_negatives > 0:
|
||
|
neg_idxs = paddle.concat([neg_idxs, cross_neg_idxs], axis=1)
|
||
|
|
||
|
negs = y[neg_idxs.reshape([-1])]
|
||
|
negs = negs.reshape(
|
||
|
[bsz, num, self.n_negatives + self.cross_sample_negatives,
|
||
|
fsz]).transpose([2, 0, 1, 3]) # to NxBxTxC
|
||
|
return negs, neg_idxs
|
||
|
|
||
|
def compute_preds(self, x, y, negatives):
|
||
|
neg_is_pos = (y == negatives).all(-1)
|
||
|
y = y.unsqueeze(0)
|
||
|
targets = paddle.concat([y, negatives], axis=0)
|
||
|
|
||
|
logits = paddle.nn.functional.cosine_similarity(x, targets, axis=-1)
|
||
|
logits = logits / self.logit_temp
|
||
|
logits = logits.astype(x.dtype)
|
||
|
|
||
|
return logits
|
||
|
|
||
|
def _get_feat_extract_output_lengths(self, input_lengths: paddle.Tensor):
|
||
|
"""
|
||
|
Computes the output length of the convolutional layers
|
||
|
"""
|
||
|
|
||
|
def _conv_out_length(input_length, kernel_size, stride):
|
||
|
return paddle.floor((input_length - kernel_size) / stride + 1)
|
||
|
|
||
|
conv_cfg_list = eval(self.cfg.conv_feature_layers)
|
||
|
|
||
|
for i in range(len(conv_cfg_list)):
|
||
|
input_lengths = _conv_out_length(input_lengths, conv_cfg_list[i][1],
|
||
|
conv_cfg_list[i][2])
|
||
|
|
||
|
return paddle.cast(input_lengths, 'int64')
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
source,
|
||
|
padding_mask=None,
|
||
|
mask=True,
|
||
|
features_only=False,
|
||
|
layer=None,
|
||
|
mask_indices=None,
|
||
|
mask_channel_indices=None,
|
||
|
padding_count=None, ):
|
||
|
|
||
|
if self.feature_grad_mult > 0:
|
||
|
features = self.feature_extractor(source)
|
||
|
if self.feature_grad_mult != 1.0:
|
||
|
features = GradMultiply.apply(features, self.feature_grad_mult)
|
||
|
else:
|
||
|
with paddle.no_grad():
|
||
|
features = self.feature_extractor(source)
|
||
|
|
||
|
features_pen = features.pow(2).mean()
|
||
|
|
||
|
features = features.transpose([0, 2, 1])
|
||
|
features = self.layer_norm(features)
|
||
|
unmasked_features = features.clone()
|
||
|
|
||
|
if padding_mask is not None and padding_mask.any():
|
||
|
input_lengths = (1 - paddle.cast(padding_mask, 'int64')).sum(-1)
|
||
|
# apply conv formula to get real output_lengths
|
||
|
output_lengths = self._get_feat_extract_output_lengths(
|
||
|
input_lengths)
|
||
|
|
||
|
padding_mask = paddle.zeros(
|
||
|
features.shape[:2], dtype=features.dtype)
|
||
|
|
||
|
# these two operations makes sure that all values
|
||
|
# before the output lengths indices are attended to
|
||
|
padding_mask[(paddle.arange(padding_mask.shape[0]),
|
||
|
output_lengths - 1, )] = 1
|
||
|
padding_mask = paddle.cast(
|
||
|
(1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])), 'bool')
|
||
|
else:
|
||
|
padding_mask = None
|
||
|
|
||
|
time_steps_to_drop = features.shape[1] % self.crop_seq_to_multiple
|
||
|
if time_steps_to_drop != 0:
|
||
|
features = features[:, :-time_steps_to_drop]
|
||
|
unmasked_features = unmasked_features[:, :-time_steps_to_drop]
|
||
|
if padding_mask is not None:
|
||
|
padding_mask = padding_mask[:, :-time_steps_to_drop]
|
||
|
|
||
|
if self.post_extract_proj is not None:
|
||
|
features = self.post_extract_proj(features)
|
||
|
|
||
|
features = self.dropout_input(features)
|
||
|
unmasked_features = self.dropout_features(unmasked_features)
|
||
|
|
||
|
num_vars = None
|
||
|
code_ppl = None
|
||
|
prob_ppl = None
|
||
|
curr_temp = None
|
||
|
|
||
|
if self.input_quantizer:
|
||
|
q = self.input_quantizer(features, produce_targets=False)
|
||
|
features = q["x"]
|
||
|
num_vars = q["num_vars"]
|
||
|
code_ppl = q["code_perplexity"]
|
||
|
prob_ppl = q["prob_perplexity"]
|
||
|
curr_temp = q["temp"]
|
||
|
features = self.project_inp(features)
|
||
|
|
||
|
if mask:
|
||
|
x, mask_indices = self.apply_mask(
|
||
|
features,
|
||
|
padding_mask,
|
||
|
mask_indices=mask_indices,
|
||
|
mask_channel_indices=mask_channel_indices, )
|
||
|
if mask_indices is not None:
|
||
|
y = unmasked_features[mask_indices].reshape([
|
||
|
unmasked_features.shape[0], -1, unmasked_features.shape[-1]
|
||
|
])
|
||
|
else:
|
||
|
x = features
|
||
|
y = unmasked_features
|
||
|
mask_indices = None
|
||
|
|
||
|
x, layer_results = self.encoder(
|
||
|
x, padding_mask=padding_mask, layer=layer)
|
||
|
|
||
|
if features_only:
|
||
|
return {
|
||
|
"x": x,
|
||
|
"padding_mask": padding_mask,
|
||
|
"features": unmasked_features,
|
||
|
"layer_results": layer_results,
|
||
|
}
|
||
|
|
||
|
if self.quantizer:
|
||
|
if self.negatives_from_everywhere:
|
||
|
q = self.quantizer(unmasked_features, produce_targets=False)
|
||
|
y = q["x"]
|
||
|
num_vars = q["num_vars"]
|
||
|
code_ppl = q["code_perplexity"]
|
||
|
prob_ppl = q["prob_perplexity"]
|
||
|
curr_temp = q["temp"]
|
||
|
y = self.project_q(y)
|
||
|
|
||
|
negs, _ = self.sample_negatives(
|
||
|
y,
|
||
|
mask_indices[0].sum(),
|
||
|
padding_count=padding_count, )
|
||
|
y = y[mask_indices].reshape([y.shape[0], -1, y.shape[-1]])
|
||
|
|
||
|
else:
|
||
|
q = self.quantizer(y, produce_targets=False)
|
||
|
y = q["x"]
|
||
|
num_vars = q["num_vars"]
|
||
|
code_ppl = q["code_perplexity"]
|
||
|
prob_ppl = q["prob_perplexity"]
|
||
|
curr_temp = q["temp"]
|
||
|
|
||
|
y = self.project_q(y)
|
||
|
|
||
|
negs, _ = self.sample_negatives(
|
||
|
y,
|
||
|
y.shape[1],
|
||
|
padding_count=padding_count, )
|
||
|
|
||
|
if self.codebook_negatives > 0:
|
||
|
cb_negs = self.quantizer.sample_from_codebook(
|
||
|
y.shape[0] * y.shape[1], self.codebook_negatives)
|
||
|
cb_negs = cb_negs.reshape(
|
||
|
[self.codebook_negatives, y.shape[0], y.shape[1],
|
||
|
-1]) # order doesnt matter
|
||
|
cb_negs = self.project_q(cb_negs)
|
||
|
negs = paddle.concat([negs, cb_negs], axis=0)
|
||
|
else:
|
||
|
y = self.project_q(y)
|
||
|
|
||
|
if self.negatives_from_everywhere:
|
||
|
negs, _ = self.sample_negatives(
|
||
|
unmasked_features,
|
||
|
y.shape[1],
|
||
|
padding_count=padding_count, )
|
||
|
negs = self.project_q(negs)
|
||
|
else:
|
||
|
negs, _ = self.sample_negatives(
|
||
|
y,
|
||
|
y.shape[1],
|
||
|
padding_count=padding_count, )
|
||
|
|
||
|
x = x[mask_indices].reshape([x.shape[0], -1, x.shape[-1]])
|
||
|
|
||
|
if self.target_glu:
|
||
|
y = self.target_glu(y)
|
||
|
negs = self.target_glu(negs)
|
||
|
|
||
|
x = self.final_proj(x)
|
||
|
x = self.compute_preds(x, y, negs)
|
||
|
|
||
|
result = {
|
||
|
"x": x,
|
||
|
"padding_mask": padding_mask,
|
||
|
"features_pen": features_pen,
|
||
|
}
|
||
|
|
||
|
if prob_ppl is not None:
|
||
|
result["prob_perplexity"] = prob_ppl
|
||
|
result["code_perplexity"] = code_ppl
|
||
|
result["num_vars"] = num_vars
|
||
|
result["temp"] = curr_temp
|
||
|
|
||
|
return result
|
||
|
|
||
|
def quantize(self, x):
|
||
|
assert self.quantizer is not None
|
||
|
x = self.feature_extractor(x)
|
||
|
x = x.transpose([0, 2, 1])
|
||
|
x = self.layer_norm(x)
|
||
|
return self.quantizer.forward_idx(x)
|
||
|
|
||
|
def extract_features(self, source, padding_mask, mask=False, layer=None):
|
||
|
res = self.forward(
|
||
|
source, padding_mask, mask=mask, features_only=True, layer=layer)
|
||
|
return res
|
||
|
|
||
|
def get_logits(self, net_output):
|
||
|
logits = net_output["x"]
|
||
|
logits = logits.transpose([2, 1, 0])
|
||
|
logits = logits.reshape([-1, logits.shape[-1]])
|
||
|
return logits
|
||
|
|
||
|
def get_targets(self, sample, net_output, expand_steps=True):
|
||
|
x = net_output["x"]
|
||
|
return paddle.zeros(x.shape[1] * x.shape[2], dtype='int64')
|
||
|
|
||
|
def get_extra_losses(self, net_output):
|
||
|
pen = []
|
||
|
|
||
|
if "prob_perplexity" in net_output:
|
||
|
pen.append((net_output["num_vars"] - net_output["prob_perplexity"])
|
||
|
/ net_output["num_vars"])
|
||
|
|
||
|
if "features_pen" in net_output:
|
||
|
pen.append(net_output["features_pen"])
|
||
|
|
||
|
return pen
|
||
|
|
||
|
def remove_pretraining_modules(self, last_layer=None):
|
||
|
self.quantizer = None
|
||
|
self.project_q = None
|
||
|
self.target_glu = None
|
||
|
self.final_proj = None
|
||
|
|
||
|
if last_layer is not None:
|
||
|
self.encoder.layers = nn.LayerList(
|
||
|
l for i, l in enumerate(self.encoder.layers) if i <= last_layer)
|
||
|
|
||
|
|
||
|
class ConvFeatureExtractionModel(nn.Layer):
|
||
|
def __init__(
|
||
|
self,
|
||
|
conv_layers: List[Tuple[int, int, int]],
|
||
|
dropout: float=0.0,
|
||
|
mode: str="default",
|
||
|
conv_bias: bool=False, ):
|
||
|
super().__init__()
|
||
|
|
||
|
assert mode in {"default", "layer_norm"}
|
||
|
|
||
|
def block(
|
||
|
n_in,
|
||
|
n_out,
|
||
|
k,
|
||
|
stride,
|
||
|
is_layer_norm=False,
|
||
|
is_group_norm=False,
|
||
|
conv_bias=False, ):
|
||
|
def make_conv():
|
||
|
conv = Conv1D(
|
||
|
n_in,
|
||
|
n_out,
|
||
|
k,
|
||
|
stride=stride,
|
||
|
bias_attr=conv_bias
|
||
|
if not conv_bias else paddle.ParamAttr())
|
||
|
# nn.initializer.KaimingNormal()(conv.weight)
|
||
|
return conv
|
||
|
|
||
|
assert (is_layer_norm and is_group_norm
|
||
|
) is False, "layer norm and group norm are exclusive"
|
||
|
|
||
|
if is_layer_norm:
|
||
|
return nn.Sequential(
|
||
|
make_conv(),
|
||
|
nn.Dropout(p=dropout),
|
||
|
nn.Sequential(
|
||
|
TransposeLast(),
|
||
|
Fp32LayerNorm(dim),
|
||
|
TransposeLast(), ),
|
||
|
nn.GELU(), )
|
||
|
elif is_group_norm:
|
||
|
return nn.Sequential(
|
||
|
make_conv(),
|
||
|
nn.Dropout(p=dropout),
|
||
|
Fp32GroupNorm(dim, dim),
|
||
|
nn.GELU(), )
|
||
|
else:
|
||
|
return nn.Sequential(
|
||
|
make_conv(), nn.Dropout(p=dropout), nn.GELU())
|
||
|
|
||
|
in_d = 1
|
||
|
self.conv_layers = nn.LayerList()
|
||
|
for i, cl in enumerate(conv_layers):
|
||
|
assert len(cl) == 3, "invalid conv definition: " + str(cl)
|
||
|
(dim, k, stride) = cl
|
||
|
|
||
|
self.conv_layers.append(
|
||
|
block(
|
||
|
in_d,
|
||
|
dim,
|
||
|
k,
|
||
|
stride,
|
||
|
is_layer_norm=mode == "layer_norm",
|
||
|
is_group_norm=mode == "default" and i == 0,
|
||
|
conv_bias=conv_bias, ))
|
||
|
in_d = dim
|
||
|
|
||
|
def forward(self, x):
|
||
|
|
||
|
# BxT -> BxCxT
|
||
|
x = x.unsqueeze(1)
|
||
|
for conv in self.conv_layers:
|
||
|
x = conv(x)
|
||
|
|
||
|
return x
|
||
|
|
||
|
|
||
|
def make_conv_pos(e, k, g):
|
||
|
dropout = 0
|
||
|
std = math.sqrt((4 * (1.0 - dropout)) / (k * e))
|
||
|
pos_conv = Conv1D(
|
||
|
e,
|
||
|
e,
|
||
|
kernel_size=k,
|
||
|
padding=k // 2,
|
||
|
groups=g,
|
||
|
weight_attr=nn.initializer.Normal(mean=0, std=std),
|
||
|
bias_attr=nn.initializer.Constant(0))
|
||
|
pos_conv = nn.utils.weight_norm(pos_conv, name="weight", dim=2)
|
||
|
pos_conv = nn.Sequential(pos_conv, SamePad(k), nn.GELU())
|
||
|
|
||
|
return pos_conv
|
||
|
|
||
|
|
||
|
class TransformerEncoder(nn.Layer):
|
||
|
def build_encoder_layer(self, args: Wav2Vec2Config):
|
||
|
layer = TransformerSentenceEncoderLayer(
|
||
|
embedding_dim=self.embedding_dim,
|
||
|
ffn_embedding_dim=args.encoder_ffn_embed_dim,
|
||
|
num_attention_heads=args.encoder_attention_heads,
|
||
|
dropout=self.dropout,
|
||
|
attention_dropout=args.attention_dropout,
|
||
|
activation_dropout=args.activation_dropout,
|
||
|
activation_fn=args.activation_fn,
|
||
|
layer_norm_first=args.layer_norm_first, )
|
||
|
return layer
|
||
|
|
||
|
def __init__(self, args: Wav2Vec2Config):
|
||
|
super().__init__()
|
||
|
|
||
|
self.dropout = args.dropout
|
||
|
self.embedding_dim = args.encoder_embed_dim
|
||
|
self.required_seq_len_multiple = args.required_seq_len_multiple
|
||
|
|
||
|
pos_conv_depth = getattr(args, "pos_conv_depth", 1)
|
||
|
if pos_conv_depth > 1:
|
||
|
num_layers = args.pos_conv_depth
|
||
|
k = max(3, args.conv_pos // num_layers)
|
||
|
|
||
|
def make_conv_block(e, k, g, l):
|
||
|
return nn.Sequential(*[
|
||
|
nn.Sequential(
|
||
|
Conv1D(
|
||
|
e,
|
||
|
e,
|
||
|
kernel_size=k,
|
||
|
padding=k // 2,
|
||
|
groups=g, ),
|
||
|
SamePad(k),
|
||
|
TransposeLast(),
|
||
|
LayerNorm(e, elementwise_affine=False),
|
||
|
TransposeLast(),
|
||
|
nn.GELU(), ) for _ in range(l)
|
||
|
])
|
||
|
|
||
|
self.pos_conv = make_conv_block(self.embedding_dim, k,
|
||
|
args.conv_pos_groups, num_layers)
|
||
|
|
||
|
else:
|
||
|
self.pos_conv = make_conv_pos(
|
||
|
self.embedding_dim,
|
||
|
args.conv_pos,
|
||
|
args.conv_pos_groups, )
|
||
|
|
||
|
self.layers = nn.LayerList([
|
||
|
self.build_encoder_layer(args) for _ in range(args.encoder_layers)
|
||
|
])
|
||
|
self.layer_norm_first = args.layer_norm_first
|
||
|
self.layer_norm = LayerNorm(self.embedding_dim)
|
||
|
self.layerdrop = args.encoder_layerdrop
|
||
|
|
||
|
def forward(self, x, padding_mask=None, layer=None):
|
||
|
x, layer_results = self.extract_features(x, padding_mask, layer)
|
||
|
if self.layer_norm_first and layer is None:
|
||
|
x = self.layer_norm(x)
|
||
|
|
||
|
return x, layer_results
|
||
|
|
||
|
def extract_features(
|
||
|
self,
|
||
|
x,
|
||
|
padding_mask=None,
|
||
|
tgt_layer=None,
|
||
|
min_layer=0, ):
|
||
|
if padding_mask is not None:
|
||
|
x = index_put(x, padding_mask, 0)
|
||
|
|
||
|
x_conv = self.pos_conv(x.transpose([0, 2, 1]))
|
||
|
x_conv = x_conv.transpose([0, 2, 1])
|
||
|
x = x + x_conv
|
||
|
|
||
|
if not self.layer_norm_first:
|
||
|
x = self.layer_norm(x)
|
||
|
|
||
|
# pad to the sequence length dimension
|
||
|
x, pad_length = pad_to_multiple(
|
||
|
x, self.required_seq_len_multiple, dim=-2, value=0)
|
||
|
if pad_length > 0 and padding_mask is None:
|
||
|
padding_mask = paddle.zeros([x.shape[0], x.shape[1]], dtype='bool')
|
||
|
padding_mask[:, -pad_length:] = True
|
||
|
else:
|
||
|
padding_mask, _ = pad_to_multiple(
|
||
|
padding_mask,
|
||
|
self.required_seq_len_multiple,
|
||
|
dim=-1,
|
||
|
value=True)
|
||
|
x = F.dropout(x, p=self.dropout, training=self.training)
|
||
|
|
||
|
# B x T x C -> T x B x C
|
||
|
x = x.transpose([1, 0, 2])
|
||
|
|
||
|
layer_results = []
|
||
|
r = None
|
||
|
for i, layer in enumerate(self.layers):
|
||
|
dropout_probability = np.random.random() if self.layerdrop > 0 else 1
|
||
|
if not self.training or (dropout_probability > self.layerdrop):
|
||
|
x, (z, lr) = layer(
|
||
|
x, self_attn_padding_mask=padding_mask, need_weights=False)
|
||
|
if i >= min_layer:
|
||
|
layer_results.append((x, z, lr))
|
||
|
if i == tgt_layer:
|
||
|
r = x
|
||
|
break
|
||
|
|
||
|
if r is not None:
|
||
|
x = r
|
||
|
|
||
|
# T x B x C -> B x T x C
|
||
|
x = x.transpose([1, 0, 2])
|
||
|
|
||
|
# undo paddding
|
||
|
if pad_length > 0:
|
||
|
x = x[:, :-pad_length]
|
||
|
|
||
|
def undo_pad(a, b, c):
|
||
|
return (a[:-pad_length], b[:-pad_length]
|
||
|
if b is not None else b, c[:-pad_length], )
|
||
|
|
||
|
layer_results = [undo_pad(*u) for u in layer_results]
|
||
|
|
||
|
return x, layer_results
|
||
|
|
||
|
def max_positions(self):
|
||
|
"""Maximum output length supported by the encoder."""
|
||
|
return self.args.max_positions
|
||
|
|
||
|
def upgrade_state_dict_named(self, state_dict, name):
|
||
|
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
|
||
|
return state_dict
|
||
|
|
||
|
|
||
|
class TransformerSentenceEncoderLayer(nn.Layer):
|
||
|
"""
|
||
|
Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
|
||
|
models.
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
embedding_dim: float=768,
|
||
|
ffn_embedding_dim: float=3072,
|
||
|
num_attention_heads: int=8,
|
||
|
dropout: float=0.1,
|
||
|
attention_dropout: float=0.1,
|
||
|
activation_dropout: float=0.1,
|
||
|
activation_fn: str="relu",
|
||
|
layer_norm_first: bool=False, ) -> None:
|
||
|
|
||
|
super().__init__()
|
||
|
# Initialize parameters
|
||
|
self.embedding_dim = embedding_dim
|
||
|
self.dropout = dropout
|
||
|
self.activation_dropout = activation_dropout
|
||
|
|
||
|
# Initialize blocks
|
||
|
self.activation_fn = get_activation_fn(activation_fn)
|
||
|
self.self_attn = MultiheadAttention(
|
||
|
self.embedding_dim,
|
||
|
num_attention_heads,
|
||
|
dropout=attention_dropout,
|
||
|
self_attention=True, )
|
||
|
|
||
|
self.dropout1 = nn.Dropout(dropout)
|
||
|
self.dropout2 = nn.Dropout(self.activation_dropout)
|
||
|
self.dropout3 = nn.Dropout(dropout)
|
||
|
|
||
|
self.layer_norm_first = layer_norm_first
|
||
|
|
||
|
# layer norm associated with the self attention layer
|
||
|
self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
|
||
|
self.fc1 = Linear(self.embedding_dim, ffn_embedding_dim)
|
||
|
self.fc2 = Linear(ffn_embedding_dim, self.embedding_dim)
|
||
|
|
||
|
# layer norm associated with the position wise feed-forward NN
|
||
|
self.final_layer_norm = LayerNorm(self.embedding_dim)
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
x: paddle.Tensor,
|
||
|
self_attn_mask: paddle.Tensor=None,
|
||
|
self_attn_padding_mask: paddle.Tensor=None,
|
||
|
need_weights: bool=False,
|
||
|
att_args=None, ):
|
||
|
"""
|
||
|
LayerNorm is applied either before or after the self-attention/ffn
|
||
|
modules similar to the original Transformer imlementation.
|
||
|
"""
|
||
|
residual = x
|
||
|
|
||
|
if self.layer_norm_first:
|
||
|
x = self.self_attn_layer_norm(x)
|
||
|
x, attn = self.self_attn(
|
||
|
query=x,
|
||
|
key=x,
|
||
|
value=x,
|
||
|
key_padding_mask=self_attn_padding_mask,
|
||
|
attn_mask=self_attn_mask,
|
||
|
need_weights=False, )
|
||
|
x = self.dropout1(x)
|
||
|
x = residual + x
|
||
|
|
||
|
residual = x
|
||
|
x = self.final_layer_norm(x)
|
||
|
x = self.activation_fn(self.fc1(x))
|
||
|
x = self.dropout2(x)
|
||
|
x = self.fc2(x)
|
||
|
|
||
|
layer_result = x
|
||
|
|
||
|
x = self.dropout3(x)
|
||
|
x = residual + x
|
||
|
else:
|
||
|
x, attn = self.self_attn(
|
||
|
query=x,
|
||
|
key=x,
|
||
|
value=x,
|
||
|
key_padding_mask=self_attn_padding_mask,
|
||
|
need_weights=False, )
|
||
|
|
||
|
x = self.dropout1(x)
|
||
|
x = residual + x
|
||
|
|
||
|
x = self.self_attn_layer_norm(x)
|
||
|
|
||
|
residual = x
|
||
|
x = self.activation_fn(self.fc1(x))
|
||
|
x = self.dropout2(x)
|
||
|
x = self.fc2(x)
|
||
|
|
||
|
layer_result = x
|
||
|
|
||
|
x = self.dropout3(x)
|
||
|
x = residual + x
|
||
|
x = self.final_layer_norm(x)
|
||
|
|
||
|
return x, (attn, layer_result)
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class AudioPretrainingConfig:
|
||
|
sample_rate: int = field(
|
||
|
default=16_000,
|
||
|
metadata={
|
||
|
"help":
|
||
|
"target sample rate. audio files will be up/down sampled to this rate"
|
||
|
}, )
|
||
|
normalize: bool = field(
|
||
|
default=False,
|
||
|
metadata={
|
||
|
"help": "if set, normalizes input to have 0 mean and unit variance"
|
||
|
}, )
|
||
|
enable_padding: bool = field(
|
||
|
default=False,
|
||
|
metadata={"help": "pad shorter samples instead of cropping"})
|
||
|
max_sample_size: Optional[int] = field(
|
||
|
default=None,
|
||
|
metadata={"help": "max sample size to crop to for batching"})
|
||
|
min_sample_size: Optional[int] = field(
|
||
|
default=None,
|
||
|
metadata={"help": "min sample size to skip small examples"})
|