|
|
@ -14,11 +14,15 @@
|
|
|
|
import argparse
|
|
|
|
import argparse
|
|
|
|
import os
|
|
|
|
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import paddle
|
|
|
|
import paddle
|
|
|
|
from paddle.io import DataLoader
|
|
|
|
from paddle.io import DataLoader
|
|
|
|
from paddle.io import DistributedBatchSampler
|
|
|
|
from paddle.io import DistributedBatchSampler
|
|
|
|
|
|
|
|
|
|
|
|
from paddleaudio.datasets.voxceleb import VoxCeleb1
|
|
|
|
from paddleaudio.datasets.voxceleb import VoxCeleb1
|
|
|
|
|
|
|
|
from paddleaudio.features.core import melspectrogram
|
|
|
|
|
|
|
|
from paddleaudio.utils.time import Timer
|
|
|
|
|
|
|
|
from paddlespeech.vector.datasets.batch import feature_normalize
|
|
|
|
from paddlespeech.vector.datasets.batch import waveform_collate_fn
|
|
|
|
from paddlespeech.vector.datasets.batch import waveform_collate_fn
|
|
|
|
from paddlespeech.vector.layers.loss import AdditiveAngularMargin
|
|
|
|
from paddlespeech.vector.layers.loss import AdditiveAngularMargin
|
|
|
|
from paddlespeech.vector.layers.loss import LogSoftmaxWrapper
|
|
|
|
from paddlespeech.vector.layers.loss import LogSoftmaxWrapper
|
|
|
@ -26,6 +30,13 @@ from paddlespeech.vector.layers.lr import CyclicLRScheduler
|
|
|
|
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
|
|
|
|
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
|
|
|
|
from paddlespeech.vector.training.sid_model import SpeakerIdetification
|
|
|
|
from paddlespeech.vector.training.sid_model import SpeakerIdetification
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# feat configuration
|
|
|
|
|
|
|
|
cpu_feat_conf = {
|
|
|
|
|
|
|
|
'n_mels': 80,
|
|
|
|
|
|
|
|
'window_size': 400,
|
|
|
|
|
|
|
|
'hop_length': 160,
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(args):
|
|
|
|
def main(args):
|
|
|
|
# stage0: set the training device, cpu or gpu
|
|
|
|
# stage0: set the training device, cpu or gpu
|
|
|
@ -42,9 +53,10 @@ def main(args):
|
|
|
|
dev_ds = VoxCeleb1('dev', target_dir=args.data_dir)
|
|
|
|
dev_ds = VoxCeleb1('dev', target_dir=args.data_dir)
|
|
|
|
|
|
|
|
|
|
|
|
# stage3: build the dnn backbone model network
|
|
|
|
# stage3: build the dnn backbone model network
|
|
|
|
|
|
|
|
#"channels": [1024, 1024, 1024, 1024, 3072],
|
|
|
|
model_conf = {
|
|
|
|
model_conf = {
|
|
|
|
"input_size": 80,
|
|
|
|
"input_size": 80,
|
|
|
|
"channels": [1024, 1024, 1024, 1024, 3072],
|
|
|
|
"channels": [512, 512, 512, 512, 1536],
|
|
|
|
"kernel_sizes": [5, 3, 3, 3, 1],
|
|
|
|
"kernel_sizes": [5, 3, 3, 3, 1],
|
|
|
|
"dilations": [1, 2, 3, 4, 1],
|
|
|
|
"dilations": [1, 2, 3, 4, 1],
|
|
|
|
"attention_channels": 128,
|
|
|
|
"attention_channels": 128,
|
|
|
@ -105,6 +117,122 @@ def main(args):
|
|
|
|
return_list=True,
|
|
|
|
return_list=True,
|
|
|
|
use_buffer_reader=True, )
|
|
|
|
use_buffer_reader=True, )
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# stage9: start to train
|
|
|
|
|
|
|
|
# we will comment the training process
|
|
|
|
|
|
|
|
steps_per_epoch = len(train_sampler)
|
|
|
|
|
|
|
|
timer = Timer(steps_per_epoch * args.epochs)
|
|
|
|
|
|
|
|
timer.start()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for epoch in range(start_epoch + 1, args.epochs + 1):
|
|
|
|
|
|
|
|
# at the begining, model must set to train mode
|
|
|
|
|
|
|
|
model.train()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
avg_loss = 0
|
|
|
|
|
|
|
|
num_corrects = 0
|
|
|
|
|
|
|
|
num_samples = 0
|
|
|
|
|
|
|
|
for batch_idx, batch in enumerate(train_loader):
|
|
|
|
|
|
|
|
waveforms, labels = batch['waveforms'], batch['labels']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
feats = []
|
|
|
|
|
|
|
|
for waveform in waveforms.numpy():
|
|
|
|
|
|
|
|
feat = melspectrogram(x=waveform, **cpu_feat_conf)
|
|
|
|
|
|
|
|
feats.append(feat)
|
|
|
|
|
|
|
|
feats = paddle.to_tensor(np.asarray(feats))
|
|
|
|
|
|
|
|
feats = feature_normalize(
|
|
|
|
|
|
|
|
feats, mean_norm=True, std_norm=False) # Features normalization
|
|
|
|
|
|
|
|
logits = model(feats)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loss = criterion(logits, labels)
|
|
|
|
|
|
|
|
loss.backward()
|
|
|
|
|
|
|
|
optimizer.step()
|
|
|
|
|
|
|
|
if isinstance(optimizer._learning_rate,
|
|
|
|
|
|
|
|
paddle.optimizer.lr.LRScheduler):
|
|
|
|
|
|
|
|
optimizer._learning_rate.step()
|
|
|
|
|
|
|
|
optimizer.clear_grad()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Calculate loss
|
|
|
|
|
|
|
|
avg_loss += loss.numpy()[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Calculate metrics
|
|
|
|
|
|
|
|
preds = paddle.argmax(logits, axis=1)
|
|
|
|
|
|
|
|
num_corrects += (preds == labels).numpy().sum()
|
|
|
|
|
|
|
|
num_samples += feats.shape[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
timer.count()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (batch_idx + 1) % args.log_freq == 0 and local_rank == 0:
|
|
|
|
|
|
|
|
lr = optimizer.get_lr()
|
|
|
|
|
|
|
|
avg_loss /= args.log_freq
|
|
|
|
|
|
|
|
avg_acc = num_corrects / num_samples
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print_msg = 'Epoch={}/{}, Step={}/{}'.format(
|
|
|
|
|
|
|
|
epoch, args.epochs, batch_idx + 1, steps_per_epoch)
|
|
|
|
|
|
|
|
print_msg += ' loss={:.4f}'.format(avg_loss)
|
|
|
|
|
|
|
|
print_msg += ' acc={:.4f}'.format(avg_acc)
|
|
|
|
|
|
|
|
print_msg += ' lr={:.4E} step/sec={:.2f} | ETA {}'.format(
|
|
|
|
|
|
|
|
lr, timer.timing, timer.eta)
|
|
|
|
|
|
|
|
print(print_msg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
avg_loss = 0
|
|
|
|
|
|
|
|
num_corrects = 0
|
|
|
|
|
|
|
|
num_samples = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if epoch % args.save_freq == 0 and batch_idx + 1 == steps_per_epoch:
|
|
|
|
|
|
|
|
if local_rank != 0:
|
|
|
|
|
|
|
|
paddle.distributed.barrier(
|
|
|
|
|
|
|
|
) # Wait for valid step in main process
|
|
|
|
|
|
|
|
continue # Resume trainning on other process
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dev_sampler = paddle.io.BatchSampler(
|
|
|
|
|
|
|
|
dev_ds,
|
|
|
|
|
|
|
|
batch_size=args.batch_size // 4,
|
|
|
|
|
|
|
|
shuffle=False,
|
|
|
|
|
|
|
|
drop_last=False)
|
|
|
|
|
|
|
|
dev_loader = paddle.io.DataLoader(
|
|
|
|
|
|
|
|
dev_ds,
|
|
|
|
|
|
|
|
batch_sampler=dev_sampler,
|
|
|
|
|
|
|
|
collate_fn=waveform_collate_fn,
|
|
|
|
|
|
|
|
num_workers=args.num_workers,
|
|
|
|
|
|
|
|
return_list=True, )
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model.eval()
|
|
|
|
|
|
|
|
num_corrects = 0
|
|
|
|
|
|
|
|
num_samples = 0
|
|
|
|
|
|
|
|
print('Evaluate on validation dataset')
|
|
|
|
|
|
|
|
with paddle.no_grad():
|
|
|
|
|
|
|
|
for batch_idx, batch in enumerate(dev_loader):
|
|
|
|
|
|
|
|
waveforms, labels = batch['waveforms'], batch['labels']
|
|
|
|
|
|
|
|
# feats = feature_extractor(waveforms)
|
|
|
|
|
|
|
|
feats = []
|
|
|
|
|
|
|
|
for waveform in waveforms.numpy():
|
|
|
|
|
|
|
|
feat = melspectrogram(x=waveform, **cpu_feat_conf)
|
|
|
|
|
|
|
|
feats.append(feat)
|
|
|
|
|
|
|
|
feats = paddle.to_tensor(np.asarray(feats))
|
|
|
|
|
|
|
|
feats = feature_normalize(
|
|
|
|
|
|
|
|
feats, mean_norm=True, std_norm=False)
|
|
|
|
|
|
|
|
logits = model(feats)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
preds = paddle.argmax(logits, axis=1)
|
|
|
|
|
|
|
|
num_corrects += (preds == labels).numpy().sum()
|
|
|
|
|
|
|
|
num_samples += feats.shape[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print_msg = '[Evaluation result]'
|
|
|
|
|
|
|
|
print_msg += ' dev_acc={:.4f}'.format(num_corrects / num_samples)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(print_msg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Save model
|
|
|
|
|
|
|
|
save_dir = os.path.join(args.checkpoint_dir,
|
|
|
|
|
|
|
|
'epoch_{}'.format(epoch))
|
|
|
|
|
|
|
|
print('Saving model checkpoint to {}'.format(save_dir))
|
|
|
|
|
|
|
|
paddle.save(model.state_dict(),
|
|
|
|
|
|
|
|
os.path.join(save_dir, 'model.pdparams'))
|
|
|
|
|
|
|
|
paddle.save(optimizer.state_dict(),
|
|
|
|
|
|
|
|
os.path.join(save_dir, 'model.pdopt'))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if nranks > 1:
|
|
|
|
|
|
|
|
paddle.distributed.barrier() # Main process
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
if __name__ == "__main__":
|
|
|
|
# yapf: disable
|
|
|
|
# yapf: disable
|
|
|
@ -117,21 +245,29 @@ if __name__ == "__main__":
|
|
|
|
default="./data/",
|
|
|
|
default="./data/",
|
|
|
|
type=str,
|
|
|
|
type=str,
|
|
|
|
help="data directory")
|
|
|
|
help="data directory")
|
|
|
|
parser.add_argument("--learning_rate",
|
|
|
|
parser.add_argument("--learning-rate",
|
|
|
|
type=float,
|
|
|
|
type=float,
|
|
|
|
default=1e-8,
|
|
|
|
default=1e-8,
|
|
|
|
help="Learning rate used to train with warmup.")
|
|
|
|
help="Learning rate used to train with warmup.")
|
|
|
|
parser.add_argument("--load_checkpoint",
|
|
|
|
parser.add_argument("--load-checkpoint",
|
|
|
|
type=str,
|
|
|
|
type=str,
|
|
|
|
default=None,
|
|
|
|
default=None,
|
|
|
|
help="Directory to load model checkpoint to contiune trainning.")
|
|
|
|
help="Directory to load model checkpoint to contiune trainning.")
|
|
|
|
parser.add_argument("--batch_size",
|
|
|
|
parser.add_argument("--batch-size",
|
|
|
|
type=int, default=64,
|
|
|
|
type=int, default=64,
|
|
|
|
help="Total examples' number in batch for training.")
|
|
|
|
help="Total examples' number in batch for training.")
|
|
|
|
parser.add_argument("--num_workers",
|
|
|
|
parser.add_argument("--num-workers",
|
|
|
|
type=int,
|
|
|
|
type=int,
|
|
|
|
default=0,
|
|
|
|
default=0,
|
|
|
|
help="Number of workers in dataloader.")
|
|
|
|
help="Number of workers in dataloader.")
|
|
|
|
|
|
|
|
parser.add_argument("--epochs",
|
|
|
|
|
|
|
|
type=int,
|
|
|
|
|
|
|
|
default=50,
|
|
|
|
|
|
|
|
help="Number of epoches for fine-tuning.")
|
|
|
|
|
|
|
|
parser.add_argument("--log_freq",
|
|
|
|
|
|
|
|
type=int,
|
|
|
|
|
|
|
|
default=10,
|
|
|
|
|
|
|
|
help="Log the training infomation every n steps.")
|
|
|
|
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
args = parser.parse_args()
|
|
|
|
# yapf: enable
|
|
|
|
# yapf: enable
|
|
|
|