# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Paddle Wav2Vec2 model.""" from dataclasses import dataclass from typing import Optional from typing import Tuple from typing import Union import numpy as np import paddle from paddle import nn from paddlespeech.s2t.models.wav2vec2.modules.activations import ACT2FN from paddlespeech.s2t.models.wav2vec2.modules.modeling_outputs import BaseModelOutput from paddlespeech.s2t.models.wav2vec2.modules.modeling_outputs import ModelOutput from paddlespeech.s2t.models.wav2vec2.modules.modeling_outputs import Wav2Vec2BaseModelOutput from paddlespeech.s2t.utils.log import Log logger = Log(__name__).getlog() @dataclass class Wav2Vec2ForPreTrainingOutput(ModelOutput): """ Output type of [`Wav2Vec2ForPreTraining`], with potential hidden states and attentions. Args: loss (*optional*, returned when `sample_negative_indices` are passed, `paddle.Tensor` of shape `(1,)`): Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the [official paper](https://arxiv.org/pdf/2006.11477.pdf) . (classification) loss. projected_states (`paddle.Tensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`): Hidden-states of the model projected to *config.proj_codevector_dim* that can be used to predict the masked projected quantized states. projected_quantized_states (`paddle.Tensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`): Quantized extracted feature vectors projected to *config.proj_codevector_dim* representing the positive target vectors for contrastive loss. hidden_states (`tuple(paddle.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `paddle.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. attentions (`tuple(paddle.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `paddle.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. contrastive_loss (*optional*, returned when `sample_negative_indices` are passed, `paddle.Tensor` of shape `(1,)`): The contrastive loss (L_m) as stated in the [official paper](https://arxiv.org/pdf/2006.11477.pdf) . diversity_loss (*optional*, returned when `sample_negative_indices` are passed, `paddle.Tensor` of shape `(1,)`): The diversity loss (L_d) as stated in the [official paper](https://arxiv.org/pdf/2006.11477.pdf) . """ loss: Optional[paddle.Tensor] = None projected_states: paddle.Tensor = None projected_quantized_states: paddle.Tensor = None codevector_perplexity: paddle.Tensor = None hidden_states: Optional[Tuple[paddle.Tensor]] = None attentions: Optional[Tuple[paddle.Tensor]] = None contrastive_loss: Optional[paddle.Tensor] = None diversity_loss: Optional[paddle.Tensor] = None def _compute_mask_indices( shape: Tuple[int, int], mask_prob: float, mask_length: int, attention_mask: Optional[paddle.Tensor]=None, min_masks: int=0, ) -> np.ndarray: """ Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on CPU as part of the preprocessing during training. Args: shape: The shape for which to compute masks. This should be of a tuple of size 2 where the first element is the batch size and the second element is the length of the axis to span. mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of independently generated mask spans of length `mask_length` is computed by `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the actual percentage will be smaller. mask_length: size of the mask min_masks: minimum number of masked spans attention_mask: A (right-padded) attention mask which independently shortens the feature axis of each batch dimension. """ batch_size, sequence_length = shape if mask_length < 1: raise ValueError("`mask_length` has to be bigger than 0.") if mask_length > sequence_length: raise ValueError( f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}" f" and `sequence_length`: {sequence_length}`") # epsilon is used for probabilistic rounding epsilon = np.random.rand(1).item() def compute_num_masked_span(input_length): """Given input length, compute how many spans should be masked""" num_masked_span = int(mask_prob * input_length / mask_length + epsilon) num_masked_span = max(num_masked_span, min_masks) # make sure num masked span <= sequence_length if num_masked_span * mask_length > sequence_length: num_masked_span = sequence_length // mask_length # make sure num_masked span is also <= input_length - (mask_length - 1) if input_length - (mask_length - 1) < num_masked_span: num_masked_span = max(input_length - (mask_length - 1), 0) return num_masked_span # compute number of masked spans in batch input_lengths = (attention_mask.sum(-1).detach().tolist() if attention_mask is not None else [sequence_length for _ in range(batch_size)]) # SpecAugment mask to fill spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=np.bool) spec_aug_mask_idxs = [] max_num_masked_span = compute_num_masked_span(sequence_length) if max_num_masked_span == 0: return spec_aug_mask for input_length in input_lengths: # compute num of masked spans for this input num_masked_span = compute_num_masked_span(input_length) # get random indices to mask spec_aug_mask_idx = np.random.choice( np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False) # pick first sampled index that will serve as a dummy index to pad vector # to ensure same dimension for all batches due to probabilistic rounding # Picking first sample just pads those vectors twice. if len(spec_aug_mask_idx) == 0: # this case can only happen if `input_length` is strictly smaller then # `sequence_length` in which case the last token has to be a padding # token which we can use as a dummy mask id dummy_mask_idx = sequence_length - 1 else: dummy_mask_idx = spec_aug_mask_idx[0] spec_aug_mask_idx = np.concatenate([ spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx ]) spec_aug_mask_idxs.append(spec_aug_mask_idx) spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) # expand masked indices to masked spans spec_aug_mask_idxs = np.broadcast_to( spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)) spec_aug_mask_idxs = spec_aug_mask_idxs.reshape( (batch_size, max_num_masked_span * mask_length)) # add offset to the starting indexes so that indexes now create a span offsets = np.arange(mask_length)[None, None, :] offsets = np.broadcast_to(offsets, ( batch_size, max_num_masked_span, mask_length)).reshape( (batch_size, max_num_masked_span * mask_length)) spec_aug_mask_idxs = spec_aug_mask_idxs + offsets # ensure that we cannot have indices larger than sequence_length if spec_aug_mask_idxs.max() > sequence_length - 1: spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 # scatter indices to mask np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) return spec_aug_mask def _sample_negative_indices(features_shape: Tuple, num_negatives: int, mask_time_indices: Optional[np.ndarray]=None): """ Sample `num_negatives` vectors from feature vectors. """ batch_size, sequence_length = features_shape # generate indices of the positive vectors themselves, repeat them `num_negatives` times sequence_length_range = np.arange(sequence_length) # get `num_negatives` random vector indices from the same utterance sampled_negative_indices = np.zeros( shape=(batch_size, sequence_length, num_negatives), dtype=np.int32) mask_time_indices = (mask_time_indices.astype(np.bool) if mask_time_indices is not None else np.ones(features_shape, dtype=np.bool)) for batch_idx in range(batch_size): high = mask_time_indices[batch_idx].sum() - 1 mapped_masked_indices = sequence_length_range[mask_time_indices[ batch_idx]] feature_indices = np.broadcast_to( np.arange(high + 1)[:, None], (high + 1, num_negatives)) sampled_indices = np.random.randint( 0, high, size=(high + 1, num_negatives)) # avoid sampling the same positive vector, but keep the distribution uniform sampled_indices[sampled_indices >= feature_indices] += 1 # remap to actual indices sampled_negative_indices[batch_idx][mask_time_indices[ batch_idx]] = mapped_masked_indices[sampled_indices] # correct for batch size sampled_negative_indices[batch_idx] += batch_idx * sequence_length return sampled_negative_indices class Wav2Vec2NoLayerNormConvLayer(nn.Layer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 self.out_conv_dim = config.conv_dim[layer_id] self.conv = nn.Conv1D( self.in_conv_dim, self.out_conv_dim, kernel_size=config.conv_kernel[layer_id], stride=config.conv_stride[layer_id], bias_attr=config.conv_bias, ) self.activation = ACT2FN[config.feat_extract_activation] def forward(self, hidden_states): hidden_states = self.conv(hidden_states) hidden_states = self.activation(hidden_states) return hidden_states class Wav2Vec2LayerNormConvLayer(nn.Layer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 self.out_conv_dim = config.conv_dim[layer_id] self.conv = nn.Conv1D( self.in_conv_dim, self.out_conv_dim, kernel_size=config.conv_kernel[layer_id], stride=config.conv_stride[layer_id], bias_attr=config.conv_bias, ) self.layer_norm = nn.LayerNorm(self.out_conv_dim) self.activation = ACT2FN[config.feat_extract_activation] def forward(self, hidden_states): hidden_states = self.conv(hidden_states) hidden_states = hidden_states.transpose([0, 2, 1]) hidden_states = self.layer_norm(hidden_states) hidden_states = hidden_states.transpose([0, 2, 1]) hidden_states = self.activation(hidden_states) return hidden_states class Wav2Vec2GroupNormConvLayer(nn.Layer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 self.out_conv_dim = config.conv_dim[layer_id] self.conv = nn.Conv1D( self.in_conv_dim, self.out_conv_dim, kernel_size=config.conv_kernel[layer_id], stride=config.conv_stride[layer_id], bias_attr=config.conv_bias, ) self.activation = ACT2FN[config.feat_extract_activation] self.layer_norm = nn.GroupNorm( num_groups=self.out_conv_dim, num_channels=self.out_conv_dim) def forward(self, hidden_states): hidden_states = self.conv(hidden_states) hidden_states = self.layer_norm(hidden_states) hidden_states = self.activation(hidden_states) return hidden_states class Wav2Vec2PositionalConvEmbedding(nn.Layer): def __init__(self, config): super().__init__() self.conv = nn.Conv1D( config.hidden_size, config.hidden_size, kernel_size=config.num_conv_pos_embeddings, padding=config.num_conv_pos_embeddings // 2, groups=config.num_conv_pos_embedding_groups, ) self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2) self.padding = Wav2Vec2SamePadLayer(config.num_conv_pos_embeddings) self.activation = ACT2FN[config.feat_extract_activation] def forward(self, hidden_states): hidden_states = hidden_states.transpose([0, 2, 1]) hidden_states = self.conv(hidden_states) hidden_states = self.padding(hidden_states) hidden_states = self.activation(hidden_states) hidden_states = hidden_states.transpose([0, 2, 1]) return hidden_states class Wav2Vec2SamePadLayer(nn.Layer): def __init__(self, num_conv_pos_embeddings): super().__init__() self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0 def forward(self, hidden_states): if self.num_pad_remove > 0: hidden_states = hidden_states[:, :, :-self.num_pad_remove] return hidden_states class Wav2Vec2FeatureEncoder(nn.Layer): """Construct the features from raw audio waveform""" def __init__(self, config): super().__init__() if config.feat_extract_norm == "group": conv_layers = [Wav2Vec2GroupNormConvLayer(config, layer_id=0)] + [ Wav2Vec2NoLayerNormConvLayer(config, layer_id=i + 1) for i in range(config.num_feat_extract_layers - 1) ] elif config.feat_extract_norm == "layer": conv_layers = [ Wav2Vec2LayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers) ] else: raise ValueError( f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']" ) self.conv_layers = nn.LayerList(conv_layers) self.gradient_checkpointing = False def _freeze_parameters(self): for param in self.parameters(): param.trainable = False def forward(self, input_values): hidden_states = input_values[:, None] for conv_layer in self.conv_layers: hidden_states = conv_layer(hidden_states) return hidden_states class Wav2Vec2FeatureProjection(nn.Layer): def __init__(self, config): super().__init__() self.layer_norm = nn.LayerNorm( config.conv_dim[-1], epsilon=config.layer_norm_eps) self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size) self.dropout = nn.Dropout(config.feat_proj_dropout) def forward(self, hidden_states): # non-projected hidden states are needed for quantization norm_hidden_states = self.layer_norm(hidden_states) hidden_states = self.projection(norm_hidden_states) hidden_states = self.dropout(hidden_states) return hidden_states, norm_hidden_states # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Wav2Vec2 class Wav2Vec2Attention(nn.Layer): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__( self, embed_dim: int, num_heads: int, dropout: float=0.0, is_decoder: bool=False, bias: bool=True, ): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads if (self.head_dim * num_heads) != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" f" and `num_heads`: {num_heads}).") self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.k_proj = nn.Linear(embed_dim, embed_dim, bias_attr=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias_attr=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias_attr=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias_attr=bias) def _shape(self, tensor: paddle.Tensor, seq_len: int, bsz: int): return paddle.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)).transpose([0, 2, 1, 3]) def forward( self, hidden_states: paddle.Tensor, key_value_states: Optional[paddle.Tensor]=None, past_key_value: Optional[Tuple[paddle.Tensor]]=None, attention_mask: Optional[paddle.Tensor]=None, layer_head_mask: Optional[paddle.Tensor]=None, output_attentions: bool=False, ) -> Tuple[paddle.Tensor, Optional[ paddle.Tensor], Optional[Tuple[paddle.Tensor]]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None bsz, tgt_len, _ = hidden_states.shape # get query proj query_states = self.q_proj(hidden_states) * self.scaling # get key, value proj if is_cross_attention and past_key_value is not None: # reuse k,v, cross_attentions key_states = past_key_value[0] value_states = past_key_value[1] elif is_cross_attention: # cross_attentions key_states = self._shape(self.k_proj(key_value_states), -1, bsz) value_states = self._shape(self.v_proj(key_value_states), -1, bsz) elif past_key_value is not None: # reuse k, v, self_attention key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) key_states = paddle.concat([past_key_value[0], key_states], axis=2) value_states = paddle.concat( [past_key_value[1], value_states], axis=2) else: # self_attention key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) if self.is_decoder: # if cross_attention save Tuple(paddle.Tensor, paddle.Tensor) of all cross attention key/value_states. # Further calls to cross_attention layer can then reuse all cross-attention # key/value_states (first "if" case) # if uni-directional self-attention (decoder) save Tuple(paddle.Tensor, paddle.Tensor) of # all previous decoder key/value_states. Further calls to uni-directional self-attention # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = self._shape(query_states, tgt_len, bsz).reshape(proj_shape) key_states = key_states.reshape(proj_shape) value_states = value_states.reshape(proj_shape) src_len = key_states.shape[1] attn_weights = paddle.bmm(query_states, key_states.transpose([0, 2, 1])) if attn_weights.shape != [bsz * self.num_heads, tgt_len, src_len]: raise ValueError( f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" f" {attn_weights.shape}") if attention_mask is not None: if attention_mask.shape != [bsz, 1, tgt_len, src_len]: raise ValueError( f"Attention mask should be of size {[bsz, 1, tgt_len, src_len]}, but is {attention_mask.shape}" ) attn_weights = attn_weights.reshape(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.reshape(bsz * self.num_heads, tgt_len, src_len) attn_weights = nn.functional.softmax(attn_weights, axis=-1) if layer_head_mask is not None: if layer_head_mask.shape != [ self.num_heads, ]: raise ValueError( f"Head mask for a single layer should be of size {[self.num_heads,]}, but is" f" {layer_head_mask.shape}") attn_weights = layer_head_mask.reshape( (1, -1, 1, 1)) * attn_weights.reshape( (bsz, self.num_heads, tgt_len, src_len)) attn_weights = attn_weights.reshape( (bsz * self.num_heads, tgt_len, src_len)) if output_attentions: # this operation is a bit awkward, but it's required to # make sure that attn_weights keeps its gradient. # In order to do so, attn_weights have to be reshaped # twice and have to be reused in the following attn_weights_reshaped = attn_weights.reshape( (bsz, self.num_heads, tgt_len, src_len)) attn_weights = attn_weights_reshaped.reshape( (bsz * self.num_heads, tgt_len, src_len)) else: attn_weights_reshaped = None attn_probs = nn.functional.dropout( attn_weights, p=self.dropout, training=self.training) attn_output = paddle.bmm(attn_probs, value_states) if attn_output.shape != [bsz * self.num_heads, tgt_len, self.head_dim]: raise ValueError( f"`attn_output` should be of size {[bsz, self.num_heads, tgt_len, self.head_dim]}, but is" f" {attn_output.shape}") attn_output = attn_output.reshape( (bsz, self.num_heads, tgt_len, self.head_dim)) attn_output = attn_output.transpose([0, 2, 1, 3]) # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be # partitioned aross GPUs when using tensor-parallelism. attn_output = attn_output.reshape((bsz, tgt_len, self.embed_dim)) attn_output = self.out_proj(attn_output) return attn_output, attn_weights_reshaped, past_key_value class Wav2Vec2FeedForward(nn.Layer): def __init__(self, config): super().__init__() self.intermediate_dropout = nn.Dropout(config.activation_dropout) self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size) if isinstance(config.hidden_act, str): self.intermediate_act_fn = ACT2FN[config.hidden_act] else: self.intermediate_act_fn = config.hidden_act self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size) self.output_dropout = nn.Dropout(config.hidden_dropout) def forward(self, hidden_states): hidden_states = self.intermediate_dense(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_dropout(hidden_states) hidden_states = self.output_dense(hidden_states) hidden_states = self.output_dropout(hidden_states) return hidden_states class Wav2Vec2EncoderLayer(nn.Layer): def __init__(self, config): super().__init__() self.attention = Wav2Vec2Attention( embed_dim=config.hidden_size, num_heads=config.num_attention_heads, dropout=config.attention_dropout, is_decoder=False, ) self.dropout = nn.Dropout(config.hidden_dropout) self.layer_norm = nn.LayerNorm( config.hidden_size, epsilon=config.layer_norm_eps) self.feed_forward = Wav2Vec2FeedForward(config) self.final_layer_norm = nn.LayerNorm( config.hidden_size, epsilon=config.layer_norm_eps) def forward(self, hidden_states, attention_mask=None, output_attentions=False): attn_residual = hidden_states hidden_states, attn_weights, _ = self.attention( hidden_states, attention_mask=attention_mask, output_attentions=output_attentions) hidden_states = self.dropout(hidden_states) hidden_states = attn_residual + hidden_states hidden_states = self.layer_norm(hidden_states) hidden_states = hidden_states + self.feed_forward(hidden_states) hidden_states = self.final_layer_norm(hidden_states) outputs = (hidden_states, ) if output_attentions: outputs += (attn_weights, ) return outputs class Wav2Vec2EncoderLayerStableLayerNorm(nn.Layer): def __init__(self, config): super().__init__() self.attention = Wav2Vec2Attention( embed_dim=config.hidden_size, num_heads=config.num_attention_heads, dropout=config.attention_dropout, is_decoder=False, ) self.dropout = nn.Dropout(config.hidden_dropout) self.layer_norm = nn.LayerNorm( config.hidden_size, epsilon=config.layer_norm_eps) self.feed_forward = Wav2Vec2FeedForward(config) self.final_layer_norm = nn.LayerNorm( config.hidden_size, epsilon=config.layer_norm_eps) def forward(self, hidden_states, attention_mask=None, output_attentions=False): attn_residual = hidden_states hidden_states = self.layer_norm(hidden_states) hidden_states, attn_weights, _ = self.attention( hidden_states, attention_mask=attention_mask, output_attentions=output_attentions) hidden_states = self.dropout(hidden_states) hidden_states = attn_residual + hidden_states hidden_states = hidden_states + self.feed_forward( self.final_layer_norm(hidden_states)) outputs = (hidden_states, ) if output_attentions: outputs += (attn_weights, ) return outputs class Wav2Vec2Encoder(nn.Layer): def __init__(self, config): super().__init__() self.config = config self.pos_conv_embed = Wav2Vec2PositionalConvEmbedding(config) self.layer_norm = nn.LayerNorm( config.hidden_size, epsilon=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout) self.layers = nn.LayerList([ Wav2Vec2EncoderLayer(config) for _ in range(config.num_hidden_layers) ]) self.gradient_checkpointing = False def forward( self, hidden_states, attention_mask=None, output_attentions=False, output_hidden_states=False, return_dict=True, ): all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None if attention_mask is not None: # make sure padded tokens output 0 expand_attention_mask = attention_mask.unsqueeze(-1).repeat( 1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 # extend attention_mask attention_mask = 1.0 - attention_mask[:, None, None, :].to( dtype=hidden_states.dtype) attention_mask = attention_mask * np.iinfo(np.float32).min attention_mask = attention_mask.expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]) position_embeddings = self.pos_conv_embed(hidden_states) hidden_states = hidden_states + position_embeddings hidden_states = self.layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) #deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled() for layer in self.layers: if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states, ) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) dropout_probability = np.random.uniform(0, 1) skip_the_layer = True if self.training and ( dropout_probability < self.config.layerdrop) else False if not skip_the_layer: # or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: # create gradient checkpointing function def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs, output_attentions) return custom_forward else: layer_outputs = layer( hidden_states, attention_mask=attention_mask, output_attentions=output_attentions) hidden_states = layer_outputs[0] if skip_the_layer: layer_outputs = (None, None) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1], ) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states, ) if not return_dict: return tuple( v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions, ) class Wav2Vec2EncoderStableLayerNorm(nn.Layer): def __init__(self, config): super().__init__() self.config = config self.pos_conv_embed = Wav2Vec2PositionalConvEmbedding(config) self.layer_norm = nn.LayerNorm( config.hidden_size, epsilon=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout) self.layers = nn.LayerList([ Wav2Vec2EncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers) ]) self.gradient_checkpointing = False def forward( self, hidden_states, attention_mask=None, output_attentions=False, output_hidden_states=False, return_dict=True, ): all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None if attention_mask is not None: # make sure padded tokens are not attended to expand_attention_mask = attention_mask.unsqueeze( -1).repeat_interleave( hidden_states.shape[2], axis=2) hidden_states[~expand_attention_mask] = 0 # extend attention_mask attention_mask = 1.0 - attention_mask[:, None, None, :].to( dtype=hidden_states.dtype) attention_mask = attention_mask * np.iinfo(np.float32).min attention_mask = attention_mask.expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]) position_embeddings = self.pos_conv_embed(hidden_states) hidden_states = hidden_states + position_embeddings hidden_states = self.dropout(hidden_states) for layer in self.layers: if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states, ) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) dropout_probability = np.random.uniform(0, 1) skip_the_layer = True if self.training and ( dropout_probability < self.config.layerdrop) else False if not skip_the_layer: # or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication if self.gradient_checkpointing and self.training: # create gradient checkpointing function def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs, output_attentions) return custom_forward else: layer_outputs = layer( hidden_states, attention_mask=attention_mask, output_attentions=output_attentions) hidden_states = layer_outputs[0] if skip_the_layer: layer_outputs = (None, None) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1], ) hidden_states = self.layer_norm(hidden_states) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states, ) if not return_dict: return tuple( v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions, ) class Wav2Vec2GumbelVectorQuantizer(nn.Layer): """ Vector quantization using gumbel softmax. See `[CATEGORICAL REPARAMETERIZATION WITH GUMBEL-SOFTMAX](https://arxiv.org/pdf/1611.01144.pdf) for more information. """ def __init__(self, config): super().__init__() self.num_groups = config.num_codevector_groups self.num_vars = config.num_codevectors_per_group if config.codevector_dim % self.num_groups != 0: raise ValueError( f"`config.codevector_dim {config.codevector_dim} must be divisible " f"by `config.num_codevector_groups` {self.num_groups} for concatenation" ) # storage for codebook variables (codewords) self.codevectors = paddle.static.create_parameter( shape=[ 1, self.num_groups * self.num_vars, config.codevector_dim // self.num_groups ], dtype='float32') self.weight_proj = nn.Linear(config.conv_dim[-1], self.num_groups * self.num_vars) # can be decayed for training self.temperature = 2 @staticmethod def _compute_perplexity(probs, mask=None): if mask is not None: mask_extended = mask.flatten()[:, None, None].expand(probs.shape) probs = paddle.where(mask_extended, probs, paddle.zeros_like(probs)) marginal_probs = probs.sum(dim=0) / mask.sum() else: marginal_probs = probs.mean(dim=0) perplexity = paddle.exp(-paddle.sum( marginal_probs * paddle.log(marginal_probs + 1e-7), dim=-1)).sum() return perplexity def forward(self, hidden_states, mask_time_indices=None): batch_size, sequence_length, hidden_size = hidden_states.shape # project to codevector dim hidden_states = self.weight_proj(hidden_states) hidden_states = hidden_states.reshape( (batch_size * sequence_length * self.num_groups, -1)) if self.training: # sample code vector probs via gumbel in differentiateable way codevector_probs = nn.functional.gumbel_softmax( hidden_states.float(), tau=self.temperature, hard=True).type_as(hidden_states) # compute perplexity codevector_soft_dist = paddle.softmax( hidden_states.reshape((batch_size * sequence_length, self.num_groups, -1)).float(), axis=-1) perplexity = self._compute_perplexity(codevector_soft_dist, mask_time_indices) else: # take argmax in non-differentiable way # comptute hard codevector distribution (one hot) codevector_idx = hidden_states.argmax(dim=-1) codevector_probs = hidden_states.new_zeros( *hidden_states.shape).scatter_(-1, codevector_idx.reshape((-1, 1)), 1.0) codevector_probs = codevector_probs.reshape( (batch_size * sequence_length, self.num_groups, -1)) perplexity = self._compute_perplexity(codevector_probs, mask_time_indices) codevector_probs = codevector_probs.reshape( (batch_size * sequence_length, -1)) # use probs to retrieve codevectors codevectors_per_group = codevector_probs.unsqueeze( -1) * self.codevectors codevectors = codevectors_per_group.reshape( (batch_size * sequence_length, self.num_groups, self.num_vars, -1)) codevectors = codevectors.sum(-2).reshape( (batch_size, sequence_length, -1)) return codevectors, perplexity class Wav2Vec2Adapter(nn.Layer): def __init__(self, config): super().__init__() # feature dim might need to be down-projected if config.output_hidden_size != config.hidden_size: self.proj = nn.Linear(config.hidden_size, config.output_hidden_size) self.proj_layer_norm = nn.LayerNorm(config.output_hidden_size) else: self.proj = self.proj_layer_norm = None self.layers = nn.LayerList( Wav2Vec2AdapterLayer(config) for _ in range(config.num_adapter_layers)) self.layerdrop = config.layerdrop def forward(self, hidden_states): # down project hidden_states if necessary if self.proj is not None and self.proj_layer_norm is not None: hidden_states = self.proj(hidden_states) hidden_states = self.proj_layer_norm(hidden_states) hidden_states = hidden_states.transpose([0, 2, 1]) for layer in self.layers: layerdrop_prob = np.random.random() if not self.training or (layerdrop_prob > self.layerdrop): hidden_states = layer(hidden_states) hidden_states = hidden_states.transpose([0, 2, 1]) return hidden_states class Wav2Vec2AdapterLayer(nn.Layer): def __init__(self, config): super().__init__() self.conv = nn.Conv1D( config.output_hidden_size, 2 * config.output_hidden_size, config.adapter_kernel_size, stride=config.adapter_stride, padding=1, ) def forward(self, hidden_states): hidden_states = self.conv(hidden_states) hidden_states = nn.functional.glu(hidden_states, axis=1) return hidden_states class Wav2Vec2Model(nn.Layer): def __init__(self, config): super().__init__() self.config = config self.feature_extractor = Wav2Vec2FeatureEncoder(config) self.feature_projection = Wav2Vec2FeatureProjection(config) # model only needs masking vector if mask prob is > 0.0 if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0: # self.masked_spec_embed = nn.Parameter(paddle.Tensor(config.hidden_size).uniform_()) #self.masked_spec_embed = paddle.uniform([config.hidden_size]) self.masked_spec_embed = paddle.static.create_parameter( shape=[config.hidden_size], dtype='float32', default_initializer=paddle.nn.initializer.Uniform( low=0, high=1.0)) if config.do_stable_layer_norm: self.encoder = Wav2Vec2EncoderStableLayerNorm(config) else: self.encoder = Wav2Vec2Encoder(config) self.adapter = Wav2Vec2Adapter(config) if config.add_adapter else None # Initialize weights and apply final processing self.post_init() def freeze_feature_encoder(self): """ Calling this function will disable the gradient computation for the feature encoder so that its parameter will not be updated during training. """ self.feature_extractor._freeze_parameters() def _mask_hidden_states( self, hidden_states: paddle.Tensor, mask_time_indices: Optional[paddle.Tensor]=None, attention_mask: Optional[paddle.Tensor]=None, ): """ Masks extracted features along time axis and/or along feature axis according to [SpecAugment](https://arxiv.org/abs/1904.08779). """ # `config.apply_spec_augment` can set masking to False if not getattr(self.config, "apply_spec_augment", True): return hidden_states # generate indices & apply SpecAugment along time axis batch_size, sequence_length, hidden_size = hidden_states.shape if mask_time_indices is not None: # apply SpecAugment along time axis with given mask_time_indices hidden_states[mask_time_indices] = self.masked_spec_embed.to( hidden_states.dtype) elif self.config.mask_time_prob > 0 and self.training: mask_time_indices = _compute_mask_indices( (batch_size, sequence_length), mask_prob=self.config.mask_time_prob, mask_length=self.config.mask_time_length, attention_mask=attention_mask, min_masks=self.config.mask_time_min_masks, ) mask_time_indices = paddle.to_tensor( mask_time_indices, dtype=paddle.bool) hidden_states[mask_time_indices] = self.masked_spec_embed.to( hidden_states.dtype) if self.config.mask_feature_prob > 0 and self.training: # generate indices & apply SpecAugment along feature axis mask_feature_indices = _compute_mask_indices( (batch_size, hidden_size), mask_prob=self.config.mask_feature_prob, mask_length=self.config.mask_feature_length, min_masks=self.config.mask_feature_min_masks, ) mask_feature_indices = paddle.to_tensor( mask_feature_indices, dtype=paddle.bool) mask_feature_indices = mask_feature_indices[:, None].expand( -1, sequence_length, -1) hidden_states[mask_feature_indices] = 0 return hidden_states def forward( self, input_values: Optional[paddle.Tensor], attention_mask: Optional[paddle.Tensor]=None, mask_time_indices: Optional[paddle.Tensor]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None, ) -> Union[Tuple, Wav2Vec2BaseModelOutput]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = (output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states) return_dict = return_dict if return_dict is not None else self.config.use_return_dict extract_features = self.feature_extractor(input_values) extract_features = extract_features.transpose([0, 2, 1]) if attention_mask is not None: # compute reduced attention_mask corresponding to feature vectors attention_mask = self._get_feature_vector_attention_mask( extract_features.shape[1], attention_mask, add_adapter=False) hidden_states, extract_features = self.feature_projection( extract_features) hidden_states = self._mask_hidden_states( hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask) encoder_outputs = self.encoder( hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = encoder_outputs[0] if self.adapter is not None: hidden_states = self.adapter(hidden_states) if not return_dict: return (hidden_states, extract_features) + encoder_outputs[1:] return Wav2Vec2BaseModelOutput( last_hidden_state=hidden_states, extract_features=extract_features, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) def post_init(self): """ A method executed at the end of each Transformer model initialization, to execute code that needs the model's modules properly initialized (such as weight initialization). """ # self.init_weights() # self._backward_compatibility_gradient_checkpointing() pass class Wav2Vec2ConfigPure(): model_type = "wav2vec2" def __init__(self, config): self.output_attentions = False self.output_hidden_states = False self.use_return_dict = True self.hidden_size = config.hidden_size self.feat_extract_norm = config.feat_extract_norm self.feat_extract_activation = config.feat_extract_activation self.conv_dim = config.conv_dim self.conv_stride = config.conv_stride self.conv_kernel = config.conv_kernel self.conv_bias = config.conv_bias self.num_conv_pos_embeddings = config.num_conv_pos_embeddings self.num_conv_pos_embedding_groups = config.num_conv_pos_embedding_groups self.num_feat_extract_layers = len(self.conv_dim) self.num_hidden_layers = config.num_hidden_layers self.intermediate_size = config.intermediate_size self.hidden_act = config.hidden_act self.num_attention_heads = config.num_attention_heads self.hidden_dropout = config.hidden_dropout self.attention_dropout = config.attention_dropout self.activation_dropout = config.activation_dropout self.feat_proj_dropout = config.feat_proj_dropout self.final_dropout = config.final_dropout self.layerdrop = config.layerdrop self.layer_norm_eps = config.layer_norm_eps self.initializer_range = config.initializer_range self.do_stable_layer_norm = config.do_stable_layer_norm self.use_weighted_layer_sum = config.use_weighted_layer_sum if ((len(self.conv_stride) != self.num_feat_extract_layers) or (len(self.conv_kernel) != self.num_feat_extract_layers) or (len(self.conv_dim) != self.num_feat_extract_layers)): raise ValueError( "Configuration for convolutional layers is incorrect. It is required that `len(config.conv_dim)` ==" " `len(config.conv_stride)` == `len(config.conv_kernel)`, but is `len(config.conv_dim) =" f" {len(self.conv_dim)}`, `len(config.conv_stride) = {len(self.conv_stride)}`," f" `len(config.conv_kernel) = {len(self.conv_kernel)}`.") # fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779 self.apply_spec_augment = config.apply_spec_augment self.mask_time_prob = config.mask_time_prob self.mask_time_length = config.mask_time_length self.mask_time_min_masks = config.mask_time_min_masks self.mask_feature_prob = config.mask_feature_prob self.mask_feature_length = config.mask_feature_length self.mask_feature_min_masks = config.mask_feature_min_masks # parameters for pretraining with codevector quantized representations self.num_codevectors_per_group = config.num_codevectors_per_group self.num_codevector_groups = config.num_codevector_groups self.contrastive_logits_temperature = config.contrastive_logits_temperature self.feat_quantizer_dropout = config.feat_quantizer_dropout self.num_negatives = config.num_negatives self.codevector_dim = config.codevector_dim self.proj_codevector_dim = config.proj_codevector_dim self.diversity_loss_weight = config.diversity_loss_weight # adapter self.add_adapter = config.add_adapter self.adapter_kernel_size = config.adapter_kernel_size self.adapter_stride = config.adapter_stride self.num_adapter_layers = config.num_adapter_layers self.output_hidden_size = config.output_hidden_size or config.hidden_size @property def inputs_to_logits_ratio(self): return functools.reduce(operator.mul, self.conv_stride, 1)