You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
587 lines
22 KiB
587 lines
22 KiB
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
""" Paddle Hubert model."""
|
|
from dataclasses import dataclass
|
|
from dataclasses import field
|
|
from typing import Any
|
|
from typing import Dict
|
|
from typing import List
|
|
from typing import Optional
|
|
from typing import Tuple
|
|
|
|
import numpy as np
|
|
import paddle
|
|
import paddle.nn as nn
|
|
|
|
from paddlespeech.s2t.models.wav2vec2.modules.wav2vec2_model import ChoiceEnum
|
|
from paddlespeech.s2t.models.wav2vec2.modules.wav2vec2_model import compute_mask_indices
|
|
from paddlespeech.s2t.models.wav2vec2.modules.wav2vec2_model import ConvFeatureExtractionModel
|
|
from paddlespeech.s2t.models.wav2vec2.modules.wav2vec2_model import EXTRACTOR_MODE_CHOICES
|
|
from paddlespeech.s2t.models.wav2vec2.modules.wav2vec2_model import get_available_activation_fns
|
|
from paddlespeech.s2t.models.wav2vec2.modules.wav2vec2_model import GLU
|
|
from paddlespeech.s2t.models.wav2vec2.modules.wav2vec2_model import GradMultiply
|
|
from paddlespeech.s2t.models.wav2vec2.modules.wav2vec2_model import LAYER_TYPE_CHOICES
|
|
from paddlespeech.s2t.models.wav2vec2.modules.wav2vec2_model import MASKING_DISTRIBUTION_CHOICES
|
|
from paddlespeech.s2t.models.wav2vec2.modules.wav2vec2_model import TransformerEncoder
|
|
from paddlespeech.s2t.modules.align import LayerNorm
|
|
from paddlespeech.s2t.modules.align import Linear
|
|
from paddlespeech.s2t.utils.log import Log
|
|
|
|
logger = Log(__name__).getlog()
|
|
|
|
|
|
@dataclass
|
|
class HubertPretrainingConfig:
|
|
label_rate: float = field(
|
|
default=-1.0,
|
|
metadata={"help": "label frame rate. -1.0 for sequence label"}, )
|
|
sample_rate: int = field(
|
|
default=16_000,
|
|
metadata={
|
|
"help":
|
|
"target sample rate. audio files will be up/down "
|
|
"sampled to this rate"
|
|
}, )
|
|
normalize: bool = field(
|
|
default=False,
|
|
metadata={
|
|
"help": "if set, normalizes input to have 0 mean and unit variance"
|
|
}, )
|
|
enable_padding: bool = field(
|
|
default=False,
|
|
metadata={"help": "pad shorter samples instead of cropping"}, )
|
|
max_keep_size: Optional[int] = field(
|
|
default=None,
|
|
metadata={"help": "exclude sample longer than this"}, )
|
|
max_sample_size: Optional[int] = field(
|
|
default=None,
|
|
metadata={"help": "max sample size to crop to for batching"}, )
|
|
min_sample_size: Optional[int] = field(
|
|
default=None,
|
|
metadata={"help": "min sample size to crop to for batching"}, )
|
|
random_crop: Optional[bool] = field(
|
|
default=True,
|
|
metadata={"help": "always crop from the beginning if false"}, )
|
|
pad_audio: Optional[bool] = field(
|
|
default=False,
|
|
metadata={"help": "pad audio to the longest one in the batch if true"},
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class HubertConfig:
|
|
label_rate: float
|
|
|
|
extractor_mode: EXTRACTOR_MODE_CHOICES = field(
|
|
default="default",
|
|
metadata={
|
|
"help":
|
|
"mode for feature extractor. default has a single group "
|
|
"norm with d groups in the first conv block, whereas layer_norm "
|
|
"has layer norms in every block (meant to use with normalize=True)"
|
|
}, )
|
|
encoder_layers: int = field(
|
|
default=12, metadata={"help": "num encoder layers in the transformer"})
|
|
encoder_embed_dim: int = field(
|
|
default=768, metadata={"help": "encoder embedding dimension"})
|
|
encoder_ffn_embed_dim: int = field(
|
|
default=3072, metadata={"help": "encoder embedding dimension for FFN"})
|
|
encoder_attention_heads: int = field(
|
|
default=12, metadata={"help": "num encoder attention heads"})
|
|
activation_fn: ChoiceEnum(get_available_activation_fns()) = field(
|
|
default="gelu", metadata={"help": "activation function to use"})
|
|
layer_type: LAYER_TYPE_CHOICES = field(
|
|
default="transformer", metadata={"help": "layer type in encoder"})
|
|
|
|
# dropouts
|
|
dropout: float = field(
|
|
default=0.1,
|
|
metadata={"help": "dropout probability for the transformer"}, )
|
|
attention_dropout: float = field(
|
|
default=0.1,
|
|
metadata={"help": "dropout probability for attention weights"}, )
|
|
activation_dropout: float = field(
|
|
default=0.0,
|
|
metadata={"help": "dropout probability after activation in FFN"}, )
|
|
encoder_layerdrop: float = field(
|
|
default=0.0,
|
|
metadata={"help": "probability of dropping a tarnsformer layer"}, )
|
|
dropout_input: float = field(
|
|
default=0.0,
|
|
metadata={"help": "dropout to apply to the input (after feat extr)"}, )
|
|
dropout_features: float = field(
|
|
default=0.0,
|
|
metadata={"help": "dropout to apply to the features (after feat extr)"},
|
|
)
|
|
|
|
final_dim: int = field(
|
|
default=0,
|
|
metadata={
|
|
"help":
|
|
"project final representations and targets to this many "
|
|
"dimensions. set to encoder_embed_dim is <= 0"
|
|
}, )
|
|
untie_final_proj: bool = field(
|
|
default=False,
|
|
metadata={"help": "use separate projection for each target"}, )
|
|
layer_norm_first: bool = field(
|
|
default=False,
|
|
metadata={"help": "apply layernorm first in the transformer"}, )
|
|
conv_feature_layers: str = field(
|
|
default="[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2",
|
|
metadata={
|
|
"help":
|
|
"string describing convolutional feature extraction "
|
|
"layers in form of a python list that contains "
|
|
"[(dim, kernel_size, stride), ...]"
|
|
}, )
|
|
conv_bias: bool = field(
|
|
default=False, metadata={"help": "include bias in conv encoder"})
|
|
logit_temp: float = field(
|
|
default=0.1, metadata={"help": "temperature to divide logits by"})
|
|
target_glu: bool = field(
|
|
default=False, metadata={"help": "adds projection + glu to targets"})
|
|
feature_grad_mult: float = field(
|
|
default=1.0,
|
|
metadata={"help": "multiply feature extractor var grads by this"}, )
|
|
|
|
# masking
|
|
mask_length: int = field(default=10, metadata={"help": "mask length"})
|
|
mask_prob: float = field(
|
|
default=0.65,
|
|
metadata={"help": "probability of replacing a token with mask"}, )
|
|
mask_selection: MASKING_DISTRIBUTION_CHOICES = field(
|
|
default="static", metadata={"help": "how to choose mask length"})
|
|
mask_other: float = field(
|
|
default=0,
|
|
metadata={
|
|
"help":
|
|
"secondary mask argument "
|
|
"(used for more complex distributions), "
|
|
"see help in compute_mask_indicesh"
|
|
}, )
|
|
no_mask_overlap: bool = field(
|
|
default=False, metadata={"help": "whether to allow masks to overlap"})
|
|
mask_min_space: int = field(
|
|
default=1,
|
|
metadata={"help": "min space between spans (if no overlap is enabled)"},
|
|
)
|
|
|
|
# channel masking
|
|
mask_channel_length: int = field(
|
|
default=10,
|
|
metadata={"help": "length of the mask for features (channels)"}, )
|
|
mask_channel_prob: float = field(
|
|
default=0.0,
|
|
metadata={"help": "probability of replacing a feature with 0"}, )
|
|
mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field(
|
|
default="static",
|
|
metadata={"help": "how to choose mask length for channel masking"}, )
|
|
mask_channel_other: float = field(
|
|
default=0,
|
|
metadata={
|
|
"help":
|
|
"secondary mask argument "
|
|
"(used for more complex distributions), "
|
|
"see help in compute_mask_indicesh"
|
|
}, )
|
|
no_mask_channel_overlap: bool = field(
|
|
default=False,
|
|
metadata={"help": "whether to allow channel masks to overlap"}, )
|
|
mask_channel_min_space: int = field(
|
|
default=1,
|
|
metadata={"help": "min space between spans (if no overlap is enabled)"},
|
|
)
|
|
|
|
# positional embeddings
|
|
conv_pos: int = field(
|
|
default=128,
|
|
metadata={
|
|
"help": "number of filters for convolutional positional embeddings"
|
|
}, )
|
|
conv_pos_groups: int = field(
|
|
default=16,
|
|
metadata={
|
|
"help": "number of groups for convolutional positional embedding"
|
|
}, )
|
|
|
|
latent_temp: Tuple[float, float, float] = field(
|
|
default=(2, 0.5, 0.999995),
|
|
metadata={"help": "legacy (to be removed)"}, )
|
|
|
|
# loss computation
|
|
skip_masked: bool = field(
|
|
default=False,
|
|
metadata={"help": "skip computing losses over masked frames"}, )
|
|
skip_nomask: bool = field(
|
|
default=False,
|
|
metadata={"help": "skip computing losses over unmasked frames"}, )
|
|
|
|
checkpoint_activations: bool = field(
|
|
default=False,
|
|
metadata={
|
|
"help": "recompute activations and save memory for extra compute"
|
|
}, )
|
|
|
|
# FP16 optimization
|
|
required_seq_len_multiple: int = field(
|
|
default=2,
|
|
metadata={
|
|
"help":
|
|
"pad the input to encoder such that the sequence length is divisible by multiple"
|
|
}, )
|
|
|
|
# Conformer
|
|
depthwise_conv_kernel_size: int = field(
|
|
default=31,
|
|
metadata={
|
|
"help":
|
|
"depthwise-conv-kernel-size for convolution in conformer layer"
|
|
}, )
|
|
attn_type: str = field(
|
|
default="",
|
|
metadata={"help": "if espnet use ESPNET MHA"}, )
|
|
pos_enc_type: str = field(
|
|
default="abs",
|
|
metadata={"help": "Positional encoding type to use in conformer"}, )
|
|
fp16: bool = field(
|
|
default=False, metadata={"help": "If fp16 is being used"})
|
|
|
|
|
|
class HubertModel(nn.Layer):
|
|
def __init__(
|
|
self,
|
|
cfg: HubertConfig,
|
|
task_cfg: HubertPretrainingConfig,
|
|
dictionaries: List[Any], ) -> None:
|
|
super().__init__()
|
|
logger.info(f"HubertModel Config: {cfg}")
|
|
|
|
feature_enc_layers = eval(cfg.conv_feature_layers) # noqa
|
|
self.embed = feature_enc_layers[-1][0]
|
|
|
|
self.feature_extractor = ConvFeatureExtractionModel(
|
|
conv_layers=feature_enc_layers,
|
|
dropout=0.0,
|
|
mode=cfg.extractor_mode,
|
|
conv_bias=cfg.conv_bias, )
|
|
feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers])
|
|
self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / task_cfg.sample_rate
|
|
|
|
self.post_extract_proj = (Linear(self.embed, cfg.encoder_embed_dim) if
|
|
self.embed != cfg.encoder_embed_dim else None)
|
|
|
|
self.mask_prob = cfg.mask_prob
|
|
self.mask_selection = cfg.mask_selection
|
|
self.mask_other = cfg.mask_other
|
|
self.mask_length = cfg.mask_length
|
|
self.no_mask_overlap = cfg.no_mask_overlap
|
|
self.mask_min_space = cfg.mask_min_space
|
|
|
|
self.mask_channel_prob = cfg.mask_channel_prob
|
|
self.mask_channel_selection = cfg.mask_channel_selection
|
|
self.mask_channel_other = cfg.mask_channel_other
|
|
self.mask_channel_length = cfg.mask_channel_length
|
|
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
|
|
self.mask_channel_min_space = cfg.mask_channel_min_space
|
|
|
|
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
|
self.dropout_features = nn.Dropout(cfg.dropout_features)
|
|
|
|
self.feature_grad_mult = cfg.feature_grad_mult
|
|
self.logit_temp = cfg.logit_temp
|
|
self.skip_masked = cfg.skip_masked
|
|
self.skip_nomask = cfg.skip_nomask
|
|
|
|
final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim
|
|
|
|
self.mask_emb = paddle.create_parameter(
|
|
shape=[cfg.encoder_embed_dim],
|
|
dtype='float32',
|
|
default_initializer=paddle.nn.initializer.Uniform(low=0), )
|
|
|
|
self.encoder = TransformerEncoder(cfg)
|
|
self.layer_norm = LayerNorm(self.embed)
|
|
|
|
self.target_glu = None
|
|
if cfg.target_glu:
|
|
self.target_glu = nn.Sequential(
|
|
Linear(final_dim, final_dim * 2), GLU())
|
|
|
|
self.untie_final_proj = cfg.untie_final_proj
|
|
if self.untie_final_proj:
|
|
self.final_proj = Linear(cfg.encoder_embed_dim,
|
|
final_dim * len(dictionaries))
|
|
else:
|
|
self.final_proj = Linear(cfg.encoder_embed_dim, final_dim)
|
|
|
|
# modules below are not needed during fine-tuning
|
|
if any([d is None for d in dictionaries]):
|
|
logger.info(
|
|
"cannot find dictionary. assume will be used for fine-tuning")
|
|
else:
|
|
self.num_classes = [len(d) for d in dictionaries]
|
|
self.label_embs_concat = paddle.create_parameter(
|
|
shape=[sum(self.num_classes), final_dim],
|
|
dtype='float32',
|
|
default_initializer=paddle.nn.initializer.Uniform(low=0), )
|
|
|
|
@classmethod
|
|
def build_model(cls, cfg: HubertConfig, task):
|
|
"""Build a new model instance."""
|
|
|
|
model = HubertModel(cfg, task.cfg, task.dictionaries)
|
|
return model
|
|
|
|
def apply_mask(self, x, padding_mask, target_list):
|
|
B, T, C = x.shape
|
|
if self.mask_prob > 0:
|
|
mask_indices = compute_mask_indices(
|
|
(B, T),
|
|
padding_mask,
|
|
self.mask_prob,
|
|
self.mask_length,
|
|
self.mask_selection,
|
|
self.mask_other,
|
|
min_masks=2,
|
|
no_overlap=self.no_mask_overlap,
|
|
min_space=self.mask_min_space, )
|
|
|
|
mask_indices = paddle.to_tensor(
|
|
mask_indices, dtype='int64', place=x.place)
|
|
x[mask_indices] = self.mask_emb
|
|
else:
|
|
mask_indices = None
|
|
|
|
if self.mask_channel_prob > 0:
|
|
mask_channel_indices = compute_mask_indices(
|
|
(B, C),
|
|
None,
|
|
self.mask_channel_prob,
|
|
self.mask_channel_length,
|
|
self.mask_channel_selection,
|
|
self.mask_channel_other,
|
|
no_overlap=self.no_mask_channel_overlap,
|
|
min_space=self.mask_channel_min_space, )
|
|
mask_channel_indices = (paddle.to_tensor(
|
|
mask_channel_indices, dtype='int64', place=x.place).unsqueeze(1)
|
|
.expand([-1, T, -1]))
|
|
x[mask_channel_indices] = 0
|
|
|
|
return x, mask_indices
|
|
|
|
def compute_nce(self, x, pos, negs):
|
|
neg_is_pos = (pos == negs).all(-1)
|
|
pos = pos.unsqueeze(0)
|
|
targets = paddle.concat([pos, negs], axis=0)
|
|
|
|
logits = paddle.nn.functional.cosine_similarity(
|
|
x.astype('float32'), targets.astype('float32'), axis=-1)
|
|
logits /= self.logit_temp
|
|
if paddle.any(neg_is_pos):
|
|
logits[1:][neg_is_pos] = float("-inf")
|
|
logits = logits.transpose([1, 0]) # (num_x, num_cls+1)
|
|
return logits
|
|
|
|
def forward_features(self, source: paddle.Tensor) -> paddle.Tensor:
|
|
if self.feature_grad_mult > 0:
|
|
features = self.feature_extractor(source)
|
|
if self.feature_grad_mult != 1.0:
|
|
features = GradMultiply.apply(features, self.feature_grad_mult)
|
|
else:
|
|
with paddle.no_grad():
|
|
features = self.feature_extractor(source)
|
|
return features
|
|
|
|
def forward_targets(
|
|
self,
|
|
features: paddle.Tensor,
|
|
target_list: List[paddle.Tensor],
|
|
) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
|
# Trim features to ensure labels exist and then get aligned labels
|
|
feat_tsz = features.shape[2]
|
|
targ_tsz = min([t.shape[1] for t in target_list])
|
|
if self.feat2tar_ratio * feat_tsz > targ_tsz:
|
|
feat_tsz = int(targ_tsz / self.feat2tar_ratio)
|
|
features = features[:, :, :feat_tsz]
|
|
target_inds = paddle.arange(feat_tsz).astype(
|
|
'float32') * self.feat2tar_ratio
|
|
target_list = [t[:, target_inds.astype('int64')] for t in target_list]
|
|
return features, target_list
|
|
|
|
def forward_padding_mask(
|
|
self,
|
|
features: paddle.Tensor,
|
|
padding_mask: paddle.Tensor, ) -> paddle.Tensor:
|
|
extra = padding_mask.shape[1] % features.shape[1]
|
|
if extra > 0:
|
|
padding_mask = padding_mask[:, :-extra]
|
|
padding_mask = paddle.reshape(
|
|
padding_mask, [padding_mask.shape[0], features.shape[1], -1])
|
|
padding_mask = paddle.all(padding_mask, axis=-1)
|
|
return padding_mask
|
|
|
|
def forward(
|
|
self,
|
|
source: paddle.Tensor,
|
|
target_list: Optional[List[paddle.Tensor]]=None,
|
|
padding_mask: Optional[paddle.Tensor]=None,
|
|
mask: bool=True,
|
|
features_only: bool=False,
|
|
output_layer: Optional[int]=None, ) -> Dict[str, paddle.Tensor]:
|
|
"""output layer is 1-based"""
|
|
features = self.forward_features(source)
|
|
if target_list is not None:
|
|
features, target_list = self.forward_targets(features, target_list)
|
|
|
|
features_pen = features.pow(2).mean()
|
|
|
|
features = features.transpose([0, 2, 1])
|
|
features = self.layer_norm(features)
|
|
unmasked_features = features.clone()
|
|
|
|
if padding_mask is not None:
|
|
padding_mask = self.forward_padding_mask(features, padding_mask)
|
|
|
|
if self.post_extract_proj is not None:
|
|
features = self.post_extract_proj(features)
|
|
|
|
features = self.dropout_input(features)
|
|
unmasked_features = self.dropout_features(unmasked_features)
|
|
|
|
if mask:
|
|
x, mask_indices = self.apply_mask(features, padding_mask,
|
|
target_list)
|
|
else:
|
|
x = features
|
|
mask_indices = None
|
|
|
|
# feature: (B, T, D), float
|
|
# target: (B, T), long
|
|
# x: (B, T, D), float
|
|
# padding_mask: (B, T), bool
|
|
# mask_indices: (B, T), bool
|
|
x, _ = self.encoder(
|
|
x,
|
|
padding_mask=padding_mask,
|
|
layer=None if output_layer is None else output_layer - 1, )
|
|
|
|
if features_only:
|
|
return {"x": x, "padding_mask": padding_mask, "features": features}
|
|
|
|
def compute_pred(self, proj_x, target, label_embs):
|
|
# compute logits for the i-th label set
|
|
y = paddle.index_select(
|
|
label_embs, index=target.astype('int64'), axis=0)
|
|
negs = paddle.expand(
|
|
label_embs.unsqueeze(1),
|
|
[label_embs.shape[0], proj_x.shape[0], label_embs.shape[-1]])
|
|
if self.target_glu:
|
|
y = self.target_glu(y)
|
|
negs = self.target_glu(negs)
|
|
# proj_x: (S, D)
|
|
# y: (S, D)
|
|
# negs: (Neg, S, D)
|
|
return self.compute_nce(proj_x, y, negs)
|
|
|
|
label_embs_list = self.label_embs_concat.split(self.num_classes, 0)
|
|
|
|
if not self.skip_masked:
|
|
masked_indices = paddle.logical_and(~padding_mask, mask_indices)
|
|
proj_x_m = self.final_proj(x[masked_indices])
|
|
if self.untie_final_proj:
|
|
proj_x_m_list = proj_x_m.chunk(len(target_list), dim=-1)
|
|
else:
|
|
proj_x_m_list = [proj_x_m for _ in range(len(target_list))]
|
|
logit_m_list = [
|
|
compute_pred(proj_x_m, t[masked_indices], label_embs_list[i])
|
|
for i, (proj_x_m, t
|
|
) in enumerate(zip(proj_x_m_list, target_list))
|
|
]
|
|
else:
|
|
logit_m_list = [None for _ in target_list]
|
|
|
|
if not self.skip_nomask:
|
|
nomask_indices = paddle.logical_and(~padding_mask, ~mask_indices)
|
|
proj_x_u = self.final_proj(x[nomask_indices])
|
|
if self.untie_final_proj:
|
|
proj_x_u_list = proj_x_u.chunk(len(target_list), dim=-1)
|
|
else:
|
|
proj_x_u_list = [proj_x_u for _ in range(len(target_list))]
|
|
|
|
logit_u_list = [
|
|
compute_pred(proj_x_u, t[nomask_indices], label_embs_list[i])
|
|
for i, (proj_x_u, t
|
|
) in enumerate(zip(proj_x_u_list, target_list))
|
|
]
|
|
else:
|
|
logit_u_list = [None for _ in target_list]
|
|
|
|
result = {
|
|
"logit_m_list": logit_m_list,
|
|
"logit_u_list": logit_u_list,
|
|
"padding_mask": padding_mask,
|
|
"features_pen": features_pen,
|
|
}
|
|
return result
|
|
|
|
def extract_features(
|
|
self,
|
|
source: paddle.Tensor,
|
|
padding_mask: Optional[paddle.Tensor]=None,
|
|
mask: bool=False,
|
|
ret_conv: bool=False,
|
|
output_layer: Optional[int]=None,
|
|
) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
|
res = self.forward(
|
|
source,
|
|
padding_mask=padding_mask,
|
|
mask=mask,
|
|
features_only=True,
|
|
output_layer=output_layer, )
|
|
feature = res["features"] if ret_conv else res["x"]
|
|
return feature, res["padding_mask"]
|
|
|
|
def get_logits(self, net_output, is_masked=True):
|
|
if is_masked:
|
|
logits_list = net_output["logit_m_list"]
|
|
else:
|
|
logits_list = net_output["logit_u_list"]
|
|
logits_list = [
|
|
paddle.cast(x, 'float32') for x in logits_list if x is not None
|
|
]
|
|
return logits_list
|
|
|
|
def get_targets(self, net_output, is_masked=True):
|
|
logits_list = self.get_logits(net_output, is_masked)
|
|
targets_list = [
|
|
paddle.zeros_like(x, dtype='int64') for x in logits_list
|
|
]
|
|
return targets_list
|
|
|
|
def get_extra_losses(self, net_output):
|
|
extra_losses = []
|
|
names = []
|
|
|
|
if "features_pen" in net_output:
|
|
extra_losses.append(net_output["features_pen"])
|
|
names.append("features_pen")
|
|
|
|
return extra_losses, names
|
|
|
|
def remove_pretraining_modules(self):
|
|
self.target_glu = None
|
|
self.final_proj = None
|