Aatricks's picture
Upload folder using huggingface_hub
1264e6e verified
import os
import queue
import sys
import random
import threading
import tkinter as tk
from tkinter import filedialog
from typing import Union
from PIL import Image, ImageTk
import numpy as np
import customtkinter as ctk
import glob
import time
import torch
# Add the directory containing LightDiffusion.py to the Python path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
from modules.AutoDetailer import SAM, ADetailer, bbox, SEGS
from modules.AutoEncoders import VariationalAE
from modules.clip import Clip
from modules.sample import sampling
from modules.Utilities import util
from modules.UltimateSDUpscale import USDU_upscaler, UltimateSDUpscale
from modules.FileManaging import Downloader, ImageSaver, Loader
from modules.Model import LoRas
from modules.Utilities import Enhancer, Latent, upscale
from modules.Quantize import Quantizer
from modules.WaveSpeed import fbcache_nodes
from modules.hidiffusion import msw_msa_attention
from modules.AutoHDR import ahdr
Downloader.CheckAndDownload()
files = glob.glob("./_internal/checkpoints/*.safetensors")
loras = glob.glob("./_internal/loras/*.safetensors")
loras += glob.glob("./_internal/loras/*.pt")
def debounce(wait):
"""Decorator to debounce resize events"""
def decorator(fn):
last_call = [0]
def debounced(*args, **kwargs):
current_time = time.time()
if current_time - last_call[0] >= wait:
fn(*args, **kwargs)
last_call[0] = current_time
return debounced
return decorator
class App(tk.Tk):
"""Main application class for the LightDiffusion GUI."""
def __init__(self):
"""Initialize the App class."""
super().__init__()
self.title("LightDiffusion")
self.geometry("900x750")
# Configure main window grid
self.grid_columnconfigure(1, weight=1)
self.grid_rowconfigure(0, weight=1)
file_names = [os.path.basename(file) for file in files]
lora_names = [os.path.basename(lora) for lora in loras]
selected_file = tk.StringVar()
selected_lora = tk.StringVar()
if file_names:
selected_file.set(file_names[0])
if lora_names:
selected_lora.set(lora_names[0])
# Create main sidebar frame with padding and grid
self.sidebar = tk.Frame(self, bg="#FBFBFB", padx=10, pady=10)
self.sidebar.grid(row=0, column=0, sticky="nsew")
self.sidebar.grid_columnconfigure(0, weight=1)
# Configure sidebar grid rows
for i in range(8):
self.sidebar.grid_rowconfigure(i, weight=1)
# Text input frames with expansion
self.prompt_frame = tk.Frame(self.sidebar, bg="#FBFBFB")
self.prompt_frame.grid(row=0, column=0, sticky="nsew", pady=(0, 5))
self.prompt_frame.grid_columnconfigure(0, weight=1)
self.prompt_frame.grid_rowconfigure(0, weight=2)
self.prompt_frame.grid_rowconfigure(1, weight=1)
# Prompt textbox with expansion
self.prompt_entry = ctk.CTkTextbox(
self.prompt_frame,
height=150,
fg_color="#E8F9FF",
text_color="black",
border_color="gray",
border_width=2,
)
self.prompt_entry.grid(row=0, column=0, sticky="nsew")
# Negative prompt textbox with expansion
self.neg = ctk.CTkTextbox(
self.prompt_frame,
height=75,
fg_color="#E8F9FF",
text_color="black",
border_color="gray",
border_width=2,
)
self.neg.grid(row=1, column=0, sticky="nsew", pady=(5, 0))
# Add model dropdown with error handling for empty lists
model_values = (file_names if file_names else ["No models found"]) + ["flux"]
# Model dropdown and Flux checkbox
self.dropdown = ctk.CTkOptionMenu(
self.sidebar,
values=model_values,
fg_color="#F5EFFF",
text_color="black",
command=self.on_model_selected,
)
self.dropdown.grid(row=2, column=0, sticky="ew")
# LoRA selection
self.lora_selection = ctk.CTkOptionMenu(
self.sidebar, values=lora_names, fg_color="#F5EFFF", text_color="black"
)
self.lora_selection.grid(row=3, column=0, sticky="ew", pady=5)
# Display frame with expansion
self.display = tk.Frame(self, bg="#FBFBFB")
self.display.grid(row=0, column=1, sticky="nsew", padx=10, pady=10)
self.display.grid_columnconfigure(0, weight=1)
self.img = None
# Add row configuration for both image and checkbox
self.display.grid_rowconfigure(0, weight=1) # For image
self.display.grid_rowconfigure(1, weight=0) # For checkbox
# Image label with expansion
self.image_label = tk.Label(self.display, bg="#FBFBFB")
self.image_label.grid(row=0, column=0, sticky="nsew")
# Previewer checkbox - changed from pack to grid
self.previewer_var = tk.BooleanVar()
self.previewer_checkbox = ctk.CTkCheckBox(
self.display,
text="Previewer",
variable=self.previewer_var,
command=self.print_previewer,
text_color="black",
)
self.previewer_checkbox.grid(row=1, column=0, pady=10)
# Progress Bar
self.progress = ctk.CTkProgressBar(self.display, fg_color="#FBFBFB")
self.progress.grid(row=2, column=0, sticky="ew", pady=10, padx=10)
self.progress.set(0)
# Make sliders frame expand
self.sliders_frame = tk.Frame(self.sidebar, bg="#FBFBFB")
self.sliders_frame.grid(row=4, column=0, sticky="nsew", pady=5)
self.sliders_frame.grid_columnconfigure(1, weight=1)
# Configure slider weights
for i in range(3):
self.sliders_frame.grid_rowconfigure(i, weight=1)
# Make checkbox frame expand
self.checkbox_frame = tk.Frame(self.sidebar, bg="#FBFBFB")
self.checkbox_frame.grid(row=5, column=0, sticky="nsew", pady=10)
self.checkbox_frame.grid_columnconfigure(0, weight=1)
self.checkbox_frame.grid_columnconfigure(1, weight=1)
# Make button frame expand
self.button_frame = tk.Frame(self.sidebar, bg="#FBFBFB")
self.button_frame.grid(row=7, column=0, sticky="nsew", pady=10)
self.button_frame.grid_columnconfigure(0, weight=1)
self.button_frame.grid_columnconfigure(1, weight=1)
# Width slider
tk.Label(self.sliders_frame, text="Width:", bg="#FBFBFB").grid(
row=0, column=0, padx=(0, 5)
)
self.width_slider = ctk.CTkSlider(
self.sliders_frame, from_=1, to=2048, number_of_steps=32, fg_color="#F5EFFF"
)
self.width_slider.grid(row=0, column=1, sticky="ew")
self.width_label = ctk.CTkLabel(self.sliders_frame, text="")
self.width_label.grid(row=0, column=2, padx=(5, 0))
# Height slider
tk.Label(self.sliders_frame, text="Height:", bg="#FBFBFB").grid(
row=1, column=0, padx=(0, 5)
)
self.height_slider = ctk.CTkSlider(
self.sliders_frame, from_=1, to=2048, number_of_steps=32, fg_color="#F5EFFF"
)
self.height_slider.grid(row=1, column=1, sticky="ew")
self.height_label = ctk.CTkLabel(self.sliders_frame, text="")
self.height_label.grid(row=1, column=2, padx=(5, 0))
# CFG slider
tk.Label(self.sliders_frame, text="CFG:", bg="#FBFBFB").grid(
row=2, column=0, padx=(0, 5)
)
self.cfg_slider = ctk.CTkSlider(
self.sliders_frame, from_=1, to=15, number_of_steps=14, fg_color="#F5EFFF"
)
self.cfg_slider.grid(row=2, column=1, sticky="ew")
self.cfg_label = ctk.CTkLabel(self.sliders_frame, text="")
self.cfg_label.grid(row=2, column=2, padx=(5, 0))
# Batch size slider
tk.Label(self.sliders_frame, text="Batch Size:", bg="#FBFBFB").grid(
row=3, column=0, padx=(0, 5)
)
self.batch_slider = ctk.CTkSlider(
self.sliders_frame, from_=1, to=10, number_of_steps=9, fg_color="#F5EFFF"
)
self.batch_slider.grid(row=3, column=1, sticky="ew")
self.batch_label = ctk.CTkLabel(self.sliders_frame, text="")
self.batch_label.grid(row=3, column=2, padx=(5, 0))
# Configure grid columns and rows to distribute space evenly
self.checkbox_frame.grid_columnconfigure(0, weight=1)
self.checkbox_frame.grid_columnconfigure(1, weight=1)
self.checkbox_frame.grid_rowconfigure(0, weight=1)
self.checkbox_frame.grid_rowconfigure(1, weight=1)
self.checkbox_frame.grid_rowconfigure(2, weight=1)
# checkbox for hiresfix
self.hires_fix_var = tk.BooleanVar()
self.hires_fix_checkbox = ctk.CTkCheckBox(
self.checkbox_frame,
text="Hires Fix",
variable=self.hires_fix_var,
command=self.print_hires_fix,
text_color="black",
)
self.hires_fix_checkbox.grid(
row=0, column=0, padx=(75, 5), pady=5, sticky="nsew"
)
# checkbox for Adetailer
self.adetailer_var = tk.BooleanVar()
self.adetailer_checkbox = ctk.CTkCheckBox(
self.checkbox_frame,
text="Adetailer",
variable=self.adetailer_var,
command=self.print_adetailer,
text_color="black",
)
self.adetailer_checkbox.grid(row=0, column=1, padx=5, pady=5, sticky="nsew")
# checkbox to enable stable-fast optimization
self.stable_fast_var = tk.BooleanVar()
self.stable_fast_checkbox = ctk.CTkCheckBox(
self.checkbox_frame,
text="Stable Fast",
variable=self.stable_fast_var,
text_color="black",
)
self.stable_fast_checkbox.grid(
row=1, column=0, padx=(75, 5), pady=5, sticky="nsew"
)
# checkbox to enable prompt enhancer
self.enhancer_var = tk.BooleanVar()
self.enhancer_checkbox = ctk.CTkCheckBox(
self.checkbox_frame,
text="Prompt enhancer",
variable=self.enhancer_var,
text_color="black",
)
self.enhancer_checkbox.grid(row=1, column=1, padx=5, pady=5, sticky="nsew")
self.prioritize_speed_var = tk.BooleanVar()
self.prioritize_speed_checkbox = ctk.CTkCheckBox(
self.checkbox_frame,
text="Prioritize Speed",
variable=self.prioritize_speed_var,
text_color="black",
)
self.prioritize_speed_checkbox.grid(
row=2, column=0, padx=(75, 5), pady=5, sticky="nsew"
)
# Button to launch the generation
self.generate_button = ctk.CTkButton(
self.sidebar,
text="Generate",
command=self.generate_image,
fg_color="#C4D9FF",
text_color="black",
border_color="gray",
border_width=2,
)
self.generate_button.grid(
row=6, column=0, pady=10, sticky="ew"
) # Changed from pack to grid
self.ckpt = None
# load the checkpoint on an another thread
threading.Thread(target=self._prep, daemon=True).start()
# img2img button
self.img2img_button = ctk.CTkButton(
self.button_frame,
text="img2img",
command=self.img2img,
fg_color="#F5EFFF",
text_color="black",
border_color="gray",
border_width=2,
)
self.img2img_button.grid(row=0, column=0, padx=5, sticky="ew")
# interrupt button
self.generation_threads = []
self.interrupt_flag = False
self.interrupt_button = ctk.CTkButton(
self.button_frame,
text="Interrupt",
command=self.interrupt_generation,
fg_color="#F5EFFF",
text_color="black",
border_color="gray",
border_width=2,
)
self.interrupt_button.grid(row=0, column=1, padx=5, sticky="ew")
prompt, neg, width, height, cfg = util.load_parameters_from_file()
self.prompt_entry.insert(tk.END, prompt)
self.neg.insert(tk.END, neg)
self.width_slider.set(width)
self.height_slider.set(height)
self.cfg_slider.set(cfg)
self.batch_slider.set(1)
self.width_slider.bind("<B1-Motion>", lambda event: self.update_labels())
self.height_slider.bind("<B1-Motion>", lambda event: self.update_labels())
self.cfg_slider.bind("<B1-Motion>", lambda event: self.update_labels())
self.batch_slider.bind("<B1-Motion>", lambda event: self.update_labels())
self.update_labels()
self.prompt_entry.bind(
"<KeyRelease>",
lambda event: util.write_parameters_to_file(
self.prompt_entry.get("1.0", tk.END),
self.neg.get("1.0", tk.END),
self.width_slider.get(),
self.height_slider.get(),
self.cfg_slider.get(),
),
)
self.neg.bind(
"<KeyRelease>",
lambda event: util.write_parameters_to_file(
self.prompt_entry.get("1.0", tk.END),
self.neg.get("1.0", tk.END),
self.width_slider.get(),
self.height_slider.get(),
self.cfg_slider.get(),
),
)
self.width_slider.bind(
"<ButtonRelease-1>",
lambda event: util.write_parameters_to_file(
self.prompt_entry.get("1.0", tk.END),
self.neg.get("1.0", tk.END),
self.width_slider.get(),
self.height_slider.get(),
self.cfg_slider.get(),
),
)
self.height_slider.bind(
"<ButtonRelease-1>",
lambda event: util.write_parameters_to_file(
self.prompt_entry.get("1.0", tk.END),
self.neg.get("1.0", tk.END),
self.width_slider.get(),
self.height_slider.get(),
self.cfg_slider.get(),
),
)
self.cfg_slider.bind(
"<ButtonRelease-1>",
lambda event: util.write_parameters_to_file(
self.prompt_entry.get("1.0", tk.END),
self.neg.get("1.0", tk.END),
self.width_slider.get(),
self.height_slider.get(),
self.cfg_slider.get(),
),
)
# Add resize handling variables
self._resize_queue = queue.Queue()
self._resize_thread = None
self._resize_event = threading.Event()
self._resize_lock = threading.Lock()
self._resize_running = True
self._last_resize_time = 0
self._resize_delay = 0.1
self._image_cache = {}
self._current_image = None
# Start resize worker thread
self._start_resize_worker()
# Bind resize event
self.bind("<Configure>", self._queue_resize)
# Bind cleanup
self.protocol("WM_DELETE_WINDOW", self._cleanup)
self.display_most_recent_image_flag = False
self.display_most_recent_image()
self.is_generating = False
self.sampler = (
"dpmpp_sde_cfgpp"
if not self.prioritize_speed_var.get()
else "dpmpp_2m_cfgpp"
)
def _img2img(self, file_path: str) -> None:
"""Perform img2img on the selected image.
Args:
file_path (str): The path to the selected image.
"""
self.is_generating = True
self.img2img_button.configure(state="disabled")
self.display_most_recent_image_flag = False
prompt = self.prompt_entry.get("1.0", tk.END)
neg = self.neg.get("1.0", tk.END)
img = Image.open(file_path)
img_array = np.array(img)
img_tensor = torch.from_numpy(img_array).float().to("cpu") / 255.0
img_tensor = img_tensor.unsqueeze(0)
self.interrupt_flag = False
self.sampler = (
"dpmpp_sde_cfgpp"
if not self.prioritize_speed_var.get()
else "dpmpp_2m_cfgpp"
)
with torch.inference_mode():
(
checkpointloadersimple_241,
cliptextencode,
emptylatentimage,
ksampler_instance,
vaedecode,
latentupscale,
upscalemodelloader,
ultimatesdupscale,
) = self._prep()
try:
loraloader = LoRas.LoraLoader()
loraloader_274 = loraloader.load_lora(
lora_name="add_detail.safetensors",
strength_model=2,
strength_clip=2,
model=checkpointloadersimple_241[0],
clip=checkpointloadersimple_241[1],
)
except:
loraloader_274 = checkpointloadersimple_241
if self.stable_fast_var.get() is True:
from modules.StableFast import StableFast
try:
app.title("LigtDiffusion - Generating StableFast model")
except:
pass
applystablefast = StableFast.ApplyStableFastUnet()
applystablefast_158 = applystablefast.apply_stable_fast(
enable_cuda_graph=False,
model=loraloader_274[0],
)
else:
applystablefast_158 = loraloader_274
fb_cache = fbcache_nodes.ApplyFBCacheOnModel()
applystablefast_158 = fb_cache.patch(
applystablefast_158, "diffusion_model", 0.120
)
clipsetlastlayer = Clip.CLIPSetLastLayer()
clipsetlastlayer_257 = clipsetlastlayer.set_last_layer(
stop_at_clip_layer=-2, clip=loraloader_274[1]
)
cliptextencode_242 = cliptextencode.encode(
text=prompt,
clip=clipsetlastlayer_257[0],
)
cliptextencode_243 = cliptextencode.encode(
text=neg,
clip=clipsetlastlayer_257[0],
)
upscalemodelloader_244 = upscalemodelloader.load_model(
"RealESRGAN_x4plus.pth"
)
try:
app.title("LightDiffusion - Upscaling")
except:
pass
ultimatesdupscale_250 = ultimatesdupscale.upscale(
upscale_by=2,
seed=random.randint(1, 2**64),
steps=8,
cfg=6,
sampler_name=self.sampler,
scheduler="karras",
denoise=0.3,
mode_type="Linear",
tile_width=512,
tile_height=512,
mask_blur=16,
tile_padding=32,
seam_fix_mode="Half Tile",
seam_fix_denoise=0.2,
seam_fix_width=64,
seam_fix_mask_blur=16,
seam_fix_padding=32,
force_uniform_tiles="enable",
image=img_tensor,
model=applystablefast_158[0],
positive=cliptextencode_242[0],
negative=cliptextencode_243[0],
vae=checkpointloadersimple_241[2],
upscale_model=upscalemodelloader_244[0],
)
self.update_from_decode(ultimatesdupscale_250[0], "LD-I2I")
self.update_image(img)
global generated
generated = img
self.display_most_recent_image_flag = True
try:
app.title("LightDiffusion")
except:
pass
self.is_generating = False
self.img2img_button.configure(state="normal")
def img2img(self) -> None:
"""Open the file selector and run img2img on the selected image."""
if self.is_generating:
return
file_path = filedialog.askopenfilename()
if file_path:
threading.Thread(
target=self._img2img, args=(file_path,), daemon=True
).start()
def print_hires_fix(self) -> None:
"""Print the status of the hires fix checkbox."""
if self.hires_fix_var.get() is True:
print("Hires fix is ON")
else:
print("Hires fix is OFF")
def print_adetailer(self) -> None:
"""Print the status of the adetailer checkbox."""
if self.adetailer_var.get() is True:
print("Adetailer is ON")
else:
print("Adetailer is OFF")
def print_previewer(self) -> None:
"""Print the status of the previewer checkbox."""
if self.previewer_var.get() is True:
print("Previewer is ON")
else:
print("Previewer is OFF")
def generate_image(self) -> None:
"""Start the image generation process."""
if self.is_generating:
return
if self.dropdown.get() == "flux":
self.generate_thread = threading.Thread(
target=self._generate_image_flux, daemon=True
).start()
else:
self.generate_thread = threading.Thread(
target=self._generate_image, daemon=True
).start()
def _prep(self) -> tuple:
"""Prepare the necessary components for image generation.
Returns:
tuple: The prepared components.
"""
if self.dropdown.get() != self.ckpt and self.dropdown.get() != "flux":
self.ckpt = self.dropdown.get()
with torch.inference_mode():
self.checkpointloadersimple = Loader.CheckpointLoaderSimple()
self.checkpointloadersimple_241 = (
self.checkpointloadersimple.load_checkpoint(
ckpt_name="./_internal/checkpoints/" + self.ckpt
)
)
self.cliptextencode = Clip.CLIPTextEncode()
self.emptylatentimage = Latent.EmptyLatentImage()
self.ksampler_instance = sampling.KSampler()
self.vaedecode = VariationalAE.VAEDecode()
self.latent_upscale = upscale.LatentUpscale()
self.upscalemodelloader = USDU_upscaler.UpscaleModelLoader()
self.ultimatesdupscale = UltimateSDUpscale.UltimateSDUpscale()
return (
self.checkpointloadersimple_241,
self.cliptextencode,
self.emptylatentimage,
self.ksampler_instance,
self.vaedecode,
self.latent_upscale,
self.upscalemodelloader,
self.ultimatesdupscale,
)
def _generate_image(self) -> None:
"""Generate image with proper interrupt handling."""
self.is_generating = True
self.generate_button.configure(state="disabled")
current_thread = threading.current_thread()
self.generation_threads.append(current_thread)
self.interrupt_flag = False
self.sampler = (
"dpmpp_sde_cfgpp"
if not self.prioritize_speed_var.get()
else "dpmpp_2m_cfgpp"
)
try:
# Disable generate button during generation
self.generate_button.configure(state="disabled")
self.display_most_recent_image_flag = False
self.progress.set(0)
# Early interrupt check
if self.interrupt_flag:
return
# Get generation parameters
prompt = self.prompt_entry.get("1.0", tk.END)
neg = self.neg.get("1.0", tk.END)
w = int(self.width_slider.get())
h = int(self.height_slider.get())
cfg = int(self.cfg_slider.get())
try:
if self.enhancer_var.get() is True:
prompt = Enhancer.enhance_prompt(prompt)
while prompt is None:
pass
except:
pass
# Main generation with proper interrupt handling
with torch.inference_mode():
components = self._prep()
if self.interrupt_flag:
return
(
checkpointloadersimple_241,
cliptextencode,
emptylatentimage,
ksampler_instance,
vaedecode,
latentupscale,
upscalemodelloader,
ultimatesdupscale,
) = self._prep()
try:
loraloader = LoRas.LoraLoader()
loraloader_274 = loraloader.load_lora(
lora_name=self.lora_selection.get().replace(
"./_internal/loras/", ""
),
strength_model=0.7,
strength_clip=0.7,
model=checkpointloadersimple_241[0],
clip=checkpointloadersimple_241[1],
)
print(
"loading",
self.lora_selection.get().replace("./_internal/loras/", ""),
)
except:
loraloader_274 = checkpointloadersimple_241
try:
cliptextencode_124 = cliptextencode.encode(
text="royal, detailed, magnificient, beautiful, seducing",
clip=loraloader_274[1],
)
ultralyticsdetectorprovider = bbox.UltralyticsDetectorProvider()
ultralyticsdetectorprovider_151 = ultralyticsdetectorprovider.doit(
# model_name="face_yolov8m.pt"
model_name="person_yolov8m-seg.pt"
)
bboxdetectorsegs = bbox.BboxDetectorForEach()
samdetectorcombined = SAM.SAMDetectorCombined()
impactsegsandmask = SEGS.SegsBitwiseAndMask()
detailerforeachdebug = ADetailer.DetailerForEachTest()
except:
pass
clipsetlastlayer = Clip.CLIPSetLastLayer()
clipsetlastlayer_257 = clipsetlastlayer.set_last_layer(
stop_at_clip_layer=-2, clip=loraloader_274[1]
)
self.progress.set(0.2)
if self.stable_fast_var.get() is True:
from modules.StableFast import StableFast
try:
self.title("LightDiffusion - Generating StableFast model")
except:
pass
applystablefast = StableFast.ApplyStableFastUnet()
applystablefast_158 = applystablefast.apply_stable_fast(
enable_cuda_graph=False,
model=loraloader_274[0],
)
else:
applystablefast_158 = loraloader_274
fb_cache = fbcache_nodes.ApplyFBCacheOnModel()
applystablefast_158 = fb_cache.patch(
applystablefast_158, "diffusion_model", 0.120
)
hidiffoptimizer = msw_msa_attention.ApplyMSWMSAAttentionSimple()
cliptextencode_242 = cliptextencode.encode(
text=prompt,
clip=clipsetlastlayer_257[0],
)
cliptextencode_243 = cliptextencode.encode(
text=neg,
clip=clipsetlastlayer_257[0],
)
emptylatentimage_244 = emptylatentimage.generate(
width=w, height=h, batch_size=int(self.batch_slider.get())
)
ksampler_239 = ksampler_instance.sample(
seed=random.randint(1, 2**64),
steps=20,
cfg=cfg,
sampler_name=self.sampler,
scheduler="karras",
denoise=1,
model=hidiffoptimizer.go(
model_type="auto", model=applystablefast_158[0]
)[0],
positive=cliptextencode_242[0],
negative=cliptextencode_243[0],
latent_image=emptylatentimage_244[0],
)
self.progress.set(0.4)
if self.hires_fix_var.get() is True:
latentupscale_254 = latentupscale.upscale(
width=w * 2,
height=h * 2,
samples=ksampler_239[0],
)
ksampler_253 = ksampler_instance.sample(
seed=random.randint(1, 2**64),
steps=10,
cfg=8,
sampler_name="euler_ancestral_cfgpp",
scheduler="normal",
denoise=0.45,
model=hidiffoptimizer.go(
model_type="auto", model=applystablefast_158[0]
)[0],
positive=cliptextencode_242[0],
negative=cliptextencode_243[0],
latent_image=latentupscale_254[0],
)
vaedecode_240 = vaedecode.decode(
samples=ksampler_253[0],
vae=checkpointloadersimple_241[2],
)
self.update_from_decode(vaedecode_240[0], "LD-HF")
else:
vaedecode_240 = vaedecode.decode(
samples=ksampler_239[0],
vae=checkpointloadersimple_241[2],
)
self.update_from_decode(vaedecode_240[0], "LD")
if self.interrupt_flag:
return
self.progress.set(0.6)
if self.adetailer_var.get() is True:
samloader = SAM.SAMLoader()
samloader_87 = samloader.load_model(
model_name="sam_vit_b_01ec64.pth", device_mode="AUTO"
)
bboxdetectorsegs_132 = bboxdetectorsegs.doit(
threshold=0.5,
dilation=10,
crop_factor=2,
drop_size=10,
labels="all",
bbox_detector=ultralyticsdetectorprovider_151[0],
image=vaedecode_240[0],
)
samdetectorcombined_139 = samdetectorcombined.doit(
detection_hint="center-1",
dilation=0,
threshold=0.93,
bbox_expansion=0,
mask_hint_threshold=0.7,
mask_hint_use_negative="False",
sam_model=samloader_87[0],
segs=bboxdetectorsegs_132,
image=vaedecode_240[0],
)
if samdetectorcombined_139[0] is None:
return
impactsegsandmask_152 = impactsegsandmask.doit(
segs=bboxdetectorsegs_132,
mask=samdetectorcombined_139[0],
)
detailerforeachdebug_145 = detailerforeachdebug.doit(
guide_size=512,
guide_size_for=False,
max_size=768,
seed=random.randint(1, 2**64),
steps=20,
cfg=6.5,
sampler_name=self.sampler,
scheduler="karras",
denoise=0.5,
feather=5,
noise_mask=True,
force_inpaint=True,
wildcard="",
cycle=1,
inpaint_model=False,
noise_mask_feather=20,
image=vaedecode_240[0],
segs=impactsegsandmask_152[0],
model=applystablefast_158[0],
clip=checkpointloadersimple_241[1],
vae=checkpointloadersimple_241[2],
positive=cliptextencode_124[0],
negative=cliptextencode_243[0],
)
self.update_from_decode(detailerforeachdebug_145[0], "LD-body")
ultralyticsdetectorprovider = bbox.UltralyticsDetectorProvider()
ultralyticsdetectorprovider_151 = ultralyticsdetectorprovider.doit(
model_name="face_yolov9c.pt"
)
bboxdetectorsegs_132 = bboxdetectorsegs.doit(
threshold=0.5,
dilation=10,
crop_factor=2,
drop_size=10,
labels="all",
bbox_detector=ultralyticsdetectorprovider_151[0],
image=detailerforeachdebug_145[0],
)
samdetectorcombined_139 = samdetectorcombined.doit(
detection_hint="center-1",
dilation=0,
threshold=0.93,
bbox_expansion=0,
mask_hint_threshold=0.7,
mask_hint_use_negative="False",
sam_model=samloader_87[0],
segs=bboxdetectorsegs_132,
image=detailerforeachdebug_145[0],
)
impactsegsandmask_152 = impactsegsandmask.doit(
segs=bboxdetectorsegs_132,
mask=samdetectorcombined_139[0],
)
detailerforeachdebug_145 = detailerforeachdebug.doit(
guide_size=512,
guide_size_for=False,
max_size=768,
seed=random.randint(1, 2**64),
steps=20,
cfg=6.5,
sampler_name=self.sampler,
scheduler="karras",
denoise=0.5,
feather=5,
noise_mask=True,
force_inpaint=True,
wildcard="",
cycle=1,
inpaint_model=False,
noise_mask_feather=20,
image=detailerforeachdebug_145[0],
segs=impactsegsandmask_152[0],
model=applystablefast_158[0],
clip=checkpointloadersimple_241[1],
vae=checkpointloadersimple_241[2],
positive=cliptextencode_124[0],
negative=cliptextencode_243[0],
)
self.update_from_decode(detailerforeachdebug_145[0], "LD-head")
self.progress.set(0.8)
except Exception as e:
print(f"Generation error: {e}")
self.title(f"LightDiffusion - Error: {str(e)}")
finally:
# Reset state when done
self.is_generating = False
self.generate_button.configure(state="normal")
if current_thread in self.generation_threads:
self.generation_threads.remove(current_thread)
self.progress.set(0)
# Clear CUDA cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
def _generate_image_flux(self) -> None:
"""Generate an image using the Flux model."""
self.is_generating = True
self.generate_button.configure(state="disabled")
# Add current thread to list at start
current_thread = threading.current_thread()
self.generation_threads.append(current_thread)
self.display_most_recent_image_flag = False
w = int(self.width_slider.get())
h = int(self.height_slider.get())
prompt = self.prompt_entry.get("1.0", tk.END)
try:
if self.enhancer_var.get() is True:
prompt = Enhancer.enhance_prompt(prompt)
while prompt is None:
pass
self.interrupt_flag = False
Downloader.CheckAndDownloadFlux()
with torch.inference_mode():
dualcliploadergguf = Quantizer.DualCLIPLoaderGGUF()
emptylatentimage = Latent.EmptyLatentImage()
vaeloader = VariationalAE.VAELoader()
unetloadergguf = Quantizer.UnetLoaderGGUF()
cliptextencodeflux = Quantizer.CLIPTextEncodeFlux()
conditioningzeroout = Quantizer.ConditioningZeroOut()
ksampler = sampling.KSampler()
vaedecode = VariationalAE.VAEDecode()
unetloadergguf_10 = unetloadergguf.load_unet(
unet_name="flux1-dev-Q8_0.gguf"
)
vaeloader_11 = vaeloader.load_vae(vae_name="ae.safetensors")
dualcliploadergguf_19 = dualcliploadergguf.load_clip(
clip_name1="clip_l.safetensors",
clip_name2="t5-v1_1-xxl-encoder-Q8_0.gguf",
type="flux",
)
emptylatentimage_5 = emptylatentimage.generate(
width=w, height=h, batch_size=int(self.batch_slider.get())
)
cliptextencodeflux_15 = cliptextencodeflux.encode(
clip_l=prompt,
t5xxl=prompt,
guidance=3.0,
clip=dualcliploadergguf_19[0],
flux_enabled=True,
)
conditioningzeroout_16 = conditioningzeroout.zero_out(
conditioning=cliptextencodeflux_15[0]
)
fb_cache = fbcache_nodes.ApplyFBCacheOnModel()
unetloadergguf_10 = fb_cache.patch(
unetloadergguf_10, "diffusion_model", 0.120
)
# try:
# import triton
# compiler = misc_nodes.EnhancedCompileModel()
# unetloadergguf_10 = compiler.patch(unetloadergguf_10, True, "diffusion_model", "torch.compile", False, False, None, None, False, "inductor")
# except ImportError:
# print("Triton not found, skipping compilation")
ksampler_3 = ksampler.sample(
seed=random.randint(1, 2**64),
steps=20,
cfg=1,
sampler_name="euler_cfgpp",
scheduler="beta",
denoise=1,
model=unetloadergguf_10[0],
positive=cliptextencodeflux_15[0],
negative=conditioningzeroout_16[0],
latent_image=emptylatentimage_5[0],
flux=True,
)
vaedecode_8 = vaedecode.decode(
samples=ksampler_3[0],
vae=vaeloader_11[0],
flux=True,
)
self.update_from_decode(vaedecode_8[0], "LD-Flux")
finally:
# Reset state when done
self.is_generating = False
self.generate_button.configure(state="normal")
if current_thread in self.generation_threads:
self.generation_threads.remove(current_thread)
def on_model_selected(self, *args):
"""Handle model selection changes"""
if self.dropdown.get() == "flux":
# Disable incompatible controls
self.adetailer_checkbox._state = tk.DISABLED
self.hires_fix_checkbox._state = tk.DISABLED
self.stable_fast_checkbox._state = tk.DISABLED
self.lora_selection._state = tk.DISABLED
self.cfg_slider._state = tk.DISABLED
else:
# Enable controls
self.adetailer_checkbox._state = tk.NORMAL
self.hires_fix_checkbox._state = tk.NORMAL
self.stable_fast_checkbox._state = tk.NORMAL
self.lora_selection._state = tk.NORMAL
self.cfg_slider._state = tk.NORMAL
def _handle_decoded_image(self, decoded, prefix: str) -> None:
"""Handle decoded image processing with HDR effects.
Args:
decoded: Decoded tensor image
prefix: Prefix for saved files
"""
try:
# Initialize components
saveimage = ImageSaver.SaveImage()
hdr = ahdr.HDREffects()
images = []
# Apply HDR effects
if isinstance(decoded, tuple):
# Handle tuple return
tensor_image = decoded[0]
else:
tensor_image = decoded
# Apply HDR as batch process
processed = hdr.apply_hdr2(tensor_image)
# Save images with prefix
saveimage.save_images(
filename_prefix=prefix,
images=processed[0] if isinstance(processed, tuple) else processed,
)
# Convert processed tensors to PIL images
for img_tensor in (
processed[0] if isinstance(processed, tuple) else [processed]
):
# Convert to numpy and scale
img_array = 255.0 * img_tensor.cpu().numpy()
# Handle different dimensions
if img_array.ndim == 4:
img_array = np.squeeze(img_array)
img_array = img_array.reshape(
-1, img_array.shape[-2], img_array.shape[-1]
)
# Convert to PIL image
img = Image.fromarray(np.clip(img_array, 0, 255).astype(np.uint8))
images.append(img)
# Update display if not interrupted
if not self.interrupt_flag:
self.progress.set(1.0)
if images:
self.img = images[0]
self.update_image(images)
self.display_most_recent_image_flag = True
except Exception as e:
print(f"Image processing error: {e}")
self.title(f"LightDiffusion - Error: {str(e)}")
def update_from_decode(self, decoded: Image.Image, prefix: str) -> None:
"""Update the image from the decode function.
Args:
decoded (Image.Image): The decoded image tensor/tuple
prefix (str): Prefix for saved files
"""
try:
# Handle image processing in separate function
self._handle_decoded_image(decoded, prefix)
except Exception as e:
print(f"Decode error: {e}")
self.title(f"LightDiffusion - Error: {str(e)}")
finally:
# Ensure cleanup
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
def update_labels(self) -> None:
"""Update the labels for the sliders."""
self.width_label.configure(text=f"{int(self.width_slider.get())}")
self.height_label.configure(text=f"{int(self.height_slider.get())}")
self.cfg_label.configure(text=f"{int(self.cfg_slider.get())}")
self.batch_label.configure(text=f"{int(self.batch_slider.get())}")
def create_image_grid(self, images: list[Image.Image]) -> Image.Image:
"""Create a grid of images.
Args:
images (list[Image.Image]): List of images to arrange in grid
Returns:
Image.Image: Combined grid image
"""
# Calculate grid dimensions
n = len(images)
if n <= 1:
return images[0]
cols = int(np.ceil(np.sqrt(n)))
rows = int(np.ceil(n / cols))
# Get max dimensions
w_max = max(img.width for img in images)
h_max = max(img.height for img in images)
# Create output image
grid = Image.new("RGB", (w_max * cols, h_max * rows))
# Paste images into grid
for idx, img in enumerate(images):
i = idx // cols
j = idx % cols
grid.paste(img, (j * w_max, i * h_max))
return grid
def update_image(self, images: Union[Image.Image, list[Image.Image]]) -> None:
"""Update the displayed image(s).
Args:
images: Single image or list of images to display
"""
# Convert single image to list
if isinstance(images, Image.Image):
images = [images]
# Create grid of all images
grid_img = self.create_image_grid(images)
# Calculate the aspect ratio of the grid
aspect_ratio = grid_img.width / grid_img.height
# Determine the new dimensions while maintaining the aspect ratio
label_width = int(4 * self.winfo_width() / 7)
label_height = int(4 * self.winfo_height() / 7)
if label_width / aspect_ratio <= label_height:
new_width = label_width
new_height = int(label_width / aspect_ratio)
else:
new_height = label_height
new_width = int(label_height * aspect_ratio)
# Resize the grid image
try:
grid_img = grid_img.resize((new_width, new_height), Image.LANCZOS)
except RecursionError:
pass
self.img = grid_img
if self.display_most_recent_image_flag is False:
self._update_image_label(grid_img)
def _update_image_label(self, img: Image.Image) -> None:
"""Update the image label with the provided image.
Args:
img (Image.Image): The image to display.
"""
# Convert the PIL image to a Tkinter PhotoImage
tk_image = ImageTk.PhotoImage(img)
# Update the image label with the Tkinter PhotoImage
self.image_label.config(image=tk_image)
# Keep a reference to the image to prevent it from being garbage collected
self.image_label.image = tk_image
def display_most_recent_image(self) -> None:
"""Display the most recent image(s) from the output directory."""
# Get a list of all image files in the output directory
image_files = glob.glob("./_internal/output/Classic/*")
image_files += glob.glob("./_internal/output/Adetailer/*")
image_files += glob.glob("./_internal/output/Flux/*")
image_files += glob.glob("./_internal/output/HiresFix/*")
image_files += glob.glob("./_internal/output/Img2Img/*")
# If there are no image files, return
if not image_files:
return
# Sort files by modification time in descending order
image_files.sort(key=os.path.getmtime, reverse=True)
# Get most recent timestamp
latest_time = os.path.getmtime(image_files[0])
# Get all images from same batch (within 1 second of most recent)
batch_images = []
for file in image_files:
if abs(os.path.getmtime(file) - latest_time) < 1.0:
try:
img = Image.open(file)
batch_images.append(img)
except:
continue
if not batch_images:
return
# Display single image or grid of batch
if len(batch_images) == 1:
self.update_image(batch_images[0])
else:
self.update_image(batch_images)
def _start_resize_worker(self):
"""Start the resize worker thread"""
self._resize_thread = threading.Thread(target=self._resize_worker, daemon=True)
self._resize_thread.start()
def _resize_worker(self):
"""Worker thread for handling resize operations"""
while self._resize_running:
try:
# Wait for resize event or timeout
if self._resize_queue.qsize() > 0:
event = self._resize_queue.get(timeout=0.1)
self._do_resize(event)
self._resize_queue.task_done()
else:
self._resize_event.wait(timeout=0.1)
self._resize_event.clear()
except queue.Empty:
continue
def _queue_resize(self, event):
"""Queue a resize event"""
current_time = time.time()
# Debounce resize events
if current_time - self._last_resize_time > self._resize_delay:
self._last_resize_time = current_time
self._resize_queue.put(event)
self._resize_event.set()
def _do_resize(self, event):
"""Handle resize operation in worker thread"""
width = self.winfo_width()
height = self.winfo_height()
# Update UI components in main thread
self.after(0, lambda: self._update_components(width, height))
# Update image if exists
if hasattr(self, "img"):
self._update_image_threaded(self.img)
def _update_components(self, width, height):
"""Update UI components sizes"""
# Update component sizes based on window dimensions
width = self.winfo_width()
height = self.winfo_height()
# Scale text boxes
prompt_height = int(height * 0.25)
neg_height = int(height * 0.15)
self.prompt_entry.configure(height=prompt_height)
self.neg.configure(height=neg_height)
def _update_image_threaded(self, img):
"""Thread-safe image update with caching"""
if img is None:
return
with self._resize_lock:
# Calculate dimensions
aspect_ratio = img.width / img.height
label_width = int(4 * self.winfo_width() / 7)
label_height = int(4 * self.winfo_height() / 7)
if label_width / aspect_ratio <= label_height:
new_width = label_width
new_height = int(label_width / aspect_ratio)
else:
new_height = label_height
new_width = int(label_height * aspect_ratio)
# Check cache
cache_key = (new_width, new_height)
if cache_key in self._image_cache:
resized_img = self._image_cache[cache_key]
else:
try:
resized_img = img.resize((new_width, new_height), Image.LANCZOS)
self._image_cache[cache_key] = resized_img
# Limit cache size
if len(self._image_cache) > 5:
self._image_cache.pop(next(iter(self._image_cache)))
except:
return
if not self.display_most_recent_image_flag:
# Update image in main thread
self.after(0, lambda: self._update_image_label_safe(resized_img))
def _update_image_label_safe(self, img):
"""Thread-safe image label update"""
if not self.display_most_recent_image_flag:
self._current_image = ImageTk.PhotoImage(img)
self.image_label.configure(image=self._current_image)
def _cleanup(self):
"""Clean up threads before closing"""
self._resize_running = False
self._resize_event.set()
if self._resize_thread:
self._resize_thread.join(timeout=1.0)
self.destroy()
def interrupt_generation(self) -> None:
"""Interrupt ongoing image generation process."""
if not self.is_generating:
return
# Set interrupt flag first
self.interrupt_flag = True
# Clear CUDA cache and release memory
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
# Stop and cleanup threads
for thread in self.generation_threads[:]:
if thread and thread.is_alive():
thread.join(timeout=1.0)
if thread in self.generation_threads:
self.generation_threads.remove(thread)
# Reset UI state
self.progress.set(0)
self.title("LightDiffusion")
self.generate_button.configure(state="normal")
self.display_most_recent_image_flag = True
# Clear any pending resize tasks
with self._resize_lock:
self._resize_queue.queue.clear()
# Reset model state if needed
if hasattr(self, "checkpointloadersimple_241"):
del self.checkpointloadersimple_241
self.ckpt = None
# Always reset flags
self.generation_threads.clear()
# Reset generation state
self.is_generating = False
self.generate_button.configure(state="normal")
if __name__ == "__main__":
from modules.user.app_instance import app
app.mainloop()