train.py 3.29 KB
Newer Older
valentini's avatar
valentini committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os
import yaml
import torch
from torch.utils.data import DataLoader, random_split
from torch import nn, optim
from dataloader import SRDataset
from settings import Settings, load_settings
from model.dummy_sr_model import DummySRModel # Use your model insted of the Dummy One

# Load settings from config.yaml
config_path = os.path.join(os.path.dirname(__file__), '../config.yaml')
settings = load_settings(config_path)

# Create dataset
full_dataset = SRDataset(settings)

# Split into train and validation (80/20)
dataset_size = len(full_dataset)
val_size = int(0.2 * dataset_size)
train_size = dataset_size - val_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

# Create DataLoaders using settings.batch_size
train_loader = DataLoader(train_dataset, batch_size=settings.batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=settings.batch_size, shuffle=False, num_workers=0)

# Determine input/output channels
if settings.luminance_only:
    in_channels = 1 if not settings.clone_luminance_as_rgb else 3
else:
    in_channels = 3

model = DummySRModel(in_channels=in_channels, out_channels=in_channels, super_resolution=settings.scaling)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)

# Loss and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Scheduler configuration
use_plateau_scheduler = getattr(settings, 'use_plateau_scheduler', False)
step_size = getattr(settings, 'lr_step_size', 10)
gamma = getattr(settings, 'lr_gamma', 0.5)
if use_plateau_scheduler:
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=gamma, patience=5, verbose=True)
else:
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma, verbose=True)

epochs = 2  # For demonstration
best_val_mse = float('inf')

for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    for batch in train_loader:
        low_sr = batch['low_sr'].to(device)
        high_sr = batch['high_sr'].to(device)

        optimizer.zero_grad()
        output = model(low_sr)
        loss = criterion(output, high_sr)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * low_sr.size(0)
    avg_loss = running_loss / len(train_loader.dataset)
    print(f"Epoch {epoch+1}/{epochs}, Train Loss: {avg_loss:.4f}")

    # Validation
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for batch in val_loader:
            low_sr = batch['low_sr'].to(device)
            high_sr = batch['high_sr'].to(device)
            output = model(low_sr)
            loss = criterion(output, high_sr)
            val_loss += loss.item() * low_sr.size(0)
    avg_val_loss = val_loss / len(val_loader.dataset)
    print(f"Epoch {epoch+1}/{epochs}, Val Loss: {avg_val_loss:.4f}")

    # Step scheduler
    if use_plateau_scheduler:
        scheduler.step(avg_val_loss)
    else:
        scheduler.step()
    print(f"Current learning rate: {optimizer.param_groups[0]['lr']}")

    # Save model if best validation MSE
    if avg_val_loss < best_val_mse:
        best_val_mse = avg_val_loss
        torch.save(model.state_dict(), 'best_model.pth')
        print(f"Best model saved at epoch {epoch+1} with Val MSE: {best_val_mse:.4f}")

print("Training complete.")