Commit 268c7f57 authored by valentini's avatar valentini
Browse files

Carica un nuovo file

parent 3cc017c3
# All arguments used in this file are standard or passed in via function parameters.
import torch
import logging
from typing import Sequence
import torch.nn as nn
import torch_pruning as tp
from model.swin import Upsample, UpsampleOneStep
class UpsamplePruner(tp.BasePruningFunc):
'''
Used to prune the upsample block each intenal convolution requires that the output and are grouped by values of (self.scale^2)
'''
def prune_out_channels(self, layer: nn.Module, idxs: list):
channel_group_size = 9 if layer.scale == 3 else 4
if isinstance(layer, UpsampleOneStep):
# Swin Transformer (single SR step)
channel_group_size = (layer.scale)**2
for key in range(len(layer)):
module = layer[key]
if isinstance(module, nn.Conv2d):
module = tp.prune_conv_in_channels(module, idxs)
# Group Output Channels y the size of the self.scale value
# This will keep aligned the output channels with the input channel enabling pixleshuffle to work
to_prune_out = []
for idx in idxs:
to_prune_out += [(idx*channel_group_size)+i for i in range(channel_group_size)]
module = tp.prune_conv_out_channels(module, to_prune_out)
layer[key] = module
layer.in_channels = layer.in_channels-len(idxs)
return layer
if isinstance(layer, Upsample):
# Swin Transformer
for key in range(len(layer)):
module = layer[key]
if isinstance(module, nn.Conv2d):
print(module.in_channels, module.out_channels)
module = tp.prune_conv_in_channels(module, idxs)
# Group Output Channels y the size of the self.scale value
# This will keep aligned the output channels with the input channel enabling pixleshuffle to work
to_prune_out = []
for idx in idxs:
to_prune_out += [(idx*channel_group_size)+i for i in range(channel_group_size)]
module = tp.prune_conv_out_channels(module, to_prune_out)
layer[key] = module
layer.in_channels = layer.in_channels-len(idxs)
print(module.in_channels-len(idxs), module.out_channels-len(to_prune_out))
return layer
else:
# DRLN
modules = layer.body._modules
for key in modules:
module = modules[key]
if isinstance(module, nn.Conv2d):
print(module.in_channels, module.out_channels)
module = tp.prune_conv_in_channels(module, idxs)
# Group Output Channels y the size of the self.scale value
# This will keep aligned the output channels with the input channel enabling pixleshuffle to work
to_prune_out = []
for idx in idxs:
to_prune_out += [(idx*channel_group_size)+i for i in range(channel_group_size)]
module = tp.prune_conv_out_channels(module, to_prune_out)
modules[key] = module
layer.in_channels = layer.in_channels-len(idxs)
print(module.in_channels-len(idxs), module.out_channels-len(to_prune_out))
return layer
# Means we have an inter_dipendency, if we prune the out we need to remove the input and viceversa
prune_in_channels = prune_out_channels
def get_out_channels(self, layer):
# After the each shuffle operation we get the initial number of channels
return layer.in_channels
get_in_channels = get_out_channels
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment