Commit 899316af authored by valentini's avatar valentini
Browse files

Carica un nuovo file

parent f94338e8
# See https://github.com/MingSun-Tse/Regularization-Pruning/blob/master/utils.py#L286
class PresetLRScheduler(object):
"""Using a manually designed learning rate schedule rules.
"""
def __init__(self, decay_schedule):
# decay_schedule is a dictionary
# which is for specifying iteration -> lr
self.decay_schedule = decay_schedule
print('=> Using a preset learning rate schedule:')
print(decay_schedule)
self.for_once = True
def __call__(self, optimizer, iteration):
for param_group in optimizer.param_groups:
lr = self.decay_schedule.get(iteration, param_group['lr'])
param_group['lr'] = lr
@staticmethod
def get_lr(optimizer):
for param_group in optimizer.param_groups:
lr = param_group['lr']
return lr
class PresetLRScheduler(object):
"""
Custom Implementation of
See: https://openaccess.thecvf.com/content/CVPR2022W/ECV/papers/Srinivas_Cyclical_Pruning_for_Sparse_Neural_Networks_CVPRW_2022_paper.pdf
"""
def __init__(self, decay_schedule):
# decay_schedule is a dictionary
# which is for specifying iteration -> lr
self.decay_schedule = decay_schedule
print('=> Using a preset learning rate schedule:')
pprint(decay_schedule)
self.for_once = True
def __call__(self, optimizer, iteration):
for param_group in optimizer.param_groups:
lr = self.decay_schedule.get(iteration, param_group['lr'])
param_group['lr'] = lr
@staticmethod
def get_lr(optimizer):
for param_group in optimizer.param_groups:
lr = param_group['lr']
return lr
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment