pull/3088/head
th.zhang 2 years ago
parent 0a8d95c6b6
commit c85d61cccc

@ -1,7 +1,7 @@
############################################
# Network Architecture #
############################################
freeze_hubert: True
freeze_hubert: False
normalize_wav: True
output_norm: True
init_type: kaiming_uniform # !Warning: need to convergence
@ -14,11 +14,20 @@ ctc:
enc_n_units: 1024
blank_id: 0
dropout_rate: 0.0
hubert_params_path: "exp/hubert/pd_hubert.pdparams"
hubert_params_path: "exp/hubert/pd_hubert_no_fintune.pdparams"
task_cfg:
label_rate: 50.0
sample_rate: 16000
normalize: True
enable_padding: False
max_keep_size: None
max_sample_size: 250000
min_sample_size: 32000
single_target: False
random_crop: True
pad_audio: False
model_cfg:
dropout_input: 0.0
@ -37,7 +46,6 @@ model_cfg:
mask_channel_selection: static
mask_channel_other: 0.0
no_mask_channel_overlap: False
freeze_finetune_updates: 10000
feature_grad_mult: 0.0
layerdrop: 0.1
normalize: True
@ -69,7 +77,7 @@ model_cfg:
###########################################
# Data #
###########################################
train_manifest: data/manifest.train
train_manifest: data/manifest.train-clean-100
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test-clean
@ -81,7 +89,7 @@ unit_type: char
mean_std_filepath: ""
preprocess_config: conf/preprocess.yaml
sortagrad: -1 # Feed samples from shortest to longest ; -1: enabled for all epochs 0: disabled other: enabled for other epochs
batch_size: 8 # Different batch_size may cause large differences in results
batch_size: 2 # Different batch_size may cause large differences in results
maxlen_in: 51200000000 # if input length > maxlen-in batchsize is automatically reduced
maxlen_out: 1500000 # if output length > maxlen-out batchsize is automatically reduced
minibatches: 0 # for debug
@ -102,12 +110,13 @@ return_lens_rate: True
############################################
audio_augment: # for raw audio
sample_rate: 16000
speeds: [95, 100, 105]
###########################################
# Training #
###########################################
n_epoch: 1
accum_grad: 1
n_epoch: 3
accum_grad: 8
global_grad_clip: 5.0
model_optim: adadelta
model_optim_conf:
@ -120,7 +129,7 @@ model_scheduler_conf:
lr_decay: 1.0
hubert_optim: adadelta
hubert_optim_conf:
lr: 0.9
lr: 1.0
epsilon: 1.0e-6
rho: 0.95
hubert_scheduler: constantlr
@ -130,4 +139,4 @@ hubert_scheduler_conf:
log_interval: 1
checkpoint:
kbest_n: 50
latest_n: 5
latest_n: 5

@ -5,9 +5,9 @@ MODEL=hubert
. ./path.sh ${MODEL} || exit 1;
. ./cmd.sh || exit 1;
gpus=2
gpus=1,2
stage=1
stop_stage=3
stop_stage=1
conf_path=conf/${MODEL}ASR.yaml
ips= #xx.xx.xx.xx,xx.xx.xx.xx
decode_conf_path=conf/tuning/decode.yaml
@ -20,7 +20,7 @@ audio_file=data/demo_002_en.wav
avg_ckpt=avg_${avg_num}
ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}')
ckpt=test6
ckpt=train_clean_test_new_3
echo "checkpoint name ${ckpt}"
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
@ -30,7 +30,7 @@ fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${resume} ${ips}
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${resume} ${ips}
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
@ -40,7 +40,7 @@ fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# greedy search decoder
CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
CUDA_VISIBLE_DEVICES=1 ./local/test.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then

@ -185,7 +185,7 @@ class HubertASRTrainer(Trainer):
utt, wav, wavs_lens, target, target_lens = batch
wavs_lens_rate = wavs_lens / wav.shape[1]
wav = wav[:, :, 0]
logger.info('training utt ids: {}'.format(utt))
if hasattr(train_conf, 'audio_augment'):
wav = self.speech_augmentation(wav, wavs_lens_rate)

