check extract embedding result, test=doc

pull/1523/head
xiongxinlei 2 years ago
parent 386ef3f161
commit 14efbf5b15

@ -22,11 +22,11 @@ from paddle.io import BatchSampler
from paddle.io import DataLoader
from tqdm import tqdm
from paddleaudio.backends import load as load_audio
from paddleaudio.datasets.voxceleb import VoxCeleb1
from paddleaudio.features.core import melspectrogram
from paddleaudio.backends import load as load_audio
from paddlespeech.vector.io.batch import feature_normalize
from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.io.batch import feature_normalize
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
from paddlespeech.vector.modules.sid_model import SpeakerIdetification
from paddlespeech.vector.training.metrics import compute_eer
@ -41,6 +41,7 @@ cpu_feat_conf = {
'hop_length': 160, #ms
}
def extract_audio_embedding(args):
# stage 0: set the training device, cpu or gpu
paddle.set_device(args.device)
@ -59,6 +60,8 @@ def extract_audio_embedding(args):
}
ecapa_tdnn = EcapaTdnn(**model_conf)
# stage4: build the speaker verification train instance with backbone model
model = SpeakerIdetification(backbone=ecapa_tdnn, num_class=1211)
# stage 2: load the pre-trained model
args.load_checkpoint = os.path.abspath(
os.path.expanduser(args.load_checkpoint))
@ -71,18 +74,29 @@ def extract_audio_embedding(args):
# stage 3: we must set the model to eval mode
model.eval()
# stage 4: read the audio data and extract the embedding
# wavform is one dimension numpy array
waveform, sr = load_audio(args.audio_path)
# feat type is numpy array, whose shape is [dim, time]
# we need convert the audio feat to one-batch shape [batch, dim, time], where the batch is one
# so the final shape is [1, dim, time]
feat = melspectrogram(x=waveform, **cpu_feat_conf)
feat = paddle.to_tensor(feat).unsqueeze(0)
lengths = paddle.ones([1]) # in paddle inference model, the lengths is all one without padding
feat = feature_normalize(feat, mean_norm=True, std_norm=False)
embedding = ecapa_tdnn(feat, lengths
).squeeze().numpy() # (1, emb_size, 1) -> (emb_size)
# in inference period, the lengths is all one without padding
lengths = paddle.ones([1])
feat = feature_normalize(
feat, mean_norm=True, std_norm=False, convert_to_numpy=True)
# model backbone network forward the feats and get the embedding
embedding = model.backbone(
feat, lengths).squeeze().numpy() # (1, emb_size, 1) -> (emb_size)
# stage 5: do global norm with external mean and std
# todo
# np.save("audio-embedding", embedding)
return embedding

