Spaces:
Running
Running
File size: 5,450 Bytes
bf5c2b9 9db537a bf5c2b9 9db537a bf5c2b9 9db537a bf5c2b9 58cd369 bf5c2b9 9db537a 66e524b bf5c2b9 66e524b bf5c2b9 66e524b bf5c2b9 66e524b bf5c2b9 66e524b bf5c2b9 66e524b bf5c2b9 9db537a 66e524b bf5c2b9 66e524b bf5c2b9 9db537a bf5c2b9 66e524b bf5c2b9 66e524b bf5c2b9 66e524b bf5c2b9 9db537a bf5c2b9 9db537a bf5c2b9 9db537a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
# data_viewer.py
import base64
import json
from functools import lru_cache
from io import BytesIO
import gradio as gr
from datasets import load_dataset
from PIL import Image
IGNORE_DETAILS = True
DATASET_NAME = "MMInstruction/VRewardBench"
@lru_cache(maxsize=1)
def load_cached_dataset(dataset_name, split):
return load_dataset(dataset_name, split=split)
def base64_to_image(base64_string):
img_data = base64.b64decode(base64_string)
return Image.open(BytesIO(img_data))
def get_responses(responses, rankings):
if isinstance(responses, str):
responses = json.loads(responses)
if isinstance(rankings, str):
rankings = json.loads(rankings)
chosen = next((resp for resp, rank in zip(responses, rankings) if rank == 0), "No chosen response")
rejected = next((resp for resp, rank in zip(responses, rankings) if rank == 1), "No rejected response")
return chosen, rejected
def load_and_display_sample(split, idx):
try:
dataset = load_cached_dataset(DATASET_NAME, split)
max_idx = len(dataset) - 1
idx = min(max(0, int(idx)), max_idx)
sample = dataset[idx]
# Get responses
chosen_response, rejected_response = get_responses(sample["response"], sample["human_ranking"])
# Process JSON data
models = json.loads(sample["models"]) if isinstance(sample["models"], str) else sample["models"]
return (
sample["image"], # image
sample["id"], # sample_id
chosen_response, # chosen_response
rejected_response, # rejected_response
sample["judge"], # judge
sample["query_source"], # query_source
sample["query"], # query
json.dumps(models, indent=2), # models_json
sample["rationale"], # rationale
sample["ground_truth"], # ground_truth
f"Total samples: {len(dataset)}", # total_samples
)
except Exception as e:
raise gr.Error(f"Error loading dataset: {str(e)}")
def create_data_viewer():
# Pre-fetch initial data
initial_split = "test"
initial_idx = 0
initial_data = load_and_display_sample(initial_split, initial_idx)
(
init_image,
init_sample_id,
init_chosen_response,
init_rejected_response,
init_judge,
init_query_source,
init_query,
init_models_json,
init_rationale,
init_ground_truth,
init_total_samples,
) = initial_data
with gr.Column():
with gr.Row():
dataset_split = gr.Radio(choices=["test"], value=initial_split, label="Dataset Split")
sample_idx = gr.Number(label="Sample Index", value=initial_idx, minimum=0, step=1, interactive=True)
total_samples = gr.Textbox(
label="Total Samples", value=init_total_samples, interactive=False # Set initial total samples
)
with gr.Row():
with gr.Column():
image = gr.Image(label="Sample Image", type="pil", value=init_image) # Set initial image
query = gr.Textbox(label="Query", value=init_query, interactive=False) # Set initial query
with gr.Column():
sample_id = gr.Textbox(
label="Sample ID", value=init_sample_id, interactive=False # Set initial sample ID
)
chosen_response = gr.TextArea(
label="Chosen Response ✅",
value=init_chosen_response,
interactive=False, # Set initial chosen response
)
rejected_response = gr.TextArea(
label="Rejected Response ❌",
value=init_rejected_response, # Set initial rejected response
interactive=False,
)
with gr.Row(visible=not IGNORE_DETAILS):
judge = gr.Textbox(label="Judge", value=init_judge, interactive=False) # Set initial judge
query_source = gr.Textbox(
label="Query Source", value=init_query_source, interactive=False # Set initial query source
)
with gr.Row(visible=not IGNORE_DETAILS):
with gr.Column():
models_json = gr.JSON(label="Models", value=json.loads(init_models_json)) # Set initial models
rationale = gr.TextArea(
label="Rationale", value=init_rationale, interactive=False # Set initial rationale
)
with gr.Column():
ground_truth = gr.TextArea(
label="Ground Truth", value=init_ground_truth, interactive=False # Set initial ground truth
)
# Auto-update when any input changes
for input_component in [dataset_split, sample_idx]:
input_component.change(
fn=load_and_display_sample,
inputs=[dataset_split, sample_idx],
outputs=[
image,
sample_id,
chosen_response,
rejected_response,
judge,
query_source,
query,
models_json,
rationale,
ground_truth,
total_samples,
],
)
return dataset_split, sample_idx
|