@ -6,40 +6,38 @@
# Based on fairseq code bases
# https://github.com/pytorch/fairseq
# --------------------------------------------------------
import math
import logging
from typing import List , Optional , Tuple
import math
from typing import List
from typing import Optional
from typing import Tuple
import numpy as np
import paddle
import paddle . nn as nn
import paddle . nn . functional as F
from paddle . nn import LayerNorm
from paddle import Tensor
from .modules . modules import (
MultiheadAttention ,
SamePad ,
get_activation_fn ,
TransposeLast ,
GLU_Linear ,
)
from paddle. nn import LayerNorm
from . modules . modules import get_activation_fn
from . modules . modules import GLU_Linear
from . modules . modules import MultiheadAttention
from . modules . modules import SamePad
from . modules . modules import TransposeLast
logger = logging . getLogger ( __name__ )
def compute_mask_indices (
shape : Tuple [ int , int ] ,
padding_mask : Optional [ Tensor ] ,
mask_prob : float ,
mask_length : int ,
mask_type : str = " static " ,
mask_other : float = 0.0 ,
min_masks : int = 0 ,
no_overlap : bool = False ,
min_space : int = 0 ,
) - > np . ndarray :
shape : Tuple [ int , int ] ,
padding_mask : Optional [ Tensor ] ,
mask_prob : float ,
mask_length : int ,
mask_type : str = " static " ,
mask_other : float = 0.0 ,
min_masks : int = 0 ,
no_overlap : bool = False ,
min_space : int = 0 , ) - > np . ndarray :
"""
Computes random mask spans for a given shape
@ -65,9 +63,7 @@ def compute_mask_indices(
all_num_mask = int (
# add a random number for probabilistic rounding
mask_prob * all_sz / float ( mask_length )
+ np . random . rand ( )
)
mask_prob * all_sz / float ( mask_length ) + np . random . rand ( ) )
all_num_mask = max ( min_masks , all_num_mask )
@ -77,9 +73,7 @@ def compute_mask_indices(
sz = all_sz - padding_mask [ i ] . long ( ) . sum ( ) . item ( )
num_mask = int (
# add a random number for probabilistic rounding
mask_prob * sz / float ( mask_length )
+ np . random . rand ( )
)
mask_prob * sz / float ( mask_length ) + np . random . rand ( ) )
num_mask = max ( min_masks , num_mask )
else :
sz = all_sz
@ -88,7 +82,8 @@ def compute_mask_indices(
if mask_type == " static " :
lengths = np . full ( num_mask , mask_length )
elif mask_type == " uniform " :
lengths = np . random . randint ( mask_other , mask_length * 2 + 1 , size = num_mask )
lengths = np . random . randint (
mask_other , mask_length * 2 + 1 , size = num_mask )
elif mask_type == " normal " :
lengths = np . random . normal ( mask_length , mask_other , size = num_mask )
lengths = [ max ( 1 , int ( round ( x ) ) ) for x in lengths ]
@ -119,9 +114,9 @@ def compute_mask_indices(
min_length = min ( lengths )
for length in sorted ( lengths , reverse = True ) :
lens = np . fromiter (
( e - s if e - s > = length + min_space else 0 for s , e in parts ) ,
np . int ,
)
( e - s if e - s > = length + min_space else 0
for s , e in parts ) ,
np . int_ , )
l_sum = np . sum ( lens )
if l_sum == 0 :
break
@ -137,13 +132,10 @@ def compute_mask_indices(
mask_idc = np . random . choice ( sz - min_len , num_mask , replace = False )
mask_idc = np . asarray (
[
mask_idc [ j ] + offset
for j in range ( len ( mask_idc ) )
for offset in range ( lengths [ j ] )
]
)
mask_idc = np . asarray ( [
mask_idc [ j ] + offset
for j in range ( len ( mask_idc ) ) for offset in range ( lengths [ j ] )
] )
mask_idcs . append ( np . unique ( mask_idc [ mask_idc < sz ] ) )
@ -158,54 +150,54 @@ def compute_mask_indices(
class WavLMConfig :
def __init__ ( self , cfg = None ) :
self . extractor_mode : str = " default " # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True)
self . encoder_layers : int = 12 # num encoder layers in the transformer
self . extractor_mode : str = " default " # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True)
self . encoder_layers : int = 12 # num encoder layers in the transformer
self . encoder_embed_dim : int = 768 # encoder embedding dimension
self . encoder_ffn_embed_dim : int = 3072 # encoder embedding dimension for FFN
self . encoder_attention_heads : int = 12 # num encoder attention heads
self . activation_fn : str = " gelu " # activation function to use
self . encoder_embed_dim : int = 768 # encoder embedding dimension
self . encoder_ffn_embed_dim : int = 3072 # encoder embedding dimension for FFN
self . encoder_attention_heads : int = 12 # num encoder attention heads
self . activation_fn : str = " gelu " # activation function to use
self . layer_norm_first : bool = False # apply layernorm first in the transformer
self . conv_feature_layers : str = " [(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2 " # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]
self . conv_bias : bool = False # include bias in conv encoder
self . feature_grad_mult : float = 1.0 # multiply feature extractor var grads by this
self . layer_norm_first : bool = False # apply layernorm first in the transformer
self . conv_feature_layers : str = " [(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2 " # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]
self . conv_bias : bool = False # include bias in conv encoder
self . feature_grad_mult : float = 1.0 # multiply feature extractor var grads by this
self . normalize : bool = False # normalize input to have 0 mean and unit variance during training
# dropouts
self . dropout : float = 0.1 # dropout probability for the transformer
self . attention_dropout : float = 0.1 # dropout probability for attention weights
self . activation_dropout : float = 0.0 # dropout probability after activation in FFN
self . encoder_layerdrop : float = 0.0 # probability of dropping a tarnsformer layer
self . dropout_input : float = 0.0 # dropout to apply to the input (after feat extr)
self . dropout_features : float = 0.0 # dropout to apply to the features (after feat extr)
self . dropout : float = 0.1 # dropout probability for the transformer
self . attention_dropout : float = 0.1 # dropout probability for attention weights
self . activation_dropout : float = 0.0 # dropout probability after activation in FFN
self . encoder_layerdrop : float = 0.0 # probability of dropping a tarnsformer layer
self . dropout_input : float = 0.0 # dropout to apply to the input (after feat extr)
self . dropout_features : float = 0.0 # dropout to apply to the features (after feat extr)
# masking
self . mask_length : int = 10 # mask length
self . mask_prob : float = 0.65 # probability of replacing a token with mask
self . mask_selection : str = " static " # how to choose mask length
self . mask_other : float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh
self . no_mask_overlap : bool = False # whether to allow masks to overlap
self . mask_min_space : int = 1 # min space between spans (if no overlap is enabled)
self . mask_length : int = 10 # mask length
self . mask_prob : float = 0.65 # probability of replacing a token with mask
self . mask_selection : str = " static " # how to choose mask length
self . mask_other : float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh
self . no_mask_overlap : bool = False # whether to allow masks to overlap
self . mask_min_space : int = 1 # min space between spans (if no overlap is enabled)
# channel masking
self . mask_channel_length : int = 10 # length of the mask for features (channels)
self . mask_channel_prob : float = 0.0 # probability of replacing a feature with 0
self . mask_channel_selection : str = " static " # how to choose mask length for channel masking
self . mask_channel_other : float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices
self . no_mask_channel_overlap : bool = False # whether to allow channel masks to overlap
self . mask_channel_min_space : int = 1 # min space between spans (if no overlap is enabled)
self . mask_channel_length : int = 10 # length of the mask for features (channels)
self . mask_channel_prob : float = 0.0 # probability of replacing a feature with 0
self . mask_channel_selection : str = " static " # how to choose mask length for channel masking
self . mask_channel_other : float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices
self . no_mask_channel_overlap : bool = False # whether to allow channel masks to overlap
self . mask_channel_min_space : int = 1 # min space between spans (if no overlap is enabled)
# positional embeddings
self . conv_pos : int = 128 # number of filters for convolutional positional embeddings
self . conv_pos_groups : int = 16 # number of groups for convolutional positional embedding
self . conv_pos : int = 128 # number of filters for convolutional positional embeddings
self . conv_pos_groups : int = 16 # number of groups for convolutional positional embedding
# relative position embedding
self . relative_position_embedding : bool = True # apply relative position embedding
self . num_buckets : int = 320 # number of buckets for relative position embedding
self . max_distance : int = 1280 # maximum distance for relative position embedding
self . gru_rel_pos : bool = True # apply gated relative position embedding
self . relative_position_embedding : bool = True # apply relative position embedding
self . num_buckets : int = 320 # number of buckets for relative position embedding
self . max_distance : int = 1280 # maximum distance for relative position embedding
self . gru_rel_pos : bool = True # apply gated relative position embedding
if cfg is not None :
self . update ( cfg )
@ -216,9 +208,8 @@ class WavLMConfig:
class WavLM ( nn . Layer ) :
def __init__ (
self ,
cfg : WavLMConfig ,
) - > None :
self ,
cfg : WavLMConfig , ) - > None :
super ( ) . __init__ ( )
logger . info ( f " WavLM Config: { cfg . __dict__ } " )
@ -230,14 +221,11 @@ class WavLM(nn.Layer):
conv_layers = feature_enc_layers ,
dropout = 0.0 ,
mode = cfg . extractor_mode ,
conv_bias = cfg . conv_bias ,
)
conv_bias = cfg . conv_bias , )
self . post_extract_proj = (
nn . Linear ( self . embed , cfg . encoder_embed_dim )
if self . embed != cfg . encoder_embed_dim
else None
)
self . post_extract_proj = ( nn . Linear ( self . embed , cfg . encoder_embed_dim )
if self . embed != cfg . encoder_embed_dim else
None )
self . mask_prob = cfg . mask_prob
self . mask_selection = cfg . mask_selection
@ -260,8 +248,7 @@ class WavLM(nn.Layer):
self . mask_emb = self . create_parameter (
shape = [ cfg . encoder_embed_dim ] ,
default_initializer = nn . initializer . Uniform ( ) ,
)
default_initializer = nn . initializer . Uniform ( ) , )
self . encoder = TransformerEncoder ( cfg )
self . layer_norm = LayerNorm ( self . embed )
@ -278,8 +265,7 @@ class WavLM(nn.Layer):
self . mask_other ,
min_masks = 2 ,
no_overlap = self . no_mask_overlap ,
min_space = self . mask_min_space ,
)
min_space = self . mask_min_space , )
# mask_indices = torch.from_numpy(mask_indices).to(x.device)
mask_indices = paddle . to_tensor ( mask_indices , dtype = ' int64 ' )
x [ mask_indices ] = self . mask_emb
@ -295,40 +281,35 @@ class WavLM(nn.Layer):
self . mask_channel_selection ,
self . mask_channel_other ,
no_overlap = self . no_mask_channel_overlap ,
min_space = self . mask_channel_min_space ,
)
min_space = self . mask_channel_min_space , )
mask_channel_indices = (
# torch.from_numpy(mask_channel_indices)
paddle . to_tensor ( mask_channel_indices , dtype = ' int64 ' )
. to ( x . device )
. unsqueeze ( 1 )
. expand ( - 1 , T , - 1 )
)
. to ( x . device ) . unsqueeze ( 1 ) . expand ( - 1 , T , - 1 ) )
x [ mask_channel_indices ] = 0
return x , mask_indices
def forward_padding_mask (
self , features : Tensor , padding_mask : Tensor ,
) - > Tensor :
self ,
features : Tensor ,
padding_mask : Tensor , ) - > Tensor :
extra = padding_mask . size ( 1 ) % features . size ( 1 )
if extra > 0 :
padding_mask = padding_mask [ : , : - extra ]
padding_mask = padding_mask . view (
padding_mask . size ( 0 ) , features . size ( 1 ) , - 1
)
padding_mask . size ( 0 ) , features . size ( 1 ) , - 1 )
padding_mask = padding_mask . all ( - 1 )
return padding_mask
def extract_features (
self ,
source : Tensor ,
padding_mask : Optional [ Tensor ] = None ,
mask : bool = False ,
ret_conv : bool = False ,
output_layer : Optional [ int ] = None ,
ret_layer_results : bool = False ,
) :
self ,
source : Tensor ,
padding_mask : Optional [ Tensor ] = None ,
mask : bool = False ,
ret_conv : bool = False ,
output_layer : Optional [ int ] = None ,
ret_layer_results : bool = False , ) :
if self . feature_grad_mult > 0 :
features = self . feature_extractor ( source )
@ -339,7 +320,7 @@ class WavLM(nn.Layer):
with paddle . no_grad ( ) :
features = self . feature_extractor ( source )
features = features . transpose ( [ 0 , 2 , 1 ] ) # [1, 49, 512]
features = features . transpose ( [ 0 , 2 , 1 ] ) # [1, 49, 512]
features = self . layer_norm ( features )
if padding_mask is not None :
@ -351,9 +332,7 @@ class WavLM(nn.Layer):
features = self . dropout_input ( features )
if mask :
x , mask_indices = self . apply_mask (
features , padding_mask
)
x , mask_indices = self . apply_mask ( features , padding_mask )
else :
x = features
@ -362,33 +341,35 @@ class WavLM(nn.Layer):
# x: (B, T, D), float
# padding_mask: (B, T), bool
# mask_indices: (B, T), bool
x , layer_results = self . encoder (
x ,
padding_mask = padding_mask ,
layer = None if output_layer is None else output_layer - 1
)
layer = None if output_layer is None else output_layer - 1 )
# print(f"Debugging: x.shape: {x.shape}, x.mean(): {x.mean()}, x.std(): {x.std()}")
res = { " x " : x , " padding_mask " : padding_mask , " features " : features , " layer_results " : layer_results }
res = {
" x " : x ,
" padding_mask " : padding_mask ,
" features " : features ,
" layer_results " : layer_results
}
feature = res [ " features " ] if ret_conv else res [ " x " ]
if ret_layer_results :
feature = ( feature , res [ " layer_results " ] )
return feature , res [ " padding_mask " ]
def forward ( self , x ) :
return self . extract_features ( x ) [ 0 ]
class ConvFeatureExtractionModel ( nn . Layer ) :
def __init__ (
self ,
conv_layers : List [ Tuple [ int , int , int ] ] ,
dropout : float = 0.0 ,
mode : str = " default " ,
conv_bias : bool = False ,
conv_type : str = " default "
) :
def __init__ ( self ,
conv_layers : List [ Tuple [ int , int , int ] ] ,
dropout : float = 0.0 ,
mode : str = " default " ,
conv_bias : bool = False ,
conv_type : str = " default " ) :
super ( ) . __init__ ( )
assert mode in { " default " , " layer_norm " }
@ -400,17 +381,20 @@ class ConvFeatureExtractionModel(nn.Layer):
stride ,
is_layer_norm = False ,
is_group_norm = False ,
conv_bias = False ,
) :
conv_bias = False , ) :
def make_conv ( ) :
conv = nn . Conv1D ( n_in , n_out , k , stride = stride , bias_attr = conv_bias ,
weight_attr = nn . initializer . KaimingNormal ( ) )
conv = nn . Conv1D (
n_in ,
n_out ,
k ,
stride = stride ,
bias_attr = conv_bias ,
weight_attr = nn . initializer . KaimingNormal ( ) )
# nn.init.kaiming_normal_(conv.weight)
return conv
assert (
is_layer_norm and is_group_norm
) == False , " layer norm and group norm are exclusive "
assert ( is_layer_norm and is_group_norm
) == False , " layer norm and group norm are exclusive "
if is_layer_norm :
return nn . Sequential (
@ -419,19 +403,18 @@ class ConvFeatureExtractionModel(nn.Layer):
nn . Sequential (
TransposeLast ( ) ,
nn . LayerNorm ( normalized_shape = dim , epsilon = 1e-5 ) ,
TransposeLast ( ) ,
) ,
nn . GELU ( ) ,
)
TransposeLast ( ) , ) ,
nn . GELU ( ) , )
elif is_group_norm :
return nn . Sequential (
make_conv ( ) ,
nn . Dropout ( p = dropout ) ,
nn . GroupNorm ( num_groups = dim , num_channels = dim , epsilon = 1e-5 ) ,
nn . GELU ( ) ,
)
nn . GroupNorm (
num_groups = dim , num_channels = dim , epsilon = 1e-5 ) ,
nn . GELU ( ) , )
else :
return nn . Sequential ( make_conv ( ) , nn . Dropout ( p = dropout ) , nn . GELU ( ) )
return nn . Sequential (
make_conv ( ) , nn . Dropout ( p = dropout ) , nn . GELU ( ) )
self . conv_type = conv_type
if self . conv_type == " default " :
@ -449,9 +432,7 @@ class ConvFeatureExtractionModel(nn.Layer):
stride ,
is_layer_norm = mode == " layer_norm " ,
is_group_norm = mode == " default " and i == 0 ,
conv_bias = conv_bias ,
)
)
conv_bias = conv_bias , ) )
in_d = dim
elif self . conv_type == " conv2d " :
in_d = 1
@ -460,9 +441,7 @@ class ConvFeatureExtractionModel(nn.Layer):
assert len ( cl ) == 3
( dim , k , stride ) = cl
self . conv_layers . append (
paddle . nn . Conv2D ( in_d , dim , k , stride )
)
self . conv_layers . append ( paddle . nn . Conv2D ( in_d , dim , k , stride ) )
self . conv_layers . append ( paddle . nn . ReLU ( ) )
in_d = dim
elif self . conv_type == " custom " :
@ -473,17 +452,13 @@ class ConvFeatureExtractionModel(nn.Layer):
assert len ( cl ) == 3
( dim , k , stride ) = cl
self . conv_layers . append (
paddle . nn . Conv2D ( in_d , dim , k , stride , padding = 1 )
)
self . conv_layers . append (
paddle . nn . LayerNorm ( [ dim , idim ] )
)
paddle . nn . Conv2D ( in_d , dim , k , stride , padding = 1 ) )
self . conv_layers . append ( paddle . nn . LayerNorm ( [ dim , idim ] ) )
self . conv_layers . append ( paddle . nn . ReLU ( ) )
in_d = dim
if ( i + 1 ) % 2 == 0 :
self . conv_layers . append (
paddle . nn . MaxPool2D ( 2 , stride = 2 , ceil_mode = True )
)
paddle . nn . MaxPool2D ( 2 , stride = 2 , ceil_mode = True ) )
idim = int ( math . ceil ( idim / 2 ) )
else :
pass
@ -518,8 +493,8 @@ class TransformerEncoder(nn.Layer):
self . dropout = args . dropout
self . embedding_dim = args . encoder_embed_dim
dropout = 0
std = math . sqrt ( ( 4 * ( 1.0 - dropout ) ) / ( args . conv_pos * self . embedding_dim ) )
std = math . sqrt (
( 4 * ( 1.0 - dropout ) ) / ( args . conv_pos * self . embedding_dim ) )
self . pos_conv = nn . Conv1D (
self . embedding_dim ,
@ -528,15 +503,16 @@ class TransformerEncoder(nn.Layer):
padding = args . conv_pos / / 2 ,
groups = args . conv_pos_groups ,
weight_attr = nn . initializer . Normal ( mean = 0 , std = std ) ,
bias_attr = True
)
bias_attr = True )
# nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
# nn.init.constant_(self.pos_conv.bias, 0)
# self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
# self.pos_conv.weight_g = self.pos_conv.weight_g.unsqueeze(0).unsqueeze(0)
self . pos_conv = nn . utils . weight_norm ( self . pos_conv , name = " weight " , dim = 2 )
self . pos_conv = nn . Sequential ( self . pos_conv , SamePad ( args . conv_pos ) , nn . GELU ( ) )
self . pos_conv = nn . utils . weight_norm (
self . pos_conv , name = " weight " , dim = 2 )
self . pos_conv = nn . Sequential ( self . pos_conv ,
SamePad ( args . conv_pos ) , nn . GELU ( ) )
if hasattr ( args , " relative_position_embedding " ) :
self . relative_position_embedding = args . relative_position_embedding
@ -547,25 +523,23 @@ class TransformerEncoder(nn.Layer):
self . num_buckets = 0
self . max_distance = 0
self . layers = nn . LayerList (
[
TransformerSentenceEncoderLayer (
embedding_dim = self . embedding_dim ,
ffn_embedding_dim = args . encoder_ffn_embed_dim ,
num_attention_heads = args . encoder_attention_heads ,
dropout = self . dropout ,
attention_dropout = args . attention_dropout ,
activation_dropout = args . activation_dropout ,
activation_fn = args . activation_fn ,
layer_norm_first = args . layer_norm_first ,
has_relative_attention_bias = ( self . relative_position_embedding and i == 0 ) ,
num_buckets = self . num_buckets ,
max_distance = self . max_distance ,
gru_rel_pos = args . gru_rel_pos ,
)
for i in range ( args . encoder_layers )
]
)
self . layers = nn . LayerList ( [
TransformerSentenceEncoderLayer (
embedding_dim = self . embedding_dim ,
ffn_embedding_dim = args . encoder_ffn_embed_dim ,
num_attention_heads = args . encoder_attention_heads ,
dropout = self . dropout ,
attention_dropout = args . attention_dropout ,
activation_dropout = args . activation_dropout ,
activation_fn = args . activation_fn ,
layer_norm_first = args . layer_norm_first ,
has_relative_attention_bias = (
self . relative_position_embedding and i == 0 ) ,
num_buckets = self . num_buckets ,
max_distance = self . max_distance ,
gru_rel_pos = args . gru_rel_pos , )
for i in range ( args . encoder_layers )
] )
self . layer_norm_first = args . layer_norm_first
self . layer_norm = LayerNorm ( self . embedding_dim )
@ -574,14 +548,19 @@ class TransformerEncoder(nn.Layer):
# self.apply(init_bert_params)
def forward ( self , x , padding_mask = None , streaming_mask = None , layer = None ) :
x , layer_results = self . extract_features ( x , padding_mask , streaming_mask , layer )
x , layer_results = self . extract_features ( x , padding_mask ,
streaming_mask , layer )
# print("x.shape", x.shape)
if self . layer_norm_first and layer is None :
x = self . layer_norm ( x )
return x , layer_results
def extract_features ( self , x , padding_mask = None , streaming_mask = None , tgt_layer = None ) :
def extract_features ( self ,
x ,
padding_mask = None ,
streaming_mask = None ,
tgt_layer = None ) :
if padding_mask is not None :
x [ padding_mask ] = 0
@ -598,7 +577,6 @@ class TransformerEncoder(nn.Layer):
# x = x.transpose(0, 1)
x = x . transpose ( [ 1 , 0 , 2 ] )
layer_results = [ ]
z = None
if tgt_layer is not None :
@ -608,7 +586,12 @@ class TransformerEncoder(nn.Layer):
for i , layer in enumerate ( self . layers ) :
dropout_probability = np . random . random ( )
if not self . training or ( dropout_probability > self . layerdrop ) :
x , z , pos_bias = layer ( x , self_attn_padding_mask = padding_mask , need_weights = False , self_attn_mask = streaming_mask , pos_bias = pos_bias )
x , z , pos_bias = layer (
x ,
self_attn_padding_mask = padding_mask ,
need_weights = False ,
self_attn_mask = streaming_mask ,
pos_bias = pos_bias )
if tgt_layer is not None :
layer_results . append ( ( x , z ) )
if i == tgt_layer :
@ -633,20 +616,19 @@ class TransformerSentenceEncoderLayer(nn.Layer):
def __init__ (
self ,
embedding_dim : float = 768 ,
ffn_embedding_dim : float = 3072 ,
num_attention_heads : float = 8 ,
dropout : float = 0.1 ,
attention_dropout : float = 0.1 ,
activation_dropout : float = 0.1 ,
activation_fn : str = " relu " ,
layer_norm_first : bool = False ,
has_relative_attention_bias : bool = True ,
num_buckets : int = 0 ,
max_distance : int = 0 ,
rescale_init : bool = False ,
gru_rel_pos : bool = True ,
) - > None :
embedding_dim : float = 768 ,
ffn_embedding_dim : float = 3072 ,
num_attention_heads : float = 8 ,
dropout : float = 0.1 ,
attention_dropout : float = 0.1 ,
activation_dropout : float = 0.1 ,
activation_fn : str = " relu " ,
layer_norm_first : bool = False ,
has_relative_attention_bias : bool = True ,
num_buckets : int = 0 ,
max_distance : int = 0 ,
rescale_init : bool = False ,
gru_rel_pos : bool = True , ) - > None :
super ( ) . __init__ ( )
# Initialize parameters
@ -666,8 +648,7 @@ class TransformerSentenceEncoderLayer(nn.Layer):
num_buckets = num_buckets ,
max_distance = max_distance ,
rescale_init = rescale_init ,
gru_rel_pos = gru_rel_pos ,
)
gru_rel_pos = gru_rel_pos , )
self . dropout1 = nn . Dropout ( dropout )
self . dropout2 = nn . Dropout ( self . activation_dropout )
@ -679,7 +660,8 @@ class TransformerSentenceEncoderLayer(nn.Layer):
self . self_attn_layer_norm = LayerNorm ( self . embedding_dim )
if self . activation_name == " glu " :
self . fc1 = GLU_Linear ( self . embedding_dim , ffn_embedding_dim , " swish " )
self . fc1 = GLU_Linear ( self . embedding_dim , ffn_embedding_dim ,
" swish " )
else :
self . fc1 = nn . Linear ( self . embedding_dim , ffn_embedding_dim )
self . fc2 = nn . Linear ( ffn_embedding_dim , self . embedding_dim )
@ -687,21 +669,19 @@ class TransformerSentenceEncoderLayer(nn.Layer):
# layer norm associated with the position wise feed-forward NN
self . final_layer_norm = LayerNorm ( self . embedding_dim )
def forward (
self ,
x : Tensor ,
self_attn_mask : Tensor = None ,
self_attn_padding_mask : Tensor = None ,
need_weights : bool = False ,
pos_bias = None
) :
def forward ( self ,
x : Tensor ,
self_attn_mask : Tensor = None ,
self_attn_padding_mask : Tensor = None ,
need_weights : bool = False ,
pos_bias = None ) :
"""
LayerNorm is applied either before or after the self - attention / ffn
modules similar to the original Transformer imlementation .
"""
residual = x
if self . layer_norm_first :
x = self . self_attn_layer_norm ( x )
x , attn , pos_bias = self . self_attn (
query = x ,
@ -710,8 +690,7 @@ class TransformerSentenceEncoderLayer(nn.Layer):
key_padding_mask = self_attn_padding_mask ,
need_weights = False ,
attn_mask = self_attn_mask ,
position_bias = pos_bias
)
position_bias = pos_bias )
# import pdb; pdb.set_trace()
x = self . dropout1 ( x )
x = residual + x
@ -734,8 +713,7 @@ class TransformerSentenceEncoderLayer(nn.Layer):
key_padding_mask = self_attn_padding_mask ,
need_weights = need_weights ,
attn_mask = self_attn_mask ,
position_bias = pos_bias
)
position_bias = pos_bias )
x = self . dropout1 ( x )
x = residual + x