# data_viewer.py import base64 import json from functools import lru_cache from io import BytesIO import gradio as gr from datasets import load_dataset from PIL import Image @lru_cache(maxsize=1) def load_cached_dataset(dataset_name, split): return load_dataset(dataset_name, split=split) def base64_to_image(base64_string): img_data = base64.b64decode(base64_string) return Image.open(BytesIO(img_data)) def get_responses(responses, rankings): if isinstance(responses, str): responses = json.loads(responses) if isinstance(rankings, str): rankings = json.loads(rankings) chosen = next((resp for resp, rank in zip(responses, rankings) if rank == 0), "No chosen response") rejected = next((resp for resp, rank in zip(responses, rankings) if rank == 1), "No rejected response") return chosen, rejected def load_and_display_sample(dataset_name, split, idx): try: dataset = load_cached_dataset(dataset_name, split) max_idx = len(dataset) - 1 idx = min(max(0, int(idx)), max_idx) sample = dataset[idx] # Process image image = base64_to_image(sample["image"]) # Get responses chosen_response, rejected_response = get_responses(sample["response"], sample["human_ranking"]) # Process JSON data models = json.loads(sample["models"]) if isinstance(sample["models"], str) else sample["models"] meta = json.loads(sample["meta"]) if isinstance(sample["meta"], str) else sample["meta"] error_analysis = ( json.loads(sample["human_error_analysis"]) if isinstance(sample["human_error_analysis"], str) else sample["human_error_analysis"] ) return ( image, # image sample["id"], # sample_id chosen_response, # chosen_response rejected_response, # rejected_response sample["judge"], # judge sample["query_source"], # query_source sample["query"], # query json.dumps(models, indent=2), # models_json json.dumps(meta, indent=2), # meta_json sample["rationale"], # rationale json.dumps(error_analysis, indent=2), # error_analysis_json sample["ground_truth"], # ground_truth f"Total samples: {len(dataset)}", # total_samples ) except Exception as e: raise gr.Error(f"Error loading dataset: {str(e)}") def create_data_viewer(): # Pre-fetch initial data initial_dataset_name = "MMInstruction/VRewardBench" initial_split = "test" initial_idx = 0 initial_data = load_and_display_sample(initial_dataset_name, initial_split, initial_idx) with gr.Column(): with gr.Row(): dataset_name = gr.Textbox(label="Dataset Name", value=initial_dataset_name, interactive=True) dataset_split = gr.Radio(choices=["test"], value=initial_split, label="Dataset Split") sample_idx = gr.Number(label="Sample Index", value=initial_idx, minimum=0, step=1, interactive=True) total_samples = gr.Textbox( label="Total Samples", value=initial_data[12], interactive=False # Set initial total samples ) with gr.Row(): with gr.Column(): image = gr.Image(label="Sample Image", type="pil", value=initial_data[0]) # Set initial image with gr.Column(): sample_id = gr.Textbox( label="Sample ID", value=initial_data[1], interactive=False # Set initial sample ID ) chosen_response = gr.TextArea( label="Chosen Response ✅", value=initial_data[2], interactive=False # Set initial chosen response ) rejected_response = gr.TextArea( label="Rejected Response ❌", value=initial_data[3], # Set initial rejected response interactive=False, ) with gr.Row(): judge = gr.Textbox(label="Judge", value=initial_data[4], interactive=False) # Set initial judge query_source = gr.Textbox( label="Query Source", value=initial_data[5], interactive=False # Set initial query source ) query = gr.Textbox(label="Query", value=initial_data[6], interactive=False) # Set initial query with gr.Row(): with gr.Column(): models_json = gr.JSON(label="Models", value=json.loads(initial_data[7])) # Set initial models meta_json = gr.JSON(label="Meta", value=json.loads(initial_data[8])) # Set initial meta rationale = gr.TextArea( label="Rationale", value=initial_data[9], interactive=False # Set initial rationale ) with gr.Column(): error_analysis_json = gr.JSON( label="Human Error Analysis", value=json.loads(initial_data[10]) # Set initial error analysis ) ground_truth = gr.TextArea( label="Ground Truth", value=initial_data[11], interactive=False # Set initial ground truth ) # Auto-update when any input changes for input_component in [dataset_name, dataset_split, sample_idx]: input_component.change( fn=load_and_display_sample, inputs=[dataset_name, dataset_split, sample_idx], outputs=[ image, sample_id, chosen_response, rejected_response, judge, query_source, query, models_json, meta_json, rationale, error_analysis_json, ground_truth, total_samples, ], ) return dataset_name, dataset_split, sample_idx