# controller.py https://www.bogotobogo.com/python/python_network_programming_server_client.php
import socket
import time
from Attacks import *
from UCHIDA import Uchi_tools
from ADI import Adi_tools
from multiprocessing import Process
from APIs import *
import psutil
import os
import ast
import tkinter as tk
from tkinter import filedialog
import wget
import config


# create a socket object
serversocket = socket.socket(
	        socket.AF_INET, socket.SOCK_STREAM)

# get local machine name
host = socket.gethostname()

port = 12345

# bind to the port
serversocket.bind((host, port))

# queue up to 5 requests
serversocket.listen(5)
print("Controller Initialized")
CompCostFlag=False
while True:
    # establish a connection
    clientsocket, addr = serversocket.accept()
    # print("Got a connection from %s" % str(addr))
    # currentTime = time.ctime(time.time()) + "\r\n"
    data=clientsocket.recv(1024)
    message=data.decode()
    message = message.split()
    if not data: break
    if "help" in message[0].lower():
        ### to be updated
        print(" ----------------------------------------")
        print(" this program is the implementation of NNW in the AIF")
        print(" you can run AIM/AIW by sending 'run XX' ")
        print(" you can pause AIM/AIW by sending 'stop XX' ")
        print(" you can resume AIM/AIW by sending 'resume XX' ")
        print(" you can obtain the status of AIM/AIW by sending 'status XX' ")
        print(" you can end the program by typing 'exit'")
        print(" ----------------------------------------")

    elif "wget" in message[0].lower():
        test=wget.download(message[1])
        print(type(test))

    elif "getparse" in message[0].lower():

        #print(type(message[1]),message[1]) #always str
        root = tk.Tk()
        root.withdraw()
        filename = filedialog.askopenfilename(title='Select the parameter file', filetypes=(("Text files",
                                                                                             "*.zip"),
                                                                                            ("all files",
                                                                                             "*.*")))
        json_dict = MPAI_AIFS_GetAndParseArchive(filename)
        time.sleep(.5)
        import AIW.AIMs_files as AIMs_file
        config.AIM_dict = json_dict['SubAIMs']
        config.Topology = json_dict['Topology'] ### topology

        for i in range(len(json_dict['SubAIMs'])):
            config.AIMs[config.AIM_dict[i]["Name"]] = getattr(AIMs_file, config.AIM_dict[i]["Name"])()

        ### AIMs file should be in the .zip
        # print(AIMs.keys())
        # print(config.Topology)
        print(".json parsed")
    elif 'write' in message[0].lower():
        ## message[1] AIM_name, message[2] port_name, message[3] what to write
        MPAI_AIFM_Port_Input_Write(message[1],message[2],message[3])
    elif "read" in message[0].lower():
        ## message[1] AIM_name, message[2] port_name
        result=MPAI_AIFM_Port_Output_Read(message[1],message[2])
        print(message[2], "of", message[1], ":", result, type(result))
    elif "reset" in message[0].lower():
        ## message[1] AIM_name, message[2] port_name
        MPAI_AIFM_Port_Reset(message[1],message[2])

    elif 'ComputationalCost' in message[0].lower():
        CompCostFlag=ast.literal_eval(str(message[1]))

    elif 'run' in message[0].lower():
        ### here you can want for a next message instead
        # print("waiting for the type of run")
        # clientsocket, addr = serversocket.accept()
        # data = clientsocket.recv(1024)
        # message = data.decode()
        # message = message.split()
        if "robustness" in message[1].lower():
            # run robustness 1 {"P":.1}
            param = ast.literal_eval(str(message[3]))
            M_ID = int(message[2])
            root = tk.Tk()
            root.withdraw()
            reload = filedialog.askopenfilename(title='Select the parameter file') # show an "Open" dialog box and return the path to the selected file
            watermarked_parameters = torch.load(reload, map_location=torch.device('cpu'))
            config.message["WatermarkedParameter"] = watermarked_parameters
            tools = Uchi_tools()
            root = tk.Tk()
            root.withdraw()
            reload_npy = filedialog.askopenfilename(title='Select the watermarking_dict')
            watermarking_dict = np.load(reload_npy, allow_pickle=True).item()
            MPAI_AIFM_Port_Input_Write("WatermarkDecoder", "tools", tools)
            MPAI_AIFM_Port_Input_Write("WatermarkDecoder", "watermarking_dict", watermarking_dict)
            MPAI_AIFM_Port_Input_Write("Comparator", "Payload", watermarking_dict["watermark"])



            ### Automatized
            for elements in config.Topology:
                # print(elements)
                if elements["Output"]["AIMName"]=="":
                    MPAI_AIFM_Port_Input_Write(elements["Input"]["AIMName"], elements["Input"]["PortName"],
                                               config.message[elements["Output"]["PortName"]])
                else:
                    if CompCostFlag:
                        time1 = time.time()
                    MPAI_AIFM_AIM_Start(elements["Output"]["AIMName"])
                    if CompCostFlag:
                        time2 = time.time()
                    if elements["Input"]["AIMName"]=="":
                        print("Output of",elements["Output"]["AIMName"],"port",elements["Output"]["PortName"] )
                        print(MPAI_AIFM_Port_Output_Read(elements["Output"]["AIMName"],elements["Input"]["PortName"]))
                    else:
                        MPAI_AIFM_Port_Input_Write(elements["Input"]["AIMName"], elements["Input"]["PortName"],
                                               MPAI_AIFM_Port_Output_Read(elements["Output"]["AIMName"],elements["Output"]["PortName"]))


            MPAI_AIFM_AIM_Start("Comparator")
            print('BER : %s' % (MPAI_AIFM_Port_Output_Read("Comparator","output_0")))
            if CompCostFlag:
                print("time of execution: %.5f sec" %(time2-time1))
        elif "imperceptibility" in message[1].lower():
            ## run imperceptibility vgg16 cifar10
            if "vgg16" in message[2].lower():
                model = tv.models.vgg16()
                model.classifier = nn.Linear(25088, 10)
            else:
                print(message[2],"not found - default loading vgg16")
                model = tv.models.vgg16()
                model.classifier = nn.Linear(25088, 10)

            if "cifar10" in message[3].lower():
                trainset,testset,tfm=CIFAR10_dataset()
            else:
                print(message[3],"not found - default loading CIFAR10")
                trainset,testset,tfm=CIFAR10_dataset()

            MPAI_AIFM_Port_Input_Write("WatermarkEmbedder", "AIM", model)
            if CompCostFlag:
                time1=time.time()
            MPAI_AIFM_AIM_Start("WatermarkEmbedder")
            if CompCostFlag:
                time2=time.time()

            MPAI_AIFM_Port_Input_Write("AIM", "model", model)
            MPAI_AIFM_Port_Input_Write("AIM", "parameters", MPAI_AIFM_Port_Output_Read("WatermarkEmbedder","output_0"))
            MPAI_AIFM_Port_Input_Write("AIM", "testingDataset", testset)

            MPAI_AIFM_AIM_Start("AIM")
            print('AIM_watermarked_result : %s' % (MPAI_AIFM_Port_Output_Read("AIM","output_0")))
            if CompCostFlag:
                print("time of execution: %.5f sec" %(time2-time1))

            MPAI_AIFM_Port_Input_Write("AIMtrainer", "AIM", model)
            MPAI_AIFM_AIM_Start("AIMtrainer")
            MPAI_AIFM_Port_Input_Write("AIM", "model", model)
            MPAI_AIFM_Port_Input_Write("AIM", "parameters", MPAI_AIFM_Port_Output_Read("AIMtrainer","output_0"))
            MPAI_AIFM_Port_Input_Write("AIM", "testingDataset", testset)

            MPAI_AIFM_AIM_Start("AIM")
            print('AIM_unwatermarked_result : %s' % (MPAI_AIFM_Port_Output_Read("AIM", "output_0")))

        else:
            print(message[1].lower(),"not implemented")

    elif 'status' in message[0].lower():
        print(config.dict_process)
        MPAI_AIFM_AIM_GetStatus(message[1])
    elif 'pause' in message[0].lower():
        MPAI_AIFM_AIM_Pause(message[1])
    elif 'resume' in message[0].lower():
        MPAI_AIFM_AIM_Resume(message[1])

    elif 'stop' in message[0].lower():
        if message[1].lower() in config.dict_process:
            config.dict_process[message[1]].terminate()
            print( message[1], "stopped")
        else:
            print(message[1], "isn't running")
    elif "exit" in message[0].lower():
        print("ending session...")

        break
    else:
        print("input not implemented")
    clientsocket.close()
print("session ended")

### TO DO https://docs.python.org/3/library/multiprocessing.html
