diff --git a/examples/librispeech/asr3/conf/hubertASR.yaml b/examples/librispeech/asr3/conf/hubertASR.yaml index efb549591..e147815a8 100644 --- a/examples/librispeech/asr3/conf/hubertASR.yaml +++ b/examples/librispeech/asr3/conf/hubertASR.yaml @@ -1,7 +1,7 @@ ############################################ # Network Architecture # ############################################ -freeze_hubert: True +freeze_hubert: False normalize_wav: True output_norm: True init_type: kaiming_uniform # !Warning: need to convergence @@ -14,11 +14,20 @@ ctc: enc_n_units: 1024 blank_id: 0 dropout_rate: 0.0 -hubert_params_path: "exp/hubert/pd_hubert.pdparams" +hubert_params_path: "exp/hubert/pd_hubert_no_fintune.pdparams" task_cfg: + label_rate: 50.0 sample_rate: 16000 + normalize: True + enable_padding: False + max_keep_size: None + max_sample_size: 250000 + min_sample_size: 32000 + single_target: False + random_crop: True + pad_audio: False model_cfg: dropout_input: 0.0 @@ -37,7 +46,6 @@ model_cfg: mask_channel_selection: static mask_channel_other: 0.0 no_mask_channel_overlap: False - freeze_finetune_updates: 10000 feature_grad_mult: 0.0 layerdrop: 0.1 normalize: True @@ -69,7 +77,7 @@ model_cfg: ########################################### # Data # ########################################### -train_manifest: data/manifest.train +train_manifest: data/manifest.train-clean-100 dev_manifest: data/manifest.dev test_manifest: data/manifest.test-clean @@ -81,7 +89,7 @@ unit_type: char mean_std_filepath: "" preprocess_config: conf/preprocess.yaml sortagrad: -1 # Feed samples from shortest to longest ; -1: enabled for all epochs 0: disabled other: enabled for other epochs -batch_size: 8 # Different batch_size may cause large differences in results +batch_size: 2 # Different batch_size may cause large differences in results maxlen_in: 51200000000 # if input length > maxlen-in batchsize is automatically reduced maxlen_out: 1500000 # if output length > maxlen-out batchsize is automatically reduced minibatches: 0 # for debug @@ -102,12 +110,13 @@ return_lens_rate: True ############################################ audio_augment: # for raw audio sample_rate: 16000 + speeds: [95, 100, 105] ########################################### # Training # ########################################### -n_epoch: 1 -accum_grad: 1 +n_epoch: 3 +accum_grad: 8 global_grad_clip: 5.0 model_optim: adadelta model_optim_conf: @@ -120,7 +129,7 @@ model_scheduler_conf: lr_decay: 1.0 hubert_optim: adadelta hubert_optim_conf: - lr: 0.9 + lr: 1.0 epsilon: 1.0e-6 rho: 0.95 hubert_scheduler: constantlr @@ -130,4 +139,4 @@ hubert_scheduler_conf: log_interval: 1 checkpoint: kbest_n: 50 - latest_n: 5 \ No newline at end of file + latest_n: 5 diff --git a/examples/librispeech/asr3/run.sh b/examples/librispeech/asr3/run.sh index 8ebab30d0..87a693c01 100755 --- a/examples/librispeech/asr3/run.sh +++ b/examples/librispeech/asr3/run.sh @@ -5,9 +5,9 @@ MODEL=hubert . ./path.sh ${MODEL} || exit 1; . ./cmd.sh || exit 1; -gpus=2 +gpus=1,2 stage=1 -stop_stage=3 +stop_stage=1 conf_path=conf/${MODEL}ASR.yaml ips= #xx.xx.xx.xx,xx.xx.xx.xx decode_conf_path=conf/tuning/decode.yaml @@ -20,7 +20,7 @@ audio_file=data/demo_002_en.wav avg_ckpt=avg_${avg_num} ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}') -ckpt=test6 +ckpt=train_clean_test_new_3 echo "checkpoint name ${ckpt}" if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then @@ -30,7 +30,7 @@ fi if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then # train model, all `ckpt` under `exp` dir - CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${resume} ${ips} + CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${resume} ${ips} fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then @@ -40,7 +40,7 @@ fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then # greedy search decoder - CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 + CUDA_VISIBLE_DEVICES=1 ./local/test.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 fi if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then diff --git a/paddlespeech/s2t/exps/hubert/model.py b/paddlespeech/s2t/exps/hubert/model.py index 8e2ab5745..5e522d276 100644 --- a/paddlespeech/s2t/exps/hubert/model.py +++ b/paddlespeech/s2t/exps/hubert/model.py @@ -185,7 +185,7 @@ class HubertASRTrainer(Trainer): utt, wav, wavs_lens, target, target_lens = batch wavs_lens_rate = wavs_lens / wav.shape[1] wav = wav[:, :, 0] - + logger.info('training utt ids: {}'.format(utt)) if hasattr(train_conf, 'audio_augment'): wav = self.speech_augmentation(wav, wavs_lens_rate) diff --git a/paddlespeech/s2t/models/hubert/hubert_ASR.py b/paddlespeech/s2t/models/hubert/hubert_ASR.py index 2f45cd1ff..a3193b3c2 100644 --- a/paddlespeech/s2t/models/hubert/hubert_ASR.py +++ b/paddlespeech/s2t/models/hubert/hubert_ASR.py @@ -12,15 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections import defaultdict -from typing import Dict -from typing import List -from typing import Tuple -from typing import Any, Optional +from typing import Dict, List, Tuple, Any from dataclasses import dataclass, field, is_dataclass from copy import deepcopy -from omegaconf import II, MISSING, open_dict - import paddle import paddle.nn as nn import paddle.nn.functional as F @@ -323,7 +318,7 @@ class HubertASR(nn.Layer): class HubertBase(nn.Layer): - """Wav2vec2 model""" + """Hubert model""" def __init__(self, config: dict): super().__init__() diff --git a/paddlespeech/s2t/models/hubert/modules/hubert_model.py b/paddlespeech/s2t/models/hubert/modules/hubert_model.py index 8b7822699..b96becbeb 100644 --- a/paddlespeech/s2t/models/hubert/modules/hubert_model.py +++ b/paddlespeech/s2t/models/hubert/modules/hubert_model.py @@ -6,14 +6,14 @@ # S3PRL Team has no contribution to this file # The file was copied from fairseq to remove the dependency on the entire fairseq package -import logging from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Tuple import numpy as np import paddle import paddle.nn as nn - +from paddlespeech.s2t.modules.align import Linear +from paddlespeech.s2t.modules.align import LayerNorm from paddlespeech.s2t.models.wav2vec2.modules.wav2vec2_model import ( EXTRACTOR_MODE_CHOICES, LAYER_TYPE_CHOICES, @@ -27,9 +27,9 @@ from paddlespeech.s2t.models.wav2vec2.modules.wav2vec2_model import ( get_available_activation_fns, GLU, ) +from paddlespeech.s2t.utils.log import Log -logger = logging.getLogger(__name__) - +logger = Log(__name__).getlog() @dataclass class HubertPretrainingConfig: @@ -302,7 +302,7 @@ class HubertModel(nn.Layer): self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / task_cfg.sample_rate self.post_extract_proj = ( - nn.Linear(self.embed, cfg.encoder_embed_dim) + Linear(self.embed, cfg.encoder_embed_dim) if self.embed != cfg.encoder_embed_dim else None ) @@ -334,7 +334,7 @@ class HubertModel(nn.Layer): self.mask_emb = paddle.create_parameter( shape=[cfg.encoder_embed_dim], dtype='float32', - default_initializer=paddle.nn.initializer.Uniform(), + default_initializer=paddle.nn.initializer.Uniform(low=0), ) self.encoder = TransformerEncoder(cfg) @@ -343,16 +343,16 @@ class HubertModel(nn.Layer): self.target_glu = None if cfg.target_glu: self.target_glu = nn.Sequential( - nn.Linear(final_dim, final_dim * 2), GLU() + Linear(final_dim, final_dim * 2), GLU() ) self.untie_final_proj = cfg.untie_final_proj if self.untie_final_proj: - self.final_proj = nn.Linear( + self.final_proj = Linear( cfg.encoder_embed_dim, final_dim * len(dictionaries) ) else: - self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim) + self.final_proj = Linear(cfg.encoder_embed_dim, final_dim) # modules below are not needed during fine-tuning if any([d is None for d in dictionaries]): @@ -362,13 +362,8 @@ class HubertModel(nn.Layer): self.label_embs_concat = paddle.create_parameter( shape=[sum(self.num_classes), final_dim], dtype='float32', - default_initializer=paddle.nn.initializer.Uniform(), + default_initializer=paddle.nn.initializer.Uniform(low=0), ) - def upgrade_state_dict_named(self, state_dict, name): - """Upgrade a (possibly old) state dict for new versions of fairseq.""" - - super().upgrade_state_dict_named(state_dict, name) - return state_dict @classmethod def build_model(cls, cfg: HubertConfig, task): @@ -417,7 +412,7 @@ class HubertModel(nn.Layer): return x, mask_indices - def compute_nce(x, pos, negs): + def compute_nce(self, x, pos, negs): neg_is_pos = (pos == negs).all(-1) pos = pos.unsqueeze(0) targets = paddle.concat([pos, negs], axis=0) diff --git a/paddlespeech/s2t/models/wav2vec2/modules/wav2vec2_model.py b/paddlespeech/s2t/models/wav2vec2/modules/wav2vec2_model.py index c610b22d7..2a7887880 100644 --- a/paddlespeech/s2t/models/wav2vec2/modules/wav2vec2_model.py +++ b/paddlespeech/s2t/models/wav2vec2/modules/wav2vec2_model.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. # S3PRL has no contribution to this file # The file was copied from fairseq to remove the dependency on the entire fairseq package -import logging import math import uuid from dataclasses import dataclass @@ -16,15 +15,19 @@ from typing import Dict from typing import List from typing import Optional from typing import Tuple - import numpy as np import paddle import paddle.nn as nn import paddle.nn.functional as F from paddle import Tensor +from paddlespeech.s2t.modules.align import Linear +from paddlespeech.s2t.modules.align import LayerNorm +from paddlespeech.s2t.modules.align import Conv1D +from paddlespeech.s2t.modules.align import Conv2D +from paddlespeech.s2t.modules.align import Embedding +from paddlespeech.s2t.utils.log import Log -logger = logging.getLogger(__name__) - +logger = Log(__name__).getlog() class GLU(nn.Layer): r"""Applies the gated linear unit function @@ -153,15 +156,19 @@ def quant_noise(module, p, block_size): return module # supported modules - assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2D)) + assert isinstance(module, (Linear, Embedding, Conv2D)) # test whether module.weight has the right sizes wrt block_size is_conv = len(module.weight.shape) == 4 # 2D matrix if not is_conv: + if isinstance(module, Linear): + features_weight = module.weight.shape[0] + else: + features_weight = module.weight.shape[1] assert ( - module.weight.shape[1] % + features_weight % block_size == 0), "Input features must be a multiple of block sizes" # 4D matrix @@ -181,14 +188,20 @@ def quant_noise(module, p, block_size): if not is_conv: # gather weight and sizes weight = mod.weight - in_features = weight.shape[1] - out_features = weight.shape[0] + if isinstance(module, Linear): + in_features = weight.shape[0] + out_features = weight.shape[1] + else: + in_features = weight.shape[1] + out_features = weight.shape[0] # split weight matrix into blocks and randomly drop selected blocks mask = paddle.zeros( [in_features // block_size * out_features], dtype=paddle.bool) - mask.bernoulli_(p) + # the implementation of bernoulli_, p=0.5 + mask = paddle.ones_like(mask) * 0.5 + mask = paddle.bernoulli(mask) mask = mask.unsqueeze(1).tile([1, block_size]).reshape( [-1, in_features]) @@ -203,12 +216,18 @@ def quant_noise(module, p, block_size): mask = paddle.zeros( [in_channels // block_size * out_channels], dtype=paddle.bool) - mask.bernoulli_(p) + + # the implementation of bernoulli_, p=0.5 + mask = paddle.ones_like(mask) * 0.5 + mask = paddle.bernoulli(mask) mask = mask.unsqueeze(1).tile([1, block_size]).reshape( [-1, in_channels]) else: mask = paddle.zeros(weight.shape) - mask.bernoulli_(p) + + # the implementation of bernoulli_, p=0.5 + mask = paddle.ones_like(mask) * 0.5 + mask = paddle.bernoulli(mask) mask = mask.unsqueeze(1).tile([1, in_channels, 1, 1]) # scale weights and apply mask @@ -281,29 +300,53 @@ class MultiheadAttention(nn.Layer): assert not self.self_attention or self.qkv_same_dim, ( "Self-attention requires query, key and " "value to be of the same size") - - weight_attr = paddle.ParamAttr(initializer=nn.initializer.XavierUniform) - bias_attr = nn.initializer.Constant(0) - # self.k_proj = quant_noise( - # nn.Linear(self.kdim, embed_dim, weight_attr=weight_attr, bias_attr=bias), q_noise, qn_block_size - # ) - # self.v_proj = quant_noise( - # nn.Linear(self.vdim, embed_dim, weight_attr=weight_attr, bias_attr=bias), q_noise, qn_block_size - # ) - # self.q_proj = quant_noise( - # nn.Linear(embed_dim, embed_dim, weight_attr=weight_attr, bias_attr=bias), q_noise, qn_block_size - # ) - - # self.out_proj = quant_noise( - # nn.Linear(embed_dim, embed_dim, weight_attr=weight_attr, bias_attr=bias if not bias else bias_attr), q_noise, qn_block_size - # ) - self.k_proj = nn.Linear(self.kdim, embed_dim) - - self.v_proj = nn.Linear(self.vdim, embed_dim) - - self.q_proj = nn.Linear(embed_dim, embed_dim) - - self.out_proj = nn.Linear(embed_dim, embed_dim) + + # Todo scaled initialization + # Empirically observed the convergence to be much better with + # the scaled initialization + weight_attr = nn.initializer.XavierUniform() + kv_proj_bias_attr = nn.initializer.XavierUniform() + out_proj_bias_attr = nn.initializer.Constant(0) + + self.k_proj = quant_noise( + nn.Linear(self.kdim, embed_dim, weight_attr=weight_attr, bias_attr=bias if not bias else kv_proj_bias_attr), q_noise, qn_block_size + ) + self.v_proj = quant_noise( + nn.Linear(self.vdim, embed_dim, weight_attr=weight_attr, bias_attr=bias if not bias else kv_proj_bias_attr), q_noise, qn_block_size + ) + self.q_proj = quant_noise( + nn.Linear(embed_dim, embed_dim, weight_attr=weight_attr, bias_attr=bias), q_noise, qn_block_size + ) + + self.out_proj = quant_noise( + nn.Linear(embed_dim, embed_dim, weight_attr=weight_attr, bias_attr=bias if not bias else out_proj_bias_attr), q_noise, qn_block_size + ) + + + # nn.initializer.XavierUniform(self.k_proj.weight, gain=1 / math.sqrt(2)) + # nn.initializer.XavierUniform(self.v_proj.weight, gain=1 / math.sqrt(2)) + # nn.initializer.XavierUniform(self.q_proj.weight, gain=1 / math.sqrt(2)) + # else: + # self.k_proj.weight = paddle.ParamAttr() + # nn.initializer.XavierUniform(self.k_proj.weight) + # nn.initializer.XavierUniform(self.v_proj.weight) + # nn.initializer.XavierUniform(self.q_proj.weight) + + # nn.initializer.XavierUniform(self.out_proj.weight) + # if self.out_proj.bias is not None: + # nn.initializer.Constant(self.out_proj.bias) + # if self.bias_k is not None: + # nn.initializer.XavierNormal(self.bias_k) + # if self.bias_v is not None: + # nn.initializer.XavierNormal(self.bias_v) + + # self.k_proj = Linear(self.kdim, embed_dim) + + # self.v_proj = Linear(self.vdim, embed_dim) + + # self.q_proj = Linear(embed_dim, embed_dim) + + # self.out_proj = Linear(embed_dim, embed_dim) if add_bias_kv: self.bias_k = paddle.create_parameter( @@ -327,26 +370,26 @@ class MultiheadAttention(nn.Layer): def prepare_for_onnx_export_(self): self.onnx_trace = True - # def reset_parameters(self): - # if self.qkv_same_dim: - # # Empirically observed the convergence to be much better with - # # the scaled initialization - # nn.initializer.XavierUniform(self.k_proj.weight, gain=1 / math.sqrt(2)) - # nn.initializer.XavierUniform(self.v_proj.weight, gain=1 / math.sqrt(2)) - # nn.initializer.XavierUniform(self.q_proj.weight, gain=1 / math.sqrt(2)) - # else: - # self.k_proj.weight = paddle.ParamAttr() - # nn.initializer.XavierUniform(self.k_proj.weight) - # nn.initializer.XavierUniform(self.v_proj.weight) - # nn.initializer.XavierUniform(self.q_proj.weight) - - # nn.initializer.XavierUniform(self.out_proj.weight) - # if self.out_proj.bias is not None: - # nn.initializer.Constant(self.out_proj.bias) - # if self.bias_k is not None: - # nn.initializer.XavierNormal(self.bias_k) - # if self.bias_v is not None: - # nn.initializer.XavierNormal(self.bias_v) + def reset_parameters(self): + if self.qkv_same_dim: + # Empirically observed the convergence to be much better with + # the scaled initialization + nn.initializer.XavierUniform(self.k_proj.weight, gain=1 / math.sqrt(2)) + nn.initializer.XavierUniform(self.v_proj.weight, gain=1 / math.sqrt(2)) + nn.initializer.XavierUniform(self.q_proj.weight, gain=1 / math.sqrt(2)) + else: + self.k_proj.weight = paddle.ParamAttr() + nn.initializer.XavierUniform(self.k_proj.weight) + nn.initializer.XavierUniform(self.v_proj.weight) + nn.initializer.XavierUniform(self.q_proj.weight) + + nn.initializer.XavierUniform(self.out_proj.weight) + if self.out_proj.bias is not None: + nn.initializer.Constant(self.out_proj.bias) + if self.bias_k is not None: + nn.initializer.XavierNormal(self.bias_k) + if self.bias_v is not None: + nn.initializer.XavierNormal(self.bias_v) def _get_reserve_head_index(self, num_heads_to_keep: int): k_proj_heads_norm = [] @@ -357,15 +400,15 @@ class MultiheadAttention(nn.Layer): start_idx = i * self.head_dim end_idx = (i + 1) * self.head_dim k_proj_heads_norm.append( - paddle.sum(paddle.abs(self.k_proj.weight[start_idx:end_idx, ])) + paddle.sum(paddle.abs(self.k_proj.weight[:, start_idx:end_idx])) .tolist() + paddle.sum( paddle.abs(self.k_proj.bias[start_idx:end_idx])).tolist()) q_proj_heads_norm.append( - paddle.sum(paddle.abs(self.q_proj.weight[start_idx:end_idx, ])) + paddle.sum(paddle.abs(self.q_proj.weight[:, start_idx:end_idx])) .tolist() + paddle.sum( paddle.abs(self.q_proj.bias[start_idx:end_idx])).tolist()) v_proj_heads_norm.append( - paddle.sum(paddle.abs(self.v_proj.weight[start_idx:end_idx, ])) + paddle.sum(paddle.abs(self.v_proj.weight[:, start_idx:end_idx])) .tolist() + paddle.sum( paddle.abs(self.v_proj.bias[start_idx:end_idx])).tolist()) @@ -395,24 +438,24 @@ class MultiheadAttention(nn.Layer): for ele in reserve_head_index: start_idx, end_idx = ele - new_q_weight.append(self.q_proj.weight[start_idx:end_idx, ]) + new_q_weight.append(self.q_proj.weight[:, start_idx:end_idx]) new_q_bias.append(self.q_proj.bias[start_idx:end_idx]) - new_k_weight.append(self.k_proj.weight[start_idx:end_idx, ]) + new_k_weight.append(self.k_proj.weight[:, start_idx:end_idx]) new_k_bias.append(self.k_proj.bias[start_idx:end_idx]) - new_v_weight.append(self.v_proj.weight[start_idx:end_idx, ]) + new_v_weight.append(self.v_proj.weight[:, start_idx:end_idx]) new_v_bias.append(self.v_proj.bias[start_idx:end_idx]) new_out_proj_weight.append( - self.out_proj.weight[:, start_idx:end_idx]) + self.out_proj.weight[start_idx:end_idx, ]) - new_q_weight = paddle.concat(new_q_weight).detach() - new_k_weight = paddle.concat(new_k_weight).detach() - new_v_weight = paddle.concat(new_v_weight).detach() + new_q_weight = paddle.concat(new_q_weight, axis=-1).detach() + new_k_weight = paddle.concat(new_k_weight, axis=-1).detach() + new_v_weight = paddle.concat(new_v_weight, axis=-1).detach() new_out_proj_weight = paddle.concat( - new_out_proj_weight, axis=-1).detach() + new_out_proj_weight).detach() new_q_weight.stop_gradient = False new_k_weight.stop_gradient = False new_v_weight.stop_gradient = False @@ -566,11 +609,11 @@ class MultiheadAttention(nn.Layer): assert (embed_dim == self.embed_dim ), f"query dim {embed_dim} != {self.embed_dim}" assert list(query.shape) == [tgt_len, bsz, embed_dim] - # if key is not None: - # src_len, key_bsz, _ = key.size() - # if not torch.jit.is_scripting(): - # assert value is not None - # assert src_len, key_bsz == value.shape[:2] + if key is not None: + src_len, key_bsz, _ = key.shape + # if not torch.jit.is_scripting(): + # assert value is not None + # assert src_len, key_bsz == value.shape[:2] # if ( # not self.onnx_trace @@ -848,7 +891,7 @@ class MultiheadAttention(nn.Layer): new_key_padding_mask = paddle.concat([ paddle.cast(prev_key_padding_mask, 'float32'), paddle.cast(key_padding_mask, 'float32') - ], axis == 1) + ], axis = 1) # During incremental decoding, as the padding token enters and # leaves the frame, there will be a time when prev or current # is None @@ -859,7 +902,7 @@ class MultiheadAttention(nn.Layer): new_key_padding_mask = paddle.concat([ paddle.cast(prev_key_padding_mask, 'float32'), paddle.cast(filler, 'float32') - ], axis == 1) + ], axis = 1) else: new_key_padding_mask = prev_key_padding_mask elif key_padding_mask is not None: @@ -869,7 +912,7 @@ class MultiheadAttention(nn.Layer): new_key_padding_mask = paddle.concat([ paddle.cast(filler, 'float32'), paddle.cast(key_padding_mask, 'float32') - ], axis == 1) + ], axis = 1) else: new_key_padding_mask = paddle.cast(key_padding_mask, 'float32') else: @@ -1022,7 +1065,7 @@ class GumbelVectorQuantizer(nn.Layer): def block(input_dim, output_dim): return nn.Sequential( - nn.Linear(input_dim, output_dim), activation) + Linear(input_dim, output_dim), activation) inner_dim = self.input_dim * weight_proj_factor self.weight_proj = nn.Sequential( @@ -1030,11 +1073,9 @@ class GumbelVectorQuantizer(nn.Layer): block(self.input_dim if i == 0 else inner_dim, inner_dim) for i in range(weight_proj_depth - 1) ], - nn.Linear(inner_dim, groups * num_vars), ) + Linear(inner_dim, groups * num_vars), ) else: - self.weight_proj = nn.Linear(self.input_dim, groups * num_vars) - nn.initializer.Normal(mean=0, std=1)(self.weight_proj.weight) - nn.initializer.Zero()(self.weight_proj.bias) + self.weight_proj = Linear(self.input_dim, groups * num_vars, weight_attr=nn.initializer.Normal(mean=0, std=1), bias_attr=nn.initializer.Zero()) if isinstance(temp, str): import ast @@ -1125,7 +1166,7 @@ class GumbelVectorQuantizer(nn.Layer): if self.training: x = F.gumbel_softmax( - x.astype('float32'), tau=self.curr_temp, + x.astype('float32'), temperature=self.curr_temp, hard=True).astype(x.dtype) else: x = hard_x @@ -1192,22 +1233,11 @@ class TransposeLast(nn.Layer): trans_dim[-1], trans_dim[-2] = trans_dim[-2], trans_dim[-1] return x.transpose(trans_dim) - -def LayerNorm(normalized_shape, eps=1e-5): - return nn.LayerNorm( - normalized_shape, - epsilon=eps, - weight_attr=paddle.ParamAttr(), - bias_attr=paddle.ParamAttr()) - - -class Fp32LayerNorm(nn.LayerNorm): +class Fp32LayerNorm(LayerNorm): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def forward(self, input): - # import pdb - # pdb.set_trace() output = F.layer_norm( input.astype('float32'), self._normalized_shape, @@ -1222,8 +1252,6 @@ class Fp32GroupNorm(nn.GroupNorm): super().__init__(*args, **kwargs) def forward(self, input): - # import pdb - # pdb.set_trace() output = F.group_norm( input.astype('float32'), self._num_groups, @@ -1724,7 +1752,7 @@ class Wav2Vec2Model(nn.Layer): mode=cfg.extractor_mode, conv_bias=cfg.conv_bias, ) - self.post_extract_proj = (nn.Linear(self.embed, cfg.encoder_embed_dim) + self.post_extract_proj = (Linear(self.embed, cfg.encoder_embed_dim) if self.embed != cfg.encoder_embed_dim and not cfg.quantize_input else None) @@ -1774,9 +1802,9 @@ class Wav2Vec2Model(nn.Layer): time_first=True, weight_proj_depth=cfg.quantizer_depth, weight_proj_factor=cfg.quantizer_factor, ) - self.project_q = nn.Linear(vq_dim, final_dim) + self.project_q = Linear(vq_dim, final_dim) else: - self.project_q = nn.Linear(self.embed, final_dim) + self.project_q = Linear(self.embed, final_dim) if cfg.quantize_input: if cfg.same_quantizer and self.quantizer is not None: @@ -1794,7 +1822,7 @@ class Wav2Vec2Model(nn.Layer): time_first=True, weight_proj_depth=cfg.quantizer_depth, weight_proj_factor=cfg.quantizer_factor, ) - self.project_inp = nn.Linear(vq_dim, cfg.encoder_embed_dim) + self.project_inp = Linear(vq_dim, cfg.encoder_embed_dim) self.mask_emb = self.create_parameter( shape=[cfg.encoder_embed_dim], @@ -1809,9 +1837,9 @@ class Wav2Vec2Model(nn.Layer): self.target_glu = None if cfg.target_glu: self.target_glu = nn.Sequential( - nn.Linear(final_dim, final_dim * 2), GLU()) + Linear(final_dim, final_dim * 2), GLU()) - self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim) + self.final_proj = Linear(cfg.encoder_embed_dim, final_dim) def upgrade_state_dict_named(self, state_dict, name): super().upgrade_state_dict_named(state_dict, name) @@ -2194,7 +2222,7 @@ class ConvFeatureExtractionModel(nn.Layer): is_group_norm=False, conv_bias=False, ): def make_conv(): - conv = nn.Conv1D( + conv = Conv1D( n_in, n_out, k, @@ -2256,17 +2284,16 @@ class ConvFeatureExtractionModel(nn.Layer): def make_conv_pos(e, k, g): - pos_conv = nn.Conv1D( + dropout = 0 + std = math.sqrt((4 * (1.0 - dropout)) / (k * e)) + pos_conv = Conv1D( e, e, kernel_size=k, padding=k // 2, - groups=g, ) - dropout = 0 - std = math.sqrt((4 * (1.0 - dropout)) / (k * e)) - nn.initializer.Normal(mean=0, std=std)(pos_conv.weight) - nn.initializer.Constant(0)(pos_conv.bias) - + groups=g, + weight_attr=nn.initializer.Normal(mean=0, std=std), + bias_attr=nn.initializer.Constant(0)) pos_conv = nn.utils.weight_norm(pos_conv, name="weight", dim=2) pos_conv = nn.Sequential(pos_conv, SamePad(k), nn.GELU()) @@ -2301,7 +2328,7 @@ class TransformerEncoder(nn.Layer): def make_conv_block(e, k, g, l): return nn.Sequential(*[ nn.Sequential( - nn.Conv1D( + Conv1D( e, e, kernel_size=k, @@ -2454,8 +2481,8 @@ class TransformerSentenceEncoderLayer(nn.Layer): # layer norm associated with the self attention layer self.self_attn_layer_norm = LayerNorm(self.embedding_dim) - self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) - self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) + self.fc1 = Linear(self.embedding_dim, ffn_embedding_dim) + self.fc2 = Linear(ffn_embedding_dim, self.embedding_dim) # layer norm associated with the position wise feed-forward NN self.final_layer_norm = LayerNorm(self.embedding_dim) diff --git a/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py b/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py index 14e6c1459..59a67a1e5 100755 --- a/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py +++ b/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py @@ -58,8 +58,6 @@ class Wav2vec2ASR(nn.Layer): reduction='mean') def forward(self, wav, wavs_lens_rate, target, target_lens): - # import pdb - # pdb.set_trace() if self.normalize_wav: wav = F.layer_norm(wav, wav.shape)