Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify the Dice loss #3825

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from

Conversation

zifuwanggg
Copy link

@zifuwanggg zifuwanggg commented Oct 11, 2024

PR types

[Bug fixes]

PR changes

[Models]

Description

The Dice loss in paddleseg.models.losses.dice_loss and paddleseg.models.losses.maskformer_loss is modified based on JDTLoss and segmentation_models.pytorch.

The original Dice loss is incompatible with soft labels. For example, with a ground truth value of 0.5 for a single pixel, it is minimized when the predicted value is 1, which is clearly erroneous. To address this, the intersection term is rewritten as $\frac{|x|_1 + |y|_1 - |x-y|_1}{2}$. This reformulation has been proven to retain equivalence with the original version when the ground truth is binary (i.e. one-hot hard labels). Moreover, since the new version is minimized if and only if the prediction is identical to the ground truth, even when the ground truth include fractional numbers, it resolves the issue with soft labels [1, 2].

In summary, there are three scenarios:

  • [Scenario 1] $x$ is nonnegative and $y$ is binary: The new version is the same as the original version.
  • [Scenario 2] Both $x$ and $y$ are nonnegative: The new version differs from the original version. The new version is minimized if and only if $x=y$, while the original version may not, making it incorrect.
  • [Scenario 3] Either $x$ or $y$ is negative: The new version differs from the original version. The new version is minimized if and only if $x=y$, while the original version may not, making it incorrect.

Due to these differences, particularly in Scenarios 2 and 3, some tests may fail with the new version. The failures are expected since the original version is incorrectly defined for non-binary ground truth.

Example

import paddle
import paddle.nn.functional as F

paddle.seed(0)

b, c, h, w = 4, 3, 32, 32
axis = (0, 2, 3)

pred = F.softmax(paddle.rand((b, c, h, w)), axis=1)
soft_label = F.softmax(paddle.rand((b, c, h, w)), axis=1)
hard_label = paddle.randint(low=0, high=c, shape=(b, h, w))
one_hot_label = paddle.transpose(F.one_hot(hard_label, c), perm=(0, 3, 1, 2))

def dice_old(x, y, axis):
    cardinality = paddle.sum(x, axis=axis) + paddle.sum(y, axis=axis)
    intersection = paddle.sum(x * y, axis=axis)
    return 2 * intersection / cardinality

def dice_new(x, y, axis):
    cardinality = paddle.sum(x, axis=axis) + paddle.sum(y, axis=axis)
    difference = paddle.sum(paddle.abs(x - y), axis=axis)
    intersection = (cardinality - difference) / 2
    return 2 * intersection / cardinality
  
print(dice_old(pred, one_hot_label, axis), dice_new(pred, one_hot_label, axis))
print(dice_old(pred, soft_label, axis), dice_new(pred, soft_label, axis))
print(dice_old(pred, pred, axis), dice_new(pred, pred, axis))

# tensor([0.3356, 0.3308, 0.3319]) tensor([0.3356, 0.3308, 0.3319])
# tensor([0.3326, 0.3323, 0.3340]) tensor([0.8668, 0.8670, 0.8675])
# tensor([0.3505, 0.3512, 0.3513]) tensor([1., 1., 1.])

References

[1] Dice Semimetric Losses: Optimizing the Dice Score with Soft Labels. Zifu Wang, Teodora Popordanoska, Jeroen Bertels, Robin Lemmens, Matthew B. Blaschko. MICCAI 2023.

[2] Jaccard Metric Losses: Optimizing the Jaccard Index with Soft Labels. Zifu Wang, Xuefei Ning, Matthew B. Blaschko. NeurIPS 2023.

Copy link

paddle-bot bot commented Oct 11, 2024

Thanks for your contribution!

@CLAassistant
Copy link

CLAassistant commented Oct 11, 2024

CLA assistant check
All committers have signed the CLA.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants