add training process for sid, test=doc

pull/1523/head
xiongxinlei 2 years ago
parent 7668f61422
commit 4648059b5f

@ -14,11 +14,15 @@
import argparse
import os
import numpy as np
import paddle
from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler
from paddleaudio.datasets.voxceleb import VoxCeleb1
from paddleaudio.features.core import melspectrogram
from paddleaudio.utils.time import Timer
from paddlespeech.vector.datasets.batch import feature_normalize
from paddlespeech.vector.datasets.batch import waveform_collate_fn
from paddlespeech.vector.layers.loss import AdditiveAngularMargin
from paddlespeech.vector.layers.loss import LogSoftmaxWrapper
@ -26,6 +30,13 @@ from paddlespeech.vector.layers.lr import CyclicLRScheduler
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
from paddlespeech.vector.training.sid_model import SpeakerIdetification
# feat configuration
cpu_feat_conf = {
'n_mels': 80,
'window_size': 400,
'hop_length': 160,
}
def main(args):
# stage0: set the training device, cpu or gpu
@ -42,9 +53,10 @@ def main(args):
dev_ds = VoxCeleb1('dev', target_dir=args.data_dir)
# stage3: build the dnn backbone model network
#"channels": [1024, 1024, 1024, 1024, 3072],
model_conf = {
"input_size": 80,
"channels": [1024, 1024, 1024, 1024, 3072],
"channels": [512, 512, 512, 512, 1536],
"kernel_sizes": [5, 3, 3, 3, 1],
"dilations": [1, 2, 3, 4, 1],
"attention_channels": 128,
@ -105,6 +117,122 @@ def main(args):
return_list=True,
use_buffer_reader=True, )
# stage9: start to train
# we will comment the training process
steps_per_epoch = len(train_sampler)
timer = Timer(steps_per_epoch * args.epochs)
timer.start()
for epoch in range(start_epoch + 1, args.epochs + 1):
# at the begining, model must set to train mode
model.train()
avg_loss = 0
num_corrects = 0
num_samples = 0
for batch_idx, batch in enumerate(train_loader):
waveforms, labels = batch['waveforms'], batch['labels']
feats = []
for waveform in waveforms.numpy():
feat = melspectrogram(x=waveform, **cpu_feat_conf)
feats.append(feat)
feats = paddle.to_tensor(np.asarray(feats))
feats = feature_normalize(
feats, mean_norm=True, std_norm=False) # Features normalization
logits = model(feats)
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
if isinstance(optimizer._learning_rate,
paddle.optimizer.lr.LRScheduler):
optimizer._learning_rate.step()
optimizer.clear_grad()
# Calculate loss
avg_loss += loss.numpy()[0]
# Calculate metrics
preds = paddle.argmax(logits, axis=1)
num_corrects += (preds == labels).numpy().sum()
num_samples += feats.shape[0]
timer.count()
if (batch_idx + 1) % args.log_freq == 0 and local_rank == 0:
lr = optimizer.get_lr()
avg_loss /= args.log_freq
avg_acc = num_corrects / num_samples
print_msg = 'Epoch={}/{}, Step={}/{}'.format(
epoch, args.epochs, batch_idx + 1, steps_per_epoch)
print_msg += ' loss={:.4f}'.format(avg_loss)
print_msg += ' acc={:.4f}'.format(avg_acc)
print_msg += ' lr={:.4E} step/sec={:.2f} | ETA {}'.format(
lr, timer.timing, timer.eta)
print(print_msg)
avg_loss = 0
num_corrects = 0
num_samples = 0
if epoch % args.save_freq == 0 and batch_idx + 1 == steps_per_epoch:
if local_rank != 0:
paddle.distributed.barrier(
) # Wait for valid step in main process
continue # Resume trainning on other process
dev_sampler = paddle.io.BatchSampler(
dev_ds,
batch_size=args.batch_size // 4,
shuffle=False,
drop_last=False)
dev_loader = paddle.io.DataLoader(
dev_ds,
batch_sampler=dev_sampler,
collate_fn=waveform_collate_fn,
num_workers=args.num_workers,
return_list=True, )
model.eval()
num_corrects = 0
num_samples = 0
print('Evaluate on validation dataset')
with paddle.no_grad():
for batch_idx, batch in enumerate(dev_loader):
waveforms, labels = batch['waveforms'], batch['labels']
# feats = feature_extractor(waveforms)
feats = []
for waveform in waveforms.numpy():
feat = melspectrogram(x=waveform, **cpu_feat_conf)
feats.append(feat)
feats = paddle.to_tensor(np.asarray(feats))
feats = feature_normalize(
feats, mean_norm=True, std_norm=False)
logits = model(feats)
preds = paddle.argmax(logits, axis=1)
num_corrects += (preds == labels).numpy().sum()
num_samples += feats.shape[0]
print_msg = '[Evaluation result]'
print_msg += ' dev_acc={:.4f}'.format(num_corrects / num_samples)
print(print_msg)
# Save model
save_dir = os.path.join(args.checkpoint_dir,
'epoch_{}'.format(epoch))
print('Saving model checkpoint to {}'.format(save_dir))
paddle.save(model.state_dict(),
os.path.join(save_dir, 'model.pdparams'))
paddle.save(optimizer.state_dict(),
os.path.join(save_dir, 'model.pdopt'))
if nranks > 1:
paddle.distributed.barrier() # Main process
if __name__ == "__main__":
# yapf: disable
@ -117,21 +245,29 @@ if __name__ == "__main__":
default="./data/",
type=str,
help="data directory")
parser.add_argument("--learning_rate",
parser.add_argument("--learning-rate",
type=float,
default=1e-8,
help="Learning rate used to train with warmup.")
parser.add_argument("--load_checkpoint",
parser.add_argument("--load-checkpoint",
type=str,
default=None,
help="Directory to load model checkpoint to contiune trainning.")
parser.add_argument("--batch_size",
parser.add_argument("--batch-size",
type=int, default=64,
help="Total examples' number in batch for training.")
parser.add_argument("--num_workers",
parser.add_argument("--num-workers",
type=int,
default=0,
help="Number of workers in dataloader.")
parser.add_argument("--epochs",
type=int,
default=50,
help="Number of epoches for fine-tuning.")
parser.add_argument("--log_freq",
type=int,
default=10,
help="Log the training infomation every n steps.")
args = parser.parse_args()
# yapf: enable

@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
def waveform_collate_fn(batch):
@ -18,3 +20,14 @@ def waveform_collate_fn(batch):
labels = np.stack([item['label'] for item in batch])
return {'waveforms': waveforms, 'labels': labels}
def feature_normalize(feats: paddle.Tensor,
mean_norm: bool=True,
std_norm: bool=True):
# Features normalization if needed
mean = feats.mean(axis=-1, keepdim=True) if mean_norm else 0
std = feats.std(axis=-1, keepdim=True) if std_norm else 1
feats = (feats - mean) / std
return feats

Loading…
Cancel
Save