util.py 3.34 KB
Newer Older
valentini's avatar
valentini committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#
#Copyright (C) 2020-2021 ISTI-CNR
#Licensed under the BSD 3-Clause Clear License (see license.txt)
#

import os
import matplotlib.pyplot as plt
import torch
from torchvision.transforms.functional import to_tensor
from scipy.io import loadmat
from PIL import Image
import numpy as np
import math
from torchvision.utils import make_grid

#read an 8-bit image
def read_img(fname, grayscale=True):
    img = Image.open(fname)
    img = img.convert('L') if grayscale else img.convert('RGB')
    x = to_tensor(img)
    return x

#read an 8-bit/32-bit image in MATLAB format
def read_mat(fname,  grayscale=True, log_range=True):
    x = loadmat(fname, verify_compressed_data_integrity=False)['image']
    x = torch.FloatTensor(x)
    
    if (x.ndimension() == 3) and grayscale:
        x = x.transpose(2, 0)
        x = torch.sum(x, dim = 0) / 3
        x = x.unsqueeze(0)
    
    if log_range:  # perform log10(1 + image)
        x += 1
        torch.log10(x, out = x)
    elif x.ndimension() == 2:
        x = x.unsqueeze(0)
    return x

#read an image
def load_image(fname,  grayscale=True, log_range=True):
    filename, ext = os.path.splitext(fname)
    if ext == '.mat':
       return read_mat(fname, grayscale, log_range)
    else:
       return read_img(fname, grayscale)

#plot a graph with train, validation, and test
def plotGraph(array_train, array_val, array_test, folder):
    fig = plt.figure(figsize=(10, 4))
    n = min([len(array_train), len(array_val), len(array_test)])
    plt.plot(np.arange(1, n + 1), array_train[0:n])# train loss (on epoch end)
    plt.plot(np.arange(1, n + 1), array_val[0:n])  # val loss (on epoch end)
    plt.plot(np.arange(1, n + 1), array_test[0:n]) # test loss (on epoch end)
    plt.title("model loss")
    plt.xlabel('epochs')
    plt.ylabel('loss')
    plt.legend(['train', 'validation','test'], loc="upper left")
    title = os.path.join(folder, "plot.png")
    plt.savefig(title, dpi=600)
    plt.close(fig)


def get_luminance(output):
    y_pred, y = output
    convert = y.new(1, 3, 1, 1)
    convert[0, 0, 0, 0] = 65.738
    convert[0, 1, 0, 0] = 129.057
    convert[0, 2, 0, 0] = 25.064
    y.mul_(convert)
    return y_pred.mul_(convert), y.mul_(convert)

def quantize(img, rgb_range):
    pixel_range = 255 / rgb_range
    return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range)

# From DRLN code
def calc_psnr(sr, hr, scale, rgb_range, benchmark=False, skip_luminance=False):
    diff = (sr - hr).data.div(rgb_range)
    '''
    if benchmark:
        shave = scale
        if diff.size(1) > 1:
            convert = diff.new(1, 3, 1, 1)
            convert[0, 0, 0, 0] = 65.738
            convert[0, 1, 0, 0] = 129.057
            convert[0, 2, 0, 0] = 25.064
            diff.mul_(convert).div_(256)
            diff = diff.sum(dim=1, keepdim=True)
    else:
        shave = scale + 6
    '''
    shave = scale
    if skip_luminance:
        valid = diff[:, :, shave:-shave, shave:-shave]
        mse = valid.pow(2).mean()

        return -10 * math.log10(mse)
        
    if diff.size(1) > 1:
        convert = diff.new(1, 3, 1, 1)
        convert[0, 0, 0, 0] = 65.738
        convert[0, 1, 0, 0] = 129.057
        convert[0, 2, 0, 0] = 25.064
        diff.mul_(convert).div_(rgb_range)
        diff = diff.sum(dim=1, keepdim=True)

    valid = diff[:, :, shave:-shave, shave:-shave]
    mse = valid.pow(2).mean()

    return -10 * math.log10(mse)