Svngoku commited on
Commit
830bf75
·
verified ·
1 Parent(s): bdaf5da

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +206 -0
app.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import pandas as pd
3
+ import gradio as gr
4
+ import litellm
5
+ import plotly.express as px
6
+ from collections import defaultdict
7
+ from datetime import datetime
8
+
9
+ def preprocess_dataset(test_data):
10
+ """
11
+ Preprocess the dataset to convert the 'choices' field from a string to a list of strings.
12
+ """
13
+ preprocessed_data = []
14
+ for example in test_data:
15
+ if isinstance(example['choices'], str):
16
+ choices_str = example['choices']
17
+ if choices_str.startswith("'") and choices_str.endswith("'"):
18
+ choices_str = choices_str[1:-1]
19
+ elif choices_str.startswith('"') and choices_str.endswith('"'):
20
+ choices_str = choices_str[1:-1]
21
+ choices_str = choices_str.replace("\\'", "'")
22
+ try:
23
+ example['choices'] = ast.literal_eval(choices_str)
24
+ except (ValueError, SyntaxError):
25
+ print(f"Error parsing choices: {choices_str}")
26
+ continue
27
+ preprocessed_data.append(example)
28
+ return preprocessed_data
29
+
30
+ def evaluate_afrimmlu(test_data, model_name="deepseek-chat"):
31
+ """
32
+ Evaluate the model on the AfriMMLU dataset.
33
+ """
34
+ results = []
35
+ correct = 0
36
+ total = 0
37
+ subject_results = defaultdict(lambda: {"correct": 0, "total": 0})
38
+
39
+ for example in test_data:
40
+ question = example['question']
41
+ choices = example['choices']
42
+ answer = example['answer']
43
+ subject = example['subject']
44
+
45
+ prompt = (
46
+ f"Answer the following multiple-choice question. "
47
+ f"Return only the letter corresponding to the correct answer (A, B, C, or D).\n"
48
+ f"Question: {question}\n"
49
+ f"Options:\n"
50
+ f"A. {choices[0]}\n"
51
+ f"B. {choices[1]}\n"
52
+ f"C. {choices[2]}\n"
53
+ f"D. {choices[3]}\n"
54
+ f"Answer:"
55
+ )
56
+
57
+ try:
58
+ response = litellm.completion(
59
+ model=model_name,
60
+ messages=[{"role": "user", "content": prompt}]
61
+ )
62
+ model_output = response.choices[0].message.content.strip().upper()
63
+
64
+ model_answer = None
65
+ for char in model_output:
66
+ if char in ['A', 'B', 'C', 'D']:
67
+ model_answer = char
68
+ break
69
+
70
+ is_correct = model_answer == answer.upper()
71
+ if is_correct:
72
+ correct += 1
73
+ subject_results[subject]["correct"] += 1
74
+ total += 1
75
+ subject_results[subject]["total"] += 1
76
+
77
+ # Store detailed results
78
+ results.append({
79
+ 'timestamp': datetime.now().isoformat(),
80
+ 'subject': subject,
81
+ 'question': question,
82
+ 'model_answer': model_answer,
83
+ 'correct_answer': answer.upper(),
84
+ 'is_correct': is_correct,
85
+ 'total_tokens': response.usage.total_tokens
86
+ })
87
+
88
+ except Exception as e:
89
+ print(f"Error processing question: {str(e)}")
90
+ continue
91
+
92
+ # Calculate accuracies
93
+ accuracy = (correct / total * 100) if total > 0 else 0
94
+ subject_accuracy = {
95
+ subject: (stats["correct"] / stats["total"] * 100) if stats["total"] > 0 else 0
96
+ for subject, stats in subject_results.items()
97
+ }
98
+
99
+ # Export results to CSV
100
+ df = pd.DataFrame(results)
101
+ df.to_csv('detailed_results.csv', index=False)
102
+
103
+ # Export summary to CSV
104
+ summary_data = [{'subject': subject, 'accuracy': acc}
105
+ for subject, acc in subject_accuracy.items()]
106
+ summary_data.append({'subject': 'Overall', 'accuracy': accuracy})
107
+ pd.DataFrame(summary_data).to_csv('summary_results.csv', index=False)
108
+
109
+ return {
110
+ "accuracy": accuracy,
111
+ "subject_accuracy": subject_accuracy,
112
+ "detailed_results": results
113
+ }
114
+
115
+ def create_visualization(results_dict):
116
+ """
117
+ Create visualization from evaluation results.
118
+ """
119
+ summary_data = [
120
+ {'Subject': subject, 'Accuracy (%)': accuracy}
121
+ for subject, accuracy in results_dict['subject_accuracy'].items()
122
+ ]
123
+ summary_data.append({'Subject': 'Overall', 'Accuracy (%)': results_dict['accuracy']})
124
+ summary_df = pd.DataFrame(summary_data)
125
+
126
+ fig = px.bar(
127
+ summary_df,
128
+ x='Subject',
129
+ y='Accuracy (%)',
130
+ title='AfriMMLU Evaluation Results',
131
+ labels={'Subject': 'Subject', 'Accuracy (%)': 'Accuracy (%)'}
132
+ )
133
+ fig.update_layout(
134
+ xaxis_tickangle=-45,
135
+ showlegend=False,
136
+ height=600
137
+ )
138
+
139
+ return summary_df, fig
140
+
141
+ def evaluate_and_display(test_file, model_name):
142
+ """
143
+ Process uploaded file and run evaluation.
144
+ """
145
+ test_data = pd.read_json(test_file.name)
146
+ preprocessed_data = preprocess_dataset(test_data.to_dict('records'))
147
+
148
+ results = evaluate_afrimmlu(preprocessed_data, model_name)
149
+
150
+ summary_df, plot = create_visualization(results)
151
+ detailed_df = pd.read_csv('detailed_results.csv')
152
+
153
+ return summary_df, plot, detailed_df
154
+
155
+ def create_gradio_interface():
156
+ """
157
+ Create and configure the Gradio interface.
158
+ """
159
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
160
+ gr.Markdown("""
161
+ # AfriMMLU Evaluation Dashboard
162
+ Upload your test data and select a model to evaluate performance on the AfriMMLU benchmark.
163
+ """)
164
+
165
+ with gr.Row():
166
+ with gr.Column(scale=1):
167
+ file_input = gr.File(
168
+ label="Upload Test Data (JSON)",
169
+ file_types=[".json"]
170
+ )
171
+ model_input = gr.Dropdown(
172
+ choices=["deepseek-chat", "gpt-3.5-turbo", "gpt-4"],
173
+ label="Select Model",
174
+ value="deepseek-chat"
175
+ )
176
+ evaluate_btn = gr.Button("Evaluate", variant="primary")
177
+
178
+ with gr.Row():
179
+ with gr.Column():
180
+ summary_table = gr.Dataframe(
181
+ headers=["Subject", "Accuracy (%)"],
182
+ label="Summary Results"
183
+ )
184
+
185
+ with gr.Row():
186
+ with gr.Column():
187
+ summary_plot = gr.Plot(label="Performance by Subject")
188
+
189
+ with gr.Row():
190
+ with gr.Column():
191
+ detailed_results = gr.Dataframe(
192
+ label="Detailed Results",
193
+ wrap=True
194
+ )
195
+
196
+ evaluate_btn.click(
197
+ fn=evaluate_and_display,
198
+ inputs=[file_input, model_input],
199
+ outputs=[summary_table, summary_plot, detailed_results]
200
+ )
201
+
202
+ return demo
203
+
204
+ if __name__ == "__main__":
205
+ demo = create_gradio_interface()
206
+ demo.launch(share=True)