Spaces:
Runtime error
Runtime error
Upload 3 files
Browse files- PE_main.py +19 -2
- fhe_utils.py +223 -0
- train_FHE.py +171 -0
PE_main.py
CHANGED
|
@@ -76,9 +76,9 @@ def get_version_info(pe):
|
|
| 76 |
return res
|
| 77 |
|
| 78 |
#extract the info for a given file using pefile
|
| 79 |
-
def extract_infos(
|
| 80 |
res = {}
|
| 81 |
-
pe = pefile.PE(
|
| 82 |
res['Machine'] = pe.FILE_HEADER.Machine
|
| 83 |
res['SizeOfOptionalHeader'] = pe.FILE_HEADER.SizeOfOptionalHeader
|
| 84 |
res['Characteristics'] = pe.FILE_HEADER.Characteristics
|
|
@@ -182,3 +182,20 @@ def extract_infos(file):
|
|
| 182 |
res['VersionInformationSize'] = 0
|
| 183 |
return res
|
| 184 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
return res
|
| 77 |
|
| 78 |
#extract the info for a given file using pefile
|
| 79 |
+
def extract_infos(fpath):
|
| 80 |
res = {}
|
| 81 |
+
pe = pefile.PE(fpath)
|
| 82 |
res['Machine'] = pe.FILE_HEADER.Machine
|
| 83 |
res['SizeOfOptionalHeader'] = pe.FILE_HEADER.SizeOfOptionalHeader
|
| 84 |
res['Characteristics'] = pe.FILE_HEADER.Characteristics
|
|
|
|
| 182 |
res['VersionInformationSize'] = 0
|
| 183 |
return res
|
| 184 |
|
| 185 |
+
|
| 186 |
+
if __name__ == '__main__':
|
| 187 |
+
|
| 188 |
+
#Loading the classifier.pkl and features.pkl
|
| 189 |
+
clf = joblib.load('Classifier/classifier.pkl')
|
| 190 |
+
features = pickle.loads(open(os.path.join('Classifier/features.pkl'),'rb').read())
|
| 191 |
+
|
| 192 |
+
#extracting features from the PE file mentioned in the argument
|
| 193 |
+
data = extract_infos(sys.argv[1])
|
| 194 |
+
|
| 195 |
+
#matching it with the features saved in features.pkl
|
| 196 |
+
pe_features = list(map(lambda x:data[x], features))
|
| 197 |
+
print("Features used for classification: ", pe_features)
|
| 198 |
+
|
| 199 |
+
#prediciting if the PE is malicious or not based on the extracted features
|
| 200 |
+
res= clf.predict([pe_features])[0]
|
| 201 |
+
print ('The file %s is %s' % (os.path.basename(sys.argv[1]),['malicious', 'legitimate'][res]))
|
fhe_utils.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
import pdb
|
| 4 |
+
import numpy as np
|
| 5 |
+
import random
|
| 6 |
+
import json
|
| 7 |
+
import shutil
|
| 8 |
+
import time
|
| 9 |
+
from scipy.stats import pearsonr
|
| 10 |
+
from sklearn.model_selection import KFold
|
| 11 |
+
import matplotlib.pyplot as plt
|
| 12 |
+
import xgboost as xgb
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
random.seed(42)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
import gzip
|
| 18 |
+
import numpy as np
|
| 19 |
+
import pandas as pd
|
| 20 |
+
import requests
|
| 21 |
+
from io import BytesIO
|
| 22 |
+
from concrete.ml.deployment import FHEModelClient, FHEModelDev, FHEModelServer
|
| 23 |
+
|
| 24 |
+
from concrete.ml.sklearn import DecisionTreeClassifier as DecisionTreeClassifierZAMA
|
| 25 |
+
|
| 26 |
+
from concrete.ml.sklearn import LinearSVC as LinearSVCZAMA
|
| 27 |
+
|
| 28 |
+
from sklearn.svm import LinearSVR as LinearSVR
|
| 29 |
+
import time
|
| 30 |
+
from shutil import copyfile
|
| 31 |
+
from tempfile import TemporaryDirectory
|
| 32 |
+
import pickle
|
| 33 |
+
import os
|
| 34 |
+
import time
|
| 35 |
+
import numpy as np
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def convert_numpy(obj):
|
| 39 |
+
if isinstance(obj, np.integer):
|
| 40 |
+
return int(obj)
|
| 41 |
+
elif isinstance(obj, np.floating):
|
| 42 |
+
return float(obj)
|
| 43 |
+
elif isinstance(obj, np.ndarray):
|
| 44 |
+
return obj.tolist()
|
| 45 |
+
else:
|
| 46 |
+
return obj
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class OnDiskNetwork:
|
| 50 |
+
"""Simulate a network on disk."""
|
| 51 |
+
|
| 52 |
+
def __init__(self):
|
| 53 |
+
# Create 3 temporary folder for server, client and dev with tempfile
|
| 54 |
+
self.server_dir = TemporaryDirectory()
|
| 55 |
+
self.client_dir = TemporaryDirectory()
|
| 56 |
+
self.dev_dir = TemporaryDirectory()
|
| 57 |
+
|
| 58 |
+
def client_send_evaluation_key_to_server(self, serialized_evaluation_keys):
|
| 59 |
+
"""Send the public key to the server."""
|
| 60 |
+
with open(self.server_dir.name + "/serialized_evaluation_keys.ekl", "wb") as f:
|
| 61 |
+
f.write(serialized_evaluation_keys)
|
| 62 |
+
|
| 63 |
+
def client_send_input_to_server_for_prediction(self, encrypted_input):
|
| 64 |
+
"""Send the input to the server and execute on the server in FHE."""
|
| 65 |
+
with open(self.server_dir.name + "/serialized_evaluation_keys.ekl", "rb") as f:
|
| 66 |
+
serialized_evaluation_keys = f.read()
|
| 67 |
+
time_begin = time.time()
|
| 68 |
+
encrypted_prediction = FHEModelServer(self.server_dir.name).run(
|
| 69 |
+
encrypted_input, serialized_evaluation_keys
|
| 70 |
+
)
|
| 71 |
+
time_end = time.time()
|
| 72 |
+
with open(self.server_dir.name + "/encrypted_prediction.enc", "wb") as f:
|
| 73 |
+
f.write(encrypted_prediction)
|
| 74 |
+
return time_end - time_begin
|
| 75 |
+
|
| 76 |
+
def dev_send_model_to_server(self):
|
| 77 |
+
"""Send the model to the server."""
|
| 78 |
+
copyfile(
|
| 79 |
+
self.dev_dir.name + "/server.zip", self.server_dir.name + "/server.zip"
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
def server_send_encrypted_prediction_to_client(self):
|
| 83 |
+
"""Send the encrypted prediction to the client."""
|
| 84 |
+
with open(self.server_dir.name + "/encrypted_prediction.enc", "rb") as f:
|
| 85 |
+
encrypted_prediction = f.read()
|
| 86 |
+
return encrypted_prediction
|
| 87 |
+
|
| 88 |
+
def dev_send_clientspecs_and_modelspecs_to_client(self):
|
| 89 |
+
"""Send the clientspecs and evaluation key to the client."""
|
| 90 |
+
copyfile(
|
| 91 |
+
self.dev_dir.name + "/client.zip", self.client_dir.name + "/client.zip"
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
def cleanup(self):
|
| 95 |
+
"""Clean up the temporary folders."""
|
| 96 |
+
self.server_dir.cleanup()
|
| 97 |
+
self.client_dir.cleanup()
|
| 98 |
+
self.dev_dir.cleanup()
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def generate_fingerprint(smiles, radius=2, bits=512):
|
| 102 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 103 |
+
if mol is None:
|
| 104 |
+
return np.nan
|
| 105 |
+
|
| 106 |
+
fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=radius, nBits=bits)
|
| 107 |
+
|
| 108 |
+
return np.array(fp)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def train_xgb_regressor(X_train, y_train, param_grid=None, verbose=10):
|
| 112 |
+
if param_grid is None:
|
| 113 |
+
param_grid = {
|
| 114 |
+
"max_depth": [3, 6],
|
| 115 |
+
"learning_rate": [0.01, 0.1, 0.2],
|
| 116 |
+
"n_estimators": [20],
|
| 117 |
+
"colsample_bytree": [0.3, 0.7],
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
xgb_regressor = xgb.XGBRegressor(objective="reg:squarederror")
|
| 121 |
+
|
| 122 |
+
kfold = KFold(n_splits=5, shuffle=True, random_state=42)
|
| 123 |
+
grid_search = GridSearchCV(
|
| 124 |
+
estimator=xgb_regressor,
|
| 125 |
+
param_grid=param_grid,
|
| 126 |
+
cv=kfold,
|
| 127 |
+
verbose=verbose,
|
| 128 |
+
n_jobs=-1,
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
grid_search.fit(X_train, y_train)
|
| 132 |
+
return (
|
| 133 |
+
grid_search.best_params_,
|
| 134 |
+
grid_search.best_score_,
|
| 135 |
+
grid_search.best_estimator_,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def evaluate_model(model, X_test, y_test):
|
| 140 |
+
y_pred = model.predict(X_test)
|
| 141 |
+
pearsonr_score = pearsonr(y_test, y_pred).statistic
|
| 142 |
+
return pearsonr_score
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def setup_network(model_dev):
|
| 146 |
+
network = OnDiskNetwork()
|
| 147 |
+
fhemodel_dev = FHEModelDev(network.dev_dir.name, model_dev)
|
| 148 |
+
fhemodel_dev.save(via_mlir=True)
|
| 149 |
+
return network, fhemodel_dev
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def copy_directory(source, destination="deployment"):
|
| 153 |
+
try:
|
| 154 |
+
# Check if the source directory exists
|
| 155 |
+
if not os.path.exists(source):
|
| 156 |
+
return False, "Source directory does not exist."
|
| 157 |
+
|
| 158 |
+
# Check if the destination directory exists
|
| 159 |
+
if not os.path.exists(destination):
|
| 160 |
+
os.makedirs(destination)
|
| 161 |
+
|
| 162 |
+
# Copy each item in the source directory
|
| 163 |
+
for item in os.listdir(source):
|
| 164 |
+
s = os.path.join(source, item)
|
| 165 |
+
d = os.path.join(destination, item)
|
| 166 |
+
if os.path.isdir(s):
|
| 167 |
+
shutil.copytree(
|
| 168 |
+
s, d, dirs_exist_ok=True
|
| 169 |
+
) # dirs_exist_ok is available from Python 3.8
|
| 170 |
+
else:
|
| 171 |
+
shutil.copy2(s, d)
|
| 172 |
+
|
| 173 |
+
return True, None
|
| 174 |
+
|
| 175 |
+
except Exception as e:
|
| 176 |
+
return False, str(e)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def client_server_interaction(network, fhemodel_client, X_client):
|
| 180 |
+
decrypted_predictions = []
|
| 181 |
+
execution_time = []
|
| 182 |
+
for i in tqdm(range(X_client.shape[0])):
|
| 183 |
+
clear_input = X_client[[i], :]
|
| 184 |
+
encrypted_input = fhemodel_client.quantize_encrypt_serialize(clear_input)
|
| 185 |
+
execution_time.append(
|
| 186 |
+
network.client_send_input_to_server_for_prediction(encrypted_input)
|
| 187 |
+
)
|
| 188 |
+
encrypted_prediction = network.server_send_encrypted_prediction_to_client()
|
| 189 |
+
decrypted_prediction = fhemodel_client.deserialize_decrypt_dequantize(
|
| 190 |
+
encrypted_prediction
|
| 191 |
+
)[0]
|
| 192 |
+
decrypted_predictions.append(decrypted_prediction)
|
| 193 |
+
#pdb.set_trace()
|
| 194 |
+
return decrypted_predictions, execution_time
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def train_zama(X_train, y_train):
|
| 198 |
+
|
| 199 |
+
model_dev = LinearSVCZAMA()
|
| 200 |
+
# LinearSVCZAMA()
|
| 201 |
+
# DecisionTreeClassifierZAMA()
|
| 202 |
+
|
| 203 |
+
print("Training Zama model...")
|
| 204 |
+
model_dev.fit(X_train, y_train)
|
| 205 |
+
print("compiling model...")
|
| 206 |
+
model_dev.compile(X_train)
|
| 207 |
+
print("done")
|
| 208 |
+
|
| 209 |
+
return model_dev
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def time_prediction(model, X_sample):
|
| 213 |
+
time_begin = time.time()
|
| 214 |
+
y_pred_fhe = model.predict(X_sample, fhe="execute")
|
| 215 |
+
time_end = time.time()
|
| 216 |
+
return time_end - time_begin
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def setup_client(network, key_dir):
|
| 220 |
+
fhemodel_client = FHEModelClient(network.client_dir.name, key_dir=key_dir)
|
| 221 |
+
fhemodel_client.generate_private_and_evaluation_keys()
|
| 222 |
+
serialized_evaluation_keys = fhemodel_client.get_serialized_evaluation_keys()
|
| 223 |
+
return fhemodel_client, serialized_evaluation_keys
|
train_FHE.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import numpy
|
| 4 |
+
import pickle
|
| 5 |
+
import pefile
|
| 6 |
+
import sklearn.ensemble as ek
|
| 7 |
+
from sklearn.feature_selection import SelectFromModel
|
| 8 |
+
import joblib
|
| 9 |
+
from sklearn.tree import DecisionTreeClassifier
|
| 10 |
+
from sklearn.metrics import confusion_matrix
|
| 11 |
+
from sklearn import svm
|
| 12 |
+
import sklearn.metrics as metrics
|
| 13 |
+
from sklearn.model_selection import train_test_split
|
| 14 |
+
import pdb
|
| 15 |
+
|
| 16 |
+
from tqdm import tqdm
|
| 17 |
+
|
| 18 |
+
dataset = pd.read_csv("data.csv", sep="|")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# Feature
|
| 22 |
+
X = dataset.drop(
|
| 23 |
+
["Name", "md5", "legitimate"], axis=1
|
| 24 |
+
).values # Droping this because classification model will not accept object type elements (float and int only)
|
| 25 |
+
# Target variable
|
| 26 |
+
|
| 27 |
+
ugly = [
|
| 28 |
+
"Machine",
|
| 29 |
+
"SizeOfOptionalHeader",
|
| 30 |
+
"Characteristics",
|
| 31 |
+
"MajorLinkerVersion",
|
| 32 |
+
"MinorLinkerVersion",
|
| 33 |
+
"SizeOfCode",
|
| 34 |
+
"SizeOfInitializedData",
|
| 35 |
+
"SizeOfUninitializedData",
|
| 36 |
+
"AddressOfEntryPoint",
|
| 37 |
+
"BaseOfCode",
|
| 38 |
+
"BaseOfData",
|
| 39 |
+
"ImageBase",
|
| 40 |
+
"SectionAlignment",
|
| 41 |
+
"FileAlignment",
|
| 42 |
+
"MajorOperatingSystemVersion",
|
| 43 |
+
"MinorOperatingSystemVersion",
|
| 44 |
+
"MajorImageVersion",
|
| 45 |
+
"MinorImageVersion",
|
| 46 |
+
"MajorSubsystemVersion",
|
| 47 |
+
"MinorSubsystemVersion",
|
| 48 |
+
"SizeOfImage",
|
| 49 |
+
"SizeOfHeaders",
|
| 50 |
+
"CheckSum",
|
| 51 |
+
"Subsystem",
|
| 52 |
+
"DllCharacteristics",
|
| 53 |
+
"SizeOfStackReserve",
|
| 54 |
+
"SizeOfStackCommit",
|
| 55 |
+
"SizeOfHeapReserve",
|
| 56 |
+
"SizeOfHeapCommit",
|
| 57 |
+
"LoaderFlags",
|
| 58 |
+
"NumberOfRvaAndSizes",
|
| 59 |
+
"SectionsNb",
|
| 60 |
+
"SectionsMeanEntropy",
|
| 61 |
+
"SectionsMinEntropy",
|
| 62 |
+
"SectionsMaxEntropy",
|
| 63 |
+
"SectionsMeanRawsize",
|
| 64 |
+
"SectionsMinRawsize",
|
| 65 |
+
#"SectionsMaxRawsize",
|
| 66 |
+
"SectionsMeanVirtualsize",
|
| 67 |
+
"SectionsMinVirtualsize",
|
| 68 |
+
"SectionMaxVirtualsize",
|
| 69 |
+
"ImportsNbDLL",
|
| 70 |
+
"ImportsNb",
|
| 71 |
+
"ImportsNbOrdinal",
|
| 72 |
+
"ExportNb",
|
| 73 |
+
"ResourcesNb",
|
| 74 |
+
"ResourcesMeanEntropy",
|
| 75 |
+
"ResourcesMinEntropy",
|
| 76 |
+
"ResourcesMaxEntropy",
|
| 77 |
+
"ResourcesMeanSize",
|
| 78 |
+
"ResourcesMinSize",
|
| 79 |
+
"ResourcesMaxSize",
|
| 80 |
+
"LoadConfigurationSize",
|
| 81 |
+
"VersionInformationSize",
|
| 82 |
+
]
|
| 83 |
+
|
| 84 |
+
X = dataset[ugly].values
|
| 85 |
+
|
| 86 |
+
y = dataset["legitimate"].values
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
extratrees = ek.ExtraTreesClassifier().fit(X[:1000], y[:1000])
|
| 90 |
+
model = SelectFromModel(extratrees, prefit=True)
|
| 91 |
+
X_new = model.transform(X)
|
| 92 |
+
nbfeatures = X_new.shape[1]
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
# splitting the data (70% - training and 30% - testing)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
X_train, X_test, y_train, y_test = train_test_split(
|
| 99 |
+
X_new, y, test_size=0.29, stratify=y
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
features = []
|
| 104 |
+
index = numpy.argsort(extratrees.feature_importances_)[::-1][:nbfeatures]
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
for f in range(nbfeatures):
|
| 108 |
+
print(
|
| 109 |
+
"%d. feature %s (%f)"
|
| 110 |
+
% (
|
| 111 |
+
f + 1,
|
| 112 |
+
dataset.columns[2 + index[f]],
|
| 113 |
+
extratrees.feature_importances_[index[f]],
|
| 114 |
+
)
|
| 115 |
+
)
|
| 116 |
+
features.append(dataset.columns[2 + f])
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
model = {
|
| 120 |
+
"DecisionTree": DecisionTreeClassifier(max_depth=10),
|
| 121 |
+
"RandomForest": ek.RandomForestClassifier(n_estimators=50),
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
results = {}
|
| 126 |
+
for algo in model:
|
| 127 |
+
clf = model[algo]
|
| 128 |
+
clf.fit(X_train, y_train)
|
| 129 |
+
score = clf.score(X_test, y_test)
|
| 130 |
+
print("%s : %s " % (algo, score))
|
| 131 |
+
results[algo] = score
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
winner = max(results, key=results.get) # Selecting the classifier with good result
|
| 135 |
+
print("Using", winner, "for classification, with", len(features), "features.")
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
joblib.dump(model[winner], "classifier.pkl")
|
| 139 |
+
open("features.pkl", "wb").write(pickle.dumps(features))
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
from fhe_utils import (
|
| 143 |
+
client_server_interaction, train_zama,
|
| 144 |
+
setup_network,
|
| 145 |
+
copy_directory,
|
| 146 |
+
setup_client,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
model_dev_fhe = train_zama(X_train, y_train)
|
| 150 |
+
#pdb.set_trace()
|
| 151 |
+
network, _ = setup_network(model_dev_fhe)
|
| 152 |
+
copied, error_message = copy_directory(network.dev_dir.name, destination="fhe_model")
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
if not copied:
|
| 156 |
+
print(f"Error copying directory: {error_message}")
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
network.dev_send_model_to_server()
|
| 160 |
+
network.dev_send_clientspecs_and_modelspecs_to_client()
|
| 161 |
+
|
| 162 |
+
fhemodel_client, serialized_evaluation_keys = setup_client(
|
| 163 |
+
network, network.client_dir.name
|
| 164 |
+
)
|
| 165 |
+
print(f"Evaluation keys size: {len(serialized_evaluation_keys)} B")
|
| 166 |
+
|
| 167 |
+
network.client_send_evaluation_key_to_server(serialized_evaluation_keys)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
decrypted_predictions, execution_time = client_server_interaction(network, fhemodel_client, X_test[:100])
|