|
import gradio as gr |
|
from collections import defaultdict |
|
import os |
|
import base64 |
|
import torch |
|
from datasets import ( |
|
Dataset, |
|
load_dataset, |
|
) |
|
import random |
|
import pandas as pd |
|
from collections import defaultdict |
|
|
|
def encode_image_to_base64(image_path): |
|
"""Encode an image or GIF file to base64.""" |
|
with open(image_path, "rb") as file: |
|
encoded_string = base64.b64encode(file.read()).decode() |
|
return encoded_string |
|
|
|
def create_html_media(media_path, is_gif=False): |
|
"""Create HTML for displaying an image or GIF.""" |
|
media_base64 = encode_image_to_base64(media_path) |
|
media_type = "gif" if is_gif else "jpeg" |
|
|
|
html_string = f""" |
|
<div style="display: flex; justify-content: center; align-items: center; width: 100%; text-align: center;"> |
|
<div style="max-width: 450px; margin: auto;"> |
|
<img src="data:image/{media_type};base64,{media_base64}" |
|
style="max-width: 75%; height: auto; display: block; margin: 0 auto; margin-top: 50px;" |
|
alt="Displayed Media"> |
|
</div> |
|
</div> |
|
""" |
|
return html_string |
|
|
|
class LMBattleArena: |
|
def __init__(self, dataset_path): |
|
"""Initialize battle arena with dataset""" |
|
self.df = pd.read_csv(dataset_path) |
|
print(self.df.head()) |
|
self.current_index = 0 |
|
self.saving_freq = 10 |
|
self.evaluation_results = [] |
|
self.model_scores = defaultdict(lambda: {'wins': 0, 'total_comparisons': 0}) |
|
|
|
def get_next_battle_pair(self): |
|
"""Retrieve next pair of summaries for comparison""" |
|
if self.current_index >= len(self.df): |
|
return None |
|
|
|
row = self.df.iloc[self.current_index] |
|
model_summary_cols = [ |
|
col |
|
for col in row.index |
|
if col.upper() != 'PROMPT' |
|
] |
|
selected_models = random.sample(model_summary_cols, 2) |
|
battle_data = { |
|
'prompt': row['prompt'], |
|
'model_1': row[selected_models[0]], |
|
'model_2': row[selected_models[1]], |
|
'model1_name': selected_models[0], |
|
'model2_name': selected_models[1] |
|
} |
|
self.current_index += 1 |
|
return battle_data |
|
|
|
def record_evaluation(self, preferred_models, input_text, output1, output2, model1_name, model2_name): |
|
"""Record user's model preference and update scores""" |
|
self.model_scores[model1_name]['total_comparisons'] += 1 |
|
self.model_scores[model2_name]['total_comparisons'] += 1 |
|
|
|
if preferred_models == "Both Good": |
|
self.model_scores[model1_name]['wins'] += 1 |
|
self.model_scores[model2_name]['wins'] += 1 |
|
elif preferred_models == "Model A": |
|
self.model_scores[model1_name]['wins'] += 1 |
|
elif preferred_models == "Model B": |
|
self.model_scores[model2_name]['wins'] += 1 |
|
|
|
|
|
evaluation = { |
|
'input_text': input_text, |
|
'output1': output1, |
|
'output2': output2, |
|
'model1_name': model1_name, |
|
'model2_name': model2_name, |
|
'preferred_models': preferred_models |
|
} |
|
self.evaluation_results.append(evaluation) |
|
|
|
return self.get_model_scores_df() |
|
|
|
def get_model_scores_df(self): |
|
"""Convert model scores to DataFrame""" |
|
scores_data = [] |
|
for model, stats in self.model_scores.items(): |
|
win_rate = (stats['wins'] / stats['total_comparisons'] * 100) if stats['total_comparisons'] > 0 else 0 |
|
scores_data.append({ |
|
'Model': model, |
|
'Wins': stats['wins'], |
|
'Total Comparisons': stats['total_comparisons'], |
|
'Win Rate (%)': round(win_rate, 2) |
|
}) |
|
results_df = pd.DataFrame(scores_data).sort_values('Win Rate (%)', ascending=False) |
|
|
|
|
|
if self.current_index % self.saving_freq == 0 and self.current_index > 0: |
|
|
|
|
|
results_df.to_csv('human_eval_results.csv') |
|
|
|
return results_df |
|
|
|
|
|
def create_battle_arena(dataset_path, is_gif): |
|
arena = LMBattleArena(dataset_path) |
|
|
|
def battle_round(): |
|
battle_data = arena.get_next_battle_pair() |
|
|
|
if battle_data is None: |
|
return "No more texts to evaluate!", "", "", "", "", gr.DataFrame(visible=False) |
|
|
|
return ( |
|
battle_data['prompt'], |
|
battle_data['model_1'], |
|
battle_data['model_2'], |
|
battle_data['model1_name'], |
|
battle_data['model2_name'], |
|
gr.DataFrame(visible=True) |
|
) |
|
|
|
def submit_preference(input_text, output_1, output_2, model1_name, model2_name, preferred_models): |
|
scores_df = arena.record_evaluation( |
|
preferred_models, input_text, output_1, output_2, model1_name, model2_name |
|
) |
|
next_battle = battle_round() |
|
return (*next_battle[:-1], scores_df) |
|
|
|
with gr.Blocks(css="footer{display:none !important}") as demo: |
|
|
|
base_path = os.path.dirname(__file__) |
|
local_image_path = os.path.join(base_path, 'battle_leaderboard.gif') |
|
gr.HTML(create_html_media(local_image_path, is_gif=is_gif)) |
|
|
|
with gr.Tabs(): |
|
with gr.Tab("Battle Arena"): |
|
gr.Markdown("# π€ Pretrained SmolLMs Battle Arena") |
|
|
|
input_text = gr.Textbox( |
|
label="Input prompt", |
|
interactive=False, |
|
) |
|
|
|
with gr.Row(): |
|
output_1 = gr.Textbox( |
|
label="Model A", |
|
interactive=False |
|
) |
|
model1_name = gr.State() |
|
|
|
with gr.Row(): |
|
output_2 = gr.Textbox( |
|
label="Model B", |
|
interactive=False |
|
) |
|
model2_name = gr.State() |
|
|
|
preferred_models = gr.Radio( |
|
label="Which model is better?", |
|
choices=["Model A", "Model B", "Both Good", "Both Bad"] |
|
) |
|
submit_btn = gr.Button("Vote", variant="primary") |
|
|
|
scores_table = gr.DataFrame( |
|
headers=['Model', 'Wins', 'Total Comparisons', 'Win Rate (%)'], |
|
label="π Leaderboard" |
|
) |
|
|
|
submit_btn.click( |
|
submit_preference, |
|
inputs=[input_text, output_1, output_2, model1_name, model2_name, preferred_models], |
|
outputs=[input_text, output_1, output_2, model1_name, model2_name, scores_table] |
|
) |
|
|
|
demo.load(battle_round, outputs=[input_text, output_1, output_2, model1_name, model2_name, scores_table]) |
|
|
|
return demo |
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
human_eval_dataset = load_dataset("atlasia/Moroccan-Darija-LLM-Battle-Al-Atlas", split='train').to_csv('human_eval_dataset.csv') |
|
|
|
|
|
torch_dtype = torch.float16 |
|
|
|
|
|
device = "cpu" |
|
dataset_path = 'human_eval_dataset.csv' |
|
is_gif = True |
|
demo = create_battle_arena(dataset_path, is_gif) |
|
demo.launch(debug=True) |