Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
EVC
Ufv1 0
Commits
fc11b5cf
Commit
fc11b5cf
authored
Nov 23, 2025
by
valentini
Browse files
Carica un nuovo file
parent
7ab20e15
Changes
1
Hide whitespace changes
Inline
Side-by-side
Reference Software/UFV1.0-Training/src/dataloader/yuv.py
0 → 100644
View file @
fc11b5cf
import
os
from
torch.utils.data
import
Dataset
import
yuvio
import
numpy
as
np
import
cv2
import
torch
INTERPOLATION_MAP
=
{
'nearest'
:
cv2
.
INTER_NEAREST
,
'linear'
:
cv2
.
INTER_LINEAR
,
'cubic'
:
cv2
.
INTER_CUBIC
,
'area'
:
cv2
.
INTER_AREA
,
'lanczos'
:
cv2
.
INTER_LANCZOS4
}
class
SRDataset
(
Dataset
):
def
__init__
(
self
,
settings
,
transform
=
None
):
self
.
low_sr_dir
=
settings
.
low_sr_dir
self
.
high_sr_dir
=
settings
.
high_sr_dir
self
.
width
=
getattr
(
settings
,
'width'
)
self
.
height
=
getattr
(
settings
,
'height'
)
self
.
pix_fmt
=
getattr
(
settings
,
'pix_fmt'
,
'yuv420p'
)
self
.
scale
=
getattr
(
settings
,
'scaling'
,
2
)
self
.
model_range
=
getattr
(
settings
,
'model_range'
,
255.0
)
# Default to 255.0 if not specified
self
.
luminance_only
=
getattr
(
settings
,
'luminance_only'
,
False
)
self
.
clone_luminance_as_rgb
=
getattr
(
settings
,
'clone_luminance_as_rgb'
,
False
)
self
.
upsample_uv
=
getattr
(
settings
,
'upsample_uv'
,
False
)
self
.
uv_interpolation
=
getattr
(
settings
,
'uv_interpolation'
,
'nearest'
)
self
.
patch_size
=
getattr
(
settings
,
'patch_size'
,
None
)
# Should be (h, w) or None
self
.
transform
=
transform
self
.
low_sr_images
=
sorted
([
f
for
f
in
os
.
listdir
(
self
.
low_sr_dir
)
if
f
.
lower
().
endswith
(
'.yuv'
)
])
self
.
high_sr_images
=
sorted
([
f
for
f
in
os
.
listdir
(
self
.
high_sr_dir
)
if
f
.
lower
().
endswith
(
'.yuv'
)
])
assert
len
(
self
.
low_sr_images
)
==
len
(
self
.
high_sr_images
),
"Mismatch in number of images."
def
_stack_yuv
(
self
,
yuv
):
y
=
yuv
.
y
u
=
yuv
.
u
v
=
yuv
.
v
if
y
.
shape
!=
u
.
shape
or
y
.
shape
!=
v
.
shape
:
if
self
.
upsample_uv
:
interp_flag
=
INTERPOLATION_MAP
.
get
(
self
.
uv_interpolation
,
cv2
.
INTER_NEAREST
)
u
=
cv2
.
resize
(
u
,
(
y
.
shape
[
1
],
y
.
shape
[
0
]),
interpolation
=
interp_flag
)
v
=
cv2
.
resize
(
v
,
(
y
.
shape
[
1
],
y
.
shape
[
0
]),
interpolation
=
interp_flag
)
else
:
raise
ValueError
(
f
"U/V channel size
{
u
.
shape
}
/
{
v
.
shape
}
does not match Y channel size
{
y
.
shape
}
and upsample_uv is False."
)
return
np
.
stack
([
y
,
u
,
v
],
axis
=-
1
)
def
_get_luminance
(
self
,
yuv
):
y
=
yuv
.
y
[...,
np
.
newaxis
]
# (H, W, 1)
if
self
.
clone_luminance_as_rgb
:
y
=
np
.
repeat
(
y
,
3
,
axis
=-
1
)
# (H, W, 3)
return
y
def
__len__
(
self
):
return
len
(
self
.
low_sr_images
)
def
__getitem__
(
self
,
idx
):
low_sr_path
=
os
.
path
.
join
(
self
.
low_sr_dir
,
self
.
low_sr_images
[
idx
])
high_sr_path
=
os
.
path
.
join
(
self
.
high_sr_dir
,
self
.
high_sr_images
[
idx
])
low_sr_yuv
=
yuvio
.
imread
(
low_sr_path
,
(
self
.
height
,
self
.
width
),
self
.
pix_fmt
)
high_sr_yuv
=
yuvio
.
imread
(
high_sr_path
,
(
self
.
height
,
self
.
width
),
self
.
pix_fmt
)
# Extract channels
if
self
.
luminance_only
:
low_sr_img
=
self
.
_get_luminance
(
low_sr_yuv
)
high_sr_img
=
self
.
_get_luminance
(
high_sr_yuv
)
else
:
low_sr_img
=
self
.
_stack_yuv
(
low_sr_yuv
)
high_sr_img
=
self
.
_stack_yuv
(
high_sr_yuv
)
# Random patch extraction
if
self
.
patch_size
is
not
None
:
ph
,
pw
=
self
.
patch_size
H
,
W
,
C
=
low_sr_img
.
shape
if
ph
>
H
or
pw
>
W
:
raise
ValueError
(
f
"Patch size
{
self
.
patch_size
}
is larger than image size
{
(
H
,
W
)
}
."
)
top
=
np
.
random
.
randint
(
0
,
H
-
ph
+
1
)
left
=
np
.
random
.
randint
(
0
,
W
-
pw
+
1
)
low_sr_img
=
low_sr_img
[
top
:
top
+
ph
,
left
:
left
+
pw
,
:]
# Scale coordinates for high_sr_img
scale
=
self
.
scale
ht
,
hp
=
int
(
scale
*
top
),
int
(
scale
*
ph
)
hl
,
hw
=
int
(
scale
*
left
),
int
(
scale
*
pw
)
high_sr_img
=
high_sr_img
[
ht
:
ht
+
hp
,
hl
:
hl
+
hw
,
:]
# Convert to float32, normalize to [0,model_range], and transpose to (C, H, W)
format_range
=
(
2
**
low_sr_yuv
.
yuv_format
.
bitdepth
())
-
1
low_sr_img
=
low_sr_img
.
astype
(
np
.
float32
)
*
(
self
.
model_range
/
format_range
)
high_sr_img
=
high_sr_img
.
astype
(
np
.
float32
)
*
(
self
.
model_range
/
format_range
)
low_sr_img
=
np
.
transpose
(
low_sr_img
,
(
2
,
0
,
1
))
high_sr_img
=
np
.
transpose
(
high_sr_img
,
(
2
,
0
,
1
))
# Convert to torch tensors
low_sr_img
=
torch
.
from_numpy
(
low_sr_img
)
high_sr_img
=
torch
.
from_numpy
(
high_sr_img
)
if
self
.
transform
:
# Ensure the same random seed for both transforms
seed
=
np
.
random
.
randint
(
0
,
1e9
)
torch
.
manual_seed
(
seed
)
np
.
random
.
seed
(
seed
)
low_sr_img
=
self
.
transform
(
low_sr_img
)
torch
.
manual_seed
(
seed
)
np
.
random
.
seed
(
seed
)
high_sr_img
=
self
.
transform
(
high_sr_img
)
return
{
'low_sr'
:
low_sr_img
,
'high_sr'
:
high_sr_img
}
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment