import streamlit as st from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline import pandas as pd from spacy import displacy ########################### # Utility Function for Cleanup ########################### def clean_and_group_entities(ner_results, min_score=0.40): """ Combines tokens for the same entity and filters out entities below the score threshold. """ grouped_entities = [] current_entity = None for result in ner_results: # Skip entities with a score below threshold if result["score"] < min_score: if current_entity: # If the current entity meets threshold, add it if current_entity["score"] >= min_score: grouped_entities.append(current_entity) current_entity = None continue # Remove any subword prefix "##" word = result["word"].replace("##", "") # Check if this result continues the current entity if (current_entity and result["entity_group"] == current_entity["entity_group"] and result["start"] == current_entity["end"]): # Update the current entity current_entity["word"] += word current_entity["end"] = result["end"] # Keep the minimum score as the "weakest link" current_entity["score"] = min(current_entity["score"], result["score"]) # If combined score now drops below threshold, discard the entity if current_entity["score"] < min_score: current_entity = None else: # Finalize the previous entity if valid if current_entity and current_entity["score"] >= min_score: grouped_entities.append(current_entity) # Start a new entity current_entity = { "entity_group": result["entity_group"], "word": word, "start": result["start"], "end": result["end"], "score": result["score"] } # Add the last entity if it meets threshold if current_entity and current_entity["score"] >= min_score: grouped_entities.append(current_entity) return grouped_entities ########################### # Constants and Setup ########################### MODELS = { "ModernBERT Base": "disham993/electrical-ner-modernbert-base", "BERT Base": "disham993/electrical-ner-bert-base", "ModernBERT Large": "disham993/electrical-ner-modernbert-large", "BERT Large": "disham993/electrical-ner-bert-large", "DistilBERT Base": "disham993/electrical-ner-distilbert-base" } ENTITY_COLORS = { "COMPONENT": "#FFB6C1", "DESIGN_PARAM": "#98FB98", "MATERIAL": "#DDA0DD", "EQUIPMENT": "#87CEEB", "TECHNOLOGY": "#F0E68C", "SOFTWARE": "#FFD700", "STANDARD": "#FFA07A", "VENDOR": "#E6E6FA", "PRODUCT": "#98FF98" } EXAMPLES = [ "Texas Instruments LM358 op-amp requires dual power supply.", "Using a Multimeter, the technician measured the 10 kΩ resistance of a Copper wire in the circuit.", "To improve the reliability of the circuit, the engineer tested a 10k Ohm resistor with a multimeter from Fluke.", "During the circuit design, we measured the current flow using a Fluke multimeter to ensure it was within the 10A specification." ] @st.cache_resource def load_model(model_name): """ Load and return a token classification pipeline with an aggregation strategy of 'simple'. """ try: tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForTokenClassification.from_pretrained(model_name) return pipeline( "ner", model=model, tokenizer=tokenizer, aggregation_strategy="simple" # <-- Aggregation strategy ) except Exception as e: st.error(f"Error loading model: {str(e)}") return None def get_base_entity_type(entity_label): """ Strips off 'B-' or 'I-' prefix if present. """ if entity_label.startswith("B-") or entity_label.startswith("I-"): return entity_label[2:] return entity_label def create_displacy_data(text, entities): """ Create data for spaCy's displacy visualizer. """ ents = [] for entity in entities: base_type = get_base_entity_type(entity["entity_group"]) ents.append({ "start": entity["start"], "end": entity["end"], "label": base_type }) colors = {entity_type: color for entity_type, color in ENTITY_COLORS.items()} options = {"ents": list(ENTITY_COLORS.keys()), "colors": colors} doc_data = { "text": text, "ents": ents, "title": None } # Render with manual mode = True html_content = displacy.render(doc_data, style="ent", options=options, manual=True) return html_content ########################### # Main Streamlit App ########################### def main(): st.set_page_config(page_title="Electrical Engineering NER", page_icon="⚡", layout="wide") st.title("⚡ Electrical Engineering Named Entity Recognition") st.markdown(""" This application identifies technical entities in electrical engineering text using a fine-tuned BERT model. It can recognize components, parameters, materials, equipment, and more. """) # Sidebar - Model Selection st.sidebar.title("Model Configuration") selected_model_name = st.sidebar.selectbox( "Select Model", list(MODELS.keys()), help="Choose which model to use for entity recognition" ) with st.sidebar.expander("Model Details"): st.write(f"**Model Path:** {MODELS[selected_model_name]}") st.write("This model is fine-tuned specifically for the electrical engineering domain.") # Confidence threshold score_threshold = st.sidebar.slider( 'Minimum confidence threshold', min_value=0.0, max_value=1.0, value=0.40, step=0.01 ) # Load selected model model = load_model(MODELS[selected_model_name]) if model is None: st.error("Failed to load model. Please try selecting a different model.") return # Create a form to collect user text and an Analyze button with st.form(key="text_form"): st.subheader("Try an example or enter your own text") example_idx = st.selectbox( "Select an example:", range(len(EXAMPLES)), format_func=lambda x: EXAMPLES[x][:100] + "..." ) text_input = st.text_area( "Enter text for analysis:", value=EXAMPLES[example_idx], height=100 ) # This button triggers form submission submit_button = st.form_submit_button(label="Analyze") # Only run inference after the user clicks "Analyze" if submit_button and text_input.strip(): with st.spinner("Analyzing text..."): try: raw_entities = model(text_input) cleaned_entities = clean_and_group_entities(raw_entities, min_score=score_threshold) # Visualization st.subheader("Identified Entities") html_content = create_displacy_data(text_input, cleaned_entities) st.markdown(html_content, unsafe_allow_html=True) # Create DataFrame if cleaned_entities: df = pd.DataFrame(cleaned_entities).round({"score": 3}) df = df.rename(columns={ "entity_group": "Entity Type", "word": "Text", "score": "Confidence", "start": "Start", "end": "End" }) st.subheader("Entity Details") st.dataframe(df) st.subheader("Entity Distribution") entity_counts = df["Entity Type"].value_counts() st.bar_chart(entity_counts) else: st.info("No entities detected in the text (or all below threshold).") except Exception as e: st.error(f"Error processing text: {str(e)}") # Entity type legend st.sidebar.title("Entity Types") st.sidebar.markdown(""" - 🔧 **COMPONENT**: Circuit elements - 📊 **DESIGN_PARAM**: Values, measurements - 🧱 **MATERIAL**: Physical materials - 🔌 **EQUIPMENT**: Testing equipment - 💻 **TECHNOLOGY**: Tech implementations - 💾 **SOFTWARE**: Software tools - 📜 **STANDARD**: Technical standards - 🏢 **VENDOR**: Manufacturers - 📦 **PRODUCT**: Specific products """) if __name__ == "__main__": main()