TomSmail's picture
feat: add user frontend
ce70f59
raw
history blame
8.98 kB
from concrete.ml.deployment import FHEModelClient
from pathlib import Path
import numpy as np
import gradio as gr
import requests
from sklearn.preprocessing import OneHotEncoder
# Store the server's URL
SERVER_URL = "http://127.0.0.1:7860/"
CURRENT_DIR = Path(__file__).parent
DEPLOYMENT_DIR = CURRENT_DIR / "deployment_files"
KEYS_DIR = DEPLOYMENT_DIR / ".fhe_keys"
CLIENT_DIR = DEPLOYMENT_DIR / "client_dir"
SERVER_DIR = DEPLOYMENT_DIR / "server_dir"
USER_ID = "user_id"
EXAMPLE_CLINICAL_TRIAL_LINK = "https://www.trials4us.co.uk/ongoing-clinical-trials/recruiting-healthy-adults-c23026?_gl=1*1ysp815*_up*MQ..&gclid=Cj0KCQjwr9m3BhDHARIsANut04bHqi5zE3sjS3f8JK2WRN3YEgY4bTfWbvTdZTxkUTSISxXX5ZWL7qEaAowwEALw_wcB&gbraid=0AAAAAD3Qci2k_3IERmM6U1FGDuYVayZWH"
# Define possible categories for fields without predefined categories
additional_categories = {
"Gender": ["Male", "Female", "Other"],
"Ethnicity": ["White", "Black or African American", "Asian", "American Indian or Alaska Native", "Native Hawaiian or Other Pacific Islander", "Other"],
"Geographic_Location": ["North America", "South America", "Europe", "Asia", "Africa", "Australia", "Antarctica"],
"Smoking_Status": ["Never", "Former", "Current"],
"Diagnoses_ICD10": ["E11.9", "I10", "J45.909", "M54.5", "F32.9", "K21.9"],
"Medications": ["Metformin", "Lisinopril", "Atorvastatin", "Amlodipine", "Omeprazole", "Simvastatin", "Levothyroxine", "None"],
"Allergies": ["Penicillin", "Peanuts", "Shellfish", "Latex", "Bee stings", "None"],
"Previous_Treatments": ["Chemotherapy", "Radiation Therapy", "Surgery", "Physical Therapy", "Immunotherapy", "None"],
"Alcohol_Consumption": ["None", "Occasionally", "Regularly", "Heavy"],
"Exercise_Habits": ["Sedentary", "Light", "Moderate", "Active", "Very Active"],
"Diet": ["Omnivore", "Vegetarian", "Vegan", "Pescatarian", "Keto", "Mediterranean"],
"Functional_Status": ["Independent", "Assisted", "Dependent"],
"Previous_Trial_Participation": ["Yes", "No"]
}
# Define the input components for the form
age_input = gr.Slider(minimum=18, maximum=100, label="Age ", step=1)
gender_input = gr.Radio(choices=additional_categories["Gender"], label="Gender")
ethnicity_input = gr.Radio(choices=additional_categories["Ethnicity"], label="Ethnicity")
geographic_location_input = gr.Radio(choices=additional_categories["Geographic_Location"], label="Geographic Location")
diagnoses_icd10_input = gr.CheckboxGroup(choices=additional_categories["Diagnoses_ICD10"], label="Diagnoses (ICD-10)")
medications_input = gr.CheckboxGroup(choices=additional_categories["Medications"], label="Medications")
allergies_input = gr.CheckboxGroup(choices=additional_categories["Allergies"], label="Allergies")
previous_treatments_input = gr.CheckboxGroup(choices=additional_categories["Previous_Treatments"], label="Previous Treatments")
blood_glucose_level_input = gr.Slider(minimum=0, maximum=300, label="Blood Glucose Level", step=1)
blood_pressure_systolic_input = gr.Slider(minimum=80, maximum=200, label="Blood Pressure (Systolic)", step=1)
blood_pressure_diastolic_input = gr.Slider(minimum=40, maximum=120, label="Blood Pressure (Diastolic)", step=1)
bmi_input = gr.Slider(minimum=10, maximum=50, label="BMI ", step=1)
smoking_status_input = gr.Radio(choices=additional_categories["Smoking_Status"], label="Smoking Status")
alcohol_consumption_input = gr.Radio(choices=additional_categories["Alcohol_Consumption"], label="Alcohol Consumption")
exercise_habits_input = gr.Radio(choices=additional_categories["Exercise_Habits"], label="Exercise Habits")
diet_input = gr.Radio(choices=additional_categories["Diet"], label="Diet")
condition_severity_input = gr.Slider(minimum=1, maximum=10, label="Condition Severity", step=1)
functional_status_input = gr.Radio(choices=additional_categories["Functional_Status"], label="Functional Status")
previous_trial_participation_input = gr.Radio(choices=additional_categories["Previous_Trial_Participation"], label="Previous Trial Participation")
def encrypt_array(user_symptoms: np.ndarray, user_id: str) -> bytes:
"""
Encrypt the user symptoms vector.
Args:
user_symptoms (np.ndarray): The vector of symptoms provided by the user.
user_id (str): The current user's ID.
Returns:
bytes: Encrypted and serialized symptoms.
"""
# Retrieve the client API
client = FHEModelClient(path_dir=DEPLOYMENT_DIR, key_dir=KEYS_DIR / f"{user_id}")
client.load()
# Ensure the symptoms are properly formatted as an array
user_symptoms = np.array(user_symptoms).reshape(1, -1)
# Encrypt and serialize the symptoms
encrypted_quantized_user_symptoms = client.quantize_encrypt_serialize(user_symptoms)
# Ensure the encryption process returned bytes
assert isinstance(encrypted_quantized_user_symptoms, bytes)
# Save the encrypted data to a file (optional)
encrypted_input_path = KEYS_DIR / f"{user_id}/encrypted_input"
with encrypted_input_path.open("wb") as f:
f.write(encrypted_quantized_user_symptoms)
# Return the encrypted data
return encrypted_quantized_user_symptoms
def decrypt_result(encrypted_answer: bytes, user_id: str) -> bool:
"""
Decrypt the encrypted result.
Args:
encrypted_answer (bytes): The encrypted result.
user_id (str): The current user's ID.
Returns:
bool: The decrypted result.
"""
# Retrieve the client API
client = FHEModelClient(path_dir=DEPLOYMENT_DIR, key_dir=KEYS_DIR / f"{user_id}")
client.load()
# Decrypt the result
decrypted_result = client.decrypt_deserialize(encrypted_answer)
# Return the decrypted result
return decrypted_result
def encode_categorical_data(data):
categories = ["Gender", "Ethnicity", "Geographic_Location", "Smoking_Status", "Alcohol_Consumption", "Exercise_Habits", "Diet", "Functional_Status", "Previous_Trial_Participation"]
encoded_data = []
for i in range(len(categories)):
sub_cats = additional_categories[categories[i]]
if data[i] in sub_cats:
encoded_data.append(sub_cats.index(data[i]) + 1)
else:
encoded_data.append(0)
return encoded_data
def process_patient_data(age, gender, ethnicity, geographic_location, diagnoses_icd10, medications, allergies, previous_treatments, blood_glucose_level, blood_pressure_systolic, blood_pressure_diastolic, bmi, smoking_status, alcohol_consumption, exercise_habits, diet, condition_severity, functional_status, previous_trial_participation):
# Encode the data
categorical_data = [gender, ethnicity, geographic_location, smoking_status, alcohol_consumption, exercise_habits, diet, functional_status, previous_trial_participation]
print(f"Categorical data: {categorical_data}")
encoded_categorical_data = encode_categorical_data(categorical_data)
numerical_data = np.array([age, blood_glucose_level, blood_pressure_systolic, blood_pressure_diastolic, bmi, condition_severity])
print(f"Numerical data: {numerical_data}")
print(f"One-hot encoded data: {encoded_categorical_data}")
combined_data = np.hstack((numerical_data, encoded_categorical_data))
print(f"Combined data: {combined_data}")
encrypted_array = encrypt_array(combined_data, "user_id")
# Send the encrypted data to the server
response = requests.post(SERVER_URL, data=encrypted_array)
# Check if the data was sent successfully
if response.status_code == 200:
print("Data sent successfully.")
else:
print("Error sending data.")
# Decrypt the result
decrypted_result = decrypt_result(response.content, USER_ID)
# If the answer is True, return the link
if decrypted_result:
return (
f"Encrypted data: {encrypted_array}",
f"Decrypted result: {decrypted_result}",
f"You may now access the link to the [clinical trial]({EXAMPLE_CLINICAL_TRIAL_LINK})"
)
else:
return (
f"Encrypted data: {encrypted_array}",
f"Decrypted result: {decrypted_result}",
f"Unfortunately, there are no clinical trials available for the provided criteria."
)
# Create the Gradio interface
demo = gr.Interface(
fn=process_patient_data,
inputs=[
age_input, gender_input, ethnicity_input, geographic_location_input, diagnoses_icd10_input, medications_input, allergies_input, previous_treatments_input, blood_glucose_level_input, blood_pressure_systolic_input, blood_pressure_diastolic_input, bmi_input, smoking_status_input, alcohol_consumption_input, exercise_habits_input, diet_input, condition_severity_input, functional_status_input, previous_trial_participation_input
],
outputs="text",
title="Patient Data Criteria Form",
description="Please fill in the criteria for the type of patients you are looking for."
)
# Launch the app
demo.launch()