import gradio as gr import numpy as np import spaces import torch import random import json import os from PIL import Image from diffusers import FluxKontextPipeline from diffusers.utils import load_image, peft_utils from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard from safetensors.torch import load_file import requests import re # Load the base model MAX_SEED = np.iinfo(np.int32).max pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to("cuda") try: # Temporary workaround for diffusers LoRA loading issue from diffusers.utils.peft_utils import _derive_exclude_modules def new_derive_exclude_modules(*args, **kwargs): exclude_modules = _derive_exclude_modules(*args, **kwargs) if exclude_modules is not None: exclude_modules = [n for n in exclude_modules if "proj_out" not in n] return exclude_modules peft_utils._derive_exclude_modules = new_derive_exclude_modules except: pass # Load LoRA configurations from JSON with open("lora_configs.json", "r") as file: data = json.load(file) lora_configs = [ { "image": item["image"], "title": item["title"], "repo": item["repo"], "trigger_word": item.get("trigger_word", ""), "trigger_position": item.get("trigger_position", "prepend"), "weights": item.get("weights", "pytorch_lora_weights.safetensors"), } for item in data ] print(f"Loaded {len(lora_configs)} LoRAs from JSON") # Global variables for adapter management active_lora_adapter = None lora_cache = {} def load_lora_weights(repo_id, weights_filename): """Load adapter weights from HuggingFace""" try: if repo_id not in lora_cache: lora_path = hf_hub_download(repo_id=repo_id, filename=weights_filename) lora_cache[repo_id] = lora_path return lora_cache[repo_id] except Exception as e: print(f"Error loading adapter from {repo_id}: {e}") return None def on_lora_select(selected_state: gr.SelectData, lora_configs): """Update UI when an adapter is selected""" if selected_state.index >= len(lora_configs): return "### No adapter selected", gr.update(), None lora_repo = lora_configs[selected_state.index]["repo"] trigger_word = lora_configs[selected_state.index]["trigger_word"] updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo})" new_placeholder = f"optional description, e.g. 'a man with glasses and a beard'" return updated_text, gr.update(placeholder=new_placeholder), selected_state.index def fetch_lora_from_hf(link): """Retrieve adapter from HuggingFace link""" split_link = link.split("/") if len(split_link) == 2: try: model_card = ModelCard.load(link) trigger_word = model_card.data.get("instance_prompt", "") fs = HfFileSystem() list_of_files = fs.ls(link, detail=False) safetensors_file = None for file in list_of_files: if file.endswith(".safetensors") and "lora" in file.lower(): safetensors_file = file.split("/")[-1] break if not safetensors_file: safetensors_file = "pytorch_lora_weights.safetensors" return split_link[1], safetensors_file, trigger_word except Exception as e: raise Exception(f"Error loading adapter: {e}") else: raise Exception("Invalid HuggingFace repository format") def load_user_lora(link): """Load a user-provided adapter""" if not link: return gr.update(visible=False), "", gr.update(visible=False), None, gr.Gallery(selected_index=None), "### Click on an adapter in the gallery to select it", None try: repo_name, weights_file, trigger_word = fetch_lora_from_hf(link) card = f'''
"+trigger_word+"
as trigger word" if trigger_word else "No trigger word found"}
Edit images using custom style adapters. Fast generation with minimal steps.
""", ) selected_state = gr.State(value=None) user_lora = gr.State(value=None) with gr.Row(elem_id="app_container"): with gr.Column(scale=4, elem_id="left_panel"): with gr.Group(elem_id="lora_selection"): input_image = gr.Image(label="Upload a picture", type="pil", height=300) gallery = gr.Gallery( label="Pick an Adapter", allow_preview=False, columns=3, elem_id="lora_gallery", show_share_button=False, height=400 ) user_lora_input = gr.Textbox( label="Or enter a custom HuggingFace adapter", placeholder="e.g., username/adapter-name", visible=True ) user_lora_card = gr.HTML(visible=False) unload_user_lora_button = gr.Button("Remove custom adapter", visible=True) with gr.Column(scale=5): with gr.Row(): prompt = gr.Textbox( label="Editing Prompt", show_label=False, lines=1, max_lines=1, placeholder="optional description, e.g. 'colorize and stylize, leave all else as is'", elem_id="edit_prompt" ) run_button = gr.Button("Generate", elem_id="generate_button") result = gr.Image(label="Generated Image", interactive=False) reuse_button = gr.Button("Reuse this image", visible=False) with gr.Accordion("Advanced Settings", open=True): lora_scale = gr.Slider( label="Adapter Scale", minimum=0, maximum=2, step=0.1, value=1.5, info="Controls the strength of the adapter effect" ) seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, ) steps = gr.Slider( label="Steps", minimum=1, maximum=40, value=10, step=1 ) width = gr.Slider( label="Width", minimum=128, maximum=2560, step=1, value=960, ) height = gr.Slider( label="Height", minimum=128, maximum=2560, step=1, value=1280, ) randomize_seed = gr.Checkbox(label="Randomize seed", value=True) guidance_scale = gr.Slider( label="Guidance Scale", minimum=1, maximum=10, step=0.1, value=2.8, ) prompt_title = gr.Markdown( value="### Click on an adapter in the gallery to select it", visible=True, elem_id="lora_info", ) # Event handlers user_lora_input.input( fn=load_user_lora, inputs=[user_lora_input], outputs=[user_lora_card, user_lora_card, unload_user_lora_button, user_lora, gallery, prompt_title, selected_state], ) unload_user_lora_button.click( fn=unload_user_lora, outputs=[user_lora_input, unload_user_lora_button, user_lora_card, user_lora, selected_state] ) gallery.select( fn=on_lora_select, inputs=[gr_lora_configs], outputs=[prompt_title, prompt, selected_state], show_progress=False ) gr.on( triggers=[run_button.click, prompt.submit], fn=generate_image_wrapper, inputs=[input_image, prompt, selected_state, user_lora, seed, randomize_seed, steps, guidance_scale, lora_scale, width, height, gr_lora_configs], outputs=[result, seed, reuse_button] ) reuse_button.click( fn=lambda image: image, inputs=[result], outputs=[input_image] ) # Initialize the gallery demo.load( fn=sort_lora_gallery, inputs=[gr_lora_configs], outputs=[gallery, gr_lora_configs] ) demo.queue(default_concurrency_limit=None) demo.launch()