LMFLOSS: a new hybrid loss function specifically designed to solve imbalanced medical image classification (with code)

Paper address: https://arxiv.org/pdf/2212.12741.pdf

Code address: https://github.com/SanaNazari/LMFLoss

1.What is it?

LMFLOSS is a hybrid loss function for imbalanced medical image classification. It is composed of a linear combination of Focal Loss and LDAM Loss and is designed to better handle imbalanced data sets. Focal Loss improves model performance by emphasizing hard-to-classify samples, while LDAM Loss takes into account the class distribution of the data set to adjust weights.

2.Why?

First, let’s briefly review how previous methods have solved the problem of category imbalance. Generally speaking, there are two main types, namely data-centric driven and algorithm-centric solutions.

Data Strategy

There are two main data-centric methods to solve class imbalance: oversampling and undersampling. Oversampling attempts to generate artificial data points for a minority class, while undersampling aims to eliminate samples from the majority class.

algorithm strategy

Algorithmic-level strategies, especially in the field of deep learning, mainly focus on developingloss functionsto cope with the class imbalance problem rather than directly manipulating the data. A simple approach is to set corresponding weights for each class so that misclassification of samples from the minority class is penalized more severely than those from the majority class. Another approach is to adaptively set a unique weight for each training sample so that hard samples get higher weights.

The author proposes a new loss function called Large Margin aware Focal (LMF) Loss to alleviate the class imbalance problem in medical imaging. This loss function dynamically considers both hard samples and class distributions.

3.How about

3.1 Focal Loss

When it comes to category-imbalanced loss functions, we have to mention Focal Loss. For classification problems, the cross-entropy loss BCE Loss is commonly used. This loss function treats all categories equally, that is, giving equal weight to learning. Focal Loss is mainly improved by cross-entropy loss, by introducing \alpha and \gamma Two adjustment factors to adjust the number of samples and the difficulty of the samples so that the model focuses on learning the minority class. The specific formula is as follows:

3.2 LDAM Loss

“Learning imbalanced datasets with label-distribution-aware margin loss” This article proposes another work to alleviate the class imbalance problem, called label distribution-aware margin (LDAM) loss. The authors propose to introduce stronger regularization for the minority class than the majority class to reduce their generalization error. In this way, the loss function maintains the model’s ability to learn the majority class and emphasize the minority class. The LDAM loss focuses on the smallest margin per class and obtaining per-class and uniform label test errors, rather than encouraging large margins for most class training samples from the decision boundary. In other words, it only encourages relatively large profits for a minority group. Furthermore, the authors propose a formula for obtaining class-related margins for multiple classes 1, 2,…,k: \gamma _{j} = \frac{C}{n_ {j}1/4^{}}.

Here j∈1,…,k represents a specific class, n_{j}Indicates the number of samples in each category, and C is a fixed constant. Now, let us define a sample pair (x, y), x is the sample, y is the corresponding label, and given a model f. Consider the following function mapping: x_{y}=f(x)_{y};We letu=e^{z_{y}-p_{y}}, here for each category j∈1 ,…,k all have p_{j}=\frac{C}{n_{j}^{1/4}}. Therefore, the LDAM loss can be defined as:

3.3 LMF Loss

Focal Loss creates a mechanism to place more emphasis on samples that are difficult for the model to classify; typically, samples from minority groups will fall into this category. In contrast, LDAM Loss judges weights by considering the class distribution of the dataset. We hypothesize that leveraging both features simultaneously can produce efficient results compared to using each feature individually. Therefore, the Large Margin aware Focal (LMF) loss proposed by the author is a linear combination of Focal loss and LDAM weighted by two hyperparameters, and the formula is as follows:

Here, α and β are constants and are considered to be adjustable hyperparameters. Therefore, the loss function proposed in this paper jointly optimizes two independent loss functions in a single framework. Through trial and error, the authors found that assigning equal weights to both components produced good results.

3.4 Code Implementation

# -*- coding: utf-8 -*-
"""
Created on Wed May 24 17:03:06 2023

@author: Sana
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from ..builder import LOSSES

class FocalLoss(nn.Module):

    def __init__(self, alpha, gamma=2):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, output, target):
        num_classes = output.size(1)
        assert len(self.alpha) == num_classes, \
            'Length of weight tensor must match the number of classes'
        logp = F.cross_entropy(output, target, self.alpha)
        p = torch.exp(-logp)
        focal_loss = (1 - p) ** self.gamma * logp

        return torch.mean(focal_loss)


class LDAMLoss(nn.Module):

    def __init__(self, cls_num_list, max_m=0.5, weight=None, s=30):
        """
        max_m: The appropriate value for max_m depends on the specific dataset and the severity of the class imbalance.
        You can start with a small value and gradually increase it to observe the impact on the model's performance.
        If the model struggles with class separation or experiences underfitting, increasing max_m might help. However,
        be cautious not to set it too high, as it can cause overfitting or make the model too conservative.

        s: The choice of s depends on the desired scale of the logits and the specific requirements of your problem.
        It can be used to adjust the balance between the margin and the original logits. A larger s value amplifies
        the impact of the logits and can be useful when dealing with highly imbalanced datasets.
        You can experiment with different values of s to find the one that works best for your dataset and model.

        """
        super(LDAMLoss, self).__init__()
        m_list = 1.0 / np.sqrt(np.sqrt(cls_num_list))
        m_list = m_list * (max_m / np.max(m_list))
        m_list = torch.cuda.FloatTensor(m_list)
        self.m_list = m_list
        assert s > 0
        self.s = s
        self.weight = weight

    def forward(self, x, target):
        index = torch.zeros_like(x, dtype=torch.uint8)
        index.scatter_(1, target.data.view(-1, 1), 1)

        index_float = index.type(torch.cuda.FloatTensor)
        batch_m = torch.matmul(self.m_list[None, :], index_float.transpose(0, 1))
        batch_m = batch_m.view((-1, 1))
        x_m = x - batch_m

        output = torch.where(index, x_m, x)
        return F.cross_entropy(self.s * output, target, weight=self.weight)

@LOSSES.register_module()
class LMFLoss(nn.Module):
    def __init__(self, cls_num_list, weight, alpha=1, beta=1, gamma=2, max_m=0.5, s=30):
        super().__init__()
        self.focal_loss = FocalLoss(weight, gamma)
        self.ldam_loss = LDAMLoss(cls_num_list, max_m, weight, s)
        self.alpha = alpha
        self.beta = beta

    def forward(self, output, target):
        focal_loss_output = self.focal_loss(output, target)
        ldam_loss_output = self.ldam_loss(output, target)
        total_loss = self.alpha * focal_loss_output + self.beta * ldam_loss_output
        return total_loss

Reference: The successor of Focal Loss | LMFLOSS: a new hybrid loss function dedicated to solving imbalanced medical image classification

The knowledge points of the article match the official knowledge files, and you can further learn related knowledge. Algorithm skill tree Home page Overview 56912 people are learning the system