Commit 5d6523f6 authored by Carl De Sousa Trias's avatar Carl De Sousa Trias
Browse files

initial push

parent 46c44bb3
from utils import *
from NNW import Uchi_tools, Adi_tools
from Attacks import *
def Modifications(Modification_ID,net,parameters):
'''
Apply a modification based on the ID and parameters
:param Modification_ID: ID of the modification
:param net: network to be altered
:param parameters: parameters of the modification
:return: altered NN
'''
if Modification_ID==0:
if parameters["name"]=="all":
return adding_noise_global(net,parameters["S"])
for module in parameters["name"]:
net=adding_noise(net,parameters["S"],module)
return net
elif Modification_ID==1:
return prune_model_l1_unstructured(net, parameters["P"])
elif Modification_ID==2:
return prune_model_random_unstructured(net,parameters["R"])
elif Modification_ID==3:
return quantization(net,parameters["B"])
elif Modification_ID==4:
return finetuning(net,parameters["E"])
elif Modification_ID==5:
return knowledge_distillation(net,parameters["E"],parameters["trainloader"],parameters["student"])
elif Modification_ID==6:
return overwriting(net, parameters["NNWmethods"], parameters["W"], parameters["watermarking_dict"])
else:
print("NotImplemented")
return net
if __name__ == '__main__':
###### Reproductibility
torch.manual_seed(0)
np.random.seed(0)
model = tv.models.vgg16()
model.classifier = nn.Linear(25088, 10)
model.to(device)
# watermarking section (change here to test another method) #######################################
tools = Uchi_tools()
reload = 'Resources/vgg16_Uchi'
watermarking_dict = np.load(reload+'_watermarking_dict.npy', allow_pickle=True).item()
# watermarking section (END change here to test another method) ###################################
name = '_quantization'
time_detect=[]
# take model
checkpoint = torch.load(reload + "_weights", map_location=torch.device('cpu'))
model.load_state_dict(checkpoint["model_state_dict"])
M_ID = 0
param = {"name": "all","S":.1}
# M_ID = 1
# param = {"P": .5}
# M_ID = 2
# param = {"R": .2}
# M_ID = 3
# param = {"B": 5}
# M_ID = 4
# trainset, testset, _ = CIFAR10_dataset()
# trainloader, testloader = dataloader(trainset, testset, 100)
# param = {"E": 5, "trainloader": trainloader}
# M_ID = 5
# trainset, testset, inference_transform = CIFAR10_dataset()
# trainloader, testloader = dataloader(trainset, testset, 128)
# student = tv.models.vgg16()
# student.classifier = nn.Linear(25088, 10)
# param = {"E":5,"trainloader":trainloader,"student":student}
# M_ID = 6
# param = {"NNWmethods":tools,"W":2,"watermarking_dict":watermarking_dict}
model = Modifications(M_ID,model,param)
model.eval()
# watermark,retrieve_res=tools.Decoder(model, watermarking_dict)
# print('Modification %s - %s - Percentage of erred bits : %2f ' % (str(M_ID),str(param), retrieve_res))
retrieve, decision=tools.Detector(model, watermarking_dict)
print('Modification %s - %s - Presence of the watermark : %s' % (str(M_ID),str(param), decision))
# val_score= fulltest(new_model, testloader)
# print('Validation error : %.2f' % val_score)
# This file may be used to create an environment using:
# $ conda create --name <env> --file <this file>
# platform: linux-64
_libgcc_mutex=0.1=main
_openmp_mutex=5.1=1_gnu
blas=1.0=mkl
brotli=1.0.9=h5eee18b_7
brotli-bin=1.0.9=h5eee18b_7
brotlipy=0.7.0=py39h27cfd23_1003
bzip2=1.0.8=h7b6447c_0
ca-certificates=2023.08.22=h06a4308_0
certifi=2023.7.22=py39h06a4308_0
cffi=1.15.1=py39h5eee18b_3
charset-normalizer=2.0.4=pyhd3eb1b0_0
contourpy=1.0.5=py39hdb19cb5_0
cryptography=41.0.2=py39h22a60cf_0
cuda-cudart=11.8.89=0
cuda-cupti=11.8.87=0
cuda-libraries=11.8.0=0
cuda-nvrtc=11.8.89=0
cuda-nvtx=11.8.86=0
cuda-runtime=11.8.0=0
cycler=0.11.0=pyhd3eb1b0_0
cyrus-sasl=2.1.28=h52b45da_1
dbus=1.13.18=hb2f20db_0
expat=2.4.9=h6a678d5_0
ffmpeg=4.3=hf484d3e_0
filelock=3.9.0=py39h06a4308_0
fontconfig=2.14.1=h4c34cd2_2
fonttools=4.25.0=pyhd3eb1b0_0
freetype=2.12.1=h4a9f257_0
giflib=5.2.1=h5eee18b_3
glib=2.69.1=he621ea3_2
gmp=6.2.1=h295c915_3
gmpy2=2.1.2=py39heeb90bb_0
gnutls=3.6.15=he1e5248_0
gst-plugins-base=1.14.1=h6a678d5_1
gstreamer=1.14.1=h5eee18b_1
icu=58.2=he6710b0_3
idna=3.4=py39h06a4308_0
importlib_resources=5.2.0=pyhd3eb1b0_1
intel-openmp=2023.1.0=hdb19cb5_46305
jinja2=3.1.2=py39h06a4308_0
jpeg=9e=h5eee18b_1
kiwisolver=1.4.4=py39h6a678d5_0
krb5=1.20.1=h143b758_1
lame=3.100=h7b6447c_0
lcms2=2.12=h3be6417_0
ld_impl_linux-64=2.38=h1181459_1
lerc=3.0=h295c915_0
libbrotlicommon=1.0.9=h5eee18b_7
libbrotlidec=1.0.9=h5eee18b_7
libbrotlienc=1.0.9=h5eee18b_7
libclang=14.0.6=default_hc6dbbc7_1
libclang13=14.0.6=default_he11475f_1
libcublas=11.11.3.6=0
libcufft=10.9.0.58=0
libcufile=1.7.1.12=0
libcups=2.4.2=h2d74bed_1
libcurand=10.3.3.129=0
libcusolver=11.4.1.48=0
libcusparse=11.7.5.86=0
libdeflate=1.17=h5eee18b_0
libedit=3.1.20221030=h5eee18b_0
libevent=2.1.12=hdbd6064_1
libffi=3.4.4=h6a678d5_0
libgcc-ng=11.2.0=h1234567_1
libgomp=11.2.0=h1234567_1
libiconv=1.16=h7f8727e_2
libidn2=2.3.4=h5eee18b_0
libllvm14=14.0.6=hdb19cb5_3
libnpp=11.8.0.86=0
libnvjpeg=11.9.0.86=0
libpng=1.6.39=h5eee18b_0
libpq=12.15=hdbd6064_1
libstdcxx-ng=11.2.0=h1234567_1
libtasn1=4.19.0=h5eee18b_0
libtiff=4.5.0=h6a678d5_2
libunistring=0.9.10=h27cfd23_0
libuuid=1.41.5=h5eee18b_0
libwebp=1.2.4=h11a3e52_1
libwebp-base=1.2.4=h5eee18b_1
libxcb=1.15=h7f8727e_0
libxkbcommon=1.0.1=h5eee18b_1
libxml2=2.10.4=hcbfbd50_0
libxslt=1.1.37=h2085143_0
lz4-c=1.9.4=h6a678d5_0
markupsafe=2.1.1=py39h7f8727e_0
matplotlib=3.7.1=py39h06a4308_1
matplotlib-base=3.7.1=py39h417a72b_1
mkl=2023.1.0=h213fc3f_46343
mkl-service=2.4.0=py39h5eee18b_1
mkl_fft=1.3.6=py39h417a72b_1
mkl_random=1.2.2=py39h417a72b_1
mpc=1.1.0=h10f8cd9_1
mpfr=4.0.2=hb69a4c5_1
mpmath=1.3.0=py39h06a4308_0
munkres=1.1.4=py_0
mysql=5.7.24=h721c034_2
ncurses=6.4=h6a678d5_0
nettle=3.7.3=hbbd107a_1
networkx=3.1=py39h06a4308_0
nspr=4.35=h6a678d5_0
nss=3.89.1=h6a678d5_0
numpy=1.25.2=py39h5f9d8c6_0
numpy-base=1.25.2=py39hb5e798b_0
openh264=2.1.1=h4ff587b_0
openssl=3.0.10=h7f8727e_2
packaging=23.0=py39h06a4308_0
pcre=8.45=h295c915_0
pillow=9.4.0=py39h6a678d5_0
pip=23.2.1=py39h06a4308_0
ply=3.11=py39h06a4308_0
psutil=5.9.0=py39h5eee18b_0
pycparser=2.21=pyhd3eb1b0_0
pyopenssl=23.2.0=py39h06a4308_0
pyparsing=3.0.9=py39h06a4308_0
pyqt=5.15.7=py39h6a678d5_1
pyqt5-sip=12.11.0=py39h6a678d5_1
pysocks=1.7.1=py39h06a4308_0
python=3.9.17=h955ad1f_0
python-dateutil=2.8.2=pyhd3eb1b0_0
pytorch=2.0.1=py3.9_cuda11.8_cudnn8.7.0_0
pytorch-cuda=11.8=h7e8668a_5
pytorch-mutex=1.0=cuda
qt-main=5.15.2=h7358343_9
qt-webengine=5.15.9=h9ab4d14_7
qtwebkit=5.212=h3fafdc1_5
readline=8.2=h5eee18b_0
requests=2.31.0=py39h06a4308_0
setuptools=68.0.0=py39h06a4308_0
sip=6.6.2=py39h6a678d5_0
six=1.16.0=pyhd3eb1b0_1
sqlite=3.41.2=h5eee18b_0
sympy=1.11.1=py39h06a4308_0
tbb=2021.8.0=hdb19cb5_0
tk=8.6.12=h1ccaba5_0
tmux=3.3a=h5eee18b_1
toml=0.10.2=pyhd3eb1b0_0
torchaudio=2.0.2=py39_cu118
torchtriton=2.0.0=py39
torchvision=0.15.2=py39_cu118
tornado=6.3.2=py39h5eee18b_0
tqdm=4.65.0=py39hb070fc8_0
typing_extensions=4.7.1=py39h06a4308_0
tzdata=2023c=h04d1e81_0
urllib3=1.26.16=py39h06a4308_0
wget=3.2=pypi_0
wheel=0.38.4=py39h06a4308_0
xz=5.4.2=h5eee18b_0
zipp=3.11.0=py39h06a4308_0
zlib=1.2.13=h5eee18b_0
zstd=1.5.5=hc292b87_0
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import os
import torchvision as tv
import torchvision.transforms as transforms
# device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def dataloader(trainset,testset,batch_size=100):
trainloader = torch.utils.data.DataLoader(
trainset,
batch_size=batch_size,
shuffle=True,
num_workers=2)
testloader = torch.utils.data.DataLoader(
testset,
batch_size=batch_size,
shuffle=False,
num_workers=2)
return trainloader,testloader
def CIFAR10_dataset():
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
# datasets
trainset = tv.datasets.CIFAR10(
root='./data/',
train=True,
download=True,
transform=transform_train)
testset = tv.datasets.CIFAR10(
'./data/',
train=False,
download=True,
transform=transform_test)
return trainset, testset, transform_test
\ No newline at end of file
from utils import *
import os
from PIL import Image
class Adi_tools():
def __init__(self)-> None:
super(Adi_tools, self).__init__()
def Embedder_one_step(self, net, trainloader, optimizer, criterion, watermarking_dict):
'''
:param watermarking_dict: dictionary with all watermarking elements
:return: the different losses ( global loss, task loss, watermark loss)
'''
running_loss = 0
for i, data in enumerate(watermarking_dict["trainloader"], 0):
# split data into the image and its label
inputs, labels = data
inputs = inputs.to(device)
labels = labels.to(device)
if inputs.size()[1] == 1:
inputs.squeeze_(1)
inputs = torch.stack([inputs, inputs, inputs], 1)
# initialise the optimiser
optimizer.zero_grad()
# forward
outputs = net(inputs)
# backward
loss = criterion(outputs, labels)
loss.backward()
# update the optimizer
optimizer.step()
# loss
running_loss += loss.item()
return running_loss, running_loss, 0
def Detector(self, net, watermarking_dict):
"""
:param file_watermark: file that contain our saved watermark elements
:return: the extracted watermark, the hamming distance compared to the original watermark
"""
# watermarking_dict = np.load(file_watermark, allow_pickle='TRUE').item() #retrieve the dictionary
keys = watermarking_dict['watermark']
res = 0
for img_file, label in keys.items():
img = self.get_image(watermarking_dict['folder'] + img_file)
net_guess = self.inference(net, img, watermarking_dict['transform'])
if net_guess == label:
res += 1
return '%i/%i' %(res,len(keys)), len(keys)-res<.1*len(keys)
def init(self, net, watermarking_dict, save=None):
'''
:param net: network
:param watermarking_dict: dictionary with all watermarking elements
:param save: file's name to save the watermark
:return: watermark_dict with a new entry: the secret key matrix X
'''
folder=watermarking_dict["folder"]
list_i = self.list_image(folder)
keys = {}
for i in range(len(list_i)):
keys[list_i[i]] = i % watermarking_dict["num_class"]
for img_file, label in keys.items():
img = self.get_image(folder + img_file)
for k in range(watermarking_dict["power"]):
self.add_images(watermarking_dict["dataset"], img, label)
trainloader = torch.utils.data.DataLoader(watermarking_dict["dataset"], batch_size=watermarking_dict["batch_size"],shuffle=True,num_workers=2)
watermarking_dict["trainloader"]=trainloader
watermarking_dict["watermark"]=keys
return watermarking_dict
def list_image(self, main_dir):
"""return all file in the directory"""
res = []
for f in os.listdir(main_dir):
if not f.startswith('.'):
res.append(f)
return res
def add_images(self, dataset, image, label):
"""add an image with its label to the dataset
:param dataset: aimed dataset to be modified
:param image: image to be added
:param label: label of this image
:return: 0
"""
(taille, height, width, channel) = np.shape(dataset.data)
dataset.data = np.append(dataset.data, image)
dataset.targets.append(label)
dataset.data = np.reshape(dataset.data, (taille + 1, height, width, channel))
return 0
def get_image(self, name):
"""
:param name: file (including the path) of an image
:return: a numpy of this image"""
image = Image.open(name)
return np.array(image)
def inference(self, net, img, transform):
"""make the inference for one image and a given transform"""
img_tensor = transform(img).unsqueeze(0)
net.eval()
with torch.no_grad():
logits = net.forward(img_tensor.to(device))
_, predicted = torch.max(logits, 1) # take the maximum value of the last layer
return predicted
'''
tools = Adi_tools()
folder = 'adi/'
power = 10
watermarking_dict = {'folder': folder, 'power': power, 'dataset': trainset, 'num_class': 10,
'batch_size':batch_size,'transform': inference_transform, "types":1}
'''
\ No newline at end of file
import os
import json
from zipfile import ZipFile
import psutil
global error_t
from multiprocessing import Process
import config
error_t=True
def MPAI_AIFS_GetAndParseArchive(filename):
'''filename is a zipfld with at least a ".json"
:return the data structure'''
i = filename.find(".")
os.makedirs(filename[:i],exist_ok=True)
with ZipFile(filename, 'r') as zObject:
# Extracting all the members of the zip
# into a specific location.
zObject.extractall(
path=filename[:i])
for files in os.listdir(filename[:i]):
if '.json' in files:
json_file = open(filename[:i]+"/"+files)
return json.load(json_file)
return error_t
def MPAI_AIFU_Controller_Initialize():
'''initialize the controller and switch it on'''
return
def MPAI_AIFU_Controller_Destroy():
'''switch of controller'''
return
def MPAI_AIFM_AIM_Start(name):
'''start AIW named name (after parse),
:return AIW_ID (int)'''
p1 = Process(target=config.AIMs[name].run())
p1.start() ### run it somewhere
config.dict_process[name.lower()] = p1
return
def MPAI_AIFM_AIM_Pause(name):
'''Pause AIW named name with AIW_ID'''
if name in config.dict_process:
temp_p = psutil.Process(config.dict_process[name].pid)
temp_p.suspend()
print(name, "paused")
else:
print(name, "isn't running")
return error_t
def MPAI_AIFM_AIM_Resume(name):
'''Resume AIW named name with AIW_ID'''
if name.lower() in config.dict_process:
temp_p = psutil.Process(config.dict_process[name].pid)
temp_p.resume()
print(name, "resumed")
else:
print(name, "isn't running")
return error_t
def MPAI_AIFM_AIM_Stop(name):
'''Stop AIW named name with AIW_ID'''
if name.lower() in config.dict_process:
config.dict_process[name].terminate()
print(name, "stopped")
else:
print(name, "isn't running")
return
def MPAI_AIFM_AIM_GetStatus(name):
'''current state of the AIM named name in AIW_ID
:return status(int) [MPAI_AIM_ALIVE, MPAI_AIM_DEAD]'''
if name.lower() in config.dict_process:
print("status of %s: %s" % (name, str(config.dict_process[name].is_alive())))
else:
print(name, "was never initiated")
return error_t
def MPAI_AIFM_Port_Input_Write(AIM_name,port_name, message):
setattr(config.AIMs[AIM_name],port_name,message)
return error_t
def MPAI_AIFM_Port_Output_Read(AIM_name,port_name):
result=getattr(config.AIMs[AIM_name],port_name)
return result
def MPAI_AIFM_Port_Reset(AIM_name,port_name):
setattr(config.AIMs[AIM_name],port_name,None)
return error_t
\ No newline at end of file
from .gaussian import *
from .fine_tuning import *
from .pruning import *
from .quantization import *
from .knowledge_distillation import *
from .watermark_overwriting import *
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment