
import UCHIDA
import time
from utils import *
from Attacks import *

#####CHangement de niveau AIMs doit etre quelque chose qui possède des ports
class ModificationModule():
    Modification_ID = 1
    parameters = {"P": .5}
    checkpoint = torch.load("./vgg16_Uchi_weights", map_location=torch.device('cpu'))

    ##
    output_0 = None

    def funcModificationModule(self, Modification_ID, parameters, checkpoint):
        '''
        Apply a modification based on the ID and parameters
        :param Modification_ID: ID of the modification
        :param net: network to be altered
        :param parameters: parameters of the modification
        :return: altered NN
        '''
        net = tv.models.vgg16()
        net.classifier = nn.Linear(25088, 10)
        net.load_state_dict(checkpoint["model_state_dict"])
        if Modification_ID==0:
            if parameters["name"]=="all":
                net= adding_noise_global(net,parameters["S"])
            for module in parameters["name"]:
                net=adding_noise(net,parameters["S"],module)
        elif Modification_ID==1:
            net= prune_model_l1_unstructured(net, parameters["P"])
        elif Modification_ID==2:
            net= prune_model_random_unstructured(net,parameters["R"])
        elif Modification_ID==3:
            net= quantization(net,parameters["B"])
        elif Modification_ID==4:
            net= finetuning(net,parameters["E"],parameters["trainloader"])
        elif Modification_ID==5:
            net= knowledge_distillation(net,parameters["E"],parameters["trainloader"],parameters["student"])
        elif Modification_ID==6:
            net=overwriting(net, parameters["NNWmethods"], parameters["W"], parameters["watermarking_dict"])
        else:
            print("NotImplemented")
        torch.save({
            'model_state_dict': net.state_dict(),
        }, 'altered_weights.pt')
        altered_parameters = torch.load('altered_weights.pt', map_location=torch.device('cpu'))
        return altered_parameters

    def run(self):
        self.output_0=self.funcModificationModule(self.Modification_ID, self.parameters, self.checkpoint)

class WatermarkDecoder():
    altered_watermarked_parameters=None
    watermarking_dict=None
    tools=UCHIDA.Uchi_tools
    ##
    output_0 = None

    def funcWatermarkDecoder(self, altered_watermarked_parameters, tools, watermarking_dict):
        model = tv.models.vgg16()
        model.classifier = nn.Linear(25088, 10)
        model.to(device)
        model.load_state_dict(altered_watermarked_parameters["model_state_dict"])
        return tools.Decoder(model, watermarking_dict)

    def run(self):
        self.output_0=self.funcWatermarkDecoder(self.altered_watermarked_parameters, self.tools, self.watermarking_dict)

class Comparator():
    payload1=[]
    payload2=[]
    ##
    output_0=None

    def funcComparator(self,payload1,payload2):
        if not type(payload1)==type(payload2):
            return "missmatched payloads"
        if not len(payload1)==len(payload2):
            return "Not the same length"
        if torch.is_tensor(payload1):
            return (torch.count_nonzero(payload1-payload2)/len(payload1))

    def run(self):
        self.output_0=self.funcComparator(self.payload1, self.payload2)

# class APItest():
#     time_set=2
#     ##
#     output_0=None
#
#     def func1(self,time_set):
#         print("func1: starting")
#         for i in range(10):
#             time.sleep(2)
#             pass
#             print("func1: finishing")
#
#     def run(self):
#         self.output_0=self.func1(self.time_set)