controller.py 8.35 KB
Newer Older
Carldst's avatar
Carldst committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
# controller.py https://www.bogotobogo.com/python/python_network_programming_server_client.php
import socket
import time
from Attacks import *
from UCHIDA import Uchi_tools
from ADI import Adi_tools
from multiprocessing import Process
from APIs import *
import psutil
import os
import ast
import tkinter as tk
from tkinter import filedialog
import wget
import config


# create a socket object
serversocket = socket.socket(
	        socket.AF_INET, socket.SOCK_STREAM)

# get local machine name
host = socket.gethostname()

port = 12345

# bind to the port
serversocket.bind((host, port))

# queue up to 5 requests
serversocket.listen(5)
print("Controller Initialized")
CompCostFlag=False
while True:
    # establish a connection
    clientsocket, addr = serversocket.accept()
    # print("Got a connection from %s" % str(addr))
    # currentTime = time.ctime(time.time()) + "\r\n"
    data=clientsocket.recv(1024)
    message=data.decode()
    message = message.split()
    if not data: break
    if "help" in message[0].lower():
        ### to be updated
        print(" this program is the implementation of NNW in the AIF")
        print(" you can run AIM/AIW by sending 'run XX' ")
        print(" you can pause AIM/AIW by sending 'stop XX' ")
        print(" you can resume AIM/AIW by sending 'resume XX' ")
        print(" you can obtain the status of AIM/AIW by sending 'status XX' ")
        print(" you can end the program by typing 'exit'")
        print(" ----------------------------------------")

    if "wget" in message[0].lower():
        test=wget.download(message[1])
        print(type(test))

    elif "getparse" in message[0].lower():

        #print(type(message[1]),message[1]) #always str
        root = tk.Tk()
        root.withdraw()
        filename = filedialog.askopenfilename(title='Select the parameter file', filetypes=(("Text files",
                                                                                             "*.zip"),
                                                                                            ("all files",
                                                                                             "*.*")))
        json_dict = MPAI_AIFS_GetAndParseArchive(filename)
        time.sleep(.5)
        import AIW.AIMs_files as AIMs_file
        config.AIM_dict = json_dict['SubAIMs']
        for i in range(len(json_dict['SubAIMs'])):
            config.AIMs[config.AIM_dict[i]["Name"]] = getattr(AIMs_file, config.AIM_dict[i]["Name"])()
        ### AIMs file should be in the .zip
        # print(AIMs.keys())

    elif 'write' in message[0].lower():
        ## message[1] AIM_name, message[2] port_name, message[3] what to write
        MPAI_AIFM_Port_Input_Write(message[1],message[2],message[3])
    elif "read" in message[0].lower():
        ## message[1] AIM_name, message[2] port_name
        result=MPAI_AIFM_Port_Output_Read(message[1],message[2])
        print(message[2], "of", message[1], ":", result, type(result))
    elif "reset" in message[0].lower():
        ## message[1] AIM_name, message[2] port_name
        MPAI_AIFM_Port_Reset(message[1],message[2])

    elif 'ComputationalCost' in message[0].lower():
        CompCostFlag=ast.literal_eval(str(message[1]))

    elif 'run' in message[0].lower():
        ### here you can want for a next message instead
        # print("waiting for the type of run")
        # clientsocket, addr = serversocket.accept()
        # data = clientsocket.recv(1024)
        # message = data.decode()
        # message = message.split()
        if "robustness" in message[1].lower():
            # run robustness 1 {"P":.1}
            # try message[2]:
            param = ast.literal_eval(str(message[3]))
            M_ID = int(message[2])
            root = tk.Tk()
            root.withdraw()
            reload = filedialog.askopenfilename(title='Select the parameter file') # show an "Open" dialog box and return the path to the selected file
            watermarked_parameters = torch.load(reload, map_location=torch.device('cpu'))

            MPAI_AIFM_Port_Input_Write("ModificationModule", "param",param)
            MPAI_AIFM_Port_Input_Write("ModificationModule", "checkpoint", watermarked_parameters)
            MPAI_AIFM_AIM_Start("ModificationModule")
            print("watermarked AIM altered")

            MPAI_AIFM_Port_Input_Write("WatermarkDecoder","altered_watermarked_parameters",
                                       MPAI_AIFM_Port_Output_Read("ModificationModule","output_0"))
            tools = Uchi_tools()
            root = tk.Tk()
            root.withdraw()
            reload_npy = filedialog.askopenfilename(title='Select the watermarking_dict')
            watermarking_dict = np.load(reload_npy, allow_pickle=True).item()
            ## to be placed
            MPAI_AIFM_Port_Input_Write("WatermarkDecoder", "tools", tools)
            MPAI_AIFM_Port_Input_Write("WatermarkDecoder", "watermarking_dict", watermarking_dict)
            if CompCostFlag:
                time1=time.time()
            MPAI_AIFM_AIM_Start("WatermarkDecoder")
            if CompCostFlag:
                time2=time.time()
            MPAI_AIFM_Port_Input_Write("Comparator","payload1",watermarking_dict["watermark"])
            MPAI_AIFM_Port_Input_Write("Comparator", "payload2",
                                       MPAI_AIFM_Port_Output_Read("WatermarkDecoder","output_0"))
            MPAI_AIFM_AIM_Start("Comparator")

            print('BER : %s' % (MPAI_AIFM_Port_Output_Read("Comparator","output_0")))
            print('Retrieved watermark : %s' % (MPAI_AIFM_Port_Output_Read("WatermarkDecoder","output_0")))
            if CompCostFlag:
                print("time of execution: %.5f sec" %(time2-time1))
        elif "imperceptibility" in message[1].lower():
            ## run imperceptibility vgg16 cifar10
            if "vgg16" in message[2].lower():
                model = tv.models.vgg16()
                model.classifier = nn.Linear(25088, 10)
            else:
                print(message[2],"not found - default loading vgg16")
                model = tv.models.vgg16()
                model.classifier = nn.Linear(25088, 10)

            if "cifar10" in message[3].lower():
                trainset,testset,tfm=CIFAR10_dataset()
            else:
                print(message[3],"not found - default loading CIFAR10")
                trainset,testset,tfm=CIFAR10_dataset()

            MPAI_AIFM_Port_Input_Write("WatermarkEmbedder", "AIM", model)
            if CompCostFlag:
                time1=time.time()
            MPAI_AIFM_AIM_Start("WatermarkEmbedder")
            if CompCostFlag:
                time2=time.time()

            MPAI_AIFM_Port_Input_Write("AIM", "model", model)
            MPAI_AIFM_Port_Input_Write("AIM", "parameters", MPAI_AIFM_Port_Output_Read("WatermarkEmbedder","output_0"))
            MPAI_AIFM_Port_Input_Write("AIM", "testingDataset", testset)

            MPAI_AIFM_AIM_Start("AIM")
            print('AIM_watermarked_result : %s' % (MPAI_AIFM_Port_Output_Read("AIM","output_0")))
            if CompCostFlag:
                print("time of execution: %.5f sec" %(time2-time1))

            MPAI_AIFM_Port_Input_Write("AIMtrainer", "AIM", model)
            MPAI_AIFM_AIM_Start("AIMtrainer")
            MPAI_AIFM_Port_Input_Write("AIM", "model", model)
            MPAI_AIFM_Port_Input_Write("AIM", "parameters", MPAI_AIFM_Port_Output_Read("AIMtrainer","output_0"))
            MPAI_AIFM_Port_Input_Write("AIM", "testingDataset", testset)

            MPAI_AIFM_AIM_Start("AIM")
            print('AIM_unwatermarked_result : %s' % (MPAI_AIFM_Port_Output_Read("AIM", "output_0")))

        else:
            print(message[1].lower(),"not implemented")

    elif 'status' in message[0].lower():
        print(config.dict_process)
        MPAI_AIFM_AIM_GetStatus(message[1])
    elif 'pause' in message[0].lower():
        MPAI_AIFM_AIM_Pause(message[1])
    elif 'resume' in message[0].lower():
        MPAI_AIFM_AIM_Resume(message[1])

    elif 'stop' in message[0].lower():
        if message[1].lower() in config.dict_process:
            config.dict_process[message[1]].terminate()
            print( message[1], "stopped")
        else:
            print(message[1], "isn't running")
    elif "exit" in message[0].lower():
        print("ending session...")

        break
    else:
        print("input not implemented")
    clientsocket.close()
print("session ended")

### TO DO https://docs.python.org/3/library/multiprocessing.html