disham993 commited on
Commit
b4171e7
Β·
1 Parent(s): 6c21374

First Commit.

Browse files
Files changed (2) hide show
  1. app.py +254 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
3
+ import pandas as pd
4
+ from spacy import displacy
5
+
6
+ ###########################
7
+ # Utility Function for Cleanup
8
+ ###########################
9
+ def clean_and_group_entities(ner_results, min_score=0.40):
10
+ """
11
+ Combines tokens for the same entity and filters out entities below the score threshold.
12
+ """
13
+ grouped_entities = []
14
+ current_entity = None
15
+
16
+ for result in ner_results:
17
+ # Skip entities with a score below threshold
18
+ if result["score"] < min_score:
19
+ if current_entity:
20
+ # If the current entity meets threshold, add it
21
+ if current_entity["score"] >= min_score:
22
+ grouped_entities.append(current_entity)
23
+ current_entity = None
24
+ continue
25
+
26
+ # Remove any subword prefix "##"
27
+ word = result["word"].replace("##", "")
28
+
29
+ # Check if this result continues the current entity
30
+ if (current_entity
31
+ and result["entity_group"] == current_entity["entity_group"]
32
+ and result["start"] == current_entity["end"]):
33
+
34
+ # Update the current entity
35
+ current_entity["word"] += word
36
+ current_entity["end"] = result["end"]
37
+ # Keep the minimum score as the "weakest link"
38
+ current_entity["score"] = min(current_entity["score"], result["score"])
39
+
40
+ # If combined score now drops below threshold, discard the entity
41
+ if current_entity["score"] < min_score:
42
+ current_entity = None
43
+ else:
44
+ # Finalize the previous entity if valid
45
+ if current_entity and current_entity["score"] >= min_score:
46
+ grouped_entities.append(current_entity)
47
+
48
+ # Start a new entity
49
+ current_entity = {
50
+ "entity_group": result["entity_group"],
51
+ "word": word,
52
+ "start": result["start"],
53
+ "end": result["end"],
54
+ "score": result["score"]
55
+ }
56
+
57
+ # Add the last entity if it meets threshold
58
+ if current_entity and current_entity["score"] >= min_score:
59
+ grouped_entities.append(current_entity)
60
+
61
+ return grouped_entities
62
+
63
+ ###########################
64
+ # Constants and Setup
65
+ ###########################
66
+ MODELS = {
67
+ "ModernBERT Base": "disham993/electrical-ner-modernbert-base",
68
+ "BERT Base": "disham993/electrical-ner-bert-base",
69
+ "ModernBERT Large": "disham993/electrical-ner-modernbert-large",
70
+ "BERT Large": "disham993/electrical-ner-bert-large",
71
+ "DistilBERT Base": "disham993/electrical-ner-distilbert-base"
72
+ }
73
+
74
+ ENTITY_COLORS = {
75
+ "COMPONENT": "#FFB6C1",
76
+ "DESIGN_PARAM": "#98FB98",
77
+ "MATERIAL": "#DDA0DD",
78
+ "EQUIPMENT": "#87CEEB",
79
+ "TECHNOLOGY": "#F0E68C",
80
+ "SOFTWARE": "#FFD700",
81
+ "STANDARD": "#FFA07A",
82
+ "VENDOR": "#E6E6FA",
83
+ "PRODUCT": "#98FF98"
84
+ }
85
+
86
+ EXAMPLES = [
87
+ "Texas Instruments LM358 op-amp requires dual power supply.",
88
+ "Using a Multimeter, the technician measured the 10 kΞ© resistance of a Copper wire in the circuit.",
89
+ "To improve the reliability of the circuit, the engineer tested a 10k Ohm resistor with a multimeter from Fluke.",
90
+ "During the circuit design, we measured the current flow using a Fluke multimeter to ensure it was within the 10A specification."
91
+ ]
92
+
93
+ @st.cache_resource
94
+ def load_model(model_name):
95
+ """
96
+ Load and return a token classification pipeline with an aggregation strategy of 'simple'.
97
+ """
98
+ try:
99
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
100
+ model = AutoModelForTokenClassification.from_pretrained(model_name)
101
+ return pipeline(
102
+ "ner",
103
+ model=model,
104
+ tokenizer=tokenizer,
105
+ aggregation_strategy="simple" # <-- Aggregation strategy
106
+ )
107
+ except Exception as e:
108
+ st.error(f"Error loading model: {str(e)}")
109
+ return None
110
+
111
+ def get_base_entity_type(entity_label):
112
+ """
113
+ Strips off 'B-' or 'I-' prefix if present.
114
+ """
115
+ if entity_label.startswith("B-") or entity_label.startswith("I-"):
116
+ return entity_label[2:]
117
+ return entity_label
118
+
119
+ def create_displacy_data(text, entities):
120
+ """
121
+ Create data for spaCy's displacy visualizer.
122
+ """
123
+ ents = []
124
+ for entity in entities:
125
+ base_type = get_base_entity_type(entity["entity_group"])
126
+ ents.append({
127
+ "start": entity["start"],
128
+ "end": entity["end"],
129
+ "label": base_type
130
+ })
131
+
132
+ colors = {entity_type: color for entity_type, color in ENTITY_COLORS.items()}
133
+ options = {"ents": list(ENTITY_COLORS.keys()), "colors": colors}
134
+
135
+ doc_data = {
136
+ "text": text,
137
+ "ents": ents,
138
+ "title": None
139
+ }
140
+
141
+ # Render with manual mode = True
142
+ html_content = displacy.render(doc_data, style="ent", options=options, manual=True)
143
+ return html_content
144
+
145
+ ###########################
146
+ # Main Streamlit App
147
+ ###########################
148
+ def main():
149
+ st.set_page_config(page_title="Electrical Engineering NER", page_icon="⚑", layout="wide")
150
+
151
+ st.title("⚑ Electrical Engineering Named Entity Recognition")
152
+ st.markdown("""
153
+ This application identifies technical entities in electrical engineering text using a fine-tuned BERT model.
154
+ It can recognize components, parameters, materials, equipment, and more.
155
+ """)
156
+
157
+ # Sidebar - Model Selection
158
+ st.sidebar.title("Model Configuration")
159
+ selected_model_name = st.sidebar.selectbox(
160
+ "Select Model",
161
+ list(MODELS.keys()),
162
+ help="Choose which model to use for entity recognition"
163
+ )
164
+
165
+ with st.sidebar.expander("Model Details"):
166
+ st.write(f"**Model Path:** {MODELS[selected_model_name]}")
167
+ st.write("This model is fine-tuned specifically for the electrical engineering domain.")
168
+
169
+ # Confidence threshold
170
+ score_threshold = st.sidebar.slider(
171
+ 'Minimum confidence threshold',
172
+ min_value=0.0,
173
+ max_value=1.0,
174
+ value=0.40,
175
+ step=0.01
176
+ )
177
+
178
+ # Load selected model
179
+ model = load_model(MODELS[selected_model_name])
180
+
181
+ if model is None:
182
+ st.error("Failed to load model. Please try selecting a different model.")
183
+ return
184
+
185
+ # Create a form to collect user text and an Analyze button
186
+ with st.form(key="text_form"):
187
+ st.subheader("Try an example or enter your own text")
188
+ example_idx = st.selectbox(
189
+ "Select an example:",
190
+ range(len(EXAMPLES)),
191
+ format_func=lambda x: EXAMPLES[x][:100] + "..."
192
+ )
193
+
194
+ text_input = st.text_area(
195
+ "Enter text for analysis:",
196
+ value=EXAMPLES[example_idx],
197
+ height=100
198
+ )
199
+
200
+ # This button triggers form submission
201
+ submit_button = st.form_submit_button(label="Analyze")
202
+
203
+ # Only run inference after the user clicks "Analyze"
204
+ if submit_button and text_input.strip():
205
+ with st.spinner("Analyzing text..."):
206
+ try:
207
+ raw_entities = model(text_input)
208
+ cleaned_entities = clean_and_group_entities(raw_entities, min_score=score_threshold)
209
+
210
+ # Visualization
211
+ st.subheader("Identified Entities")
212
+ html_content = create_displacy_data(text_input, cleaned_entities)
213
+ st.markdown(html_content, unsafe_allow_html=True)
214
+
215
+ # Create DataFrame
216
+ if cleaned_entities:
217
+ df = pd.DataFrame(cleaned_entities).round({"score": 3})
218
+
219
+ df = df.rename(columns={
220
+ "entity_group": "Entity Type",
221
+ "word": "Text",
222
+ "score": "Confidence",
223
+ "start": "Start",
224
+ "end": "End"
225
+ })
226
+
227
+ st.subheader("Entity Details")
228
+ st.dataframe(df)
229
+
230
+ st.subheader("Entity Distribution")
231
+ entity_counts = df["Entity Type"].value_counts()
232
+ st.bar_chart(entity_counts)
233
+ else:
234
+ st.info("No entities detected in the text (or all below threshold).")
235
+
236
+ except Exception as e:
237
+ st.error(f"Error processing text: {str(e)}")
238
+
239
+ # Entity type legend
240
+ st.sidebar.title("Entity Types")
241
+ st.sidebar.markdown("""
242
+ - πŸ”§ **COMPONENT**: Circuit elements
243
+ - πŸ“Š **DESIGN_PARAM**: Values, measurements
244
+ - 🧱 **MATERIAL**: Physical materials
245
+ - πŸ”Œ **EQUIPMENT**: Testing equipment
246
+ - πŸ’» **TECHNOLOGY**: Tech implementations
247
+ - πŸ’Ύ **SOFTWARE**: Software tools
248
+ - πŸ“œ **STANDARD**: Technical standards
249
+ - 🏒 **VENDOR**: Manufacturers
250
+ - πŸ“¦ **PRODUCT**: Specific products
251
+ """)
252
+
253
+ if __name__ == "__main__":
254
+ main()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ git+https://github.com/huggingface/transformers.git@6e0515e99c39444caae39472ee1b2fd76ece32f1
2
+ streamlit
3
+ spacy
4
+ pandas
5
+ torch