fine_tuning.py 959 Bytes
Newer Older
Carldst's avatar
Carldst committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# fine tuning

import matplotlib.pyplot as plt
from tqdm import tqdm
from utils import *

def finetuning(net,epochs,trainloader):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
    for epoch in tqdm(range(epochs)):
        net.train()
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            # split data into the image and its label
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            # initialise the optimiser
            optimizer.zero_grad()
            # forward
            outputs = net(inputs)
            # backward
            loss = criterion(outputs, labels)
            loss.backward()
            # update the optimizer
            optimizer.step()
            # loss
            running_loss += loss.item()
    return net

'''
    M_ID=4
    param={"E":1}
'''