import torch import torch.nn as nn class CharbonnierLoss(nn.Module): """Charbonnier Loss (L1)""" def __init__(self, eps=1e-6, out_norm:str='bci'): super(CharbonnierLoss, self).__init__() self.eps = eps self.out_norm = out_norm def forward(self, x, y): norm = get_outnorm(x, self.out_norm) loss = torch.sum(torch.sqrt((x - y).pow(2) + self.eps**2)) return loss*norm def get_outnorm(x:torch.Tensor, out_norm:str='') -> torch.Tensor: """ Common function to get a loss normalization value. Can normalize by either the batch size ('b'), the number of channels ('c'), the image size ('i') or combinations ('bi', 'bci', etc) """ # b, c, h, w = x.size() img_shape = x.shape if not out_norm: return 1 norm = 1 if 'b' in out_norm: # normalize by batch size # norm /= b norm /= img_shape[0] if 'c' in out_norm: # normalize by the number of channels # norm /= c norm /= img_shape[-3] if 'i' in out_norm: # normalize by image/map size # norm /= h*w norm /= img_shape[-1]*img_shape[-2] return norm