# Prepare the dataset for the training phase import torch.utils.data as data import torchvision.transforms as T import yuvio import glob import glob2 import torch import logging import random import numpy as np import cv2 import re import os.path as path from utils.transforms import paralellCrop from PIL import Image ORIGNAL_BASE_PATH = '/trinity/home/mangelini/data/datasets/OrignialRaw/Original_Sequencies' def getOriginalSequencePathFromSequenceName(name): sequence_paths = { 'BasketballDrive': 'BasketballDrive_1920x1080_50.yuv', 'FoodMarket4': 'FoodMarket4_3840x2160_60fps_10bit_420.yuv', 'BQTerrace': 'BQTerrace_1920x1080_60.yuv', 'MarketPlace': 'MarketPlace_1920x1080_60fps_10bit_420.yuv', 'Cactus': 'Cactus_1920x1080_50.yuv', 'ParkRunning3': 'ParkRunning3_3840x2160_50fps_10bit_420.yuv', 'Campfire': 'Campfire_3840x2160_30fps_10bit_bt709_420_videoRange.yuv', 'RitualDance': 'RitualDance_1920x1080_60fps_10bit_420.yuv', 'CatRobot': 'CatRobot_3840x2160_60fps_10bit_420_jvet.yuv', 'Tango2': 'Tango2_3840x2160_60fps_10bit_420.yuv', 'DaylightRoad2': 'DaylightRoad2_3840x2160_60fps_10bit_420.yuv', } return f'{ORIGNAL_BASE_PATH}/{sequence_paths[name]}' def getDepthFromFormat(format): if format == 'gray': return 8 if format =='gray10le': return 10 if format =='gray10be': return 10 if format =='gray16le': return 16 if format =='gray16be': return 16 if format =='gray9le': return 9 if format =='gray9be': return 9 if format =='gray12le': return 12 if format =='gray12be': return 12 if format =='gray14le': return 14 if format =='gray14be': return 14 if format =='nv12': return 8 if format =='v210': return 10 if format =='yuv420p': return 8 if format =='yuv420p10le': return 10 if format =='yuv420p10be': return 10 if format =='yuv420p16le': return 16 if format =='yuv420p16be': return 16 if format =='yuv420p9le': return 9 if format =='yuv420p9be': return 9 if format =='yuv420p12le': return 12 if format =='yuv420p12be': return 12 if format =='yuv420p14le': return 14 if format =='yuv420p14be': return 14 if format =='yuv422p': return 8 if format =='yuv422p10le': return 10 if format =='yuv422p10be': return 10 if format =='yuv422p16le': return 16 if format =='yuv422p16be': return 16 if format =='yuv422p9le': return 9 if format =='yuv422p9be': return 9 if format =='yuv422p12le': return 12 if format =='yuv422p12be': return 12 if format =='yuv422p14le': return 14 if format =='yuv422p14be': return 14 if format =='yuv444p': return 8 if format =='yuv444p10le': return 10 if format =='yuv444p10be': return 10 if format =='yuv444p16le': return 16 if format =='yuv444p16be': return 16 if format =='yuv444p9le': return 9 if format =='yuv444p9be': return 9 if format =='yuv444p12le': return 12 if format =='yuv444p12be': return 12 if format =='yuv444p14le': return 14 if format =='yuv444p14be': return 14 if format =='yuyv422': return 8 if format =='uyvy422': return 8 if format =='yvyu422': return 8 return 0 def extract_max_from_bitsize(bitsize): #find nearest 2 mutiple size = 2 while size < bitsize: size *= 2 return ((2 ** size) - 1) def is_image_file(filename): return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"]) def load_img(filepath): img = Image.open(filepath).convert('YCbCr') y, _, _ = img.split() return y def frame_to_tensor(input_image): img_array = np.asarray(input_image).copy().astype(np.float) img_tensor = torch.tensor(img_array, dtype=torch.float32).permute((2, 1, 0)) return img_tensor def yuv_to_tensor(input_image): img_array = np.asarray(input_image).copy().astype(np.float) img_tensor = torch.tensor(img_array, dtype=torch.float32).permute((0, 2, 1)) return img_tensor # support only 444 def load_yuv_img(path, bitsFormart='uint8', w=1920, h=1080) : # Read entire file into YUV YUV = np.fromfile(path, dtype=bitsFormart) # return the image as a frame of w*h return YUV.reshape(h, w, 3) def getContainerType(depth): return 'uint{}'.format(depth) class DatasetFromFolderYUV(data.Dataset): def __init__(self, low_res_folder, high_res_folder, size, scale, yuvFormat='yuv444p10le', take_channels=3,transform=None, limit=None, return_names=False, filter_only=None, y_only=False, prefix_low_res=None, low_res_filter=None, prefix_high_res=None, high_res_filter=None, normalize=False, random_frame=False, seed= 888): super(DatasetFromFolderYUV, self).__init__() self.transform = transform self.return_names = return_names self.width = 128 self.heigth = 128 self.yuvFormat = yuvFormat self.depth = getDepthFromFormat(self.yuvFormat) self.bucketSize = min([x for x in [8, 16, 32, 64, 128] if x > self.depth]) self.yOnly = y_only self.takeChannels = max(min(take_channels,1), 3) self.scale = int(scale) self.normalize = normalize self.random_frame = random_frame self.random = np.random.RandomState(seed) self.low_res_folder = low_res_folder self.high_res_folder = high_res_folder self.prefix_low_res = prefix_low_res self.prefix_high_res = prefix_high_res self.low_res_filter = low_res_filter self.high_res_filter = high_res_filter self.low_res_filenames = glob.glob("{}/**/*.yuv".format(low_res_folder), recursive=True) if filter_only: self.low_res_filenames = [x for x in self.low_res_filenames if filter_only in x] #self.low_res_filenames = self._removeBroken() if limit : self.low_res_filenames = self.low_res_filenames[:limit] def _getMax(self): return (2 ** self.depth) - 1 def _removeBroken(self): not_broken = [] for input_path in self.low_res_filenames: try: self.getYUV(input_path, 1) target_path = input_path.replace(self.prefix_low_res, self.prefix_high_res).replace(self.low_res_folder, self.high_res_folder).replace('-qp22', '') target_path = self.removeClasses(target_path) target_path = target_path.replace(f'1920x1088', f'3840x2176') # HD-4K target_path = target_path.replace(f'960x544', f'1920x1088') # SD-HD self.getYUV(target_path, self.scale) not_broken += input_path except: logging.warning(f'Skipping ({input_path}) impossiple to find HD counterpart of broken data') return not_broken def getDataRange(self): if self.normalize: return 1.0 return self._getMax() def removeClasses(self, text): sliced = text.split('/') if not 'bvi-dvc' in text: return text elif '1920x1088' in text: sliced[-1] = sliced[-1].replace("B", "A", 1) sliced[-2] = sliced[-2].replace("B", "A", 1) else: sliced[-1] = sliced[-1].replace("C", "B", 1) sliced[-2] = sliced[-2].replace("C", "B", 1) return '/'.join(sliced) def swapFpsDepth(self, text): if not 'ultravideo' in text or "3840x2176" in text: return text sliced = text.split('/') tmpSlice = sliced[-1].split("_") tmp = tmpSlice[-2] tmpSlice[-2] = tmpSlice[-3] tmpSlice[-3] = tmp sliced[-1] = "_".join(tmpSlice) tmpSlice = sliced[-2].split("_") tmp = tmpSlice[-1] tmpSlice[-1] = tmpSlice[-2] tmpSlice[-2] = tmp sliced[-2] = "_".join(tmpSlice) return '/'.join(sliced) def getYUV(self, path, scale): input = yuvio.imread(path, self.width*scale, self.heigth*scale, self.yuvFormat) Y,U,V = input.split() if self.yOnly: input = np.stack([Y,Y,Y]) else: input = np.stack([Y,U,V]) # input = np.fromfile(path, dtype=np.uint16) # input = input[0:(self.heigth*scale) * (self.width*scale)].reshape(self.heigth*scale, self.width*scale) # input = np.stack([input,input,input]) if self.yOnly: return input[0:self.takeChannels] return input def getTargetPath(self, in_path): target_path = in_path.replace(self.prefix_low_res, self.prefix_high_res).replace(self.low_res_folder, self.high_res_folder).replace('-qp22', '') target_path = target_path.replace(f'1920x1088', f'3840x2176') # HD-4K target_path = target_path.replace(f'960x544', f'1920x1088') # SD-HD if not path.exists(target_path): target_path = self.removeClasses(target_path) if not path.exists(target_path): target_path = self.swapFpsDepth(target_path) return target_path def __getitem__(self, index): input_path = self.low_res_filenames[index] input = self.getYUV(input_path, 1) target_path = self.getTargetPath(input_path) target = self.getYUV(target_path, self.scale) # Convert yuv in pytorch shape input = yuv_to_tensor(input) target = yuv_to_tensor(target) # Normalize input = input.mul(255/1023) target = target.mul(255/1023) if self.transform: seed = self.random.randint(2147483647) # make a seed with numpy generator random.seed(seed) # apply this seed to img tranfsorms torch.manual_seed(seed) # needed for torchvision 0.7 input = self.transform(input) random.seed(seed) # Force same transform for the target torch.manual_seed(seed) # Force same transform for the target target = self.transform(target) if self.return_names: return (input_path, input), (target_path, target) return (input, target) def __len__(self): return len(self.low_res_filenames) class DatasetFromYUVSequence(data.Dataset): def __init__(self, low_res_folder, high_res_folder, size, scale, yuvFormat='yuv444p10le', transform=None, limit=None, return_names=False, filter_only=None, y_only=False, prefix_low_res=None, prefix_high_res=None, random_frame=False, crop=None, seed= 888): super(DatasetFromYUVSequence, self).__init__() self.transform = transform self.return_names = return_names self.width = int(size.split("x")[0]) self.heigth = int(size.split("x")[1]) self.yuvFormat = yuvFormat self.depth = getDepthFromFormat(self.yuvFormat) self.bucketSize = min([x for x in [8, 16, 32, 64, 128] if x > self.depth]) self.yOnly = y_only self.scale = int(scale) self.random_frame = random_frame self.crop = int(crop) if crop else -1 self.random = np.random.RandomState(seed=seed) self.low_res_folder = low_res_folder self.high_res_folder = high_res_folder self.prefix_low_res = prefix_low_res self.prefix_high_res = prefix_high_res self.low_res_filenames = glob2.glob("{}/**/*.yuv".format(low_res_folder)) if filter_only: self.low_res_filenames = [x for x in self.low_res_filenames if filter_only in x] self.low_res_filenames = [x for x in self.low_res_filenames if size in x] self.low_res_filenames = self.removeBrokenPairs() if limit : self.low_res_filenames = self.low_res_filenames[:limit] def _getMax(self): return (2 ** self.depth) - 1 def getDataRange(self): return 255 def loadSequence(self, path_img, scale): return yuvio.mimread(path_img, self.width*scale, self.heigth*scale, self.yuvFormat) def toNumpy(self, yuvIOimg): Y, U, V = yuvIOimg.split() if self.yOnly: return np.stack([Y,Y,Y]) else: return np.stack([Y,U,V]) def prepare(self, data, seed): data = yuv_to_tensor(data) if self.transform: random.seed(seed) # apply this seed to img tranfsorms torch.manual_seed(seed) # needed for torchvision 0.7 data = self.transform(data) # Normalize for transfer learning data = data.mul(255/1023) return data def removeBrokenPairs(self): new_low_res=[] for low_res in self.low_res_filenames: high_res_path = self.getHighResPath(low_res) if(path.exists(high_res_path)): new_low_res.append(low_res) else: logging.warning(f'Could not fint the HD counterpart for {low_res} at -> {high_res_path}') return new_low_res def sequenceSepcificIssuesFix(self, path, file): if path.count('/bvi') > 0: if '1920x1088' in file: return file.replace("B", "A", 1) else: return file.replace("C", "B", 1) elif path.count('/ultravideo') > 0: file_segment = file.replace(".yuv", "").split("_") tmp = file_segment[-1] file_segment[-1] = file_segment[-2] file_segment[-2] = tmp new_filename = "_".join(file_segment) return f"{new_filename}.yuv" else: return file def getHighResPath(self, img_path): target_folder = path.dirname(img_path).replace(self.low_res_folder, self.high_res_folder) target_name = path.basename(img_path) target_name = re.sub("-qp[0-9][0-9]", "", target_name) # Remove the -qp## from name target_name = self.sequenceSepcificIssuesFix(target_folder, target_name) target_name = target_name.replace(self.prefix_low_res, self.prefix_high_res) target_name = target_name.replace(f'1920x1088', f'3840x2176') # HD-4K target_name = target_name.replace(f'960x544', f'1920x1088') # SD-HD return path.join(target_folder, target_name) def __getitem__(self, index): input_path = self.low_res_filenames[index] inputs = self.loadSequence(input_path, 1) target_path = self.getHighResPath(input_path) targets = self.loadSequence(target_path, self.scale) random_frame_index = 0 if self.random_frame: frame_count = min(len(targets), len(inputs)) random_frame_index = random.random * frame_count # Extract frame input = inputs[random_frame_index] target = targets[random_frame_index] ## to numpy input = self.toNumpy(input) target = self.toNumpy(target) if self.crop and self.crop > 10 : i, t, crop_info = paralellCrop(input, target, crop=self.crop, scale=self.scale, random=random) input = i target = t seed = self.random.randint(1000000) input = self.prepare(input, seed) target = self.prepare(target, seed) if self.return_names: return (input_path, input), (target_path, target) return (input, target) def __len__(self): return len(self.low_res_filenames) class TestSequencies(data.Dataset): loaded = {} def __init__(self, lowres_path, target_path, scale, width, heigth, yuvFormat, yOnly=True): super(TestSequencies, self).__init__() self.yOnly = yOnly self.scale = scale self.width=width self.heigth=heigth self.yuvFormat=yuvFormat self.data_path=lowres_path self.target_path=target_path self.low_res_filenames = glob2.glob(f'{self.data_path}/**/*.yuv') self.low_res_filenames.sort() #self.low_res_filenames = [x for x in self.low_res_filenames if "qp-22" in x] def _getMax(self): return 1023 def getDataRange(self): return 255 def load(self, path_img, scale): return yuvio.imread(path_img, self.width*scale, self.heigth*scale, self.yuvFormat) def toNumpy(self, yuvIOimg): Y, U, V = yuvIOimg.split() if self.yOnly: return np.stack([Y,Y,Y]) else: return np.stack([Y,U,V]) def prepare(self, data): data = yuv_to_tensor(data) # Normalize for transfer learning data = data.mul(255/1023) return data def getSequenceName(self, img_path): return re.sub("_prop_qp-[0-9][0-9]", "", path.dirname(img_path).split("/")[-1]) def getQP(self, img_path): base_sequence = path.dirname(img_path).split("/")[-1] for i in range(10, 100): if f'_qp{i}' in base_sequence: return i return 0 def getTargetPath(self, img_path): sequenceName = self.getSequenceName(img_path) target_folder = path.join(self.target_path, f'{sequenceName}_3840x2176_50fps_10bit_420p') target_name = path.basename(img_path) return path.join(target_folder, target_name) def __getitem__(self, index): input_path = self.low_res_filenames[index] input = self.load(input_path, 1) target_path = self.getTargetPath(input_path) target = self.load(target_path, self.scale) ## to numpy input = self.toNumpy(input) target = self.toNumpy(target) input = self.prepare(input) target = yuv_to_tensor(target) return (input, target, { "sequence_name": self.getSequenceName(input_path), "qp": self.getQP(input_path), "frame": path.basename(input_path).replace(".yuv", "") }) def __len__(self): return len(self.low_res_filenames) class TestSequenciesEVCIntra(TestSequencies): def __init__(self, yOnly=True): super(TestSequenciesEVCIntra, self).__init__( '/trinity/home/mangelini/data/datasets/encodings_3x3_twice_4x4_large_invstride_cs-64_ps-32-16-8-4_jvet-B_e-1000-686cont_ds_randomaccess/data/prepared/low_res_420', '/trinity/home/mangelini/data/datasets/encodings_3x3_twice_4x4_large_invstride_cs-64_ps-32-16-8-4_jvet-B_e-1000-686cont_ds_randomaccess/data/prepared/target', 2, 1920, 1088,'yuv420p10le', yOnly ) class TestSequenciesEVCHD4K(TestSequencies): def __init__(self, yOnly=True): super(TestSequenciesEVCHD4K, self).__init__( '/trinity/home/mangelini/data/datasets/MPAI-TestSets/HR/EVC_420', '/trinity/home/mangelini/data/datasets/encodings_3x3_twice_4x4_large_invstride_cs-64_ps-32-16-8-4_jvet-B_e-1000-686cont_ds_randomaccess/data/prepared/target', 2, 1920, 1088,'yuv420p10le', yOnly ) def __getitem__(self, index): input_path = self.low_res_filenames[index] input = self.load(input_path, 1) basename = input_path.split("/")[-1] frame_no=int(basename.split("_")[0].replace("frame", "")) sequence = basename.split("_")[1] original_path = getOriginalSequencePathFromSequenceName(sequence) target = yuvio.mimread(original_path, self.width*self.scale, self.heigth*self.scale, 'yuv420p10le', index=frame_no, count=1)[0] ## to numpy input = self.toNumpy(input) target = self.toNumpy(target) input = self.prepare(input) target = yuv_to_tensor(target) return (input, target, { "sequence_name": self.getSequenceName(input_path), "qp": self.getQP(input_path), "frame": path.basename(input_path).replace(".yuv", "") }) class TestSequenciesVVCHD4K(TestSequencies): def __init__(self, yOnly=True): super(TestSequenciesVVCHD4K, self).__init__( '/trinity/home/mangelini/data/datasets/MPAI-TestSets/HR/VVC', '/trinity/home/mangelini/data/datasets/encodings_3x3_twice_4x4_large_invstride_cs-64_ps-32-16-8-4_jvet-B_e-1000-686cont_ds_randomaccess/data/prepared/target', 2, 1920, 1080,'gray10le', yOnly ) def __getitem__(self, index): input_path = self.low_res_filenames[index] input = self.load(input_path, 1) basename = input_path.split("/")[-3] frame_no=int(input_path.split("/")[-1].split("_")[0].replace("frame", "")) sequence = basename.split("_")[1] original_path = getOriginalSequencePathFromSequenceName(sequence) target = yuvio.mimread(original_path, self.width*self.scale, self.heigth*self.scale, 'yuv420p10le', index=frame_no, count=1)[0] ## to numpy input = self.toNumpy(input) target = self.toNumpy(target) input = self.prepare(input) target = yuv_to_tensor(target) return (input, target, { "sequence_name": self.getSequenceName(input_path), "qp": self.getQP(input_path), "frame": path.basename(input_path).replace(".yuv", "") }) class TestSequenciesEVCSDHD(TestSequencies): def __init__(self, yOnly=True): super(TestSequenciesEVCSDHD, self).__init__( '/trinity/home/mangelini/data/datasets/MPAI-TestSets/SR/EVC_420', '', 2, 1920//2, 1080//2,'yuv420p10le', yOnly ) def __getitem__(self, index): input_path = self.low_res_filenames[index] input = self.load(input_path, 1) basename = input_path.split("/")[-1] frame_no=int(basename.split("_")[0].replace("frame", "")) sequence = basename.split("_")[1] original_path = getOriginalSequencePathFromSequenceName(sequence) target = yuvio.mimread(original_path, self.width*self.scale, self.heigth*self.scale, 'yuv420p10le', index=frame_no, count=1)[0] ## to numpy input = self.toNumpy(input) target = self.toNumpy(target) input = self.prepare(input) target = yuv_to_tensor(target) return (input, target, { "sequence_name": self.getSequenceName(input_path), "qp": self.getQP(input_path), "frame": path.basename(input_path).replace(".yuv", "") }) class TestSequenciesVVCSDHD(TestSequencies): def __init__(self, yOnly=True): super(TestSequenciesVVCSDHD, self).__init__( '/trinity/home/mangelini/data/datasets/MPAI-TestSets/SR/VVC', '', 2, 1920//2, 1080//2,'yuv420p10le', yOnly ) def __getitem__(self, index): input_path = self.low_res_filenames[index] input = self.load(input_path, 1) basename = input_path.split("/")[-1] frame_no=int(basename.split("_")[0].replace("frame", "")) sequence = basename.split("_")[1] original_path = getOriginalSequencePathFromSequenceName(sequence) target = yuvio.mimread(original_path, self.width*self.scale, self.heigth*self.scale, 'yuv420p10le', index=frame_no, count=1)[0] ## to numpy input = self.toNumpy(input) target = self.toNumpy(target) input = self.prepare(input) target = yuv_to_tensor(target) return (input, target, { "sequence_name": self.getSequenceName(input_path), "qp": self.getQP(input_path), "frame": path.basename(input_path).replace(".yuv", "") })