pruning.py 2.38 KB
Newer Older
Carl De Sousa Trias's avatar
Carl De Sousa Trias committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# pruning

import matplotlib.pyplot as plt
import torch.nn.utils.prune as prune
from utils import *



def prune_model_l1_unstructured(new_model, proportion):
    for name, module in new_model.named_modules():
        if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
            print("pruned")
            prune.l1_unstructured(module, name='weight', amount=proportion)
            prune.remove(module, 'weight')
    return new_model

def random_mask(new_model, proportion):
    # maybe add a dimension for the pruning to remove entirely the kernel
    for name, module in new_model.named_modules():
        if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
            prune.random_unstructured(module, name='weight', amount=proportion)

    return dict(new_model.named_buffers())

def prune_model_random_unstructured(new_model, proportion):
    dict_mask=random_mask(new_model,proportion)
    for name, module in new_model.named_modules():
        if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
            print("pruned")
            weight_name = name + '.weight_mask'
            module.weight = nn.Parameter(module.weight * dict_mask[weight_name])
    return new_model





def train_pruning(net, optimizer, criterion, trainloader, number_epochs, value=None, mask=None):
    # train
    net.train()
    for epoch in range(number_epochs):
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):

            # split data into the image and its label
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            if inputs.size()[1] == 1:
                inputs.squeeze_(1)
                inputs = torch.stack([inputs, inputs, inputs], 1)
            # initialise the optimiser
            optimizer.zero_grad()

            # forward
            outputs = net(inputs)
            # backward
            loss = criterion(outputs, labels)
            loss.backward()

            if value != None:
                net = prune_model_l1_unstructured(net, value)
            elif mask != None:
                net = prune_model_random_unstructured(net, mask)

            # update the optimizer
            optimizer.step()
            # loss
            running_loss += loss.item()



'''

    M_ID=1
    param={"P":.99}
    
    M_ID=2
    param={"R":.99}
'''