add some comments to the code

pull/1523/head
xiongxinlei 4 years ago
parent 8ed5c287a3
commit 311fa87a11

@ -1,3 +1,12 @@
###########################################
# Data #
###########################################
batch_size: 32
num_workers: 2
num_speakers: 7205 # 1211 vox1, 5994 vox2, 7205 vox1+2, test speakers: 41
shuffle: True
random_chunk: True
########################################################### ###########################################################
# FEATURE EXTRACTION SETTING # # FEATURE EXTRACTION SETTING #
########################################################### ###########################################################
@ -7,7 +16,6 @@ feature:
window_size: 400 #25ms, sample rate 16000, 25 * 16000 / 1000 = 400 window_size: 400 #25ms, sample rate 16000, 25 * 16000 / 1000 = 400
hop_length: 160 #10ms, sample rate 16000, 10 * 16000 / 1000 = 160 hop_length: 160 #10ms, sample rate 16000, 10 * 16000 / 1000 = 160
########################################################### ###########################################################
# MODEL SETTING # # MODEL SETTING #
########################################################### ###########################################################
@ -15,9 +23,8 @@ feature:
# if we want use another model, please choose another configuration yaml file # if we want use another model, please choose another configuration yaml file
model: model:
input_size: 80 input_size: 80
##"channels": [1024, 1024, 1024, 1024, 3072],
# "channels": [512, 512, 512, 512, 1536], # "channels": [512, 512, 512, 512, 1536],
channels: [512, 512, 512, 512, 1536] channels: [1024, 1024, 1024, 1024, 3072]
kernel_sizes: [5, 3, 3, 3, 1] kernel_sizes: [5, 3, 3, 3, 1]
dilations: [1, 2, 3, 4, 1] dilations: [1, 2, 3, 4, 1]
attention_channels: 128 attention_channels: 128
@ -26,10 +33,9 @@ model:
########################################### ###########################################
# Training # # Training #
########################################### ###########################################
seed: 0 seed: 1986 # according from speechbrain configuration
epochs: 10 epochs: 10
batch_size: 32
num_workers: 2
save_freq: 10 save_freq: 10
log_freq: 10 log_interval: 10
learning_rate: 1e-8 learning_rate: 1e-8

@ -47,7 +47,8 @@ mkdir -p ${exp_dir}
if [ $stage -le 0 ]; then if [ $stage -le 0 ]; then
# stage 0: data prepare for vox1 and vox2, vox2 must be converted from m4a to wav # stage 0: data prepare for vox1 and vox2, vox2 must be converted from m4a to wav
python3 local/data_prepare.py \ python3 local/data_prepare.py \
--data-dir ${dir} --augment --vox2-base-path ${vox2_base_path} --data-dir ${dir} --augment --vox2-base-path ${vox2_base_path} \
--config conf/ecapa_tdnn.yaml
fi fi
if [ $stage -le 1 ]; then if [ $stage -le 1 ]; then

@ -12,5 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .dtw import dtw_distance from .dtw import dtw_distance
from .mcd import mcd_distance
from .eer import compute_eer from .eer import compute_eer
from .eer import compute_minDCF
from .mcd import mcd_distance

