File size: 6,343 Bytes
174cd37 d812385 174cd37 67fa189 d812385 67fa189 1dfccc3 67fa189 174cd37 67fa189 7cb14dd 174cd37 67fa189 174cd37 67fa189 174cd37 67fa189 174cd37 7cb14dd ce217e0 174cd37 67fa189 7cb14dd 67fa189 174cd37 cf6aebf 67fa189 cf6aebf 174cd37 ce217e0 174cd37 1dfccc3 174cd37 1dfccc3 174cd37 1dfccc3 174cd37 1dfccc3 628fe8f d812385 174cd37 d812385 174cd37 d812385 174cd37 67fa189 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
import json
import os
import pickle as pkl
import re
import shutil
import string
from collections import Counter
from pathlib import Path
import numpy as np
import torch
from transformers import AutoModel, AutoTokenizer
from pathlib import Path
# Core Application URL
SERVER_URL = "http://localhost:8000/"
# Maximum length for user queries
MAX_USER_QUERY_LEN = 128
# Base Directories
CURRENT_DIR = Path(__file__).parent
DEPLOYMENT_DIR = CURRENT_DIR / "deployment"
DATA_PATH = CURRENT_DIR / "files"
# Deployment Directories
CLIENT_DIR = DEPLOYMENT_DIR / "client_dir"
SERVER_DIR = DEPLOYMENT_DIR / "server_dir"
KEYS_DIR = DEPLOYMENT_DIR / ".fhe_keys"
# All Directories
ALL_DIRS = [KEYS_DIR, CLIENT_DIR, SERVER_DIR]
# Model and Data Files
LOGREG_MODEL_PATH = CURRENT_DIR / "models" / "cml_logreg.model"
ORIGINAL_FILE_PATH = DATA_PATH / "original_document.txt"
ANONYMIZED_FILE_PATH = DATA_PATH / "anonymized_document.txt"
MAPPING_UUID_PATH = DATA_PATH / "original_document_uuid_mapping.json"
MAPPING_ANONYMIZED_SENTENCES_PATH = DATA_PATH / "mapping_clear_to_anonymized.pkl"
MAPPING_ENCRYPTED_SENTENCES_PATH = DATA_PATH / "mapping_clear_to_encrypted.pkl"
MAPPING_DOC_EMBEDDING_PATH = DATA_PATH / "mapping_doc_embedding_path.pkl"
PROMPT_PATH = DATA_PATH / "chatgpt_prompt.txt"
# List of example queries for easy access
DEFAULT_QUERIES = {
"Example Query 1": "What is the amount of the contract between David and Kate?",
"Example Query 2": "What's the duration of the contract?",
"Example Query 3": "Does Kate have an international bank account?",
}
# Load tokenizer and model
TOKENIZER = AutoTokenizer.from_pretrained("obi/deid_roberta_i2b2")
EMBEDDINGS_MODEL = AutoModel.from_pretrained("obi/deid_roberta_i2b2")
PUNCTUATION_LIST = list(string.punctuation)
PUNCTUATION_LIST.remove("%")
PUNCTUATION_LIST.remove("$")
PUNCTUATION_LIST = "".join(PUNCTUATION_LIST) + '°'
print(f'{PUNCTUATION_LIST=}')
def clean_directory() -> None:
"""Clear direcgtories"""
print("Cleaning...\n")
for target_dir in ALL_DIRS:
if os.path.exists(target_dir) and os.path.isdir(target_dir):
shutil.rmtree(target_dir)
target_dir.mkdir(exist_ok=True, parents=True)
def get_batch_text_representation(texts, model, tokenizer, batch_size=1):
"""Get mean-pooled representations of given texts in batches."""
mean_pooled_batch = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i : i + batch_size]
inputs = tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = model(**inputs, output_hidden_states=False)
last_hidden_states = outputs.last_hidden_state
input_mask_expanded = (
inputs["attention_mask"].unsqueeze(-1).expand(last_hidden_states.size()).float()
)
sum_embeddings = torch.sum(last_hidden_states * input_mask_expanded, 1)
sum_mask = input_mask_expanded.sum(1)
mean_pooled = sum_embeddings / sum_mask
mean_pooled_batch.extend(mean_pooled.cpu().detach().numpy())
return np.array(mean_pooled_batch)
def is_user_query_valid(user_query: str) -> bool:
"""
Check if the `user_query` is None and not empty.
Args:
user_query (str): The input text to be checked.
Returns:
bool: True if the `user_query` is None or empty, False otherwise.
"""
# If the query is not part of the default queries
is_default_query = user_query in DEFAULT_QUERIES.values()
# Check if the query exceeds the length limit
is_exceeded_max_length = user_query is not None and len(user_query) <= MAX_USER_QUERY_LEN
return not is_default_query and not is_exceeded_max_length
def compare_texts_ignoring_extra_spaces(original_text, modified_text):
"""Check if the modified_text is identical to the original_text except for additional spaces.
Args:
original_text (str): The original text for comparison.
modified_text (str): The modified text to compare against the original.
Returns:
(bool): True if the modified_text is the same as the original_text except for
additional spaces; False otherwise.
"""
normalized_original = " ".join(original_text.split())
normalized_modified = " ".join(modified_text.split())
return normalized_original == normalized_modified
def is_strict_deletion_only(original_text, modified_text):
# Define a regex pattern that matches a word character next to a punctuation
# or a punctuation next to a word character, without a space between them.
pattern = r"(?<=[\w])(?=[^\w\s])|(?<=[^\w\s])(?=[\w])"
# Replace instances found by the pattern with a space
original_text = re.sub(pattern, " ", original_text)
modified_text = re.sub(pattern, " ", modified_text)
# Tokenize the texts into words, considering also punctuation
original_words = Counter(original_text.lower().split())
modified_words = Counter(modified_text.lower().split())
base_words = all(item in original_words.keys() for item in modified_words.keys())
base_count = all(original_words[k] >= v for k, v in modified_words.items())
return base_words and base_count
def read_txt(file_path):
"""Read text from a file."""
with open(file_path, "r", encoding="utf-8") as file:
return file.read()
def write_txt(file_path, data):
"""Write text to a file."""
with open(file_path, "w", encoding="utf-8") as file:
file.write(data)
def write_pickle(file_path, data):
"""Save data to a pickle file."""
with open(file_path, "wb") as f:
pkl.dump(data, f)
def read_pickle(file_name):
"""Load data from a pickle file."""
with open(file_name, "rb") as file:
return pkl.load(file)
def read_json(file_name):
"""Load data from a json file."""
with open(file_name, "r") as file:
return json.load(file)
def write_json(file_name, data):
"""Save data to a json file."""
with open(file_name, "w", encoding="utf-8") as file:
json.dump(data, file, indent=4, sort_keys=True)
def write_bytes(path, data):
"""Save binary data."""
with path.open("wb") as f:
f.write(data)
def read_bytes(path):
"""Load data from a binary file."""
with path.open("rb") as f:
return f.read()
|