import gradio as gr import numpy as np import spaces import torch import random import json import os from PIL import Image from diffusers import FluxKontextPipeline from diffusers.utils import load_image, peft_utils from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard from safetensors.torch import load_file import requests import re # Load the base model MAX_SEED = np.iinfo(np.int32).max pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to("cuda") try: # Temporary workaround for diffusers LoRA loading issue from diffusers.utils.peft_utils import _derive_exclude_modules def new_derive_exclude_modules(*args, **kwargs): exclude_modules = _derive_exclude_modules(*args, **kwargs) if exclude_modules is not None: exclude_modules = [n for n in exclude_modules if "proj_out" not in n] return exclude_modules peft_utils._derive_exclude_modules = new_derive_exclude_modules except: pass # Load LoRA configurations from JSON with open("lora_configs.json", "r") as file: data = json.load(file) lora_configs = [ { "image": item["image"], "title": item["title"], "repo": item["repo"], "trigger_word": item.get("trigger_word", ""), "trigger_position": item.get("trigger_position", "prepend"), "weights": item.get("weights", "pytorch_lora_weights.safetensors"), } for item in data ] print(f"Loaded {len(lora_configs)} LoRAs from JSON") # Global variables for adapter management active_lora_adapter = None lora_cache = {} def load_lora_weights(repo_id, weights_filename): """Load adapter weights from HuggingFace""" try: if repo_id not in lora_cache: lora_path = hf_hub_download(repo_id=repo_id, filename=weights_filename) lora_cache[repo_id] = lora_path return lora_cache[repo_id] except Exception as e: print(f"Error loading adapter from {repo_id}: {e}") return None def on_lora_select(selected_state: gr.SelectData, lora_configs): """Update UI when an adapter is selected""" if selected_state.index >= len(lora_configs): return "### No adapter selected", gr.update(), None lora_repo = lora_configs[selected_state.index]["repo"] trigger_word = lora_configs[selected_state.index]["trigger_word"] updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo})" new_placeholder = f"optional description, e.g. 'a man with glasses and a beard'" return updated_text, gr.update(placeholder=new_placeholder), selected_state.index def fetch_lora_from_hf(link): """Retrieve adapter from HuggingFace link""" split_link = link.split("/") if len(split_link) == 2: try: model_card = ModelCard.load(link) trigger_word = model_card.data.get("instance_prompt", "") fs = HfFileSystem() list_of_files = fs.ls(link, detail=False) safetensors_file = None for file in list_of_files: if file.endswith(".safetensors") and "lora" in file.lower(): safetensors_file = file.split("/")[-1] break if not safetensors_file: safetensors_file = "pytorch_lora_weights.safetensors" return split_link[1], safetensors_file, trigger_word except Exception as e: raise Exception(f"Error loading adapter: {e}") else: raise Exception("Invalid HuggingFace repository format") def load_user_lora(link): """Load a user-provided adapter""" if not link: return gr.update(visible=False), "", gr.update(visible=False), None, gr.Gallery(selected_index=None), "### Click on an adapter in the gallery to select it", None try: repo_name, weights_file, trigger_word = fetch_lora_from_hf(link) card = f'''
Loaded custom adapter:

{repo_name}

