import os import pandas as pd import torch import torch.nn as nn import torch.nn.functional as F import torch_pruning as tp import numpy as np import random from tqdm import tqdm from ignite.metrics import PSNR from ignite.metrics import SSIM # Model Variants from model.dummy_sr_model import DummySRModel from utils.ckpt import load_checkpoint from utils.util import get_luminance as get_luminance_from_rgb from dataset import RGBDataset, YUVDataset from utils.trainer import trainFor, testModel, sparsityLearning from utils.loss import CharbonnierLoss from torch.nn import MSELoss from arguments import get_arguments from pruners import pruner_constructor def generateRunName(args): pruning_steps = args.pruning_steps pruning_target_ratio = args.pruning_target_ratio # Constract name name = f"Pruning(x{args.scale}({pruning_steps}))-RetrainedOn(e{args.epochs}-{args.crop}x{args.crop})" return name def loadModel(args): # create the model model = DummySRModel(args) # Load pretrained weights if args.weigths: model_dict, epoch, mse = load_checkpoint(args.weigths) if model_dict is None: raise Exception("The ckpt dose not have the model state_dict!") model.load_state_dict(model_dict['model']) # Saving original Model ckpt_path = os.path.join(checkpoints_path, 'unpruned_model.pth') if not os.path.exists(ckpt_path): torch.save({ 'model': model_dict['model'], }, ckpt_path) return model if __name__ == '__main__': args = get_arguments() print(args) # SetUp Random torch.manual_seed(args.seed) np.random.seed(args.seed) random.seed(args.seed) ### Prepare output Dirs params = vars(args) params['dataset'] = os.path.basename(os.path.normpath(args.loader)) run_name, tags = generateRunName(args) run_dir = os.path.join(args.runs, run_name) print("Outupt Folder root:", run_dir) log_metrics_path = os.path.join(run_dir, 'pruning_results.csv') original_test_path = os.path.join(run_dir, 'original') checkpoints_path = os.path.join(run_dir, 'checkpoints') param_file = os.path.join(run_dir, 'params.csv') os.makedirs(original_test_path, exist_ok=True) os.makedirs(checkpoints_path, exist_ok=True) pd.DataFrame(params, index=[0]).to_csv(param_file, index=False) # Load model model = loadModel(args) # Get Dataset if args.loader.lower() == 'div2k_rgb': train_loader, val_loader, test_loader = RGBDataset(args) loss_function = CharbonnierLoss(args.loss_epsylon) get_luminance = get_luminance_from_rgb elif args.loader.lower() == 'custom_yuv': train_loader, val_loader, test_loader = YUVDataset(args) loss_function = MSELoss() get_luminance = None else: raise Exception("Unsupported dataset") # Define device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) # Defining test inputs to evaluate the model performances example_input_sd = torch.randn(1, 3, 964, 540).to(device) # Needed for some pruners pruner = pruner_constructor(args, model, torch.randn(1, 3, 10, 10), device) log_metrics = pd.DataFrame() base_macs_sd, base_nparams = tp.utils.count_ops_and_params(model, example_input_sd) # Eval Original Model on The TestSet psnr_pretrain = PSNR(data_range=args.data_range, output_transform=get_luminance, device=device) ssim_pretrain = SSIM(data_range=args.data_range, output_transform=get_luminance, device=device) test_loss, test_table = testModel( loader=test_loader, model=model, args=args, psnr=psnr_pretrain, ssim=ssim_pretrain, data_range=args.data_range, device=device, loss_function=loss_function ) prune_iter_metrics = {} prune_iter_metrics["pruning_step"] = 0 prune_iter_metrics["pruning_rateo"] = 0 prune_iter_metrics["parameters_(M)"] = base_nparams / 1e6 prune_iter_metrics["inference_SD_HD_flops(G)"] = base_macs_sd / 1e9 prune_iter_metrics["mse"] = test_loss prune_iter_metrics['ssim'] = float(ssim_pretrain.compute()) prune_iter_metrics['psnr'] = float(psnr_pretrain.compute()) test_table.to_csv(os.path.join(original_test_path, 'original_test.csv'), index=False) test_table.groupby('sequence').mean().reset_index().to_csv(os.path.join(original_test_path, 'original_test_by_sequence.csv'), index=False) log_metrics = log_metrics.append(prune_iter_metrics, ignore_index=True) log_metrics.to_csv(log_metrics_path, index=False) # Save depenency graph visualization tp.utils.draw_dependency_graph(pruner.DG, save_as=os.path.join(original_test_path, 'draw_dep_graph.png'), title=None) tp.utils.draw_groups(pruner.DG, save_as=os.path.join(original_test_path, 'draw_groups.png'), title=None) tp.utils.draw_computational_graph(pruner.DG, save_as=os.path.join(original_test_path, 'draw_comp_graph.png'), title=None) # Save original model structure macs_sd, nparams = tp.utils.count_ops_and_params(model, example_input_sd) with open(os.path.join(original_test_path, 'model_details.txt'), 'w') as model_details_file: model_details_file.write(f"{model}\n") wandb.log({"model_description": f"{model}"}) model_details_file.write( " Iter %d/%d, Params: %.2f M => %.2f M\n" % (0, args.pruning_steps, base_nparams / 1e6, nparams / 1e6) ) model_details_file.write( " Iter %d/%d, MACs SD_Input: %.2f G => %.2f G\n" % (0, args.pruning_steps, base_macs_sd / 1e9, macs_sd / 1e9) ) #################################### # Pruning Cycles ################### #################################### training_iter = 0 for i in tqdm(range(1, args.pruning_steps + 1)): step_path = os.path.join(run_dir, 'pruning_iter_{}'.format(i)) os.makedirs(step_path, exist_ok=True) # Learning Sparsity (Some pruning techniques require a treaning step to learn the sparsity) if args.pruning_method == "growing_reg": sparsityLearning( model=model, pruner=pruner, loader=train_loader, args=args, loss_function = loss_function ) # Pruning Step pruner.step() macs_sd, nparams = tp.utils.count_ops_and_params(model, example_input_sd) with open(os.path.join(step_path, 'model_details.txt'), 'w') as model_details_file: model_details_file.write(f"{model}\n") wandb.log({"model_description": f"{model}"}) model_details_file.write( " Iter %d/%d, Params: %.2f M => %.2f M\n" % (i, args.pruning_steps, base_nparams / 1e6, nparams / 1e6) ) model_details_file.write( " Iter %d/%d, MACs SD_Input: %.2f G => %.2f G\n" % (i, args.pruning_steps, base_macs_sd / 1e9, macs_sd / 1e9) ) # Model finetuing to recover the loast performacies best_model_current_pruning = model # If noting better is found the initial model is the betst best_mse_current_pruning = None best_optimizer_current_pruning = None if args.epochs and args.epochs > 0: print("Retraining for recovery!") (best_model, best_mse, last_epoch_model, last_epoch_mse, optimizer, run_logs) = trainFor( model=model, train_dataloader=train_loader, val_dataloader=val_loader, device=device, args=args, epochs=args.epochs, run_folder=step_path, pruner=pruner, loss_function=loss_function ) model.load_state_dict(best_model.state_dict()) # Load weights of the best model!! wandb.log({"train_logs": wandb.Table(dataframe=run_logs)}) run_logs.to_csv(os.path.join(step_path, f'traning_log.csv'),index=False) best_model_current_pruning = best_model best_mse_current_pruning = best_mse # Test Best Retrained Model And log results prune_iter_metrics = {} prune_iter_metrics["pruning_step"] = i prune_iter_metrics["pruning_rateo"] = 1 - (nparams / base_nparams) prune_iter_metrics["parameters_(M)"] = nparams / 1e6 prune_iter_metrics["inference_SD_HD_flops(G)"] = macs_sd / 1e9 psnr_test = PSNR(data_range=args.data_range, device=device) ssim_test = SSIM(data_range=args.data_range, device=device) test_loss, test_table = testModel( loader=test_loader, model=model, args=args, psnr=psnr_test, ssim=ssim_test, data_range=args.data_range, device=device, loss_function=loss_function ) test_table.to_csv(os.path.join(step_path, f'test.csv'),index=False) test_table.groupby('sequence').mean().reset_index().to_csv(os.path.join(step_path, f'test_by_sequence.csv'), index=False) prune_iter_metrics["mse"]= test_loss prune_iter_metrics['ssim'] = float(ssim_test.compute()) prune_iter_metrics['psnr'] = float(psnr_test.compute()) log_metrics = log_metrics.append(prune_iter_metrics, ignore_index=True) log_metrics.to_csv(log_metrics_path, index=False) # Saving the fine tuned (BEST) model ckpt_path = os.path.join(checkpoints_path, 'pruned_iteraion_{}.pth'.format(i)) if not os.path.exists(ckpt_path): torch.save({ 'metrics': test_loss, 'model': tp.state_dict(best_model_current_pruning), }, ckpt_path)