|
import warnings |
|
|
|
import matplotlib.colors as mcolors |
|
import matplotlib.pyplot as plt |
|
import streamlit as st |
|
from transformers import AutoModelForTokenClassification, AutoTokenizer, logging, pipeline |
|
|
|
warnings.simplefilter(action="ignore", category=Warning) |
|
logging.set_verbosity(logging.ERROR) |
|
|
|
st.set_page_config(page_title="CAROLL Language Models - Demo", layout="wide") |
|
|
|
st.markdown( |
|
""" |
|
<style> |
|
body { |
|
font-family: 'Poppins', sans-serif; |
|
background-color: #f4f4f8; |
|
} |
|
.header { |
|
background-color: rgba(220, 219, 219, 0.25); |
|
color: #000; |
|
padding: 5px 0; |
|
text-align: center; |
|
border-radius: 7px; |
|
margin-bottom: 13px; |
|
border-bottom: 2px solid #333; |
|
} |
|
#logo { |
|
width: auto; |
|
height: 75px; |
|
margin-top: -15px; |
|
margin-bottom: 15px; |
|
} |
|
.container { |
|
background-color: #fff; |
|
padding: 30px; |
|
border-radius: 10px; |
|
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1); |
|
width: 100%; |
|
max-width: 1000px; |
|
margin: 0 auto; |
|
position: absolute; |
|
top: 50%; |
|
left: 50%; |
|
transform: translate(-50%, -50%); |
|
} |
|
.btn-primary { |
|
background-color: #5477d1; |
|
border: none; |
|
transition: background-color 0.3s, transform 0.2s; |
|
border-radius: 25px; |
|
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.08); |
|
} |
|
.btn-primary:hover { |
|
background-color: #4c6cbe; |
|
transform: translateY(-1px); |
|
} |
|
h2 { |
|
font-weight: 600; |
|
font-size: 24px; |
|
margin-bottom: 20px; |
|
} |
|
h4 { |
|
font-weight: 500; |
|
font-size: 15px; |
|
margin-top: 15px; |
|
margin-bottom: 15px; |
|
} |
|
label { |
|
font-weight: 500; |
|
} |
|
</style> |
|
""", |
|
unsafe_allow_html=True, |
|
) |
|
|
|
st.markdown( |
|
""" |
|
<div class="header"> |
|
<img src="https://raw.githubusercontent.com/ca-roll/ca-roll.github.io/release/images/logopic/caroll.png" alt="Research Group Logo" id="logo"> |
|
<h4>Demonstrating <a href="https://ca-roll.github.io/" target="_blank">CAROLL Research Group</a>'s Language Models</h4> |
|
</div> |
|
""", |
|
unsafe_allow_html=True, |
|
) |
|
|
|
|
|
tokenizer_legal = AutoTokenizer.from_pretrained("PaDaS-Lab/gbert-legal-ner") |
|
model_legal = AutoModelForTokenClassification.from_pretrained( |
|
"PaDaS-Lab/gbert-legal-ner" |
|
) |
|
ner_legal = pipeline("ner", model=model_legal, tokenizer=tokenizer_legal) |
|
|
|
|
|
tokenizer_gdpr = AutoTokenizer.from_pretrained("PaDaS-Lab/gdpr-privacy-policy-ner") |
|
model_gdpr = AutoModelForTokenClassification.from_pretrained( |
|
"PaDaS-Lab/gdpr-privacy-policy-ner" |
|
) |
|
ner_gdpr = pipeline("ner", model=model_gdpr, tokenizer=tokenizer_gdpr) |
|
|
|
|
|
classes_legal = { |
|
"AN": "Lawyer", |
|
"EUN": "European legal norm", |
|
"GRT": "Court", |
|
"GS": "Law", |
|
"INN": "Institution", |
|
"LD": "Country", |
|
"LDS": "Landscape", |
|
"LIT": "Legal literature", |
|
"MRK": "Brand", |
|
"ORG": "Organization", |
|
"PER": "Person", |
|
"RR": "Judge", |
|
"RS": "Court decision", |
|
"ST": "City", |
|
"STR": "Street", |
|
"UN": "Company", |
|
"VO": "Ordinance", |
|
"VS": "Regulation", |
|
"VT": "Contract", |
|
} |
|
classes_gdpr = { |
|
"DC": "Data Controller", |
|
"DP": "Data Processor", |
|
"DPO": "Data Protection Officer", |
|
"R": "Recipient", |
|
"TP": "Third Party", |
|
"A": "Authority", |
|
"DS": "Data Subject", |
|
"DSO": "Data Source", |
|
"RP": "Required Purpose", |
|
"NRP": "Not-Required Purpose", |
|
"P": "Processing", |
|
"NPD": "Non-Personal Data", |
|
"PD": "Personal Data", |
|
"OM": "Organisational Measure", |
|
"TM": "Technical Measure", |
|
"LB": "Legal Basis", |
|
"CONS": "Consent", |
|
"CONT": "Contract", |
|
"LI": "Legitimate Interest", |
|
"ADM": "Automated Decision Making", |
|
"RET": "Retention", |
|
"SEU": "Scale EU", |
|
"SNEU": "Scale Non-EU", |
|
"RI": "Right", |
|
"DSR15": "Art. 15 Right of access by the data subject", |
|
"DSR16": "Art. 16 Right to rectification", |
|
"DSR17": "Art. 17 Right to erasure (‘right to be forgotten’)", |
|
"DSR18": "Art. 18 Right to restriction of processing", |
|
"DSR19": "Art. 19 Notification obligation regarding rectification or erasure of personal data or restriction of processing", |
|
"DSR20": "Art. 20 Right to data portability", |
|
"DSR21": "Art. 21 Right to object", |
|
"DSR22": "Art. 22 Automated individual decision-making, including profiling", |
|
"LC": "Lodge Complaint", |
|
} |
|
|
|
|
|
ner_labels_legal = list(classes_legal.keys()) |
|
ner_labels_gdpr = list(classes_gdpr.keys()) |
|
|
|
|
|
|
|
def generate_colors(num_colors): |
|
cm = plt.get_cmap("tab20") |
|
colors = [mcolors.rgb2hex(cm(1.0 * i / num_colors)) for i in range(num_colors)] |
|
return colors |
|
|
|
|
|
|
|
def color_substrings(input_string, model_output, ner_labels, current_classes): |
|
colors = generate_colors(len(ner_labels)) |
|
label_to_color = { |
|
label: colors[i % len(colors)] for i, label in enumerate(ner_labels) |
|
} |
|
|
|
last_end = 0 |
|
html_output = "" |
|
|
|
for entity in sorted(model_output, key=lambda x: x["start"]): |
|
start, end, label = entity["start"], entity["end"], entity["label"] |
|
html_output += input_string[last_end:start] |
|
tooltip = current_classes.get(label, "") |
|
html_output += f'<span style="color: {label_to_color.get(label)}; font-weight: bold;" title="{tooltip}">{input_string[start:end]}</span>' |
|
last_end = end |
|
|
|
html_output += input_string[last_end:] |
|
|
|
return html_output |
|
|
|
|
|
st.title("CAROLL Language Models - Demo") |
|
st.markdown("<hr>", unsafe_allow_html=True) |
|
|
|
test_sentence = st.text_area("Enter Text:", height=200) |
|
model_choice = st.selectbox( |
|
"Choose a model:", ["Legal NER", "GDPR Privacy Policy NER"], index=0 |
|
) |
|
|
|
if st.button("Analyze"): |
|
if model_choice == "Legal NER": |
|
ner_model = ner_legal |
|
current_classes = classes_legal |
|
current_ner_labels = ner_labels_legal |
|
else: |
|
ner_model = ner_gdpr |
|
current_classes = classes_gdpr |
|
current_ner_labels = ner_labels_gdpr |
|
|
|
results = ner_model(test_sentence) |
|
processed_results = [ |
|
{ |
|
"start": result["start"], |
|
"end": result["end"], |
|
"label": result["entity"].split("-")[-1], |
|
} |
|
for result in results |
|
] |
|
|
|
colored_html = color_substrings( |
|
test_sentence, processed_results, current_ner_labels, current_classes |
|
) |
|
|
|
st.markdown( |
|
"<strong>- Original text -</strong><br><br>{}".format(test_sentence), |
|
unsafe_allow_html=True, |
|
) |
|
st.markdown( |
|
"<strong>- Analyzed text -</strong><br><br>{}".format(colored_html), |
|
unsafe_allow_html=True, |
|
) |
|
st.markdown( |
|
"<mark><strong>Tip:</strong> Hover over the colored words to see its class.</mark>", |
|
unsafe_allow_html=True, |
|
) |
|
|