File size: 8,930 Bytes
ce70f59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184

from concrete.ml.deployment import FHEModelClient
from pathlib import Path
import numpy as np
import gradio as gr
import requests

# Store the server's URL
SERVER_URL = "http://127.0.0.1:7860/"
CURRENT_DIR = Path(__file__).parent
DEPLOYMENT_DIR = CURRENT_DIR / "deployment_files"
KEYS_DIR = DEPLOYMENT_DIR / ".fhe_keys"
CLIENT_DIR = DEPLOYMENT_DIR / "client_dir"
SERVER_DIR = DEPLOYMENT_DIR / "server_dir"


USER_ID = "user_id"
EXAMPLE_CLINICAL_TRIAL_LINK = "https://www.trials4us.co.uk/ongoing-clinical-trials/recruiting-healthy-adults-c23026?_gl=1*1ysp815*_up*MQ..&gclid=Cj0KCQjwr9m3BhDHARIsANut04bHqi5zE3sjS3f8JK2WRN3YEgY4bTfWbvTdZTxkUTSISxXX5ZWL7qEaAowwEALw_wcB&gbraid=0AAAAAD3Qci2k_3IERmM6U1FGDuYVayZWH"




# Define possible categories for fields without predefined categories
additional_categories = {
    "Gender": ["Male", "Female", "Other"],
    "Ethnicity": ["White", "Black or African American", "Asian", "American Indian or Alaska Native", "Native Hawaiian or Other Pacific Islander", "Other"],
    "Geographic_Location": ["North America", "South America", "Europe", "Asia", "Africa", "Australia", "Antarctica"],
    "Smoking_Status": ["Never", "Former", "Current"],
    "Diagnoses_ICD10": ["E11.9", "I10", "J45.909", "M54.5", "F32.9", "K21.9"],
    "Medications": ["Metformin", "Lisinopril", "Atorvastatin", "Amlodipine", "Omeprazole", "Simvastatin", "Levothyroxine", "None"],
    "Allergies": ["Penicillin", "Peanuts", "Shellfish", "Latex", "Bee stings", "None"],
    "Previous_Treatments": ["Chemotherapy", "Radiation Therapy", "Surgery", "Physical Therapy", "Immunotherapy", "None"],
    "Alcohol_Consumption": ["None", "Occasionally", "Regularly", "Heavy"],
    "Exercise_Habits": ["Sedentary", "Light", "Moderate", "Active", "Very Active"],
    "Diet": ["Omnivore", "Vegetarian", "Vegan", "Pescatarian", "Keto", "Mediterranean"],
    "Functional_Status": ["Independent", "Assisted", "Dependent"],
    "Previous_Trial_Participation": ["Yes", "No"]
}

# Define the input components for the form
age_input = gr.Slider(minimum=18, maximum=100, label="Age ", step=1)
gender_input = gr.Radio(choices=additional_categories["Gender"], label="Gender")
ethnicity_input = gr.Radio(choices=additional_categories["Ethnicity"], label="Ethnicity")
geographic_location_input = gr.Radio(choices=additional_categories["Geographic_Location"], label="Geographic Location")
diagnoses_icd10_input = gr.CheckboxGroup(choices=additional_categories["Diagnoses_ICD10"], label="Diagnoses (ICD-10)")
medications_input = gr.CheckboxGroup(choices=additional_categories["Medications"], label="Medications")
allergies_input = gr.CheckboxGroup(choices=additional_categories["Allergies"], label="Allergies")
previous_treatments_input = gr.CheckboxGroup(choices=additional_categories["Previous_Treatments"], label="Previous Treatments")
blood_glucose_level_input = gr.Slider(minimum=0, maximum=300, label="Blood Glucose Level", step=1)
blood_pressure_systolic_input = gr.Slider(minimum=80, maximum=200, label="Blood Pressure (Systolic)", step=1)
blood_pressure_diastolic_input = gr.Slider(minimum=40, maximum=120, label="Blood Pressure (Diastolic)", step=1)
bmi_input = gr.Slider(minimum=10, maximum=50, label="BMI ", step=1)
smoking_status_input = gr.Radio(choices=additional_categories["Smoking_Status"], label="Smoking Status")
alcohol_consumption_input = gr.Radio(choices=additional_categories["Alcohol_Consumption"], label="Alcohol Consumption")
exercise_habits_input = gr.Radio(choices=additional_categories["Exercise_Habits"], label="Exercise Habits")
diet_input = gr.Radio(choices=additional_categories["Diet"], label="Diet")
condition_severity_input = gr.Slider(minimum=1, maximum=10, label="Condition Severity", step=1)
functional_status_input = gr.Radio(choices=additional_categories["Functional_Status"], label="Functional Status")
previous_trial_participation_input = gr.Radio(choices=additional_categories["Previous_Trial_Participation"], label="Previous Trial Participation")


