Commit bca5eca1 authored by Carldst's avatar Carldst
Browse files

initial push

parents
Pipeline #34 failed with stages
in 0 seconds
from utils import *
import os
from PIL import Image
class Adi_tools():
def __init__(self)-> None:
super(Adi_tools, self).__init__()
def Embedder_one_step(self, net, trainloader, optimizer, criterion, watermarking_dict):
'''
:param watermarking_dict: dictionary with all watermarking elements
:return: the different losses ( global loss, task loss, watermark loss)
'''
running_loss = 0
for i, data in enumerate(watermarking_dict["trainloader"], 0):
# split data into the image and its label
inputs, labels = data
inputs = inputs.to(device)
labels = labels.to(device)
if inputs.size()[1] == 1:
inputs.squeeze_(1)
inputs = torch.stack([inputs, inputs, inputs], 1)
# initialise the optimiser
optimizer.zero_grad()
# forward
outputs = net(inputs)
# backward
loss = criterion(outputs, labels)
loss.backward()
# update the optimizer
optimizer.step()
# loss
running_loss += loss.item()
return running_loss, running_loss, 0
def Detector(self, net, watermarking_dict):
"""
:param file_watermark: file that contain our saved watermark elements
:return: the extracted watermark, the hamming distance compared to the original watermark
"""
# watermarking_dict = np.load(file_watermark, allow_pickle='TRUE').item() #retrieve the dictionary
keys = watermarking_dict['watermark']
res = 0
for img_file, label in keys.items():
img = self.get_image(watermarking_dict['folder'] + img_file)
net_guess = self.inference(net, img, watermarking_dict['transform'])
if net_guess == label:
res += 1
return '%i/%i' %(res,len(keys)), len(keys)-res<.1*len(keys)
def init(self, net, watermarking_dict, save=None):
'''
:param net: network
:param watermarking_dict: dictionary with all watermarking elements
:param save: file's name to save the watermark
:return: watermark_dict with a new entry: the secret key matrix X
'''
folder=watermarking_dict["folder"]
list_i = self.list_image(folder)
keys = {}
for i in range(len(list_i)):
keys[list_i[i]] = i % watermarking_dict["num_class"]
for img_file, label in keys.items():
img = self.get_image(folder + img_file)
for k in range(watermarking_dict["power"]):
self.add_images(watermarking_dict["dataset"], img, label)
trainloader = torch.utils.data.DataLoader(watermarking_dict["dataset"], batch_size=watermarking_dict["batch_size"],shuffle=True,num_workers=2)
watermarking_dict["trainloader"]=trainloader
watermarking_dict["watermark"]=keys
return watermarking_dict
def list_image(self, main_dir):
"""return all file in the directory"""
res = []
for f in os.listdir(main_dir):
if not f.startswith('.'):
res.append(f)
return res
def add_images(self, dataset, image, label):
"""add an image with its label to the dataset
:param dataset: aimed dataset to be modified
:param image: image to be added
:param label: label of this image
:return: 0
"""
(taille, height, width, channel) = np.shape(dataset.data)
dataset.data = np.append(dataset.data, image)
dataset.targets.append(label)
dataset.data = np.reshape(dataset.data, (taille + 1, height, width, channel))
return 0
def get_image(self, name):
"""
:param name: file (including the path) of an image
:return: a numpy of this image"""
image = Image.open(name)
return np.array(image)
def inference(self, net, img, transform):
"""make the inference for one image and a given transform"""
img_tensor = transform(img).unsqueeze(0)
net.eval()
with torch.no_grad():
logits = net.forward(img_tensor.to(device))
_, predicted = torch.max(logits, 1) # take the maximum value of the last layer
return predicted
'''
tools = Adi_tools()
folder = 'adi/'
power = 10
watermarking_dict = {'folder': folder, 'power': power, 'dataset': trainset, 'num_class': 10,
'batch_size':batch_size,'transform': inference_transform, "types":1}
'''
\ No newline at end of file
import os
import json
from zipfile import ZipFile
import psutil
global error_t
from multiprocessing import Process
import config
error_t=True
def MPAI_AIFS_GetAndParseArchive(filename):
'''filename is a zipfld with at least a ".json"
:return the data structure'''
i = filename.find(".")
os.makedirs(filename[:i],exist_ok=True)
with ZipFile(filename, 'r') as zObject:
# Extracting all the members of the zip
# into a specific location.
zObject.extractall(
path=filename[:i])
for files in os.listdir(filename[:i]):
if '.json' in files:
json_file = open(filename[:i]+"/"+files)
return json.load(json_file)
return error_t
def MPAI_AIFU_Controller_Initialize():
'''initialize the controller and switch it on'''
return
def MPAI_AIFU_Controller_Destroy():
'''switch of controller'''
return
def MPAI_AIFM_AIM_Start(name):
'''start AIW named name (after parse),
:return AIW_ID (int)'''
p1 = Process(target=config.AIMs[name].run())
p1.start() ### run it somewhere
config.dict_process[name.lower()] = p1
return
def MPAI_AIFM_AIM_Pause(name):
'''Pause AIW named name with AIW_ID'''
if name in config.dict_process:
temp_p = psutil.Process(config.dict_process[name].pid)
temp_p.suspend()
print(name, "paused")
else:
print(name, "isn't running")
return error_t
def MPAI_AIFM_AIM_Resume(name):
'''Resume AIW named name with AIW_ID'''
if name.lower() in config.dict_process:
temp_p = psutil.Process(config.dict_process[name].pid)
temp_p.resume()
print(name, "resumed")
else:
print(name, "isn't running")
return error_t
def MPAI_AIFM_AIM_Stop(name):
'''Stop AIW named name with AIW_ID'''
if name.lower() in config.dict_process:
config.dict_process[name].terminate()
print(name, "stopped")
else:
print(name, "isn't running")
return
def MPAI_AIFM_AIM_GetStatus(name):
'''current state of the AIM named name in AIW_ID
:return status(int) [MPAI_AIM_ALIVE, MPAI_AIM_DEAD]'''
if name.lower() in config.dict_process:
print("status of %s: %s" % (name, str(config.dict_process[name].is_alive())))
else:
print(name, "was never initiated")
return error_t
def MPAI_AIFM_Port_Input_Write(AIM_name,port_name, message):
setattr(config.AIMs[AIM_name],port_name,message)
return error_t
def MPAI_AIFM_Port_Output_Read(AIM_name,port_name):
result=getattr(config.AIMs[AIM_name],port_name)
return result
def MPAI_AIFM_Port_Reset(AIM_name,port_name):
setattr(config.AIMs[AIM_name],port_name,None)
return error_t
\ No newline at end of file
from .gaussian import *
from .fine_tuning import *
from .pruning import *
from .quantization import *
from .knowledge_distillation import *
from .watermark_overwriting import *
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment