import os import yaml import torch from torch.utils.data import DataLoader, random_split from torch import nn, optim from dataloader import SRDataset from settings import Settings, load_settings from model.dummy_sr_model import DummySRModel # Use your model insted of the Dummy One # Load settings from config.yaml config_path = os.path.join(os.path.dirname(__file__), '../config.yaml') settings = load_settings(config_path) # Create dataset full_dataset = SRDataset(settings) # Split into train and validation (80/20) dataset_size = len(full_dataset) val_size = int(0.2 * dataset_size) train_size = dataset_size - val_size train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size]) # Create DataLoaders using settings.batch_size train_loader = DataLoader(train_dataset, batch_size=settings.batch_size, shuffle=True, num_workers=0) val_loader = DataLoader(val_dataset, batch_size=settings.batch_size, shuffle=False, num_workers=0) # Determine input/output channels if settings.luminance_only: in_channels = 1 if not settings.clone_luminance_as_rgb else 3 else: in_channels = 3 model = DummySRModel(in_channels=in_channels, out_channels=in_channels, super_resolution=settings.scaling) device = 'cuda' if torch.cuda.is_available() else 'cpu' model = model.to(device) # Loss and optimizer criterion = nn.MSELoss() optimizer = optim.Adam(model.parameters(), lr=1e-3) # Scheduler configuration use_plateau_scheduler = getattr(settings, 'use_plateau_scheduler', False) step_size = getattr(settings, 'lr_step_size', 10) gamma = getattr(settings, 'lr_gamma', 0.5) if use_plateau_scheduler: scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=gamma, patience=5, verbose=True) else: scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma, verbose=True) epochs = 2 # For demonstration best_val_mse = float('inf') for epoch in range(epochs): model.train() running_loss = 0.0 for batch in train_loader: low_sr = batch['low_sr'].to(device) high_sr = batch['high_sr'].to(device) optimizer.zero_grad() output = model(low_sr) loss = criterion(output, high_sr) loss.backward() optimizer.step() running_loss += loss.item() * low_sr.size(0) avg_loss = running_loss / len(train_loader.dataset) print(f"Epoch {epoch+1}/{epochs}, Train Loss: {avg_loss:.4f}") # Validation model.eval() val_loss = 0.0 with torch.no_grad(): for batch in val_loader: low_sr = batch['low_sr'].to(device) high_sr = batch['high_sr'].to(device) output = model(low_sr) loss = criterion(output, high_sr) val_loss += loss.item() * low_sr.size(0) avg_val_loss = val_loss / len(val_loader.dataset) print(f"Epoch {epoch+1}/{epochs}, Val Loss: {avg_val_loss:.4f}") # Step scheduler if use_plateau_scheduler: scheduler.step(avg_val_loss) else: scheduler.step() print(f"Current learning rate: {optimizer.param_groups[0]['lr']}") # Save model if best validation MSE if avg_val_loss < best_val_mse: best_val_mse = avg_val_loss torch.save(model.state_dict(), 'best_model.pth') print(f"Best model saved at epoch {epoch+1} with Val MSE: {best_val_mse:.4f}") print("Training complete.")