Commit 305950ef authored by valentini's avatar valentini
Browse files

Carica un nuovo file

parent d684063b
from torchvision import transforms as transforms
from torch.utils.data import DataLoader
from dataset.RGBDataset import Div2KTrain, Banchmark, CropDataset
from dataset.YUVDataset import *
from utils.transforms import DescreteRandomRotation
from torch.utils.data import random_split
import os
import torch
import numpy as np
def RGBDataset(args, useBGR=False, test_only=False):
# Dataset Set Up ###########################################
############ Image data augmentation for retrain ###########
transform = transforms.Compose([
DescreteRandomRotation(angles=[0, 90, 180, -90]),
transforms.RandomVerticalFlip(),
transforms.RandomHorizontalFlip(),
])
# Select the Data load er for the datasets (biggest difference is the folder structure and the loading process)
dataset = Div2KTrain(scaling=args.scale, useBGR=useBGR)
dataset_test = Banchmark(scaling=args.scale, useBGR=useBGR)
if args.test_mode:
use_only = 0.2
kept_num = round(len(dataset) * use_only)
igonored_num = round(len(dataset) * (1-use_only))
dataset, _ = random_split(dataset, [kept_num, igonored_num], generator=torch.Generator().manual_seed(args.seed))
use_only = 0.1
kept_num = round(len(dataset_test) * use_only)
igonored_num = round(len(dataset_test) * (1-use_only))
dataset_test, _ = random_split(dataset_test, [kept_num, igonored_num], generator=torch.Generator().manual_seed(args.seed))
if not test_only :
#Split the dataset in train and validation
train_percentual = 0.8
train_num = round(len(dataset) * train_percentual)
val_num = round(len(dataset) * (1-train_percentual))
train_data, val_data = random_split(dataset, [train_num, val_num], generator=torch.Generator().manual_seed(args.seed))
train_data = CropDataset(train_data, scaling=args.scale, transform=transform, seed=args.seed, crop=args.crop, cropMode="random", noiseInensity=args.noise)
val_data = CropDataset(val_data, scaling=args.scale, transform=None, seed=None, crop=args.crop, cropMode="center")
# create the loader for the training set
train_loader = DataLoader(train_data, shuffle=True, batch_size=args.batch, num_workers=0, pin_memory=True, worker_init_fn = lambda id: np.random.seed(id))
# create the loader for the validation set (to select the model after prune)
val_loader = DataLoader(val_data, shuffle=False, batch_size=args.batch, num_workers=0, pin_memory=True, worker_init_fn = lambda id: np.random.seed(id))
# create the loader for the test set (to evaluate the model performaces after selection)
test_loader = DataLoader(dataset_test, shuffle=False, batch_size=1, num_workers=0, pin_memory=True, worker_init_fn = lambda id: np.random.seed(id))
return (train_loader, val_loader, test_loader)
else:
test_loader = DataLoader(dataset_test, shuffle=False, batch_size=1, num_workers=0, pin_memory=True, worker_init_fn = lambda id: np.random.seed(id))
return (train_loader, val_loader, test_loader)
def YUVDataset(args, test_only=False):
# Dataset Set Up ###########################################
############ Image data augmentation for retrain ###########
transform = transforms.Compose([
DescreteRandomRotation(angles=[0, 90, 180, -90]),
transforms.RandomVerticalFlip(),
transforms.RandomHorizontalFlip(),
])
if not test_only:
dataset = DatasetFromFolderYUV(
scale=args.scale,
size=args.yuv_size,
yuvFormat=args.yuv_format,
y_only=args.only_y_channel,
filter_only=args.filter,
low_res_folder = os.path.normpath(args.low_res_data),
prefix_low_res = args.prefix_lowres,
high_res_folder = os.path.normpath(args.high_res_data),
prefix_high_res = args.prefix_highres,
transform=transform,
seed=args.seed,
)
if(args.yuv_testset.lower() == 'evcintra'):
dataset_test = TestSequenciesEVCIntra(yOnly=args.only_y_channel)
elif(args.yuv_testset.lower() == 'evcsdhd'):
dataset_test = TestSequenciesEVCSDHD(yOnly=args.only_y_channel)
elif(args.yuv_testset.lower() == 'evchd4k'):
dataset_test = TestSequenciesEVCHD4K(yOnly=args.only_y_channel)
elif(args.yuv_testset.lower() == 'vvcsdhd'):
dataset_test = TestSequenciesVVCSDHD(yOnly=args.only_y_channel)
elif(args.yuv_testset.lower() == 'vvchd4k'):
dataset_test = TestSequenciesVVCHD4K(yOnly=args.only_y_channel)
else:
raise Exception(f"Unsupported Raw Testset:{args.yuv_testset}. see comand info for detrails.")
if not test_only:
#Split the dataset in train and validation
train_percentual = 0.8
train_num = round(len(dataset) * train_percentual)
val_num = round(len(dataset) * (1-train_percentual))
train_data, val_data = random_split(dataset, [train_num, val_num], generator=torch.Generator().manual_seed(args.seed))
val_data.dataset.transform = None # Deactivate Transform in validation
#create the loader for the training set
train_loader = DataLoader(train_data, shuffle=True, batch_size=args.batch, num_workers=1, pin_memory=True, worker_init_fn = lambda id: np.random.seed(id))
#create the loader for the validation set
val_loader = DataLoader(val_data, shuffle=False, batch_size=args.batch, num_workers=1, pin_memory=True, worker_init_fn = lambda id: np.random.seed(id))
test_loader = DataLoader(dataset_test, shuffle=False, batch_size=1, num_workers=1, pin_memory=True, worker_init_fn = lambda id: np.random.seed(id))
return(train_loader, val_loader, test_loader)
else:
test_loader = DataLoader(dataset_test, shuffle=False, batch_size=1, num_workers=1, pin_memory=True, worker_init_fn = lambda id: np.random.seed(id))
return test_loader
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment