repair the code according to the part comment, test=doc

pull/1523/head
xiongxinlei 4 years ago
parent 97ec01260b
commit 016ed6d69c

@ -11,21 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import ast import ast
import os import os
import numpy as np import numpy as np
import paddle import paddle
import paddle.nn.functional as F
from paddle.io import BatchSampler from paddle.io import BatchSampler
from paddle.io import DataLoader from paddle.io import DataLoader
import paddle.nn.functional as F from tqdm import tqdm
from paddlespeech.vector.training.metrics import compute_eer
from paddleaudio.datasets.voxceleb import VoxCeleb1 from paddleaudio.datasets.voxceleb import VoxCeleb1
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
from paddlespeech.vector.training.sid_model import SpeakerIdetification from paddlespeech.vector.modules.sid_model import SpeakerIdetification
from tqdm import tqdm from paddlespeech.vector.training.metrics import compute_eer
def pad_right_2d(x, target_length, axis=-1, mode='constant', **kwargs): def pad_right_2d(x, target_length, axis=-1, mode='constant', **kwargs):
@ -44,7 +44,7 @@ def pad_right_2d(x, target_length, axis=-1, mode='constant', **kwargs):
return np.pad(x, pad_width, mode=mode, **kwargs) return np.pad(x, pad_width, mode=mode, **kwargs)
def feature_normalize(batch, mean_norm: bool = True, std_norm: bool = True): def feature_normalize(batch, mean_norm: bool=True, std_norm: bool=True):
ids = [item['id'] for item in batch] ids = [item['id'] for item in batch]
lengths = np.asarray([item['feat'].shape[1] for item in batch]) lengths = np.asarray([item['feat'].shape[1] for item in batch])
feats = list( feats = list(
@ -58,8 +58,8 @@ def feature_normalize(batch, mean_norm: bool = True, std_norm: bool = True):
mean = feat.mean(axis=-1, keepdims=True) if mean_norm else 0 mean = feat.mean(axis=-1, keepdims=True) if mean_norm else 0
std = feat.std(axis=-1, keepdims=True) if std_norm else 1 std = feat.std(axis=-1, keepdims=True) if std_norm else 1
feats[i][:, :lengths[i]] = (feat - mean) / std feats[i][:, :lengths[i]] = (feat - mean) / std
assert feats[i][:, lengths[i]:].sum( assert feats[i][:, lengths[
) == 0 # Padding valus should all be 0. i]:].sum() == 0 # Padding valus should all be 0.
# Converts into ratios. # Converts into ratios.
lengths = (lengths / lengths.max()).astype(np.float32) lengths = (lengths / lengths.max()).astype(np.float32)
@ -98,16 +98,16 @@ def main(args):
print(f'Checkpoint loaded from {args.load_checkpoint}') print(f'Checkpoint loaded from {args.load_checkpoint}')
# stage4: construct the enroll and test dataloader # stage4: construct the enroll and test dataloader
enrol_ds = VoxCeleb1(subset='enrol', enrol_ds = VoxCeleb1(
feat_type='melspectrogram', subset='enrol',
random_chunk=False, feat_type='melspectrogram',
n_mels=80, random_chunk=False,
window_size=400, n_mels=80,
hop_length=160) window_size=400,
hop_length=160)
enrol_sampler = BatchSampler( enrol_sampler = BatchSampler(
enrol_ds, enrol_ds, batch_size=args.batch_size,
batch_size=args.batch_size, shuffle=True) # Shuffle to make embedding normalization more robust.
shuffle=True) # Shuffle to make embedding normalization more robust.
enrol_loader = DataLoader(enrol_ds, enrol_loader = DataLoader(enrol_ds,
batch_sampler=enrol_sampler, batch_sampler=enrol_sampler,
collate_fn=lambda x: feature_normalize( collate_fn=lambda x: feature_normalize(
@ -115,16 +115,16 @@ def main(args):
num_workers=args.num_workers, num_workers=args.num_workers,
return_list=True,) return_list=True,)
test_ds = VoxCeleb1(subset='test', test_ds = VoxCeleb1(
feat_type='melspectrogram', subset='test',
random_chunk=False, feat_type='melspectrogram',
n_mels=80, random_chunk=False,
window_size=400, n_mels=80,
hop_length=160) window_size=400,
hop_length=160)
test_sampler = BatchSampler(test_ds, test_sampler = BatchSampler(
batch_size=args.batch_size, test_ds, batch_size=args.batch_size, shuffle=True)
shuffle=True)
test_loader = DataLoader(test_ds, test_loader = DataLoader(test_ds,
batch_sampler=test_sampler, batch_sampler=test_sampler,
collate_fn=lambda x: feature_normalize( collate_fn=lambda x: feature_normalize(
@ -169,12 +169,13 @@ def main(args):
embedding_mean, embedding_std = mean, std embedding_mean, embedding_std = mean, std
else: else:
weight = 1 / batch_count # Weight decay by batches. weight = 1 / batch_count # Weight decay by batches.
embedding_mean = ( embedding_mean = (1 - weight
1 - weight) * embedding_mean + weight * mean ) * embedding_mean + weight * mean
embedding_std = ( embedding_std = (1 - weight
1 - weight) * embedding_std + weight * std ) * embedding_std + weight * std
# Apply global embedding normalization. # Apply global embedding normalization.
embeddings = (embeddings - embedding_mean) / embedding_std embeddings = (
embeddings - embedding_mean) / embedding_std
# Update embedding dict. # Update embedding dict.
id2embedding.update(dict(zip(ids, embeddings))) id2embedding.update(dict(zip(ids, embeddings)))
@ -201,38 +202,39 @@ def main(args):
f'EER of verification test: {EER*100:.4f}%, score threshold: {threshold:.5f}' f'EER of verification test: {EER*100:.4f}%, score threshold: {threshold:.5f}'
) )
if __name__ == "__main__": if __name__ == "__main__":
# yapf: disable # yapf: disable
parser = argparse.ArgumentParser(__doc__) parser = argparse.ArgumentParser(__doc__)
parser.add_argument('--device', parser.add_argument('--device',
choices=['cpu', 'gpu'], choices=['cpu', 'gpu'],
default="gpu", default="gpu",
help="Select which device to train model, defaults to gpu.") help="Select which device to train model, defaults to gpu.")
parser.add_argument("--batch-size", parser.add_argument("--batch-size",
type=int, type=int,
default=16, default=16,
help="Total examples' number in batch for training.") help="Total examples' number in batch for training.")
parser.add_argument("--num-workers", parser.add_argument("--num-workers",
type=int, type=int,
default=0, default=0,
help="Number of workers in dataloader.") help="Number of workers in dataloader.")
parser.add_argument("--load-checkpoint", parser.add_argument("--load-checkpoint",
type=str, type=str,
default='', default='',
help="Directory to load model checkpoint to contiune trainning.") help="Directory to load model checkpoint to contiune trainning.")
parser.add_argument("--global-embedding-norm", parser.add_argument("--global-embedding-norm",
type=bool, type=bool,
default=True, default=True,
help="Apply global normalization on speaker embeddings.") help="Apply global normalization on speaker embeddings.")
parser.add_argument("--embedding-mean-norm", parser.add_argument("--embedding-mean-norm",
type=bool, type=bool,
default=True, default=True,
help="Apply mean normalization on speaker embeddings.") help="Apply mean normalization on speaker embeddings.")
parser.add_argument("--embedding-std-norm", parser.add_argument("--embedding-std-norm",
type=bool, type=bool,
default=False, default=False,
help="Apply std normalization on speaker embeddings.") help="Apply std normalization on speaker embeddings.")
args = parser.parse_args() args = parser.parse_args()
# yapf: enable # yapf: enable
main(args) main(args)

@ -22,22 +22,23 @@ from paddle.io import DistributedBatchSampler
from paddleaudio.datasets.voxceleb import VoxCeleb1 from paddleaudio.datasets.voxceleb import VoxCeleb1
from paddleaudio.features.core import melspectrogram from paddleaudio.features.core import melspectrogram
from paddlespeech.vector.training.time import Timer from paddlespeech.vector.io.batch import feature_normalize
from paddlespeech.vector.datasets.batch import feature_normalize from paddlespeech.vector.io.batch import waveform_collate_fn
from paddlespeech.vector.datasets.batch import waveform_collate_fn
from paddlespeech.vector.layers.loss import AdditiveAngularMargin
from paddlespeech.vector.layers.loss import LogSoftmaxWrapper
from paddlespeech.vector.layers.lr import CyclicLRScheduler
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
from paddlespeech.vector.training.sid_model import SpeakerIdetification from paddlespeech.vector.modules.loss import AdditiveAngularMargin
from paddlespeech.vector.modules.loss import LogSoftmaxWrapper
from paddlespeech.vector.modules.lr import CyclicLRScheduler
from paddlespeech.vector.modules.sid_model import SpeakerIdetification
from paddlespeech.vector.utils.time import Timer
# feat configuration # feat configuration
cpu_feat_conf = { cpu_feat_conf = {
'n_mels': 80, 'n_mels': 80,
'window_size': 400, 'window_size': 400, #ms
'hop_length': 160, 'hop_length': 160, #ms
} }
def main(args): def main(args):
# stage0: set the training device, cpu or gpu # stage0: set the training device, cpu or gpu
paddle.set_device(args.device) paddle.set_device(args.device)

@ -76,6 +76,9 @@ class VoxCeleb1(Dataset):
'META_INFO', ('id', 'duration', 'wav', 'start', 'stop', 'spk_id')) 'META_INFO', ('id', 'duration', 'wav', 'start', 'stop', 'spk_id'))
base_path = os.path.join(DATA_HOME, 'vox1') base_path = os.path.join(DATA_HOME, 'vox1')
wav_path = os.path.join(base_path, 'wav') wav_path = os.path.join(base_path, 'wav')
meta_path = os.path.join(base_path, 'meta')
veri_test_file = os.path.join(meta_path, 'veri_test2.txt')
csv_path = os.path.join(base_path, 'csv')
subsets = ['train', 'dev', 'enrol', 'test'] subsets = ['train', 'dev', 'enrol', 'test']
def __init__( def __init__(

@ -22,30 +22,22 @@ from .log import logger
download.logger = logger download.logger = logger
__all__ = [
'decompress',
'download_and_decompress',
'load_state_dict_from_url',
]
def decompress(file: str, path: str=os.PathLike):
def decompress(file: str):
""" """
Extracts all files from a compressed file to specific path. Extracts all files from a compressed file.
""" """
assert os.path.isfile(file), "File: {} not exists.".format(file) assert os.path.isfile(file), "File: {} not exists.".format(file)
download._decompress(file)
if path is None:
print("decompress the data: {}".format(file))
download._decompress(file)
else:
print("decompress the data: {} to {}".format(file, path))
if not os.path.isdir(path):
os.makedirs(path)
tmp_file = os.path.join(path, os.path.basename(file))
os.rename(file, tmp_file)
download._decompress(tmp_file)
os.rename(tmp_file, file)
def download_and_decompress(archives: List[Dict[str, str]], path: str):
def download_and_decompress(archives: List[Dict[str, str]],
path: str,
decompress: bool=True):
""" """
Download archieves and decompress to specific path. Download archieves and decompress to specific path.
""" """
@ -55,8 +47,8 @@ def download_and_decompress(archives: List[Dict[str, str]],
for archive in archives: for archive in archives:
assert 'url' in archive and 'md5' in archive, \ assert 'url' in archive and 'md5' in archive, \
'Dictionary keys of "url" and "md5" are required in the archive, but got: {list(archieve.keys())}' 'Dictionary keys of "url" and "md5" are required in the archive, but got: {list(archieve.keys())}'
download.get_path_from_url(
archive['url'], path, archive['md5'], decompress=decompress) download.get_path_from_url(archive['url'], path, archive['md5'])
def load_state_dict_from_url(url: str, path: str, md5: str=None): def load_state_dict_from_url(url: str, path: str, md5: str=None):
@ -67,4 +59,4 @@ def load_state_dict_from_url(url: str, path: str, md5: str=None):
os.makedirs(path) os.makedirs(path)
download.get_path_from_url(url, path, md5) download.get_path_from_url(url, path, md5)
return load_state_dict(os.path.join(path, os.path.basename(url))) return load_state_dict(os.path.join(path, os.path.basename(url)))

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math import math
import paddle import paddle
@ -67,4 +66,4 @@ class LogSoftmaxWrapper(nn.Layer):
predictions = F.log_softmax(predictions, axis=1) predictions = F.log_softmax(predictions, axis=1)
loss = self.criterion(predictions, targets) / targets.sum() loss = self.criterion(predictions, targets) / targets.sum()
return loss return loss

@ -0,0 +1,28 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import numpy as np
from sklearn.metrics import roc_curve
def compute_eer(labels: np.ndarray, scores: np.ndarray) -> List[float]:
'''
Compute EER and return score threshold.
'''
fpr, tpr, threshold = roc_curve(y_true=labels, y_score=scores)
fnr = 1 - tpr
eer_threshold = threshold[np.nanargmin(np.absolute((fnr - fpr)))]
eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
return eer, eer_threshold

@ -0,0 +1,72 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Dict
from typing import List
from paddle.framework import load as load_state_dict
from paddle.utils import download
__all__ = [
'decompress',
'download_and_decompress',
'load_state_dict_from_url',
]
def decompress(file: str, path: str=os.PathLike):
"""
Extracts all files from a compressed file to specific path.
"""
assert os.path.isfile(file), "File: {} not exists.".format(file)
if path is None:
print("decompress the data: {}".format(file))
download._decompress(file)
else:
print("decompress the data: {} to {}".format(file, path))
if not os.path.isdir(path):
os.makedirs(path)
tmp_file = os.path.join(path, os.path.basename(file))
os.rename(file, tmp_file)
download._decompress(tmp_file)
os.rename(tmp_file, file)
def download_and_decompress(archives: List[Dict[str, str]],
path: str,
decompress: bool=True):
"""
Download archieves and decompress to specific path.
"""
if not os.path.isdir(path):
os.makedirs(path)
for archive in archives:
assert 'url' in archive and 'md5' in archive, \
'Dictionary keys of "url" and "md5" are required in the archive, but got: {list(archieve.keys())}'
download.get_path_from_url(
archive['url'], path, archive['md5'], decompress=decompress)
def load_state_dict_from_url(url: str, path: str, md5: str=None):
"""
Download and load a state dict from url
"""
if not os.path.isdir(path):
os.makedirs(path)
download.get_path_from_url(url, path, md5)
return load_state_dict(os.path.join(path, os.path.basename(url)))
Loading…
Cancel
Save