yuv.py 4.78 KB
Newer Older
valentini's avatar
valentini committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
from torch.utils.data import Dataset
import yuvio
import numpy as np
import cv2
import torch

INTERPOLATION_MAP = {
    'nearest': cv2.INTER_NEAREST,
    'linear': cv2.INTER_LINEAR,
    'cubic': cv2.INTER_CUBIC,
    'area': cv2.INTER_AREA,
    'lanczos': cv2.INTER_LANCZOS4
}

class SRDataset(Dataset):
    def __init__(self, settings, transform=None):
        self.low_sr_dir = settings.low_sr_dir
        self.high_sr_dir = settings.high_sr_dir
        self.width = getattr(settings, 'width')
        self.height = getattr(settings, 'height')
        self.pix_fmt = getattr(settings, 'pix_fmt', 'yuv420p')
        self.scale = getattr(settings, 'scaling', 2)
        self.model_range = getattr(settings, 'model_range', 255.0)  # Default to 255.0 if not specified
        self.luminance_only = getattr(settings, 'luminance_only', False)
        self.clone_luminance_as_rgb = getattr(settings, 'clone_luminance_as_rgb', False)
        self.upsample_uv = getattr(settings, 'upsample_uv', False)
        self.uv_interpolation = getattr(settings, 'uv_interpolation', 'nearest')
        self.patch_size = getattr(settings, 'patch_size', None)  # Should be (h, w) or None
        self.transform = transform

        self.low_sr_images = sorted([
            f for f in os.listdir(self.low_sr_dir)
            if f.lower().endswith('.yuv')
        ])
        self.high_sr_images = sorted([
            f for f in os.listdir(self.high_sr_dir)
            if f.lower().endswith('.yuv')
        ])

        assert len(self.low_sr_images) == len(self.high_sr_images), "Mismatch in number of images."

    def _stack_yuv(self, yuv):
        y = yuv.y
        u = yuv.u
        v = yuv.v
        if y.shape != u.shape or y.shape != v.shape:
            if self.upsample_uv:
                interp_flag = INTERPOLATION_MAP.get(self.uv_interpolation, cv2.INTER_NEAREST)
                u = cv2.resize(u, (y.shape[1], y.shape[0]), interpolation=interp_flag)
                v = cv2.resize(v, (y.shape[1], y.shape[0]), interpolation=interp_flag)
            else:
                raise ValueError(f"U/V channel size {u.shape}/{v.shape} does not match Y channel size {y.shape} and upsample_uv is False.")
        return np.stack([y, u, v], axis=-1)

    def _get_luminance(self, yuv):
        y = yuv.y[..., np.newaxis]  # (H, W, 1)
        if self.clone_luminance_as_rgb:
            y = np.repeat(y, 3, axis=-1)  # (H, W, 3)
        return y

    def __len__(self):
        return len(self.low_sr_images)

    def __getitem__(self, idx):
        low_sr_path = os.path.join(self.low_sr_dir, self.low_sr_images[idx])
        high_sr_path = os.path.join(self.high_sr_dir, self.high_sr_images[idx])

        low_sr_yuv = yuvio.imread(low_sr_path, (self.height, self.width), self.pix_fmt)
        high_sr_yuv = yuvio.imread(high_sr_path, (self.height, self.width), self.pix_fmt)

        # Extract channels
        if self.luminance_only:
            low_sr_img = self._get_luminance(low_sr_yuv)
            high_sr_img = self._get_luminance(high_sr_yuv)
        else:
            low_sr_img = self._stack_yuv(low_sr_yuv)
            high_sr_img = self._stack_yuv(high_sr_yuv)

        # Random patch extraction
        if self.patch_size is not None:
            ph, pw = self.patch_size
            H, W, C = low_sr_img.shape
            if ph > H or pw > W:
                raise ValueError(f"Patch size {self.patch_size} is larger than image size {(H, W)}.")
            top = np.random.randint(0, H - ph + 1)
            left = np.random.randint(0, W - pw + 1)
            low_sr_img = low_sr_img[top:top+ph, left:left+pw, :]
            # Scale coordinates for high_sr_img
            scale = self.scale
            ht, hp = int(scale * top), int(scale * ph)
            hl, hw = int(scale * left), int(scale * pw)
            high_sr_img = high_sr_img[ht:ht+hp, hl:hl+hw, :]

        # Convert to float32, normalize to [0,model_range], and transpose to (C, H, W)
        format_range = (2 ** low_sr_yuv.yuv_format.bitdepth()) - 1
        low_sr_img = low_sr_img.astype(np.float32) * (self.model_range / format_range)
        high_sr_img = high_sr_img.astype(np.float32) * (self.model_range / format_range)
        low_sr_img = np.transpose(low_sr_img, (2, 0, 1))
        high_sr_img = np.transpose(high_sr_img, (2, 0, 1))

        # Convert to torch tensors
        low_sr_img = torch.from_numpy(low_sr_img)
        high_sr_img = torch.from_numpy(high_sr_img)

        if self.transform:
            # Ensure the same random seed for both transforms
            seed = np.random.randint(0, 1e9)
            torch.manual_seed(seed)
            np.random.seed(seed)
            low_sr_img = self.transform(low_sr_img)
            torch.manual_seed(seed)
            np.random.seed(seed)
            high_sr_img = self.transform(high_sr_img)

        return {'low_sr': low_sr_img, 'high_sr': high_sr_img}