"""A gradio app. that runs locally (analytics=False and share=False) about sentiment analysis on tweets.""" import random import numpy as np import gradio as gr from concrete.ml.deployment import FHEModelClient import numpy import os from pathlib import Path import shutil import torch from model import Autoencoder from concrete.ml.torch.compile import compile_torch_model sequence_length = 50 input_size = 12 latent_size = 8 hidden_size = 64 random.seed(0) np.random.seed(0) torch.manual_seed(0) ae_model = Autoencoder( input_size=input_size, hidden_size=hidden_size, latent_size=latent_size, sequence_length=sequence_length, num_lstm_layers=1, ) encoder = ae_model.encoder encoder.load_state_dict(torch.load("deployment/encoder.pth", weights_only=True)) decoder = ae_model.decoder decoder.load_state_dict(torch.load("deployment/decoder.pth", weights_only=True)) criterion = torch.nn.MSELoss() dummy_input = torch.randn(1, latent_size) compiled_decoder = compile_torch_model( decoder, dummy_input.numpy(), n_bits=6, rounding_threshold_bits={"n_bits": 6, "method": "approximate"}, ) # Encrypted data limit for the browser to display # (encrypted data is too large to display in the browser) ENCRYPTED_DATA_BROWSER_LIMIT = 100 N_USER_KEY_STORED = 20 FHE_MODEL_PATH = "deployment" def clean_tmp_directory(): # Allow 20 user keys to be stored. # Once that limitation is reached, deleted the oldest. path_sub_directories = sorted([f for f in Path(".fhe_keys/").iterdir() if f.is_dir()], key=os.path.getmtime) user_ids = [] if len(path_sub_directories) > N_USER_KEY_STORED: n_files_to_delete = len(path_sub_directories) - N_USER_KEY_STORED for p in path_sub_directories[:n_files_to_delete]: user_ids.append(p.name) shutil.rmtree(p) list_files_tmp = Path("tmp/").iterdir() # Delete all files related to user_id for file in list_files_tmp: for user_id in user_ids: if file.name.endswith(f"{user_id}.npy"): file.unlink() def keygen(): # Clean tmp directory if needed clean_tmp_directory() print("Initializing FHEModelClient...") # Let's create a user_id user_id = numpy.random.randint(0, 2**32) fhe_api = FHEModelClient(FHE_MODEL_PATH, f".fhe_keys/{user_id}") fhe_api.load() # Generate a fresh key fhe_api.generate_private_and_evaluation_keys(force=True) evaluation_key = fhe_api.get_serialized_evaluation_keys() # Save evaluation_key in a file, since too large to pass through regular Gradio # buttons, https://github.com/gradio-app/gradio/issues/1877 numpy.save(f"tmp/tmp_evaluation_key_{user_id}.npy", evaluation_key) return [list(evaluation_key)[:ENCRYPTED_DATA_BROWSER_LIMIT], user_id] def run_fhe(packets_ids, threshold=0.05): int_values = np.array([int(h[0], 16) for h in packets_ids.split(" ")]) binary_rep = np.array([list(bin(x)[2:].zfill(12)) for x in int_values]) packets_ids = binary_rep.astype(float) packets_ids = torch.tensor(packets_ids).unsqueeze(0).float() latent = encoder(packets_ids) with torch.no_grad(): # Disable gradient computation for validation decrypted_output = compiled_decoder.forward(latent.numpy(), fhe="simulate") decrypted_output = torch.tensor(decrypted_output).view( -1, ae_model.sequence_length, packets_ids.size(2) ) loss = criterion(decrypted_output, packets_ids) pred = loss.item() > threshold return [loss, pred] demo = gr.Blocks() with demo: gr.Markdown( """