UpsamplePruner.py 3.94 KB
Newer Older
valentini's avatar
valentini committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# All arguments used in this file are standard or passed in via function parameters.
import torch
import logging
from typing import Sequence
import torch.nn as nn
import torch_pruning as tp
from model.swin import Upsample, UpsampleOneStep

class UpsamplePruner(tp.BasePruningFunc):
    '''
    Used to prune the upsample block each intenal convolution requires that the output and are grouped by values of (self.scale^2)
    '''
    def prune_out_channels(self, layer: nn.Module, idxs: list):
        channel_group_size = 9 if layer.scale == 3 else 4
        if isinstance(layer, UpsampleOneStep):
            # Swin Transformer (single SR step)
            channel_group_size = (layer.scale)**2
            for key in range(len(layer)):
                module = layer[key]
                if isinstance(module, nn.Conv2d):                
                    module = tp.prune_conv_in_channels(module, idxs)
                    
                    # Group Output Channels y the size of the self.scale value
                    # This will keep aligned the output channels with the input channel enabling pixleshuffle to work
                    to_prune_out = []
                    for idx in idxs:
                        to_prune_out += [(idx*channel_group_size)+i for i in range(channel_group_size)]
                    
                    module = tp.prune_conv_out_channels(module, to_prune_out)
                    layer[key] = module
                    layer.in_channels = layer.in_channels-len(idxs)

            return layer   
        if isinstance(layer, Upsample):
            # Swin Transformer
            for key in range(len(layer)):
                module = layer[key]
                if isinstance(module, nn.Conv2d):
                    print(module.in_channels, module.out_channels)      
                    module = tp.prune_conv_in_channels(module, idxs)
                    
                    # Group Output Channels y the size of the self.scale value
                    # This will keep aligned the output channels with the input channel enabling pixleshuffle to work
                    to_prune_out = []
                    for idx in idxs:
                        to_prune_out += [(idx*channel_group_size)+i for i in range(channel_group_size)]
                    
                    module = tp.prune_conv_out_channels(module, to_prune_out)
                    layer[key] = module
                    layer.in_channels = layer.in_channels-len(idxs)
                    print(module.in_channels-len(idxs), module.out_channels-len(to_prune_out))
            return layer   
        else:
            # DRLN 
            modules = layer.body._modules
            for key in modules:
                module = modules[key]
                if isinstance(module, nn.Conv2d):
                    print(module.in_channels, module.out_channels)      
                    module = tp.prune_conv_in_channels(module, idxs)
                    
                    # Group Output Channels y the size of the self.scale value
                    # This will keep aligned the output channels with the input channel enabling pixleshuffle to work
                    to_prune_out = []
                    for idx in idxs:
                        to_prune_out += [(idx*channel_group_size)+i for i in range(channel_group_size)]
                    
                    module = tp.prune_conv_out_channels(module, to_prune_out)
                    
                    modules[key] = module
                    layer.in_channels = layer.in_channels-len(idxs)
                    print(module.in_channels-len(idxs), module.out_channels-len(to_prune_out))
            return layer   

    # Means we have an inter_dipendency, if we prune the out we need to remove the input and viceversa
    prune_in_channels = prune_out_channels

    def get_out_channels(self, layer):
        # After the each shuffle operation we get the initial number of channels
        return layer.in_channels 

    get_in_channels = get_out_channels