@ -14,6 +14,7 @@
from typing import List from typing import List
import numpy as np import numpy as np
import paddle
from sklearn.metrics import roc_curve from sklearn.metrics import roc_curve
@ -26,3 +27,68 @@ def compute_eer(labels: np.ndarray, scores: np.ndarray) -> List[float]:
eer_threshold = threshold[np.nanargmin(np.absolute((fnr - fpr)))] eer_threshold = threshold[np.nanargmin(np.absolute((fnr - fpr)))]
eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))] eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
return eer, eer_threshold return eer, eer_threshold
def compute_minDCF(positive_scores,
negative_scores,
c_miss=1.0,
c_fa=1.0,
p_target=0.01):
"""
This is modified from SpeechBrain
https://github.com/speechbrain/speechbrain/blob/085be635c07f16d42cd1295045bc46c407f1e15b/speechbrain/utils/metric_stats.py#L509
Computes the minDCF metric normally used to evaluate speaker verification
systems. The min_DCF is the minimum of the following C_det function computed
within the defined threshold range:
C_det = c_miss * p_miss * p_target + c_fa * p_fa * (1 -p_target)
where p_miss is the missing probability and p_fa is the probability of having
a false alarm.
Args:
positive_scores (Paddle.Tensor): The scores from entries of the same class.
negative_scores (Paddle.Tensor): The scores from entries of different classes.
c_miss (float, optional): Cost assigned to a missing error (default 1.0).
c_fa (float, optional): Cost assigned to a false alarm (default 1.0).
p_target (float, optional): Prior probability of having a target (default 0.01).
Returns:
_type_: min dcf
"""
# Computing candidate thresholds
if len(positive_scores.shape) > 1:
positive_scores = positive_scores.squeeze()
if len(negative_scores.shape) > 1:
negative_scores = negative_scores.squeeze()
thresholds = paddle.sort(paddle.concat([positive_scores, negative_scores]))
thresholds = paddle.unique(thresholds)
# Adding intermediate thresholds
interm_thresholds = (thresholds[0:-1] + thresholds[1:]) / 2
thresholds = paddle.sort(paddle.concat([thresholds, interm_thresholds]))
# Computing False Rejection Rate (miss detection)
positive_scores = paddle.concat(
len(thresholds) * [positive_scores.unsqueeze(0)])
pos_scores_threshold = positive_scores.transpose(perm=[1, 0]) <= thresholds
p_miss = (pos_scores_threshold.sum(0)
).astype("float32") / positive_scores.shape[1]
del positive_scores
del pos_scores_threshold
# Computing False Acceptance Rate (false alarm)
negative_scores = paddle.concat(
len(thresholds) * [negative_scores.unsqueeze(0)])
neg_scores_threshold = negative_scores.transpose(perm=[1, 0]) > thresholds
p_fa = (neg_scores_threshold.sum(0)
).astype("float32") / negative_scores.shape[1]
del negative_scores
del neg_scores_threshold
c_det = c_miss * p_miss * p_target + c_fa * p_fa * (1 - p_target)
c_min = paddle.min(c_det, axis=0)
min_index = paddle.argmin(c_det, axis=0)
return float(c_min), float(thresholds[min_index])

