train process add new voxceleb and rirs dataset, test=doc

pull/1630/head
xiongxinlei 3 years ago
parent 965f486dd5
commit 5b05300e53

@ -136,10 +136,3 @@ if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
exit 1 exit 1
fi fi
fi fi

@ -23,13 +23,13 @@ from paddle.io import DistributedBatchSampler
from yacs.config import CfgNode from yacs.config import CfgNode
from paddleaudio.compliance.librosa import melspectrogram from paddleaudio.compliance.librosa import melspectrogram
from paddleaudio.datasets.voxceleb import VoxCeleb
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.io.augment import build_augment_pipeline from paddlespeech.vector.io.augment import build_augment_pipeline
from paddlespeech.vector.io.augment import waveform_augment from paddlespeech.vector.io.augment import waveform_augment
from paddlespeech.vector.io.batch import batch_pad_right from paddlespeech.vector.io.batch import batch_pad_right
from paddlespeech.vector.io.batch import feature_normalize from paddlespeech.vector.io.batch import feature_normalize
from paddlespeech.vector.io.batch import waveform_collate_fn from paddlespeech.vector.io.batch import waveform_collate_fn
from paddlespeech.vector.io.dataset import VoxCelebDataset
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
from paddlespeech.vector.modules.loss import AdditiveAngularMargin from paddlespeech.vector.modules.loss import AdditiveAngularMargin
from paddlespeech.vector.modules.loss import LogSoftmaxWrapper from paddlespeech.vector.modules.loss import LogSoftmaxWrapper
@ -37,6 +37,7 @@ from paddlespeech.vector.modules.sid_model import SpeakerIdetification
from paddlespeech.vector.training.scheduler import CyclicLRScheduler from paddlespeech.vector.training.scheduler import CyclicLRScheduler
from paddlespeech.vector.training.seeding import seed_everything from paddlespeech.vector.training.seeding import seed_everything
from paddlespeech.vector.utils.time import Timer from paddlespeech.vector.utils.time import Timer
# from paddleaudio.datasets.voxceleb import VoxCeleb
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
@ -54,8 +55,14 @@ def main(args, config):
# stage2: data prepare, such vox1 and vox2 data, and augment noise data and pipline # stage2: data prepare, such vox1 and vox2 data, and augment noise data and pipline
# note: some cmd must do in rank==0, so wo will refactor the data prepare code # note: some cmd must do in rank==0, so wo will refactor the data prepare code
train_dataset = VoxCeleb('train', target_dir=args.data_dir) train_dataset = VoxCelebDataset(
dev_dataset = VoxCeleb('dev', target_dir=args.data_dir) csv_path=os.path.join(args.data_dir, "vox/csv/train.csv"),
spk_id2label_path=os.path.join(args.data_dir,
"vox/meta/spk_id2label.txt"))
dev_dataset = VoxCelebDataset(
csv_path=os.path.join(args.data_dir, "vox/csv/dev.csv"),
spk_id2label_path=os.path.join(args.data_dir,
"vox/meta/spk_id2label.txt"))
if config.augment: if config.augment:
augment_pipeline = build_augment_pipeline(target_dir=args.data_dir) augment_pipeline = build_augment_pipeline(target_dir=args.data_dir)

@ -21,13 +21,14 @@ import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
from paddleaudio.datasets.rirs_noises import OpenRIRNoise
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.io.dataset import RIRSNoiseDataset
from paddlespeech.vector.io.signal_processing import compute_amplitude from paddlespeech.vector.io.signal_processing import compute_amplitude
from paddlespeech.vector.io.signal_processing import convolve1d from paddlespeech.vector.io.signal_processing import convolve1d
from paddlespeech.vector.io.signal_processing import dB_to_amplitude from paddlespeech.vector.io.signal_processing import dB_to_amplitude
from paddlespeech.vector.io.signal_processing import notch_filter from paddlespeech.vector.io.signal_processing import notch_filter
from paddlespeech.vector.io.signal_processing import reverberate from paddlespeech.vector.io.signal_processing import reverberate
# from paddleaudio.datasets.rirs_noises import OpenRIRNoise
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
@ -839,8 +840,10 @@ def build_augment_pipeline(target_dir=None) -> List[paddle.nn.Layer]:
List[paddle.nn.Layer]: all augment process List[paddle.nn.Layer]: all augment process
""" """
logger.info("start to build the augment pipeline") logger.info("start to build the augment pipeline")
noise_dataset = OpenRIRNoise('noise', target_dir=target_dir) noise_dataset = RIRSNoiseDataset(csv_path=os.path.join(
rir_dataset = OpenRIRNoise('rir', target_dir=target_dir) target_dir, "rir_noise/csv/noise.csv"))
rir_dataset = OpenRIRNoise(csv_path=os.path.join(target_dir,
"rir_noise/csv/rir.csv"))
wavedrop = TimeDomainSpecAugment( wavedrop = TimeDomainSpecAugment(
sample_rate=16000, sample_rate=16000,

Loading…
Cancel
Save