shreyasmeher commited on
Commit
558b857
·
verified ·
1 Parent(s): 1f74d02

Upload 3 files

Browse files
Files changed (3) hide show
  1. README.md +5 -5
  2. app.py +717 -0
  3. requirements.txt +5 -0
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
- title: ConfliBERT GUI V2
3
- emoji: 🐠
4
- colorFrom: indigo
5
- colorTo: gray
6
  sdk: gradio
7
- sdk_version: 5.13.2
8
  app_file: app.py
9
  pinned: false
10
  ---
 
1
  ---
2
+ title: ConfliBERT Demo
3
+ emoji:
4
+ colorFrom: red
5
+ colorTo: green
6
  sdk: gradio
7
+ sdk_version: 4.38.1
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py ADDED
@@ -0,0 +1,717 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import tensorflow as tf
3
+ from tf_keras import models, layers
4
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification, TFAutoModelForQuestionAnswering
5
+ import gradio as gr
6
+ import re
7
+ import pandas as pd
8
+ import io
9
+
10
+ # Check if GPU is available and use it if possible
11
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
+
13
+ MAX_TOKEN_LENGTH = 512 # Adjust based on your model's limits
14
+
15
+ def truncate_text(text, tokenizer, max_length=MAX_TOKEN_LENGTH):
16
+ """Truncate text to max token length"""
17
+ tokens = tokenizer.encode(text, truncation=False)
18
+ if len(tokens) > max_length:
19
+ tokens = tokens[:max_length-1] + [tokenizer.sep_token_id]
20
+ return tokenizer.decode(tokens, skip_special_tokens=True)
21
+ return text
22
+
23
+ def safe_process(func, text, tokenizer):
24
+ """Safely process text with proper error handling"""
25
+ try:
26
+ truncated_text = truncate_text(text, tokenizer)
27
+ return func(truncated_text)
28
+ except Exception as e:
29
+ error_msg = str(e)
30
+ if 'out of memory' in error_msg.lower():
31
+ return "Error: Text too long for processing"
32
+ elif 'cuda' in error_msg.lower():
33
+ return "Error: GPU processing error"
34
+ else:
35
+ return f"Error: {error_msg}"
36
+
37
+ # Load the models and tokenizers
38
+ qa_model_name = 'salsarra/ConfliBERT-QA'
39
+ qa_model = TFAutoModelForQuestionAnswering.from_pretrained(qa_model_name)
40
+ qa_tokenizer = AutoTokenizer.from_pretrained(qa_model_name)
41
+
42
+ ner_model_name = 'eventdata-utd/conflibert-named-entity-recognition'
43
+ ner_model = AutoModelForTokenClassification.from_pretrained(ner_model_name).to(device)
44
+ ner_tokenizer = AutoTokenizer.from_pretrained(ner_model_name)
45
+
46
+ clf_model_name = 'eventdata-utd/conflibert-binary-classification'
47
+ clf_model = AutoModelForSequenceClassification.from_pretrained(clf_model_name).to(device)
48
+ clf_tokenizer = AutoTokenizer.from_pretrained(clf_model_name)
49
+
50
+ multi_clf_model_name = 'eventdata-utd/conflibert-satp-relevant-multilabel'
51
+ multi_clf_model = AutoModelForSequenceClassification.from_pretrained(multi_clf_model_name).to(device)
52
+ multi_clf_tokenizer = AutoTokenizer.from_pretrained(multi_clf_model_name)
53
+
54
+ # Define the class names for text classification
55
+ class_names = ['Negative', 'Positive']
56
+ multi_class_names = ["Armed Assault", "Bombing or Explosion", "Kidnapping", "Other"] # Updated labels
57
+
58
+ # Define the NER labels and colors
59
+ ner_labels = {
60
+ 'Organisation': 'blue',
61
+ 'Person': 'red',
62
+ 'Location': 'green',
63
+ 'Quantity': 'orange',
64
+ 'Weapon': 'purple',
65
+ 'Nationality': 'cyan',
66
+ 'Temporal': 'magenta',
67
+ 'DocumentReference': 'brown',
68
+ 'MilitaryPlatform': 'yellow',
69
+ 'Money': 'pink'
70
+ }
71
+
72
+ def handle_error_message(e, default_limit=512):
73
+ error_message = str(e)
74
+ pattern = re.compile(r"The size of tensor a \((\d+)\) must match the size of tensor b \((\d+)\)")
75
+ match = pattern.search(error_message)
76
+ if match:
77
+ number_1, number_2 = match.groups()
78
+ return f"<span style='color: red; font-weight: bold;'>Error: Text Input is over limit where inserted text size {number_1} is larger than model limits of {number_2}</span>"
79
+ pattern_qa = re.compile(r"indices\[0,(\d+)\] = \d+ is not in \[0, (\d+)\)")
80
+ match_qa = pattern_qa.search(error_message)
81
+ if match_qa:
82
+ number_1, number_2 = match_qa.groups()
83
+ return f"<span style='color: red; font-weight: bold;'>Error: Text Input is over limit where inserted text size {number_1} is larger than model limits of {number_2}</span>"
84
+ return f"<span style='color: red; font-weight: bold;'>Error: Text Input is over limit where inserted text size is larger than model limits of {default_limit}</span>"
85
+
86
+ # Define the functions for each task
87
+ def question_answering(context, question):
88
+ try:
89
+ inputs = qa_tokenizer(question, context, return_tensors='tf', truncation=True)
90
+ outputs = qa_model(inputs)
91
+ answer_start = tf.argmax(outputs.start_logits, axis=1).numpy()[0]
92
+ answer_end = tf.argmax(outputs.end_logits, axis=1).numpy()[0] + 1
93
+ answer = qa_tokenizer.convert_tokens_to_string(qa_tokenizer.convert_ids_to_tokens(inputs['input_ids'].numpy()[0][answer_start:answer_end]))
94
+ return f"<span style='color: green; font-weight: bold;'>{answer}</span>"
95
+ except Exception as e:
96
+ return handle_error_message(e)
97
+
98
+ def replace_unk(tokens):
99
+ return [token.replace('[UNK]', "'") for token in tokens]
100
+
101
+ def named_entity_recognition(text, output_format='html'):
102
+ """
103
+ Process text for named entity recognition.
104
+ output_format: 'html' for GUI display, 'csv' for CSV processing
105
+ """
106
+ try:
107
+ inputs = ner_tokenizer(text, return_tensors='pt', truncation=True)
108
+ with torch.no_grad():
109
+ outputs = ner_model(**inputs)
110
+ ner_results = outputs.logits.argmax(dim=2).squeeze().tolist()
111
+ tokens = ner_tokenizer.convert_ids_to_tokens(inputs['input_ids'].squeeze().tolist())
112
+ tokens = replace_unk(tokens)
113
+
114
+ entities = []
115
+ seen_labels = set()
116
+ current_entity = []
117
+ current_label = None
118
+
119
+ # Process tokens and group consecutive entities
120
+ for i in range(len(tokens)):
121
+ token = tokens[i]
122
+ label = ner_model.config.id2label[ner_results[i]].split('-')[-1]
123
+
124
+ # Handle subwords
125
+ if token.startswith('##'):
126
+ if entities:
127
+ if output_format == 'html':
128
+ entities[-1][0] += token[2:]
129
+ elif current_entity:
130
+ current_entity[-1] = current_entity[-1] + token[2:]
131
+ else:
132
+ # For CSV format, group consecutive tokens of same entity type
133
+ if output_format == 'csv':
134
+ if label != 'O':
135
+ if label == current_label:
136
+ current_entity.append(token)
137
+ else:
138
+ if current_entity:
139
+ entities.append([' '.join(current_entity), current_label])
140
+ current_entity = [token]
141
+ current_label = label
142
+ else:
143
+ if current_entity:
144
+ entities.append([' '.join(current_entity), current_label])
145
+ current_entity = []
146
+ current_label = None
147
+ else:
148
+ entities.append([token, label])
149
+
150
+ if label != 'O':
151
+ seen_labels.add(label)
152
+
153
+ # Don't forget the last entity for CSV format
154
+ if output_format == 'csv' and current_entity:
155
+ entities.append([' '.join(current_entity), current_label])
156
+
157
+ if output_format == 'csv':
158
+ # Group by entity type
159
+ grouped_entities = {}
160
+ for token, label in entities:
161
+ if label != 'O':
162
+ if label not in grouped_entities:
163
+ grouped_entities[label] = []
164
+ grouped_entities[label].append(token)
165
+
166
+ # Format the output
167
+ result_parts = []
168
+ for label, tokens in grouped_entities.items():
169
+ unique_tokens = list(dict.fromkeys(tokens)) # Remove duplicates
170
+ result_parts.append(f"{label}: {' | '.join(unique_tokens)}")
171
+
172
+ return ' || '.join(result_parts)
173
+ else:
174
+ # Original HTML output
175
+ highlighted_text = ""
176
+ for token, label in entities:
177
+ color = ner_labels.get(label, 'black')
178
+ if label != 'O':
179
+ highlighted_text += f"<span style='color: {color}; font-weight: bold;'>{token}</span> "
180
+ else:
181
+ highlighted_text += f"{token} "
182
+
183
+ legend = "<div><strong>NER Tags Found:</strong><ul style='list-style-type: disc; padding-left: 20px;'>"
184
+ for label in seen_labels:
185
+ color = ner_labels.get(label, 'black')
186
+ legend += f"<li style='color: {color}; font-weight: bold;'>{label}</li>"
187
+ legend += "</ul></div>"
188
+
189
+ return f"<div>{highlighted_text}</div>{legend}"
190
+
191
+ except Exception as e:
192
+ return handle_error_message(e)
193
+
194
+ def text_classification(text):
195
+ try:
196
+ inputs = clf_tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(device)
197
+ with torch.no_grad():
198
+ outputs = clf_model(**inputs)
199
+ logits = outputs.logits.squeeze().tolist()
200
+ predicted_class = torch.argmax(outputs.logits, dim=1).item()
201
+ confidence = torch.softmax(outputs.logits, dim=1).max().item() * 100
202
+
203
+ if predicted_class == 1: # Positive class
204
+ result = f"<span style='color: green; font-weight: bold;'>Positive: The text is related to conflict, violence, or politics. (Confidence: {confidence:.2f}%)</span>"
205
+ else: # Negative class
206
+ result = f"<span style='color: red; font-weight: bold;'>Negative: The text is not related to conflict, violence, or politics. (Confidence: {confidence:.2f}%)</span>"
207
+ return result
208
+ except Exception as e:
209
+ return handle_error_message(e)
210
+
211
+ def multilabel_classification(text):
212
+ try:
213
+ inputs = multi_clf_tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(device)
214
+ with torch.no_grad():
215
+ outputs = multi_clf_model(**inputs)
216
+ predicted_classes = torch.sigmoid(outputs.logits).squeeze().tolist()
217
+ if len(predicted_classes) != len(multi_class_names):
218
+ return f"Error: Number of predicted classes ({len(predicted_classes)}) does not match number of class names ({len(multi_class_names)})."
219
+
220
+ results = []
221
+ for i in range(len(predicted_classes)):
222
+ confidence = predicted_classes[i] * 100
223
+ if predicted_classes[i] >= 0.5:
224
+ results.append(f"<span style='color: green; font-weight: bold;'>{multi_class_names[i]} (Confidence: {confidence:.2f}%)</span>")
225
+ else:
226
+ results.append(f"<span style='color: red; font-weight: bold;'>{multi_class_names[i]} (Confidence: {confidence:.2f}%)</span>")
227
+
228
+ return " / ".join(results)
229
+ except Exception as e:
230
+ return handle_error_message(e)
231
+
232
+ def clean_html_tags(text):
233
+ """Remove HTML tags and formatting from the output."""
234
+ # Remove HTML tags but keep the text content
235
+ clean_text = re.sub(r'<[^>]+>', '', text)
236
+ # Remove multiple spaces
237
+ clean_text = re.sub(r'\s+', ' ', clean_text)
238
+ # Remove [CLS] and [SEP] tokens
239
+ clean_text = re.sub(r'\[CLS\]|\[SEP\]', '', clean_text)
240
+ return clean_text.strip()
241
+
242
+ def extract_ner_entities(html_output):
243
+ """Extract entities and their types from NER output using a simpler approach."""
244
+ # Map colors to entity types
245
+ color_to_type = {
246
+ 'blue': 'Organisation',
247
+ 'red': 'Person',
248
+ 'green': 'Location',
249
+ 'orange': 'Quantity',
250
+ 'purple': 'Weapon',
251
+ 'cyan': 'Nationality',
252
+ 'magenta': 'Temporal',
253
+ 'brown': 'DocumentReference',
254
+ 'yellow': 'MilitaryPlatform',
255
+ 'pink': 'Money'
256
+ }
257
+
258
+ # Find all colored spans
259
+ pattern = r"<span style='color: ([^']+)[^>]+>([^<]+)</span>"
260
+ matches = re.findall(pattern, html_output)
261
+
262
+ # Group by entity type
263
+ entities = {}
264
+
265
+ # Process each match
266
+ for color, text in matches:
267
+ if color in color_to_type:
268
+ entity_type = color_to_type[color]
269
+ if entity_type not in entities:
270
+ entities[entity_type] = []
271
+
272
+ # Clean and store the text
273
+ text = text.strip()
274
+ if text and not text.isspace():
275
+ entities[entity_type].append(text)
276
+
277
+ # Join consecutive words for each entity type
278
+ result_parts = []
279
+ for entity_type, words in entities.items():
280
+ # Join consecutive words
281
+ phrases = []
282
+ current_phrase = []
283
+
284
+ for word in words:
285
+ if word in [',', '/', ':', '-']: # Skip punctuation
286
+ continue
287
+ if not current_phrase:
288
+ current_phrase.append(word)
289
+ else:
290
+ # If it's a continuation (e.g., part of a date or name)
291
+ if word.startswith(':') or word == 'of' or current_phrase[-1].endswith('/'):
292
+ current_phrase.append(word)
293
+ else:
294
+ # If it's a new entity
295
+ phrases.append(' '.join(current_phrase))
296
+ current_phrase = [word]
297
+
298
+ if current_phrase:
299
+ phrases.append(' '.join(current_phrase))
300
+
301
+ # Remove duplicates while preserving order
302
+ unique_phrases = []
303
+ seen = set()
304
+ for phrase in phrases:
305
+ clean_phrase = phrase.strip()
306
+ if clean_phrase and clean_phrase not in seen:
307
+ unique_phrases.append(clean_phrase)
308
+ seen.add(clean_phrase)
309
+
310
+ if unique_phrases:
311
+ result_parts.append(f"{entity_type}: {' | '.join(unique_phrases)}")
312
+
313
+ return ' || '.join(result_parts)
314
+
315
+
316
+ def clean_classification_output(html_output):
317
+ """Extract classification results without HTML formatting."""
318
+ if "Positive" in html_output:
319
+ # Binary classification
320
+ match = re.search(r">(Positive|Negative).*?Confidence: ([\d.]+)%", html_output)
321
+ if match:
322
+ class_name, confidence = match.groups()
323
+ return f"{class_name} ({confidence}%)"
324
+ else:
325
+ # Multilabel classification
326
+ results = []
327
+ matches = re.finditer(r">([^<]+)\s*\(Confidence:\s*([\d.]+)%\)", html_output)
328
+ for match in matches:
329
+ class_name, confidence = match.groups()
330
+ if float(confidence) >= 50: # Only include classes with confidence >= 50%
331
+ results.append(f"{class_name.strip()} ({confidence}%)")
332
+ return " | ".join(results) if results else "No classes above 50% confidence"
333
+
334
+ return "Unknown"
335
+
336
+
337
+ def process_csv_ner(file):
338
+ try:
339
+ df = pd.read_csv(file.name)
340
+
341
+ if 'text' not in df.columns:
342
+ return "Error: CSV must contain a 'text' column"
343
+
344
+ entities = []
345
+ for text in df['text']:
346
+ if pd.isna(text):
347
+ entities.append("")
348
+ continue
349
+
350
+ # Use CSV output format
351
+ result = named_entity_recognition(str(text), output_format='csv')
352
+ entities.append(result)
353
+
354
+ df['entities'] = entities
355
+
356
+ output_path = "processed_results.csv"
357
+ df.to_csv(output_path, index=False)
358
+ return output_path
359
+ except Exception as e:
360
+ return f"Error processing CSV: {str(e)}"
361
+
362
+ def process_csv_classification(file, is_multi=False):
363
+ try:
364
+ df = pd.read_csv(file.name)
365
+
366
+ if 'text' not in df.columns:
367
+ return "Error: CSV must contain a 'text' column"
368
+
369
+ results = []
370
+ for text in df['text']:
371
+ if pd.isna(text):
372
+ results.append("")
373
+ continue
374
+
375
+ if is_multi:
376
+ html_result = multilabel_classification(str(text))
377
+ else:
378
+ html_result = text_classification(str(text))
379
+ results.append(clean_classification_output(html_result))
380
+
381
+ result_column = 'multilabel_results' if is_multi else 'classification_results'
382
+ df[result_column] = results
383
+
384
+ output_path = "processed_results.csv"
385
+ df.to_csv(output_path, index=False)
386
+ return output_path
387
+ except Exception as e:
388
+ return f"Error processing CSV: {str(e)}"
389
+
390
+
391
+ # Define the Gradio interface
392
+ def chatbot(task, text=None, context=None, question=None, file=None):
393
+ if file is not None: # Handle CSV file input
394
+ if task == "Named Entity Recognition":
395
+ return process_csv_ner(file)
396
+ elif task == "Text Classification":
397
+ return process_csv_classification(file, is_multi=False)
398
+ elif task == "Multilabel Classification":
399
+ return process_csv_classification(file, is_multi=True)
400
+ else:
401
+ return "CSV processing is not supported for Question Answering task"
402
+
403
+ # Handle regular text input (previous implementation)
404
+ if task == "Question Answering":
405
+ if context and question:
406
+ return question_answering(context, question)
407
+ else:
408
+ return "Please provide both context and question for the Question Answering task."
409
+ elif task == "Named Entity Recognition":
410
+ if text:
411
+ return named_entity_recognition(text)
412
+ else:
413
+ return "Please provide text for the Named Entity Recognition task."
414
+ elif task == "Text Classification":
415
+ if text:
416
+ return text_classification(text)
417
+ else:
418
+ return "Please provide text for the Text Classification task."
419
+ elif task == "Multilabel Classification":
420
+ if text:
421
+ return multilabel_classification(text)
422
+ else:
423
+ return "Please provide text for the Multilabel Classification task."
424
+ else:
425
+ return "Please select a valid task."
426
+
427
+
428
+ css = """
429
+ :root {
430
+ --primary-color: #2563eb;
431
+ --secondary-color: #1e40af;
432
+ --accent-color: #3b82f6;
433
+ --background-color: #f8fafc;
434
+ --card-background: #ffffff;
435
+ --text-color: #1e293b;
436
+ --border-color: #e2e8f0;
437
+ }
438
+
439
+ body {
440
+ background-color: var(--background-color);
441
+ font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif;
442
+ color: var(--text-color);
443
+ }
444
+
445
+ .gradio-container {
446
+ max-width: 1200px !important;
447
+ margin: 2rem auto !important;
448
+ padding: 0 1rem;
449
+ }
450
+
451
+ .header-container {
452
+ background: linear-gradient(135deg, var(--primary-color), var(--secondary-color));
453
+ padding: 2rem 1rem;
454
+ margin: -1rem -1rem 2rem -1rem;
455
+ border-radius: 1rem;
456
+ box-shadow: 0 4px 6px -1px rgb(0 0 0 / 0.1), 0 2px 4px -2px rgb(0 0 0 / 0.1);
457
+ }
458
+
459
+ .header-title-center a {
460
+ font-size: 2.5rem !important;
461
+ font-weight: 800;
462
+ color: white !important;
463
+ text-align: center;
464
+ display: block;
465
+ text-decoration: none;
466
+ letter-spacing: -0.025em;
467
+ margin-bottom: 0.5rem;
468
+ }
469
+
470
+ .task-container {
471
+ background: var(--card-background);
472
+ padding: 2rem;
473
+ border-radius: 1rem;
474
+ box-shadow: 0 4px 6px -1px rgb(0 0 0 / 0.1), 0 2px 4px -2px rgb(0 0 0 / 0.1);
475
+ margin-bottom: 2rem;
476
+ }
477
+
478
+ .gr-input, .gr-box {
479
+ border: 1px solid var(--border-color) !important;
480
+ border-radius: 0.75rem !important;
481
+ padding: 1rem !important;
482
+ background: var(--card-background) !important;
483
+ transition: border-color 0.15s ease;
484
+ }
485
+
486
+ .gr-input:focus, .gr-box:focus {
487
+ border-color: var(--accent-color) !important;
488
+ outline: none !important;
489
+ box-shadow: 0 0 0 3px rgba(59, 130, 246, 0.1) !important;
490
+ }
491
+
492
+ .gr-button {
493
+ background: var(--primary-color) !important;
494
+ border: none;
495
+ padding: 0.75rem 1.5rem !important;
496
+ font-weight: 600 !important;
497
+ border-radius: 0.75rem !important;
498
+ cursor: pointer;
499
+ transition: all 0.15s ease;
500
+ }
501
+
502
+ .gr-button:hover {
503
+ background: var(--secondary-color) !important;
504
+ transform: translateY(-1px);
505
+ }
506
+
507
+ .gr-button:active {
508
+ transform: translateY(0);
509
+ }
510
+
511
+ select.gr-box {
512
+ cursor: pointer;
513
+ padding-right: 2.5rem !important;
514
+ appearance: none;
515
+ background-image: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' fill='none' viewBox='0 0 24 24' stroke='%23475569'%3E%3Cpath stroke-linecap='round' stroke-linejoin='round' stroke-width='2' d='M19 9l-7 7-7-7'%3E%3C/path%3E%3C/svg%3E");
516
+ background-repeat: no-repeat;
517
+ background-position: right 1rem center;
518
+ background-size: 1.5em 1.5em;
519
+ }
520
+
521
+ .footer {
522
+ text-align: center;
523
+ margin-top: 2rem;
524
+ padding: 2rem 0;
525
+ border-top: 1px solid var(--border-color);
526
+ color: #64748b;
527
+ }
528
+
529
+ .footer a {
530
+ color: var(--primary-color);
531
+ font-weight: 500;
532
+ text-decoration: none;
533
+ transition: color 0.15s ease;
534
+ }
535
+
536
+ .footer a:hover {
537
+ color: var(--secondary-color);
538
+ }
539
+
540
+ /* File upload styles */
541
+ .gr-file-drop {
542
+ border: 2px dashed var(--border-color) !important;
543
+ border-radius: 0.75rem !important;
544
+ padding: 2rem !important;
545
+ text-align: center;
546
+ transition: all 0.15s ease;
547
+ }
548
+
549
+ .gr-file-drop:hover {
550
+ border-color: var(--accent-color) !important;
551
+ background-color: rgba(59, 130, 246, 0.05) !important;
552
+ }
553
+
554
+ /* Output container */
555
+ .output-html {
556
+ background: var(--card-background);
557
+ padding: 1.5rem;
558
+ border-radius: 0.75rem;
559
+ box-shadow: 0 1px 3px 0 rgb(0 0 0 / 0.1), 0 1px 2px -1px rgb(0 0 0 / 0.1);
560
+ }
561
+
562
+ /* Labels */
563
+ label {
564
+ font-weight: 500;
565
+ margin-bottom: 0.5rem;
566
+ color: #475569;
567
+ }
568
+
569
+ /* Spacing between elements */
570
+ .gr-form {
571
+ gap: 1.5rem !important;
572
+ }
573
+
574
+ .gr-row {
575
+ gap: 1rem !important;
576
+ }
577
+ """
578
+
579
+ with gr.Blocks(css=css) as demo:
580
+ with gr.Column():
581
+ with gr.Row(elem_id="header", elem_classes="header-container"):
582
+ gr.Markdown("<div class='header-title-center'><a href='https://eventdata.utdallas.edu/conflibert/'>ConfliBERT</a></div>")
583
+
584
+ with gr.Column(elem_classes="task-container"):
585
+ gr.Markdown("<h2 style='font-size: 1.25rem; font-weight: 600; margin-bottom: 1.5rem; color: #0f172a;'>Select a task and provide the necessary inputs:</h2>")
586
+
587
+ task = gr.Dropdown(
588
+ choices=["Question Answering", "Named Entity Recognition", "Text Classification", "Multilabel Classification"],
589
+ label="Select Task",
590
+ value="Named Entity Recognition"
591
+ )
592
+
593
+ with gr.Row():
594
+ text_input = gr.Textbox(
595
+ lines=5,
596
+ placeholder="Enter the text here...",
597
+ label="Text",
598
+ elem_classes="input-text"
599
+ )
600
+ context_input = gr.Textbox(
601
+ lines=5,
602
+ placeholder="Enter the context here...",
603
+ label="Context",
604
+ visible=False,
605
+ elem_classes="input-text"
606
+ )
607
+ question_input = gr.Textbox(
608
+ lines=2,
609
+ placeholder="Enter your question here...",
610
+ label="Question",
611
+ visible=False,
612
+ elem_classes="input-text"
613
+ )
614
+
615
+ with gr.Row():
616
+ file_input = gr.File(
617
+ label="Or upload a CSV file (must contain a 'text' column)",
618
+ file_types=[".csv"],
619
+ elem_classes="file-upload"
620
+ )
621
+ file_output = gr.File(
622
+ label="Download processed results",
623
+ visible=False,
624
+ elem_classes="file-download"
625
+ )
626
+
627
+ with gr.Row():
628
+ submit_button = gr.Button(
629
+ "Submit",
630
+ elem_id="submit-button",
631
+ elem_classes="submit-btn"
632
+ )
633
+
634
+ output = gr.HTML(label="Output", elem_classes="output-html")
635
+
636
+ with gr.Row(elem_classes="footer"):
637
+ gr.Markdown("<a href='https://eventdata.utdallas.edu/'>UTD Event Data</a> | <a href='https://www.utdallas.edu/'>University of Texas at Dallas</a>")
638
+ gr.Markdown("Developed By: <a href='https://www.linkedin.com/in/sultan-alsarra-phd-56977a63/' target='_blank'>Sultan Alsarra</a> and <a href='http://shreyasmeher.com' target='_blank'>Shreyas Meher</a>")
639
+
640
+ # Define the update_inputs function
641
+ def update_inputs(task_name):
642
+ """Updates the visibility of input components based on the selected task."""
643
+ if task_name == "Question Answering":
644
+ return [
645
+ gr.update(visible=False),
646
+ gr.update(visible=True),
647
+ gr.update(visible=True),
648
+ gr.update(visible=False),
649
+ gr.update(visible=False)
650
+ ]
651
+ else:
652
+ return [
653
+ gr.update(visible=True),
654
+ gr.update(visible=False),
655
+ gr.update(visible=False),
656
+ gr.update(visible=True),
657
+ gr.update(visible=True)
658
+ ]
659
+
660
+ # Define the chatbot_interface function
661
+ def chatbot_interface(task, text, context, question, file):
662
+ """Handles both file and text inputs for different tasks."""
663
+ if file:
664
+ result = chatbot(task, file=file)
665
+ if isinstance(result, str) and result.endswith('.csv'):
666
+ return gr.update(visible=False), gr.update(value=result, visible=True)
667
+ return gr.update(value=result, visible=True), gr.update(visible=False)
668
+ else:
669
+ result = chatbot(task, text, context, question)
670
+ return gr.update(value=result, visible=True), gr.update(visible=False)
671
+
672
+ # Define the main chatbot function
673
+ def chatbot(task, text=None, context=None, question=None, file=None):
674
+ """Main function to process different types of inputs and tasks."""
675
+ if file is not None: # Handle CSV file input
676
+ if task == "Named Entity Recognition":
677
+ return process_csv_ner(file)
678
+ elif task == "Text Classification":
679
+ return process_csv_classification(file, is_multi=False)
680
+ elif task == "Multilabel Classification":
681
+ return process_csv_classification(file, is_multi=True)
682
+ else:
683
+ return "CSV processing is not supported for Question Answering task"
684
+
685
+ # Handle regular text input
686
+ if task == "Question Answering":
687
+ if context and question:
688
+ return question_answering(context, question)
689
+ else:
690
+ return "Please provide both context and question for the Question Answering task."
691
+ elif task == "Named Entity Recognition":
692
+ if text:
693
+ return named_entity_recognition(text)
694
+ else:
695
+ return "Please provide text for the Named Entity Recognition task."
696
+ elif task == "Text Classification":
697
+ if text:
698
+ return text_classification(text)
699
+ else:
700
+ return "Please provide text for the Text Classification task."
701
+ elif task == "Multilabel Classification":
702
+ if text:
703
+ return multilabel_classification(text)
704
+ else:
705
+ return "Please provide text for the Multilabel Classification task."
706
+ else:
707
+ return "Please select a valid task."
708
+
709
+ # Event handlers
710
+ task.change(fn=update_inputs, inputs=task, outputs=[text_input, context_input, question_input, file_input, file_output])
711
+ submit_button.click(
712
+ fn=chatbot_interface,
713
+ inputs=[task, text_input, context_input, question_input, file_input],
714
+ outputs=[output, file_output]
715
+ )
716
+
717
+ demo.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ tensorflow
3
+ transformers
4
+ gradio
5
+ tf-keras