ckpt.py 1.08 KB
Newer Older
valentini's avatar
valentini committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import re
import os
import torch
import glob2

def get_epoch_from_name(ckpt_url):
            s = re.findall("ckpt_e(\d+).pth", ckpt_url)
            epoch = int(s[0]) if s else -1
            return epoch, ckpt_url

def load_checkpoint(ckpt_path):
    if os.path.isdir(ckpt_path):
        ckpts = glob2.glob(os.path.join(ckpt_path, '/**/*.pth'))
        assert ckpts, "No checkpoints to resume from!"

        # load checkpoint with highest epoch
        start_epoch, ckpt = max(get_epoch_from_name(c) for c in ckpts)
        ckpt_path = ckpt

    ckpt = torch.load(ckpt_path)

    if 'params' in ckpt.keys():
        return (ckpt['params'], 0, 10000000000.0)

    if not hasattr(ckpt, 'model'):
        return (ckpt, 0, 10000000000.0)
        
    state_dict = ckpt['model']
    if hasattr(state_dict, 'model'): state_dict = state_dict['model']
    if hasattr(state_dict, 'model'): state_dict = state_dict['model']
    start_epoch = ckpt['epoch'] if hasattr(ckpt, 'epoch') else 0
    best_mse = ckpt['mse_val'] if hasattr(ckpt, 'mse_val') else 10000000000.0

    return (state_dict, start_epoch, best_mse)