Commit 891d58a1 authored by valentini's avatar valentini
Browse files

Carica un nuovo file

parent d0662c89
# Prepare the dataset for the training phase
import torch.utils.data as data
import torchvision.transforms as T
import yuvio
import glob
import glob2
import torch
import logging
import random
import numpy as np
import cv2
import re
import os.path as path
from utils.transforms import paralellCrop
from PIL import Image
ORIGNAL_BASE_PATH = '/trinity/home/mangelini/data/datasets/OrignialRaw/Original_Sequencies'
def getOriginalSequencePathFromSequenceName(name):
sequence_paths = {
'BasketballDrive': 'BasketballDrive_1920x1080_50.yuv',
'FoodMarket4': 'FoodMarket4_3840x2160_60fps_10bit_420.yuv',
'BQTerrace': 'BQTerrace_1920x1080_60.yuv',
'MarketPlace': 'MarketPlace_1920x1080_60fps_10bit_420.yuv',
'Cactus': 'Cactus_1920x1080_50.yuv',
'ParkRunning3': 'ParkRunning3_3840x2160_50fps_10bit_420.yuv',
'Campfire': 'Campfire_3840x2160_30fps_10bit_bt709_420_videoRange.yuv',
'RitualDance': 'RitualDance_1920x1080_60fps_10bit_420.yuv',
'CatRobot': 'CatRobot_3840x2160_60fps_10bit_420_jvet.yuv',
'Tango2': 'Tango2_3840x2160_60fps_10bit_420.yuv',
'DaylightRoad2': 'DaylightRoad2_3840x2160_60fps_10bit_420.yuv',
}
return f'{ORIGNAL_BASE_PATH}/{sequence_paths[name]}'
def getDepthFromFormat(format):
if format == 'gray': return 8
if format =='gray10le': return 10
if format =='gray10be': return 10
if format =='gray16le': return 16
if format =='gray16be': return 16
if format =='gray9le': return 9
if format =='gray9be': return 9
if format =='gray12le': return 12
if format =='gray12be': return 12
if format =='gray14le': return 14
if format =='gray14be': return 14
if format =='nv12': return 8
if format =='v210': return 10
if format =='yuv420p': return 8
if format =='yuv420p10le': return 10
if format =='yuv420p10be': return 10
if format =='yuv420p16le': return 16
if format =='yuv420p16be': return 16
if format =='yuv420p9le': return 9
if format =='yuv420p9be': return 9
if format =='yuv420p12le': return 12
if format =='yuv420p12be': return 12
if format =='yuv420p14le': return 14
if format =='yuv420p14be': return 14
if format =='yuv422p': return 8
if format =='yuv422p10le': return 10
if format =='yuv422p10be': return 10
if format =='yuv422p16le': return 16
if format =='yuv422p16be': return 16
if format =='yuv422p9le': return 9
if format =='yuv422p9be': return 9
if format =='yuv422p12le': return 12
if format =='yuv422p12be': return 12
if format =='yuv422p14le': return 14
if format =='yuv422p14be': return 14
if format =='yuv444p': return 8
if format =='yuv444p10le': return 10
if format =='yuv444p10be': return 10
if format =='yuv444p16le': return 16
if format =='yuv444p16be': return 16
if format =='yuv444p9le': return 9
if format =='yuv444p9be': return 9
if format =='yuv444p12le': return 12
if format =='yuv444p12be': return 12
if format =='yuv444p14le': return 14
if format =='yuv444p14be': return 14
if format =='yuyv422': return 8
if format =='uyvy422': return 8
if format =='yvyu422': return 8
return 0
def extract_max_from_bitsize(bitsize):
#find nearest 2 mutiple
size = 2
while size < bitsize: size *= 2
return ((2 ** size) - 1)
def is_image_file(filename):
return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"])
def load_img(filepath):
img = Image.open(filepath).convert('YCbCr')
y, _, _ = img.split()
return y
def frame_to_tensor(input_image):
img_array = np.asarray(input_image).copy().astype(np.float)
img_tensor = torch.tensor(img_array, dtype=torch.float32).permute((2, 1, 0))
return img_tensor
def yuv_to_tensor(input_image):
img_array = np.asarray(input_image).copy().astype(np.float)
img_tensor = torch.tensor(img_array, dtype=torch.float32).permute((0, 2, 1))
return img_tensor
# support only 444
def load_yuv_img(path, bitsFormart='uint8', w=1920, h=1080) :
# Read entire file into YUV
YUV = np.fromfile(path, dtype=bitsFormart)
# return the image as a frame of w*h
return YUV.reshape(h, w, 3)
def getContainerType(depth):
return 'uint{}'.format(depth)
class DatasetFromFolderYUV(data.Dataset):
def __init__(self, low_res_folder, high_res_folder, size, scale,
yuvFormat='yuv444p10le', take_channels=3,transform=None, limit=None, return_names=False, filter_only=None, y_only=False,
prefix_low_res=None, low_res_filter=None, prefix_high_res=None, high_res_filter=None, normalize=False, random_frame=False, seed= 888):
super(DatasetFromFolderYUV, self).__init__()
self.transform = transform
self.return_names = return_names
self.width = 128
self.heigth = 128
self.yuvFormat = yuvFormat
self.depth = getDepthFromFormat(self.yuvFormat)
self.bucketSize = min([x for x in [8, 16, 32, 64, 128] if x > self.depth])
self.yOnly = y_only
self.takeChannels = max(min(take_channels,1), 3)
self.scale = int(scale)
self.normalize = normalize
self.random_frame = random_frame
self.random = np.random.RandomState(seed)
self.low_res_folder = low_res_folder
self.high_res_folder = high_res_folder
self.prefix_low_res = prefix_low_res
self.prefix_high_res = prefix_high_res
self.low_res_filter = low_res_filter
self.high_res_filter = high_res_filter
self.low_res_filenames = glob.glob("{}/**/*.yuv".format(low_res_folder), recursive=True)
if filter_only:
self.low_res_filenames = [x for x in self.low_res_filenames if filter_only in x]
#self.low_res_filenames = self._removeBroken()
if limit :
self.low_res_filenames = self.low_res_filenames[:limit]
def _getMax(self):
return (2 ** self.depth) - 1
def _removeBroken(self):
not_broken = []
for input_path in self.low_res_filenames:
try:
self.getYUV(input_path, 1)
target_path = input_path.replace(self.prefix_low_res, self.prefix_high_res).replace(self.low_res_folder, self.high_res_folder).replace('-qp22', '')
target_path = self.removeClasses(target_path)
target_path = target_path.replace(f'1920x1088', f'3840x2176') # HD-4K
target_path = target_path.replace(f'960x544', f'1920x1088') # SD-HD
self.getYUV(target_path, self.scale)
not_broken += input_path
except:
logging.warning(f'Skipping ({input_path}) impossiple to find HD counterpart of broken data')
return not_broken
def getDataRange(self):
if self.normalize:
return 1.0
return self._getMax()
def removeClasses(self, text):
sliced = text.split('/')
if not 'bvi-dvc' in text: return text
elif '1920x1088' in text:
sliced[-1] = sliced[-1].replace("B", "A", 1)
sliced[-2] = sliced[-2].replace("B", "A", 1)
else:
sliced[-1] = sliced[-1].replace("C", "B", 1)
sliced[-2] = sliced[-2].replace("C", "B", 1)
return '/'.join(sliced)
def swapFpsDepth(self, text):
if not 'ultravideo' in text or "3840x2176" in text: return text
sliced = text.split('/')
tmpSlice = sliced[-1].split("_")
tmp = tmpSlice[-2]
tmpSlice[-2] = tmpSlice[-3]
tmpSlice[-3] = tmp
sliced[-1] = "_".join(tmpSlice)
tmpSlice = sliced[-2].split("_")
tmp = tmpSlice[-1]
tmpSlice[-1] = tmpSlice[-2]
tmpSlice[-2] = tmp
sliced[-2] = "_".join(tmpSlice)
return '/'.join(sliced)
def getYUV(self, path, scale):
input = yuvio.imread(path, self.width*scale, self.heigth*scale, self.yuvFormat)
Y,U,V = input.split()
if self.yOnly:
input = np.stack([Y,Y,Y])
else:
input = np.stack([Y,U,V])
# input = np.fromfile(path, dtype=np.uint16)
# input = input[0:(self.heigth*scale) * (self.width*scale)].reshape(self.heigth*scale, self.width*scale)
# input = np.stack([input,input,input])
if self.yOnly:
return input[0:self.takeChannels]
return input
def getTargetPath(self, in_path):
target_path = in_path.replace(self.prefix_low_res, self.prefix_high_res).replace(self.low_res_folder, self.high_res_folder).replace('-qp22', '')
target_path = target_path.replace(f'1920x1088', f'3840x2176') # HD-4K
target_path = target_path.replace(f'960x544', f'1920x1088') # SD-HD
if not path.exists(target_path): target_path = self.removeClasses(target_path)
if not path.exists(target_path): target_path = self.swapFpsDepth(target_path)
return target_path
def __getitem__(self, index):
input_path = self.low_res_filenames[index]
input = self.getYUV(input_path, 1)
target_path = self.getTargetPath(input_path)
target = self.getYUV(target_path, self.scale)
# Convert yuv in pytorch shape
input = yuv_to_tensor(input)
target = yuv_to_tensor(target)
# Normalize
input = input.mul(255/1023)
target = target.mul(255/1023)
if self.transform:
seed = self.random.randint(2147483647) # make a seed with numpy generator
random.seed(seed) # apply this seed to img tranfsorms
torch.manual_seed(seed) # needed for torchvision 0.7
input = self.transform(input)
random.seed(seed) # Force same transform for the target
torch.manual_seed(seed) # Force same transform for the target
target = self.transform(target)
if self.return_names:
return (input_path, input), (target_path, target)
return (input, target)
def __len__(self):
return len(self.low_res_filenames)
class DatasetFromYUVSequence(data.Dataset):
def __init__(self, low_res_folder, high_res_folder, size, scale,
yuvFormat='yuv444p10le', transform=None, limit=None, return_names=False, filter_only=None, y_only=False,
prefix_low_res=None, prefix_high_res=None, random_frame=False, crop=None, seed= 888):
super(DatasetFromYUVSequence, self).__init__()
self.transform = transform
self.return_names = return_names
self.width = int(size.split("x")[0])
self.heigth = int(size.split("x")[1])
self.yuvFormat = yuvFormat
self.depth = getDepthFromFormat(self.yuvFormat)
self.bucketSize = min([x for x in [8, 16, 32, 64, 128] if x > self.depth])
self.yOnly = y_only
self.scale = int(scale)
self.random_frame = random_frame
self.crop = int(crop) if crop else -1
self.random = np.random.RandomState(seed=seed)
self.low_res_folder = low_res_folder
self.high_res_folder = high_res_folder
self.prefix_low_res = prefix_low_res
self.prefix_high_res = prefix_high_res
self.low_res_filenames = glob2.glob("{}/**/*.yuv".format(low_res_folder))
if filter_only:
self.low_res_filenames = [x for x in self.low_res_filenames if filter_only in x]
self.low_res_filenames = [x for x in self.low_res_filenames if size in x]
self.low_res_filenames = self.removeBrokenPairs()
if limit :
self.low_res_filenames = self.low_res_filenames[:limit]
def _getMax(self):
return (2 ** self.depth) - 1
def getDataRange(self):
return 255
def loadSequence(self, path_img, scale):
return yuvio.mimread(path_img, self.width*scale, self.heigth*scale, self.yuvFormat)
def toNumpy(self, yuvIOimg):
Y, U, V = yuvIOimg.split()
if self.yOnly: return np.stack([Y,Y,Y])
else: return np.stack([Y,U,V])
def prepare(self, data, seed):
data = yuv_to_tensor(data)
if self.transform:
random.seed(seed) # apply this seed to img tranfsorms
torch.manual_seed(seed) # needed for torchvision 0.7
data = self.transform(data)
# Normalize for transfer learning
data = data.mul(255/1023)
return data
def removeBrokenPairs(self):
new_low_res=[]
for low_res in self.low_res_filenames:
high_res_path = self.getHighResPath(low_res)
if(path.exists(high_res_path)): new_low_res.append(low_res)
else: logging.warning(f'Could not fint the HD counterpart for {low_res} at -> {high_res_path}')
return new_low_res
def sequenceSepcificIssuesFix(self, path, file):
if path.count('/bvi') > 0:
if '1920x1088' in file:
return file.replace("B", "A", 1)
else:
return file.replace("C", "B", 1)
elif path.count('/ultravideo') > 0:
file_segment = file.replace(".yuv", "").split("_")
tmp = file_segment[-1]
file_segment[-1] = file_segment[-2]
file_segment[-2] = tmp
new_filename = "_".join(file_segment)
return f"{new_filename}.yuv"
else: return file
def getHighResPath(self, img_path):
target_folder = path.dirname(img_path).replace(self.low_res_folder, self.high_res_folder)
target_name = path.basename(img_path)
target_name = re.sub("-qp[0-9][0-9]", "", target_name) # Remove the -qp## from name
target_name = self.sequenceSepcificIssuesFix(target_folder, target_name)
target_name = target_name.replace(self.prefix_low_res, self.prefix_high_res)
target_name = target_name.replace(f'1920x1088', f'3840x2176') # HD-4K
target_name = target_name.replace(f'960x544', f'1920x1088') # SD-HD
return path.join(target_folder, target_name)
def __getitem__(self, index):
input_path = self.low_res_filenames[index]
inputs = self.loadSequence(input_path, 1)
target_path = self.getHighResPath(input_path)
targets = self.loadSequence(target_path, self.scale)
random_frame_index = 0
if self.random_frame:
frame_count = min(len(targets), len(inputs))
random_frame_index = random.random * frame_count
# Extract frame
input = inputs[random_frame_index]
target = targets[random_frame_index]
## to numpy
input = self.toNumpy(input)
target = self.toNumpy(target)
if self.crop and self.crop > 10 :
i, t, crop_info = paralellCrop(input, target, crop=self.crop, scale=self.scale, random=random)
input = i
target = t
seed = self.random.randint(1000000)
input = self.prepare(input, seed)
target = self.prepare(target, seed)
if self.return_names:
return (input_path, input), (target_path, target)
return (input, target)
def __len__(self):
return len(self.low_res_filenames)
class TestSequencies(data.Dataset):
loaded = {}
def __init__(self, lowres_path, target_path, scale, width, heigth, yuvFormat, yOnly=True):
super(TestSequencies, self).__init__()
self.yOnly = yOnly
self.scale = scale
self.width=width
self.heigth=heigth
self.yuvFormat=yuvFormat
self.data_path=lowres_path
self.target_path=target_path
self.low_res_filenames = glob2.glob(f'{self.data_path}/**/*.yuv')
self.low_res_filenames.sort()
#self.low_res_filenames = [x for x in self.low_res_filenames if "qp-22" in x]
def _getMax(self):
return 1023
def getDataRange(self):
return 255
def load(self, path_img, scale):
return yuvio.imread(path_img, self.width*scale, self.heigth*scale, self.yuvFormat)
def toNumpy(self, yuvIOimg):
Y, U, V = yuvIOimg.split()
if self.yOnly: return np.stack([Y,Y,Y])
else: return np.stack([Y,U,V])
def prepare(self, data):
data = yuv_to_tensor(data)
# Normalize for transfer learning
data = data.mul(255/1023)
return data
def getSequenceName(self, img_path):
return re.sub("_prop_qp-[0-9][0-9]", "", path.dirname(img_path).split("/")[-1])
def getQP(self, img_path):
base_sequence = path.dirname(img_path).split("/")[-1]
for i in range(10, 100):
if f'_qp{i}' in base_sequence: return i
return 0
def getTargetPath(self, img_path):
sequenceName = self.getSequenceName(img_path)
target_folder = path.join(self.target_path, f'{sequenceName}_3840x2176_50fps_10bit_420p')
target_name = path.basename(img_path)
return path.join(target_folder, target_name)
def __getitem__(self, index):
input_path = self.low_res_filenames[index]
input = self.load(input_path, 1)
target_path = self.getTargetPath(input_path)
target = self.load(target_path, self.scale)
## to numpy
input = self.toNumpy(input)
target = self.toNumpy(target)
input = self.prepare(input)
target = yuv_to_tensor(target)
return (input, target, {
"sequence_name": self.getSequenceName(input_path),
"qp": self.getQP(input_path),
"frame": path.basename(input_path).replace(".yuv", "")
})
def __len__(self):
return len(self.low_res_filenames)
class TestSequenciesEVCIntra(TestSequencies):
def __init__(self, yOnly=True):
super(TestSequenciesEVCIntra, self).__init__(
'/trinity/home/mangelini/data/datasets/encodings_3x3_twice_4x4_large_invstride_cs-64_ps-32-16-8-4_jvet-B_e-1000-686cont_ds_randomaccess/data/prepared/low_res_420',
'/trinity/home/mangelini/data/datasets/encodings_3x3_twice_4x4_large_invstride_cs-64_ps-32-16-8-4_jvet-B_e-1000-686cont_ds_randomaccess/data/prepared/target',
2, 1920, 1088,'yuv420p10le', yOnly
)
class TestSequenciesEVCHD4K(TestSequencies):
def __init__(self, yOnly=True):
super(TestSequenciesEVCHD4K, self).__init__(
'/trinity/home/mangelini/data/datasets/MPAI-TestSets/HR/EVC_420',
'/trinity/home/mangelini/data/datasets/encodings_3x3_twice_4x4_large_invstride_cs-64_ps-32-16-8-4_jvet-B_e-1000-686cont_ds_randomaccess/data/prepared/target',
2, 1920, 1088,'yuv420p10le', yOnly
)
def __getitem__(self, index):
input_path = self.low_res_filenames[index]
input = self.load(input_path, 1)
basename = input_path.split("/")[-1]
frame_no=int(basename.split("_")[0].replace("frame", ""))
sequence = basename.split("_")[1]
original_path = getOriginalSequencePathFromSequenceName(sequence)
target = yuvio.mimread(original_path, self.width*self.scale, self.heigth*self.scale, 'yuv420p10le', index=frame_no, count=1)[0]
## to numpy
input = self.toNumpy(input)
target = self.toNumpy(target)
input = self.prepare(input)
target = yuv_to_tensor(target)
return (input, target, {
"sequence_name": self.getSequenceName(input_path),
"qp": self.getQP(input_path),
"frame": path.basename(input_path).replace(".yuv", "")
})
class TestSequenciesVVCHD4K(TestSequencies):
def __init__(self, yOnly=True):
super(TestSequenciesVVCHD4K, self).__init__(
'/trinity/home/mangelini/data/datasets/MPAI-TestSets/HR/VVC',
'/trinity/home/mangelini/data/datasets/encodings_3x3_twice_4x4_large_invstride_cs-64_ps-32-16-8-4_jvet-B_e-1000-686cont_ds_randomaccess/data/prepared/target',
2, 1920, 1080,'gray10le', yOnly
)
def __getitem__(self, index):
input_path = self.low_res_filenames[index]
input = self.load(input_path, 1)
basename = input_path.split("/")[-3]
frame_no=int(input_path.split("/")[-1].split("_")[0].replace("frame", ""))
sequence = basename.split("_")[1]
original_path = getOriginalSequencePathFromSequenceName(sequence)
target = yuvio.mimread(original_path, self.width*self.scale, self.heigth*self.scale, 'yuv420p10le', index=frame_no, count=1)[0]
## to numpy
input = self.toNumpy(input)
target = self.toNumpy(target)
input = self.prepare(input)
target = yuv_to_tensor(target)
return (input, target, {
"sequence_name": self.getSequenceName(input_path),
"qp": self.getQP(input_path),
"frame": path.basename(input_path).replace(".yuv", "")
})
class TestSequenciesEVCSDHD(TestSequencies):
def __init__(self, yOnly=True):
super(TestSequenciesEVCSDHD, self).__init__(
'/trinity/home/mangelini/data/datasets/MPAI-TestSets/SR/EVC_420',
'',
2, 1920//2, 1080//2,'yuv420p10le', yOnly
)
def __getitem__(self, index):
input_path = self.low_res_filenames[index]
input = self.load(input_path, 1)
basename = input_path.split("/")[-1]
frame_no=int(basename.split("_")[0].replace("frame", ""))
sequence = basename.split("_")[1]
original_path = getOriginalSequencePathFromSequenceName(sequence)
target = yuvio.mimread(original_path, self.width*self.scale, self.heigth*self.scale, 'yuv420p10le', index=frame_no, count=1)[0]
## to numpy
input = self.toNumpy(input)
target = self.toNumpy(target)
input = self.prepare(input)
target = yuv_to_tensor(target)
return (input, target, {
"sequence_name": self.getSequenceName(input_path),
"qp": self.getQP(input_path),
"frame": path.basename(input_path).replace(".yuv", "")
})
class TestSequenciesVVCSDHD(TestSequencies):
def __init__(self, yOnly=True):
super(TestSequenciesVVCSDHD, self).__init__(
'/trinity/home/mangelini/data/datasets/MPAI-TestSets/SR/VVC',
'',
2, 1920//2, 1080//2,'yuv420p10le', yOnly
)
def __getitem__(self, index):
input_path = self.low_res_filenames[index]
input = self.load(input_path, 1)
basename = input_path.split("/")[-1]
frame_no=int(basename.split("_")[0].replace("frame", ""))
sequence = basename.split("_")[1]
original_path = getOriginalSequencePathFromSequenceName(sequence)
target = yuvio.mimread(original_path, self.width*self.scale, self.heigth*self.scale, 'yuv420p10le', index=frame_no, count=1)[0]
## to numpy
input = self.toNumpy(input)
target = self.toNumpy(target)
input = self.prepare(input)
target = yuv_to_tensor(target)
return (input, target, {
"sequence_name": self.getSequenceName(input_path),
"qp": self.getQP(input_path),
"frame": path.basename(input_path).replace(".yuv", "")
})
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment