import os from torch.utils.data import Dataset import yuvio import numpy as np import cv2 import torch INTERPOLATION_MAP = { 'nearest': cv2.INTER_NEAREST, 'linear': cv2.INTER_LINEAR, 'cubic': cv2.INTER_CUBIC, 'area': cv2.INTER_AREA, 'lanczos': cv2.INTER_LANCZOS4 } class SRDataset(Dataset): def __init__(self, settings, transform=None): self.low_sr_dir = settings.low_sr_dir self.high_sr_dir = settings.high_sr_dir self.width = getattr(settings, 'width') self.height = getattr(settings, 'height') self.pix_fmt = getattr(settings, 'pix_fmt', 'yuv420p') self.scale = getattr(settings, 'scaling', 2) self.model_range = getattr(settings, 'model_range', 255.0) # Default to 255.0 if not specified self.luminance_only = getattr(settings, 'luminance_only', False) self.clone_luminance_as_rgb = getattr(settings, 'clone_luminance_as_rgb', False) self.upsample_uv = getattr(settings, 'upsample_uv', False) self.uv_interpolation = getattr(settings, 'uv_interpolation', 'nearest') self.patch_size = getattr(settings, 'patch_size', None) # Should be (h, w) or None self.transform = transform self.low_sr_images = sorted([ f for f in os.listdir(self.low_sr_dir) if f.lower().endswith('.yuv') ]) self.high_sr_images = sorted([ f for f in os.listdir(self.high_sr_dir) if f.lower().endswith('.yuv') ]) assert len(self.low_sr_images) == len(self.high_sr_images), "Mismatch in number of images." def _stack_yuv(self, yuv): y = yuv.y u = yuv.u v = yuv.v if y.shape != u.shape or y.shape != v.shape: if self.upsample_uv: interp_flag = INTERPOLATION_MAP.get(self.uv_interpolation, cv2.INTER_NEAREST) u = cv2.resize(u, (y.shape[1], y.shape[0]), interpolation=interp_flag) v = cv2.resize(v, (y.shape[1], y.shape[0]), interpolation=interp_flag) else: raise ValueError(f"U/V channel size {u.shape}/{v.shape} does not match Y channel size {y.shape} and upsample_uv is False.") return np.stack([y, u, v], axis=-1) def _get_luminance(self, yuv): y = yuv.y[..., np.newaxis] # (H, W, 1) if self.clone_luminance_as_rgb: y = np.repeat(y, 3, axis=-1) # (H, W, 3) return y def __len__(self): return len(self.low_sr_images) def __getitem__(self, idx): low_sr_path = os.path.join(self.low_sr_dir, self.low_sr_images[idx]) high_sr_path = os.path.join(self.high_sr_dir, self.high_sr_images[idx]) low_sr_yuv = yuvio.imread(low_sr_path, (self.height, self.width), self.pix_fmt) high_sr_yuv = yuvio.imread(high_sr_path, (self.height, self.width), self.pix_fmt) # Extract channels if self.luminance_only: low_sr_img = self._get_luminance(low_sr_yuv) high_sr_img = self._get_luminance(high_sr_yuv) else: low_sr_img = self._stack_yuv(low_sr_yuv) high_sr_img = self._stack_yuv(high_sr_yuv) # Random patch extraction if self.patch_size is not None: ph, pw = self.patch_size H, W, C = low_sr_img.shape if ph > H or pw > W: raise ValueError(f"Patch size {self.patch_size} is larger than image size {(H, W)}.") top = np.random.randint(0, H - ph + 1) left = np.random.randint(0, W - pw + 1) low_sr_img = low_sr_img[top:top+ph, left:left+pw, :] # Scale coordinates for high_sr_img scale = self.scale ht, hp = int(scale * top), int(scale * ph) hl, hw = int(scale * left), int(scale * pw) high_sr_img = high_sr_img[ht:ht+hp, hl:hl+hw, :] # Convert to float32, normalize to [0,model_range], and transpose to (C, H, W) format_range = (2 ** low_sr_yuv.yuv_format.bitdepth()) - 1 low_sr_img = low_sr_img.astype(np.float32) * (self.model_range / format_range) high_sr_img = high_sr_img.astype(np.float32) * (self.model_range / format_range) low_sr_img = np.transpose(low_sr_img, (2, 0, 1)) high_sr_img = np.transpose(high_sr_img, (2, 0, 1)) # Convert to torch tensors low_sr_img = torch.from_numpy(low_sr_img) high_sr_img = torch.from_numpy(high_sr_img) if self.transform: # Ensure the same random seed for both transforms seed = np.random.randint(0, 1e9) torch.manual_seed(seed) np.random.seed(seed) low_sr_img = self.transform(low_sr_img) torch.manual_seed(seed) np.random.seed(seed) high_sr_img = self.transform(high_sr_img) return {'low_sr': low_sr_img, 'high_sr': high_sr_img}