@ -12,15 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from typing import Dict
from typing import List
from typing import Tuple
from typing import Any, Optional
from typing import Dict, List, Tuple, Any
from dataclasses import dataclass, field, is_dataclass
from copy import deepcopy
from omegaconf import II, MISSING, open_dict
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
@ -323,7 +318,7 @@ class HubertASR(nn.Layer):
class HubertBase(nn.Layer):
"""Wav2vec2 model"""
"""Hubert model"""
def __init__(self, config: dict):
super().__init__()

@ -6,14 +6,14 @@
# S3PRL Team has no contribution to this file
# The file was copied from fairseq to remove the dependency on the entire fairseq package
import logging
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import paddle
import paddle.nn as nn
from paddlespeech.s2t.modules.align import Linear
from paddlespeech.s2t.modules.align import LayerNorm
from paddlespeech.s2t.models.wav2vec2.modules.wav2vec2_model import (
EXTRACTOR_MODE_CHOICES,
LAYER_TYPE_CHOICES,
@ -27,9 +27,9 @@ from paddlespeech.s2t.models.wav2vec2.modules.wav2vec2_model import (
get_available_activation_fns,
GLU,
)
from paddlespeech.s2t.utils.log import Log
logger = logging.getLogger(__name__)
logger = Log(__name__).getlog()
@dataclass
class HubertPretrainingConfig:
@ -302,7 +302,7 @@ class HubertModel(nn.Layer):
self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / task_cfg.sample_rate
self.post_extract_proj = (
nn.Linear(self.embed, cfg.encoder_embed_dim)
Linear(self.embed, cfg.encoder_embed_dim)
if self.embed != cfg.encoder_embed_dim
else None
)
@ -334,7 +334,7 @@ class HubertModel(nn.Layer):
self.mask_emb = paddle.create_parameter(
shape=[cfg.encoder_embed_dim],
dtype='float32',
default_initializer=paddle.nn.initializer.Uniform(),
default_initializer=paddle.nn.initializer.Uniform(low=0),
)
self.encoder = TransformerEncoder(cfg)
@ -343,16 +343,16 @@ class HubertModel(nn.Layer):
self.target_glu = None
if cfg.target_glu:
self.target_glu = nn.Sequential(
nn.Linear(final_dim, final_dim * 2), GLU()
Linear(final_dim, final_dim * 2), GLU()
)
self.untie_final_proj = cfg.untie_final_proj
if self.untie_final_proj:
self.final_proj = nn.Linear(
self.final_proj = Linear(
cfg.encoder_embed_dim, final_dim * len(dictionaries)
)
else:
self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim)
self.final_proj = Linear(cfg.encoder_embed_dim, final_dim)
# modules below are not needed during fine-tuning
if any([d is None for d in dictionaries]):
@ -362,13 +362,8 @@ class HubertModel(nn.Layer):
self.label_embs_concat = paddle.create_parameter(
shape=[sum(self.num_classes), final_dim],
dtype='float32',
default_initializer=paddle.nn.initializer.Uniform(),
default_initializer=paddle.nn.initializer.Uniform(low=0),
)
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
super().upgrade_state_dict_named(state_dict, name)
return state_dict
@classmethod
def build_model(cls, cfg: HubertConfig, task):
@ -417,7 +412,7 @@ class HubertModel(nn.Layer):
return x, mask_indices
def compute_nce(x, pos, negs):
def compute_nce(self, x, pos, negs):
neg_is_pos = (pos == negs).all(-1)
pos = pos.unsqueeze(0)
targets = paddle.concat([pos, negs], axis=0)

@ -4,7 +4,6 @@
# LICENSE file in the root directory of this source tree.
# S3PRL has no contribution to this file
# The file was copied from fairseq to remove the dependency on the entire fairseq package
import logging
import math
import uuid
from dataclasses import dataclass
@ -16,15 +15,19 @@ from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import Tensor
from paddlespeech.s2t.modules.align import Linear
from paddlespeech.s2t.modules.align import LayerNorm
from paddlespeech.s2t.modules.align import Conv1D
from paddlespeech.s2t.modules.align import Conv2D
from paddlespeech.s2t.modules.align import Embedding
from paddlespeech.s2t.utils.log import Log
logger = logging.getLogger(__name__)
logger = Log(__name__).getlog()
class GLU(nn.Layer):
r"""Applies the gated linear unit function
@ -153,15 +156,19 @@ def quant_noise(module, p, block_size):
return module
# supported modules
assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2D))
assert isinstance(module, (Linear, Embedding, Conv2D))
# test whether module.weight has the right sizes wrt block_size
is_conv = len(module.weight.shape) == 4
# 2D matrix
if not is_conv:
if isinstance(module, Linear):
features_weight = module.weight.shape[0]
else:
features_weight = module.weight.shape[1]
assert (
module.weight.shape[1] %
features_weight %
block_size == 0), "Input features must be a multiple of block sizes"
# 4D matrix
@ -181,14 +188,20 @@ def quant_noise(module, p, block_size):
if not is_conv:
# gather weight and sizes
weight = mod.weight
in_features = weight.shape[1]
out_features = weight.shape[0]
if isinstance(module, Linear):
in_features = weight.shape[0]
out_features = weight.shape[1]
else:
in_features = weight.shape[1]
out_features = weight.shape[0]
# split weight matrix into blocks and randomly drop selected blocks
mask = paddle.zeros(
[in_features // block_size * out_features],
dtype=paddle.bool)
mask.bernoulli_(p)
# the implementation of bernoulli_, p=0.5
mask = paddle.ones_like(mask) * 0.5
mask = paddle.bernoulli(mask)
mask = mask.unsqueeze(1).tile([1, block_size]).reshape(
[-1, in_features])
@ -203,12 +216,18 @@ def quant_noise(module, p, block_size):
mask = paddle.zeros(
[in_channels // block_size * out_channels],
dtype=paddle.bool)
mask.bernoulli_(p)
# the implementation of bernoulli_, p=0.5
mask = paddle.ones_like(mask) * 0.5
mask = paddle.bernoulli(mask)
mask = mask.unsqueeze(1).tile([1, block_size]).reshape(
[-1, in_channels])
else:
mask = paddle.zeros(weight.shape)
mask.bernoulli_(p)
# the implementation of bernoulli_, p=0.5
mask = paddle.ones_like(mask) * 0.5
mask = paddle.bernoulli(mask)
mask = mask.unsqueeze(1).tile([1, in_channels, 1, 1])
# scale weights and apply mask
@ -281,29 +300,53 @@ class MultiheadAttention(nn.Layer):
assert not self.self_attention or self.qkv_same_dim, (
"Self-attention requires query, key and "
"value to be of the same size")
weight_attr = paddle.ParamAttr(initializer=nn.initializer.XavierUniform)
bias_attr = nn.initializer.Constant(0)
# self.k_proj = quant_noise(
# nn.Linear(self.kdim, embed_dim, weight_attr=weight_attr, bias_attr=bias), q_noise, qn_block_size
# )
# self.v_proj = quant_noise(
# nn.Linear(self.vdim, embed_dim, weight_attr=weight_attr, bias_attr=bias), q_noise, qn_block_size
# )
# self.q_proj = quant_noise(
# nn.Linear(embed_dim, embed_dim, weight_attr=weight_attr, bias_attr=bias), q_noise, qn_block_size
# )
# self.out_proj = quant_noise(
# nn.Linear(embed_dim, embed_dim, weight_attr=weight_attr, bias_attr=bias if not bias else bias_attr), q_noise, qn_block_size
# )
self.k_proj = nn.Linear(self.kdim, embed_dim)
self.v_proj = nn.Linear(self.vdim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
# Todo scaled initialization
# Empirically observed the convergence to be much better with
# the scaled initialization
weight_attr = nn.initializer.XavierUniform()
kv_proj_bias_attr = nn.initializer.XavierUniform()
out_proj_bias_attr = nn.initializer.Constant(0)
self.k_proj = quant_noise(
nn.Linear(self.kdim, embed_dim, weight_attr=weight_attr, bias_attr=bias if not bias else kv_proj_bias_attr), q_noise, qn_block_size
)
self.v_proj = quant_noise(
nn.Linear(self.vdim, embed_dim, weight_attr=weight_attr, bias_attr=bias if not bias else kv_proj_bias_attr), q_noise, qn_block_size
)
self.q_proj = quant_noise(
nn.Linear(embed_dim, embed_dim, weight_attr=weight_attr, bias_attr=bias), q_noise, qn_block_size
)
self.out_proj = quant_noise(
nn.Linear(embed_dim, embed_dim, weight_attr=weight_attr, bias_attr=bias if not bias else out_proj_bias_attr), q_noise, qn_block_size
)
# nn.initializer.XavierUniform(self.k_proj.weight, gain=1 / math.sqrt(2))
# nn.initializer.XavierUniform(self.v_proj.weight, gain=1 / math.sqrt(2))
# nn.initializer.XavierUniform(self.q_proj.weight, gain=1 / math.sqrt(2))
# else:
# self.k_proj.weight = paddle.ParamAttr()
# nn.initializer.XavierUniform(self.k_proj.weight)
# nn.initializer.XavierUniform(self.v_proj.weight)
# nn.initializer.XavierUniform(self.q_proj.weight)
# nn.initializer.XavierUniform(self.out_proj.weight)
# if self.out_proj.bias is not None:
# nn.initializer.Constant(self.out_proj.bias)
# if self.bias_k is not None:
# nn.initializer.XavierNormal(self.bias_k)
# if self.bias_v is not None:
# nn.initializer.XavierNormal(self.bias_v)
# self.k_proj = Linear(self.kdim, embed_dim)
# self.v_proj = Linear(self.vdim, embed_dim)
# self.q_proj = Linear(embed_dim, embed_dim)
# self.out_proj = Linear(embed_dim, embed_dim)
if add_bias_kv:
self.bias_k = paddle.create_parameter(
@ -327,26 +370,26 @@ class MultiheadAttention(nn.Layer):
def prepare_for_onnx_export_(self):
self.onnx_trace = True
# def reset_parameters(self):
# if self.qkv_same_dim:
# # Empirically observed the convergence to be much better with
# # the scaled initialization
# nn.initializer.XavierUniform(self.k_proj.weight, gain=1 / math.sqrt(2))
# nn.initializer.XavierUniform(self.v_proj.weight, gain=1 / math.sqrt(2))
# nn.initializer.XavierUniform(self.q_proj.weight, gain=1 / math.sqrt(2))
# else:
# self.k_proj.weight = paddle.ParamAttr()
# nn.initializer.XavierUniform(self.k_proj.weight)
# nn.initializer.XavierUniform(self.v_proj.weight)
# nn.initializer.XavierUniform(self.q_proj.weight)
# nn.initializer.XavierUniform(self.out_proj.weight)
# if self.out_proj.bias is not None:
# nn.initializer.Constant(self.out_proj.bias)
# if self.bias_k is not None:
# nn.initializer.XavierNormal(self.bias_k)
# if self.bias_v is not None:
# nn.initializer.XavierNormal(self.bias_v)
def reset_parameters(self):
if self.qkv_same_dim:
# Empirically observed the convergence to be much better with
# the scaled initialization
nn.initializer.XavierUniform(self.k_proj.weight, gain=1 / math.sqrt(2))
nn.initializer.XavierUniform(self.v_proj.weight, gain=1 / math.sqrt(2))
nn.initializer.XavierUniform(self.q_proj.weight, gain=1 / math.sqrt(2))
else:
self.k_proj.weight = paddle.ParamAttr()
nn.initializer.XavierUniform(self.k_proj.weight)
nn.initializer.XavierUniform(self.v_proj.weight)
nn.initializer.XavierUniform(self.q_proj.weight)
nn.initializer.XavierUniform(self.out_proj.weight)
if self.out_proj.bias is not None:
nn.initializer.Constant(self.out_proj.bias)
if self.bias_k is not None:
nn.initializer.XavierNormal(self.bias_k)
if self.bias_v is not None:
nn.initializer.XavierNormal(self.bias_v)
def _get_reserve_head_index(self, num_heads_to_keep: int):
k_proj_heads_norm = []
@ -357,15 +400,15 @@ class MultiheadAttention(nn.Layer):
start_idx = i * self.head_dim
end_idx = (i + 1) * self.head_dim
k_proj_heads_norm.append(
paddle.sum(paddle.abs(self.k_proj.weight[start_idx:end_idx, ]))
paddle.sum(paddle.abs(self.k_proj.weight[:, start_idx:end_idx]))
.tolist() + paddle.sum(
paddle.abs(self.k_proj.bias[start_idx:end_idx])).tolist())
q_proj_heads_norm.append(
paddle.sum(paddle.abs(self.q_proj.weight[start_idx:end_idx, ]))
paddle.sum(paddle.abs(self.q_proj.weight[:, start_idx:end_idx]))
.tolist() + paddle.sum(
paddle.abs(self.q_proj.bias[start_idx:end_idx])).tolist())
v_proj_heads_norm.append(
paddle.sum(paddle.abs(self.v_proj.weight[start_idx:end_idx, ]))
paddle.sum(paddle.abs(self.v_proj.weight[:, start_idx:end_idx]))
.tolist() + paddle.sum(
paddle.abs(self.v_proj.bias[start_idx:end_idx])).tolist())
@ -395,24 +438,24 @@ class MultiheadAttention(nn.Layer):
for ele in reserve_head_index:
start_idx, end_idx = ele
new_q_weight.append(self.q_proj.weight[start_idx:end_idx, ])
new_q_weight.append(self.q_proj.weight[:, start_idx:end_idx])
new_q_bias.append(self.q_proj.bias[start_idx:end_idx])
new_k_weight.append(self.k_proj.weight[start_idx:end_idx, ])
new_k_weight.append(self.k_proj.weight[:, start_idx:end_idx])
new_k_bias.append(self.k_proj.bias[start_idx:end_idx])
new_v_weight.append(self.v_proj.weight[start_idx:end_idx, ])
new_v_weight.append(self.v_proj.weight[:, start_idx:end_idx])
new_v_bias.append(self.v_proj.bias[start_idx:end_idx])
new_out_proj_weight.append(
self.out_proj.weight[:, start_idx:end_idx])
self.out_proj.weight[start_idx:end_idx, ])
new_q_weight = paddle.concat(new_q_weight).detach()
new_k_weight = paddle.concat(new_k_weight).detach()
new_v_weight = paddle.concat(new_v_weight).detach()
new_q_weight = paddle.concat(new_q_weight, axis=-1).detach()
new_k_weight = paddle.concat(new_k_weight, axis=-1).detach()
new_v_weight = paddle.concat(new_v_weight, axis=-1).detach()
new_out_proj_weight = paddle.concat(
new_out_proj_weight, axis=-1).detach()
new_out_proj_weight).detach()
new_q_weight.stop_gradient = False
new_k_weight.stop_gradient = False
new_v_weight.stop_gradient = False
@ -566,11 +609,11 @@ class MultiheadAttention(nn.Layer):
assert (embed_dim == self.embed_dim
), f"query dim {embed_dim} != {self.embed_dim}"
assert list(query.shape) == [tgt_len, bsz, embed_dim]
# if key is not None:
# src_len, key_bsz, _ = key.size()
# if not torch.jit.is_scripting():
# assert value is not None
# assert src_len, key_bsz == value.shape[:2]
if key is not None:
src_len, key_bsz, _ = key.shape
# if not torch.jit.is_scripting():
# assert value is not None
# assert src_len, key_bsz == value.shape[:2]
# if (
# not self.onnx_trace
@ -848,7 +891,7 @@ class MultiheadAttention(nn.Layer):
new_key_padding_mask = paddle.concat([
paddle.cast(prev_key_padding_mask, 'float32'),
paddle.cast(key_padding_mask, 'float32')
], axis == 1)
], axis = 1)
# During incremental decoding, as the padding token enters and
# leaves the frame, there will be a time when prev or current
# is None
@ -859,7 +902,7 @@ class MultiheadAttention(nn.Layer):
new_key_padding_mask = paddle.concat([
paddle.cast(prev_key_padding_mask, 'float32'),
paddle.cast(filler, 'float32')
], axis == 1)
], axis = 1)
else:
new_key_padding_mask = prev_key_padding_mask
elif key_padding_mask is not None:
@ -869,7 +912,7 @@ class MultiheadAttention(nn.Layer):
new_key_padding_mask = paddle.concat([
paddle.cast(filler, 'float32'),
paddle.cast(key_padding_mask, 'float32')
], axis == 1)
], axis = 1)
else:
new_key_padding_mask = paddle.cast(key_padding_mask, 'float32')
else:
@ -1022,7 +1065,7 @@ class GumbelVectorQuantizer(nn.Layer):
def block(input_dim, output_dim):
return nn.Sequential(
nn.Linear(input_dim, output_dim), activation)
Linear(input_dim, output_dim), activation)
inner_dim = self.input_dim * weight_proj_factor
self.weight_proj = nn.Sequential(
@ -1030,11 +1073,9 @@ class GumbelVectorQuantizer(nn.Layer):
block(self.input_dim if i == 0 else inner_dim, inner_dim)
for i in range(weight_proj_depth - 1)
],
nn.Linear(inner_dim, groups * num_vars), )
Linear(inner_dim, groups * num_vars), )
else:
self.weight_proj = nn.Linear(self.input_dim, groups * num_vars)
nn.initializer.Normal(mean=0, std=1)(self.weight_proj.weight)
nn.initializer.Zero()(self.weight_proj.bias)
self.weight_proj = Linear(self.input_dim, groups * num_vars, weight_attr=nn.initializer.Normal(mean=0, std=1), bias_attr=nn.initializer.Zero())
if isinstance(temp, str):
import ast
@ -1125,7 +1166,7 @@ class GumbelVectorQuantizer(nn.Layer):
if self.training:
x = F.gumbel_softmax(
x.astype('float32'), tau=self.curr_temp,
x.astype('float32'), temperature=self.curr_temp,
hard=True).astype(x.dtype)
else:
x = hard_x
@ -1192,22 +1233,11 @@ class TransposeLast(nn.Layer):
trans_dim[-1], trans_dim[-2] = trans_dim[-2], trans_dim[-1]
return x.transpose(trans_dim)
def LayerNorm(normalized_shape, eps=1e-5):
return nn.LayerNorm(
normalized_shape,
epsilon=eps,
weight_attr=paddle.ParamAttr(),
bias_attr=paddle.ParamAttr())
class Fp32LayerNorm(nn.LayerNorm):
class Fp32LayerNorm(LayerNorm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, input):
# import pdb
# pdb.set_trace()
output = F.layer_norm(
input.astype('float32'),
self._normalized_shape,
@ -1222,8 +1252,6 @@ class Fp32GroupNorm(nn.GroupNorm):
super().__init__(*args, **kwargs)
def forward(self, input):
# import pdb
# pdb.set_trace()
output = F.group_norm(
input.astype('float32'),
self._num_groups,
@ -1724,7 +1752,7 @@ class Wav2Vec2Model(nn.Layer):
mode=cfg.extractor_mode,
conv_bias=cfg.conv_bias, )
self.post_extract_proj = (nn.Linear(self.embed, cfg.encoder_embed_dim)
self.post_extract_proj = (Linear(self.embed, cfg.encoder_embed_dim)
if self.embed != cfg.encoder_embed_dim and
not cfg.quantize_input else None)
@ -1774,9 +1802,9 @@ class Wav2Vec2Model(nn.Layer):
time_first=True,
weight_proj_depth=cfg.quantizer_depth,
weight_proj_factor=cfg.quantizer_factor, )
self.project_q = nn.Linear(vq_dim, final_dim)
self.project_q = Linear(vq_dim, final_dim)
else:
self.project_q = nn.Linear(self.embed, final_dim)
self.project_q = Linear(self.embed, final_dim)
if cfg.quantize_input:
if cfg.same_quantizer and self.quantizer is not None:
@ -1794,7 +1822,7 @@ class Wav2Vec2Model(nn.Layer):
time_first=True,
weight_proj_depth=cfg.quantizer_depth,
weight_proj_factor=cfg.quantizer_factor, )
self.project_inp = nn.Linear(vq_dim, cfg.encoder_embed_dim)
self.project_inp = Linear(vq_dim, cfg.encoder_embed_dim)
self.mask_emb = self.create_parameter(
shape=[cfg.encoder_embed_dim],
@ -1809,9 +1837,9 @@ class Wav2Vec2Model(nn.Layer):
self.target_glu = None
if cfg.target_glu:
self.target_glu = nn.Sequential(
nn.Linear(final_dim, final_dim * 2), GLU())
Linear(final_dim, final_dim * 2), GLU())
self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim)
self.final_proj = Linear(cfg.encoder_embed_dim, final_dim)
def upgrade_state_dict_named(self, state_dict, name):
super().upgrade_state_dict_named(state_dict, name)
@ -2194,7 +2222,7 @@ class ConvFeatureExtractionModel(nn.Layer):
is_group_norm=False,
conv_bias=False, ):
def make_conv():
conv = nn.Conv1D(
conv = Conv1D(
n_in,
n_out,
k,
@ -2256,17 +2284,16 @@ class ConvFeatureExtractionModel(nn.Layer):
def make_conv_pos(e, k, g):
pos_conv = nn.Conv1D(
dropout = 0
std = math.sqrt((4 * (1.0 - dropout)) / (k * e))
pos_conv = Conv1D(
e,
e,
kernel_size=k,
padding=k // 2,
groups=g, )
dropout = 0
std = math.sqrt((4 * (1.0 - dropout)) / (k * e))
nn.initializer.Normal(mean=0, std=std)(pos_conv.weight)
nn.initializer.Constant(0)(pos_conv.bias)
groups=g,
weight_attr=nn.initializer.Normal(mean=0, std=std),
bias_attr=nn.initializer.Constant(0))
pos_conv = nn.utils.weight_norm(pos_conv, name="weight", dim=2)
pos_conv = nn.Sequential(pos_conv, SamePad(k), nn.GELU())
@ -2301,7 +2328,7 @@ class TransformerEncoder(nn.Layer):
def make_conv_block(e, k, g, l):
return nn.Sequential(*[
nn.Sequential(
nn.Conv1D(
Conv1D(
e,
e,
kernel_size=k,
@ -2454,8 +2481,8 @@ class TransformerSentenceEncoderLayer(nn.Layer):
# layer norm associated with the self attention layer
self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
self.fc1 = Linear(self.embedding_dim, ffn_embedding_dim)
self.fc2 = Linear(ffn_embedding_dim, self.embedding_dim)
# layer norm associated with the position wise feed-forward NN
self.final_layer_norm = LayerNorm(self.embedding_dim)

@ -58,8 +58,6 @@ class Wav2vec2ASR(nn.Layer):
reduction='mean')
def forward(self, wav, wavs_lens_rate, target, target_lens):
# import pdb
# pdb.set_trace()
if self.normalize_wav:
wav = F.layer_norm(wav, wav.shape)

Loading…
Cancel
Save