|
|
|
@ -4,7 +4,6 @@
|
|
|
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
# S3PRL has no contribution to this file
|
|
|
|
|
# The file was copied from fairseq to remove the dependency on the entire fairseq package
|
|
|
|
|
import logging
|
|
|
|
|
import math
|
|
|
|
|
import uuid
|
|
|
|
|
from dataclasses import dataclass
|
|
|
|
@ -16,15 +15,19 @@ from typing import Dict
|
|
|
|
|
from typing import List
|
|
|
|
|
from typing import Optional
|
|
|
|
|
from typing import Tuple
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
import paddle
|
|
|
|
|
import paddle.nn as nn
|
|
|
|
|
import paddle.nn.functional as F
|
|
|
|
|
from paddle import Tensor
|
|
|
|
|
from paddlespeech.s2t.modules.align import Linear
|
|
|
|
|
from paddlespeech.s2t.modules.align import LayerNorm
|
|
|
|
|
from paddlespeech.s2t.modules.align import Conv1D
|
|
|
|
|
from paddlespeech.s2t.modules.align import Conv2D
|
|
|
|
|
from paddlespeech.s2t.modules.align import Embedding
|
|
|
|
|
from paddlespeech.s2t.utils.log import Log
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
logger = Log(__name__).getlog()
|
|
|
|
|
|
|
|
|
|
class GLU(nn.Layer):
|
|
|
|
|
r"""Applies the gated linear unit function
|
|
|
|
@ -153,15 +156,19 @@ def quant_noise(module, p, block_size):
|
|
|
|
|
return module
|
|
|
|
|
|
|
|
|
|
# supported modules
|
|
|
|
|
assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2D))
|
|
|
|
|
assert isinstance(module, (Linear, Embedding, Conv2D))
|
|
|
|
|
|
|
|
|
|
# test whether module.weight has the right sizes wrt block_size
|
|
|
|
|
is_conv = len(module.weight.shape) == 4
|
|
|
|
|
|
|
|
|
|
# 2D matrix
|
|
|
|
|
if not is_conv:
|
|
|
|
|
if isinstance(module, Linear):
|
|
|
|
|
features_weight = module.weight.shape[0]
|
|
|
|
|
else:
|
|
|
|
|
features_weight = module.weight.shape[1]
|
|
|
|
|
assert (
|
|
|
|
|
module.weight.shape[1] %
|
|
|
|
|
features_weight %
|
|
|
|
|
block_size == 0), "Input features must be a multiple of block sizes"
|
|
|
|
|
|
|
|
|
|
# 4D matrix
|
|
|
|
@ -181,14 +188,20 @@ def quant_noise(module, p, block_size):
|
|
|
|
|
if not is_conv:
|
|
|
|
|
# gather weight and sizes
|
|
|
|
|
weight = mod.weight
|
|
|
|
|
in_features = weight.shape[1]
|
|
|
|
|
out_features = weight.shape[0]
|
|
|
|
|
if isinstance(module, Linear):
|
|
|
|
|
in_features = weight.shape[0]
|
|
|
|
|
out_features = weight.shape[1]
|
|
|
|
|
else:
|
|
|
|
|
in_features = weight.shape[1]
|
|
|
|
|
out_features = weight.shape[0]
|
|
|
|
|
|
|
|
|
|
# split weight matrix into blocks and randomly drop selected blocks
|
|
|
|
|
mask = paddle.zeros(
|
|
|
|
|
[in_features // block_size * out_features],
|
|
|
|
|
dtype=paddle.bool)
|
|
|
|
|
mask.bernoulli_(p)
|
|
|
|
|
# the implementation of bernoulli_, p=0.5
|
|
|
|
|
mask = paddle.ones_like(mask) * 0.5
|
|
|
|
|
mask = paddle.bernoulli(mask)
|
|
|
|
|
mask = mask.unsqueeze(1).tile([1, block_size]).reshape(
|
|
|
|
|
[-1, in_features])
|
|
|
|
|
|
|
|
|
@ -203,12 +216,18 @@ def quant_noise(module, p, block_size):
|
|
|
|
|
mask = paddle.zeros(
|
|
|
|
|
[in_channels // block_size * out_channels],
|
|
|
|
|
dtype=paddle.bool)
|
|
|
|
|
mask.bernoulli_(p)
|
|
|
|
|
|
|
|
|
|
# the implementation of bernoulli_, p=0.5
|
|
|
|
|
mask = paddle.ones_like(mask) * 0.5
|
|
|
|
|
mask = paddle.bernoulli(mask)
|
|
|
|
|
mask = mask.unsqueeze(1).tile([1, block_size]).reshape(
|
|
|
|
|
[-1, in_channels])
|
|
|
|
|
else:
|
|
|
|
|
mask = paddle.zeros(weight.shape)
|
|
|
|
|
mask.bernoulli_(p)
|
|
|
|
|
|
|
|
|
|
# the implementation of bernoulli_, p=0.5
|
|
|
|
|
mask = paddle.ones_like(mask) * 0.5
|
|
|
|
|
mask = paddle.bernoulli(mask)
|
|
|
|
|
mask = mask.unsqueeze(1).tile([1, in_channels, 1, 1])
|
|
|
|
|
|
|
|
|
|
# scale weights and apply mask
|
|
|
|
@ -282,28 +301,52 @@ class MultiheadAttention(nn.Layer):
|
|
|
|
|
"Self-attention requires query, key and "
|
|
|
|
|
"value to be of the same size")
|
|
|
|
|
|
|
|
|
|
weight_attr = paddle.ParamAttr(initializer=nn.initializer.XavierUniform)
|
|
|
|
|
bias_attr = nn.initializer.Constant(0)
|
|
|
|
|
# self.k_proj = quant_noise(
|
|
|
|
|
# nn.Linear(self.kdim, embed_dim, weight_attr=weight_attr, bias_attr=bias), q_noise, qn_block_size
|
|
|
|
|
# )
|
|
|
|
|
# self.v_proj = quant_noise(
|
|
|
|
|
# nn.Linear(self.vdim, embed_dim, weight_attr=weight_attr, bias_attr=bias), q_noise, qn_block_size
|
|
|
|
|
# )
|
|
|
|
|
# self.q_proj = quant_noise(
|
|
|
|
|
# nn.Linear(embed_dim, embed_dim, weight_attr=weight_attr, bias_attr=bias), q_noise, qn_block_size
|
|
|
|
|
# )
|
|
|
|
|
|
|
|
|
|
# self.out_proj = quant_noise(
|
|
|
|
|
# nn.Linear(embed_dim, embed_dim, weight_attr=weight_attr, bias_attr=bias if not bias else bias_attr), q_noise, qn_block_size
|
|
|
|
|
# )
|
|
|
|
|
self.k_proj = nn.Linear(self.kdim, embed_dim)
|
|
|
|
|
|
|
|
|
|
self.v_proj = nn.Linear(self.vdim, embed_dim)
|
|
|
|
|
|
|
|
|
|
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
|
|
|
|
|
|
|
|
|
self.out_proj = nn.Linear(embed_dim, embed_dim)
|
|
|
|
|
# Todo scaled initialization
|
|
|
|
|
# Empirically observed the convergence to be much better with
|
|
|
|
|
# the scaled initialization
|
|
|
|
|
weight_attr = nn.initializer.XavierUniform()
|
|
|
|
|
kv_proj_bias_attr = nn.initializer.XavierUniform()
|
|
|
|
|
out_proj_bias_attr = nn.initializer.Constant(0)
|
|
|
|
|
|
|
|
|
|
self.k_proj = quant_noise(
|
|
|
|
|
nn.Linear(self.kdim, embed_dim, weight_attr=weight_attr, bias_attr=bias if not bias else kv_proj_bias_attr), q_noise, qn_block_size
|
|
|
|
|
)
|
|
|
|
|
self.v_proj = quant_noise(
|
|
|
|
|
nn.Linear(self.vdim, embed_dim, weight_attr=weight_attr, bias_attr=bias if not bias else kv_proj_bias_attr), q_noise, qn_block_size
|
|
|
|
|
)
|
|
|
|
|
self.q_proj = quant_noise(
|
|
|
|
|
nn.Linear(embed_dim, embed_dim, weight_attr=weight_attr, bias_attr=bias), q_noise, qn_block_size
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.out_proj = quant_noise(
|
|
|
|
|
nn.Linear(embed_dim, embed_dim, weight_attr=weight_attr, bias_attr=bias if not bias else out_proj_bias_attr), q_noise, qn_block_size
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# nn.initializer.XavierUniform(self.k_proj.weight, gain=1 / math.sqrt(2))
|
|
|
|
|
# nn.initializer.XavierUniform(self.v_proj.weight, gain=1 / math.sqrt(2))
|
|
|
|
|
# nn.initializer.XavierUniform(self.q_proj.weight, gain=1 / math.sqrt(2))
|
|
|
|
|
# else:
|
|
|
|
|
# self.k_proj.weight = paddle.ParamAttr()
|
|
|
|
|
# nn.initializer.XavierUniform(self.k_proj.weight)
|
|
|
|
|
# nn.initializer.XavierUniform(self.v_proj.weight)
|
|
|
|
|
# nn.initializer.XavierUniform(self.q_proj.weight)
|
|
|
|
|
|
|
|
|
|
# nn.initializer.XavierUniform(self.out_proj.weight)
|
|
|
|
|
# if self.out_proj.bias is not None:
|
|
|
|
|
# nn.initializer.Constant(self.out_proj.bias)
|
|
|
|
|
# if self.bias_k is not None:
|
|
|
|
|
# nn.initializer.XavierNormal(self.bias_k)
|
|
|
|
|
# if self.bias_v is not None:
|
|
|
|
|
# nn.initializer.XavierNormal(self.bias_v)
|
|
|
|
|
|
|
|
|
|
# self.k_proj = Linear(self.kdim, embed_dim)
|
|
|
|
|
|
|
|
|
|
# self.v_proj = Linear(self.vdim, embed_dim)
|
|
|
|
|
|
|
|
|
|
# self.q_proj = Linear(embed_dim, embed_dim)
|
|
|
|
|
|
|
|
|
|
# self.out_proj = Linear(embed_dim, embed_dim)
|
|
|
|
|
|
|
|
|
|
if add_bias_kv:
|
|
|
|
|
self.bias_k = paddle.create_parameter(
|
|
|
|
@ -327,26 +370,26 @@ class MultiheadAttention(nn.Layer):
|
|
|
|
|
def prepare_for_onnx_export_(self):
|
|
|
|
|
self.onnx_trace = True
|
|
|
|
|
|
|
|
|
|
# def reset_parameters(self):
|
|
|
|
|
# if self.qkv_same_dim:
|
|
|
|
|
# # Empirically observed the convergence to be much better with
|
|
|
|
|
# # the scaled initialization
|
|
|
|
|
# nn.initializer.XavierUniform(self.k_proj.weight, gain=1 / math.sqrt(2))
|
|
|
|
|
# nn.initializer.XavierUniform(self.v_proj.weight, gain=1 / math.sqrt(2))
|
|
|
|
|
# nn.initializer.XavierUniform(self.q_proj.weight, gain=1 / math.sqrt(2))
|
|
|
|
|
# else:
|
|
|
|
|
# self.k_proj.weight = paddle.ParamAttr()
|
|
|
|
|
# nn.initializer.XavierUniform(self.k_proj.weight)
|
|
|
|
|
# nn.initializer.XavierUniform(self.v_proj.weight)
|
|
|
|
|
# nn.initializer.XavierUniform(self.q_proj.weight)
|
|
|
|
|
|
|
|
|
|
# nn.initializer.XavierUniform(self.out_proj.weight)
|
|
|
|
|
# if self.out_proj.bias is not None:
|
|
|
|
|
# nn.initializer.Constant(self.out_proj.bias)
|
|
|
|
|
# if self.bias_k is not None:
|
|
|
|
|
# nn.initializer.XavierNormal(self.bias_k)
|
|
|
|
|
# if self.bias_v is not None:
|
|
|
|
|
# nn.initializer.XavierNormal(self.bias_v)
|
|
|
|
|
def reset_parameters(self):
|
|
|
|
|
if self.qkv_same_dim:
|
|
|
|
|
# Empirically observed the convergence to be much better with
|
|
|
|
|
# the scaled initialization
|
|
|
|
|
nn.initializer.XavierUniform(self.k_proj.weight, gain=1 / math.sqrt(2))
|
|
|
|
|
nn.initializer.XavierUniform(self.v_proj.weight, gain=1 / math.sqrt(2))
|
|
|
|
|
nn.initializer.XavierUniform(self.q_proj.weight, gain=1 / math.sqrt(2))
|
|
|
|
|
else:
|
|
|
|
|
self.k_proj.weight = paddle.ParamAttr()
|
|
|
|
|
nn.initializer.XavierUniform(self.k_proj.weight)
|
|
|
|
|
nn.initializer.XavierUniform(self.v_proj.weight)
|
|
|
|
|
nn.initializer.XavierUniform(self.q_proj.weight)
|
|
|
|
|
|
|
|
|
|
nn.initializer.XavierUniform(self.out_proj.weight)
|
|
|
|
|
if self.out_proj.bias is not None:
|
|
|
|
|
nn.initializer.Constant(self.out_proj.bias)
|
|
|
|
|
if self.bias_k is not None:
|
|
|
|
|
nn.initializer.XavierNormal(self.bias_k)
|
|
|
|
|
if self.bias_v is not None:
|
|
|
|
|
nn.initializer.XavierNormal(self.bias_v)
|
|
|
|
|
|
|
|
|
|
def _get_reserve_head_index(self, num_heads_to_keep: int):
|
|
|
|
|
k_proj_heads_norm = []
|
|
|
|
@ -357,15 +400,15 @@ class MultiheadAttention(nn.Layer):
|
|
|
|
|
start_idx = i * self.head_dim
|
|
|
|
|
end_idx = (i + 1) * self.head_dim
|
|
|
|
|
k_proj_heads_norm.append(
|
|
|
|
|
paddle.sum(paddle.abs(self.k_proj.weight[start_idx:end_idx, ]))
|
|
|
|
|
paddle.sum(paddle.abs(self.k_proj.weight[:, start_idx:end_idx]))
|
|
|
|
|
.tolist() + paddle.sum(
|
|
|
|
|
paddle.abs(self.k_proj.bias[start_idx:end_idx])).tolist())
|
|
|
|
|
q_proj_heads_norm.append(
|
|
|
|
|
paddle.sum(paddle.abs(self.q_proj.weight[start_idx:end_idx, ]))
|
|
|
|
|
paddle.sum(paddle.abs(self.q_proj.weight[:, start_idx:end_idx]))
|
|
|
|
|
.tolist() + paddle.sum(
|
|
|
|
|
paddle.abs(self.q_proj.bias[start_idx:end_idx])).tolist())
|
|
|
|
|
v_proj_heads_norm.append(
|
|
|
|
|
paddle.sum(paddle.abs(self.v_proj.weight[start_idx:end_idx, ]))
|
|
|
|
|
paddle.sum(paddle.abs(self.v_proj.weight[:, start_idx:end_idx]))
|
|
|
|
|
.tolist() + paddle.sum(
|
|
|
|
|
paddle.abs(self.v_proj.bias[start_idx:end_idx])).tolist())
|
|
|
|
|
|
|
|
|
@ -395,24 +438,24 @@ class MultiheadAttention(nn.Layer):
|
|
|
|
|
|
|
|
|
|
for ele in reserve_head_index:
|
|
|
|
|
start_idx, end_idx = ele
|
|
|
|
|
new_q_weight.append(self.q_proj.weight[start_idx:end_idx, ])
|
|
|
|
|
new_q_weight.append(self.q_proj.weight[:, start_idx:end_idx])
|
|
|
|
|
new_q_bias.append(self.q_proj.bias[start_idx:end_idx])
|
|
|
|
|
|
|
|
|
|
new_k_weight.append(self.k_proj.weight[start_idx:end_idx, ])
|
|
|
|
|
new_k_weight.append(self.k_proj.weight[:, start_idx:end_idx])
|
|
|
|
|
|
|
|
|
|
new_k_bias.append(self.k_proj.bias[start_idx:end_idx])
|
|
|
|
|
|
|
|
|
|
new_v_weight.append(self.v_proj.weight[start_idx:end_idx, ])
|
|
|
|
|
new_v_weight.append(self.v_proj.weight[:, start_idx:end_idx])
|
|
|
|
|
new_v_bias.append(self.v_proj.bias[start_idx:end_idx])
|
|
|
|
|
|
|
|
|
|
new_out_proj_weight.append(
|
|
|
|
|
self.out_proj.weight[:, start_idx:end_idx])
|
|
|
|
|
self.out_proj.weight[start_idx:end_idx, ])
|
|
|
|
|
|
|
|
|
|
new_q_weight = paddle.concat(new_q_weight).detach()
|
|
|
|
|
new_k_weight = paddle.concat(new_k_weight).detach()
|
|
|
|
|
new_v_weight = paddle.concat(new_v_weight).detach()
|
|
|
|
|
new_q_weight = paddle.concat(new_q_weight, axis=-1).detach()
|
|
|
|
|
new_k_weight = paddle.concat(new_k_weight, axis=-1).detach()
|
|
|
|
|
new_v_weight = paddle.concat(new_v_weight, axis=-1).detach()
|
|
|
|
|
new_out_proj_weight = paddle.concat(
|
|
|
|
|
new_out_proj_weight, axis=-1).detach()
|
|
|
|
|
new_out_proj_weight).detach()
|
|
|
|
|
new_q_weight.stop_gradient = False
|
|
|
|
|
new_k_weight.stop_gradient = False
|
|
|
|
|
new_v_weight.stop_gradient = False
|
|
|
|
@ -566,11 +609,11 @@ class MultiheadAttention(nn.Layer):
|
|
|
|
|
assert (embed_dim == self.embed_dim
|
|
|
|
|
), f"query dim {embed_dim} != {self.embed_dim}"
|
|
|
|
|
assert list(query.shape) == [tgt_len, bsz, embed_dim]
|
|
|
|
|
# if key is not None:
|
|
|
|
|
# src_len, key_bsz, _ = key.size()
|
|
|
|
|
# if not torch.jit.is_scripting():
|
|
|
|
|
# assert value is not None
|
|
|
|
|
# assert src_len, key_bsz == value.shape[:2]
|
|
|
|
|
if key is not None:
|
|
|
|
|
src_len, key_bsz, _ = key.shape
|
|
|
|
|
# if not torch.jit.is_scripting():
|
|
|
|
|
# assert value is not None
|
|
|
|
|
# assert src_len, key_bsz == value.shape[:2]
|
|
|
|
|
|
|
|
|
|
# if (
|
|
|
|
|
# not self.onnx_trace
|
|
|
|
@ -848,7 +891,7 @@ class MultiheadAttention(nn.Layer):
|
|
|
|
|
new_key_padding_mask = paddle.concat([
|
|
|
|
|
paddle.cast(prev_key_padding_mask, 'float32'),
|
|
|
|
|
paddle.cast(key_padding_mask, 'float32')
|
|
|
|
|
], axis == 1)
|
|
|
|
|
], axis = 1)
|
|
|
|
|
# During incremental decoding, as the padding token enters and
|
|
|
|
|
# leaves the frame, there will be a time when prev or current
|
|
|
|
|
# is None
|
|
|
|
@ -859,7 +902,7 @@ class MultiheadAttention(nn.Layer):
|
|
|
|
|
new_key_padding_mask = paddle.concat([
|
|
|
|
|
paddle.cast(prev_key_padding_mask, 'float32'),
|
|
|
|
|
paddle.cast(filler, 'float32')
|
|
|
|
|
], axis == 1)
|
|
|
|
|
], axis = 1)
|
|
|
|
|
else:
|
|
|
|
|
new_key_padding_mask = prev_key_padding_mask
|
|
|
|
|
elif key_padding_mask is not None:
|
|
|
|
@ -869,7 +912,7 @@ class MultiheadAttention(nn.Layer):
|
|
|
|
|
new_key_padding_mask = paddle.concat([
|
|
|
|
|
paddle.cast(filler, 'float32'),
|
|
|
|
|
paddle.cast(key_padding_mask, 'float32')
|
|
|
|
|
], axis == 1)
|
|
|
|
|
], axis = 1)
|
|
|
|
|
else:
|
|
|
|
|
new_key_padding_mask = paddle.cast(key_padding_mask, 'float32')
|
|
|
|
|
else:
|
|
|
|
@ -1022,7 +1065,7 @@ class GumbelVectorQuantizer(nn.Layer):
|
|
|
|
|
|
|
|
|
|
def block(input_dim, output_dim):
|
|
|
|
|
return nn.Sequential(
|
|
|
|
|
nn.Linear(input_dim, output_dim), activation)
|
|
|
|
|
Linear(input_dim, output_dim), activation)
|
|
|
|
|
|
|
|
|
|
inner_dim = self.input_dim * weight_proj_factor
|
|
|
|
|
self.weight_proj = nn.Sequential(
|
|
|
|
@ -1030,11 +1073,9 @@ class GumbelVectorQuantizer(nn.Layer):
|
|
|
|
|
block(self.input_dim if i == 0 else inner_dim, inner_dim)
|
|
|
|
|
for i in range(weight_proj_depth - 1)
|
|
|
|
|
],
|
|
|
|
|
nn.Linear(inner_dim, groups * num_vars), )
|
|
|
|
|
Linear(inner_dim, groups * num_vars), )
|
|
|
|
|
else:
|
|
|
|
|
self.weight_proj = nn.Linear(self.input_dim, groups * num_vars)
|
|
|
|
|
nn.initializer.Normal(mean=0, std=1)(self.weight_proj.weight)
|
|
|
|
|
nn.initializer.Zero()(self.weight_proj.bias)
|
|
|
|
|
self.weight_proj = Linear(self.input_dim, groups * num_vars, weight_attr=nn.initializer.Normal(mean=0, std=1), bias_attr=nn.initializer.Zero())
|
|
|
|
|
|
|
|
|
|
if isinstance(temp, str):
|
|
|
|
|
import ast
|
|
|
|
@ -1125,7 +1166,7 @@ class GumbelVectorQuantizer(nn.Layer):
|
|
|
|
|
|
|
|
|
|
if self.training:
|
|
|
|
|
x = F.gumbel_softmax(
|
|
|
|
|
x.astype('float32'), tau=self.curr_temp,
|
|
|
|
|
x.astype('float32'), temperature=self.curr_temp,
|
|
|
|
|
hard=True).astype(x.dtype)
|
|
|
|
|
else:
|
|
|
|
|
x = hard_x
|
|
|
|
@ -1192,22 +1233,11 @@ class TransposeLast(nn.Layer):
|
|
|
|
|
trans_dim[-1], trans_dim[-2] = trans_dim[-2], trans_dim[-1]
|
|
|
|
|
return x.transpose(trans_dim)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def LayerNorm(normalized_shape, eps=1e-5):
|
|
|
|
|
return nn.LayerNorm(
|
|
|
|
|
normalized_shape,
|
|
|
|
|
epsilon=eps,
|
|
|
|
|
weight_attr=paddle.ParamAttr(),
|
|
|
|
|
bias_attr=paddle.ParamAttr())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Fp32LayerNorm(nn.LayerNorm):
|
|
|
|
|
class Fp32LayerNorm(LayerNorm):
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
def forward(self, input):
|
|
|
|
|
# import pdb
|
|
|
|
|
# pdb.set_trace()
|
|
|
|
|
output = F.layer_norm(
|
|
|
|
|
input.astype('float32'),
|
|
|
|
|
self._normalized_shape,
|
|
|
|
@ -1222,8 +1252,6 @@ class Fp32GroupNorm(nn.GroupNorm):
|
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
def forward(self, input):
|
|
|
|
|
# import pdb
|
|
|
|
|
# pdb.set_trace()
|
|
|
|
|
output = F.group_norm(
|
|
|
|
|
input.astype('float32'),
|
|
|
|
|
self._num_groups,
|
|
|
|
@ -1724,7 +1752,7 @@ class Wav2Vec2Model(nn.Layer):
|
|
|
|
|
mode=cfg.extractor_mode,
|
|
|
|
|
conv_bias=cfg.conv_bias, )
|
|
|
|
|
|
|
|
|
|
self.post_extract_proj = (nn.Linear(self.embed, cfg.encoder_embed_dim)
|
|
|
|
|
self.post_extract_proj = (Linear(self.embed, cfg.encoder_embed_dim)
|
|
|
|
|
if self.embed != cfg.encoder_embed_dim and
|
|
|
|
|
not cfg.quantize_input else None)
|
|
|
|
|
|
|
|
|
@ -1774,9 +1802,9 @@ class Wav2Vec2Model(nn.Layer):
|
|
|
|
|
time_first=True,
|
|
|
|
|
weight_proj_depth=cfg.quantizer_depth,
|
|
|
|
|
weight_proj_factor=cfg.quantizer_factor, )
|
|
|
|
|
self.project_q = nn.Linear(vq_dim, final_dim)
|
|
|
|
|
self.project_q = Linear(vq_dim, final_dim)
|
|
|
|
|
else:
|
|
|
|
|
self.project_q = nn.Linear(self.embed, final_dim)
|
|
|
|
|
self.project_q = Linear(self.embed, final_dim)
|
|
|
|
|
|
|
|
|
|
if cfg.quantize_input:
|
|
|
|
|
if cfg.same_quantizer and self.quantizer is not None:
|
|
|
|
@ -1794,7 +1822,7 @@ class Wav2Vec2Model(nn.Layer):
|
|
|
|
|
time_first=True,
|
|
|
|
|
weight_proj_depth=cfg.quantizer_depth,
|
|
|
|
|
weight_proj_factor=cfg.quantizer_factor, )
|
|
|
|
|
self.project_inp = nn.Linear(vq_dim, cfg.encoder_embed_dim)
|
|
|
|
|
self.project_inp = Linear(vq_dim, cfg.encoder_embed_dim)
|
|
|
|
|
|
|
|
|
|
self.mask_emb = self.create_parameter(
|
|
|
|
|
shape=[cfg.encoder_embed_dim],
|
|
|
|
@ -1809,9 +1837,9 @@ class Wav2Vec2Model(nn.Layer):
|
|
|
|
|
self.target_glu = None
|
|
|
|
|
if cfg.target_glu:
|
|
|
|
|
self.target_glu = nn.Sequential(
|
|
|
|
|
nn.Linear(final_dim, final_dim * 2), GLU())
|
|
|
|
|
Linear(final_dim, final_dim * 2), GLU())
|
|
|
|
|
|
|
|
|
|
self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim)
|
|
|
|
|
self.final_proj = Linear(cfg.encoder_embed_dim, final_dim)
|
|
|
|
|
|
|
|
|
|
def upgrade_state_dict_named(self, state_dict, name):
|
|
|
|
|
super().upgrade_state_dict_named(state_dict, name)
|
|
|
|
@ -2194,7 +2222,7 @@ class ConvFeatureExtractionModel(nn.Layer):
|
|
|
|
|
is_group_norm=False,
|
|
|
|
|
conv_bias=False, ):
|
|
|
|
|
def make_conv():
|
|
|
|
|
conv = nn.Conv1D(
|
|
|
|
|
conv = Conv1D(
|
|
|
|
|
n_in,
|
|
|
|
|
n_out,
|
|
|
|
|
k,
|
|
|
|
@ -2256,17 +2284,16 @@ class ConvFeatureExtractionModel(nn.Layer):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def make_conv_pos(e, k, g):
|
|
|
|
|
pos_conv = nn.Conv1D(
|
|
|
|
|
dropout = 0
|
|
|
|
|
std = math.sqrt((4 * (1.0 - dropout)) / (k * e))
|
|
|
|
|
pos_conv = Conv1D(
|
|
|
|
|
e,
|
|
|
|
|
e,
|
|
|
|
|
kernel_size=k,
|
|
|
|
|
padding=k // 2,
|
|
|
|
|
groups=g, )
|
|
|
|
|
dropout = 0
|
|
|
|
|
std = math.sqrt((4 * (1.0 - dropout)) / (k * e))
|
|
|
|
|
nn.initializer.Normal(mean=0, std=std)(pos_conv.weight)
|
|
|
|
|
nn.initializer.Constant(0)(pos_conv.bias)
|
|
|
|
|
|
|
|
|
|
groups=g,
|
|
|
|
|
weight_attr=nn.initializer.Normal(mean=0, std=std),
|
|
|
|
|
bias_attr=nn.initializer.Constant(0))
|
|
|
|
|
pos_conv = nn.utils.weight_norm(pos_conv, name="weight", dim=2)
|
|
|
|
|
pos_conv = nn.Sequential(pos_conv, SamePad(k), nn.GELU())
|
|
|
|
|
|
|
|
|
@ -2301,7 +2328,7 @@ class TransformerEncoder(nn.Layer):
|
|
|
|
|
def make_conv_block(e, k, g, l):
|
|
|
|
|
return nn.Sequential(*[
|
|
|
|
|
nn.Sequential(
|
|
|
|
|
nn.Conv1D(
|
|
|
|
|
Conv1D(
|
|
|
|
|
e,
|
|
|
|
|
e,
|
|
|
|
|
kernel_size=k,
|
|
|
|
@ -2454,8 +2481,8 @@ class TransformerSentenceEncoderLayer(nn.Layer):
|
|
|
|
|
|
|
|
|
|
# layer norm associated with the self attention layer
|
|
|
|
|
self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
|
|
|
|
|
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
|
|
|
|
|
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
|
|
|
|
|
self.fc1 = Linear(self.embedding_dim, ffn_embedding_dim)
|
|
|
|
|
self.fc2 = Linear(ffn_embedding_dim, self.embedding_dim)
|
|
|
|
|
|
|
|
|
|
# layer norm associated with the position wise feed-forward NN
|
|
|
|
|
self.final_layer_norm = LayerNorm(self.embedding_dim)
|
|
|
|
|