Merge pull request #1719 from qingen/cluster

[vec][loss] add NCE Loss from RNNLM
pull/1725/head
qingen 2 years ago committed by GitHub
commit 9382ad8a16
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -91,3 +91,137 @@ class LogSoftmaxWrapper(nn.Layer):
predictions = F.log_softmax(predictions, axis=1)
loss = self.criterion(predictions, targets) / targets.sum()
return loss
class NCELoss(nn.Layer):
"""Noise Contrastive Estimation loss funtion
Noise Contrastive Estimation (NCE) is an approximation method that is used to
work around the huge computational cost of large softmax layer.
The basic idea is to convert the prediction problem into classification problem
at training stage. It has been proved that these two criterions converges to
the same minimal point as long as noise distribution is close enough to real one.
NCE bridges the gap between generative models and discriminative models,
rather than simply speedup the softmax layer.
With NCE, you can turn almost anything into posterior with less effort (I think).
Refs:
NCEhttp://www.cs.helsinki.fi/u/ahyvarin/papers/Gutmann10AISTATS.pdf
Thanks: https://github.com/mingen-pan/easy-to-use-NCE-RNN-for-Pytorch/blob/master/nce.py
Examples:
Q = Q_from_tokens(output_dim)
NCELoss(Q)
"""
def __init__(self, Q, noise_ratio=100, Z_offset=9.5):
"""Noise Contrastive Estimation loss funtion
Args:
Q (tensor): prior model, uniform or guassian
noise_ratio (int, optional): noise sampling times. Defaults to 100.
Z_offset (float, optional): scale of post processing the score. Defaults to 9.5.
"""
super(NCELoss, self).__init__()
assert type(noise_ratio) is int
self.Q = paddle.to_tensor(Q, stop_gradient=False)
self.N = self.Q.shape[0]
self.K = noise_ratio
self.Z_offset = Z_offset
def forward(self, output, target):
"""Forward inference
Args:
output (tensor): the model output, which is the input of loss function
"""
output = paddle.reshape(output, [-1, self.N])
B = output.shape[0]
noise_idx = self.get_noise(B)
idx = self.get_combined_idx(target, noise_idx)
P_target, P_noise = self.get_prob(idx, output, sep_target=True)
Q_target, Q_noise = self.get_Q(idx)
loss = self.nce_loss(P_target, P_noise, Q_noise, Q_target)
return loss.mean()
def get_Q(self, idx, sep_target=True):
"""Get prior model of batchsize data
"""
idx_size = idx.size
prob_model = paddle.to_tensor(
self.Q.numpy()[paddle.reshape(idx, [-1]).numpy()])
prob_model = paddle.reshape(prob_model, [idx.shape[0], idx.shape[1]])
if sep_target:
return prob_model[:, 0], prob_model[:, 1:]
else:
return prob_model
def get_prob(self, idx, scores, sep_target=True):
"""Post processing the score of post model(output of nn) of batchsize data
"""
scores = self.get_scores(idx, scores)
scale = paddle.to_tensor([self.Z_offset], dtype='float32')
scores = paddle.add(scores, -scale)
prob = paddle.exp(scores)
if sep_target:
return prob[:, 0], prob[:, 1:]
else:
return prob
def get_scores(self, idx, scores):
"""Get the score of post model(output of nn) of batchsize data
"""
B, N = scores.shape
K = idx.shape[1]
idx_increment = paddle.to_tensor(
N * paddle.reshape(paddle.arange(B), [B, 1]) * paddle.ones([1, K]),
dtype="int64",
stop_gradient=False)
new_idx = idx_increment + idx
new_scores = paddle.index_select(
paddle.reshape(scores, [-1]), paddle.reshape(new_idx, [-1]))
return paddle.reshape(new_scores, [B, K])
def get_noise(self, batch_size, uniform=True):
"""Select noise sample
"""
if uniform:
noise = np.random.randint(self.N, size=self.K * batch_size)
else:
noise = np.random.choice(
self.N, self.K * batch_size, replace=True, p=self.Q.data)
noise = paddle.to_tensor(noise, dtype='int64', stop_gradient=False)
noise_idx = paddle.reshape(noise, [batch_size, self.K])
return noise_idx
def get_combined_idx(self, target_idx, noise_idx):
"""Combined target and noise
"""
target_idx = paddle.reshape(target_idx, [-1, 1])
return paddle.concat((target_idx, noise_idx), 1)
def nce_loss(self, prob_model, prob_noise_in_model, prob_noise,
prob_target_in_noise):
"""Combined the loss of target and noise
"""
def safe_log(tensor):
"""Safe log
"""
EPSILON = 1e-10
return paddle.log(EPSILON + tensor)
model_loss = safe_log(prob_model /
(prob_model + self.K * prob_target_in_noise))
model_loss = paddle.reshape(model_loss, [-1])
noise_loss = paddle.sum(
safe_log((self.K * prob_noise) /
(prob_noise_in_model + self.K * prob_noise)), -1)
noise_loss = paddle.reshape(noise_loss, [-1])
loss = -(model_loss + noise_loss)
return loss

@ -30,3 +30,11 @@ def get_chunks(seg_dur, audio_id, audio_duration):
for i in range(num_chunks)
]
return chunk_lst
def Q_from_tokens(token_num):
"""Get prior model, data from uniform, would support others(guassian) in future
"""
freq = [1] * token_num
Q = paddle.to_tensor(freq, dtype='float64')
return Q / Q.sum()

Loading…
Cancel
Save