|
|
@ -13,9 +13,9 @@
|
|
|
|
# limitations under the License.
|
|
|
|
# limitations under the License.
|
|
|
|
import argparse
|
|
|
|
import argparse
|
|
|
|
import os
|
|
|
|
import os
|
|
|
|
|
|
|
|
import time
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import numpy as np
|
|
|
|
import time
|
|
|
|
|
|
|
|
import paddle
|
|
|
|
import paddle
|
|
|
|
from paddle.io import BatchSampler
|
|
|
|
from paddle.io import BatchSampler
|
|
|
|
from paddle.io import DataLoader
|
|
|
|
from paddle.io import DataLoader
|
|
|
@ -27,6 +27,7 @@ from paddleaudio.datasets.voxceleb import VoxCeleb
|
|
|
|
from paddlespeech.s2t.utils.log import Log
|
|
|
|
from paddlespeech.s2t.utils.log import Log
|
|
|
|
from paddlespeech.vector.io.augment import build_augment_pipeline
|
|
|
|
from paddlespeech.vector.io.augment import build_augment_pipeline
|
|
|
|
from paddlespeech.vector.io.augment import waveform_augment
|
|
|
|
from paddlespeech.vector.io.augment import waveform_augment
|
|
|
|
|
|
|
|
from paddlespeech.vector.io.batch import batch_pad_right
|
|
|
|
from paddlespeech.vector.io.batch import feature_normalize
|
|
|
|
from paddlespeech.vector.io.batch import feature_normalize
|
|
|
|
from paddlespeech.vector.io.batch import waveform_collate_fn
|
|
|
|
from paddlespeech.vector.io.batch import waveform_collate_fn
|
|
|
|
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
|
|
|
|
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
|
|
|
@ -36,7 +37,6 @@ from paddlespeech.vector.modules.sid_model import SpeakerIdetification
|
|
|
|
from paddlespeech.vector.training.scheduler import CyclicLRScheduler
|
|
|
|
from paddlespeech.vector.training.scheduler import CyclicLRScheduler
|
|
|
|
from paddlespeech.vector.training.seeding import seed_everything
|
|
|
|
from paddlespeech.vector.training.seeding import seed_everything
|
|
|
|
from paddlespeech.vector.utils.time import Timer
|
|
|
|
from paddlespeech.vector.utils.time import Timer
|
|
|
|
from paddlespeech.vector.io.batch import batch_pad_right
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger = Log(__name__).getlog()
|
|
|
|
logger = Log(__name__).getlog()
|
|
|
|
|
|
|
|
|
|
|
@ -165,7 +165,8 @@ def main(args, config):
|
|
|
|
# stage 9-3: extract the audio feats,such fbank, mfcc, spectrogram
|
|
|
|
# stage 9-3: extract the audio feats,such fbank, mfcc, spectrogram
|
|
|
|
feats = []
|
|
|
|
feats = []
|
|
|
|
for waveform in waveforms.numpy():
|
|
|
|
for waveform in waveforms.numpy():
|
|
|
|
feat = melspectrogram(x=waveform,
|
|
|
|
feat = melspectrogram(
|
|
|
|
|
|
|
|
x=waveform,
|
|
|
|
sr=config.sr,
|
|
|
|
sr=config.sr,
|
|
|
|
n_mels=config.n_mels,
|
|
|
|
n_mels=config.n_mels,
|
|
|
|
window_size=config.window_size,
|
|
|
|
window_size=config.window_size,
|
|
|
@ -213,9 +214,12 @@ def main(args, config):
|
|
|
|
epoch, config.epochs, batch_idx + 1, steps_per_epoch)
|
|
|
|
epoch, config.epochs, batch_idx + 1, steps_per_epoch)
|
|
|
|
print_msg += ' loss={:.4f}'.format(avg_loss)
|
|
|
|
print_msg += ' loss={:.4f}'.format(avg_loss)
|
|
|
|
print_msg += ' acc={:.4f}'.format(avg_acc)
|
|
|
|
print_msg += ' acc={:.4f}'.format(avg_acc)
|
|
|
|
print_msg += ' avg_reader_cost: {:.5f} sec,'.format(train_reader_cost / config.log_interval)
|
|
|
|
print_msg += ' avg_reader_cost: {:.5f} sec,'.format(
|
|
|
|
print_msg += ' avg_feat_cost: {:.5f} sec,'.format(train_feat_cost / config.log_interval)
|
|
|
|
train_reader_cost / config.log_interval)
|
|
|
|
print_msg += ' avg_train_cost: {:.5f} sec,'.format(train_run_cost / config.log_interval)
|
|
|
|
print_msg += ' avg_feat_cost: {:.5f} sec,'.format(
|
|
|
|
|
|
|
|
train_feat_cost / config.log_interval)
|
|
|
|
|
|
|
|
print_msg += ' avg_train_cost: {:.5f} sec,'.format(
|
|
|
|
|
|
|
|
train_run_cost / config.log_interval)
|
|
|
|
print_msg += ' lr={:.4E} step/sec={:.2f} | ETA {}'.format(
|
|
|
|
print_msg += ' lr={:.4E} step/sec={:.2f} | ETA {}'.format(
|
|
|
|
lr, timer.timing, timer.eta)
|
|
|
|
lr, timer.timing, timer.eta)
|
|
|
|
logger.info(print_msg)
|
|
|
|
logger.info(print_msg)
|
|
|
@ -262,7 +266,8 @@ def main(args, config):
|
|
|
|
|
|
|
|
|
|
|
|
feats = []
|
|
|
|
feats = []
|
|
|
|
for waveform in waveforms.numpy():
|
|
|
|
for waveform in waveforms.numpy():
|
|
|
|
feat = melspectrogram(x=waveform,
|
|
|
|
feat = melspectrogram(
|
|
|
|
|
|
|
|
x=waveform,
|
|
|
|
sr=config.sr,
|
|
|
|
sr=config.sr,
|
|
|
|
n_mels=config.n_mels,
|
|
|
|
n_mels=config.n_mels,
|
|
|
|
window_size=config.window_size,
|
|
|
|
window_size=config.window_size,
|
|
|
@ -285,7 +290,8 @@ def main(args, config):
|
|
|
|
# stage 9-14: Save model parameters
|
|
|
|
# stage 9-14: Save model parameters
|
|
|
|
save_dir = os.path.join(args.checkpoint_dir,
|
|
|
|
save_dir = os.path.join(args.checkpoint_dir,
|
|
|
|
'epoch_{}'.format(epoch))
|
|
|
|
'epoch_{}'.format(epoch))
|
|
|
|
last_saved_epoch = os.path.join('epoch_{}'.format(epoch), "model.pdparams")
|
|
|
|
last_saved_epoch = os.path.join('epoch_{}'.format(epoch),
|
|
|
|
|
|
|
|
"model.pdparams")
|
|
|
|
logger.info('Saving model checkpoint to {}'.format(save_dir))
|
|
|
|
logger.info('Saving model checkpoint to {}'.format(save_dir))
|
|
|
|
paddle.save(model.state_dict(),
|
|
|
|
paddle.save(model.state_dict(),
|
|
|
|
os.path.join(save_dir, 'model.pdparams'))
|
|
|
|
os.path.join(save_dir, 'model.pdparams'))
|
|
|
@ -300,10 +306,13 @@ def main(args, config):
|
|
|
|
final_model = os.path.join(args.checkpoint_dir, "model.pdparams")
|
|
|
|
final_model = os.path.join(args.checkpoint_dir, "model.pdparams")
|
|
|
|
logger.info(f"we will create the final model: {final_model}")
|
|
|
|
logger.info(f"we will create the final model: {final_model}")
|
|
|
|
if os.path.islink(final_model):
|
|
|
|
if os.path.islink(final_model):
|
|
|
|
logger.info(f"An {final_model} already exists, we will rm is and create it again")
|
|
|
|
logger.info(
|
|
|
|
|
|
|
|
f"An {final_model} already exists, we will rm is and create it again"
|
|
|
|
|
|
|
|
)
|
|
|
|
os.unlink(final_model)
|
|
|
|
os.unlink(final_model)
|
|
|
|
os.symlink(last_saved_epoch, final_model)
|
|
|
|
os.symlink(last_saved_epoch, final_model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
if __name__ == "__main__":
|
|
|
|
# yapf: disable
|
|
|
|
# yapf: disable
|
|
|
|
parser = argparse.ArgumentParser(__doc__)
|
|
|
|
parser = argparse.ArgumentParser(__doc__)
|
|
|
|