pull/3088/head
th.zhang 2 years ago
parent 0a8d95c6b6
commit c85d61cccc

@ -1,7 +1,7 @@
############################################ ############################################
# Network Architecture # # Network Architecture #
############################################ ############################################
freeze_hubert: True freeze_hubert: False
normalize_wav: True normalize_wav: True
output_norm: True output_norm: True
init_type: kaiming_uniform # !Warning: need to convergence init_type: kaiming_uniform # !Warning: need to convergence
@ -14,11 +14,20 @@ ctc:
enc_n_units: 1024 enc_n_units: 1024
blank_id: 0 blank_id: 0
dropout_rate: 0.0 dropout_rate: 0.0
hubert_params_path: "exp/hubert/pd_hubert.pdparams" hubert_params_path: "exp/hubert/pd_hubert_no_fintune.pdparams"
task_cfg: task_cfg:
label_rate: 50.0
sample_rate: 16000 sample_rate: 16000
normalize: True
enable_padding: False
max_keep_size: None
max_sample_size: 250000
min_sample_size: 32000
single_target: False
random_crop: True
pad_audio: False
model_cfg: model_cfg:
dropout_input: 0.0 dropout_input: 0.0
@ -37,7 +46,6 @@ model_cfg:
mask_channel_selection: static mask_channel_selection: static
mask_channel_other: 0.0 mask_channel_other: 0.0
no_mask_channel_overlap: False no_mask_channel_overlap: False
freeze_finetune_updates: 10000
feature_grad_mult: 0.0 feature_grad_mult: 0.0
layerdrop: 0.1 layerdrop: 0.1
normalize: True normalize: True
@ -69,7 +77,7 @@ model_cfg:
########################################### ###########################################
# Data # # Data #
########################################### ###########################################
train_manifest: data/manifest.train train_manifest: data/manifest.train-clean-100
dev_manifest: data/manifest.dev dev_manifest: data/manifest.dev
test_manifest: data/manifest.test-clean test_manifest: data/manifest.test-clean
@ -81,7 +89,7 @@ unit_type: char
mean_std_filepath: "" mean_std_filepath: ""
preprocess_config: conf/preprocess.yaml preprocess_config: conf/preprocess.yaml
sortagrad: -1 # Feed samples from shortest to longest ; -1: enabled for all epochs 0: disabled other: enabled for other epochs sortagrad: -1 # Feed samples from shortest to longest ; -1: enabled for all epochs 0: disabled other: enabled for other epochs
batch_size: 8 # Different batch_size may cause large differences in results batch_size: 2 # Different batch_size may cause large differences in results
maxlen_in: 51200000000 # if input length > maxlen-in batchsize is automatically reduced maxlen_in: 51200000000 # if input length > maxlen-in batchsize is automatically reduced
maxlen_out: 1500000 # if output length > maxlen-out batchsize is automatically reduced maxlen_out: 1500000 # if output length > maxlen-out batchsize is automatically reduced
minibatches: 0 # for debug minibatches: 0 # for debug
@ -102,12 +110,13 @@ return_lens_rate: True
############################################ ############################################
audio_augment: # for raw audio audio_augment: # for raw audio
sample_rate: 16000 sample_rate: 16000
speeds: [95, 100, 105]
########################################### ###########################################
# Training # # Training #
########################################### ###########################################
n_epoch: 1 n_epoch: 3
accum_grad: 1 accum_grad: 8
global_grad_clip: 5.0 global_grad_clip: 5.0
model_optim: adadelta model_optim: adadelta
model_optim_conf: model_optim_conf:
@ -120,7 +129,7 @@ model_scheduler_conf:
lr_decay: 1.0 lr_decay: 1.0
hubert_optim: adadelta hubert_optim: adadelta
hubert_optim_conf: hubert_optim_conf:
lr: 0.9 lr: 1.0
epsilon: 1.0e-6 epsilon: 1.0e-6
rho: 0.95 rho: 0.95
hubert_scheduler: constantlr hubert_scheduler: constantlr

