CarolXia's picture
single thread 50
7961c58
import streamlit as st
# from gliner import GLiNER
from datasets import load_dataset
from peft import PeftModel, PeftConfig
import threading
import time
import torch
from torch.profiler import profile, record_function, ProfilerActivity
from transformers import DebertaV2ForTokenClassification, DebertaV2Tokenizer, pipeline
def predict_entities(text, labels, entity_set):
if labels == []:
entities = recognizer(text)
for entity in entities:
if entity['entity'] in entity_set:
entity_set[entity['entity']] += 1
else:
entity_set[entity['entity']] = 1
else:
# Use Gliner labels
entities = model.predict_entities(text, labels, threshold = 0.7)
for entity in entities:
if entity['label'] in entity_set:
entity_set[entity['label']] += 1
else:
entity_set[entity['label']] = 1
def process_datasets(start, end, unmasked_text, sizes, index, entity_set, labels):
size = 0
text = ""
for i in range(start, end):
if len(text) < 700:
text = text + " " + unmasked_text[i]
else:
size += len(text)
predict_entities(text, labels, entity_set)
text = unmasked_text[i]
sizes[index] = size
print(f"Is CUDA available: {torch.cuda.is_available()}")
# True
if torch.cuda.is_available():
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
# Load the fine-tuned GLiNER model
st.write('Loading the pretrained model ...')
model_name = "CarolXia/pii-kd-deberta-v2"
# config = PeftConfig.from_pretrained(model_name)
model = DebertaV2ForTokenClassification.from_pretrained(model_name, token=st.secrets["HUGGINGFACE_TOKEN"])
if torch.cuda.is_available():
model = model.to("cuda")
# Try quantization instead
# model = AutoModelForTokenClassification.from_pretrained(model_name, device_map="auto", load_in_8bit=True)
tokenizer = DebertaV2Tokenizer.from_pretrained("microsoft/mdeberta-v3-base", token=st.secrets["HUGGINGFACE_TOKEN"])
recognizer = pipeline("ner", model=model, tokenizer=tokenizer)
# model_name = "urchade/gliner_multi_pii-v1"
# model = GLiNER.from_pretrained(model_name)
# print weights
pytorch_total_params = sum(p.numel() for p in model.parameters())
torch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'total params: {pytorch_total_params}. tunable params: {torch_total_params}')
# Sample text containing PII/PHI entities
text = """
Hello Jane Doe. Your AnyCompany Financial Services, LLC credit card account
4111-0000-1111-0000 has a minimum payment of $24.53 that is due by July 31st.
Based on your autopay settings, we will withdraw your payment on the due date from
your bank account XXXXXX1111 with the routing number XXXXX0000.
Your latest statement was mailed to 100 Main Street, Anytown, WA 98121.
After your payment is received, you will receive a confirmation text message
at 206-555-0100.
If you have questions about your bill, AnyCompany Customer Service is available by
phone at 206-555-0199 or email at [email protected].
"""
# Define the labels for PII/PHI entities
labels = [
"medical_record_number",
"date_of_birth",
"ssn",
"date",
"first_name",
"email",
"last_name",
"customer_id",
"employee_id",
"name",
"street_address",
"phone_number",
"ipv4",
"credit_card_number",
"license_plate",
"address",
"user_name",
"device_identifier",
"bank_routing_number",
"date_time",
"company_name",
"unique_identifier",
"biometric_identifier",
"account_number",
"city",
"certificate_license_number",
"time",
"postcode",
"vehicle_identifier",
"coordinate",
"country",
"api_key",
"ipv6",
"password",
"health_plan_beneficiary_number",
"national_id",
"tax_id",
"url",
"state",
"swift_bic",
"cvv",
"pin"
]
st.write('Trying a sample first')
st.write(text)
# Predict entities with a confidence threshold of 0.7
# entities = model.predict_entities(text, labels, threshold=0.7)
entities = recognizer(text)
# Display the detected entities
for entity in entities:
st.write(entity)
st.write('Processing the full dataset now ...')
entity_set=dict()
dataset = load_dataset("Isotonic/pii-masking-200k", split="train")
unmasked_text = dataset['unmasked_text'] # This will load the entire column inmemory. Must do this to avoid I/O delay later
st.write('Number of rows in the dataset ', dataset.num_rows)
sizes = [0] * 5
start = time.time()
t0 = threading.Thread(target=process_datasets, args=(0, 10, unmasked_text, sizes, 0, entity_set, []))
t1 = threading.Thread(target=process_datasets, args=(10, 20, unmasked_text, sizes, 1, entity_set, []))
t2 = threading.Thread(target=process_datasets, args=(20, 30, unmasked_text, sizes, 2, entity_set, []))
t3 = threading.Thread(target=process_datasets, args=(30, 40, unmasked_text, sizes, 3, entity_set, []))
t4 = threading.Thread(target=process_datasets, args=(40, 50, unmasked_text, sizes, 4, entity_set, []))
# with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof:
# process_datasets(0, 50, unmasked_text, sizes, 0, entity_set, [])
t0.start()
t1.start()
t2.start()
t3.start()
t4.start()
t0.join()
t1.join()
t2.join()
t3.join()
t4.join()
end = time.time()
length = end - start
# Show the results : this can be altered however you like
st.write('Bytes processed ', sum(sizes))
st.write("It took", length, "seconds!")
# Display the summary
st.write('Total entities found')
for key in entity_set:
st.write(key, ' => ', entity_set[key])
st.write(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))