Commit fc11b5cf authored by valentini's avatar valentini
Browse files

Carica un nuovo file

parent 7ab20e15
import os
from torch.utils.data import Dataset
import yuvio
import numpy as np
import cv2
import torch
INTERPOLATION_MAP = {
'nearest': cv2.INTER_NEAREST,
'linear': cv2.INTER_LINEAR,
'cubic': cv2.INTER_CUBIC,
'area': cv2.INTER_AREA,
'lanczos': cv2.INTER_LANCZOS4
}
class SRDataset(Dataset):
def __init__(self, settings, transform=None):
self.low_sr_dir = settings.low_sr_dir
self.high_sr_dir = settings.high_sr_dir
self.width = getattr(settings, 'width')
self.height = getattr(settings, 'height')
self.pix_fmt = getattr(settings, 'pix_fmt', 'yuv420p')
self.scale = getattr(settings, 'scaling', 2)
self.model_range = getattr(settings, 'model_range', 255.0) # Default to 255.0 if not specified
self.luminance_only = getattr(settings, 'luminance_only', False)
self.clone_luminance_as_rgb = getattr(settings, 'clone_luminance_as_rgb', False)
self.upsample_uv = getattr(settings, 'upsample_uv', False)
self.uv_interpolation = getattr(settings, 'uv_interpolation', 'nearest')
self.patch_size = getattr(settings, 'patch_size', None) # Should be (h, w) or None
self.transform = transform
self.low_sr_images = sorted([
f for f in os.listdir(self.low_sr_dir)
if f.lower().endswith('.yuv')
])
self.high_sr_images = sorted([
f for f in os.listdir(self.high_sr_dir)
if f.lower().endswith('.yuv')
])
assert len(self.low_sr_images) == len(self.high_sr_images), "Mismatch in number of images."
def _stack_yuv(self, yuv):
y = yuv.y
u = yuv.u
v = yuv.v
if y.shape != u.shape or y.shape != v.shape:
if self.upsample_uv:
interp_flag = INTERPOLATION_MAP.get(self.uv_interpolation, cv2.INTER_NEAREST)
u = cv2.resize(u, (y.shape[1], y.shape[0]), interpolation=interp_flag)
v = cv2.resize(v, (y.shape[1], y.shape[0]), interpolation=interp_flag)
else:
raise ValueError(f"U/V channel size {u.shape}/{v.shape} does not match Y channel size {y.shape} and upsample_uv is False.")
return np.stack([y, u, v], axis=-1)
def _get_luminance(self, yuv):
y = yuv.y[..., np.newaxis] # (H, W, 1)
if self.clone_luminance_as_rgb:
y = np.repeat(y, 3, axis=-1) # (H, W, 3)
return y
def __len__(self):
return len(self.low_sr_images)
def __getitem__(self, idx):
low_sr_path = os.path.join(self.low_sr_dir, self.low_sr_images[idx])
high_sr_path = os.path.join(self.high_sr_dir, self.high_sr_images[idx])
low_sr_yuv = yuvio.imread(low_sr_path, (self.height, self.width), self.pix_fmt)
high_sr_yuv = yuvio.imread(high_sr_path, (self.height, self.width), self.pix_fmt)
# Extract channels
if self.luminance_only:
low_sr_img = self._get_luminance(low_sr_yuv)
high_sr_img = self._get_luminance(high_sr_yuv)
else:
low_sr_img = self._stack_yuv(low_sr_yuv)
high_sr_img = self._stack_yuv(high_sr_yuv)
# Random patch extraction
if self.patch_size is not None:
ph, pw = self.patch_size
H, W, C = low_sr_img.shape
if ph > H or pw > W:
raise ValueError(f"Patch size {self.patch_size} is larger than image size {(H, W)}.")
top = np.random.randint(0, H - ph + 1)
left = np.random.randint(0, W - pw + 1)
low_sr_img = low_sr_img[top:top+ph, left:left+pw, :]
# Scale coordinates for high_sr_img
scale = self.scale
ht, hp = int(scale * top), int(scale * ph)
hl, hw = int(scale * left), int(scale * pw)
high_sr_img = high_sr_img[ht:ht+hp, hl:hl+hw, :]
# Convert to float32, normalize to [0,model_range], and transpose to (C, H, W)
format_range = (2 ** low_sr_yuv.yuv_format.bitdepth()) - 1
low_sr_img = low_sr_img.astype(np.float32) * (self.model_range / format_range)
high_sr_img = high_sr_img.astype(np.float32) * (self.model_range / format_range)
low_sr_img = np.transpose(low_sr_img, (2, 0, 1))
high_sr_img = np.transpose(high_sr_img, (2, 0, 1))
# Convert to torch tensors
low_sr_img = torch.from_numpy(low_sr_img)
high_sr_img = torch.from_numpy(high_sr_img)
if self.transform:
# Ensure the same random seed for both transforms
seed = np.random.randint(0, 1e9)
torch.manual_seed(seed)
np.random.seed(seed)
low_sr_img = self.transform(low_sr_img)
torch.manual_seed(seed)
np.random.seed(seed)
high_sr_img = self.transform(high_sr_img)
return {'low_sr': low_sr_img, 'high_sr': high_sr_img}
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment