Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| import os | |
| import uuid | |
| import datetime | |
| import logging | |
| from huggingface_hub import hf_hub_download, upload_file, list_repo_tree | |
| from dotenv import load_dotenv | |
| from collections import defaultdict | |
| load_dotenv() | |
| # Configuration | |
| HF_INPUT_DATASET = os.getenv("HF_INPUT_DATASET") | |
| HF_INPUT_DATASET_PATH = os.getenv("HF_INPUT_DATASET_PATH") | |
| HF_INPUT_DATASET_ID_COLUMN = os.getenv("HF_INPUT_DATASET_ID_COLUMN") | |
| HF_INPUT_DATASET_COLUMN_A = os.getenv("HF_INPUT_DATASET_COLUMN_A") | |
| HF_INPUT_DATASET_COLUMN_B = os.getenv("HF_INPUT_DATASET_COLUMN_B") | |
| HF_INPUT_DATASET_URL_COLUMN = os.getenv("HF_INPUT_DATASET_URL_COLUMN") | |
| HF_OUTPUT_DATASET = os.getenv("HF_OUTPUT_DATASET") | |
| HF_OUTPUT_DATASET_DIR = os.getenv("HF_OUTPUT_DATASET_DIR") | |
| INSTRUCTIONS = """ | |
| # Pairwise Model Output Labeling Space | |
| ### How-to: | |
| * Duplicate this Space | |
| * Add a `HF_TOKEN` secret in the Space settings (it saves the results to a given dataset, so it requires write access) | |
| * Set the environment variables in the [`.env`](https://huggingface.co/spaces/saridormi/labeling-template/blob/main/.env) file to point to correct datasets/columns/paths for your input and output data | |
| * Adjust `INSTRUCTIONS` (it's what's rendered here) and `SAVE_EVERY_N_EXAMPLES` (how often should answers be saved to the output dataset) in the [`app.py`](https://huggingface.co/spaces/saridormi/labeling-template/blob/main/app.py) file | |
| Done! 💛 | |
| ### About: | |
| Please compare the two model outputs shown below and select which one you think is better. | |
| - Choose "A is better" if the output from Model A (left) is superior | |
| - Choose "B is better" if the output from Model B (right) is superior | |
| - Choose "Tie" if they are equally good or bad | |
| - Choose "Can't choose" if you cannot make a determination | |
| """ | |
| SAVE_EVERY_N_EXAMPLES = 5 | |
| class PairwiseLabeler: | |
| def __init__(self): | |
| self.current_index = defaultdict(int) | |
| self.results = defaultdict(list) | |
| self._df = self.read_hf_dataset() | |
| self.df = {} | |
| self.has_url_column = HF_INPUT_DATASET_URL_COLUMN and HF_INPUT_DATASET_URL_COLUMN in self._df.columns | |
| def __len__(self): | |
| return len(self._df) | |
| def read_hf_dataset(self) -> pd.DataFrame: | |
| try: | |
| local_file = hf_hub_download(repo_id=HF_INPUT_DATASET, repo_type="dataset", filename=HF_INPUT_DATASET_PATH) | |
| if local_file.endswith(".json"): | |
| return pd.read_json(local_file) | |
| elif local_file.endswith(".jsonl"): | |
| return pd.read_json(local_file, orient="records",lines=True) | |
| elif local_file.endswith(".csv"): | |
| return pd.read_csv(local_file) | |
| elif local_file.endswith(".parquet"): | |
| return pd.read_parquet(local_file) | |
| else: | |
| raise ValueError(f"Unsupported file type: {local_file}") | |
| except Exception as e: | |
| # Fallback to sample data if loading fails | |
| logging.error(f"Couldn't read HF dataset from {HF_INPUT_DATASET_PATH}. Using sample data instead.") | |
| sample_data = { | |
| HF_INPUT_DATASET_ID_COLUMN: [f"sample_{i}" for i in range(SAVE_EVERY_N_EXAMPLES)], | |
| HF_INPUT_DATASET_COLUMN_A: [f"This is sample generation A {i}" for i in range(SAVE_EVERY_N_EXAMPLES)], | |
| HF_INPUT_DATASET_COLUMN_B: [f"This is sample generation B {i}" for i in range(SAVE_EVERY_N_EXAMPLES)], | |
| } | |
| # Add URL column to sample data if specified | |
| if HF_INPUT_DATASET_URL_COLUMN: | |
| sample_data[HF_INPUT_DATASET_URL_COLUMN] = [f"https://example.com/sample_{i}" for i in range(SAVE_EVERY_N_EXAMPLES)] | |
| return pd.DataFrame(sample_data) | |
| def get_current_pair(self, session_id): | |
| if session_id not in self.df: | |
| self.df[session_id] = self._df.sample(frac=1).reset_index(drop=True) | |
| if self.current_index[session_id] >= len(self.df[session_id]): | |
| if self.has_url_column: | |
| return None, None, None, None | |
| else: | |
| return None, None, None | |
| item = self.df[session_id].iloc[self.current_index[session_id]] | |
| item_id = item.get(HF_INPUT_DATASET_ID_COLUMN, f"item_{self.current_index[session_id]}") | |
| left_text = item.get(HF_INPUT_DATASET_COLUMN_A, "") | |
| right_text = item.get(HF_INPUT_DATASET_COLUMN_B, "") | |
| if self.has_url_column: | |
| url = item.get(HF_INPUT_DATASET_URL_COLUMN, "") | |
| return item_id, left_text, right_text, url | |
| else: | |
| return item_id, left_text, right_text | |
| def submit_judgment(self, item_id, left_text, right_text, choice, session_id): | |
| if item_id is None: | |
| if self.has_url_column: | |
| return item_id, left_text, right_text, None, self.current_index[session_id] | |
| else: | |
| return item_id, left_text, right_text, self.current_index[session_id] | |
| # Record the judgment | |
| result = { | |
| "item_id": item_id, | |
| "judgment": choice, | |
| "timestamp": datetime.datetime.now().isoformat(), | |
| "labeler_id": session_id | |
| } | |
| self.results[session_id].append(result) | |
| # Move to next item | |
| self.current_index[session_id] += 1 | |
| # Save results periodically | |
| if len(self.results[session_id]) % SAVE_EVERY_N_EXAMPLES == 0: | |
| self.save_results(session_id) | |
| # Get next pair | |
| if self.has_url_column: | |
| next_id, next_left, next_right, next_url = self.get_current_pair(session_id) | |
| return next_id, next_left, next_right, next_url, self.current_index[session_id] | |
| else: | |
| next_id, next_left, next_right = self.get_current_pair(session_id) | |
| return next_id, next_left, next_right, self.current_index[session_id] | |
| def save_results(self, session_id): | |
| if not self.results[session_id]: | |
| return | |
| try: | |
| # Convert results to dataset format | |
| results_df = pd.DataFrame(self.results[session_id]) | |
| results_df.to_json("temp.jsonl", orient="records", lines=True) | |
| # Push to Hugging Face Hub | |
| try: | |
| num_files = len([_ for _ in list_repo_tree(repo_id=HF_OUTPUT_DATASET, repo_type="dataset", path_in_repo=HF_OUTPUT_DATASET_DIR) if session_id in _.path]) | |
| except Exception as e: | |
| num_files = 0 | |
| # Use session_id in filename to avoid conflicts | |
| filename = f"results_{session_id}_{num_files+1}.jsonl" | |
| upload_file( | |
| repo_id=HF_OUTPUT_DATASET, | |
| repo_type="dataset", | |
| path_in_repo=os.path.join(HF_OUTPUT_DATASET_DIR, filename), | |
| path_or_fileobj="temp.jsonl" | |
| ) | |
| os.remove("temp.jsonl") | |
| # Clear saved results | |
| self.results[session_id] = [] | |
| logging.info(f"Saved results for session {session_id} to {HF_OUTPUT_DATASET}/{filename}") | |
| except Exception as e: | |
| logging.error(f"Error saving results: {e}") | |
| # Keep results in memory to try saving again later | |
| # Initialize the labeler | |
| labeler = PairwiseLabeler() | |
| # Create a default session ID | |
| def create_default_session(): | |
| return str(uuid.uuid4())[:8] | |
| # Create a demo session ID for initial display | |
| demo_session = create_default_session() | |
| demo_item_id, demo_left_text, demo_right_text = None, "", "" | |
| demo_url = None | |
| if labeler.has_url_column: | |
| demo_item_id, demo_left_text, demo_right_text, demo_url = labeler.get_current_pair(demo_session) | |
| else: | |
| demo_item_id, demo_left_text, demo_right_text = labeler.get_current_pair(demo_session) | |
| with gr.Blocks(css="footer {visibility: hidden}") as app: | |
| # State for the session ID | |
| session_id = gr.State(value=None) | |
| is_session_started = gr.State(value=False) | |
| # Header with instructions | |
| gr.Markdown(INSTRUCTIONS) | |
| # Session and progress bar in one line | |
| user_session_id = gr.Textbox( | |
| placeholder="Enter your unique ID or leave blank for random", | |
| label="Session ID", | |
| scale=2 | |
| ) | |
| start_btn = gr.Button("Start Session", variant="primary", scale=1) | |
| # URL display component - only shown if URL column is defined | |
| if labeler.has_url_column: | |
| url_display = gr.HTML(label="Reference URL", value="", visible=True) | |
| # Main content area (shown with placeholder text) | |
| with gr.Row(): | |
| with gr.Column(): | |
| left_output = gr.Textbox( | |
| value=demo_left_text, | |
| label="Model A Output", | |
| lines=10, | |
| interactive=False | |
| ) | |
| with gr.Column(): | |
| right_output = gr.Textbox( | |
| value=demo_right_text, | |
| label="Model B Output", | |
| lines=10, | |
| interactive=False | |
| ) | |
| item_id = gr.Textbox(value=demo_item_id, visible=False) | |
| # Buttons row (initially disabled) | |
| with gr.Row(): | |
| left_btn = gr.Button("⬅️ A is better", variant="primary", interactive=False) | |
| right_btn = gr.Button("➡️ B is better", variant="primary", interactive=False) | |
| tie_btn = gr.Button("🤝 Tie", variant="primary", interactive=False) | |
| cant_choose_btn = gr.Button("🤔 Can't choose", interactive=False) | |
| # Progress slider (initially disabled) | |
| current_sample_sld = gr.Slider( | |
| minimum=0, | |
| maximum=len(labeler), | |
| step=1, | |
| value=0, | |
| interactive=False, | |
| label='Progress', | |
| show_label=False | |
| ) | |
| # Initialize the session and get the first pair | |
| def init_session(entered_id): | |
| # Use entered ID or generate a new one if empty | |
| session_id_value = entered_id.strip() if entered_id and entered_id.strip() else create_default_session() | |
| # Get the initial data for this session ID | |
| if labeler.has_url_column: | |
| initial_id, initial_left, initial_right, initial_url = labeler.get_current_pair(session_id_value) | |
| url_html = f'<a href="{initial_url}" target="_blank">{initial_url}</a>' if initial_url else "" | |
| return ( | |
| session_id_value, # session_id state | |
| True, # is_session_started state | |
| labeler.current_index[session_id_value], # current_sample_sld | |
| url_html, # url_display | |
| initial_left, # left_output | |
| initial_right, # right_output | |
| initial_id, # item_id | |
| gr.update(interactive=True), # left_btn | |
| gr.update(interactive=True), # right_btn | |
| gr.update(interactive=True), # tie_btn | |
| gr.update(interactive=True) # cant_choose_btn | |
| ) | |
| else: | |
| initial_id, initial_left, initial_right = labeler.get_current_pair(session_id_value) | |
| # Create session info text | |
| progress_text = f"Session: {session_id_value} | {labeler.current_index[session_id_value] + 1}/{len(labeler)}" | |
| return ( | |
| session_id_value, # session_id state | |
| True, # is_session_started state | |
| labeler.current_index[session_id_value], # current_sample_sld | |
| initial_left, # left_output | |
| initial_right, # right_output | |
| initial_id, # item_id | |
| gr.update(interactive=True), # left_btn | |
| gr.update(interactive=True), # right_btn | |
| gr.update(interactive=True), # tie_btn | |
| gr.update(interactive=True) # cant_choose_btn | |
| ) | |
| # Connect the start button | |
| if labeler.has_url_column: | |
| start_btn.click( | |
| init_session, | |
| inputs=[user_session_id], | |
| outputs=[ | |
| session_id, | |
| is_session_started, | |
| current_sample_sld, | |
| url_display, | |
| left_output, | |
| right_output, | |
| item_id, | |
| left_btn, | |
| right_btn, | |
| tie_btn, | |
| cant_choose_btn | |
| ] | |
| ) | |
| else: | |
| start_btn.click( | |
| init_session, | |
| inputs=[user_session_id], | |
| outputs=[ | |
| session_id, | |
| is_session_started, | |
| current_sample_sld, | |
| left_output, | |
| right_output, | |
| item_id, | |
| left_btn, | |
| right_btn, | |
| tie_btn, | |
| cant_choose_btn | |
| ] | |
| ) | |
| def judge_left(session_id, item_id, left_text, right_text): | |
| return judge("A is better", session_id, item_id, left_text, right_text) | |
| def judge_right(session_id, item_id, left_text, right_text): | |
| return judge("B is better", session_id, item_id, left_text, right_text) | |
| def judge_tie(session_id, item_id, left_text, right_text): | |
| return judge("Tie", session_id, item_id, left_text, right_text) | |
| def judge_cant_choose(session_id, item_id, left_text, right_text): | |
| return judge("Can't choose", session_id, item_id, left_text, right_text) | |
| def judge(choice, session_id, item_id, left_text, right_text): | |
| if not session_id: | |
| # Session not initialized, do nothing | |
| return item_id, left_text, right_text, 0, "Not started" | |
| if labeler.has_url_column: | |
| new_id, new_left, new_right, new_url, new_index = labeler.submit_judgment( | |
| item_id, left_text, right_text, choice, session_id | |
| ) | |
| url_html = f'<a href="{new_url}" target="_blank">{new_url}</a>' if new_url else "" | |
| # Update session info text with new progress | |
| progress_text = f"Session: {session_id} | {new_index + 1}/{len(labeler)}" if new_index < len(labeler) else f"Session: {session_id} | Complete! {len(labeler)}/{len(labeler)}" | |
| return new_id, new_left, new_right, url_html, new_index, progress_text | |
| else: | |
| new_id, new_left, new_right, new_index = labeler.submit_judgment( | |
| item_id, left_text, right_text, choice, session_id | |
| ) | |
| # Update session info text with new progress | |
| progress_text = f"Session: {session_id} | {new_index + 1}/{len(labeler)}" if new_index < len(labeler) else f"Session: {session_id} | Complete! {len(labeler)}/{len(labeler)}" | |
| return new_id, new_left, new_right, new_index, progress_text | |
| if labeler.has_url_column: | |
| left_btn.click( | |
| judge_left, | |
| inputs=[session_id, item_id, left_output, right_output], | |
| outputs=[item_id, left_output, right_output, url_display, current_sample_sld] | |
| ) | |
| right_btn.click( | |
| judge_right, | |
| inputs=[session_id, item_id, left_output, right_output], | |
| outputs=[item_id, left_output, right_output, url_display, current_sample_sld] | |
| ) | |
| tie_btn.click( | |
| judge_tie, | |
| inputs=[session_id, item_id, left_output, right_output], | |
| outputs=[item_id, left_output, right_output, url_display, current_sample_sld] | |
| ) | |
| cant_choose_btn.click( | |
| judge_cant_choose, | |
| inputs=[session_id, item_id, left_output, right_output], | |
| outputs=[item_id, left_output, right_output, url_display, current_sample_sld] | |
| ) | |
| else: | |
| left_btn.click( | |
| judge_left, | |
| inputs=[session_id, item_id, left_output, right_output], | |
| outputs=[item_id, left_output, right_output, current_sample_sld] | |
| ) | |
| right_btn.click( | |
| judge_right, | |
| inputs=[session_id, item_id, left_output, right_output], | |
| outputs=[item_id, left_output, right_output, current_sample_sld] | |
| ) | |
| tie_btn.click( | |
| judge_tie, | |
| inputs=[session_id, item_id, left_output, right_output], | |
| outputs=[item_id, left_output, right_output, current_sample_sld] | |
| ) | |
| cant_choose_btn.click( | |
| judge_cant_choose, | |
| inputs=[session_id, item_id, left_output, right_output], | |
| outputs=[item_id, left_output, right_output, current_sample_sld] | |
| ) | |
| if __name__ == "__main__": | |
| app.launch() | |