Commit bca5eca1 authored by Carldst's avatar Carldst
Browse files

initial push

parents
Pipeline #34 failed with stages
in 0 seconds
{
"$schema": "",
"$id": "",
"title": "NNW NNW v1 AIW/AIM",
"Identifier": {
"ImplementerID": "/* String assigned by IIDRA */",
"Specification": {
"Standard": "MPAI-NNW",
"AIW": "NNW-imperceptibility",
"AIM": "NNW-imperceptibility",
"Version": "1"
}
},
"APIProfile": "basic",
"Description": "This AIF is used to call the AIW of NNW imperceptibility evaluation ",
"Types": [
{
"Name": "dataset",
"Type": "dataset"
},
{
"Name": "parameters",
"Type": "tensor[]"
},
{
"Name": "inference",
"Type": "output"
},
{
"Name": "bitstring",
"Type": "uint8[]"
}
],
"Ports": [
{
"Name": "training dataset",
"Direction": "Input",
"RecordType": "dataset"
},
{
"Name": "watermarked parameter",
"Direction": "InputOutput",
"RecordType": "parameters"
},
{
"Name": "watermarked inference",
"Direction": "InputOutput",
"RecordType": "inference"
},
{
"Name": "testing dataset",
"Direction": "Input",
"RecordType": "dataset"
},
{
"Name": "unwatermarked parameter",
"Direction": "InputOutput",
"RecordType": "parameters"
},
{
"Name": "unwatermarked inference",
"Direction": "InputOutput",
"RecordType": "inference"
},
{
"Name": "payload",
"Direction": "Input",
"RecordType": " bitstring "
}
],
"SubAIMs": [
{
"Name": "AIM",
"Identifier": {
"ImplementerID": "/* String assigned by IIDRA */",
"Specification": {
"Standard": "MPAI-NNW",
"AIW": "NNW-imperceptibility",
"AIM": "AIM",
"Version": "1"
}
}
},
{
"Name": "WatermarkEmbedder",
"Identifier": {
"ImplementerID": "/* String assigned by IIDRA */",
"Specification": {
"Standard": "MPAI-NNW",
"AIW": "NNW-imperceptibility",
"AIM": "WatermarkEmbedder",
"Version": "1"
}
}
},
{
"Name": "AIMtrainer",
"Identifier": {
"ImplementerID": "/* String assigned by IIDRA */",
"Specification": {
"Standard": "MPAI-NNW",
"AIW": "NNW-imperceptibility",
"AIM": "AIMtrainer",
"Version": "1"
}
}
},
{
"Name": "Comparator",
"Identifier": {
"ImplementerID": "/* String assigned by IIDRA */",
"Specification": {
"Standard": "MPAI-NNW",
"AIW": "NNW-imperceptibility",
"AIM": "Comparator",
"Version": "1"
}
}
}
],
"Topology": [
{
"Output": {
"AIMName": "",
"PortName": "Training dataset_1"
},
"Input": {
"AIMName": "AIMtrainer",
"PortName": " Training dataset_1"
}
},
{
"Output": {
"AIMName": "AIMtrainer",
"PortName": "unwatermarked parameter"
},
"Input": {
"AIMName": "AIM",
"PortName": "unwatermarked parameter"
}
},
{
"Output": {
"AIMName": "",
"PortName": "Testing dataset_1"
},
"Input": {
"AIMName": "AIM",
"PortName": "Testing dataset_1"
}
},
{
"Output": {
"AIMName": "AIM",
"PortName": "unwatermarked inference"
},
"Input": {
"AIMName": "Measure",
"PortName": "unwatermarked inference"
}
},
{
"Output": {
"AIMName": "",
"PortName": "Training dataset_2"
},
"Input": {
"AIMName": "WatermarkEmbedder",
"PortName": " Training dataset_2"
}
},
{
"Output": {
"AIMName": "",
"PortName": "payload"
},
"Input": {
"AIMName": "WatermarkEmbedder",
"PortName": "payload"
}
},
{
"Output": {
"AIMName": "WatermarkEmbedder",
"PortName": "watermarked parameter"
},
"Input": {
"AIMName": "AIM",
"PortName": "watermarked parameter"
}
},
{
"Output": {
"AIMName": "",
"PortName": "Testing dataset_2"
},
"Input": {
"AIMName": "AIM",
"PortName": "Testing dataset_2"
}
},
{
"Output": {
"AIMName": "AIM",
"PortName": "watermarked inference"
},
"Input": {
"AIMName": "Measure",
"PortName": "watermarked inference"
}
}
]
}
import UCHIDA
import time
from utils import *
from Attacks import *
#####CHangement de niveau AIMs doit etre quelque chose qui possède des ports
class ModificationModule():
Modification_ID = 1
parameters = {"P": .5}
checkpoint = torch.load("./vgg16_Uchi_weights", map_location=torch.device('cpu'))
##
output_0 = None
def funcModificationModule(self, Modification_ID, parameters, checkpoint):
'''
Apply a modification based on the ID and parameters
:param Modification_ID: ID of the modification
:param net: network to be altered
:param parameters: parameters of the modification
:return: altered NN
'''
net = tv.models.vgg16()
net.classifier = nn.Linear(25088, 10)
net.load_state_dict(checkpoint["model_state_dict"])
if Modification_ID==0:
if parameters["name"]=="all":
net= adding_noise_global(net,parameters["S"])
for module in parameters["name"]:
net=adding_noise(net,parameters["S"],module)
elif Modification_ID==1:
net= prune_model_l1_unstructured(net, parameters["P"])
elif Modification_ID==2:
net= prune_model_random_unstructured(net,parameters["R"])
elif Modification_ID==3:
net= quantization(net,parameters["B"])
elif Modification_ID==4:
net= finetuning(net,parameters["E"],parameters["trainloader"])
elif Modification_ID==5:
net= knowledge_distillation(net,parameters["E"],parameters["trainloader"],parameters["student"])
elif Modification_ID==6:
net=overwriting(net, parameters["NNWmethods"], parameters["W"], parameters["watermarking_dict"])
else:
print("NotImplemented")
torch.save({
'model_state_dict': net.state_dict(),
}, 'altered_weights.pt')
altered_parameters = torch.load('altered_weights.pt', map_location=torch.device('cpu'))
return altered_parameters
def run(self):
self.output_0=self.funcModificationModule(self.Modification_ID, self.parameters, self.checkpoint)
class WatermarkDecoder():
altered_watermarked_parameters=None
watermarking_dict=None
tools=UCHIDA.Uchi_tools
##
output_0 = None
def funcWatermarkDecoder(self, altered_watermarked_parameters, tools, watermarking_dict):
model = tv.models.vgg16()
model.classifier = nn.Linear(25088, 10)
model.to(device)
model.load_state_dict(altered_watermarked_parameters["model_state_dict"])
return tools.Decoder(model, watermarking_dict)
def run(self):
self.output_0=self.funcWatermarkDecoder(self.altered_watermarked_parameters, self.tools, self.watermarking_dict)
class Comparator():
payload1=[]
payload2=[]
##
output_0=None
def funcComparator(self,payload1,payload2):
if not type(payload1)==type(payload2):
return "missmatched payloads"
if not len(payload1)==len(payload2):
return "Not the same length"
if torch.is_tensor(payload1):
return (torch.count_nonzero(payload1-payload2)/len(payload1))
def run(self):
self.output_0=self.funcComparator(self.payload1, self.payload2)
# class APItest():
# time_set=2
# ##
# output_0=None
#
# def func1(self,time_set):
# print("func1: starting")
# for i in range(10):
# time.sleep(2)
# pass
# print("func1: finishing")
#
# def run(self):
# self.output_0=self.func1(self.time_set)
\ No newline at end of file
{
"$schema": "",
"$id": "",
"title": "NNW NNW v1 AIW/AIM",
"Identifier": {
"ImplementerID": "/* String assigned by IIDRA */",
"Specification": {
"Standard": "MPAI-NNW",
"AIW": "NNW-robustness",
"AIM": "NNW-robustness",
"Version": "1"
}
},
"APIProfile": "robustness",
"Description": "This AIF is used to call the AIW of NNW robustness evaluation when the payload is retrieved in the parameters",
"Types": [
{
"Name":"parameters",
"Type":"tensor[]"
},
{
"Name":"bitstring",
"Type":"uint8[]"
}
],
"Ports": [
{
"Name":"WatermarkedParameter",
"Direction":"Input",
"RecordType":"parameters"
},
{
"Name":"AlteredWatermarkedParameter",
"Direction":"InputOutput",
"RecordType":"parameters"
},
{
"Name":"RetrievedPayload",
"Direction":"InputOutput",
"RecordType":"bitstring"
},
{
"Name":"Payload",
"Direction":"Input",
"RecordType":" bitstring",
"Technology":"Neural Network",
"Protocol":"",
"IsRemote": false
}
],
"SubAIMs": [
{
"Name": "ModificationModule",
"Identifier": {
"ImplementerID": "/* String assigned by IIDRA */",
"Specification": {
"Standard": "MPAI-NNW",
"AIW": "NNW-robustness",
"AIM": "ModificationModule",
"Version": "1"
}
}
},
{
"Name": "WatermarkDecoder",
"Identifier": {
"ImplementerID": "/* String assigned by IIDRA */",
"Specification": {
"Standard": "MPAI-NNW",
"AIW": "NNW-robustness",
"AIM": "ModificationModule",
"Version": "1"
}
}
},
{
"Name": "Comparator",
"Identifier": {
"ImplementerID": "/* String assigned by IIDRA */",
"Specification": {
"Standard": "MPAI-NNW",
"AIW": "NNW-robustness",
"AIM": "Comparator",
"Version": "1"
}
}
}
],
"Topology": [
{
"Output":{
"AIMName":"",
"PortName":"watermarked parameters"
},
"Input":{
"AIMName":"AttackModule",
"PortName":"watermarked parameters"
}
},
{
"Output":{
"AIMName":"AttackModule",
"PortName":"AlteredWatermarkedParameter"
},
"Input":{
"AIMName":"WatermarkDecoder",
"PortName":"AlteredWatermarkedParameter"
}
},
{
"Output":{
"AIMName":"WatermarkDecoder",
"PortName":"RetrievedPayload"
},
"Input":{
"AIMName":"Comparator",
"PortName":"RetrievedPayload"
}
},
{
"Output":{
"AIMName":"",
"PortName":"Payload"
},
"Input":{
"AIMName":"Measure",
"PortName":"Payload"
}
}
]
}
message=[]
dict_process={}
AIMs={}
AIM_dict={}
# controller.py https://www.bogotobogo.com/python/python_network_programming_server_client.php
import socket
import time
from Attacks import *
from UCHIDA import Uchi_tools
from ADI import Adi_tools
from multiprocessing import Process
from APIs import *
import psutil
import os
import ast
import tkinter as tk
from tkinter import filedialog
import wget
import config
# create a socket object
serversocket = socket.socket(
socket.AF_INET, socket.SOCK_STREAM)
# get local machine name
host = socket.gethostname()
port = 12345
# bind to the port
serversocket.bind((host, port))
# queue up to 5 requests
serversocket.listen(5)
print("Controller Initialized")
CompCostFlag=False
while True:
# establish a connection
clientsocket, addr = serversocket.accept()
# print("Got a connection from %s" % str(addr))
# currentTime = time.ctime(time.time()) + "\r\n"
data=clientsocket.recv(1024)
message=data.decode()
message = message.split()
if not data: break
if "help" in message[0].lower():
### to be updated
print(" this program is the implementation of NNW in the AIF")
print(" you can run AIM/AIW by sending 'run XX' ")
print(" you can pause AIM/AIW by sending 'stop XX' ")
print(" you can resume AIM/AIW by sending 'resume XX' ")
print(" you can obtain the status of AIM/AIW by sending 'status XX' ")
print(" you can end the program by typing 'exit'")
print(" ----------------------------------------")
if "wget" in message[0].lower():
test=wget.download(message[1])
print(type(test))
elif "getparse" in message[0].lower():
#print(type(message[1]),message[1]) #always str
root = tk.Tk()
root.withdraw()
filename = filedialog.askopenfilename(title='Select the parameter file', filetypes=(("Text files",
"*.zip"),
("all files",
"*.*")))
json_dict = MPAI_AIFS_GetAndParseArchive(filename)
time.sleep(.5)
import AIW.AIMs_files as AIMs_file
config.AIM_dict = json_dict['SubAIMs']
for i in range(len(json_dict['SubAIMs'])):
config.AIMs[config.AIM_dict[i]["Name"]] = getattr(AIMs_file, config.AIM_dict[i]["Name"])()
### AIMs file should be in the .zip
# print(AIMs.keys())
elif 'write' in message[0].lower():
## message[1] AIM_name, message[2] port_name, message[3] what to write
MPAI_AIFM_Port_Input_Write(message[1],message[2],message[3])
elif "read" in message[0].lower():
## message[1] AIM_name, message[2] port_name
result=MPAI_AIFM_Port_Output_Read(message[1],message[2])
print(message[2], "of", message[1], ":", result, type(result))
elif "reset" in message[0].lower():
## message[1] AIM_name, message[2] port_name
MPAI_AIFM_Port_Reset(message[1],message[2])
elif 'ComputationalCost' in message[0].lower():
CompCostFlag=ast.literal_eval(str(message[1]))
elif 'run' in message[0].lower():
### here you can want for a next message instead
# print("waiting for the type of run")
# clientsocket, addr = serversocket.accept()
# data = clientsocket.recv(1024)
# message = data.decode()
# message = message.split()
if "robustness" in message[1].lower():
# run robustness 1 {"P":.1}
# try message[2]:
param = ast.literal_eval(str(message[3]))
M_ID = int(message[2])
root = tk.Tk()
root.withdraw()
reload = filedialog.askopenfilename(title='Select the parameter file') # show an "Open" dialog box and return the path to the selected file
watermarked_parameters = torch.load(reload, map_location=torch.device('cpu'))
MPAI_AIFM_Port_Input_Write("ModificationModule", "param",param)
MPAI_AIFM_Port_Input_Write("ModificationModule", "checkpoint", watermarked_parameters)
MPAI_AIFM_AIM_Start("ModificationModule")
print("watermarked AIM altered")
MPAI_AIFM_Port_Input_Write("WatermarkDecoder","altered_watermarked_parameters",
MPAI_AIFM_Port_Output_Read("ModificationModule","output_0"))
tools = Uchi_tools()
root = tk.Tk()
root.withdraw()
reload_npy = filedialog.askopenfilename(title='Select the watermarking_dict')
watermarking_dict = np.load(reload_npy, allow_pickle=True).item()
## to be placed
MPAI_AIFM_Port_Input_Write("WatermarkDecoder", "tools", tools)
MPAI_AIFM_Port_Input_Write("WatermarkDecoder", "watermarking_dict", watermarking_dict)
if CompCostFlag:
time1=time.time()
MPAI_AIFM_AIM_Start("WatermarkDecoder")
if CompCostFlag:
time2=time.time()
MPAI_AIFM_Port_Input_Write("Comparator","payload1",watermarking_dict["watermark"])
MPAI_AIFM_Port_Input_Write("Comparator", "payload2",
MPAI_AIFM_Port_Output_Read("WatermarkDecoder","output_0"))
MPAI_AIFM_AIM_Start("Comparator")
print('BER : %s' % (MPAI_AIFM_Port_Output_Read("Comparator","output_0")))
print('Retrieved watermark : %s' % (MPAI_AIFM_Port_Output_Read("WatermarkDecoder","output_0")))
if CompCostFlag:
print("time of execution: %.5f sec" %(time2-time1))
elif "imperceptibility" in message[1].lower():
## run imperceptibility vgg16 cifar10
if "vgg16" in message[2].lower():
model = tv.models.vgg16()
model.classifier = nn.Linear(25088, 10)
else:
print(message[2],"not found - default loading vgg16")
model = tv.models.vgg16()
model.classifier = nn.Linear(25088, 10)
if "cifar10" in message[3].lower():
trainset,testset,tfm=CIFAR10_dataset()
else:
print(message[3],"not found - default loading CIFAR10")
trainset,testset,tfm=CIFAR10_dataset()
MPAI_AIFM_Port_Input_Write("WatermarkEmbedder", "AIM", model)
if CompCostFlag:
time1=time.time()
MPAI_AIFM_AIM_Start("WatermarkEmbedder")
if CompCostFlag:
time2=time.time()
MPAI_AIFM_Port_Input_Write("AIM", "model", model)
MPAI_AIFM_Port_Input_Write("AIM", "parameters", MPAI_AIFM_Port_Output_Read("WatermarkEmbedder","output_0"))
MPAI_AIFM_Port_Input_Write("AIM", "testingDataset", testset)
MPAI_AIFM_AIM_Start("AIM")
print('AIM_watermarked_result : %s' % (MPAI_AIFM_Port_Output_Read("AIM","output_0")))
if CompCostFlag:
print("time of execution: %.5f sec" %(time2-time1))
MPAI_AIFM_Port_Input_Write("AIMtrainer", "AIM", model)
MPAI_AIFM_AIM_Start("AIMtrainer")
MPAI_AIFM_Port_Input_Write("AIM", "model", model)
MPAI_AIFM_Port_Input_Write("AIM", "parameters", MPAI_AIFM_Port_Output_Read("AIMtrainer","output_0"))
MPAI_AIFM_Port_Input_Write("AIM", "testingDataset", testset)
MPAI_AIFM_AIM_Start("AIM")
print('AIM_unwatermarked_result : %s' % (MPAI_AIFM_Port_Output_Read("AIM", "output_0")))
else:
print(message[1].lower(),"not implemented")
elif 'status' in message[0].lower():
print(config.dict_process)
MPAI_AIFM_AIM_GetStatus(message[1])
elif 'pause' in message[0].lower():
MPAI_AIFM_AIM_Pause(message[1])
elif 'resume' in message[0].lower():
MPAI_AIFM_AIM_Resume(message[1])
elif 'stop' in message[0].lower():
if message[1].lower() in config.dict_process:
config.dict_process[message[1]].terminate()
print( message[1], "stopped")
else:
print(message[1], "isn't running")
elif "exit" in message[0].lower():
print("ending session...")
break
else:
print("input not implemented")
clientsocket.close()
print("session ended")
### TO DO https://docs.python.org/3/library/multiprocessing.html
\ No newline at end of file
# client.py
import socket
import time
# create a socket object
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# get local machine name
host = socket.gethostname()
port = 12345
# connection to hostname on the port.
s.connect((host, port))
message = input('input:')
b_message=str.encode(message)
s.sendall(b_message)
# This file may be used to create an environment using:
# $ conda create --name <env> --file <this file>
# platform: linux-64
_libgcc_mutex=0.1=main
_openmp_mutex=5.1=1_gnu
blas=1.0=mkl
brotli=1.0.9=h5eee18b_7
brotli-bin=1.0.9=h5eee18b_7
brotlipy=0.7.0=py39h27cfd23_1003
bzip2=1.0.8=h7b6447c_0
ca-certificates=2023.08.22=h06a4308_0
certifi=2023.7.22=py39h06a4308_0
cffi=1.15.1=py39h5eee18b_3
charset-normalizer=2.0.4=pyhd3eb1b0_0
contourpy=1.0.5=py39hdb19cb5_0
cryptography=41.0.2=py39h22a60cf_0
cuda-cudart=11.8.89=0
cuda-cupti=11.8.87=0
cuda-libraries=11.8.0=0
cuda-nvrtc=11.8.89=0
cuda-nvtx=11.8.86=0
cuda-runtime=11.8.0=0
cycler=0.11.0=pyhd3eb1b0_0
cyrus-sasl=2.1.28=h52b45da_1
dbus=1.13.18=hb2f20db_0
expat=2.4.9=h6a678d5_0
ffmpeg=4.3=hf484d3e_0
filelock=3.9.0=py39h06a4308_0
fontconfig=2.14.1=h4c34cd2_2
fonttools=4.25.0=pyhd3eb1b0_0
freetype=2.12.1=h4a9f257_0
giflib=5.2.1=h5eee18b_3
glib=2.69.1=he621ea3_2
gmp=6.2.1=h295c915_3
gmpy2=2.1.2=py39heeb90bb_0
gnutls=3.6.15=he1e5248_0
gst-plugins-base=1.14.1=h6a678d5_1
gstreamer=1.14.1=h5eee18b_1
icu=58.2=he6710b0_3
idna=3.4=py39h06a4308_0
importlib_resources=5.2.0=pyhd3eb1b0_1
intel-openmp=2023.1.0=hdb19cb5_46305
jinja2=3.1.2=py39h06a4308_0
jpeg=9e=h5eee18b_1
kiwisolver=1.4.4=py39h6a678d5_0
krb5=1.20.1=h143b758_1
lame=3.100=h7b6447c_0
lcms2=2.12=h3be6417_0
ld_impl_linux-64=2.38=h1181459_1
lerc=3.0=h295c915_0
libbrotlicommon=1.0.9=h5eee18b_7
libbrotlidec=1.0.9=h5eee18b_7
libbrotlienc=1.0.9=h5eee18b_7
libclang=14.0.6=default_hc6dbbc7_1
libclang13=14.0.6=default_he11475f_1
libcublas=11.11.3.6=0
libcufft=10.9.0.58=0
libcufile=1.7.1.12=0
libcups=2.4.2=h2d74bed_1
libcurand=10.3.3.129=0
libcusolver=11.4.1.48=0
libcusparse=11.7.5.86=0
libdeflate=1.17=h5eee18b_0
libedit=3.1.20221030=h5eee18b_0
libevent=2.1.12=hdbd6064_1
libffi=3.4.4=h6a678d5_0
libgcc-ng=11.2.0=h1234567_1
libgomp=11.2.0=h1234567_1
libiconv=1.16=h7f8727e_2
libidn2=2.3.4=h5eee18b_0
libllvm14=14.0.6=hdb19cb5_3
libnpp=11.8.0.86=0
libnvjpeg=11.9.0.86=0
libpng=1.6.39=h5eee18b_0
libpq=12.15=hdbd6064_1
libstdcxx-ng=11.2.0=h1234567_1
libtasn1=4.19.0=h5eee18b_0
libtiff=4.5.0=h6a678d5_2
libunistring=0.9.10=h27cfd23_0
libuuid=1.41.5=h5eee18b_0
libwebp=1.2.4=h11a3e52_1
libwebp-base=1.2.4=h5eee18b_1
libxcb=1.15=h7f8727e_0
libxkbcommon=1.0.1=h5eee18b_1
libxml2=2.10.4=hcbfbd50_0
libxslt=1.1.37=h2085143_0
lz4-c=1.9.4=h6a678d5_0
markupsafe=2.1.1=py39h7f8727e_0
matplotlib=3.7.1=py39h06a4308_1
matplotlib-base=3.7.1=py39h417a72b_1
mkl=2023.1.0=h213fc3f_46343
mkl-service=2.4.0=py39h5eee18b_1
mkl_fft=1.3.6=py39h417a72b_1
mkl_random=1.2.2=py39h417a72b_1
mpc=1.1.0=h10f8cd9_1
mpfr=4.0.2=hb69a4c5_1
mpmath=1.3.0=py39h06a4308_0
munkres=1.1.4=py_0
mysql=5.7.24=h721c034_2
ncurses=6.4=h6a678d5_0
nettle=3.7.3=hbbd107a_1
networkx=3.1=py39h06a4308_0
nspr=4.35=h6a678d5_0
nss=3.89.1=h6a678d5_0
numpy=1.25.2=py39h5f9d8c6_0
numpy-base=1.25.2=py39hb5e798b_0
openh264=2.1.1=h4ff587b_0
openssl=3.0.10=h7f8727e_2
packaging=23.0=py39h06a4308_0
pcre=8.45=h295c915_0
pillow=9.4.0=py39h6a678d5_0
pip=23.2.1=py39h06a4308_0
ply=3.11=py39h06a4308_0
psutil=5.9.0=py39h5eee18b_0
pycparser=2.21=pyhd3eb1b0_0
pyopenssl=23.2.0=py39h06a4308_0
pyparsing=3.0.9=py39h06a4308_0
pyqt=5.15.7=py39h6a678d5_1
pyqt5-sip=12.11.0=py39h6a678d5_1
pysocks=1.7.1=py39h06a4308_0
python=3.9.17=h955ad1f_0
python-dateutil=2.8.2=pyhd3eb1b0_0
pytorch=2.0.1=py3.9_cuda11.8_cudnn8.7.0_0
pytorch-cuda=11.8=h7e8668a_5
pytorch-mutex=1.0=cuda
qt-main=5.15.2=h7358343_9
qt-webengine=5.15.9=h9ab4d14_7
qtwebkit=5.212=h3fafdc1_5
readline=8.2=h5eee18b_0
requests=2.31.0=py39h06a4308_0
setuptools=68.0.0=py39h06a4308_0
sip=6.6.2=py39h6a678d5_0
six=1.16.0=pyhd3eb1b0_1
sqlite=3.41.2=h5eee18b_0
sympy=1.11.1=py39h06a4308_0
tbb=2021.8.0=hdb19cb5_0
tk=8.6.12=h1ccaba5_0
tmux=3.3a=h5eee18b_1
toml=0.10.2=pyhd3eb1b0_0
torchaudio=2.0.2=py39_cu118
torchtriton=2.0.0=py39
torchvision=0.15.2=py39_cu118
tornado=6.3.2=py39h5eee18b_0
tqdm=4.65.0=py39hb070fc8_0
typing_extensions=4.7.1=py39h06a4308_0
tzdata=2023c=h04d1e81_0
urllib3=1.26.16=py39h06a4308_0
wget=3.2=pypi_0
wheel=0.38.4=py39h06a4308_0
xz=5.4.2=h5eee18b_0
zipp=3.11.0=py39h06a4308_0
zlib=1.2.13=h5eee18b_0
zstd=1.5.5=hc292b87_0
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import os
import torchvision as tv
import torchvision.transforms as transforms
# device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def dataloader(trainset,testset,batch_size=100):
trainloader = torch.utils.data.DataLoader(
trainset,
batch_size=batch_size,
shuffle=True,
num_workers=2)
testloader = torch.utils.data.DataLoader(
testset,
batch_size=batch_size,
shuffle=False,
num_workers=2)
return trainloader,testloader
def CIFAR10_dataset():
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
# datasets
trainset = tv.datasets.CIFAR10(
root='./data/',
train=True,
download=True,
transform=transform_train)
testset = tv.datasets.CIFAR10(
'./data/',
train=False,
download=True,
transform=transform_test)
return trainset, testset, transform_test
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment