From 57c4f4a68cf5b722bfaf6ee0f90c9f1768e7dded Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Thu, 3 Mar 2022 16:37:26 +0800 Subject: [PATCH] add sid learning rate and training model --- examples/voxceleb/sv0/local/train.py | 34 ++++++++++++- paddlespeech/vector/layers/lr.py | 45 +++++++++++++++++ paddlespeech/vector/training/sid_model.py | 60 +++++++++++++++++++++++ 3 files changed, 138 insertions(+), 1 deletion(-) create mode 100644 paddlespeech/vector/layers/lr.py create mode 100644 paddlespeech/vector/training/sid_model.py diff --git a/examples/voxceleb/sv0/local/train.py b/examples/voxceleb/sv0/local/train.py index c0cb1e17..8dea5fff 100644 --- a/examples/voxceleb/sv0/local/train.py +++ b/examples/voxceleb/sv0/local/train.py @@ -15,10 +15,14 @@ import argparse import paddle -from dataset.voxceleb.voxceleb1 import VoxCeleb1 +from paddleaudio.datasets.voxceleb import VoxCeleb1 +from paddlespeech.vector.layers.lr import CyclicLRScheduler +from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn +from paddlespeech.vector.training.sid_model import SpeakerIdetification def main(args): + # stage0: set the training device, cpu or gpu paddle.set_device(args.device) # stage1: we must call the paddle.distributed.init_parallel_env() api at the begining @@ -27,8 +31,32 @@ def main(args): local_rank = paddle.distributed.get_rank() # stage2: data prepare + # note: some cmd must do in rank==0 train_ds = VoxCeleb1('train', target_dir=args.data_dir) + # stage3: build the dnn backbone model network + model_conf = { + "input_size": 80, + "channels": [1024, 1024, 1024, 1024, 3072], + "kernel_sizes": [5, 3, 3, 3, 1], + "dilations": [1, 2, 3, 4, 1], + "attention_channels": 128, + "lin_neurons": 192, + } + ecapa_tdnn = EcapaTdnn(**model_conf) + + # stage4: build the speaker verification train instance with backbone model + model = SpeakerIdetification( + backbone=ecapa_tdnn, num_class=VoxCeleb1.num_speakers) + + # stage5: build the optimizer, we now only construct the AdamW optimizer + lr_schedule = CyclicLRScheduler( + base_lr=args.learning_rate, max_lr=1e-3, step_size=140000 // nranks) + optimizer = paddle.optimizer.AdamW( + learning_rate=lr_schedule, parameters=model.parameters()) + + # stage6: build the loss function, we now only support LogSoftmaxWrapper + if __name__ == "__main__": # yapf: disable @@ -41,6 +69,10 @@ if __name__ == "__main__": default="./data/", type=str, help="data directory") + parser.add_argument("--learning_rate", + type=float, + default=1e-8, + help="Learning rate used to train with warmup.") args = parser.parse_args() # yapf: enable diff --git a/paddlespeech/vector/layers/lr.py b/paddlespeech/vector/layers/lr.py new file mode 100644 index 00000000..3dcac057 --- /dev/null +++ b/paddlespeech/vector/layers/lr.py @@ -0,0 +1,45 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from paddle.optimizer.lr import LRScheduler + + +class CyclicLRScheduler(LRScheduler): + def __init__(self, + base_lr: float=1e-8, + max_lr: float=1e-3, + step_size: int=10000): + + super(CyclicLRScheduler, self).__init__() + + self.current_step = -1 + self.base_lr = base_lr + self.max_lr = max_lr + self.step_size = step_size + + def step(self): + if not hasattr(self, 'current_step'): + return + + self.current_step += 1 + if self.current_step >= 2 * self.step_size: + self.current_step %= 2 * self.step_size + + self.last_lr = self.get_lr() + + def get_lr(self): + p = self.current_step / (2 * self.step_size) # Proportion in one cycle. + if p < 0.5: # Increase + return self.base_lr + p / 0.5 * (self.max_lr - self.base_lr) + else: # Decrease + return self.max_lr - (p / 0.5 - 1) * (self.max_lr - self.base_lr) diff --git a/paddlespeech/vector/training/sid_model.py b/paddlespeech/vector/training/sid_model.py new file mode 100644 index 00000000..8a46c3cd --- /dev/null +++ b/paddlespeech/vector/training/sid_model.py @@ -0,0 +1,60 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +class SpeakerIdetification(nn.Layer): + def __init__( + self, + backbone, + num_class, + lin_blocks=0, + lin_neurons=192, + dropout=0.1, ): + + super(SpeakerIdetification, self).__init__() + self.backbone = backbone + if dropout > 0: + self.dropout = nn.Dropout(dropout) + else: + self.dropout = None + input_size = self.backbone.emb_size + self.blocks = nn.LayerList() + for i in range(lin_blocks): + self.blocks.extend([ + nn.BatchNorm1D(input_size), + nn.Linear(in_features=input_size, out_features=lin_neurons), + ]) + input_size = lin_neurons + + self.weight = paddle.create_parameter( + shape=(input_size, num_class), + dtype='float32', + attr=paddle.ParamAttr(initializer=nn.initializer.XavierUniform()), ) + + def forward(self, x, lengths=None): + # x.shape: (N, C, L) + x = self.backbone(x, lengths).squeeze( + -1) # (N, emb_size, 1) -> (N, emb_size) + if self.dropout is not None: + x = self.dropout(x) + + for fc in self.blocks: + x = fc(x) + + logits = F.linear(F.normalize(x), F.normalize(self.weight, axis=0)) + + return logits