@ -45,7 +45,7 @@ def main(args, config):
# stage2: build the speaker verification eval instance with backbone model # stage2: build the speaker verification eval instance with backbone model
model = SpeakerIdetification( model = SpeakerIdetification(
backbone=ecapa_tdnn, num_class=VoxCeleb.num_speakers) backbone=ecapa_tdnn, num_class=config.num_speakers)
# stage3: load the pre-trained model # stage3: load the pre-trained model
args.load_checkpoint = os.path.abspath( args.load_checkpoint = os.path.abspath(
@ -93,6 +93,7 @@ def main(args, config):
model.eval() model.eval()
# stage7: global embedding norm to imporve the performance # stage7: global embedding norm to imporve the performance
print("global embedding norm: {}".format(args.global_embedding_norm))
if args.global_embedding_norm: if args.global_embedding_norm:
global_embedding_mean = None global_embedding_mean = None
global_embedding_std = None global_embedding_std = None
@ -118,6 +119,8 @@ def main(args, config):
-1).numpy() # (N, emb_size, 1) -> (N, emb_size) -1).numpy() # (N, emb_size, 1) -> (N, emb_size)
# Global embedding normalization. # Global embedding normalization.
# if we use the global embedding norm
# eer can reduece about relative 10%
if args.global_embedding_norm: if args.global_embedding_norm:
batch_count += 1 batch_count += 1
current_mean = embeddings.mean( current_mean = embeddings.mean(
@ -150,8 +153,8 @@ def main(args, config):
for line in f.readlines(): for line in f.readlines():
label, enrol_id, test_id = line.strip().split(' ') label, enrol_id, test_id = line.strip().split(' ')
labels.append(int(label)) labels.append(int(label))
enrol_ids.append(enrol_id.split('.')[0].replace('/', '-')) enrol_ids.append(enrol_id.split('.')[0].replace('/', '--'))
test_ids.append(test_id.split('.')[0].replace('/', '-')) test_ids.append(test_id.split('.')[0].replace('/', '--'))
cos_sim_func = paddle.nn.CosineSimilarity(axis=1) cos_sim_func = paddle.nn.CosineSimilarity(axis=1)
enrol_embeddings, test_embeddings = map(lambda ids: paddle.to_tensor( enrol_embeddings, test_embeddings = map(lambda ids: paddle.to_tensor(
@ -185,11 +188,10 @@ if __name__ == "__main__":
default='', default='',
help="Directory to load model checkpoint to contiune trainning.") help="Directory to load model checkpoint to contiune trainning.")
parser.add_argument("--global-embedding-norm", parser.add_argument("--global-embedding-norm",
type=bool, default=False,
default=True, action="store_true",
help="Apply global normalization on speaker embeddings.") help="Apply global normalization on speaker embeddings.")
parser.add_argument("--embedding-mean-norm", parser.add_argument("--embedding-mean-norm",
type=bool,
default=True, default=True,
help="Apply mean normalization on speaker embeddings.") help="Apply mean normalization on speaker embeddings.")
parser.add_argument("--embedding-std-norm", parser.add_argument("--embedding-std-norm",

@ -178,9 +178,9 @@ def main(args, config):
timer.count() # step plus one in timer timer.count() # step plus one in timer
# stage 9-10: print the log information only on 0-rank per log-freq batchs # stage 9-10: print the log information only on 0-rank per log-freq batchs
if (batch_idx + 1) % config.log_freq == 0 and local_rank == 0: if (batch_idx + 1) % config.log_interval == 0 and local_rank == 0:
lr = optimizer.get_lr() lr = optimizer.get_lr()
avg_loss /= config.log_freq avg_loss /= config.log_interval
avg_acc = num_corrects / num_samples avg_acc = num_corrects / num_samples
print_msg = 'Train Epoch={}/{}, Step={}/{}'.format( print_msg = 'Train Epoch={}/{}, Step={}/{}'.format(
@ -196,7 +196,7 @@ def main(args, config):
num_samples = 0 num_samples = 0
# stage 9-11: save the model parameters only on 0-rank per save-freq batchs # stage 9-11: save the model parameters only on 0-rank per save-freq batchs
if epoch % config.save_freq == 0 and batch_idx + 1 == steps_per_epoch: if epoch % config.save_interval == 0 and batch_idx + 1 == steps_per_epoch:
if local_rank != 0: if local_rank != 0:
paddle.distributed.barrier( paddle.distributed.barrier(
) # Wait for valid step in main process ) # Wait for valid step in main process

@ -11,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# this is modified from https://github.com/speechbrain/speechbrain/blob/085be635c07f16d42cd1295045bc46c407f1e15b/speechbrain/lobes/augment.py # this is modified from SpeechBrain
# https://github.com/speechbrain/speechbrain/blob/085be635c07f16d42cd1295045bc46c407f1e15b/speechbrain/lobes/augment.py
import math import math
import os import os
from typing import List from typing import List

@ -75,6 +75,9 @@ def batch_feature_normalize(batch, mean_norm: bool=True, std_norm: bool=True):
i]:].sum() == 0 # Padding valus should all be 0. i]:].sum() == 0 # Padding valus should all be 0.
# Converts into ratios. # Converts into ratios.
# the utterance of the max length doesn't need to padding
# the remaining utterances need to padding and all of them will be padded to max length
# we convert the original length of each utterance to the ratio of the max length
lengths = (lengths / lengths.max()).astype(np.float32) lengths = (lengths / lengths.max()).astype(np.float32)
return {'ids': ids, 'feats': feats, 'lengths': lengths} return {'ids': ids, 'feats': feats, 'lengths': lengths}
Loading…
Cancel
Save