Commit ce775a4b authored by valentini's avatar valentini
Browse files

Carica un nuovo file

parent d0740e50
import torch
import torch.nn as nn
class CharbonnierLoss(nn.Module):
"""Charbonnier Loss (L1)"""
def __init__(self, eps=1e-6, out_norm:str='bci'):
super(CharbonnierLoss, self).__init__()
self.eps = eps
self.out_norm = out_norm
def forward(self, x, y):
norm = get_outnorm(x, self.out_norm)
loss = torch.sum(torch.sqrt((x - y).pow(2) + self.eps**2))
return loss*norm
def get_outnorm(x:torch.Tensor, out_norm:str='') -> torch.Tensor:
""" Common function to get a loss normalization value. Can
normalize by either the batch size ('b'), the number of
channels ('c'), the image size ('i') or combinations
('bi', 'bci', etc)
"""
# b, c, h, w = x.size()
img_shape = x.shape
if not out_norm:
return 1
norm = 1
if 'b' in out_norm:
# normalize by batch size
# norm /= b
norm /= img_shape[0]
if 'c' in out_norm:
# normalize by the number of channels
# norm /= c
norm /= img_shape[-3]
if 'i' in out_norm:
# normalize by image/map size
# norm /= h*w
norm /= img_shape[-1]*img_shape[-2]
return norm
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment