Commit 99896d6f authored by valentini's avatar valentini
Browse files

Carica un nuovo file

parent 899316af
import torch
import cv2
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.parallel as P
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
import pandas as pd
import yuvio
import os
from tqdm import tqdm, trange
from ignite.metrics import PSNR
from ignite.metrics import SSIM
from utils.util import get_luminance, calc_psnr, quantize
from utils.loss import CharbonnierLoss
import os
import copy
import torch_pruning as tp
import logging
class ForwardManager():
def __init__(self, model, training, args, max_size=16384):
self.scale = args.scale
self.self_ensemble = args.self_ensemble
self.chop = args.chop
self.precision = args.precision
self.device = torch.device('cuda')
self.n_GPUs = torch.cuda.device_count()
self.training = training
self.model = model
if self.precision == 'half': self.model.half()
if self.n_GPUs > 1:
self.model = nn.DataParallel(self.model, range(self.n_GPUs))
self.max_inference_size = max_size
def forward(self, x):
if self.self_ensemble and not self.training:
if self.chop:
forward_function = self.forward_chop
else:
forward_function = self.model.forward
return self.forward_x8(x, forward_function)
elif self.chop and not self.training:
return self.forward_chop(x)
else:
return self.model(x)
def get_model(self):
if self.n_GPUs == 1:
return self.model
else:
return self.model.module
def forward_chop(self, x, shave=None, min_size=None):
scale = self.scale
if not shave: shave = scale
if min_size == None: min_size = self.max_inference_size
n_GPUs = min(self.n_GPUs, 4)
b, c, h, w = x.size()
h_half, w_half = h // 2, w // 2
h_size, w_size = h_half + shave, w_half + shave
lr_list = [
x[:, :, 0:h_size, 0:w_size],
x[:, :, 0:h_size, (w - w_size):w],
x[:, :, (h - h_size):h, 0:w_size],
x[:, :, (h - h_size):h, (w - w_size):w]]
if w_size * h_size < min_size:
sr_list = []
for i in range(0, 4, n_GPUs):
lr_batch = torch.cat(lr_list[i:(i + n_GPUs)], dim=0)
sr_batch = self.model(lr_batch)
sr_list.extend(sr_batch.chunk(n_GPUs, dim=0))
else:
sr_list = [
self.forward_chop(patch, shave=shave, min_size=min_size) \
for patch in lr_list
]
h, w = scale * h, scale * w
h_half, w_half = scale * h_half, scale * w_half
h_size, w_size = scale * h_size, scale * w_size
shave *= scale
output = x.new(b, c, h, w)
output[:, :, 0:h_half, 0:w_half] \
= sr_list[0][:, :, 0:h_half, 0:w_half]
output[:, :, 0:h_half, w_half:w] \
= sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size]
output[:, :, h_half:h, 0:w_half] \
= sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half]
output[:, :, h_half:h, w_half:w] \
= sr_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size]
return output
def forward_x8(self, x, forward_function):
def _transform(v, op):
if self.precision != 'single': v = v.float()
v2np = v.data.cpu().numpy()
if op == 'v':
tfnp = v2np[:, :, :, ::-1].copy()
elif op == 'h':
tfnp = v2np[:, :, ::-1, :].copy()
elif op == 't':
tfnp = v2np.transpose((0, 1, 3, 2)).copy()
ret = torch.Tensor(tfnp).to(self.device)
if self.precision == 'half': ret = ret.half()
return ret
lr_list = [x]
for tf in 'v', 'h', 't':
lr_list.extend([_transform(t, tf) for t in lr_list])
sr_list = [forward_function(aug) for aug in lr_list]
for i in range(len(sr_list)):
if i > 3:
sr_list[i] = _transform(sr_list[i], 't')
if i % 4 > 1:
sr_list[i] = _transform(sr_list[i], 'h')
if (i % 4) % 2 == 1:
sr_list[i] = _transform(sr_list[i], 'v')
output_cat = torch.cat(sr_list, dim=0)
output = output_cat.mean(dim=0, keepdim=True)
return output
def __call__(self, x) :
self.forward(x)
def trainFor(
train_dataloader: DataLoader,
val_dataloader: DataLoader,
model: torch.nn.Module,
device: torch.device,
run_folder: str,
epochs: int,
loss_function,
args,
pruner = None,
optimizer = None,
scheduler = None,
):
# paths
ckpt_path = os.path.join(run_folder, 'checkpoints')
os.makedirs(ckpt_path, exist_ok=True)
# create the optmizer
if not optimizer: optimizer = Adam(model.parameters(), lr=args.lr)
if not scheduler: scheduler = ReduceLROnPlateau(optimizer, patience=args.patience, factor=0.5, verbose=True)
model = model.to(device)
# Set Up Metric Trackers
psnr_train = PSNR(data_range=args.data_range, output_transform=get_luminance, device=device)
psnr_val = PSNR(data_range=args.data_range, output_transform=get_luminance, device=device)
ssim_train = SSIM(data_range=args.data_range, output_transform=get_luminance, device=device)
ssim_val = SSIM(data_range=args.data_range, output_transform=get_luminance, device=device)
log = pd.DataFrame()
best_validation_mse = None
start_epoch = 1
best_model = model
for epoch in trange(start_epoch, epochs + 1):
#Reset metrics calculator
psnr_train.reset()
psnr_val.reset()
ssim_train.reset()
ssim_val.reset()
metrics = {'epoch': epoch}
cur_loss = trainEval(
loader=train_dataloader,
model=model,
optimizer=scheduler.optimizer,
device=device,
args=args,
loss_function=loss_function,
bTrain = True,
psnr=psnr_train,
ssim=ssim_train,
pruner = None,
data_range=args.data_range
)
val_loss = trainEval(
loader=val_dataloader,
model=model,
optimizer=scheduler.optimizer,
args=args,
device=device,
loss_function=loss_function,
bTrain = False,
psnr=psnr_val,
ssim=ssim_val,
data_range=args.data_range
)
# Log Metrics to CSV File
metrics['mse_train'] = cur_loss
metrics['mse_val'] = val_loss
metrics['psnr_train'] = float(psnr_train.compute())
metrics['psnr_val'] = float(psnr_val.compute())
metrics['ssim_train'] = float(ssim_train.compute())
metrics['ssim_val'] = float(ssim_val.compute())
if best_validation_mse is None or (val_loss < best_validation_mse):
best_validation_mse = val_loss
best_model = copy.deepcopy(model)
ckpt = os.path.join(ckpt_path, 'ckpt_e{}.pth'.format(epoch))
torch.save({
'epoch': epoch,
'mse_train': cur_loss,
'mse_val': val_loss,
'model': model,
'optimizer': scheduler.optimizer,
}, ckpt)
log = log.append(metrics, ignore_index=True)
log.to_csv(os.path.join(run_folder, f'train_log.csv'), index=False)
scheduler.step(val_loss)
## Save last epoch
ckpt_path = os.path.join(ckpt_path, 'final_ckpt_e{}.pth'.format(epoch))
if not os.path.exists(ckpt_path):
torch.save({
'epoch': epoch,
'mse_train': cur_loss,
'mse_val': val_loss,
'model': model,
'optimizer': scheduler.optimizer,
}, ckpt_path)
last_epoch_model = model
last_epoch_validation_mse = best_validation_mse
return (best_model, best_validation_mse, last_epoch_model, last_epoch_validation_mse, scheduler.optimizer, log)
#training for a single epoch
def trainEval(loader, model, optimizer, device: torch.device, loss_function, args, bTrain = True, psnr=None, ssim=None, data_range=1.0, pruner=None, training_iter=0):
forward = ForwardManager(model, bTrain, args)
local_psnr = PSNR(data_range=data_range, device=device)
local_ssim = SSIM(data_range=data_range, device=device)
if bTrain:
model.train()
else:
model.eval()
total_loss = 0.0
counter = 0
progress = tqdm(loader)
for input, target in progress:
if bTrain:#train
if torch.cuda.is_available():
input = input.cuda()
target = target.cuda()
model_out = forward.forward(input)
else: #eval
with torch.no_grad():
if torch.cuda.is_available():
input = input.cuda()
target = target.cuda()
model_out = forward.forward(input)
# if "yuv" in args.loader:
# target = target * (1023/255)
# model_out = model_out * (1023/255)
# quantize the out for SSIM and PSNR calc
q_out = quantize(model_out, args.data_range)
local_psnr.reset()
local_ssim.reset()
for b in range(q_out.shape[0]):
local_psnr.update((q_out[b:b+1], target[b:b+1]))
local_ssim.update((q_out[b:b+1], target[b:b+1]))
if psnr: psnr.update((q_out[b:b+1], target[b:b+1]))
if ssim: ssim.update((q_out[b:b+1], target[b:b+1]))
loss = loss_function(model_out, target)
if bTrain:
optimizer.zero_grad()
loss.backward()
if pruner is not None: pruner.regularize(model) # for sparsity learning
optimizer.step()
total_loss += loss.item()
counter += 1
progress.set_postfix({
'avg_loss': total_loss / counter,
'loss_iteration': loss.item(),
'psnr_iteration': float(local_psnr.compute()),
'ssim_iteration': float(local_ssim.compute())
})
training_iter += 1
if pruner is not None and isinstance(pruner, tp.pruner.GrowingRegPruner) and training_iter % args.update_reg_interval == 0:
pruner.update_reg() # increase the strength of regularization
#print(pruner.group_reg[pruner._groups[0]])
return total_loss / counter
#training for a single epoch
def sparsityLearning(loader, model,pruner, args, loss_function, training_iter=0):
pruner.update_regularizor() # Regrenerate the regularizator needed to handle pruned models
model.train()
optimizer = Adam(model.parameters(), lr=args.lr)
total_loss = 0.0
counter = 0
stop_condition_satisfied = False
while (not stop_condition_satisfied):
for input, target in loader:
if torch.cuda.is_available():
input = input.cuda()
target = target.cuda()
model_out = model(input)
if "yuv" in args.loader:
target = target * (1023/255)
model_out = model_out * (1023/255)
# quantize the out for SSIM and PSNR calc
loss = loss_function(model_out, target)
optimizer.zero_grad()
loss.backward()
pruner.regularize(model) # for sparsity learning
optimizer.step()
total_loss += loss.item()
counter += 1
training_iter += 1
if pruner is not None and isinstance(pruner, tp.pruner.GrowingRegPruner) and training_iter % args.update_reg_interval == 0:
pruner.update_reg() # increase the strength of regularization
stop_condition_satisfied = True
for i, group in enumerate(pruner._groups):
gamma = pruner.group_reg[group]
stop_condition_satisfied = torch.min(gamma) < args.target_regularization
# Generic case is to stop after a full train epoch for sparsity learning
if pruner is not None and not isinstance(pruner, tp.pruner.GrowingRegPruner):
stop_condition_satisfied=True
return model, pruner
def forward_one_batch(
train_dataloader: DataLoader,
model: torch.nn.Module,
device: torch.device,
run_folder: str,
loss_function,
args,
batch_idx =0,
optimizer = None,
scheduler = None,
local_psnr = None,
local_ssim = None,
):
# paths
ckpt_path = os.path.join(run_folder, 'checkpoints')
os.makedirs(ckpt_path, exist_ok=True)
# create the optmizer
if not optimizer: optimizer = Adam(model.parameters(), lr=args.lr)
if not scheduler: scheduler = ReduceLROnPlateau(optimizer, patience=args.patience, factor=0.5, verbose=True)
model = model.to(device)
forward = ForwardManager(model, True, args)
model.train()
for input, target in train_dataloader[batch_idx]:
if torch.cuda.is_available():
input = input.cuda()
target = target.cuda()
model_out = forward.forward(input)
if "yuv" in args.loader:
target = target * (1023/255)
model_out = model_out * (1023/255)
# quantize the out for SSIM and PSNR calc
q_out = quantize(model_out, args.data_range)
if local_psnr: local_psnr.reset()
if local_ssim: local_ssim.reset()
for b in range(q_out.shape[0]):
if local_psnr: local_psnr.update((q_out[b:b+1], target[b:b+1]))
if local_ssim: local_ssim.update((q_out[b:b+1], target[b:b+1]))
loss = loss_function(model_out, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return loss.item(), scheduler, optimizer, batch_idx +1
def eval(loader, model, loss_function, device: torch.device, args, data_range=1.0):
forward = ForwardManager(model, True, args)
local_psnr = PSNR(data_range=data_range, device=device)
local_ssim = SSIM(data_range=data_range, device=device)
model.train()
total_loss = 0.0
counter = 0
progress = tqdm(loader)
for input, target in progress:
with torch.no_grad():
if torch.cuda.is_available():
input = input.cuda()
target = target.cuda()
model_out = forward.forward(input)
if "yuv" in args.loader:
target = target * (1023/255)
model_out = model_out * (1023/255)
# quantize the out for SSIM and PSNR calc
q_out = quantize(model_out, 255)
for b in range(q_out.shape[0]):
local_psnr.update((q_out[b:b+1], target[b:b+1]))
local_ssim.update((q_out[b:b+1], target[b:b+1]))
loss = loss_function(model_out, target)
total_loss += loss.item()
counter += 1
progress.set_postfix({
'avg_loss': total_loss / counter,
'loss_iteration': loss.item(),
'psnr_iteration': float(local_psnr.compute()),
'ssim_iteration': float(local_ssim.compute())
})
return total_loss / counter, local_psnr, local_ssim
def testModel(loader, model, psnr, ssim, args, loss_function, device: torch.device, data_range=1.0, metricTable:pd.DataFrame=None, out= None):
forward = ForwardManager(model, False, args)
if not metricTable: metricTable = pd.DataFrame(columns=['sequence','frame','loss','psnr','ssim'])
local_ssim = SSIM(data_range=data_range, output_transform=get_luminance, device=device)
custom_psnr = []
model.eval()
forward.self_ensemble = True
total_loss = 0.0
counter = 0
progress = tqdm(loader)
for input, target, info in progress:
with torch.no_grad():
if torch.cuda.is_available():
input = input.cuda()
target = target.cuda()
model_out = forward.forward(input)
if "yuv" in args.loader:
target = target.mean(dim=1).unsqueeze(0)
model_out = model_out.mean(dim=1).unsqueeze(0) *(1023/255)
q_out = quantize(model_out, args.data_range)
psnr_score = calc_psnr(q_out, target, args.scale, args.data_range, skip_luminance=True)
else:
q_out = quantize(model_out, args.data_range)
psnr_score = calc_psnr(q_out, target, args.scale, args.data_range)
custom_psnr.append(psnr_score)
local_ssim.reset()
local_ssim.update((q_out, target))
if psnr: psnr.update((q_out, target))
if ssim: ssim.update((q_out, target))
loss = loss_function(model_out, target)
new_row = pd.DataFrame({
'sequence': info['sequence_name'],
'frame': info['frame'],
'loss': loss.item(),
'psnr': float(psnr_score),
'ssim': float(local_ssim.compute())},
index=[0]
)
metricTable = pd.concat([new_row, metricTable.loc[:]]).reset_index(drop=True)
if out:
(B, C, H, W) = q_out.shape
model_out_img = q_out.cpu().numpy() # allready clamped between 0 and 255
for b in range(B):
model_out_img = model_out_img[b].transpose((1, 2, 0))
if not loader.dataset.useBGR:
model_out_img = cv2.cvtColor(model_out_img, cv2.COLOR_RGB2BGR) # CV save in BGR
s = info['sequence_name']
n = info['frame']
cv2.imwrite(f'{out}/{s}_{n}.png', model_out_img)
total_loss += loss.item()
counter += 1
progress.set_postfix({
'loss': total_loss / counter,
'psnr_iteration': sum(custom_psnr)/len(custom_psnr),
'ssim_iteration': float(local_ssim.compute())
})
return total_loss / counter, metricTable
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment