import torch import gradio as gr from diffusers import FluxPipeline, FluxTransformer2DModel import gc import random from PIL import Image import os import time import spaces DEVICE = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {DEVICE}") DEFAULT_HEIGHT = 1024 DEFAULT_WIDTH = 1024 DEFAULT_GUIDANCE_SCALE = 3.5 DEFAULT_NUM_INFERENCE_STEPS = 15 DEFAULT_MAX_SEQUENCE_LENGTH = 512 GENERATION_SEED = 0 # could use a random number generator to set this, for more variety HF_TOKEN = os.environ.get("HF_ACCESS_TOKEN") CACHED_PIPES = {} def load_bf16_pipeline(): """Loads the original FLUX.1-dev pipeline in BF16 precision.""" print("Loading BF16 pipeline...") MODEL_ID = "black-forest-labs/FLUX.1-dev" if MODEL_ID in CACHED_PIPES: return CACHED_PIPES[MODEL_ID] start_time = time.time() try: pipe = FluxPipeline.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16, token=HF_TOKEN ) pipe.to(DEVICE) # pipe.enable_model_cpu_offload() end_time = time.time() mem_reserved = torch.cuda.memory_reserved(0)/1024**3 if DEVICE == "cuda" else 0 print(f"BF16 Pipeline loaded in {end_time - start_time:.2f}s. Memory reserved: {mem_reserved:.2f} GB") CACHED_PIPES[MODEL_ID] = pipe return pipe except Exception as e: print(f"Error loading BF16 pipeline: {e}") raise # Re-raise exception to be caught in generate_images def load_bnb_8bit_pipeline(): """Loads the FLUX.1-dev pipeline with 8-bit quantized components.""" print("Loading 8-bit BNB pipeline...") MODEL_ID = "derekl35/FLUX.1-dev-bnb-8bit" if MODEL_ID in CACHED_PIPES: return CACHED_PIPES[MODEL_ID] start_time = time.time() try: pipe = FluxPipeline.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16 ) pipe.to(DEVICE) # pipe.enable_model_cpu_offload() end_time = time.time() mem_reserved = torch.cuda.memory_reserved(0)/1024**3 if DEVICE == "cuda" else 0 print(f"8-bit BNB pipeline loaded in {end_time - start_time:.2f}s. Memory reserved: {mem_reserved:.2f} GB") CACHED_PIPES[MODEL_ID] = pipe return pipe except Exception as e: print(f"Error loading 8-bit BNB pipeline: {e}") raise def load_bnb_4bit_pipeline(): """Loads the FLUX.1-dev pipeline with 4-bit quantized components.""" print("Loading 4-bit BNB pipeline...") MODEL_ID = "derekl35/FLUX.1-dev-nf4" if MODEL_ID in CACHED_PIPES: return CACHED_PIPES[MODEL_ID] start_time = time.time() try: pipe = FluxPipeline.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16 ) pipe.to(DEVICE) # pipe.enable_model_cpu_offload() end_time = time.time() mem_reserved = torch.cuda.memory_reserved(0)/1024**3 if DEVICE == "cuda" else 0 print(f"4-bit BNB pipeline loaded in {end_time - start_time:.2f}s. Memory reserved: {mem_reserved:.2f} GB") CACHED_PIPES[MODEL_ID] = pipe return pipe except Exception as e: print(f"4-bit BNB pipeline: {e}") raise @spaces.GPU(duration=240) def generate_images(prompt, quantization_choice, progress=gr.Progress(track_tqdm=True)): """Loads original and selected quantized model, generates one image each, shuffles results.""" if not prompt: return None, {}, gr.update(value="Please enter a prompt.", interactive=False), gr.update(choices=[], value=None) if not quantization_choice: # Return updates for all outputs to clear them or show warning return None, {}, gr.update(value="Please select a quantization method.", interactive=False), gr.update(choices=[], value=None) # Determine which quantized model to load if quantization_choice == "8-bit": quantized_load_func = load_bnb_8bit_pipeline quantized_label = "Quantized (8-bit)" elif quantization_choice == "4-bit": quantized_load_func = load_bnb_4bit_pipeline quantized_label = "Quantized (4-bit)" else: # Should not happen with Radio choices, but good practice return None, {}, gr.update(value="Invalid quantization choice.", interactive=False), gr.update(choices=[], value=None) model_configs = [ ("Original", load_bf16_pipeline), (quantized_label, quantized_load_func), # Use the specific label here ] results = [] pipe_kwargs = { "prompt": prompt, "height": DEFAULT_HEIGHT, "width": DEFAULT_WIDTH, "guidance_scale": DEFAULT_GUIDANCE_SCALE, "num_inference_steps": DEFAULT_NUM_INFERENCE_STEPS, "max_sequence_length": DEFAULT_MAX_SEQUENCE_LENGTH, } current_pipe = None # Keep track of the current pipe for cleanup seed = random.getrandbits(64) print(f"Using seed: {seed}") for i, (label, load_func) in enumerate(model_configs): progress(i / len(model_configs), desc=f"Loading {label} model...") print(f"\n--- Loading {label} Model ---") load_start_time = time.time() try: current_pipe = load_func() load_end_time = time.time() print(f"{label} model loaded in {load_end_time - load_start_time:.2f} seconds.") progress((i + 0.5) / len(model_configs), desc=f"Generating with {label} model...") print(f"--- Generating with {label} Model ---") gen_start_time = time.time() image_list = current_pipe(**pipe_kwargs, generator=torch.manual_seed(seed)).images image = image_list[0] # image.save(f"{load_start_time}.png") gen_end_time = time.time() results.append({"label": label, "image": image}) print(f"--- Finished Generation with {label} Model in {gen_end_time - gen_start_time:.2f} seconds ---") mem_reserved = torch.cuda.memory_reserved(0)/1024**3 if DEVICE == "cuda" else 0 print(f"Memory reserved: {mem_reserved:.2f} GB") except Exception as e: print(f"Error during {label} model processing: {e}") # Return error state to Gradio - update all outputs return None, {}, gr.update(value=f"Error processing {label} model: {e}", interactive=False), gr.update(choices=[], value=None) # No finally block needed here, cleanup happens before next load or after loop if len(results) != len(model_configs): print("Generation did not complete for all models.") # Update all outputs return None, {}, gr.update(value="Failed to generate images for all model types.", interactive=False), gr.update(choices=[], value=None) # Shuffle the results for display shuffled_results = results.copy() random.shuffle(shuffled_results) # Create the gallery data: [(image, caption), (image, caption)] shuffled_data_for_gallery = [(res["image"], f"Image {i+1}") for i, res in enumerate(shuffled_results)] # Create the mapping: display_index -> correct_label (e.g., {0: 'Original', 1: 'Quantized (8-bit)'}) correct_mapping = {i: res["label"] for i, res in enumerate(shuffled_results)} print("Correct mapping (hidden):", correct_mapping) guess_radio_update = gr.update(choices=["Image 1", "Image 2"], value=None, interactive=True) # Return shuffled images, the correct mapping state, status message, and update the guess radio return shuffled_data_for_gallery, correct_mapping, gr.update(value="Generation complete! Make your guess.", interactive=False), guess_radio_update # --- Guess Verification Function --- def check_guess(user_guess, correct_mapping_state): """Compares the user's guess with the correct mapping stored in the state.""" if not isinstance(correct_mapping_state, dict) or not correct_mapping_state: return "Please generate images first (state is empty or invalid)." if user_guess is None: return "Please select which image you think is quantized." # Find which display index (0 or 1) corresponds to the quantized image quantized_image_index = -1 quantized_label_actual = "" for index, label in correct_mapping_state.items(): if "Quantized" in label: # Check if the label indicates quantization quantized_image_index = index quantized_label_actual = label # Store the full label e.g. "Quantized (8-bit)" break if quantized_image_index == -1: # This shouldn't happen if generation was successful return "Error: Could not find the quantized image in the mapping data." # Determine what the user *should* have selected based on the index correct_guess_label = f"Image {quantized_image_index + 1}" # "Image 1" or "Image 2" if user_guess == correct_guess_label: feedback = f"Correct! {correct_guess_label} used the {quantized_label_actual} model." else: feedback = f"Incorrect. The quantized image ({quantized_label_actual}) was {correct_guess_label}." return feedback with gr.Blocks(title="FLUX Quantization Challenge", theme=gr.themes.Soft()) as demo: gr.Markdown("# FLUX Model Quantization Challenge") gr.Markdown( "Compare the original FLUX.1-dev (BF16) model against a quantized version (4-bit or 8-bit). " "Enter a prompt, choose the quantization method, and generate two images. " "The images will be shuffled, can you spot which one was quantized?" ) with gr.Row(): prompt_input = gr.Textbox(label="Enter Prompt", scale=3) quantization_choice_radio = gr.Radio( choices=["8-bit", "4-bit"], label="Select Quantization", value="8-bit", # Default choice scale=1 ) generate_button = gr.Button("Generate & Compare", variant="primary", scale=1) output_gallery = gr.Gallery( label="Generated Images", columns=2, height=512, object_fit="contain", allow_preview=True, show_label=True, # Shows "Image 1", "Image 2" captions we provide ) gr.Markdown("### Which image used the selected quantization method?") with gr.Row(): image1_btn = gr.Button("Image 1") image2_btn = gr.Button("Image 2") feedback_box = gr.Textbox(label="Feedback", interactive=False, lines=1) # Hidden state to store the correct mapping after shuffling # e.g., {0: 'Original', 1: 'Quantized (8-bit)'} or {0: 'Quantized (4-bit)', 1: 'Original'} correct_mapping_state = gr.State({}) generate_button.click( fn=generate_images, inputs=[prompt_input, quantization_choice_radio], outputs=[output_gallery, correct_mapping_state], ).then(lambda: "", outputs=[feedback_box]) # clear feedback on new run # helper wrappers so we can supply the fixed choice string def choose_img1(mapping): return check_guess("Image 1", mapping) def choose_img2(mapping): return check_guess("Image 2", mapping) image1_btn.click(choose_img1, inputs=[correct_mapping_state], outputs=[feedback_box]) image2_btn.click(choose_img2, inputs=[correct_mapping_state], outputs=[feedback_box]) if __name__ == "__main__": demo.launch(share=True) demo.launch()