Commit 3a603161 authored by valentini's avatar valentini
Browse files

Carica un nuovo file

parent 305950ef
import torch.nn as nn
import torch_pruning as tp
from functools import partial
# Special Pruners
from pruners.UpsamplePruner import UpsamplePruner
def pruner_constructor(args, model, train_data, device, pruning_rateo = None):
sparsity_learning = False
# Random
if args.pruning_method == "random":
imp = tp.importance.RandomImportance()
pruner_entry = partial(tp.pruner.MagnitudePruner)
# Greg
elif args.pruning_method == "growing_reg":
sparsity_learning = True
imp = tp.importance.GroupNormImportance(p=2, group_reduction=False)
pruner_entry = partial(tp.pruner.GrowingRegPruner, reg=args.reg, delta_reg=args.delta_reg)
else:
raise Exception("Invalid Pruning Method:", args.pruning_method)
num_heads = {}
ignored_layers = []
# Cycle to extract key modules from the model
# TODO: Add layer to be ignored for pruning (usually first and last layers)
# for m in model.modules():
#ignored_layers.append(m)
# Ignore-By-Parameters
# Add Upsample blocks to ignore (we eill add later them to a separate pruning function)
# ignored_layers.append(m)
# Special Custom Pruners maps
# TODO: Upsample and UpsampleOneStep are not imported
custom_pruners = {
#E.G., UpsapleBlockTypeOfYourModel: UpsamplePruner(),
}
pruning_ratio_dict = {}
pruner = pruner_entry(
model,
importance=imp,
example_inputs=train_data.to(device),
iterative_steps=args.pruning_steps,
pruning_ratio=pruning_rateo if pruning_rateo else args.pruning_target_ratio,
pruning_ratio_dict=pruning_ratio_dict,
max_pruning_ratio=args.max_pruning_ratio,
global_pruning=False,
num_heads=num_heads,
ignored_layers=ignored_layers,
customized_pruners=custom_pruners,
root_module_types=[ nn.modules.conv._ConvNd, nn.Linear ]
)
return pruner
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment