VL-RewardBench / data_reviewer.py
Zhihui's picture
Update data_reviewer.py
58cd369 verified
raw
history blame
5.89 kB
# data_viewer.py
import base64
import json
from functools import lru_cache
from io import BytesIO
import gradio as gr
from datasets import load_dataset
from PIL import Image
IGNORE_DETAILS = True
DATASET_NAME = "MMInstruction/VRewardBench"
@lru_cache(maxsize=1)
def load_cached_dataset(dataset_name, split):
return load_dataset(dataset_name, split=split)
def base64_to_image(base64_string):
img_data = base64.b64decode(base64_string)
return Image.open(BytesIO(img_data))
def get_responses(responses, rankings):
if isinstance(responses, str):
responses = json.loads(responses)
if isinstance(rankings, str):
rankings = json.loads(rankings)
chosen = next((resp for resp, rank in zip(responses, rankings) if rank == 0), "No chosen response")
rejected = next((resp for resp, rank in zip(responses, rankings) if rank == 1), "No rejected response")
return chosen, rejected
def load_and_display_sample(split, idx):
try:
dataset = load_cached_dataset(DATASET_NAME, split)
max_idx = len(dataset) - 1
idx = min(max(0, int(idx)), max_idx)
sample = dataset[idx]
# Get responses
chosen_response, rejected_response = get_responses(sample["response"], sample["human_ranking"])
# Process JSON data
models = json.loads(sample["models"]) if isinstance(sample["models"], str) else sample["models"]
meta = json.loads(sample["meta"]) if isinstance(sample["meta"], str) else sample["meta"]
error_analysis = (
json.loads(sample["human_error_analysis"])
if isinstance(sample["human_error_analysis"], str)
else sample["human_error_analysis"]
)
return (
sample["image"], # image
sample["id"], # sample_id
chosen_response, # chosen_response
rejected_response, # rejected_response
sample["judge"], # judge
sample["query_source"], # query_source
sample["query"], # query
json.dumps(models, indent=2), # models_json
json.dumps(meta, indent=2), # meta_json
sample["rationale"], # rationale
json.dumps(error_analysis, indent=2), # error_analysis_json
sample["ground_truth"], # ground_truth
f"Total samples: {len(dataset)}", # total_samples
)
except Exception as e:
raise gr.Error(f"Error loading dataset: {str(e)}")
def create_data_viewer():
# Pre-fetch initial data
initial_split = "test"
initial_idx = 0
initial_data = load_and_display_sample(initial_split, initial_idx)
with gr.Column():
with gr.Row():
dataset_split = gr.Radio(choices=["test"], value=initial_split, label="Dataset Split")
sample_idx = gr.Number(label="Sample Index", value=initial_idx, minimum=0, step=1, interactive=True)
total_samples = gr.Textbox(
label="Total Samples", value=initial_data[12], interactive=False # Set initial total samples
)
with gr.Row():
with gr.Column():
image = gr.Image(label="Sample Image", type="pil", value=initial_data[0]) # Set initial image
query = gr.Textbox(label="Query", value=initial_data[6], interactive=False) # Set initial query
with gr.Column():
sample_id = gr.Textbox(
label="Sample ID", value=initial_data[1], interactive=False # Set initial sample ID
)
chosen_response = gr.TextArea(
label="Chosen Response βœ…", value=initial_data[2], interactive=False # Set initial chosen response
)
rejected_response = gr.TextArea(
label="Rejected Response ❌",
value=initial_data[3], # Set initial rejected response
interactive=False,
)
with gr.Row(visible=not IGNORE_DETAILS):
judge = gr.Textbox(label="Judge", value=initial_data[4], interactive=False) # Set initial judge
query_source = gr.Textbox(
label="Query Source", value=initial_data[5], interactive=False # Set initial query source
)
with gr.Row(visible=not IGNORE_DETAILS):
with gr.Column():
models_json = gr.JSON(label="Models", value=json.loads(initial_data[7])) # Set initial models
meta_json = gr.JSON(label="Meta", value=json.loads(initial_data[8])) # Set initial meta
rationale = gr.TextArea(
label="Rationale", value=initial_data[9], interactive=False # Set initial rationale
)
with gr.Column():
error_analysis_json = gr.JSON(
label="Human Error Analysis", value=json.loads(initial_data[10]) # Set initial error analysis
)
ground_truth = gr.TextArea(
label="Ground Truth", value=initial_data[11], interactive=False # Set initial ground truth
)
# Auto-update when any input changes
for input_component in [dataset_split, sample_idx]:
input_component.change(
fn=load_and_display_sample,
inputs=[dataset_split, sample_idx],
outputs=[
image,
sample_id,
chosen_response,
rejected_response,
judge,
query_source,
query,
models_json,
meta_json,
rationale,
error_analysis_json,
ground_truth,
total_samples,
],
)
return dataset_split, sample_idx