exec pre-commit in paddlespeech vector, test=doc

pull/1523/head
xiongxinlei 3 years ago
parent 9874fb7d75
commit d85d1deef5

@ -13,9 +13,8 @@
# limitations under the License. # limitations under the License.
import argparse import argparse
import os import os
import time import time
import numpy as np
import paddle import paddle
from yacs.config import CfgNode from yacs.config import CfgNode
@ -40,7 +39,8 @@ def extract_audio_embedding(args, config):
ecapa_tdnn = EcapaTdnn(**config.model) ecapa_tdnn = EcapaTdnn(**config.model)
# stage4: build the speaker verification train instance with backbone model # stage4: build the speaker verification train instance with backbone model
model = SpeakerIdetification(backbone=ecapa_tdnn, num_class=config.num_speakers) model = SpeakerIdetification(
backbone=ecapa_tdnn, num_class=config.num_speakers)
# stage 2: load the pre-trained model # stage 2: load the pre-trained model
args.load_checkpoint = os.path.abspath( args.load_checkpoint = os.path.abspath(
os.path.expanduser(args.load_checkpoint)) os.path.expanduser(args.load_checkpoint))
@ -62,17 +62,17 @@ def extract_audio_embedding(args, config):
# we need convert the audio feat to one-batch shape [batch, dim, time], where the batch is one # we need convert the audio feat to one-batch shape [batch, dim, time], where the batch is one
# so the final shape is [1, dim, time] # so the final shape is [1, dim, time]
start_time = time.time() start_time = time.time()
feat = melspectrogram(x=waveform, feat = melspectrogram(
sr=config.sr, x=waveform,
n_mels=config.n_mels, sr=config.sr,
window_size=config.window_size, n_mels=config.n_mels,
hop_length=config.hop_size) window_size=config.window_size,
hop_length=config.hop_size)
feat = paddle.to_tensor(feat).unsqueeze(0) feat = paddle.to_tensor(feat).unsqueeze(0)
# in inference period, the lengths is all one without padding # in inference period, the lengths is all one without padding
lengths = paddle.ones([1]) lengths = paddle.ones([1])
feat = feature_normalize( feat = feature_normalize(feat, mean_norm=True, std_norm=False)
feat, mean_norm=True, std_norm=False)
# model backbone network forward the feats and get the embedding # model backbone network forward the feats and get the embedding
embedding = model.backbone( embedding = model.backbone(
@ -80,7 +80,6 @@ def extract_audio_embedding(args, config):
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
audio_length = waveform.shape[0] / sr audio_length = waveform.shape[0] / sr
# stage 5: do global norm with external mean and std # stage 5: do global norm with external mean and std
rtf = elapsed_time / audio_length rtf = elapsed_time / audio_length
logger.info(f"{args.device} rft={rtf}") logger.info(f"{args.device} rft={rtf}")

@ -13,9 +13,9 @@
# limitations under the License. # limitations under the License.
import argparse import argparse
import os import os
import time
import numpy as np import numpy as np
import time
import paddle import paddle
from paddle.io import BatchSampler from paddle.io import BatchSampler
from paddle.io import DataLoader from paddle.io import DataLoader
@ -27,6 +27,7 @@ from paddleaudio.datasets.voxceleb import VoxCeleb
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.io.augment import build_augment_pipeline from paddlespeech.vector.io.augment import build_augment_pipeline
from paddlespeech.vector.io.augment import waveform_augment from paddlespeech.vector.io.augment import waveform_augment
from paddlespeech.vector.io.batch import batch_pad_right
from paddlespeech.vector.io.batch import feature_normalize from paddlespeech.vector.io.batch import feature_normalize
from paddlespeech.vector.io.batch import waveform_collate_fn from paddlespeech.vector.io.batch import waveform_collate_fn
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
@ -36,7 +37,6 @@ from paddlespeech.vector.modules.sid_model import SpeakerIdetification
from paddlespeech.vector.training.scheduler import CyclicLRScheduler from paddlespeech.vector.training.scheduler import CyclicLRScheduler
from paddlespeech.vector.training.seeding import seed_everything from paddlespeech.vector.training.seeding import seed_everything
from paddlespeech.vector.utils.time import Timer from paddlespeech.vector.utils.time import Timer
from paddlespeech.vector.io.batch import batch_pad_right
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
@ -165,11 +165,12 @@ def main(args, config):
# stage 9-3: extract the audio feats,such fbank, mfcc, spectrogram # stage 9-3: extract the audio feats,such fbank, mfcc, spectrogram
feats = [] feats = []
for waveform in waveforms.numpy(): for waveform in waveforms.numpy():
feat = melspectrogram(x=waveform, feat = melspectrogram(
sr=config.sr, x=waveform,
n_mels=config.n_mels, sr=config.sr,
window_size=config.window_size, n_mels=config.n_mels,
hop_length=config.hop_size) window_size=config.window_size,
hop_length=config.hop_size)
feats.append(feat) feats.append(feat)
feats = paddle.to_tensor(np.asarray(feats)) feats = paddle.to_tensor(np.asarray(feats))
@ -213,9 +214,12 @@ def main(args, config):
epoch, config.epochs, batch_idx + 1, steps_per_epoch) epoch, config.epochs, batch_idx + 1, steps_per_epoch)
print_msg += ' loss={:.4f}'.format(avg_loss) print_msg += ' loss={:.4f}'.format(avg_loss)
print_msg += ' acc={:.4f}'.format(avg_acc) print_msg += ' acc={:.4f}'.format(avg_acc)
print_msg += ' avg_reader_cost: {:.5f} sec,'.format(train_reader_cost / config.log_interval) print_msg += ' avg_reader_cost: {:.5f} sec,'.format(
print_msg += ' avg_feat_cost: {:.5f} sec,'.format(train_feat_cost / config.log_interval) train_reader_cost / config.log_interval)
print_msg += ' avg_train_cost: {:.5f} sec,'.format(train_run_cost / config.log_interval) print_msg += ' avg_feat_cost: {:.5f} sec,'.format(
train_feat_cost / config.log_interval)
print_msg += ' avg_train_cost: {:.5f} sec,'.format(
train_run_cost / config.log_interval)
print_msg += ' lr={:.4E} step/sec={:.2f} | ETA {}'.format( print_msg += ' lr={:.4E} step/sec={:.2f} | ETA {}'.format(
lr, timer.timing, timer.eta) lr, timer.timing, timer.eta)
logger.info(print_msg) logger.info(print_msg)
@ -262,11 +266,12 @@ def main(args, config):
feats = [] feats = []
for waveform in waveforms.numpy(): for waveform in waveforms.numpy():
feat = melspectrogram(x=waveform, feat = melspectrogram(
sr=config.sr, x=waveform,
n_mels=config.n_mels, sr=config.sr,
window_size=config.window_size, n_mels=config.n_mels,
hop_length=config.hop_size) window_size=config.window_size,
hop_length=config.hop_size)
feats.append(feat) feats.append(feat)
feats = paddle.to_tensor(np.asarray(feats)) feats = paddle.to_tensor(np.asarray(feats))
@ -285,7 +290,8 @@ def main(args, config):
# stage 9-14: Save model parameters # stage 9-14: Save model parameters
save_dir = os.path.join(args.checkpoint_dir, save_dir = os.path.join(args.checkpoint_dir,
'epoch_{}'.format(epoch)) 'epoch_{}'.format(epoch))
last_saved_epoch = os.path.join('epoch_{}'.format(epoch), "model.pdparams") last_saved_epoch = os.path.join('epoch_{}'.format(epoch),
"model.pdparams")
logger.info('Saving model checkpoint to {}'.format(save_dir)) logger.info('Saving model checkpoint to {}'.format(save_dir))
paddle.save(model.state_dict(), paddle.save(model.state_dict(),
os.path.join(save_dir, 'model.pdparams')) os.path.join(save_dir, 'model.pdparams'))
@ -300,10 +306,13 @@ def main(args, config):
final_model = os.path.join(args.checkpoint_dir, "model.pdparams") final_model = os.path.join(args.checkpoint_dir, "model.pdparams")
logger.info(f"we will create the final model: {final_model}") logger.info(f"we will create the final model: {final_model}")
if os.path.islink(final_model): if os.path.islink(final_model):
logger.info(f"An {final_model} already exists, we will rm is and create it again") logger.info(
f"An {final_model} already exists, we will rm is and create it again"
)
os.unlink(final_model) os.unlink(final_model)
os.symlink(last_saved_epoch, final_model) os.symlink(last_saved_epoch, final_model)
if __name__ == "__main__": if __name__ == "__main__":
# yapf: disable # yapf: disable
parser = argparse.ArgumentParser(__doc__) parser = argparse.ArgumentParser(__doc__)

Loading…
Cancel
Save