def encrypt_array(user_symptoms: np.ndarray, user_id: str) -> bytes:
    """
    Encrypt the user symptoms vector.

    Args:
        user_symptoms (np.ndarray): The vector of symptoms provided by the user.
        user_id (str): The current user's ID.
    
    Returns:
        bytes: Encrypted and serialized symptoms.
    """

    # Retrieve the client API
    client = FHEModelClient(path_dir=DEPLOYMENT_DIR, key_dir=KEYS_DIR / f"{user_id}")
    client.load()

    # Ensure the symptoms are properly formatted as an array
    user_symptoms = np.array(user_symptoms).reshape(1, -1)

    # Encrypt and serialize the symptoms
    encrypted_quantized_user_symptoms = client.quantize_encrypt_serialize(user_symptoms)
    
    # Ensure the encryption process returned bytes
    assert isinstance(encrypted_quantized_user_symptoms, bytes)

    # Save the encrypted data to a file (optional)
    encrypted_input_path = KEYS_DIR / f"{user_id}/encrypted_input"
    with encrypted_input_path.open("wb") as f:
        f.write(encrypted_quantized_user_symptoms)

    # Return the encrypted data
    return encrypted_quantized_user_symptoms


def decrypt_result(encrypted_answer: bytes, user_id: str) -> bool:
    """
    Decrypt the encrypted result.

    Args:
        encrypted_answer (bytes): The encrypted result.
        user_id (str): The current user's ID.
    
    Returns:
        bool: The decrypted result.
    """

    # Retrieve the client API
    client = FHEModelClient(path_dir=DEPLOYMENT_DIR, key_dir=KEYS_DIR / f"{user_id}")
    client.load()

    # Decrypt the result
    decrypted_result = client.decrypt_deserialize(encrypted_answer)

    # Return the decrypted result
    return decrypted_result



def encode_categorical_data(data):
    categories = ["Gender", "Ethnicity", "Geographic_Location", "Smoking_Status", "Alcohol_Consumption", "Exercise_Habits", "Diet", "Functional_Status", "Previous_Trial_Participation"]
    encoded_data = []
    for i in range(len(categories)):
        sub_cats = additional_categories[categories[i]]
        if data[i] in sub_cats:
            encoded_data.append(sub_cats.index(data[i]) + 1)
        else:
            encoded_data.append(0)
        
    return encoded_data


def process_patient_data(age, gender, ethnicity, geographic_location, diagnoses_icd10, medications, allergies, previous_treatments, blood_glucose_level, blood_pressure_systolic, blood_pressure_diastolic, bmi, smoking_status, alcohol_consumption, exercise_habits, diet, condition_severity, functional_status, previous_trial_participation):    
    
    # Encode the data
    categorical_data = [gender, ethnicity, geographic_location, smoking_status, alcohol_consumption, exercise_habits, diet, functional_status, previous_trial_participation]
    print(f"Categorical data: {categorical_data}")
    encoded_categorical_data = encode_categorical_data(categorical_data)
    numerical_data = np.array([age, blood_glucose_level, blood_pressure_systolic, blood_pressure_diastolic, bmi, condition_severity])
    print(f"Numerical data: {numerical_data}")
    print(f"One-hot encoded data: {encoded_categorical_data}")
    combined_data = np.hstack((numerical_data, encoded_categorical_data))
    print(f"Combined data: {combined_data}")
    encrypted_array = encrypt_array(combined_data, "user_id")

    # Send the encrypted data to the server 
    response = requests.post(SERVER_URL, data=encrypted_array)

    # Check if the data was sent successfully
    if response.status_code == 200:
        print("Data sent successfully.")
    else:
        print("Error sending data.")
    
    # Decrypt the result
    decrypted_result = decrypt_result(response.content, USER_ID)
    
    # If the answer is True, return the link
    if decrypted_result:
        return (
            f"Encrypted data: {encrypted_array}",
            f"Decrypted result: {decrypted_result}",
            f"You may now access the link to the [clinical trial]({EXAMPLE_CLINICAL_TRIAL_LINK})"
        )
    else:
        return (
            f"Encrypted data: {encrypted_array}",
            f"Decrypted result: {decrypted_result}",
            f"Unfortunately, there are no clinical trials available for the provided criteria."
        )

# Create the Gradio interface
demo = gr.Interface(
    fn=process_patient_data, 
    inputs=[
        age_input, gender_input, ethnicity_input, geographic_location_input, diagnoses_icd10_input, medications_input, allergies_input, previous_treatments_input, blood_glucose_level_input, blood_pressure_systolic_input, blood_pressure_diastolic_input, bmi_input, smoking_status_input, alcohol_consumption_input, exercise_habits_input, diet_input, condition_severity_input, functional_status_input, previous_trial_participation_input
    ], 
    outputs="text",
    title="Patient Data Criteria Form",
    description="Please fill in the criteria for the type of patients you are looking for."
)

# Launch the app
demo.launch()