{"Using: "+trigger_word+" as trigger word" if trigger_word else "No trigger word found"}
''' user_lora_data = { "repo": link, "weights": weights_file, "trigger_word": trigger_word } return gr.update(visible=True), card, gr.update(visible=True), user_lora_data, gr.Gallery(selected_index=None), f"Custom: {repo_name}", None except Exception as e: return gr.update(visible=True), f"Error: {str(e)}", gr.update(visible=False), None, gr.update(), "### Click on an adapter in the gallery to select it", None def unload_user_lora(): """Remove the user-provided adapter""" return "", gr.update(visible=False), gr.update(visible=False), None, None def sort_lora_gallery(lora_configs): """Sort the adapter gallery by likes""" sorted_gallery = sorted(lora_configs, key=lambda x: x.get("likes", 0), reverse=True) return [(item["image"], item["title"]) for item in sorted_gallery], sorted_gallery def generate_image_wrapper(input_image, prompt, selected_index, user_lora, seed=42, randomize_seed=False, steps=28, guidance_scale=2.5, lora_scale=1.75, width=960, height=1280, lora_configs=None, progress=gr.Progress(track_tqdm=True)): """Wrapper for image generation to handle state""" return generate_image(input_image, prompt, selected_index, user_lora, seed, randomize_seed, steps, guidance_scale, lora_scale, width, height, lora_configs, progress) @spaces.GPU def generate_image(input_image, prompt, selected_index, user_lora, seed=42, randomize_seed=False, steps=28, guidance_scale=2.5, lora_scale=1.0, width=960, height=1280, lora_configs=None, progress=gr.Progress(track_tqdm=True)): """Generate an image using the selected adapter""" global active_lora_adapter, pipe if randomize_seed: seed = random.randint(0, MAX_SEED) # Select the adapter to use lora_to_use = None if user_lora: lora_to_use = user_lora elif selected_index is not None and lora_configs and selected_index < len(lora_configs): lora_to_use = lora_configs[selected_index] print(f"Loaded {len(lora_configs)} adapters from JSON") # Load the adapter if necessary if lora_to_use and lora_to_use != active_lora_adapter: try: if active_lora_adapter: pipe.unload_lora_weights() lora_path = load_lora_weights(lora_to_use["repo"], lora_to_use["weights"]) if lora_path: pipe.load_lora_weights(lora_path, adapter_name="selected_lora") pipe.set_adapters(["selected_lora"], adapter_weights=[lora_scale]) print(f"loaded: {lora_path} with scale {lora_scale}") active_lora_adapter = lora_to_use except Exception as e: print(f"Error loading adapter: {e}") else: print(f"using already loaded adapter: {lora_to_use}") input_image = input_image.convert("RGB") # Modify prompt based on trigger word trigger_word = lora_to_use["trigger_word"] if trigger_word == ", How2Draw": prompt = f"create a How2Draw sketch of the person of the photo {prompt}, maintain the facial identity of the person and general features" elif trigger_word == "__ ": prompt = f" {prompt}. Accurately render the toolimpact logo and any tool impact iconography. The toolimpact logo begins with a two-line-tall drop-cap capital letter T with a dot in the center of its top bar." else: prompt = f" {prompt}. convert the style of this photo or image to {trigger_word}. Maintain the facial identity of any persons and the general features of the image!" try: image = pipe( image=input_image, prompt=prompt, guidance_scale=guidance_scale, num_inference_steps=steps, generator=torch.Generator().manual_seed(seed), width=width, height=height, max_area=width * height ).images[0] return image, seed, gr.update(visible=True) except Exception as e: print(f"Error during generation: {e}") return None, seed, gr.update(visible=False) # CSS styling css = """ #app_container { display: flex; gap: 20px; } #left_panel { min-width: 400px; } #lora_info { color: #2563eb; font-weight: bold; } #edit_prompt { flex-grow: 1; } #generate_button { background: linear-gradient(45deg, #2563eb, #3b82f6); color: white; border: none; padding: 8px 16px; border-radius: 6px; font-weight: bold; } .user_lora_card { background: #f8fafc; border: 1px solid #e2e8f0; border-radius: 8px; padding: 12px; margin: 8px 0; } #lora_gallery{ overflow: scroll !important } """ # Build the Gradio interface with gr.Blocks(theme=gr.themes.Soft(), css=css, delete_cache=(60, 60)) as demo: gr_lora_configs = gr.State(value=lora_configs) title = gr.HTML( """

Flux Kontext DLC😍

""", ) selected_state = gr.State(value=None) user_lora = gr.State(value=None) with gr.Row(elem_id="app_container"): with gr.Column(scale=4, elem_id="left_panel"): with gr.Group(elem_id="lora_selection"): input_image = gr.Image(label="Upload a picture", type="pil", height=300) gallery = gr.Gallery( label="Pick an Adapter", allow_preview=False, columns=3, elem_id="lora_gallery", show_share_button=False, height=400 ) user_lora_input = gr.Textbox( label="Or enter a custom HuggingFace adapter", placeholder="e.g., username/adapter-name", visible=True ) user_lora_card = gr.HTML(visible=False) unload_user_lora_button = gr.Button("Remove custom adapter", visible=True) with gr.Column(scale=5): with gr.Row(): prompt = gr.Textbox( label="Editing Prompt", show_label=False, lines=1, max_lines=1, placeholder="optional description, e.g. 'colorize and stylize, leave all else as is'", elem_id="edit_prompt" ) run_button = gr.Button("Generate", elem_id="generate_button") result = gr.Image(label="Generated Image", interactive=False) reuse_button = gr.Button("Reuse this image", visible=False) with gr.Accordion("Advanced Settings", open=True): lora_scale = gr.Slider( label="Adapter Scale", minimum=0, maximum=2, step=0.1, value=1.5, info="Controls the strength of the adapter effect" ) seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, ) steps = gr.Slider( label="Steps", minimum=1, maximum=40, value=10, step=1 ) width = gr.Slider( label="Width", minimum=128, maximum=2560, step=1, value=960, ) height = gr.Slider( label="Height", minimum=128, maximum=2560, step=1, value=1280, ) randomize_seed = gr.Checkbox(label="Randomize seed", value=True) guidance_scale = gr.Slider( label="Guidance Scale", minimum=1, maximum=10, step=0.1, value=2.8, ) prompt_title = gr.Markdown( value="### Click on an adapter in the gallery to select it", visible=True, elem_id="lora_info", ) # Event handlers user_lora_input.input( fn=load_user_lora, inputs=[user_lora_input], outputs=[user_lora_card, user_lora_card, unload_user_lora_button, user_lora, gallery, prompt_title, selected_state], ) unload_user_lora_button.click( fn=unload_user_lora, outputs=[user_lora_input, unload_user_lora_button, user_lora_card, user_lora, selected_state] ) gallery.select( fn=on_lora_select, inputs=[gr_lora_configs], outputs=[prompt_title, prompt, selected_state], show_progress=False ) gr.on( triggers=[run_button.click, prompt.submit], fn=generate_image_wrapper, inputs=[input_image, prompt, selected_state, user_lora, seed, randomize_seed, steps, guidance_scale, lora_scale, width, height, gr_lora_configs], outputs=[result, seed, reuse_button] ) reuse_button.click( fn=lambda image: image, inputs=[result], outputs=[input_image] ) # Initialize the gallery demo.load( fn=sort_lora_gallery, inputs=[gr_lora_configs], outputs=[gallery, gr_lora_configs] ) demo.queue(default_concurrency_limit=None) demo.launch()