kcelia's picture
chore: update
5062baf unverified
raw
history blame
19 kB
import os
import shutil
import subprocess
import time
from pathlib import Path
from typing import List, Tuple, Union
import gradio as gr
import numpy as np
import pandas as pd
import requests
from preprocessing import pretty_print
from symptoms_categories import SYMPTOMS_LIST
from concrete.ml.common.serialization.loaders import load
from concrete.ml.deployment import FHEModelClient, FHEModelDev, FHEModelServer
from concrete.ml.sklearn import XGBClassifier as ConcreteXGBoostClassifier
INPUT_BROWSER_LIMIT = 635
SERVER_URL = "http://localhost:8000/"
# This repository's main necessary folders
REPO_DIR = Path(__file__).parent
MODEL_PATH = REPO_DIR / "client_folder"
KEYS_PATH = REPO_DIR / ".fhe_keys"
CLIENT_TMP_PATH = REPO_DIR / "client_tmp"
SERVER_TMP_PATH = REPO_DIR / "server_tmp"
# Create the necessary folders
KEYS_PATH.mkdir(exist_ok=True)
CLIENT_TMP_PATH.mkdir(exist_ok=True)
SERVER_TMP_PATH.mkdir(exist_ok=True)
subprocess.Popen(["uvicorn", "server:app"], cwd=REPO_DIR)
time.sleep(3)
def clean_directory():
target_dir = ".fhe_keys"
if os.path.exists(target_dir) and os.path.isdir(target_dir):
shutil.rmtree(target_dir)
print("The .fhe_keys directory and its contents have been successfully removed.")
else:
print("The .keys directory does not exist.")
def load_data():
# Load data
df_train = pd.read_csv("./data/Training_preprocessed.csv")
df_test = pd.read_csv("./data/Testing_preprocessed.csv")
# Separate the traget from the training set
# df['prognosis] contains the name of the disease
# df['y] contains the numeric label of the disease
y_train = df_train["y"]
X_train = df_train.drop(columns=["y", "prognosis"], axis=1, errors="ignore")
y_test = df_train["y"]
X_test = df_test.drop(columns=["y", "prognosis"], axis=1, errors="ignore")
return (df_train, X_train, X_test), (df_test, y_train, y_test)
def load_model(X_train, y_train):
concrete_args = {"max_depth": 1, "n_bits": 3, "n_estimators": 3, "n_jobs": -1}
classifier = ConcreteXGBoostClassifier(**concrete_args)
classifier.fit(X_train, y_train)
circuit = classifier.compile(X_train)
return classifier, circuit
def get_user_vect_symptoms_from_checkboxgroup(*user_symptoms) -> np.array:
symptoms_vector = {key: 0 for key in VALID_COLUMNS}
for symptom_box in user_symptoms:
for pretty_symptom in symptom_box:
symptom = "_".join((pretty_symptom.lower().split(" ")))
if symptom not in symptoms_vector.keys():
raise KeyError(
f"The symptom '{symptom}' you provided is not recognized as a valid "
f"symptom.\nHere is the list of valid symptoms: {symptoms_vector}"
)
symptoms_vector[symptom] = 1.0
user_symptoms_vect = np.fromiter(symptoms_vector.values(), dtype=float)[np.newaxis, :]
assert all(value == 0 or value == 1 for value in user_symptoms_vect.flatten())
return user_symptoms_vect
def get_user_vector_from_default_disease(disease):
user_symptom_vector = df_test[df_test["prognosis"] == disease].iloc[0].values
user_symptoms_vect = np.fromiter(user_symptom_vector[:-2], dtype=float)[np.newaxis, :]
assert all(value == 0 or value == 1 for value in user_symptoms_vect.flatten())
return user_symptoms_vect
def get_user_symptoms_from_default_disease(disease):
df_filtred = df_test[df_test["prognosis"] == disease]
columns_with_1 = df_filtred.columns[df_filtred.eq(1).any()].to_list()
return pretty_print(columns_with_1)
def get_user_symptoms_vector_fn(selected_default_disease, *selected_symptoms):
# Display an error box, if:
# 1. The user has already selected a default disease and added more symptoms, or
# 2. The the user has not selected a default disease or symptoms
if (
any(lst for lst in selected_symptoms if lst)
and (selected_default_disease is not None and len(selected_default_disease) > 0)
and set(pretty_print(selected_symptoms))
- set(get_user_symptoms_from_default_disease(selected_default_disease))
) or (
not any(lst for lst in selected_symptoms if lst)
and (
selected_default_disease is None
or (selected_default_disease is not None and len(selected_default_disease) < 1)
)
):
return {
error_box_1: gr.update(
visible=True, value="Enter a default disease or select your own symptoms"
),
}
# Case 1: The user has checked his own symptoms
if any(lst for lst in selected_symptoms if lst):
return {
error_box_1: gr.update(visible=False),
user_vector_textbox: get_user_vect_symptoms_from_checkboxgroup(*selected_symptoms),
}
# Case 2: The user has selected a default disease
if selected_default_disease is not None and len(selected_default_disease) > 0:
return {
user_vector_textbox: get_user_vector_from_default_disease(selected_default_disease),
error_box_1: gr.update(visible=False),
**{
box: get_user_symptoms_from_default_disease(selected_default_disease)
for box in check_boxes
},
}
def key_gen_fn(user_symptoms):
print("Cleaning directory ...")
clean_directory()
if user_symptoms is None or (user_symptoms is not None and len(user_symptoms) < 1):
print("Please submit your symptoms first")
return {
error_box_2: gr.update(visible=True, value="Please submit your symptoms first"),
}
# Key serialization
user_id = np.random.randint(0, 2**32)
client = FHEModelClient(path_dir=MODEL_PATH, key_dir=KEYS_PATH / f"{user_id}")
client.load()
# The client first need to create the private and evaluation keys.
client.generate_private_and_evaluation_keys()
# Get the serialized evaluation keys
serialized_evaluation_keys = client.get_serialized_evaluation_keys()
assert isinstance(serialized_evaluation_keys, bytes)
# np.save(f".fhe_keys/{user_id}/eval_key.npy", serialized_evaluation_keys)
evaluation_key_path = KEYS_PATH / f"{user_id}/evaluation_key"
with evaluation_key_path.open("wb") as f:
f.write(serialized_evaluation_keys)
serialized_evaluation_keys_shorten_hex = serialized_evaluation_keys.hex()[:INPUT_BROWSER_LIMIT]
return {
error_box_2: gr.update(visible=False),
eval_key_textbox: serialized_evaluation_keys_shorten_hex,
user_id_textbox: user_id,
eval_key_len_textbox: f"{len(serialized_evaluation_keys) / (10**6):.2f} MB",
}
def encrypt_fn(user_symptoms, user_id):
if not user_symptoms or not user_symptoms:
return {
error_box_3: gr.update(
visible=True, value="Please ensure that the evaluation key has been generated!"
)
}
# Retrieve the client API
client = FHEModelClient(path_dir=MODEL_PATH, key_dir=KEYS_PATH / f"{user_id}")
client.load()
user_symptoms = np.fromstring(user_symptoms[2:-2], dtype=int, sep=".").reshape(1, -1)
quant_user_symptoms = client.model.quantize_input(user_symptoms)
encrypted_quantized_user_symptoms = client.quantize_encrypt_serialize(user_symptoms)
assert isinstance(encrypted_quantized_user_symptoms, bytes)
encrypted_input_path = KEYS_PATH / f"{user_id}/encrypted_symptoms"
with encrypted_input_path.open("wb") as f:
f.write(encrypted_quantized_user_symptoms)
# print(client.model.predict(vect_x, fhe="simulate"), client.model.predict(vect_x, fhe="execute"))
# pred_s = client.model.fhe_circuit.simulate(quant_vect)
# pred_fhe = client.model.fhe_circuit.encrypt_run_decrypt(quant_vect) #
# non alpha -> \X1124, base64 ou en exa
# Compute size
# np.save(f".fhe_keys/{user_id}/encrypted_quant_vect.npy", encrypted_quantized_user_symptoms)
encrypted_quantized_user_symptoms_shorten_hex = encrypted_quantized_user_symptoms.hex()[
:INPUT_BROWSER_LIMIT
]
return {
error_box_3: gr.update(visible=False),
vect_textbox: user_symptoms,
quant_vect_textbox: quant_user_symptoms,
encrypted_vect_textbox: encrypted_quantized_user_symptoms_shorten_hex,
}
def is_nan(input):
return input is None or (input is not None and len(input) < 1)
def send_input_fn(user_id, user_symptoms):
"""Send the encrypted input image as well as the evaluation key to the server.
Args:
user_id (int): The current user's ID.
filter_name (str): The current filter to consider.
"""
# Get the evaluation key path
if is_nan(user_id) or is_nan(user_symptoms):
return {
error_box_4: gr.update(
visible=True,
value="Please ensure that the evaluation key has been generated "
"and the symptoms have been submitted before sending the data to the server",
)
}
evaluation_key_path = KEYS_PATH / f"{user_id}/evaluation_key"
encrypted_input_path = KEYS_PATH / f"{user_id}/encrypted_symptoms"
if not evaluation_key_path.is_file():
print(f"Please generate the private key, first.{evaluation_key_path.is_file()=}")
return {
error_box_4: gr.update(visible=True, value="Please generate the private key first.")
}
if not encrypted_input_path.is_file():
print(f"Please submit your symptoms, first.{encrypted_input_path.is_file()=}")
return {
error_box_4: gr.update(
visible=True,
value="Please generate the private key and then encrypt an image first.",
)
}
# Define the data and files to post
data = {
"user_id": user_id,
"filter": user_symptoms,
}
files = [
("files", open(encrypted_input_path, "rb")),
("files", open(evaluation_key_path, "rb")),
]
# Send the encrypted input image and evaluation key to the server
url = SERVER_URL + "send_input"
with requests.post(
url=url,
data=data,
files=files,
) as response:
print(f"response.ok: {response.ok}")
return {error_box_4: gr.update(visible=False), server_response_box: gr.update(visible=True)}
# def decrypt_prediction(encrypted_quantized_vect, user_id):
# fhe_api = FHEModelClient(path_dir=REPO_DIR, key_dir=f".fhe_keys/{user_id}")
# fhe_api.load()
# fhe_api.generate_private_and_evaluation_keys(force=False)
# predictions = fhe_api.deserialize_decrypt_dequantize(encrypted_quantized_vect)
# return predictions
def clear_all_btn():
clean_directory()
return {
box_default: None,
vect_textbox: None,
user_id_textbox: None,
eval_key_textbox: None,
quant_vect_textbox: None,
user_vector_textbox: None,
eval_key_len_textbox: None,
encrypted_vect_textbox: None,
error_box_1: gr.update(visible=False),
error_box_2: gr.update(visible=False),
error_box_3: gr.update(visible=False),
error_box_4: gr.update(visible=False),
server_response_box: gr.update(visible=False),
**{box: None for box in check_boxes},
}
if __name__ == "__main__":
print("Starting demo ...")
(df_train, X_train, X_test), (df_test, y_train, y_test) = load_data()
VALID_COLUMNS = X_train.columns.to_list()
# Load the model
with open("ConcreteXGBoostClassifier.pkl", "r", encoding="utf-8") as file:
concrete_classifier = load(file)
with gr.Blocks() as demo:
# Link + images
gr.Markdown(
"""
<p align="center">
<img width=200 src="https://user-images.githubusercontent.com/5758427/197816413-d9cddad3-ba38-4793-847d-120975e1da11.png">
</p>
<h2 align="center">Health Prediction On Encrypted Data Using Homomorphic Encryption.</h2>
<p align="center">
<a href="https://github.com/zama-ai/concrete-ml"> <img style="vertical-align: middle; display:inline-block; margin-right: 3px;" width=15 src="https://user-images.githubusercontent.com/5758427/197972109-faaaff3e-10e2-4ab6-80f5-7531f7cfb08f.png">Concrete-ML</a>
<a href="https://docs.zama.ai/concrete-ml"> <img style="vertical-align: middle; display:inline-block; margin-right: 3px;" width=15 src="https://user-images.githubusercontent.com/5758427/197976802-fddd34c5-f59a-48d0-9bff-7ad1b00cb1fb.png">Documentation</a>
<a href="https://zama.ai/community"> <img style="vertical-align: middle; display:inline-block; margin-right: 3px;" width=15 src="https://user-images.githubusercontent.com/5758427/197977153-8c9c01a7-451a-4993-8e10-5a6ed5343d02.png">Community</a>
<a href="https://twitter.com/zama_fhe"> <img style="vertical-align: middle; display:inline-block; margin-right: 3px;" width=15 src="https://user-images.githubusercontent.com/5758427/197975044-bab9d199-e120-433b-b3be-abd73b211a54.png">@zama_fhe</a>
</p>
<p align="center">
<img src="https://raw.githubusercontent.com/kcelia/Img/main/demo-img2.png" width="60%" height="60%">
</p>
"""
)
# Gentle introduction
gr.Markdown("## Introduction")
gr.Markdown("""Blablabla""")
# User symptoms
gr.Markdown("# Step 1: Provide your symptoms")
gr.Markdown("Client side")
# Default disease, picked from the dataframe
with gr.Row():
default_diseases = list(set(df_test["prognosis"]))
box_default = gr.Dropdown(default_diseases, label="Disease")
# Box symptoms
check_boxes = []
for i, category in enumerate(SYMPTOMS_LIST):
check_box = gr.CheckboxGroup(
pretty_print(category.values()),
label=pretty_print(category.keys()),
info=f"Symptoms related to `{pretty_print(category.values())}`",
max_batch_size=45,
)
check_boxes.append(check_box)
error_box_1 = gr.Textbox(label="Error", visible=False)
# User symptom vector
with gr.Row():
user_vector_textbox = gr.Textbox(
label="User symptoms (vector)",
interactive=False,
max_lines=100,
)
with gr.Row():
# Submit botton
with gr.Column():
submit_button = gr.Button("Submit")
# Clear botton
with gr.Column():
clear_button = gr.Button("Clear")
# Click submit botton
submit_button.click(
fn=get_user_symptoms_vector_fn,
inputs=[box_default, *check_boxes],
outputs=[user_vector_textbox, error_box_1, *check_boxes],
)
gr.Markdown("# Step 2: Generate the keys")
gr.Markdown("Client side")
gen_key_btn = gr.Button("Generate the keys and send public part to server")
error_box_2 = gr.Textbox(label="Error", visible=False)
with gr.Row():
# User ID
with gr.Column(scale=1, min_width=600):
user_id_textbox = gr.Textbox(
label="User ID:",
max_lines=4,
interactive=False,
)
# Evaluation key size
with gr.Column(scale=1, min_width=600):
eval_key_len_textbox = gr.Textbox(
label="Evaluation key size:", max_lines=4, interactive=False
)
with gr.Row():
# Evaluation key (truncated)
with gr.Column(scale=2, min_width=600):
eval_key_textbox = gr.Textbox(
label="Evaluation key (truncated):",
max_lines=4,
interactive=False,
)
gen_key_btn.click(
key_gen_fn,
inputs=user_vector_textbox,
outputs=[eval_key_textbox, user_id_textbox, eval_key_len_textbox, error_box_2],
)
gr.Markdown("# Step 3: Encode the message with the private key")
gr.Markdown("Client side")
encrypt_btn = gr.Button("Encode the message with the private key")
error_box_3 = gr.Textbox(label="Error", visible=False)
with gr.Row():
with gr.Column(scale=1, min_width=600):
vect_textbox = gr.Textbox(
label="Vector:",
max_lines=4,
interactive=False,
)
with gr.Column(scale=1, min_width=600):
quant_vect_textbox = gr.Textbox(
label="Quant vector:", max_lines=4, interactive=False
)
with gr.Column(scale=1, min_width=600):
encrypted_vect_textbox = gr.Textbox(
label="Encrypted vector:", max_lines=4, interactive=False
)
encrypt_btn.click(
encrypt_fn,
inputs=[user_vector_textbox, user_id_textbox],
outputs=[vect_textbox, quant_vect_textbox, encrypted_vect_textbox, error_box_3],
)
gr.Markdown("# Step 4: Send the encrypted data to the server.")
gr.Markdown("Client side")
send_input_btn = gr.Button("Send the encrypted data to the server.")
error_box_4 = gr.Textbox(label="Error", visible=False)
server_response_box = gr.Textbox(value="Data sent", visible=False, show_label=False)
send_input_btn.click(
send_input_fn,
inputs=[user_id_textbox, user_vector_textbox],
outputs=[error_box_4, server_response_box],
)
gr.Markdown("# Step 5: Run the FHE evaluation")
gr.Markdown("Server side")
run_fhe = gr.Button("Run the FHE evaluation")
gr.Markdown("# Step 6: Decrypt the sentiment")
gr.Markdown("Server side")
decrypt_target_botton = gr.Button("Decrypt the sentiment")
decrypt_target_textbox = gr.Textbox(
label="Encrypted vector:", max_lines=4, interactive=False
)
# decrypt_target_botton.click(
# decrypt_prediction,
# inputs=[encrypted_vect_textbox, user_id_textbox],
# outputs=[decrypt_target_textbox],
# )
clear_button.click(
clear_all_btn,
outputs=[
box_default,
error_box_1,
error_box_2,
error_box_3,
error_box_4,
vect_textbox,
user_id_textbox,
eval_key_textbox,
quant_vect_textbox,
user_vector_textbox,
server_response_box,
eval_key_len_textbox,
encrypted_vect_textbox,
*check_boxes,
],
)
demo.launch()