import torch.nn as nn import torch_pruning as tp from functools import partial # Special Pruners from pruners.UpsamplePruner import UpsamplePruner def pruner_constructor(args, model, train_data, device, pruning_rateo = None): sparsity_learning = False # Random if args.pruning_method == "random": imp = tp.importance.RandomImportance() pruner_entry = partial(tp.pruner.MagnitudePruner) # Greg elif args.pruning_method == "growing_reg": sparsity_learning = True imp = tp.importance.GroupNormImportance(p=2, group_reduction=False) pruner_entry = partial(tp.pruner.GrowingRegPruner, reg=args.reg, delta_reg=args.delta_reg) else: raise Exception("Invalid Pruning Method:", args.pruning_method) num_heads = {} ignored_layers = [] # Cycle to extract key modules from the model # TODO: Add layer to be ignored for pruning (usually first and last layers) # for m in model.modules(): #ignored_layers.append(m) # Ignore-By-Parameters # Add Upsample blocks to ignore (we eill add later them to a separate pruning function) # ignored_layers.append(m) # Special Custom Pruners maps # TODO: Upsample and UpsampleOneStep are not imported custom_pruners = { #E.G., UpsapleBlockTypeOfYourModel: UpsamplePruner(), } pruning_ratio_dict = {} pruner = pruner_entry( model, importance=imp, example_inputs=train_data.to(device), iterative_steps=args.pruning_steps, pruning_ratio=pruning_rateo if pruning_rateo else args.pruning_target_ratio, pruning_ratio_dict=pruning_ratio_dict, max_pruning_ratio=args.max_pruning_ratio, global_pruning=False, num_heads=num_heads, ignored_layers=ignored_layers, customized_pruners=custom_pruners, root_module_types=[ nn.modules.conv._ConvNd, nn.Linear ] ) return pruner