Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
evc-ufv
V1-0
Commits
8f77be3d
Commit
8f77be3d
authored
Dec 19, 2025
by
valentini
Browse files
Carica un nuovo file
parent
f36c8cff
Changes
1
Hide whitespace changes
Inline
Side-by-side
Reference Software/UFV1.0-Pruning/src/main.py
0 → 100644
View file @
8f77be3d
import
os
import
pandas
as
pd
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch_pruning
as
tp
import
numpy
as
np
import
random
from
tqdm
import
tqdm
from
ignite.metrics
import
PSNR
from
ignite.metrics
import
SSIM
# Model Variants
from
model.dummy_sr_model
import
DummySRModel
from
utils.ckpt
import
load_checkpoint
from
utils.util
import
get_luminance
as
get_luminance_from_rgb
from
dataset
import
RGBDataset
,
YUVDataset
from
utils.trainer
import
trainFor
,
testModel
,
sparsityLearning
from
utils.loss
import
CharbonnierLoss
from
torch.nn
import
MSELoss
from
arguments
import
get_arguments
from
pruners
import
pruner_constructor
def
generateRunName
(
args
):
pruning_steps
=
args
.
pruning_steps
pruning_target_ratio
=
args
.
pruning_target_ratio
# Constract name
name
=
f
"Pruning(x
{
args
.
scale
}
(
{
pruning_steps
}
))-RetrainedOn(e
{
args
.
epochs
}
-
{
args
.
crop
}
x
{
args
.
crop
}
)"
return
name
def
loadModel
(
args
):
# create the model
model
=
DummySRModel
(
args
)
# Load pretrained weights
if
args
.
weigths
:
model_dict
,
epoch
,
mse
=
load_checkpoint
(
args
.
weigths
)
if
model_dict
is
None
:
raise
Exception
(
"The ckpt dose not have the model state_dict!"
)
model
.
load_state_dict
(
model_dict
[
'model'
])
# Saving original Model
ckpt_path
=
os
.
path
.
join
(
checkpoints_path
,
'unpruned_model.pth'
)
if
not
os
.
path
.
exists
(
ckpt_path
):
torch
.
save
({
'model'
:
model_dict
[
'model'
],
},
ckpt_path
)
return
model
if
__name__
==
'__main__'
:
args
=
get_arguments
()
print
(
args
)
# SetUp Random
torch
.
manual_seed
(
args
.
seed
)
np
.
random
.
seed
(
args
.
seed
)
random
.
seed
(
args
.
seed
)
### Prepare output Dirs
params
=
vars
(
args
)
params
[
'dataset'
]
=
os
.
path
.
basename
(
os
.
path
.
normpath
(
args
.
loader
))
run_name
,
tags
=
generateRunName
(
args
)
run_dir
=
os
.
path
.
join
(
args
.
runs
,
run_name
)
print
(
"Outupt Folder root:"
,
run_dir
)
log_metrics_path
=
os
.
path
.
join
(
run_dir
,
'pruning_results.csv'
)
original_test_path
=
os
.
path
.
join
(
run_dir
,
'original'
)
checkpoints_path
=
os
.
path
.
join
(
run_dir
,
'checkpoints'
)
param_file
=
os
.
path
.
join
(
run_dir
,
'params.csv'
)
os
.
makedirs
(
original_test_path
,
exist_ok
=
True
)
os
.
makedirs
(
checkpoints_path
,
exist_ok
=
True
)
pd
.
DataFrame
(
params
,
index
=
[
0
]).
to_csv
(
param_file
,
index
=
False
)
# Load model
model
=
loadModel
(
args
)
# Get Dataset
if
args
.
loader
.
lower
()
==
'div2k_rgb'
:
train_loader
,
val_loader
,
test_loader
=
RGBDataset
(
args
)
loss_function
=
CharbonnierLoss
(
args
.
loss_epsylon
)
get_luminance
=
get_luminance_from_rgb
elif
args
.
loader
.
lower
()
==
'custom_yuv'
:
train_loader
,
val_loader
,
test_loader
=
YUVDataset
(
args
)
loss_function
=
MSELoss
()
get_luminance
=
None
else
:
raise
Exception
(
"Unsupported dataset"
)
# Define device
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
model
=
model
.
to
(
device
)
# Defining test inputs to evaluate the model performances
example_input_sd
=
torch
.
randn
(
1
,
3
,
964
,
540
).
to
(
device
)
# Needed for some pruners
pruner
=
pruner_constructor
(
args
,
model
,
torch
.
randn
(
1
,
3
,
10
,
10
),
device
)
log_metrics
=
pd
.
DataFrame
()
base_macs_sd
,
base_nparams
=
tp
.
utils
.
count_ops_and_params
(
model
,
example_input_sd
)
# Eval Original Model on The TestSet
psnr_pretrain
=
PSNR
(
data_range
=
args
.
data_range
,
output_transform
=
get_luminance
,
device
=
device
)
ssim_pretrain
=
SSIM
(
data_range
=
args
.
data_range
,
output_transform
=
get_luminance
,
device
=
device
)
test_loss
,
test_table
=
testModel
(
loader
=
test_loader
,
model
=
model
,
args
=
args
,
psnr
=
psnr_pretrain
,
ssim
=
ssim_pretrain
,
data_range
=
args
.
data_range
,
device
=
device
,
loss_function
=
loss_function
)
prune_iter_metrics
=
{}
prune_iter_metrics
[
"pruning_step"
]
=
0
prune_iter_metrics
[
"pruning_rateo"
]
=
0
prune_iter_metrics
[
"parameters_(M)"
]
=
base_nparams
/
1e6
prune_iter_metrics
[
"inference_SD_HD_flops(G)"
]
=
base_macs_sd
/
1e9
prune_iter_metrics
[
"mse"
]
=
test_loss
prune_iter_metrics
[
'ssim'
]
=
float
(
ssim_pretrain
.
compute
())
prune_iter_metrics
[
'psnr'
]
=
float
(
psnr_pretrain
.
compute
())
test_table
.
to_csv
(
os
.
path
.
join
(
original_test_path
,
'original_test.csv'
),
index
=
False
)
test_table
.
groupby
(
'sequence'
).
mean
().
reset_index
().
to_csv
(
os
.
path
.
join
(
original_test_path
,
'original_test_by_sequence.csv'
),
index
=
False
)
log_metrics
=
log_metrics
.
append
(
prune_iter_metrics
,
ignore_index
=
True
)
log_metrics
.
to_csv
(
log_metrics_path
,
index
=
False
)
# Save depenency graph visualization
tp
.
utils
.
draw_dependency_graph
(
pruner
.
DG
,
save_as
=
os
.
path
.
join
(
original_test_path
,
'draw_dep_graph.png'
),
title
=
None
)
tp
.
utils
.
draw_groups
(
pruner
.
DG
,
save_as
=
os
.
path
.
join
(
original_test_path
,
'draw_groups.png'
),
title
=
None
)
tp
.
utils
.
draw_computational_graph
(
pruner
.
DG
,
save_as
=
os
.
path
.
join
(
original_test_path
,
'draw_comp_graph.png'
),
title
=
None
)
# Save original model structure
macs_sd
,
nparams
=
tp
.
utils
.
count_ops_and_params
(
model
,
example_input_sd
)
with
open
(
os
.
path
.
join
(
original_test_path
,
'model_details.txt'
),
'w'
)
as
model_details_file
:
model_details_file
.
write
(
f
"
{
model
}
\n
"
)
wandb
.
log
({
"model_description"
:
f
"
{
model
}
"
})
model_details_file
.
write
(
" Iter %d/%d, Params: %.2f M => %.2f M
\n
"
%
(
0
,
args
.
pruning_steps
,
base_nparams
/
1e6
,
nparams
/
1e6
)
)
model_details_file
.
write
(
" Iter %d/%d, MACs SD_Input: %.2f G => %.2f G
\n
"
%
(
0
,
args
.
pruning_steps
,
base_macs_sd
/
1e9
,
macs_sd
/
1e9
)
)
####################################
# Pruning Cycles ###################
####################################
training_iter
=
0
for
i
in
tqdm
(
range
(
1
,
args
.
pruning_steps
+
1
)):
step_path
=
os
.
path
.
join
(
run_dir
,
'pruning_iter_{}'
.
format
(
i
))
os
.
makedirs
(
step_path
,
exist_ok
=
True
)
# Learning Sparsity (Some pruning techniques require a treaning step to learn the sparsity)
if
args
.
pruning_method
==
"growing_reg"
:
sparsityLearning
(
model
=
model
,
pruner
=
pruner
,
loader
=
train_loader
,
args
=
args
,
loss_function
=
loss_function
)
# Pruning Step
pruner
.
step
()
macs_sd
,
nparams
=
tp
.
utils
.
count_ops_and_params
(
model
,
example_input_sd
)
with
open
(
os
.
path
.
join
(
step_path
,
'model_details.txt'
),
'w'
)
as
model_details_file
:
model_details_file
.
write
(
f
"
{
model
}
\n
"
)
wandb
.
log
({
"model_description"
:
f
"
{
model
}
"
})
model_details_file
.
write
(
" Iter %d/%d, Params: %.2f M => %.2f M
\n
"
%
(
i
,
args
.
pruning_steps
,
base_nparams
/
1e6
,
nparams
/
1e6
)
)
model_details_file
.
write
(
" Iter %d/%d, MACs SD_Input: %.2f G => %.2f G
\n
"
%
(
i
,
args
.
pruning_steps
,
base_macs_sd
/
1e9
,
macs_sd
/
1e9
)
)
# Model finetuing to recover the loast performacies
best_model_current_pruning
=
model
# If noting better is found the initial model is the betst
best_mse_current_pruning
=
None
best_optimizer_current_pruning
=
None
if
args
.
epochs
and
args
.
epochs
>
0
:
print
(
"Retraining for recovery!"
)
(
best_model
,
best_mse
,
last_epoch_model
,
last_epoch_mse
,
optimizer
,
run_logs
)
=
trainFor
(
model
=
model
,
train_dataloader
=
train_loader
,
val_dataloader
=
val_loader
,
device
=
device
,
args
=
args
,
epochs
=
args
.
epochs
,
run_folder
=
step_path
,
pruner
=
pruner
,
loss_function
=
loss_function
)
model
.
load_state_dict
(
best_model
.
state_dict
())
# Load weights of the best model!!
wandb
.
log
({
"train_logs"
:
wandb
.
Table
(
dataframe
=
run_logs
)})
run_logs
.
to_csv
(
os
.
path
.
join
(
step_path
,
f
'traning_log.csv'
),
index
=
False
)
best_model_current_pruning
=
best_model
best_mse_current_pruning
=
best_mse
# Test Best Retrained Model And log results
prune_iter_metrics
=
{}
prune_iter_metrics
[
"pruning_step"
]
=
i
prune_iter_metrics
[
"pruning_rateo"
]
=
1
-
(
nparams
/
base_nparams
)
prune_iter_metrics
[
"parameters_(M)"
]
=
nparams
/
1e6
prune_iter_metrics
[
"inference_SD_HD_flops(G)"
]
=
macs_sd
/
1e9
psnr_test
=
PSNR
(
data_range
=
args
.
data_range
,
device
=
device
)
ssim_test
=
SSIM
(
data_range
=
args
.
data_range
,
device
=
device
)
test_loss
,
test_table
=
testModel
(
loader
=
test_loader
,
model
=
model
,
args
=
args
,
psnr
=
psnr_test
,
ssim
=
ssim_test
,
data_range
=
args
.
data_range
,
device
=
device
,
loss_function
=
loss_function
)
test_table
.
to_csv
(
os
.
path
.
join
(
step_path
,
f
'test.csv'
),
index
=
False
)
test_table
.
groupby
(
'sequence'
).
mean
().
reset_index
().
to_csv
(
os
.
path
.
join
(
step_path
,
f
'test_by_sequence.csv'
),
index
=
False
)
prune_iter_metrics
[
"mse"
]
=
test_loss
prune_iter_metrics
[
'ssim'
]
=
float
(
ssim_test
.
compute
())
prune_iter_metrics
[
'psnr'
]
=
float
(
psnr_test
.
compute
())
log_metrics
=
log_metrics
.
append
(
prune_iter_metrics
,
ignore_index
=
True
)
log_metrics
.
to_csv
(
log_metrics_path
,
index
=
False
)
# Saving the fine tuned (BEST) model
ckpt_path
=
os
.
path
.
join
(
checkpoints_path
,
'pruned_iteraion_{}.pth'
.
format
(
i
))
if
not
os
.
path
.
exists
(
ckpt_path
):
torch
.
save
({
'metrics'
:
test_loss
,
'model'
:
tp
.
state_dict
(
best_model_current_pruning
),
},
ckpt_path
)
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment