|
|
|
@ -23,13 +23,13 @@ from paddle.io import DistributedBatchSampler
|
|
|
|
|
from yacs.config import CfgNode
|
|
|
|
|
|
|
|
|
|
from paddleaudio.compliance.librosa import melspectrogram
|
|
|
|
|
from paddleaudio.datasets.voxceleb import VoxCeleb
|
|
|
|
|
from paddlespeech.s2t.utils.log import Log
|
|
|
|
|
from paddlespeech.vector.io.augment import build_augment_pipeline
|
|
|
|
|
from paddlespeech.vector.io.augment import waveform_augment
|
|
|
|
|
from paddlespeech.vector.io.batch import batch_pad_right
|
|
|
|
|
from paddlespeech.vector.io.batch import feature_normalize
|
|
|
|
|
from paddlespeech.vector.io.batch import waveform_collate_fn
|
|
|
|
|
from paddlespeech.vector.io.dataset import VoxCelebDataset
|
|
|
|
|
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
|
|
|
|
|
from paddlespeech.vector.modules.loss import AdditiveAngularMargin
|
|
|
|
|
from paddlespeech.vector.modules.loss import LogSoftmaxWrapper
|
|
|
|
@ -37,6 +37,7 @@ from paddlespeech.vector.modules.sid_model import SpeakerIdetification
|
|
|
|
|
from paddlespeech.vector.training.scheduler import CyclicLRScheduler
|
|
|
|
|
from paddlespeech.vector.training.seeding import seed_everything
|
|
|
|
|
from paddlespeech.vector.utils.time import Timer
|
|
|
|
|
# from paddleaudio.datasets.voxceleb import VoxCeleb
|
|
|
|
|
|
|
|
|
|
logger = Log(__name__).getlog()
|
|
|
|
|
|
|
|
|
@ -54,8 +55,14 @@ def main(args, config):
|
|
|
|
|
|
|
|
|
|
# stage2: data prepare, such vox1 and vox2 data, and augment noise data and pipline
|
|
|
|
|
# note: some cmd must do in rank==0, so wo will refactor the data prepare code
|
|
|
|
|
train_dataset = VoxCeleb('train', target_dir=args.data_dir)
|
|
|
|
|
dev_dataset = VoxCeleb('dev', target_dir=args.data_dir)
|
|
|
|
|
train_dataset = VoxCelebDataset(
|
|
|
|
|
csv_path=os.path.join(args.data_dir, "vox/csv/train.csv"),
|
|
|
|
|
spk_id2label_path=os.path.join(args.data_dir,
|
|
|
|
|
"vox/meta/spk_id2label.txt"))
|
|
|
|
|
dev_dataset = VoxCelebDataset(
|
|
|
|
|
csv_path=os.path.join(args.data_dir, "vox/csv/dev.csv"),
|
|
|
|
|
spk_id2label_path=os.path.join(args.data_dir,
|
|
|
|
|
"vox/meta/spk_id2label.txt"))
|
|
|
|
|
|
|
|
|
|
if config.augment:
|
|
|
|
|
augment_pipeline = build_augment_pipeline(target_dir=args.data_dir)
|
|
|
|
|