@ -5,9 +5,9 @@ MODEL=hubert
. ./path.sh ${MODEL} || exit 1; . ./path.sh ${MODEL} || exit 1;
. ./cmd.sh || exit 1; . ./cmd.sh || exit 1;
gpus=2 gpus=1,2
stage=1 stage=1
stop_stage=3 stop_stage=1
conf_path=conf/${MODEL}ASR.yaml conf_path=conf/${MODEL}ASR.yaml
ips= #xx.xx.xx.xx,xx.xx.xx.xx ips= #xx.xx.xx.xx,xx.xx.xx.xx
decode_conf_path=conf/tuning/decode.yaml decode_conf_path=conf/tuning/decode.yaml
@ -20,7 +20,7 @@ audio_file=data/demo_002_en.wav
avg_ckpt=avg_${avg_num} avg_ckpt=avg_${avg_num}
ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}') ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}')
ckpt=test6 ckpt=train_clean_test_new_3
echo "checkpoint name ${ckpt}" echo "checkpoint name ${ckpt}"
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
@ -40,7 +40,7 @@ fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# greedy search decoder # greedy search decoder
CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 CUDA_VISIBLE_DEVICES=1 ./local/test.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then

@ -185,7 +185,7 @@ class HubertASRTrainer(Trainer):
utt, wav, wavs_lens, target, target_lens = batch utt, wav, wavs_lens, target, target_lens = batch
wavs_lens_rate = wavs_lens / wav.shape[1] wavs_lens_rate = wavs_lens / wav.shape[1]
wav = wav[:, :, 0] wav = wav[:, :, 0]
logger.info('training utt ids: {}'.format(utt))
if hasattr(train_conf, 'audio_augment'): if hasattr(train_conf, 'audio_augment'):
wav = self.speech_augmentation(wav, wavs_lens_rate) wav = self.speech_augmentation(wav, wavs_lens_rate)

@ -12,15 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections import defaultdict from collections import defaultdict
from typing import Dict from typing import Dict, List, Tuple, Any
from typing import List
from typing import Tuple
from typing import Any, Optional
from dataclasses import dataclass, field, is_dataclass from dataclasses import dataclass, field, is_dataclass
from copy import deepcopy from copy import deepcopy
from omegaconf import II, MISSING, open_dict
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
@ -323,7 +318,7 @@ class HubertASR(nn.Layer):
class HubertBase(nn.Layer): class HubertBase(nn.Layer):
"""Wav2vec2 model""" """Hubert model"""
def __init__(self, config: dict): def __init__(self, config: dict):
super().__init__() super().__init__()

@ -6,14 +6,14 @@
# S3PRL Team has no contribution to this file # S3PRL Team has no contribution to this file
# The file was copied from fairseq to remove the dependency on the entire fairseq package # The file was copied from fairseq to remove the dependency on the entire fairseq package
import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import numpy as np import numpy as np
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
from paddlespeech.s2t.modules.align import Linear
from paddlespeech.s2t.modules.align import LayerNorm
from paddlespeech.s2t.models.wav2vec2.modules.wav2vec2_model import ( from paddlespeech.s2t.models.wav2vec2.modules.wav2vec2_model import (
EXTRACTOR_MODE_CHOICES, EXTRACTOR_MODE_CHOICES,
LAYER_TYPE_CHOICES, LAYER_TYPE_CHOICES,
@ -27,9 +27,9 @@ from paddlespeech.s2t.models.wav2vec2.modules.wav2vec2_model import (
get_available_activation_fns, get_available_activation_fns,
GLU, GLU,
) )
from paddlespeech.s2t.utils.log import Log
logger = logging.getLogger(__name__) logger = Log(__name__).getlog()
@dataclass @dataclass
class HubertPretrainingConfig: class HubertPretrainingConfig:
@ -302,7 +302,7 @@ class HubertModel(nn.Layer):
self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / task_cfg.sample_rate self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / task_cfg.sample_rate
self.post_extract_proj = ( self.post_extract_proj = (
nn.Linear(self.embed, cfg.encoder_embed_dim) Linear(self.embed, cfg.encoder_embed_dim)
if self.embed != cfg.encoder_embed_dim if self.embed != cfg.encoder_embed_dim
else None else None
) )
@ -334,7 +334,7 @@ class HubertModel(nn.Layer):
self.mask_emb = paddle.create_parameter( self.mask_emb = paddle.create_parameter(
shape=[cfg.encoder_embed_dim], shape=[cfg.encoder_embed_dim],
dtype='float32', dtype='float32',
default_initializer=paddle.nn.initializer.Uniform(), default_initializer=paddle.nn.initializer.Uniform(low=0),
) )
self.encoder = TransformerEncoder(cfg) self.encoder = TransformerEncoder(cfg)
@ -343,16 +343,16 @@ class HubertModel(nn.Layer):
self.target_glu = None self.target_glu = None
if cfg.target_glu: if cfg.target_glu:
self.target_glu = nn.Sequential( self.target_glu = nn.Sequential(
nn.Linear(final_dim, final_dim * 2), GLU() Linear(final_dim, final_dim * 2), GLU()
) )
self.untie_final_proj = cfg.untie_final_proj self.untie_final_proj = cfg.untie_final_proj
if self.untie_final_proj: if self.untie_final_proj:
self.final_proj = nn.Linear( self.final_proj = Linear(
cfg.encoder_embed_dim, final_dim * len(dictionaries) cfg.encoder_embed_dim, final_dim * len(dictionaries)
) )
else: else:
self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim) self.final_proj = Linear(cfg.encoder_embed_dim, final_dim)
# modules below are not needed during fine-tuning # modules below are not needed during fine-tuning
if any([d is None for d in dictionaries]): if any([d is None for d in dictionaries]):
@ -362,13 +362,8 @@ class HubertModel(nn.Layer):
self.label_embs_concat = paddle.create_parameter( self.label_embs_concat = paddle.create_parameter(
shape=[sum(self.num_classes), final_dim], shape=[sum(self.num_classes), final_dim],
dtype='float32', dtype='float32',
default_initializer=paddle.nn.initializer.Uniform(), default_initializer=paddle.nn.initializer.Uniform(low=0),
) )
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
super().upgrade_state_dict_named(state_dict, name)
return state_dict
@classmethod @classmethod
def build_model(cls, cfg: HubertConfig, task): def build_model(cls, cfg: HubertConfig, task):
@ -417,7 +412,7 @@ class HubertModel(nn.Layer):
return x, mask_indices return x, mask_indices
def compute_nce(x, pos, negs): def compute_nce(self, x, pos, negs):
neg_is_pos = (pos == negs).all(-1) neg_is_pos = (pos == negs).all(-1)
pos = pos.unsqueeze(0) pos = pos.unsqueeze(0)
targets = paddle.concat([pos, negs], axis=0) targets = paddle.concat([pos, negs], axis=0)

