from torchvision import transforms as transforms from torch.utils.data import DataLoader from dataset.RGBDataset import Div2KTrain, Banchmark, CropDataset from dataset.YUVDataset import * from utils.transforms import DescreteRandomRotation from torch.utils.data import random_split import os import torch import numpy as np def RGBDataset(args, useBGR=False, test_only=False): # Dataset Set Up ########################################### ############ Image data augmentation for retrain ########### transform = transforms.Compose([ DescreteRandomRotation(angles=[0, 90, 180, -90]), transforms.RandomVerticalFlip(), transforms.RandomHorizontalFlip(), ]) # Select the Data load er for the datasets (biggest difference is the folder structure and the loading process) dataset = Div2KTrain(scaling=args.scale, useBGR=useBGR) dataset_test = Banchmark(scaling=args.scale, useBGR=useBGR) if args.test_mode: use_only = 0.2 kept_num = round(len(dataset) * use_only) igonored_num = round(len(dataset) * (1-use_only)) dataset, _ = random_split(dataset, [kept_num, igonored_num], generator=torch.Generator().manual_seed(args.seed)) use_only = 0.1 kept_num = round(len(dataset_test) * use_only) igonored_num = round(len(dataset_test) * (1-use_only)) dataset_test, _ = random_split(dataset_test, [kept_num, igonored_num], generator=torch.Generator().manual_seed(args.seed)) if not test_only : #Split the dataset in train and validation train_percentual = 0.8 train_num = round(len(dataset) * train_percentual) val_num = round(len(dataset) * (1-train_percentual)) train_data, val_data = random_split(dataset, [train_num, val_num], generator=torch.Generator().manual_seed(args.seed)) train_data = CropDataset(train_data, scaling=args.scale, transform=transform, seed=args.seed, crop=args.crop, cropMode="random", noiseInensity=args.noise) val_data = CropDataset(val_data, scaling=args.scale, transform=None, seed=None, crop=args.crop, cropMode="center") # create the loader for the training set train_loader = DataLoader(train_data, shuffle=True, batch_size=args.batch, num_workers=0, pin_memory=True, worker_init_fn = lambda id: np.random.seed(id)) # create the loader for the validation set (to select the model after prune) val_loader = DataLoader(val_data, shuffle=False, batch_size=args.batch, num_workers=0, pin_memory=True, worker_init_fn = lambda id: np.random.seed(id)) # create the loader for the test set (to evaluate the model performaces after selection) test_loader = DataLoader(dataset_test, shuffle=False, batch_size=1, num_workers=0, pin_memory=True, worker_init_fn = lambda id: np.random.seed(id)) return (train_loader, val_loader, test_loader) else: test_loader = DataLoader(dataset_test, shuffle=False, batch_size=1, num_workers=0, pin_memory=True, worker_init_fn = lambda id: np.random.seed(id)) return (train_loader, val_loader, test_loader) def YUVDataset(args, test_only=False): # Dataset Set Up ########################################### ############ Image data augmentation for retrain ########### transform = transforms.Compose([ DescreteRandomRotation(angles=[0, 90, 180, -90]), transforms.RandomVerticalFlip(), transforms.RandomHorizontalFlip(), ]) if not test_only: dataset = DatasetFromFolderYUV( scale=args.scale, size=args.yuv_size, yuvFormat=args.yuv_format, y_only=args.only_y_channel, filter_only=args.filter, low_res_folder = os.path.normpath(args.low_res_data), prefix_low_res = args.prefix_lowres, high_res_folder = os.path.normpath(args.high_res_data), prefix_high_res = args.prefix_highres, transform=transform, seed=args.seed, ) if(args.yuv_testset.lower() == 'evcintra'): dataset_test = TestSequenciesEVCIntra(yOnly=args.only_y_channel) elif(args.yuv_testset.lower() == 'evcsdhd'): dataset_test = TestSequenciesEVCSDHD(yOnly=args.only_y_channel) elif(args.yuv_testset.lower() == 'evchd4k'): dataset_test = TestSequenciesEVCHD4K(yOnly=args.only_y_channel) elif(args.yuv_testset.lower() == 'vvcsdhd'): dataset_test = TestSequenciesVVCSDHD(yOnly=args.only_y_channel) elif(args.yuv_testset.lower() == 'vvchd4k'): dataset_test = TestSequenciesVVCHD4K(yOnly=args.only_y_channel) else: raise Exception(f"Unsupported Raw Testset:{args.yuv_testset}. see comand info for detrails.") if not test_only: #Split the dataset in train and validation train_percentual = 0.8 train_num = round(len(dataset) * train_percentual) val_num = round(len(dataset) * (1-train_percentual)) train_data, val_data = random_split(dataset, [train_num, val_num], generator=torch.Generator().manual_seed(args.seed)) val_data.dataset.transform = None # Deactivate Transform in validation #create the loader for the training set train_loader = DataLoader(train_data, shuffle=True, batch_size=args.batch, num_workers=1, pin_memory=True, worker_init_fn = lambda id: np.random.seed(id)) #create the loader for the validation set val_loader = DataLoader(val_data, shuffle=False, batch_size=args.batch, num_workers=1, pin_memory=True, worker_init_fn = lambda id: np.random.seed(id)) test_loader = DataLoader(dataset_test, shuffle=False, batch_size=1, num_workers=1, pin_memory=True, worker_init_fn = lambda id: np.random.seed(id)) return(train_loader, val_loader, test_loader) else: test_loader = DataLoader(dataset_test, shuffle=False, batch_size=1, num_workers=1, pin_memory=True, worker_init_fn = lambda id: np.random.seed(id)) return test_loader