Robustness.py 3.23 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from utils import *
from NNW import Uchi_tools, Adi_tools
from Attacks import *

def Modifications(Modification_ID,net,parameters):
    '''
    Apply a modification based on the ID and parameters
    :param Modification_ID: ID of the modification
    :param net: network to be altered
    :param parameters: parameters of the modification
    :return: altered NN
    '''
    if Modification_ID==0:
        if parameters["name"]=="all":
            return adding_noise_global(net,parameters["S"])
        for module in parameters["name"]:
            net=adding_noise(net,parameters["S"],module)
        return net
    elif Modification_ID==1:
        return prune_model_l1_unstructured(net, parameters["P"])
    elif Modification_ID==2:
        return prune_model_random_unstructured(net,parameters["R"])
    elif Modification_ID==3:
        return quantization(net,parameters["B"])
    elif Modification_ID==4:
        return finetuning(net,parameters["E"])
    elif Modification_ID==5:
        return knowledge_distillation(net,parameters["E"],parameters["trainloader"],parameters["student"])
    elif Modification_ID==6:
        return overwriting(net, parameters["NNWmethods"], parameters["W"], parameters["watermarking_dict"])
    else:
        print("NotImplemented")
        return net



if __name__ == '__main__':
    ###### Reproductibility
    torch.manual_seed(0)
    np.random.seed(0)


    model = tv.models.vgg16()
    model.classifier = nn.Linear(25088, 10)
    model.to(device)
    # watermarking section (change here to test another method) #######################################
    tools = Uchi_tools()
    reload = 'Resources/vgg16_Uchi'
    watermarking_dict = np.load(reload+'_watermarking_dict.npy', allow_pickle=True).item()
    # watermarking section (END change here to test another method) ###################################
    name = '_quantization'

    time_detect=[]

    # take model
    checkpoint = torch.load(reload + "_weights", map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint["model_state_dict"])

    M_ID = 0
    param = {"name": "all","S":.1}
    # M_ID = 1
    # param = {"P": .5}
    # M_ID = 2
    # param = {"R": .2}
    # M_ID = 3
    # param = {"B": 5}
    # M_ID = 4
    # trainset, testset, _ = CIFAR10_dataset()
    # trainloader, testloader = dataloader(trainset, testset, 100)
    # param = {"E": 5, "trainloader": trainloader}
    # M_ID = 5
    # trainset, testset, inference_transform = CIFAR10_dataset()
    # trainloader, testloader = dataloader(trainset, testset, 128)
    # student = tv.models.vgg16()
    # student.classifier = nn.Linear(25088, 10)
    # param = {"E":5,"trainloader":trainloader,"student":student}
    # M_ID = 6
    # param = {"NNWmethods":tools,"W":2,"watermarking_dict":watermarking_dict}

    model = Modifications(M_ID,model,param)

    model.eval()

    # watermark,retrieve_res=tools.Decoder(model, watermarking_dict)
    # print('Modification %s - %s - Percentage of erred bits : %2f ' % (str(M_ID),str(param), retrieve_res))
    retrieve, decision=tools.Detector(model, watermarking_dict)
    print('Modification %s - %s - Presence of the watermark : %s' % (str(M_ID),str(param), decision))

    # val_score= fulltest(new_model, testloader)
    # print('Validation error : %.2f' % val_score)