@ -4,7 +4,6 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# S3PRL has no contribution to this file # S3PRL has no contribution to this file
# The file was copied from fairseq to remove the dependency on the entire fairseq package # The file was copied from fairseq to remove the dependency on the entire fairseq package
import logging
import math import math
import uuid import uuid
from dataclasses import dataclass from dataclasses import dataclass
@ -16,15 +15,19 @@ from typing import Dict
from typing import List from typing import List
from typing import Optional from typing import Optional
from typing import Tuple from typing import Tuple
import numpy as np import numpy as np
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle import Tensor from paddle import Tensor
from paddlespeech.s2t.modules.align import Linear
from paddlespeech.s2t.modules.align import LayerNorm
from paddlespeech.s2t.modules.align import Conv1D
from paddlespeech.s2t.modules.align import Conv2D
from paddlespeech.s2t.modules.align import Embedding
from paddlespeech.s2t.utils.log import Log
logger = logging.getLogger(__name__) logger = Log(__name__).getlog()
class GLU(nn.Layer): class GLU(nn.Layer):
r"""Applies the gated linear unit function r"""Applies the gated linear unit function
@ -153,15 +156,19 @@ def quant_noise(module, p, block_size):
return module return module
# supported modules # supported modules
assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2D)) assert isinstance(module, (Linear, Embedding, Conv2D))
# test whether module.weight has the right sizes wrt block_size # test whether module.weight has the right sizes wrt block_size
is_conv = len(module.weight.shape) == 4 is_conv = len(module.weight.shape) == 4
# 2D matrix # 2D matrix
if not is_conv: if not is_conv:
if isinstance(module, Linear):
features_weight = module.weight.shape[0]
else:
features_weight = module.weight.shape[1]
assert ( assert (
module.weight.shape[1] % features_weight %
block_size == 0), "Input features must be a multiple of block sizes" block_size == 0), "Input features must be a multiple of block sizes"
# 4D matrix # 4D matrix
@ -181,6 +188,10 @@ def quant_noise(module, p, block_size):
if not is_conv: if not is_conv:
# gather weight and sizes # gather weight and sizes
weight = mod.weight weight = mod.weight
if isinstance(module, Linear):
in_features = weight.shape[0]
out_features = weight.shape[1]
else:
in_features = weight.shape[1] in_features = weight.shape[1]
out_features = weight.shape[0] out_features = weight.shape[0]
@ -188,7 +199,9 @@ def quant_noise(module, p, block_size):
mask = paddle.zeros( mask = paddle.zeros(
[in_features // block_size * out_features], [in_features // block_size * out_features],
dtype=paddle.bool) dtype=paddle.bool)
mask.bernoulli_(p) # the implementation of bernoulli_, p=0.5
mask = paddle.ones_like(mask) * 0.5
mask = paddle.bernoulli(mask)
mask = mask.unsqueeze(1).tile([1, block_size]).reshape( mask = mask.unsqueeze(1).tile([1, block_size]).reshape(
[-1, in_features]) [-1, in_features])
@ -203,12 +216,18 @@ def quant_noise(module, p, block_size):
mask = paddle.zeros( mask = paddle.zeros(
[in_channels // block_size * out_channels], [in_channels // block_size * out_channels],
dtype=paddle.bool) dtype=paddle.bool)
mask.bernoulli_(p)
# the implementation of bernoulli_, p=0.5
mask = paddle.ones_like(mask) * 0.5
mask = paddle.bernoulli(mask)
mask = mask.unsqueeze(1).tile([1, block_size]).reshape( mask = mask.unsqueeze(1).tile([1, block_size]).reshape(
[-1, in_channels]) [-1, in_channels])
else: else:
mask = paddle.zeros(weight.shape) mask = paddle.zeros(weight.shape)
mask.bernoulli_(p)
# the implementation of bernoulli_, p=0.5
mask = paddle.ones_like(mask) * 0.5
mask = paddle.bernoulli(mask)
mask = mask.unsqueeze(1).tile([1, in_channels, 1, 1]) mask = mask.unsqueeze(1).tile([1, in_channels, 1, 1])
# scale weights and apply mask # scale weights and apply mask
@ -282,28 +301,52 @@ class MultiheadAttention(nn.Layer):
"Self-attention requires query, key and " "Self-attention requires query, key and "
"value to be of the same size") "value to be of the same size")
weight_attr = paddle.ParamAttr(initializer=nn.initializer.XavierUniform) # Todo scaled initialization
bias_attr = nn.initializer.Constant(0) # Empirically observed the convergence to be much better with
# self.k_proj = quant_noise( # the scaled initialization
# nn.Linear(self.kdim, embed_dim, weight_attr=weight_attr, bias_attr=bias), q_noise, qn_block_size weight_attr = nn.initializer.XavierUniform()
# ) kv_proj_bias_attr = nn.initializer.XavierUniform()
# self.v_proj = quant_noise( out_proj_bias_attr = nn.initializer.Constant(0)
# nn.Linear(self.vdim, embed_dim, weight_attr=weight_attr, bias_attr=bias), q_noise, qn_block_size
# )
# self.q_proj = quant_noise(
# nn.Linear(embed_dim, embed_dim, weight_attr=weight_attr, bias_attr=bias), q_noise, qn_block_size
# )
# self.out_proj = quant_noise( self.k_proj = quant_noise(
# nn.Linear(embed_dim, embed_dim, weight_attr=weight_attr, bias_attr=bias if not bias else bias_attr), q_noise, qn_block_size nn.Linear(self.kdim, embed_dim, weight_attr=weight_attr, bias_attr=bias if not bias else kv_proj_bias_attr), q_noise, qn_block_size
# ) )
self.k_proj = nn.Linear(self.kdim, embed_dim) self.v_proj = quant_noise(
nn.Linear(self.vdim, embed_dim, weight_attr=weight_attr, bias_attr=bias if not bias else kv_proj_bias_attr), q_noise, qn_block_size
)
self.q_proj = quant_noise(
nn.Linear(embed_dim, embed_dim, weight_attr=weight_attr, bias_attr=bias), q_noise, qn_block_size
)
self.out_proj = quant_noise(
nn.Linear(embed_dim, embed_dim, weight_attr=weight_attr, bias_attr=bias if not bias else out_proj_bias_attr), q_noise, qn_block_size
)
# nn.initializer.XavierUniform(self.k_proj.weight, gain=1 / math.sqrt(2))
# nn.initializer.XavierUniform(self.v_proj.weight, gain=1 / math.sqrt(2))
# nn.initializer.XavierUniform(self.q_proj.weight, gain=1 / math.sqrt(2))
# else:
# self.k_proj.weight = paddle.ParamAttr()
# nn.initializer.XavierUniform(self.k_proj.weight)
# nn.initializer.XavierUniform(self.v_proj.weight)
# nn.initializer.XavierUniform(self.q_proj.weight)
# nn.initializer.XavierUniform(self.out_proj.weight)
# if self.out_proj.bias is not None:
# nn.initializer.Constant(self.out_proj.bias)
# if self.bias_k is not None:
# nn.initializer.XavierNormal(self.bias_k)
# if self.bias_v is not None:
# nn.initializer.XavierNormal(self.bias_v)
# self.k_proj = Linear(self.kdim, embed_dim)
self.v_proj = nn.Linear(self.vdim, embed_dim) # self.v_proj = Linear(self.vdim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim) # self.q_proj = Linear(embed_dim, embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim) # self.out_proj = Linear(embed_dim, embed_dim)
if add_bias_kv: if add_bias_kv:
self.bias_k = paddle.create_parameter( self.bias_k = paddle.create_parameter(
@ -327,26 +370,26 @@ class MultiheadAttention(nn.Layer):
def prepare_for_onnx_export_(self): def prepare_for_onnx_export_(self):
self.onnx_trace = True self.onnx_trace = True
# def reset_parameters(self): def reset_parameters(self):
# if self.qkv_same_dim: if self.qkv_same_dim:
# # Empirically observed the convergence to be much better with # Empirically observed the convergence to be much better with
# # the scaled initialization # the scaled initialization
# nn.initializer.XavierUniform(self.k_proj.weight, gain=1 / math.sqrt(2)) nn.initializer.XavierUniform(self.k_proj.weight, gain=1 / math.sqrt(2))
# nn.initializer.XavierUniform(self.v_proj.weight, gain=1 / math.sqrt(2)) nn.initializer.XavierUniform(self.v_proj.weight, gain=1 / math.sqrt(2))
# nn.initializer.XavierUniform(self.q_proj.weight, gain=1 / math.sqrt(2)) nn.initializer.XavierUniform(self.q_proj.weight, gain=1 / math.sqrt(2))
# else: else:
# self.k_proj.weight = paddle.ParamAttr() self.k_proj.weight = paddle.ParamAttr()
# nn.initializer.XavierUniform(self.k_proj.weight) nn.initializer.XavierUniform(self.k_proj.weight)
# nn.initializer.XavierUniform(self.v_proj.weight) nn.initializer.XavierUniform(self.v_proj.weight)
# nn.initializer.XavierUniform(self.q_proj.weight) nn.initializer.XavierUniform(self.q_proj.weight)
# nn.initializer.XavierUniform(self.out_proj.weight) nn.initializer.XavierUniform(self.out_proj.weight)
# if self.out_proj.bias is not None: if self.out_proj.bias is not None:
# nn.initializer.Constant(self.out_proj.bias) nn.initializer.Constant(self.out_proj.bias)
# if self.bias_k is not None: if self.bias_k is not None:
# nn.initializer.XavierNormal(self.bias_k) nn.initializer.XavierNormal(self.bias_k)
# if self.bias_v is not None: if self.bias_v is not None:
# nn.initializer.XavierNormal(self.bias_v) nn.initializer.XavierNormal(self.bias_v)
def _get_reserve_head_index(self, num_heads_to_keep: int): def _get_reserve_head_index(self, num_heads_to_keep: int):
k_proj_heads_norm = [] k_proj_heads_norm = []
@ -357,15 +400,15 @@ class MultiheadAttention(nn.Layer):
start_idx = i * self.head_dim start_idx = i * self.head_dim
end_idx = (i + 1) * self.head_dim end_idx = (i + 1) * self.head_dim
k_proj_heads_norm.append( k_proj_heads_norm.append(
paddle.sum(paddle.abs(self.k_proj.weight[start_idx:end_idx, ])) paddle.sum(paddle.abs(self.k_proj.weight[:, start_idx:end_idx]))
.tolist() + paddle.sum( .tolist() + paddle.sum(
paddle.abs(self.k_proj.bias[start_idx:end_idx])).tolist()) paddle.abs(self.k_proj.bias[start_idx:end_idx])).tolist())
q_proj_heads_norm.append( q_proj_heads_norm.append(
paddle.sum(paddle.abs(self.q_proj.weight[start_idx:end_idx, ])) paddle.sum(paddle.abs(self.q_proj.weight[:, start_idx:end_idx]))
.tolist() + paddle.sum( .tolist() + paddle.sum(
paddle.abs(self.q_proj.bias[start_idx:end_idx])).tolist()) paddle.abs(self.q_proj.bias[start_idx:end_idx])).tolist())
v_proj_heads_norm.append( v_proj_heads_norm.append(
paddle.sum(paddle.abs(self.v_proj.weight[start_idx:end_idx, ])) paddle.sum(paddle.abs(self.v_proj.weight[:, start_idx:end_idx]))
.tolist() + paddle.sum( .tolist() + paddle.sum(
paddle.abs(self.v_proj.bias[start_idx:end_idx])).tolist()) paddle.abs(self.v_proj.bias[start_idx:end_idx])).tolist())
@ -395,24 +438,24 @@ class MultiheadAttention(nn.Layer):
for ele in reserve_head_index: for ele in reserve_head_index:
start_idx, end_idx = ele start_idx, end_idx = ele
new_q_weight.append(self.q_proj.weight[start_idx:end_idx, ]) new_q_weight.append(self.q_proj.weight[:, start_idx:end_idx])
new_q_bias.append(self.q_proj.bias[start_idx:end_idx]) new_q_bias.append(self.q_proj.bias[start_idx:end_idx])
new_k_weight.append(self.k_proj.weight[start_idx:end_idx, ]) new_k_weight.append(self.k_proj.weight[:, start_idx:end_idx])
new_k_bias.append(self.k_proj.bias[start_idx:end_idx]) new_k_bias.append(self.k_proj.bias[start_idx:end_idx])
new_v_weight.append(self.v_proj.weight[start_idx:end_idx, ]) new_v_weight.append(self.v_proj.weight[:, start_idx:end_idx])
new_v_bias.append(self.v_proj.bias[start_idx:end_idx]) new_v_bias.append(self.v_proj.bias[start_idx:end_idx])
new_out_proj_weight.append( new_out_proj_weight.append(
self.out_proj.weight[:, start_idx:end_idx]) self.out_proj.weight[start_idx:end_idx, ])
new_q_weight = paddle.concat(new_q_weight).detach() new_q_weight = paddle.concat(new_q_weight, axis=-1).detach()
new_k_weight = paddle.concat(new_k_weight).detach() new_k_weight = paddle.concat(new_k_weight, axis=-1).detach()
new_v_weight = paddle.concat(new_v_weight).detach() new_v_weight = paddle.concat(new_v_weight, axis=-1).detach()
new_out_proj_weight = paddle.concat( new_out_proj_weight = paddle.concat(
new_out_proj_weight, axis=-1).detach() new_out_proj_weight).detach()
new_q_weight.stop_gradient = False new_q_weight.stop_gradient = False
new_k_weight.stop_gradient = False new_k_weight.stop_gradient = False
new_v_weight.stop_gradient = False new_v_weight.stop_gradient = False
@ -566,8 +609,8 @@ class MultiheadAttention(nn.Layer):
assert (embed_dim == self.embed_dim assert (embed_dim == self.embed_dim
), f"query dim {embed_dim} != {self.embed_dim}" ), f"query dim {embed_dim} != {self.embed_dim}"
assert list(query.shape) == [tgt_len, bsz, embed_dim] assert list(query.shape) == [tgt_len, bsz, embed_dim]
# if key is not None: if key is not None:
# src_len, key_bsz, _ = key.size() src_len, key_bsz, _ = key.shape
# if not torch.jit.is_scripting(): # if not torch.jit.is_scripting():
# assert value is not None # assert value is not None
# assert src_len, key_bsz == value.shape[:2] # assert src_len, key_bsz == value.shape[:2]
@ -848,7 +891,7 @@ class MultiheadAttention(nn.Layer):
new_key_padding_mask = paddle.concat([ new_key_padding_mask = paddle.concat([
paddle.cast(prev_key_padding_mask, 'float32'), paddle.cast(prev_key_padding_mask, 'float32'),
paddle.cast(key_padding_mask, 'float32') paddle.cast(key_padding_mask, 'float32')
], axis == 1) ], axis = 1)
# During incremental decoding, as the padding token enters and # During incremental decoding, as the padding token enters and
# leaves the frame, there will be a time when prev or current # leaves the frame, there will be a time when prev or current
# is None # is None
@ -859,7 +902,7 @@ class MultiheadAttention(nn.Layer):
new_key_padding_mask = paddle.concat([ new_key_padding_mask = paddle.concat([
paddle.cast(prev_key_padding_mask, 'float32'), paddle.cast(prev_key_padding_mask, 'float32'),
paddle.cast(filler, 'float32') paddle.cast(filler, 'float32')
], axis == 1) ], axis = 1)
else: else:
new_key_padding_mask = prev_key_padding_mask new_key_padding_mask = prev_key_padding_mask
elif key_padding_mask is not None: elif key_padding_mask is not None:
@ -869,7 +912,7 @@ class MultiheadAttention(nn.Layer):
new_key_padding_mask = paddle.concat([ new_key_padding_mask = paddle.concat([
paddle.cast(filler, 'float32'), paddle.cast(filler, 'float32'),
paddle.cast(key_padding_mask, 'float32') paddle.cast(key_padding_mask, 'float32')
], axis == 1) ], axis = 1)
else: else:
new_key_padding_mask = paddle.cast(key_padding_mask, 'float32') new_key_padding_mask = paddle.cast(key_padding_mask, 'float32')
else: else:
@ -1022,7 +1065,7 @@ class GumbelVectorQuantizer(nn.Layer):
def block(input_dim, output_dim): def block(input_dim, output_dim):
return nn.Sequential( return nn.Sequential(
nn.Linear(input_dim, output_dim), activation) Linear(input_dim, output_dim), activation)
inner_dim = self.input_dim * weight_proj_factor inner_dim = self.input_dim * weight_proj_factor
self.weight_proj = nn.Sequential( self.weight_proj = nn.Sequential(
@ -1030,11 +1073,9 @@ class GumbelVectorQuantizer(nn.Layer):
block(self.input_dim if i == 0 else inner_dim, inner_dim) block(self.input_dim if i == 0 else inner_dim, inner_dim)
for i in range(weight_proj_depth - 1) for i in range(weight_proj_depth - 1)
], ],
nn.Linear(inner_dim, groups * num_vars), ) Linear(inner_dim, groups * num_vars), )
else: else:
self.weight_proj = nn.Linear(self.input_dim, groups * num_vars) self.weight_proj = Linear(self.input_dim, groups * num_vars, weight_attr=nn.initializer.Normal(mean=0, std=1), bias_attr=nn.initializer.Zero())
nn.initializer.Normal(mean=0, std=1)(self.weight_proj.weight)
nn.initializer.Zero()(self.weight_proj.bias)
if isinstance(temp, str): if isinstance(temp, str):
import ast import ast
@ -1125,7 +1166,7 @@ class GumbelVectorQuantizer(nn.Layer):
if self.training: if self.training:
x = F.gumbel_softmax( x = F.gumbel_softmax(
x.astype('float32'), tau=self.curr_temp, x.astype('float32'), temperature=self.curr_temp,
hard=True).astype(x.dtype) hard=True).astype(x.dtype)
else: else:
x = hard_x x = hard_x
@ -1192,22 +1233,11 @@ class TransposeLast(nn.Layer):
trans_dim[-1], trans_dim[-2] = trans_dim[-2], trans_dim[-1] trans_dim[-1], trans_dim[-2] = trans_dim[-2], trans_dim[-1]
return x.transpose(trans_dim) return x.transpose(trans_dim)
class Fp32LayerNorm(LayerNorm):
def LayerNorm(normalized_shape, eps=1e-5):
return nn.LayerNorm(
normalized_shape,
epsilon=eps,
weight_attr=paddle.ParamAttr(),
bias_attr=paddle.ParamAttr())
class Fp32LayerNorm(nn.LayerNorm):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def forward(self, input): def forward(self, input):
# import pdb
# pdb.set_trace()
output = F.layer_norm( output = F.layer_norm(
input.astype('float32'), input.astype('float32'),
self._normalized_shape, self._normalized_shape,
@ -1222,8 +1252,6 @@ class Fp32GroupNorm(nn.GroupNorm):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def forward(self, input): def forward(self, input):
# import pdb
# pdb.set_trace()
output = F.group_norm( output = F.group_norm(
input.astype('float32'), input.astype('float32'),
self._num_groups, self._num_groups,
@ -1724,7 +1752,7 @@ class Wav2Vec2Model(nn.Layer):
mode=cfg.extractor_mode, mode=cfg.extractor_mode,
conv_bias=cfg.conv_bias, ) conv_bias=cfg.conv_bias, )
self.post_extract_proj = (nn.Linear(self.embed, cfg.encoder_embed_dim) self.post_extract_proj = (Linear(self.embed, cfg.encoder_embed_dim)
if self.embed != cfg.encoder_embed_dim and if self.embed != cfg.encoder_embed_dim and
not cfg.quantize_input else None) not cfg.quantize_input else None)
@ -1774,9 +1802,9 @@ class Wav2Vec2Model(nn.Layer):
time_first=True, time_first=True,
weight_proj_depth=cfg.quantizer_depth, weight_proj_depth=cfg.quantizer_depth,
weight_proj_factor=cfg.quantizer_factor, ) weight_proj_factor=cfg.quantizer_factor, )
self.project_q = nn.Linear(vq_dim, final_dim) self.project_q = Linear(vq_dim, final_dim)
else: else:
self.project_q = nn.Linear(self.embed, final_dim) self.project_q = Linear(self.embed, final_dim)
if cfg.quantize_input: if cfg.quantize_input:
if cfg.same_quantizer and self.quantizer is not None: if cfg.same_quantizer and self.quantizer is not None:
@ -1794,7 +1822,7 @@ class Wav2Vec2Model(nn.Layer):
time_first=True, time_first=True,
weight_proj_depth=cfg.quantizer_depth, weight_proj_depth=cfg.quantizer_depth,
weight_proj_factor=cfg.quantizer_factor, ) weight_proj_factor=cfg.quantizer_factor, )
self.project_inp = nn.Linear(vq_dim, cfg.encoder_embed_dim) self.project_inp = Linear(vq_dim, cfg.encoder_embed_dim)
self.mask_emb = self.create_parameter( self.mask_emb = self.create_parameter(
shape=[cfg.encoder_embed_dim], shape=[cfg.encoder_embed_dim],
@ -1809,9 +1837,9 @@ class Wav2Vec2Model(nn.Layer):
self.target_glu = None self.target_glu = None
if cfg.target_glu: if cfg.target_glu:
self.target_glu = nn.Sequential( self.target_glu = nn.Sequential(
nn.Linear(final_dim, final_dim * 2), GLU()) Linear(final_dim, final_dim * 2), GLU())
self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim) self.final_proj = Linear(cfg.encoder_embed_dim, final_dim)
def upgrade_state_dict_named(self, state_dict, name): def upgrade_state_dict_named(self, state_dict, name):
super().upgrade_state_dict_named(state_dict, name) super().upgrade_state_dict_named(state_dict, name)
@ -2194,7 +2222,7 @@ class ConvFeatureExtractionModel(nn.Layer):
is_group_norm=False, is_group_norm=False,
conv_bias=False, ): conv_bias=False, ):
def make_conv(): def make_conv():
conv = nn.Conv1D( conv = Conv1D(
n_in, n_in,
n_out, n_out,
k, k,
@ -2256,17 +2284,16 @@ class ConvFeatureExtractionModel(nn.Layer):
def make_conv_pos(e, k, g): def make_conv_pos(e, k, g):
pos_conv = nn.Conv1D( dropout = 0
std = math.sqrt((4 * (1.0 - dropout)) / (k * e))
pos_conv = Conv1D(
e, e,
e, e,
kernel_size=k, kernel_size=k,
padding=k // 2, padding=k // 2,
groups=g, ) groups=g,
dropout = 0 weight_attr=nn.initializer.Normal(mean=0, std=std),
std = math.sqrt((4 * (1.0 - dropout)) / (k * e)) bias_attr=nn.initializer.Constant(0))
nn.initializer.Normal(mean=0, std=std)(pos_conv.weight)
nn.initializer.Constant(0)(pos_conv.bias)
pos_conv = nn.utils.weight_norm(pos_conv, name="weight", dim=2) pos_conv = nn.utils.weight_norm(pos_conv, name="weight", dim=2)
pos_conv = nn.Sequential(pos_conv, SamePad(k), nn.GELU()) pos_conv = nn.Sequential(pos_conv, SamePad(k), nn.GELU())
@ -2301,7 +2328,7 @@ class TransformerEncoder(nn.Layer):
def make_conv_block(e, k, g, l): def make_conv_block(e, k, g, l):
return nn.Sequential(*[ return nn.Sequential(*[
nn.Sequential( nn.Sequential(
nn.Conv1D( Conv1D(
e, e,
e, e,
kernel_size=k, kernel_size=k,
@ -2454,8 +2481,8 @@ class TransformerSentenceEncoderLayer(nn.Layer):
# layer norm associated with the self attention layer # layer norm associated with the self attention layer
self.self_attn_layer_norm = LayerNorm(self.embedding_dim) self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) self.fc1 = Linear(self.embedding_dim, ffn_embedding_dim)
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) self.fc2 = Linear(ffn_embedding_dim, self.embedding_dim)
# layer norm associated with the position wise feed-forward NN # layer norm associated with the position wise feed-forward NN
self.final_layer_norm = LayerNorm(self.embedding_dim) self.final_layer_norm = LayerNorm(self.embedding_dim)

@ -58,8 +58,6 @@ class Wav2vec2ASR(nn.Layer):
reduction='mean') reduction='mean')
def forward(self, wav, wavs_lens_rate, target, target_lens): def forward(self, wav, wavs_lens_rate, target, target_lens):
# import pdb
# pdb.set_trace()
if self.normalize_wav: if self.normalize_wav:
wav = F.layer_norm(wav, wav.shape) wav = F.layer_norm(wav, wav.shape)

Loading…
Cancel
Save