From 6a7245657fa088db7ee1c275ae953695e8d77d39 Mon Sep 17 00:00:00 2001 From: qingen Date: Wed, 20 Apr 2022 11:33:25 +0800 Subject: [PATCH] [vec][loss] add FocalLoss to deal with class imbalances, test=doc fix #1721 --- paddlespeech/vector/modules/loss.py | 66 ++++++++++++++++++++++- paddlespeech/vector/utils/vector_utils.py | 1 + 2 files changed, 65 insertions(+), 2 deletions(-) diff --git a/paddlespeech/vector/modules/loss.py b/paddlespeech/vector/modules/loss.py index af38dd01..9a7530c1 100644 --- a/paddlespeech/vector/modules/loss.py +++ b/paddlespeech/vector/modules/loss.py @@ -132,7 +132,7 @@ class NCELoss(nn.Layer): def forward(self, output, target): """Forward inference - + Args: output (tensor): the model output, which is the input of loss function """ @@ -161,7 +161,7 @@ class NCELoss(nn.Layer): """Post processing the score of post model(output of nn) of batchsize data """ scores = self.get_scores(idx, scores) - scale = paddle.to_tensor([self.Z_offset], dtype='float32') + scale = paddle.to_tensor([self.Z_offset], dtype='float64') scores = paddle.add(scores, -scale) prob = paddle.exp(scores) if sep_target: @@ -225,3 +225,65 @@ class NCELoss(nn.Layer): loss = -(model_loss + noise_loss) return loss + + +class FocalLoss(nn.Layer): + """This criterion is a implemenation of Focal Loss, which is proposed in + Focal Loss for Dense Object Detection. + + Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class]) + + The losses are averaged across observations for each minibatch. + + Args: + alpha(1D Tensor, Variable) : the scalar factor for this criterion + gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5), + putting more focus on hard, misclassified examples + size_average(bool): By default, the losses are averaged over observations for each minibatch. + However, if the field size_average is set to False, the losses are + instead summed for each minibatch. + """ + + def __init__(self, alpha=1, gamma=0, size_average=True, ignore_index=-100): + super(FocalLoss, self).__init__() + self.alpha = alpha + self.gamma = gamma + self.size_average = size_average + self.ce = nn.CrossEntropyLoss( + ignore_index=ignore_index, reduction="none") + + def forward(self, outputs, targets): + """Forword inference. + + Args: + outputs: input tensor + target: target label tensor + """ + ce_loss = self.ce(outputs, targets) + pt = paddle.exp(-ce_loss) + focal_loss = self.alpha * (1 - pt)**self.gamma * ce_loss + if self.size_average: + return focal_loss.mean() + else: + return focal_loss.sum() + + +if __name__ == "__main__": + import numpy as np + from paddlespeech.vector.utils.vector_utils import Q_from_tokens + paddle.set_device("cpu") + + input_data = paddle.uniform([5, 100], dtype="float64") + label_data = np.random.randint(0, 100, size=(5)).astype(np.int64) + + input = paddle.to_tensor(input_data) + label = paddle.to_tensor(label_data) + + loss1 = FocalLoss() + loss = loss1.forward(input, label) + print("loss: %.5f" % (loss)) + + Q = Q_from_tokens(100) + loss2 = NCELoss(Q) + loss = loss2.forward(input, label) + print("loss: %.5f" % (loss)) diff --git a/paddlespeech/vector/utils/vector_utils.py b/paddlespeech/vector/utils/vector_utils.py index dcf0f1aa..d6659e3f 100644 --- a/paddlespeech/vector/utils/vector_utils.py +++ b/paddlespeech/vector/utils/vector_utils.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import paddle def get_chunks(seg_dur, audio_id, audio_duration):