Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pandas as pd | |
import os | |
import uuid | |
import datetime | |
import logging | |
from huggingface_hub import hf_hub_download, upload_file, list_repo_tree | |
from dotenv import load_dotenv | |
from collections import defaultdict | |
load_dotenv() | |
# Configuration | |
HF_INPUT_DATASET = os.getenv("HF_INPUT_DATASET") | |
HF_INPUT_DATASET_PATH = os.getenv("HF_INPUT_DATASET_PATH") | |
HF_INPUT_DATASET_ID_COLUMN = os.getenv("HF_INPUT_DATASET_ID_COLUMN") | |
HF_INPUT_DATASET_COLUMN_A = os.getenv("HF_INPUT_DATASET_COLUMN_A") | |
HF_INPUT_DATASET_COLUMN_B = os.getenv("HF_INPUT_DATASET_COLUMN_B") | |
HF_INPUT_DATASET_URL_COLUMN = os.getenv("HF_INPUT_DATASET_URL_COLUMN") | |
HF_OUTPUT_DATASET = os.getenv("HF_OUTPUT_DATASET") | |
HF_OUTPUT_DATASET_DIR = os.getenv("HF_OUTPUT_DATASET_DIR") | |
INSTRUCTIONS = """ | |
# Pairwise Model Output Labeling Space | |
### How-to: | |
* Duplicate this Space | |
* Add a `HF_TOKEN` secret in the Space settings (it saves the results to a given dataset, so it requires write access) | |
* Set the environment variables in the [`.env`](https://huggingface.co/spaces/saridormi/labeling-template/blob/main/.env) file to point to correct datasets/columns/paths for your input and output data | |
* Adjust `INSTRUCTIONS` (it's what's rendered here) and `SAVE_EVERY_N_EXAMPLES` (how often should answers be saved to the output dataset) in the [`app.py`](https://huggingface.co/spaces/saridormi/labeling-template/blob/main/app.py) file | |
Done! 💛 | |
### About: | |
Please compare the two model outputs shown below and select which one you think is better. | |
- Choose "A is better" if the output from Model A (left) is superior | |
- Choose "B is better" if the output from Model B (right) is superior | |
- Choose "Tie" if they are equally good or bad | |
- Choose "Can't choose" if you cannot make a determination | |
""" | |
SAVE_EVERY_N_EXAMPLES = 5 | |
class PairwiseLabeler: | |
def __init__(self): | |
self.current_index = defaultdict(int) | |
self.results = defaultdict(list) | |
self._df = self.read_hf_dataset() | |
self.df = {} | |
self.has_url_column = HF_INPUT_DATASET_URL_COLUMN and HF_INPUT_DATASET_URL_COLUMN in self._df.columns | |
def __len__(self): | |
return len(self._df) | |
def read_hf_dataset(self) -> pd.DataFrame: | |
try: | |
local_file = hf_hub_download(repo_id=HF_INPUT_DATASET, repo_type="dataset", filename=HF_INPUT_DATASET_PATH) | |
if local_file.endswith(".json"): | |
return pd.read_json(local_file) | |
elif local_file.endswith(".jsonl"): | |
return pd.read_json(local_file, orient="records",lines=True) | |
elif local_file.endswith(".csv"): | |
return pd.read_csv(local_file) | |
elif local_file.endswith(".parquet"): | |
return pd.read_parquet(local_file) | |
else: | |
raise ValueError(f"Unsupported file type: {local_file}") | |
except Exception as e: | |
# Fallback to sample data if loading fails | |
logging.error(f"Couldn't read HF dataset from {HF_INPUT_DATASET_PATH}. Using sample data instead.") | |
sample_data = { | |
HF_INPUT_DATASET_ID_COLUMN: [f"sample_{i}" for i in range(SAVE_EVERY_N_EXAMPLES)], | |
HF_INPUT_DATASET_COLUMN_A: [f"This is sample generation A {i}" for i in range(SAVE_EVERY_N_EXAMPLES)], | |
HF_INPUT_DATASET_COLUMN_B: [f"This is sample generation B {i}" for i in range(SAVE_EVERY_N_EXAMPLES)], | |
} | |
# Add URL column to sample data if specified | |
if HF_INPUT_DATASET_URL_COLUMN: | |
sample_data[HF_INPUT_DATASET_URL_COLUMN] = [f"https://example.com/sample_{i}" for i in range(SAVE_EVERY_N_EXAMPLES)] | |
return pd.DataFrame(sample_data) | |
def get_current_pair(self, session_id): | |
if session_id not in self.df: | |
self.df[session_id] = self._df.sample(frac=1).reset_index(drop=True) | |
if self.current_index[session_id] >= len(self.df[session_id]): | |
if self.has_url_column: | |
return None, None, None, None | |
else: | |
return None, None, None | |
item = self.df[session_id].iloc[self.current_index[session_id]] | |
item_id = item.get(HF_INPUT_DATASET_ID_COLUMN, f"item_{self.current_index[session_id]}") | |
left_text = item.get(HF_INPUT_DATASET_COLUMN_A, "") | |
right_text = item.get(HF_INPUT_DATASET_COLUMN_B, "") | |
if self.has_url_column: | |
url = item.get(HF_INPUT_DATASET_URL_COLUMN, "") | |
return item_id, left_text, right_text, url | |
else: | |
return item_id, left_text, right_text | |
def submit_judgment(self, item_id, left_text, right_text, choice, session_id): | |
if item_id is None: | |
if self.has_url_column: | |
return item_id, left_text, right_text, None, self.current_index[session_id] | |
else: | |
return item_id, left_text, right_text, self.current_index[session_id] | |
# Record the judgment | |
result = { | |
"item_id": item_id, | |
"judgment": choice, | |
"timestamp": datetime.datetime.now().isoformat(), | |
"labeler_id": session_id | |
} | |
self.results[session_id].append(result) | |
# Move to next item | |
self.current_index[session_id] += 1 | |
# Save results periodically | |
if len(self.results[session_id]) % SAVE_EVERY_N_EXAMPLES == 0: | |
self.save_results(session_id) | |
# Get next pair | |
if self.has_url_column: | |
next_id, next_left, next_right, next_url = self.get_current_pair(session_id) | |
return next_id, next_left, next_right, next_url, self.current_index[session_id] | |
else: | |
next_id, next_left, next_right = self.get_current_pair(session_id) | |
return next_id, next_left, next_right, self.current_index[session_id] | |
def save_results(self, session_id): | |
if not self.results[session_id]: | |
return | |
try: | |
# Convert results to dataset format | |
results_df = pd.DataFrame(self.results[session_id]) | |
results_df.to_json("temp.jsonl", orient="records", lines=True) | |
# Push to Hugging Face Hub | |
try: | |
num_files = len([_ for _ in list_repo_tree(repo_id=HF_OUTPUT_DATASET, repo_type="dataset", path_in_repo=HF_OUTPUT_DATASET_DIR) if session_id in _.path]) | |
except Exception as e: | |
num_files = 0 | |
# Use session_id in filename to avoid conflicts | |
filename = f"results_{session_id}_{num_files+1}.jsonl" | |
upload_file( | |
repo_id=HF_OUTPUT_DATASET, | |
repo_type="dataset", | |
path_in_repo=os.path.join(HF_OUTPUT_DATASET_DIR, filename), | |
path_or_fileobj="temp.jsonl" | |
) | |
os.remove("temp.jsonl") | |
# Clear saved results | |
self.results[session_id] = [] | |
logging.info(f"Saved results for session {session_id} to {HF_OUTPUT_DATASET}/{filename}") | |
except Exception as e: | |
logging.error(f"Error saving results: {e}") | |
# Keep results in memory to try saving again later | |
# Initialize the labeler | |
labeler = PairwiseLabeler() | |
# Create a default session ID | |
def create_default_session(): | |
return str(uuid.uuid4())[:8] | |
# Create a demo session ID for initial display | |
demo_session = create_default_session() | |
demo_item_id, demo_left_text, demo_right_text = None, "", "" | |
demo_url = None | |
if labeler.has_url_column: | |
demo_item_id, demo_left_text, demo_right_text, demo_url = labeler.get_current_pair(demo_session) | |
else: | |
demo_item_id, demo_left_text, demo_right_text = labeler.get_current_pair(demo_session) | |
with gr.Blocks(css="footer {visibility: hidden}") as app: | |
# State for the session ID | |
session_id = gr.State(value=None) | |
is_session_started = gr.State(value=False) | |
# Header with instructions | |
gr.Markdown(INSTRUCTIONS) | |
# Session and progress bar in one line | |
user_session_id = gr.Textbox( | |
placeholder="Enter your unique ID or leave blank for random", | |
label="Session ID", | |
scale=2 | |
) | |
start_btn = gr.Button("Start Session", variant="primary", scale=1) | |
# URL display component - only shown if URL column is defined | |
if labeler.has_url_column: | |
url_display = gr.HTML(label="Reference URL", value="", visible=True) | |
# Main content area (shown with placeholder text) | |
with gr.Row(): | |
with gr.Column(): | |
left_output = gr.Textbox( | |
value=demo_left_text, | |
label="Model A Output", | |
lines=10, | |
interactive=False | |
) | |
with gr.Column(): | |
right_output = gr.Textbox( | |
value=demo_right_text, | |
label="Model B Output", | |
lines=10, | |
interactive=False | |
) | |
item_id = gr.Textbox(value=demo_item_id, visible=False) | |
# Buttons row (initially disabled) | |
with gr.Row(): | |
left_btn = gr.Button("⬅️ A is better", variant="primary", interactive=False) | |
right_btn = gr.Button("➡️ B is better", variant="primary", interactive=False) | |
tie_btn = gr.Button("🤝 Tie", variant="primary", interactive=False) | |
cant_choose_btn = gr.Button("🤔 Can't choose", interactive=False) | |
# Progress slider (initially disabled) | |
current_sample_sld = gr.Slider( | |
minimum=0, | |
maximum=len(labeler), | |
step=1, | |
value=0, | |
interactive=False, | |
label='Progress', | |
show_label=False | |
) | |
# Initialize the session and get the first pair | |
def init_session(entered_id): | |
# Use entered ID or generate a new one if empty | |
session_id_value = entered_id.strip() if entered_id and entered_id.strip() else create_default_session() | |
# Get the initial data for this session ID | |
if labeler.has_url_column: | |
initial_id, initial_left, initial_right, initial_url = labeler.get_current_pair(session_id_value) | |
url_html = f'<a href="{initial_url}" target="_blank">{initial_url}</a>' if initial_url else "" | |
return ( | |
session_id_value, # session_id state | |
True, # is_session_started state | |
labeler.current_index[session_id_value], # current_sample_sld | |
url_html, # url_display | |
initial_left, # left_output | |
initial_right, # right_output | |
initial_id, # item_id | |
gr.update(interactive=True), # left_btn | |
gr.update(interactive=True), # right_btn | |
gr.update(interactive=True), # tie_btn | |
gr.update(interactive=True) # cant_choose_btn | |
) | |
else: | |
initial_id, initial_left, initial_right = labeler.get_current_pair(session_id_value) | |
# Create session info text | |
progress_text = f"Session: {session_id_value} | {labeler.current_index[session_id_value] + 1}/{len(labeler)}" | |
return ( | |
session_id_value, # session_id state | |
True, # is_session_started state | |
labeler.current_index[session_id_value], # current_sample_sld | |
initial_left, # left_output | |
initial_right, # right_output | |
initial_id, # item_id | |
gr.update(interactive=True), # left_btn | |
gr.update(interactive=True), # right_btn | |
gr.update(interactive=True), # tie_btn | |
gr.update(interactive=True) # cant_choose_btn | |
) | |
# Connect the start button | |
if labeler.has_url_column: | |
start_btn.click( | |
init_session, | |
inputs=[user_session_id], | |
outputs=[ | |
session_id, | |
is_session_started, | |
current_sample_sld, | |
url_display, | |
left_output, | |
right_output, | |
item_id, | |
left_btn, | |
right_btn, | |
tie_btn, | |
cant_choose_btn | |
] | |
) | |
else: | |
start_btn.click( | |
init_session, | |
inputs=[user_session_id], | |
outputs=[ | |
session_id, | |
is_session_started, | |
current_sample_sld, | |
left_output, | |
right_output, | |
item_id, | |
left_btn, | |
right_btn, | |
tie_btn, | |
cant_choose_btn | |
] | |
) | |
def judge_left(session_id, item_id, left_text, right_text): | |
return judge("A is better", session_id, item_id, left_text, right_text) | |
def judge_right(session_id, item_id, left_text, right_text): | |
return judge("B is better", session_id, item_id, left_text, right_text) | |
def judge_tie(session_id, item_id, left_text, right_text): | |
return judge("Tie", session_id, item_id, left_text, right_text) | |
def judge_cant_choose(session_id, item_id, left_text, right_text): | |
return judge("Can't choose", session_id, item_id, left_text, right_text) | |
def judge(choice, session_id, item_id, left_text, right_text): | |
if not session_id: | |
# Session not initialized, do nothing | |
return item_id, left_text, right_text, 0, "Not started" | |
if labeler.has_url_column: | |
new_id, new_left, new_right, new_url, new_index = labeler.submit_judgment( | |
item_id, left_text, right_text, choice, session_id | |
) | |
url_html = f'<a href="{new_url}" target="_blank">{new_url}</a>' if new_url else "" | |
# Update session info text with new progress | |
progress_text = f"Session: {session_id} | {new_index + 1}/{len(labeler)}" if new_index < len(labeler) else f"Session: {session_id} | Complete! {len(labeler)}/{len(labeler)}" | |
return new_id, new_left, new_right, url_html, new_index, progress_text | |
else: | |
new_id, new_left, new_right, new_index = labeler.submit_judgment( | |
item_id, left_text, right_text, choice, session_id | |
) | |
# Update session info text with new progress | |
progress_text = f"Session: {session_id} | {new_index + 1}/{len(labeler)}" if new_index < len(labeler) else f"Session: {session_id} | Complete! {len(labeler)}/{len(labeler)}" | |
return new_id, new_left, new_right, new_index, progress_text | |
if labeler.has_url_column: | |
left_btn.click( | |
judge_left, | |
inputs=[session_id, item_id, left_output, right_output], | |
outputs=[item_id, left_output, right_output, url_display, current_sample_sld] | |
) | |
right_btn.click( | |
judge_right, | |
inputs=[session_id, item_id, left_output, right_output], | |
outputs=[item_id, left_output, right_output, url_display, current_sample_sld] | |
) | |
tie_btn.click( | |
judge_tie, | |
inputs=[session_id, item_id, left_output, right_output], | |
outputs=[item_id, left_output, right_output, url_display, current_sample_sld] | |
) | |
cant_choose_btn.click( | |
judge_cant_choose, | |
inputs=[session_id, item_id, left_output, right_output], | |
outputs=[item_id, left_output, right_output, url_display, current_sample_sld] | |
) | |
else: | |
left_btn.click( | |
judge_left, | |
inputs=[session_id, item_id, left_output, right_output], | |
outputs=[item_id, left_output, right_output, current_sample_sld] | |
) | |
right_btn.click( | |
judge_right, | |
inputs=[session_id, item_id, left_output, right_output], | |
outputs=[item_id, left_output, right_output, current_sample_sld] | |
) | |
tie_btn.click( | |
judge_tie, | |
inputs=[session_id, item_id, left_output, right_output], | |
outputs=[item_id, left_output, right_output, current_sample_sld] | |
) | |
cant_choose_btn.click( | |
judge_cant_choose, | |
inputs=[session_id, item_id, left_output, right_output], | |
outputs=[item_id, left_output, right_output, current_sample_sld] | |
) | |
if __name__ == "__main__": | |
app.launch() | |