import warnings import matplotlib.colors as mcolors import matplotlib.pyplot as plt import streamlit as st from transformers import AutoModelForTokenClassification, AutoTokenizer, logging, pipeline warnings.simplefilter(action="ignore", category=Warning) logging.set_verbosity(logging.ERROR) st.set_page_config(page_title="CAROLL Language Models - Demo", layout="wide") st.markdown( """ """, unsafe_allow_html=True, ) st.markdown( """

Demonstrating CAROLL Research Group's Language Models

""", unsafe_allow_html=True, ) # Initialization for Legal NER tokenizer_legal = AutoTokenizer.from_pretrained("PaDaS-Lab/gbert-legal-ner") model_legal = AutoModelForTokenClassification.from_pretrained( "PaDaS-Lab/gbert-legal-ner" ) ner_legal = pipeline("ner", model=model_legal, tokenizer=tokenizer_legal) # Initialization for GDPR Privacy Policy NER tokenizer_gdpr = AutoTokenizer.from_pretrained("PaDaS-Lab/gdpr-privacy-policy-ner") model_gdpr = AutoModelForTokenClassification.from_pretrained( "PaDaS-Lab/gdpr-privacy-policy-ner" ) ner_gdpr = pipeline("ner", model=model_gdpr, tokenizer=tokenizer_gdpr) # Define class labels for Legal and GDPR NER models classes_legal = { "AN": "Lawyer", "EUN": "European legal norm", "GRT": "Court", "GS": "Law", "INN": "Institution", "LD": "Country", "LDS": "Landscape", "LIT": "Legal literature", "MRK": "Brand", "ORG": "Organization", "PER": "Person", "RR": "Judge", "RS": "Court decision", "ST": "City", "STR": "Street", "UN": "Company", "VO": "Ordinance", "VS": "Regulation", "VT": "Contract", } classes_gdpr = { "DC": "Data Controller", "DP": "Data Processor", "DPO": "Data Protection Officer", "R": "Recipient", "TP": "Third Party", "A": "Authority", "DS": "Data Subject", "DSO": "Data Source", "RP": "Required Purpose", "NRP": "Not-Required Purpose", "P": "Processing", "NPD": "Non-Personal Data", "PD": "Personal Data", "OM": "Organisational Measure", "TM": "Technical Measure", "LB": "Legal Basis", "CONS": "Consent", "CONT": "Contract", "LI": "Legitimate Interest", "ADM": "Automated Decision Making", "RET": "Retention", "SEU": "Scale EU", "SNEU": "Scale Non-EU", "RI": "Right", "DSR15": "Art. 15 Right of access by the data subject", "DSR16": "Art. 16 Right to rectification", "DSR17": "Art. 17 Right to erasure (‘right to be forgotten’)", "DSR18": "Art. 18 Right to restriction of processing", "DSR19": "Art. 19 Notification obligation regarding rectification or erasure of personal data or restriction of processing", "DSR20": "Art. 20 Right to data portability", "DSR21": "Art. 21 Right to object", "DSR22": "Art. 22 Automated individual decision-making, including profiling", "LC": "Lodge Complaint", } # Extract the keys (labels) from the class dictionaries ner_labels_legal = list(classes_legal.keys()) ner_labels_gdpr = list(classes_gdpr.keys()) # Function to generate a list of colors for visualization def generate_colors(num_colors): cm = plt.get_cmap("tab20") colors = [mcolors.rgb2hex(cm(1.0 * i / num_colors)) for i in range(num_colors)] return colors # Function to color substrings based on NER results def color_substrings(input_string, model_output, ner_labels, current_classes): colors = generate_colors(len(ner_labels)) label_to_color = { label: colors[i % len(colors)] for i, label in enumerate(ner_labels) } last_end = 0 html_output = "" for entity in sorted(model_output, key=lambda x: x["start"]): start, end, label = entity["start"], entity["end"], entity["label"] html_output += input_string[last_end:start] tooltip = current_classes.get(label, "") html_output += f'{input_string[start:end]}' last_end = end html_output += input_string[last_end:] return html_output st.title("CAROLL Language Models - Demo") st.markdown("
", unsafe_allow_html=True) test_sentence = st.text_area("Enter Text:", height=200) model_choice = st.selectbox( "Choose a model:", ["Legal NER", "GDPR Privacy Policy NER"], index=0 ) if st.button("Analyze"): if model_choice == "Legal NER": ner_model = ner_legal current_classes = classes_legal current_ner_labels = ner_labels_legal else: ner_model = ner_gdpr current_classes = classes_gdpr current_ner_labels = ner_labels_gdpr results = ner_model(test_sentence) processed_results = [ { "start": result["start"], "end": result["end"], "label": result["entity"].split("-")[-1], } for result in results ] colored_html = color_substrings( test_sentence, processed_results, current_ner_labels, current_classes ) st.markdown( "- Original text -

{}".format(test_sentence), unsafe_allow_html=True, ) st.markdown( "- Analyzed text -

{}".format(colored_html), unsafe_allow_html=True, ) st.markdown( "Tip: Hover over the colored words to see its class.", unsafe_allow_html=True, )