[vec][loss] add FocalLoss to deal with class imbalances, test=doc fix #1721

pull/1722/head
qingen 2 years ago
parent 00febff734
commit 6a7245657f

@ -132,7 +132,7 @@ class NCELoss(nn.Layer):
def forward(self, output, target):
"""Forward inference
Args:
output (tensor): the model output, which is the input of loss function
"""
@ -161,7 +161,7 @@ class NCELoss(nn.Layer):
"""Post processing the score of post model(output of nn) of batchsize data
"""
scores = self.get_scores(idx, scores)
scale = paddle.to_tensor([self.Z_offset], dtype='float32')
scale = paddle.to_tensor([self.Z_offset], dtype='float64')
scores = paddle.add(scores, -scale)
prob = paddle.exp(scores)
if sep_target:
@ -225,3 +225,65 @@ class NCELoss(nn.Layer):
loss = -(model_loss + noise_loss)
return loss
class FocalLoss(nn.Layer):
"""This criterion is a implemenation of Focal Loss, which is proposed in
Focal Loss for Dense Object Detection.
Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class])
The losses are averaged across observations for each minibatch.
Args:
alpha(1D Tensor, Variable) : the scalar factor for this criterion
gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5),
putting more focus on hard, misclassified examples
size_average(bool): By default, the losses are averaged over observations for each minibatch.
However, if the field size_average is set to False, the losses are
instead summed for each minibatch.
"""
def __init__(self, alpha=1, gamma=0, size_average=True, ignore_index=-100):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.size_average = size_average
self.ce = nn.CrossEntropyLoss(
ignore_index=ignore_index, reduction="none")
def forward(self, outputs, targets):
"""Forword inference.
Args:
outputs: input tensor
target: target label tensor
"""
ce_loss = self.ce(outputs, targets)
pt = paddle.exp(-ce_loss)
focal_loss = self.alpha * (1 - pt)**self.gamma * ce_loss
if self.size_average:
return focal_loss.mean()
else:
return focal_loss.sum()
if __name__ == "__main__":
import numpy as np
from paddlespeech.vector.utils.vector_utils import Q_from_tokens
paddle.set_device("cpu")
input_data = paddle.uniform([5, 100], dtype="float64")
label_data = np.random.randint(0, 100, size=(5)).astype(np.int64)
input = paddle.to_tensor(input_data)
label = paddle.to_tensor(label_data)
loss1 = FocalLoss()
loss = loss1.forward(input, label)
print("loss: %.5f" % (loss))
Q = Q_from_tokens(100)
loss2 = NCELoss(Q)
loss = loss2.forward(input, label)
print("loss: %.5f" % (loss))

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
def get_chunks(seg_dur, audio_id, audio_duration):

Loading…
Cancel
Save