@ -120,7 +120,7 @@ def main(args):
**cpu_feat_conf)
enrol_sampler = BatchSampler(
enrol_ds, batch_size=args.batch_size,
shuffle=False) # Shuffle to make embedding normalization more robust.
shuffle=True) # Shuffle to make embedding normalization more robust.
enrol_loader = DataLoader(enrol_ds,
batch_sampler=enrol_sampler,
collate_fn=lambda x: feature_normalize(
@ -136,7 +136,7 @@ def main(args):
**cpu_feat_conf)
test_sampler = BatchSampler(
test_ds, batch_size=args.batch_size, shuffle=False)
test_ds, batch_size=args.batch_size, shuffle=True)
test_loader = DataLoader(test_ds,
batch_sampler=test_sampler,
collate_fn=lambda x: feature_normalize(

@ -56,10 +56,10 @@ def main(args):
# set the random seed, it is a must for multiprocess training
seed_everything(args.seed)
# stage2: data prepare, such vox1 and vox2 data, and augment data and pipline
# stage2: data prepare, such vox1 and vox2 data, and augment noise data and pipline
# note: some cmd must do in rank==0, so wo will refactor the data prepare code
train_ds = VoxCeleb1('train', target_dir=args.data_dir)
dev_ds = VoxCeleb1('dev', target_dir=args.data_dir)
train_dataset = VoxCeleb1('train', target_dir=args.data_dir)
dev_dataset = VoxCeleb1('dev', target_dir=args.data_dir)
if args.augment:
augment_pipeline = build_augment_pipeline(target_dir=args.data_dir)
@ -123,9 +123,9 @@ def main(args):
# stage8: we build the batch sampler for paddle.DataLoader
train_sampler = DistributedBatchSampler(
train_ds, batch_size=args.batch_size, shuffle=True, drop_last=False)
train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=False)
train_loader = DataLoader(
train_ds,
train_dataset,
batch_sampler=train_sampler,
num_workers=args.num_workers,
collate_fn=waveform_collate_fn,
@ -216,12 +216,12 @@ def main(args):
# stage 9-12: construct the valid dataset dataloader
dev_sampler = BatchSampler(
dev_ds,
dev_dataset,
batch_size=args.batch_size // 4,
shuffle=False,
drop_last=False)
dev_loader = DataLoader(
dev_ds,
dev_dataset,
batch_sampler=dev_sampler,
collate_fn=waveform_collate_fn,
num_workers=args.num_workers,

@ -3,6 +3,8 @@
set -e
#######################################################################
# stage 0: data prepare, including voxceleb1 download and generate {train,dev,enroll,test}.csv
# voxceleb2 data is m4a format, so we need user to convert the m4a to wav yourselves as described in Readme.md
# stage 1: train the speaker identification model
# stage 2: test speaker identification
# stage 3: extract the training embeding to train the LDA and PLDA
@ -12,23 +14,42 @@ set -e
# default the dataset is the ~/.paddleaudio/
# export PPAUDIO_HOME=
stage=2
dir=data/ # data directory
exp_dir=exp/ecapa-tdnn/ # experiment directory
stage=0
dir=data.bak/ # data directory
exp_dir=exp/ecapa-tdnn/ # experiment directory
mkdir -p ${dir}
mkdir -p ${exp_dir}
# if [ $stage -le 0 ]; then
# # stage 0: data prepare for vox1 and vox2, vox2 must be converted from m4a to wav
# # todo
# fi
if [ $stage -le 1 ]; then
# stage 1: train the speaker identification model
python3 \
-m paddle.distributed.launch --gpus=0,1,2,3 \
local/train.py --device "gpu" --checkpoint-dir ${exp_dir} \
--save-freq 10 --data-dir ${dir} --batch-size 256 --epochs 60
local/train.py --device "gpu" --checkpoint-dir ${exp_dir} --augment \
--save-freq 10 --data-dir ${dir} --batch-size 64 --epochs 100
fi
if [ $stage -le 2 ]; then
# stage 1: train the speaker identification model
# you can set the variable PPAUDIO_HOME to specifiy the downloaded the vox1 and vox2 dataset
python3 \
local/speaker_verification_cosine.py\
--batch-size 4 --data-dir ${dir} --load-checkpoint ${exp_dir}/epoch_10/
fi
if [ $stage -le 3 ]; then
# stage 1: train the speaker identification model
# you can set the variable PPAUDIO_HOME to specifiy the downloaded the vox1 and vox2 dataset
python3 \
local/speaker_verification_cosine.py \
--load-checkpoint ${exp_dir}/epoch_40/
local/extract_speaker_embedding.py\
--audio-path "demo/csv/00001.wav" --load-checkpoint ${exp_dir}/epoch_60/
fi
# if [ $stage -le 3 ]; then
# # stage 2: extract the training embeding to train the LDA and PLDA
# # todo: extract the training embedding
# fi

@ -28,7 +28,7 @@ from paddleaudio.backends import load as load_audio
from paddleaudio.datasets.dataset import feat_funcs
from paddleaudio.utils import DATA_HOME
from paddleaudio.utils import decompress
from paddleaudio.utils import download_and_decompress
from paddlespeech.vector.utils.download import download_and_decompress
from paddlespeech.s2t.utils.log import Log
from utils.utility import download
from utils.utility import unpack
@ -106,13 +106,14 @@ class VoxCeleb1(Dataset):
self.chunk_duration = chunk_duration
self.split_ratio = split_ratio
self.target_dir = target_dir if target_dir else self.base_path
self.csv_path = os.path.join(
VoxCeleb1.csv_path = os.path.join(
target_dir, 'csv') if target_dir else os.path.join(self.base_path,
'csv')
self.meta_path = os.path.join(
VoxCeleb1.meta_path = os.path.join(
target_dir, 'meta') if target_dir else os.path.join(self.base_path,
'meta')
self.veri_test_file = os.path.join(self.meta_path, 'veri_test2.txt')
VoxCeleb1.veri_test_file = os.path.join(self.meta_path,
'veri_test2.txt')
# self._data = self._get_data()[:1000] # KP: Small dataset test.
self._data = self._get_data()
super(VoxCeleb1, self).__init__()

@ -24,10 +24,19 @@ def waveform_collate_fn(batch):
def feature_normalize(feats: paddle.Tensor,
mean_norm: bool=True,
std_norm: bool=True):
std_norm: bool=True,
convert_to_numpy: bool=False):
# Features normalization if needed
mean = feats.mean(axis=-1, keepdim=True) if mean_norm else 0
std = feats.std(axis=-1, keepdim=True) if std_norm else 1
feats = (feats - mean) / std
# numpy.mean is a little with paddle.mean about 1e-6
if convert_to_numpy:
feats_np = feats.numpy()
mean = feats_np.mean(axis=-1, keepdims=True) if mean_norm else 0
std = feats_np.std(axis=-1, keepdims=True) if std_norm else 1
feats_np = (feats_np - mean) / std
feats = paddle.to_tensor(feats_np, dtype=feats.dtype)
else:
mean = feats.mean(axis=-1, keepdim=True) if mean_norm else 0
std = feats.std(axis=-1, keepdim=True) if std_norm else 1
feats = (feats - mean) / std
return feats

Loading…
Cancel
Save