add sid dataloader for training, test=doc

pull/1523/head
xiongxinlei 3 years ago
parent 6af2bc3d5b
commit 7668f61422

@ -11,16 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import argparse
import os
import paddle
from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler
from paddleaudio.datasets.voxceleb import VoxCeleb1
from paddlespeech.vector.datasets.batch import waveform_collate_fn
from paddlespeech.vector.layers.loss import AdditiveAngularMargin
from paddlespeech.vector.layers.loss import LogSoftmaxWrapper
from paddlespeech.vector.layers.lr import CyclicLRScheduler
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
from paddlespeech.vector.training.sid_model import SpeakerIdetification
from paddlespeech.vector.layers.loss import AdditiveAngularMargin, LogSoftmaxWrapper
def main(args):
# stage0: set the training device, cpu or gpu
@ -61,7 +66,6 @@ def main(args):
criterion = LogSoftmaxWrapper(
loss_fn=AdditiveAngularMargin(margin=0.2, scale=30))
# stage7: confirm training start epoch
# if pre-trained model exists, start epoch confirmed by the pre-trained model
start_epoch = 0
@ -90,6 +94,18 @@ def main(args):
except ValueError:
pass
# stage8: we build the batch sampler for paddle.DataLoader
train_sampler = DistributedBatchSampler(
train_ds, batch_size=args.batch_size, shuffle=True, drop_last=False)
train_loader = DataLoader(
train_ds,
batch_sampler=train_sampler,
num_workers=args.num_workers,
collate_fn=waveform_collate_fn,
return_list=True,
use_buffer_reader=True, )
if __name__ == "__main__":
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
@ -109,6 +125,13 @@ if __name__ == "__main__":
type=str,
default=None,
help="Directory to load model checkpoint to contiune trainning.")
parser.add_argument("--batch_size",
type=int, default=64,
help="Total examples' number in batch for training.")
parser.add_argument("--num_workers",
type=int,
default=0,
help="Number of workers in dataloader.")
args = parser.parse_args()
# yapf: enable

@ -0,0 +1,20 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
def waveform_collate_fn(batch):
waveforms = np.stack([item['feat'] for item in batch])
labels = np.stack([item['label'] for item in batch])
return {'waveforms': waveforms, 'labels': labels}
Loading…
Cancel
Save