add sid dataloader for training, test=doc

pull/1523/head
xiongxinlei 3 years ago
parent 6af2bc3d5b
commit 7668f61422

@ -11,16 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import argparse import argparse
import os
import paddle import paddle
from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler
from paddleaudio.datasets.voxceleb import VoxCeleb1 from paddleaudio.datasets.voxceleb import VoxCeleb1
from paddlespeech.vector.datasets.batch import waveform_collate_fn
from paddlespeech.vector.layers.loss import AdditiveAngularMargin
from paddlespeech.vector.layers.loss import LogSoftmaxWrapper
from paddlespeech.vector.layers.lr import CyclicLRScheduler from paddlespeech.vector.layers.lr import CyclicLRScheduler
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
from paddlespeech.vector.training.sid_model import SpeakerIdetification from paddlespeech.vector.training.sid_model import SpeakerIdetification
from paddlespeech.vector.layers.loss import AdditiveAngularMargin, LogSoftmaxWrapper
def main(args): def main(args):
# stage0: set the training device, cpu or gpu # stage0: set the training device, cpu or gpu
@ -61,7 +66,6 @@ def main(args):
criterion = LogSoftmaxWrapper( criterion = LogSoftmaxWrapper(
loss_fn=AdditiveAngularMargin(margin=0.2, scale=30)) loss_fn=AdditiveAngularMargin(margin=0.2, scale=30))
# stage7: confirm training start epoch # stage7: confirm training start epoch
# if pre-trained model exists, start epoch confirmed by the pre-trained model # if pre-trained model exists, start epoch confirmed by the pre-trained model
start_epoch = 0 start_epoch = 0
@ -89,7 +93,19 @@ def main(args):
print(f'Restore training from epoch {start_epoch}.') print(f'Restore training from epoch {start_epoch}.')
except ValueError: except ValueError:
pass pass
# stage8: we build the batch sampler for paddle.DataLoader
train_sampler = DistributedBatchSampler(
train_ds, batch_size=args.batch_size, shuffle=True, drop_last=False)
train_loader = DataLoader(
train_ds,
batch_sampler=train_sampler,
num_workers=args.num_workers,
collate_fn=waveform_collate_fn,
return_list=True,
use_buffer_reader=True, )
if __name__ == "__main__": if __name__ == "__main__":
# yapf: disable # yapf: disable
parser = argparse.ArgumentParser(__doc__) parser = argparse.ArgumentParser(__doc__)
@ -105,10 +121,17 @@ if __name__ == "__main__":
type=float, type=float,
default=1e-8, default=1e-8,
help="Learning rate used to train with warmup.") help="Learning rate used to train with warmup.")
parser.add_argument("--load_checkpoint", parser.add_argument("--load_checkpoint",
type=str, type=str,
default=None, default=None,
help="Directory to load model checkpoint to contiune trainning.") help="Directory to load model checkpoint to contiune trainning.")
parser.add_argument("--batch_size",
type=int, default=64,
help="Total examples' number in batch for training.")
parser.add_argument("--num_workers",
type=int,
default=0,
help="Number of workers in dataloader.")
args = parser.parse_args() args = parser.parse_args()
# yapf: enable # yapf: enable

@ -0,0 +1,20 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
def waveform_collate_fn(batch):
waveforms = np.stack([item['feat'] for item in batch])
labels = np.stack([item['label'] for item in batch])
return {'waveforms': waveforms, 'labels': labels}
Loading…
Cancel
Save