Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
EVC
Ufv1 0
Commits
268c7f57
Commit
268c7f57
authored
Nov 23, 2025
by
valentini
Browse files
Carica un nuovo file
parent
3cc017c3
Changes
1
Hide whitespace changes
Inline
Side-by-side
Reference Software/UFV1.0-Pruning/src/pruners/UpsamplePruner.py
0 → 100644
View file @
268c7f57
# All arguments used in this file are standard or passed in via function parameters.
import
torch
import
logging
from
typing
import
Sequence
import
torch.nn
as
nn
import
torch_pruning
as
tp
from
model.swin
import
Upsample
,
UpsampleOneStep
class
UpsamplePruner
(
tp
.
BasePruningFunc
):
'''
Used to prune the upsample block each intenal convolution requires that the output and are grouped by values of (self.scale^2)
'''
def
prune_out_channels
(
self
,
layer
:
nn
.
Module
,
idxs
:
list
):
channel_group_size
=
9
if
layer
.
scale
==
3
else
4
if
isinstance
(
layer
,
UpsampleOneStep
):
# Swin Transformer (single SR step)
channel_group_size
=
(
layer
.
scale
)
**
2
for
key
in
range
(
len
(
layer
)):
module
=
layer
[
key
]
if
isinstance
(
module
,
nn
.
Conv2d
):
module
=
tp
.
prune_conv_in_channels
(
module
,
idxs
)
# Group Output Channels y the size of the self.scale value
# This will keep aligned the output channels with the input channel enabling pixleshuffle to work
to_prune_out
=
[]
for
idx
in
idxs
:
to_prune_out
+=
[(
idx
*
channel_group_size
)
+
i
for
i
in
range
(
channel_group_size
)]
module
=
tp
.
prune_conv_out_channels
(
module
,
to_prune_out
)
layer
[
key
]
=
module
layer
.
in_channels
=
layer
.
in_channels
-
len
(
idxs
)
return
layer
if
isinstance
(
layer
,
Upsample
):
# Swin Transformer
for
key
in
range
(
len
(
layer
)):
module
=
layer
[
key
]
if
isinstance
(
module
,
nn
.
Conv2d
):
print
(
module
.
in_channels
,
module
.
out_channels
)
module
=
tp
.
prune_conv_in_channels
(
module
,
idxs
)
# Group Output Channels y the size of the self.scale value
# This will keep aligned the output channels with the input channel enabling pixleshuffle to work
to_prune_out
=
[]
for
idx
in
idxs
:
to_prune_out
+=
[(
idx
*
channel_group_size
)
+
i
for
i
in
range
(
channel_group_size
)]
module
=
tp
.
prune_conv_out_channels
(
module
,
to_prune_out
)
layer
[
key
]
=
module
layer
.
in_channels
=
layer
.
in_channels
-
len
(
idxs
)
print
(
module
.
in_channels
-
len
(
idxs
),
module
.
out_channels
-
len
(
to_prune_out
))
return
layer
else
:
# DRLN
modules
=
layer
.
body
.
_modules
for
key
in
modules
:
module
=
modules
[
key
]
if
isinstance
(
module
,
nn
.
Conv2d
):
print
(
module
.
in_channels
,
module
.
out_channels
)
module
=
tp
.
prune_conv_in_channels
(
module
,
idxs
)
# Group Output Channels y the size of the self.scale value
# This will keep aligned the output channels with the input channel enabling pixleshuffle to work
to_prune_out
=
[]
for
idx
in
idxs
:
to_prune_out
+=
[(
idx
*
channel_group_size
)
+
i
for
i
in
range
(
channel_group_size
)]
module
=
tp
.
prune_conv_out_channels
(
module
,
to_prune_out
)
modules
[
key
]
=
module
layer
.
in_channels
=
layer
.
in_channels
-
len
(
idxs
)
print
(
module
.
in_channels
-
len
(
idxs
),
module
.
out_channels
-
len
(
to_prune_out
))
return
layer
# Means we have an inter_dipendency, if we prune the out we need to remove the input and viceversa
prune_in_channels
=
prune_out_channels
def
get_out_channels
(
self
,
layer
):
# After the each shuffle operation we get the initial number of channels
return
layer
.
in_channels
get_in_channels
=
get_out_channels
\ No newline at end of file
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment