__init__.py 1.96 KB
Newer Older
valentini's avatar
valentini committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch.nn as nn
import torch_pruning as tp
from functools import partial

# Special Pruners
from pruners.UpsamplePruner import UpsamplePruner

def pruner_constructor(args, model, train_data, device, pruning_rateo = None):
    sparsity_learning = False

    # Random 
    if args.pruning_method == "random":
        imp = tp.importance.RandomImportance()
        pruner_entry = partial(tp.pruner.MagnitudePruner)
    # Greg
    elif args.pruning_method == "growing_reg":
         sparsity_learning = True
        imp = tp.importance.GroupNormImportance(p=2, group_reduction=False)
        pruner_entry = partial(tp.pruner.GrowingRegPruner, reg=args.reg, delta_reg=args.delta_reg)

    else:
        raise Exception("Invalid Pruning Method:", args.pruning_method)
    

    num_heads = {}
    ignored_layers = []
    # Cycle to extract key modules from the model
   
    # TODO: Add layer to be ignored for pruning (usually first and last layers)
    # for m in model.modules():
       
        #ignored_layers.append(m)
    
        # Ignore-By-Parameters
        # Add Upsample blocks to ignore (we eill add later them to a separate pruning function)
        # ignored_layers.append(m)
            
    
    # Special Custom Pruners maps
    # TODO: Upsample and UpsampleOneStep are not imported
    custom_pruners = {
        #E.G., UpsapleBlockTypeOfYourModel: UpsamplePruner(), 
    }
           

    pruning_ratio_dict = {}
    pruner = pruner_entry(
        model,
        importance=imp,
        example_inputs=train_data.to(device),
        iterative_steps=args.pruning_steps,
        pruning_ratio=pruning_rateo if pruning_rateo else args.pruning_target_ratio,
        pruning_ratio_dict=pruning_ratio_dict,
        max_pruning_ratio=args.max_pruning_ratio,
        global_pruning=False,
        num_heads=num_heads,
        ignored_layers=ignored_layers,
        customized_pruners=custom_pruners,
        root_module_types=[ nn.modules.conv._ConvNd, nn.Linear ]
    )
    return pruner