Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
EVC
Ufv1 0
Commits
305950ef
Commit
305950ef
authored
Nov 27, 2025
by
valentini
Browse files
Carica un nuovo file
parent
d684063b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Reference Software/UFV1.0-Pruning/src/dataset/__init__.py
0 → 100644
View file @
305950ef
from
torchvision
import
transforms
as
transforms
from
torch.utils.data
import
DataLoader
from
dataset.RGBDataset
import
Div2KTrain
,
Banchmark
,
CropDataset
from
dataset.YUVDataset
import
*
from
utils.transforms
import
DescreteRandomRotation
from
torch.utils.data
import
random_split
import
os
import
torch
import
numpy
as
np
def
RGBDataset
(
args
,
useBGR
=
False
,
test_only
=
False
):
# Dataset Set Up ###########################################
############ Image data augmentation for retrain ###########
transform
=
transforms
.
Compose
([
DescreteRandomRotation
(
angles
=
[
0
,
90
,
180
,
-
90
]),
transforms
.
RandomVerticalFlip
(),
transforms
.
RandomHorizontalFlip
(),
])
# Select the Data load er for the datasets (biggest difference is the folder structure and the loading process)
dataset
=
Div2KTrain
(
scaling
=
args
.
scale
,
useBGR
=
useBGR
)
dataset_test
=
Banchmark
(
scaling
=
args
.
scale
,
useBGR
=
useBGR
)
if
args
.
test_mode
:
use_only
=
0.2
kept_num
=
round
(
len
(
dataset
)
*
use_only
)
igonored_num
=
round
(
len
(
dataset
)
*
(
1
-
use_only
))
dataset
,
_
=
random_split
(
dataset
,
[
kept_num
,
igonored_num
],
generator
=
torch
.
Generator
().
manual_seed
(
args
.
seed
))
use_only
=
0.1
kept_num
=
round
(
len
(
dataset_test
)
*
use_only
)
igonored_num
=
round
(
len
(
dataset_test
)
*
(
1
-
use_only
))
dataset_test
,
_
=
random_split
(
dataset_test
,
[
kept_num
,
igonored_num
],
generator
=
torch
.
Generator
().
manual_seed
(
args
.
seed
))
if
not
test_only
:
#Split the dataset in train and validation
train_percentual
=
0.8
train_num
=
round
(
len
(
dataset
)
*
train_percentual
)
val_num
=
round
(
len
(
dataset
)
*
(
1
-
train_percentual
))
train_data
,
val_data
=
random_split
(
dataset
,
[
train_num
,
val_num
],
generator
=
torch
.
Generator
().
manual_seed
(
args
.
seed
))
train_data
=
CropDataset
(
train_data
,
scaling
=
args
.
scale
,
transform
=
transform
,
seed
=
args
.
seed
,
crop
=
args
.
crop
,
cropMode
=
"random"
,
noiseInensity
=
args
.
noise
)
val_data
=
CropDataset
(
val_data
,
scaling
=
args
.
scale
,
transform
=
None
,
seed
=
None
,
crop
=
args
.
crop
,
cropMode
=
"center"
)
# create the loader for the training set
train_loader
=
DataLoader
(
train_data
,
shuffle
=
True
,
batch_size
=
args
.
batch
,
num_workers
=
0
,
pin_memory
=
True
,
worker_init_fn
=
lambda
id
:
np
.
random
.
seed
(
id
))
# create the loader for the validation set (to select the model after prune)
val_loader
=
DataLoader
(
val_data
,
shuffle
=
False
,
batch_size
=
args
.
batch
,
num_workers
=
0
,
pin_memory
=
True
,
worker_init_fn
=
lambda
id
:
np
.
random
.
seed
(
id
))
# create the loader for the test set (to evaluate the model performaces after selection)
test_loader
=
DataLoader
(
dataset_test
,
shuffle
=
False
,
batch_size
=
1
,
num_workers
=
0
,
pin_memory
=
True
,
worker_init_fn
=
lambda
id
:
np
.
random
.
seed
(
id
))
return
(
train_loader
,
val_loader
,
test_loader
)
else
:
test_loader
=
DataLoader
(
dataset_test
,
shuffle
=
False
,
batch_size
=
1
,
num_workers
=
0
,
pin_memory
=
True
,
worker_init_fn
=
lambda
id
:
np
.
random
.
seed
(
id
))
return
(
train_loader
,
val_loader
,
test_loader
)
def
YUVDataset
(
args
,
test_only
=
False
):
# Dataset Set Up ###########################################
############ Image data augmentation for retrain ###########
transform
=
transforms
.
Compose
([
DescreteRandomRotation
(
angles
=
[
0
,
90
,
180
,
-
90
]),
transforms
.
RandomVerticalFlip
(),
transforms
.
RandomHorizontalFlip
(),
])
if
not
test_only
:
dataset
=
DatasetFromFolderYUV
(
scale
=
args
.
scale
,
size
=
args
.
yuv_size
,
yuvFormat
=
args
.
yuv_format
,
y_only
=
args
.
only_y_channel
,
filter_only
=
args
.
filter
,
low_res_folder
=
os
.
path
.
normpath
(
args
.
low_res_data
),
prefix_low_res
=
args
.
prefix_lowres
,
high_res_folder
=
os
.
path
.
normpath
(
args
.
high_res_data
),
prefix_high_res
=
args
.
prefix_highres
,
transform
=
transform
,
seed
=
args
.
seed
,
)
if
(
args
.
yuv_testset
.
lower
()
==
'evcintra'
):
dataset_test
=
TestSequenciesEVCIntra
(
yOnly
=
args
.
only_y_channel
)
elif
(
args
.
yuv_testset
.
lower
()
==
'evcsdhd'
):
dataset_test
=
TestSequenciesEVCSDHD
(
yOnly
=
args
.
only_y_channel
)
elif
(
args
.
yuv_testset
.
lower
()
==
'evchd4k'
):
dataset_test
=
TestSequenciesEVCHD4K
(
yOnly
=
args
.
only_y_channel
)
elif
(
args
.
yuv_testset
.
lower
()
==
'vvcsdhd'
):
dataset_test
=
TestSequenciesVVCSDHD
(
yOnly
=
args
.
only_y_channel
)
elif
(
args
.
yuv_testset
.
lower
()
==
'vvchd4k'
):
dataset_test
=
TestSequenciesVVCHD4K
(
yOnly
=
args
.
only_y_channel
)
else
:
raise
Exception
(
f
"Unsupported Raw Testset:
{
args
.
yuv_testset
}
. see comand info for detrails."
)
if
not
test_only
:
#Split the dataset in train and validation
train_percentual
=
0.8
train_num
=
round
(
len
(
dataset
)
*
train_percentual
)
val_num
=
round
(
len
(
dataset
)
*
(
1
-
train_percentual
))
train_data
,
val_data
=
random_split
(
dataset
,
[
train_num
,
val_num
],
generator
=
torch
.
Generator
().
manual_seed
(
args
.
seed
))
val_data
.
dataset
.
transform
=
None
# Deactivate Transform in validation
#create the loader for the training set
train_loader
=
DataLoader
(
train_data
,
shuffle
=
True
,
batch_size
=
args
.
batch
,
num_workers
=
1
,
pin_memory
=
True
,
worker_init_fn
=
lambda
id
:
np
.
random
.
seed
(
id
))
#create the loader for the validation set
val_loader
=
DataLoader
(
val_data
,
shuffle
=
False
,
batch_size
=
args
.
batch
,
num_workers
=
1
,
pin_memory
=
True
,
worker_init_fn
=
lambda
id
:
np
.
random
.
seed
(
id
))
test_loader
=
DataLoader
(
dataset_test
,
shuffle
=
False
,
batch_size
=
1
,
num_workers
=
1
,
pin_memory
=
True
,
worker_init_fn
=
lambda
id
:
np
.
random
.
seed
(
id
))
return
(
train_loader
,
val_loader
,
test_loader
)
else
:
test_loader
=
DataLoader
(
dataset_test
,
shuffle
=
False
,
batch_size
=
1
,
num_workers
=
1
,
pin_memory
=
True
,
worker_init_fn
=
lambda
id
:
np
.
random
.
seed
(
id
))
return
test_loader
\ No newline at end of file
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment