Robustness.py 3.23 KB
Newer Older
Carldst's avatar
Carldst committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from utils import *
from NNW import Uchi_tools, Adi_tools
from Attacks import *

def Modifications(Modification_ID,net,parameters):
    '''
    Apply a modification based on the ID and parameters
    :param Modification_ID: ID of the modification
    :param net: network to be altered
    :param parameters: parameters of the modification
    :return: altered NN
    '''
    if Modification_ID==0:
        if parameters["name"]=="all":
            return adding_noise_global(net,parameters["S"])
        for module in parameters["name"]:
            net=adding_noise(net,parameters["S"],module)
        return net
    elif Modification_ID==1:
        return prune_model_l1_unstructured(net, parameters["P"])
    elif Modification_ID==2:
        return prune_model_random_unstructured(net,parameters["R"])
    elif Modification_ID==3:
        return quantization(net,parameters["B"])
    elif Modification_ID==4:
        return finetuning(net,parameters["E"])
    elif Modification_ID==5:
        return knowledge_distillation(net,parameters["E"],parameters["trainloader"],parameters["student"])
    elif Modification_ID==6:
        return overwriting(net, parameters["NNWmethods"], parameters["W"], parameters["watermarking_dict"])
    else:
        print("NotImplemented")
        return net



if __name__ == '__main__':
    ###### Reproductibility
    torch.manual_seed(0)
    np.random.seed(0)


    model = tv.models.vgg16()
    model.classifier = nn.Linear(25088, 10)
    model.to(device)
    # watermarking section (change here to test another method) #######################################
47
48
    tools = Uchi_tools()
    reload = 'Resources/vgg16_Uchi'
Carldst's avatar
Carldst committed
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
    watermarking_dict = np.load(reload+'_watermarking_dict.npy', allow_pickle=True).item()
    # watermarking section (END change here to test another method) ###################################
    name = '_quantization'

    time_detect=[]

    # take model
    checkpoint = torch.load(reload + "_weights", map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint["model_state_dict"])

    M_ID = 0
    param = {"name": "all","S":.1}
    # M_ID = 1
    # param = {"P": .5}
    # M_ID = 2
    # param = {"R": .2}
    # M_ID = 3
    # param = {"B": 5}
    # M_ID = 4
    # trainset, testset, _ = CIFAR10_dataset()
    # trainloader, testloader = dataloader(trainset, testset, 100)
    # param = {"E": 5, "trainloader": trainloader}
    # M_ID = 5
    # trainset, testset, inference_transform = CIFAR10_dataset()
    # trainloader, testloader = dataloader(trainset, testset, 128)
    # student = tv.models.vgg16()
    # student.classifier = nn.Linear(25088, 10)
    # param = {"E":5,"trainloader":trainloader,"student":student}
    # M_ID = 6
    # param = {"NNWmethods":tools,"W":2,"watermarking_dict":watermarking_dict}

    model = Modifications(M_ID,model,param)

    model.eval()

    # watermark,retrieve_res=tools.Decoder(model, watermarking_dict)
    # print('Modification %s - %s - Percentage of erred bits : %2f ' % (str(M_ID),str(param), retrieve_res))
    retrieve, decision=tools.Detector(model, watermarking_dict)
    print('Modification %s - %s - Presence of the watermark : %s' % (str(M_ID),str(param), decision))

    # val_score= fulltest(new_model, testloader)
    # print('Validation error : %.2f' % val_score)