import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import os
import torchvision as tv
import torchvision.transforms as transforms
# device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def dataloader(trainset,testset,batch_size=100):
    trainloader = torch.utils.data.DataLoader(
        trainset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2)

    testloader = torch.utils.data.DataLoader(
        testset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=2)

    return trainloader,testloader

def CIFAR10_dataset():
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    # datasets
    trainset = tv.datasets.CIFAR10(
        root='./data/',
        train=True,
        download=True,
        transform=transform_train)

    testset = tv.datasets.CIFAR10(
        './data/',
        train=False,
        download=True,
        transform=transform_test)

    return trainset, testset, transform_test