__init__.py 5.95 KB
Newer Older
valentini's avatar
valentini committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118

from torchvision import transforms as transforms
from torch.utils.data import DataLoader
from dataset.RGBDataset import Div2KTrain, Banchmark, CropDataset 
from dataset.YUVDataset import *
from utils.transforms import DescreteRandomRotation
from torch.utils.data import random_split
import os
import torch
import numpy as np



def RGBDataset(args, useBGR=False, test_only=False):
    # Dataset Set Up ###########################################

    ############ Image data augmentation for retrain ###########
    transform = transforms.Compose([
        DescreteRandomRotation(angles=[0, 90, 180, -90]),
        transforms.RandomVerticalFlip(),
        transforms.RandomHorizontalFlip(),
    ])

    # Select the Data load er for the datasets (biggest difference is the folder structure and the loading process)
    dataset = Div2KTrain(scaling=args.scale, useBGR=useBGR)
    dataset_test = Banchmark(scaling=args.scale, useBGR=useBGR)

    
    if args.test_mode:
        use_only = 0.2
        
        kept_num = round(len(dataset) * use_only)
        igonored_num = round(len(dataset) * (1-use_only))
        dataset, _ = random_split(dataset, [kept_num, igonored_num], generator=torch.Generator().manual_seed(args.seed))

        use_only = 0.1
        kept_num = round(len(dataset_test) * use_only)
        igonored_num = round(len(dataset_test) * (1-use_only))
        dataset_test, _ = random_split(dataset_test, [kept_num, igonored_num], generator=torch.Generator().manual_seed(args.seed))

    if not test_only :
        #Split the dataset in train and validation
        train_percentual = 0.8
        train_num = round(len(dataset) * train_percentual)
        val_num = round(len(dataset) * (1-train_percentual))
        train_data, val_data = random_split(dataset, [train_num, val_num], generator=torch.Generator().manual_seed(args.seed))
        train_data = CropDataset(train_data, scaling=args.scale, transform=transform, seed=args.seed, crop=args.crop, cropMode="random", noiseInensity=args.noise)
        val_data = CropDataset(val_data, scaling=args.scale, transform=None, seed=None, crop=args.crop, cropMode="center")
        
        # create the loader for the training set
        train_loader = DataLoader(train_data, shuffle=True,  batch_size=args.batch, num_workers=0, pin_memory=True, worker_init_fn = lambda id: np.random.seed(id))
        # create the loader for the validation set (to select the model after prune)
        val_loader = DataLoader(val_data, shuffle=False, batch_size=args.batch, num_workers=0, pin_memory=True, worker_init_fn = lambda id: np.random.seed(id))
        # create the loader for the test set (to evaluate the model performaces after selection)
        test_loader = DataLoader(dataset_test, shuffle=False, batch_size=1, num_workers=0, pin_memory=True, worker_init_fn = lambda id: np.random.seed(id))
    
        return (train_loader, val_loader, test_loader)

    else:
        test_loader = DataLoader(dataset_test, shuffle=False, batch_size=1, num_workers=0, pin_memory=True, worker_init_fn = lambda id: np.random.seed(id))
        return (train_loader, val_loader, test_loader)
    
def YUVDataset(args, test_only=False):
   # Dataset Set Up ###########################################

    ############ Image data augmentation for retrain ###########
    transform = transforms.Compose([
        DescreteRandomRotation(angles=[0, 90, 180, -90]),
        transforms.RandomVerticalFlip(),
        transforms.RandomHorizontalFlip(),
    ])

    if not test_only:
        dataset = DatasetFromFolderYUV(
                scale=args.scale,
                size=args.yuv_size,
                yuvFormat=args.yuv_format,
                y_only=args.only_y_channel,
                filter_only=args.filter,
                low_res_folder = os.path.normpath(args.low_res_data),
                prefix_low_res = args.prefix_lowres,
                high_res_folder = os.path.normpath(args.high_res_data),   
                prefix_high_res = args.prefix_highres,
                transform=transform,
                seed=args.seed,
            )

    if(args.yuv_testset.lower() == 'evcintra'):
        dataset_test = TestSequenciesEVCIntra(yOnly=args.only_y_channel)
    elif(args.yuv_testset.lower() == 'evcsdhd'):
        dataset_test = TestSequenciesEVCSDHD(yOnly=args.only_y_channel)
    elif(args.yuv_testset.lower() == 'evchd4k'):
        dataset_test = TestSequenciesEVCHD4K(yOnly=args.only_y_channel)
    elif(args.yuv_testset.lower() == 'vvcsdhd'):
        dataset_test = TestSequenciesVVCSDHD(yOnly=args.only_y_channel)
    elif(args.yuv_testset.lower() == 'vvchd4k'):
        dataset_test = TestSequenciesVVCHD4K(yOnly=args.only_y_channel)
    else:
        raise Exception(f"Unsupported Raw Testset:{args.yuv_testset}. see comand info for detrails.")

    if not test_only:
        #Split the dataset in train and validation
        train_percentual = 0.8
        train_num = round(len(dataset) * train_percentual)
        val_num = round(len(dataset) * (1-train_percentual))
        train_data, val_data = random_split(dataset, [train_num, val_num], generator=torch.Generator().manual_seed(args.seed))
        val_data.dataset.transform = None # Deactivate Transform in validation
        
        #create the loader for the training set
        train_loader = DataLoader(train_data, shuffle=True,  batch_size=args.batch, num_workers=1, pin_memory=True, worker_init_fn = lambda id: np.random.seed(id))
        #create the loader for the validation set
        val_loader = DataLoader(val_data, shuffle=False, batch_size=args.batch, num_workers=1, pin_memory=True, worker_init_fn = lambda id: np.random.seed(id))
        test_loader = DataLoader(dataset_test, shuffle=False, batch_size=1, num_workers=1, pin_memory=True, worker_init_fn = lambda id: np.random.seed(id))
        
        return(train_loader, val_loader, test_loader)
    else:
        test_loader = DataLoader(dataset_test, shuffle=False, batch_size=1, num_workers=1, pin_memory=True, worker_init_fn = lambda id: np.random.seed(id))
        return test_loader