| | import torch |
| | import torch.nn as nn |
| | from torch import Tensor, einsum |
| | import torch.nn .functional as F |
| | from misc.torchutils import class2one_hot,simplex |
| | from models.darnet_help.loss_help import FocalLoss, dernet_dice_loss |
| |
|
| | def cross_entropy(input, target, weight=None, reduction='mean',ignore_index=255): |
| | """ |
| | logSoftmax_with_loss |
| | :param input: torch.Tensor, N*C*H*W |
| | :param target: torch.Tensor, N*1*H*W,/ N*H*W |
| | :param weight: torch.Tensor, C |
| | :return: torch.Tensor [0] |
| | """ |
| | target = target.long() |
| | if target.dim() == 4: |
| | target = torch.squeeze(target, dim=1) |
| | if input.shape[-1] != target.shape[-1]: |
| | input = F.interpolate(input, size=target.shape[1:], mode='bilinear',align_corners=True) |
| |
|
| | return F.cross_entropy(input=input, target=target, weight=weight, |
| | ignore_index=ignore_index, reduction=reduction) |
| |
|
| |
|
| | def dice_loss(predicts,target,weight=None): |
| | idc= [0, 1] |
| | probs = torch.softmax(predicts, dim=1) |
| | |
| | target = class2one_hot(target, 7) |
| | assert simplex(probs) and simplex(target) |
| |
|
| | pc = probs[:, idc, ...].type(torch.float32) |
| | tc = target[:, idc, ...].type(torch.float32) |
| | intersection: Tensor = einsum("bcwh,bcwh->bc", pc, tc) |
| | union: Tensor = (einsum("bkwh->bk", pc) + einsum("bkwh->bk", tc)) |
| |
|
| | divided: Tensor = torch.ones_like(intersection) - (2 * intersection + 1e-10) / (union + 1e-10) |
| |
|
| | loss = divided.mean() |
| | return loss |
| |
|
| | def ce_dice(input, target, weight=None): |
| | ce_loss = cross_entropy(input, target) |
| | dice_loss_ = dice_loss(input, target) |
| | loss = 0.5 * ce_loss + 0.5 * dice_loss_ |
| | return loss |
| |
|
| | def dice(input, target, weight=None): |
| | dice_loss_ = dice_loss(input, target) |
| | return dice_loss_ |
| |
|
| | def ce2_dice1(input, target, weight=None): |
| | ce_loss = cross_entropy(input, target) |
| | dice_loss_ = dice_loss(input, target) |
| | loss = ce_loss + 0.5 * dice_loss_ |
| | return loss |
| |
|
| | def ce1_dice2(input, target, weight=None): |
| | ce_loss = cross_entropy(input, target) |
| | dice_loss_ = dice_loss(input, target) |
| | loss = 0.5 * ce_loss + dice_loss_ |
| | return loss |
| |
|
| | def ce_scl(input, target, weight=None): |
| | ce_loss = cross_entropy(input, target) |
| | dice_loss_ = dice_loss(input, target) |
| | loss = 0.5 * ce_loss + 0.5 * dice_loss_ |
| | return loss |
| |
|
| |
|
| | def weighted_BCE_logits(logit_pixel, truth_pixel, weight_pos=0.25, weight_neg=0.75): |
| | logit = logit_pixel.view(-1) |
| | truth = truth_pixel.view(-1) |
| | assert (logit.shape == truth.shape) |
| |
|
| | loss = F.binary_cross_entropy_with_logits(logit.float(), truth.float(), reduction='none') |
| |
|
| | pos = (truth > 0.5).float() |
| | neg = (truth < 0.5).float() |
| | pos_num = pos.sum().item() + 1e-12 |
| | neg_num = neg.sum().item() + 1e-12 |
| | loss = (weight_pos * pos * loss / pos_num + weight_neg * neg * loss / neg_num).sum() |
| |
|
| | return loss |
| |
|
| | class ChangeSimilarity(nn.Module): |
| | """input: x1, x2 multi-class predictions, c = class_num |
| | label_change: changed part |
| | """ |
| |
|
| | def __init__(self, reduction='mean'): |
| | super(ChangeSimilarity, self).__init__() |
| | self.loss_f = nn.CosineEmbeddingLoss(margin=0., reduction=reduction) |
| |
|
| | def forward(self, x1, x2, label_change): |
| | b, c, h, w = x1.size() |
| | x1 = F.softmax(x1, dim=1) |
| | x2 = F.softmax(x2, dim=1) |
| | x1 = x1.permute(0, 2, 3, 1) |
| | x2 = x2.permute(0, 2, 3, 1) |
| | x1 = torch.reshape(x1, [b * h * w, c]) |
| | x2 = torch.reshape(x2, [b * h * w, c]) |
| |
|
| | label_unchange = ~label_change.bool() |
| | target = label_unchange.float() |
| | target = target - label_change.float() |
| | target = torch.reshape(target, [b * h * w]) |
| |
|
| | loss = self.loss_f(x1, x2, target) |
| | return loss |
| |
|
| | def hybrid_loss(predictions, target, weight=[0,2,0.2,0.2,0.2,0.2]): |
| | """Calculating the loss""" |
| | loss = 0 |
| |
|
| | |
| | |
| | |
| |
|
| | for i,prediction in enumerate(predictions): |
| |
|
| | bce = cross_entropy(prediction, target) |
| | dice = dice_loss(prediction, target) |
| | |
| | loss += weight[i]*(bce + dice) |
| |
|
| | return loss |
| |
|
| | class BCL(nn.Module): |
| | """ |
| | batch-balanced contrastive loss |
| | no-change,1 |
| | change,-1 |
| | """ |
| | def __init__(self, margin=2.0): |
| | super(BCL, self).__init__() |
| | self.margin = margin |
| |
|
| | def forward(self, distance, label): |
| | label[label == 1] = -1 |
| | label[label == 0] = 1 |
| |
|
| | mask = (label != 255).float() |
| | distance = distance * mask |
| |
|
| | pos_num = torch.sum((label==1).float())+0.0001 |
| | neg_num = torch.sum((label==-1).float())+0.0001 |
| |
|
| | loss_1 = torch.sum((1+label) / 2 * torch.pow(distance, 2)) /pos_num |
| | loss_2 = torch.sum((1-label) / 2 * |
| | torch.pow(torch.clamp(self.margin - distance, min=0.0), 2) |
| | ) / neg_num |
| | loss = loss_1 + loss_2 |
| | return loss |