diff --git a/paddlespeech/vector/modules/loss.py b/paddlespeech/vector/modules/loss.py index 1c80dda4..015c0dfe 100644 --- a/paddlespeech/vector/modules/loss.py +++ b/paddlespeech/vector/modules/loss.py @@ -91,3 +91,134 @@ class LogSoftmaxWrapper(nn.Layer): predictions = F.log_softmax(predictions, axis=1) loss = self.criterion(predictions, targets) / targets.sum() return loss + + +class NCELoss(nn.Layer): + """Noise Contrastive Estimation loss funtion + + Noise Contrastive Estimation (NCE) is an approximation method that is used to + work around the huge computational cost of large softmax layer. + The basic idea is to convert the prediction problem into classification problem + at training stage. It has been proved that these two criterions converges to + the same minimal point as long as noise distribution is close enough to real one. + + NCE bridges the gap between generative models and discriminative models, + rather than simply speedup the softmax layer. + With NCE, you can turn almost anything into posterior with less effort (I think). + + Refs: + NCE:http://www.cs.helsinki.fi/u/ahyvarin/papers/Gutmann10AISTATS.pdf + Thanks: https://github.com/mingen-pan/easy-to-use-NCE-RNN-for-Pytorch/blob/master/nce.py + + Examples: + Q = Q_from_tokens(output_dim) + NCELoss(Q) + """ + + def __init__(self, Q, noise_ratio=100, Z_offset=9.5): + """Noise Contrastive Estimation loss funtion + + Args: + Q (tensor): prior model, uniform or guassian + noise_ratio (int, optional): noise sampling times. Defaults to 100. + Z_offset (float, optional): scale of post processing the score. Defaults to 9.5. + """ + super(NCELoss, self).__init__() + assert type(noise_ratio) is int + self.Q = paddle.to_tensor(Q, stop_gradient=False) + self.N = self.Q.shape[0] + self.K = noise_ratio + self.Z_offset = Z_offset + + def forward(self, output, target): + """Forward inference + """ + output = paddle.reshape(output, [-1, self.N]) + B = output.shape[0] + noise_idx = self.get_noise(B) + idx = self.get_combined_idx(target, noise_idx) + P_target, P_noise = self.get_prob(idx, output, sep_target=True) + Q_target, Q_noise = self.get_Q(idx) + loss = self.nce_loss(P_target, P_noise, Q_noise, Q_target) + return loss.mean() + + def get_Q(self, idx, sep_target=True): + """Get prior model of batchsize data + """ + idx_size = idx.size + prob_model = paddle.to_tensor( + self.Q.numpy()[paddle.reshape(idx, [-1]).numpy()]) + prob_model = paddle.reshape(prob_model, [idx.shape[0], idx.shape[1]]) + if sep_target: + return prob_model[:, 0], prob_model[:, 1:] + else: + return prob_model + + def get_prob(self, idx, scores, sep_target=True): + """Post processing the score of post model(output of nn) of batchsize data + """ + scores = self.get_scores(idx, scores) + scale = paddle.to_tensor([self.Z_offset], dtype='float32') + scores = paddle.add(scores, -scale) + prob = paddle.exp(scores) + if sep_target: + return prob[:, 0], prob[:, 1:] + else: + return prob + + def get_scores(self, idx, scores): + """Get the score of post model(output of nn) of batchsize data + """ + B, N = scores.shape + K = idx.shape[1] + idx_increment = paddle.to_tensor( + N * paddle.reshape(paddle.arange(B), [B, 1]) * paddle.ones([1, K]), + dtype="int64", + stop_gradient=False) + new_idx = idx_increment + idx + new_scores = paddle.index_select( + paddle.reshape(scores, [-1]), paddle.reshape(new_idx, [-1])) + + return paddle.reshape(new_scores, [B, K]) + + def get_noise(self, batch_size, uniform=True): + """Select noise sample + """ + if uniform: + noise = np.random.randint(self.N, size=self.K * batch_size) + else: + noise = np.random.choice( + self.N, self.K * batch_size, replace=True, p=self.Q.data) + noise = paddle.to_tensor(noise, dtype='int64', stop_gradient=False) + noise_idx = paddle.reshape(noise, [batch_size, self.K]) + return noise_idx + + def get_combined_idx(self, target_idx, noise_idx): + """Combined target and noise + """ + target_idx = paddle.reshape(target_idx, [-1, 1]) + return paddle.concat((target_idx, noise_idx), 1) + + def nce_loss(self, prob_model, prob_noise_in_model, prob_noise, + prob_target_in_noise): + """Combined the loss of target and noise + """ + + def safe_log(tensor): + """Safe log + """ + EPSILON = 1e-10 + return paddle.log(EPSILON + tensor) + + model_loss = safe_log(prob_model / + (prob_model + self.K * prob_target_in_noise)) + model_loss = paddle.reshape(model_loss, [-1]) + + noise_loss = paddle.sum( + safe_log((self.K * prob_noise) / + (prob_noise_in_model + self.K * prob_noise)), -1) + noise_loss = paddle.reshape(noise_loss, [-1]) + + loss = -(model_loss + noise_loss) + + return loss diff --git a/paddlespeech/vector/utils/vector_utils.py b/paddlespeech/vector/utils/vector_utils.py index 46de7ffa..dcf0f1aa 100644 --- a/paddlespeech/vector/utils/vector_utils.py +++ b/paddlespeech/vector/utils/vector_utils.py @@ -30,3 +30,11 @@ def get_chunks(seg_dur, audio_id, audio_duration): for i in range(num_chunks) ] return chunk_lst + + +def Q_from_tokens(token_num): + """Get prior model, data from uniform, would support others(guassian) in future + """ + freq = [1] * token_num + Q = paddle.to_tensor(freq, dtype='float64') + return Q / Q.sum()