import streamlit as st # from gliner import GLiNER from datasets import load_dataset from peft import PeftModel, PeftConfig import threading import time import torch from torch.profiler import profile, record_function, ProfilerActivity from transformers import DebertaV2ForTokenClassification, DebertaV2Tokenizer, pipeline def predict_entities(text, labels, entity_set): if labels == []: entities = recognizer(text) for entity in entities: if entity['entity'] in entity_set: entity_set[entity['entity']] += 1 else: entity_set[entity['entity']] = 1 else: # Use Gliner labels entities = model.predict_entities(text, labels, threshold = 0.7) for entity in entities: if entity['label'] in entity_set: entity_set[entity['label']] += 1 else: entity_set[entity['label']] = 1 def process_datasets(start, end, unmasked_text, sizes, index, entity_set, labels): size = 0 text = "" for i in range(start, end): if len(text) < 700: text = text + " " + unmasked_text[i] else: size += len(text) predict_entities(text, labels, entity_set) text = unmasked_text[i] sizes[index] = size print(f"Is CUDA available: {torch.cuda.is_available()}") # True if torch.cuda.is_available(): print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") # Load the fine-tuned GLiNER model st.write('Loading the pretrained model ...') model_name = "CarolXia/pii-kd-deberta-v2" # config = PeftConfig.from_pretrained(model_name) model = DebertaV2ForTokenClassification.from_pretrained(model_name, token=st.secrets["HUGGINGFACE_TOKEN"]) if torch.cuda.is_available(): model = model.to("cuda") # Try quantization instead # model = AutoModelForTokenClassification.from_pretrained(model_name, device_map="auto", load_in_8bit=True) tokenizer = DebertaV2Tokenizer.from_pretrained("microsoft/mdeberta-v3-base", token=st.secrets["HUGGINGFACE_TOKEN"]) recognizer = pipeline("ner", model=model, tokenizer=tokenizer) # model_name = "urchade/gliner_multi_pii-v1" # model = GLiNER.from_pretrained(model_name) # print weights pytorch_total_params = sum(p.numel() for p in model.parameters()) torch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f'total params: {pytorch_total_params}. tunable params: {torch_total_params}') # Sample text containing PII/PHI entities text = """ Hello Jane Doe. Your AnyCompany Financial Services, LLC credit card account 4111-0000-1111-0000 has a minimum payment of $24.53 that is due by July 31st. Based on your autopay settings, we will withdraw your payment on the due date from your bank account XXXXXX1111 with the routing number XXXXX0000. Your latest statement was mailed to 100 Main Street, Anytown, WA 98121. After your payment is received, you will receive a confirmation text message at 206-555-0100. If you have questions about your bill, AnyCompany Customer Service is available by phone at 206-555-0199 or email at support@anycompany.com. """ # Define the labels for PII/PHI entities labels = [ "medical_record_number", "date_of_birth", "ssn", "date", "first_name", "email", "last_name", "customer_id", "employee_id", "name", "street_address", "phone_number", "ipv4", "credit_card_number", "license_plate", "address", "user_name", "device_identifier", "bank_routing_number", "date_time", "company_name", "unique_identifier", "biometric_identifier", "account_number", "city", "certificate_license_number", "time", "postcode", "vehicle_identifier", "coordinate", "country", "api_key", "ipv6", "password", "health_plan_beneficiary_number", "national_id", "tax_id", "url", "state", "swift_bic", "cvv", "pin" ] st.write('Trying a sample first') st.write(text) # Predict entities with a confidence threshold of 0.7 # entities = model.predict_entities(text, labels, threshold=0.7) entities = recognizer(text) # Display the detected entities for entity in entities: st.write(entity) st.write('Processing the full dataset now ...') entity_set=dict() dataset = load_dataset("Isotonic/pii-masking-200k", split="train") unmasked_text = dataset['unmasked_text'] # This will load the entire column inmemory. Must do this to avoid I/O delay later st.write('Number of rows in the dataset ', dataset.num_rows) sizes = [0] * 5 start = time.time() t0 = threading.Thread(target=process_datasets, args=(0, 10, unmasked_text, sizes, 0, entity_set, [])) t1 = threading.Thread(target=process_datasets, args=(10, 20, unmasked_text, sizes, 1, entity_set, [])) t2 = threading.Thread(target=process_datasets, args=(20, 30, unmasked_text, sizes, 2, entity_set, [])) t3 = threading.Thread(target=process_datasets, args=(30, 40, unmasked_text, sizes, 3, entity_set, [])) t4 = threading.Thread(target=process_datasets, args=(40, 50, unmasked_text, sizes, 4, entity_set, [])) # with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof: # process_datasets(0, 50, unmasked_text, sizes, 0, entity_set, []) t0.start() t1.start() t2.start() t3.start() t4.start() t0.join() t1.join() t2.join() t3.join() t4.join() end = time.time() length = end - start # Show the results : this can be altered however you like st.write('Bytes processed ', sum(sizes)) st.write("It took", length, "seconds!") # Display the summary st.write('Total entities found') for key in entity_set: st.write(key, ' => ', entity_set[key]) st.write(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))