knowledge_distillation.py 3.43 KB
Newer Older
Carldst's avatar
Carldst committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch.nn.functional as F
from utils import *

def distill_unlabeled(y, teacher_scores, T):
    return nn.KLDivLoss()(F.log_softmax(y/T), F.softmax(teacher_scores/T)) * T * T

def test_knowledge_dist(net, water_loss, file_weights, file_watermark, dataset='CIFAR10'):
    epochs_list, test_list, water_test_list = [], [], []

    trainset, testset, _ = CIFAR10_dataset()

    trainloader, testloader = dataloader(trainset, testset, 100)
    student_net = tv.models.vgg16()
    student_net.classifier = nn.Linear(25088, 10)
    student_net.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(student_net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
    watermarking_dict = np.load(file_watermark, allow_pickle='TRUE').item()
    net.eval()
    for param in net.parameters():
        param.requires_grad = False
    student_net.train()
    for epoch in range(10):
        net.train()
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            # split data into the image and its label
            inputs, labels = data
            if dataset == 'MNIST':
                inputs.squeeze_(1)
                inputs = torch.stack([inputs, inputs, inputs], 1)
            inputs = inputs.to(device)
            labels = labels.to(device)

            teacher_output = net(inputs)
            teacher_output = teacher_output.detach()
            _, labels_teacher = torch.max(F.log_softmax(teacher_output, dim=1),dim=1)
            # initialise the optimiser
            optimizer.zero_grad()
            # forward
            outputs = student_net(inputs)
            # backward
            loss = criterion(outputs, labels_teacher)
            loss.backward()
            # update the optimizer
            optimizer.step()
            # loss
            running_loss += loss.item()
        print(running_loss)
    return epochs_list, test_list, water_test_list

def knowledge_distillation(net, epochs, trainloader,student_net):
    student_net.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(student_net.parameters(), lr=0.001, momentum=0.9, weight_decay=5e-4)
    net.eval()
    for param in net.parameters():
        param.requires_grad = False
    student_net.train()
    for epoch in range(epochs):
        print('doing epoch', str(epoch + 1), ".....")
        net.train()
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            # split data into the image and its label
            inputs, labels = data
            inputs = inputs.to(device)

            teacher_output = net(inputs)
            teacher_output = teacher_output.detach()
            _, labels_teacher = torch.max(F.log_softmax(teacher_output, dim=1), dim=1)
            # initialise the optimiser
            optimizer.zero_grad()
            # forward
            outputs = student_net(inputs)
            # backward
            loss = criterion(outputs, labels_teacher)
            loss.backward()
            # update the optimizer
            optimizer.step()
            # loss
            running_loss += loss.item()
        loss = (running_loss * 128 / len(trainloader.dataset))
        print(' loss  : %.5f   ' % (loss))


'''
    M_ID = 5
    trainset, testset, inference_transform = CIFAR10_dataset()
    trainloader, testloader = dataloader(trainset, testset, 128)
    student = tv.models.vgg16()
    student.classifier = nn.Linear(25088, 10)
    param = {"E":5,"trainloader":trainloader,"student":student}
'''