import os import queue import sys import random import threading import tkinter as tk from tkinter import filedialog from typing import Union from PIL import Image, ImageTk import numpy as np import customtkinter as ctk import glob import time import torch # Add the directory containing LightDiffusion.py to the Python path sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) from modules.AutoDetailer import SAM, ADetailer, bbox, SEGS from modules.AutoEncoders import VariationalAE from modules.clip import Clip from modules.sample import sampling from modules.Utilities import util from modules.UltimateSDUpscale import USDU_upscaler, UltimateSDUpscale from modules.FileManaging import Downloader, ImageSaver, Loader from modules.Model import LoRas from modules.Utilities import Enhancer, Latent, upscale from modules.Quantize import Quantizer from modules.WaveSpeed import fbcache_nodes from modules.hidiffusion import msw_msa_attention from modules.AutoHDR import ahdr Downloader.CheckAndDownload() files = glob.glob("./_internal/checkpoints/*.safetensors") loras = glob.glob("./_internal/loras/*.safetensors") loras += glob.glob("./_internal/loras/*.pt") def debounce(wait): """Decorator to debounce resize events""" def decorator(fn): last_call = [0] def debounced(*args, **kwargs): current_time = time.time() if current_time - last_call[0] >= wait: fn(*args, **kwargs) last_call[0] = current_time return debounced return decorator class App(tk.Tk): """Main application class for the LightDiffusion GUI.""" def __init__(self): """Initialize the App class.""" super().__init__() self.title("LightDiffusion") self.geometry("900x750") # Configure main window grid self.grid_columnconfigure(1, weight=1) self.grid_rowconfigure(0, weight=1) file_names = [os.path.basename(file) for file in files] lora_names = [os.path.basename(lora) for lora in loras] selected_file = tk.StringVar() selected_lora = tk.StringVar() if file_names: selected_file.set(file_names[0]) if lora_names: selected_lora.set(lora_names[0]) # Create main sidebar frame with padding and grid self.sidebar = tk.Frame(self, bg="#FBFBFB", padx=10, pady=10) self.sidebar.grid(row=0, column=0, sticky="nsew") self.sidebar.grid_columnconfigure(0, weight=1) # Configure sidebar grid rows for i in range(8): self.sidebar.grid_rowconfigure(i, weight=1) # Text input frames with expansion self.prompt_frame = tk.Frame(self.sidebar, bg="#FBFBFB") self.prompt_frame.grid(row=0, column=0, sticky="nsew", pady=(0, 5)) self.prompt_frame.grid_columnconfigure(0, weight=1) self.prompt_frame.grid_rowconfigure(0, weight=2) self.prompt_frame.grid_rowconfigure(1, weight=1) # Prompt textbox with expansion self.prompt_entry = ctk.CTkTextbox( self.prompt_frame, height=150, fg_color="#E8F9FF", text_color="black", border_color="gray", border_width=2, ) self.prompt_entry.grid(row=0, column=0, sticky="nsew") # Negative prompt textbox with expansion self.neg = ctk.CTkTextbox( self.prompt_frame, height=75, fg_color="#E8F9FF", text_color="black", border_color="gray", border_width=2, ) self.neg.grid(row=1, column=0, sticky="nsew", pady=(5, 0)) # Add model dropdown with error handling for empty lists model_values = (file_names if file_names else ["No models found"]) + ["flux"] # Model dropdown and Flux checkbox self.dropdown = ctk.CTkOptionMenu( self.sidebar, values=model_values, fg_color="#F5EFFF", text_color="black", command=self.on_model_selected, ) self.dropdown.grid(row=2, column=0, sticky="ew") # LoRA selection self.lora_selection = ctk.CTkOptionMenu( self.sidebar, values=lora_names, fg_color="#F5EFFF", text_color="black" ) self.lora_selection.grid(row=3, column=0, sticky="ew", pady=5) # Display frame with expansion self.display = tk.Frame(self, bg="#FBFBFB") self.display.grid(row=0, column=1, sticky="nsew", padx=10, pady=10) self.display.grid_columnconfigure(0, weight=1) self.img = None # Add row configuration for both image and checkbox self.display.grid_rowconfigure(0, weight=1) # For image self.display.grid_rowconfigure(1, weight=0) # For checkbox # Image label with expansion self.image_label = tk.Label(self.display, bg="#FBFBFB") self.image_label.grid(row=0, column=0, sticky="nsew") # Previewer checkbox - changed from pack to grid self.previewer_var = tk.BooleanVar() self.previewer_checkbox = ctk.CTkCheckBox( self.display, text="Previewer", variable=self.previewer_var, command=self.print_previewer, text_color="black", ) self.previewer_checkbox.grid(row=1, column=0, pady=10) # Progress Bar self.progress = ctk.CTkProgressBar(self.display, fg_color="#FBFBFB") self.progress.grid(row=2, column=0, sticky="ew", pady=10, padx=10) self.progress.set(0) # Make sliders frame expand self.sliders_frame = tk.Frame(self.sidebar, bg="#FBFBFB") self.sliders_frame.grid(row=4, column=0, sticky="nsew", pady=5) self.sliders_frame.grid_columnconfigure(1, weight=1) # Configure slider weights for i in range(3): self.sliders_frame.grid_rowconfigure(i, weight=1) # Make checkbox frame expand self.checkbox_frame = tk.Frame(self.sidebar, bg="#FBFBFB") self.checkbox_frame.grid(row=5, column=0, sticky="nsew", pady=10) self.checkbox_frame.grid_columnconfigure(0, weight=1) self.checkbox_frame.grid_columnconfigure(1, weight=1) # Make button frame expand self.button_frame = tk.Frame(self.sidebar, bg="#FBFBFB") self.button_frame.grid(row=7, column=0, sticky="nsew", pady=10) self.button_frame.grid_columnconfigure(0, weight=1) self.button_frame.grid_columnconfigure(1, weight=1) # Width slider tk.Label(self.sliders_frame, text="Width:", bg="#FBFBFB").grid( row=0, column=0, padx=(0, 5) ) self.width_slider = ctk.CTkSlider( self.sliders_frame, from_=1, to=2048, number_of_steps=32, fg_color="#F5EFFF" ) self.width_slider.grid(row=0, column=1, sticky="ew") self.width_label = ctk.CTkLabel(self.sliders_frame, text="") self.width_label.grid(row=0, column=2, padx=(5, 0)) # Height slider tk.Label(self.sliders_frame, text="Height:", bg="#FBFBFB").grid( row=1, column=0, padx=(0, 5) ) self.height_slider = ctk.CTkSlider( self.sliders_frame, from_=1, to=2048, number_of_steps=32, fg_color="#F5EFFF" ) self.height_slider.grid(row=1, column=1, sticky="ew") self.height_label = ctk.CTkLabel(self.sliders_frame, text="") self.height_label.grid(row=1, column=2, padx=(5, 0)) # CFG slider tk.Label(self.sliders_frame, text="CFG:", bg="#FBFBFB").grid( row=2, column=0, padx=(0, 5) ) self.cfg_slider = ctk.CTkSlider( self.sliders_frame, from_=1, to=15, number_of_steps=14, fg_color="#F5EFFF" ) self.cfg_slider.grid(row=2, column=1, sticky="ew") self.cfg_label = ctk.CTkLabel(self.sliders_frame, text="") self.cfg_label.grid(row=2, column=2, padx=(5, 0)) # Batch size slider tk.Label(self.sliders_frame, text="Batch Size:", bg="#FBFBFB").grid( row=3, column=0, padx=(0, 5) ) self.batch_slider = ctk.CTkSlider( self.sliders_frame, from_=1, to=10, number_of_steps=9, fg_color="#F5EFFF" ) self.batch_slider.grid(row=3, column=1, sticky="ew") self.batch_label = ctk.CTkLabel(self.sliders_frame, text="") self.batch_label.grid(row=3, column=2, padx=(5, 0)) # Configure grid columns and rows to distribute space evenly self.checkbox_frame.grid_columnconfigure(0, weight=1) self.checkbox_frame.grid_columnconfigure(1, weight=1) self.checkbox_frame.grid_rowconfigure(0, weight=1) self.checkbox_frame.grid_rowconfigure(1, weight=1) self.checkbox_frame.grid_rowconfigure(2, weight=1) # checkbox for hiresfix self.hires_fix_var = tk.BooleanVar() self.hires_fix_checkbox = ctk.CTkCheckBox( self.checkbox_frame, text="Hires Fix", variable=self.hires_fix_var, command=self.print_hires_fix, text_color="black", ) self.hires_fix_checkbox.grid( row=0, column=0, padx=(75, 5), pady=5, sticky="nsew" ) # checkbox for Adetailer self.adetailer_var = tk.BooleanVar() self.adetailer_checkbox = ctk.CTkCheckBox( self.checkbox_frame, text="Adetailer", variable=self.adetailer_var, command=self.print_adetailer, text_color="black", ) self.adetailer_checkbox.grid(row=0, column=1, padx=5, pady=5, sticky="nsew") # checkbox to enable stable-fast optimization self.stable_fast_var = tk.BooleanVar() self.stable_fast_checkbox = ctk.CTkCheckBox( self.checkbox_frame, text="Stable Fast", variable=self.stable_fast_var, text_color="black", ) self.stable_fast_checkbox.grid( row=1, column=0, padx=(75, 5), pady=5, sticky="nsew" ) # checkbox to enable prompt enhancer self.enhancer_var = tk.BooleanVar() self.enhancer_checkbox = ctk.CTkCheckBox( self.checkbox_frame, text="Prompt enhancer", variable=self.enhancer_var, text_color="black", ) self.enhancer_checkbox.grid(row=1, column=1, padx=5, pady=5, sticky="nsew") self.prioritize_speed_var = tk.BooleanVar() self.prioritize_speed_checkbox = ctk.CTkCheckBox( self.checkbox_frame, text="Prioritize Speed", variable=self.prioritize_speed_var, text_color="black", ) self.prioritize_speed_checkbox.grid( row=2, column=0, padx=(75, 5), pady=5, sticky="nsew" ) # Button to launch the generation self.generate_button = ctk.CTkButton( self.sidebar, text="Generate", command=self.generate_image, fg_color="#C4D9FF", text_color="black", border_color="gray", border_width=2, ) self.generate_button.grid( row=6, column=0, pady=10, sticky="ew" ) # Changed from pack to grid self.ckpt = None # load the checkpoint on an another thread threading.Thread(target=self._prep, daemon=True).start() # img2img button self.img2img_button = ctk.CTkButton( self.button_frame, text="img2img", command=self.img2img, fg_color="#F5EFFF", text_color="black", border_color="gray", border_width=2, ) self.img2img_button.grid(row=0, column=0, padx=5, sticky="ew") # interrupt button self.generation_threads = [] self.interrupt_flag = False self.interrupt_button = ctk.CTkButton( self.button_frame, text="Interrupt", command=self.interrupt_generation, fg_color="#F5EFFF", text_color="black", border_color="gray", border_width=2, ) self.interrupt_button.grid(row=0, column=1, padx=5, sticky="ew") prompt, neg, width, height, cfg = util.load_parameters_from_file() self.prompt_entry.insert(tk.END, prompt) self.neg.insert(tk.END, neg) self.width_slider.set(width) self.height_slider.set(height) self.cfg_slider.set(cfg) self.batch_slider.set(1) self.width_slider.bind("", lambda event: self.update_labels()) self.height_slider.bind("", lambda event: self.update_labels()) self.cfg_slider.bind("", lambda event: self.update_labels()) self.batch_slider.bind("", lambda event: self.update_labels()) self.update_labels() self.prompt_entry.bind( "", lambda event: util.write_parameters_to_file( self.prompt_entry.get("1.0", tk.END), self.neg.get("1.0", tk.END), self.width_slider.get(), self.height_slider.get(), self.cfg_slider.get(), ), ) self.neg.bind( "", lambda event: util.write_parameters_to_file( self.prompt_entry.get("1.0", tk.END), self.neg.get("1.0", tk.END), self.width_slider.get(), self.height_slider.get(), self.cfg_slider.get(), ), ) self.width_slider.bind( "", lambda event: util.write_parameters_to_file( self.prompt_entry.get("1.0", tk.END), self.neg.get("1.0", tk.END), self.width_slider.get(), self.height_slider.get(), self.cfg_slider.get(), ), ) self.height_slider.bind( "", lambda event: util.write_parameters_to_file( self.prompt_entry.get("1.0", tk.END), self.neg.get("1.0", tk.END), self.width_slider.get(), self.height_slider.get(), self.cfg_slider.get(), ), ) self.cfg_slider.bind( "", lambda event: util.write_parameters_to_file( self.prompt_entry.get("1.0", tk.END), self.neg.get("1.0", tk.END), self.width_slider.get(), self.height_slider.get(), self.cfg_slider.get(), ), ) # Add resize handling variables self._resize_queue = queue.Queue() self._resize_thread = None self._resize_event = threading.Event() self._resize_lock = threading.Lock() self._resize_running = True self._last_resize_time = 0 self._resize_delay = 0.1 self._image_cache = {} self._current_image = None # Start resize worker thread self._start_resize_worker() # Bind resize event self.bind("", self._queue_resize) # Bind cleanup self.protocol("WM_DELETE_WINDOW", self._cleanup) self.display_most_recent_image_flag = False self.display_most_recent_image() self.is_generating = False self.sampler = ( "dpmpp_sde_cfgpp" if not self.prioritize_speed_var.get() else "dpmpp_2m_cfgpp" ) def _img2img(self, file_path: str) -> None: """Perform img2img on the selected image. Args: file_path (str): The path to the selected image. """ self.is_generating = True self.img2img_button.configure(state="disabled") self.display_most_recent_image_flag = False prompt = self.prompt_entry.get("1.0", tk.END) neg = self.neg.get("1.0", tk.END) img = Image.open(file_path) img_array = np.array(img) img_tensor = torch.from_numpy(img_array).float().to("cpu") / 255.0 img_tensor = img_tensor.unsqueeze(0) self.interrupt_flag = False self.sampler = ( "dpmpp_sde_cfgpp" if not self.prioritize_speed_var.get() else "dpmpp_2m_cfgpp" ) with torch.inference_mode(): ( checkpointloadersimple_241, cliptextencode, emptylatentimage, ksampler_instance, vaedecode, latentupscale, upscalemodelloader, ultimatesdupscale, ) = self._prep() try: loraloader = LoRas.LoraLoader() loraloader_274 = loraloader.load_lora( lora_name="add_detail.safetensors", strength_model=2, strength_clip=2, model=checkpointloadersimple_241[0], clip=checkpointloadersimple_241[1], ) except: loraloader_274 = checkpointloadersimple_241 if self.stable_fast_var.get() is True: from modules.StableFast import StableFast try: app.title("LigtDiffusion - Generating StableFast model") except: pass applystablefast = StableFast.ApplyStableFastUnet() applystablefast_158 = applystablefast.apply_stable_fast( enable_cuda_graph=False, model=loraloader_274[0], ) else: applystablefast_158 = loraloader_274 fb_cache = fbcache_nodes.ApplyFBCacheOnModel() applystablefast_158 = fb_cache.patch( applystablefast_158, "diffusion_model", 0.120 ) clipsetlastlayer = Clip.CLIPSetLastLayer() clipsetlastlayer_257 = clipsetlastlayer.set_last_layer( stop_at_clip_layer=-2, clip=loraloader_274[1] ) cliptextencode_242 = cliptextencode.encode( text=prompt, clip=clipsetlastlayer_257[0], ) cliptextencode_243 = cliptextencode.encode( text=neg, clip=clipsetlastlayer_257[0], ) upscalemodelloader_244 = upscalemodelloader.load_model( "RealESRGAN_x4plus.pth" ) try: app.title("LightDiffusion - Upscaling") except: pass ultimatesdupscale_250 = ultimatesdupscale.upscale( upscale_by=2, seed=random.randint(1, 2**64), steps=8, cfg=6, sampler_name=self.sampler, scheduler="karras", denoise=0.3, mode_type="Linear", tile_width=512, tile_height=512, mask_blur=16, tile_padding=32, seam_fix_mode="Half Tile", seam_fix_denoise=0.2, seam_fix_width=64, seam_fix_mask_blur=16, seam_fix_padding=32, force_uniform_tiles="enable", image=img_tensor, model=applystablefast_158[0], positive=cliptextencode_242[0], negative=cliptextencode_243[0], vae=checkpointloadersimple_241[2], upscale_model=upscalemodelloader_244[0], ) self.update_from_decode(ultimatesdupscale_250[0], "LD-I2I") self.update_image(img) global generated generated = img self.display_most_recent_image_flag = True try: app.title("LightDiffusion") except: pass self.is_generating = False self.img2img_button.configure(state="normal") def img2img(self) -> None: """Open the file selector and run img2img on the selected image.""" if self.is_generating: return file_path = filedialog.askopenfilename() if file_path: threading.Thread( target=self._img2img, args=(file_path,), daemon=True ).start() def print_hires_fix(self) -> None: """Print the status of the hires fix checkbox.""" if self.hires_fix_var.get() is True: print("Hires fix is ON") else: print("Hires fix is OFF") def print_adetailer(self) -> None: """Print the status of the adetailer checkbox.""" if self.adetailer_var.get() is True: print("Adetailer is ON") else: print("Adetailer is OFF") def print_previewer(self) -> None: """Print the status of the previewer checkbox.""" if self.previewer_var.get() is True: print("Previewer is ON") else: print("Previewer is OFF") def generate_image(self) -> None: """Start the image generation process.""" if self.is_generating: return if self.dropdown.get() == "flux": self.generate_thread = threading.Thread( target=self._generate_image_flux, daemon=True ).start() else: self.generate_thread = threading.Thread( target=self._generate_image, daemon=True ).start() def _prep(self) -> tuple: """Prepare the necessary components for image generation. Returns: tuple: The prepared components. """ if self.dropdown.get() != self.ckpt and self.dropdown.get() != "flux": self.ckpt = self.dropdown.get() with torch.inference_mode(): self.checkpointloadersimple = Loader.CheckpointLoaderSimple() self.checkpointloadersimple_241 = ( self.checkpointloadersimple.load_checkpoint( ckpt_name="./_internal/checkpoints/" + self.ckpt ) ) self.cliptextencode = Clip.CLIPTextEncode() self.emptylatentimage = Latent.EmptyLatentImage() self.ksampler_instance = sampling.KSampler() self.vaedecode = VariationalAE.VAEDecode() self.latent_upscale = upscale.LatentUpscale() self.upscalemodelloader = USDU_upscaler.UpscaleModelLoader() self.ultimatesdupscale = UltimateSDUpscale.UltimateSDUpscale() return ( self.checkpointloadersimple_241, self.cliptextencode, self.emptylatentimage, self.ksampler_instance, self.vaedecode, self.latent_upscale, self.upscalemodelloader, self.ultimatesdupscale, ) def _generate_image(self) -> None: """Generate image with proper interrupt handling.""" self.is_generating = True self.generate_button.configure(state="disabled") current_thread = threading.current_thread() self.generation_threads.append(current_thread) self.interrupt_flag = False self.sampler = ( "dpmpp_sde_cfgpp" if not self.prioritize_speed_var.get() else "dpmpp_2m_cfgpp" ) try: # Disable generate button during generation self.generate_button.configure(state="disabled") self.display_most_recent_image_flag = False self.progress.set(0) # Early interrupt check if self.interrupt_flag: return # Get generation parameters prompt = self.prompt_entry.get("1.0", tk.END) neg = self.neg.get("1.0", tk.END) w = int(self.width_slider.get()) h = int(self.height_slider.get()) cfg = int(self.cfg_slider.get()) try: if self.enhancer_var.get() is True: prompt = Enhancer.enhance_prompt(prompt) while prompt is None: pass except: pass # Main generation with proper interrupt handling with torch.inference_mode(): components = self._prep() if self.interrupt_flag: return ( checkpointloadersimple_241, cliptextencode, emptylatentimage, ksampler_instance, vaedecode, latentupscale, upscalemodelloader, ultimatesdupscale, ) = self._prep() try: loraloader = LoRas.LoraLoader() loraloader_274 = loraloader.load_lora( lora_name=self.lora_selection.get().replace( "./_internal/loras/", "" ), strength_model=0.7, strength_clip=0.7, model=checkpointloadersimple_241[0], clip=checkpointloadersimple_241[1], ) print( "loading", self.lora_selection.get().replace("./_internal/loras/", ""), ) except: loraloader_274 = checkpointloadersimple_241 try: cliptextencode_124 = cliptextencode.encode( text="royal, detailed, magnificient, beautiful, seducing", clip=loraloader_274[1], ) ultralyticsdetectorprovider = bbox.UltralyticsDetectorProvider() ultralyticsdetectorprovider_151 = ultralyticsdetectorprovider.doit( # model_name="face_yolov8m.pt" model_name="person_yolov8m-seg.pt" ) bboxdetectorsegs = bbox.BboxDetectorForEach() samdetectorcombined = SAM.SAMDetectorCombined() impactsegsandmask = SEGS.SegsBitwiseAndMask() detailerforeachdebug = ADetailer.DetailerForEachTest() except: pass clipsetlastlayer = Clip.CLIPSetLastLayer() clipsetlastlayer_257 = clipsetlastlayer.set_last_layer( stop_at_clip_layer=-2, clip=loraloader_274[1] ) self.progress.set(0.2) if self.stable_fast_var.get() is True: from modules.StableFast import StableFast try: self.title("LightDiffusion - Generating StableFast model") except: pass applystablefast = StableFast.ApplyStableFastUnet() applystablefast_158 = applystablefast.apply_stable_fast( enable_cuda_graph=False, model=loraloader_274[0], ) else: applystablefast_158 = loraloader_274 fb_cache = fbcache_nodes.ApplyFBCacheOnModel() applystablefast_158 = fb_cache.patch( applystablefast_158, "diffusion_model", 0.120 ) hidiffoptimizer = msw_msa_attention.ApplyMSWMSAAttentionSimple() cliptextencode_242 = cliptextencode.encode( text=prompt, clip=clipsetlastlayer_257[0], ) cliptextencode_243 = cliptextencode.encode( text=neg, clip=clipsetlastlayer_257[0], ) emptylatentimage_244 = emptylatentimage.generate( width=w, height=h, batch_size=int(self.batch_slider.get()) ) ksampler_239 = ksampler_instance.sample( seed=random.randint(1, 2**64), steps=20, cfg=cfg, sampler_name=self.sampler, scheduler="karras", denoise=1, model=hidiffoptimizer.go( model_type="auto", model=applystablefast_158[0] )[0], positive=cliptextencode_242[0], negative=cliptextencode_243[0], latent_image=emptylatentimage_244[0], ) self.progress.set(0.4) if self.hires_fix_var.get() is True: latentupscale_254 = latentupscale.upscale( width=w * 2, height=h * 2, samples=ksampler_239[0], ) ksampler_253 = ksampler_instance.sample( seed=random.randint(1, 2**64), steps=10, cfg=8, sampler_name="euler_ancestral_cfgpp", scheduler="normal", denoise=0.45, model=hidiffoptimizer.go( model_type="auto", model=applystablefast_158[0] )[0], positive=cliptextencode_242[0], negative=cliptextencode_243[0], latent_image=latentupscale_254[0], ) vaedecode_240 = vaedecode.decode( samples=ksampler_253[0], vae=checkpointloadersimple_241[2], ) self.update_from_decode(vaedecode_240[0], "LD-HF") else: vaedecode_240 = vaedecode.decode( samples=ksampler_239[0], vae=checkpointloadersimple_241[2], ) self.update_from_decode(vaedecode_240[0], "LD") if self.interrupt_flag: return self.progress.set(0.6) if self.adetailer_var.get() is True: samloader = SAM.SAMLoader() samloader_87 = samloader.load_model( model_name="sam_vit_b_01ec64.pth", device_mode="AUTO" ) bboxdetectorsegs_132 = bboxdetectorsegs.doit( threshold=0.5, dilation=10, crop_factor=2, drop_size=10, labels="all", bbox_detector=ultralyticsdetectorprovider_151[0], image=vaedecode_240[0], ) samdetectorcombined_139 = samdetectorcombined.doit( detection_hint="center-1", dilation=0, threshold=0.93, bbox_expansion=0, mask_hint_threshold=0.7, mask_hint_use_negative="False", sam_model=samloader_87[0], segs=bboxdetectorsegs_132, image=vaedecode_240[0], ) if samdetectorcombined_139[0] is None: return impactsegsandmask_152 = impactsegsandmask.doit( segs=bboxdetectorsegs_132, mask=samdetectorcombined_139[0], ) detailerforeachdebug_145 = detailerforeachdebug.doit( guide_size=512, guide_size_for=False, max_size=768, seed=random.randint(1, 2**64), steps=20, cfg=6.5, sampler_name=self.sampler, scheduler="karras", denoise=0.5, feather=5, noise_mask=True, force_inpaint=True, wildcard="", cycle=1, inpaint_model=False, noise_mask_feather=20, image=vaedecode_240[0], segs=impactsegsandmask_152[0], model=applystablefast_158[0], clip=checkpointloadersimple_241[1], vae=checkpointloadersimple_241[2], positive=cliptextencode_124[0], negative=cliptextencode_243[0], ) self.update_from_decode(detailerforeachdebug_145[0], "LD-body") ultralyticsdetectorprovider = bbox.UltralyticsDetectorProvider() ultralyticsdetectorprovider_151 = ultralyticsdetectorprovider.doit( model_name="face_yolov9c.pt" ) bboxdetectorsegs_132 = bboxdetectorsegs.doit( threshold=0.5, dilation=10, crop_factor=2, drop_size=10, labels="all", bbox_detector=ultralyticsdetectorprovider_151[0], image=detailerforeachdebug_145[0], ) samdetectorcombined_139 = samdetectorcombined.doit( detection_hint="center-1", dilation=0, threshold=0.93, bbox_expansion=0, mask_hint_threshold=0.7, mask_hint_use_negative="False", sam_model=samloader_87[0], segs=bboxdetectorsegs_132, image=detailerforeachdebug_145[0], ) impactsegsandmask_152 = impactsegsandmask.doit( segs=bboxdetectorsegs_132, mask=samdetectorcombined_139[0], ) detailerforeachdebug_145 = detailerforeachdebug.doit( guide_size=512, guide_size_for=False, max_size=768, seed=random.randint(1, 2**64), steps=20, cfg=6.5, sampler_name=self.sampler, scheduler="karras", denoise=0.5, feather=5, noise_mask=True, force_inpaint=True, wildcard="", cycle=1, inpaint_model=False, noise_mask_feather=20, image=detailerforeachdebug_145[0], segs=impactsegsandmask_152[0], model=applystablefast_158[0], clip=checkpointloadersimple_241[1], vae=checkpointloadersimple_241[2], positive=cliptextencode_124[0], negative=cliptextencode_243[0], ) self.update_from_decode(detailerforeachdebug_145[0], "LD-head") self.progress.set(0.8) except Exception as e: print(f"Generation error: {e}") self.title(f"LightDiffusion - Error: {str(e)}") finally: # Reset state when done self.is_generating = False self.generate_button.configure(state="normal") if current_thread in self.generation_threads: self.generation_threads.remove(current_thread) self.progress.set(0) # Clear CUDA cache if torch.cuda.is_available(): torch.cuda.empty_cache() def _generate_image_flux(self) -> None: """Generate an image using the Flux model.""" self.is_generating = True self.generate_button.configure(state="disabled") # Add current thread to list at start current_thread = threading.current_thread() self.generation_threads.append(current_thread) self.display_most_recent_image_flag = False w = int(self.width_slider.get()) h = int(self.height_slider.get()) prompt = self.prompt_entry.get("1.0", tk.END) try: if self.enhancer_var.get() is True: prompt = Enhancer.enhance_prompt(prompt) while prompt is None: pass self.interrupt_flag = False Downloader.CheckAndDownloadFlux() with torch.inference_mode(): dualcliploadergguf = Quantizer.DualCLIPLoaderGGUF() emptylatentimage = Latent.EmptyLatentImage() vaeloader = VariationalAE.VAELoader() unetloadergguf = Quantizer.UnetLoaderGGUF() cliptextencodeflux = Quantizer.CLIPTextEncodeFlux() conditioningzeroout = Quantizer.ConditioningZeroOut() ksampler = sampling.KSampler() vaedecode = VariationalAE.VAEDecode() unetloadergguf_10 = unetloadergguf.load_unet( unet_name="flux1-dev-Q8_0.gguf" ) vaeloader_11 = vaeloader.load_vae(vae_name="ae.safetensors") dualcliploadergguf_19 = dualcliploadergguf.load_clip( clip_name1="clip_l.safetensors", clip_name2="t5-v1_1-xxl-encoder-Q8_0.gguf", type="flux", ) emptylatentimage_5 = emptylatentimage.generate( width=w, height=h, batch_size=int(self.batch_slider.get()) ) cliptextencodeflux_15 = cliptextencodeflux.encode( clip_l=prompt, t5xxl=prompt, guidance=3.0, clip=dualcliploadergguf_19[0], flux_enabled=True, ) conditioningzeroout_16 = conditioningzeroout.zero_out( conditioning=cliptextencodeflux_15[0] ) fb_cache = fbcache_nodes.ApplyFBCacheOnModel() unetloadergguf_10 = fb_cache.patch( unetloadergguf_10, "diffusion_model", 0.120 ) # try: # import triton # compiler = misc_nodes.EnhancedCompileModel() # unetloadergguf_10 = compiler.patch(unetloadergguf_10, True, "diffusion_model", "torch.compile", False, False, None, None, False, "inductor") # except ImportError: # print("Triton not found, skipping compilation") ksampler_3 = ksampler.sample( seed=random.randint(1, 2**64), steps=20, cfg=1, sampler_name="euler_cfgpp", scheduler="beta", denoise=1, model=unetloadergguf_10[0], positive=cliptextencodeflux_15[0], negative=conditioningzeroout_16[0], latent_image=emptylatentimage_5[0], flux=True, ) vaedecode_8 = vaedecode.decode( samples=ksampler_3[0], vae=vaeloader_11[0], flux=True, ) self.update_from_decode(vaedecode_8[0], "LD-Flux") finally: # Reset state when done self.is_generating = False self.generate_button.configure(state="normal") if current_thread in self.generation_threads: self.generation_threads.remove(current_thread) def on_model_selected(self, *args): """Handle model selection changes""" if self.dropdown.get() == "flux": # Disable incompatible controls self.adetailer_checkbox._state = tk.DISABLED self.hires_fix_checkbox._state = tk.DISABLED self.stable_fast_checkbox._state = tk.DISABLED self.lora_selection._state = tk.DISABLED self.cfg_slider._state = tk.DISABLED else: # Enable controls self.adetailer_checkbox._state = tk.NORMAL self.hires_fix_checkbox._state = tk.NORMAL self.stable_fast_checkbox._state = tk.NORMAL self.lora_selection._state = tk.NORMAL self.cfg_slider._state = tk.NORMAL def _handle_decoded_image(self, decoded, prefix: str) -> None: """Handle decoded image processing with HDR effects. Args: decoded: Decoded tensor image prefix: Prefix for saved files """ try: # Initialize components saveimage = ImageSaver.SaveImage() hdr = ahdr.HDREffects() images = [] # Apply HDR effects if isinstance(decoded, tuple): # Handle tuple return tensor_image = decoded[0] else: tensor_image = decoded # Apply HDR as batch process processed = hdr.apply_hdr2(tensor_image) # Save images with prefix saveimage.save_images( filename_prefix=prefix, images=processed[0] if isinstance(processed, tuple) else processed, ) # Convert processed tensors to PIL images for img_tensor in ( processed[0] if isinstance(processed, tuple) else [processed] ): # Convert to numpy and scale img_array = 255.0 * img_tensor.cpu().numpy() # Handle different dimensions if img_array.ndim == 4: img_array = np.squeeze(img_array) img_array = img_array.reshape( -1, img_array.shape[-2], img_array.shape[-1] ) # Convert to PIL image img = Image.fromarray(np.clip(img_array, 0, 255).astype(np.uint8)) images.append(img) # Update display if not interrupted if not self.interrupt_flag: self.progress.set(1.0) if images: self.img = images[0] self.update_image(images) self.display_most_recent_image_flag = True except Exception as e: print(f"Image processing error: {e}") self.title(f"LightDiffusion - Error: {str(e)}") def update_from_decode(self, decoded: Image.Image, prefix: str) -> None: """Update the image from the decode function. Args: decoded (Image.Image): The decoded image tensor/tuple prefix (str): Prefix for saved files """ try: # Handle image processing in separate function self._handle_decoded_image(decoded, prefix) except Exception as e: print(f"Decode error: {e}") self.title(f"LightDiffusion - Error: {str(e)}") finally: # Ensure cleanup if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() def update_labels(self) -> None: """Update the labels for the sliders.""" self.width_label.configure(text=f"{int(self.width_slider.get())}") self.height_label.configure(text=f"{int(self.height_slider.get())}") self.cfg_label.configure(text=f"{int(self.cfg_slider.get())}") self.batch_label.configure(text=f"{int(self.batch_slider.get())}") def create_image_grid(self, images: list[Image.Image]) -> Image.Image: """Create a grid of images. Args: images (list[Image.Image]): List of images to arrange in grid Returns: Image.Image: Combined grid image """ # Calculate grid dimensions n = len(images) if n <= 1: return images[0] cols = int(np.ceil(np.sqrt(n))) rows = int(np.ceil(n / cols)) # Get max dimensions w_max = max(img.width for img in images) h_max = max(img.height for img in images) # Create output image grid = Image.new("RGB", (w_max * cols, h_max * rows)) # Paste images into grid for idx, img in enumerate(images): i = idx // cols j = idx % cols grid.paste(img, (j * w_max, i * h_max)) return grid def update_image(self, images: Union[Image.Image, list[Image.Image]]) -> None: """Update the displayed image(s). Args: images: Single image or list of images to display """ # Convert single image to list if isinstance(images, Image.Image): images = [images] # Create grid of all images grid_img = self.create_image_grid(images) # Calculate the aspect ratio of the grid aspect_ratio = grid_img.width / grid_img.height # Determine the new dimensions while maintaining the aspect ratio label_width = int(4 * self.winfo_width() / 7) label_height = int(4 * self.winfo_height() / 7) if label_width / aspect_ratio <= label_height: new_width = label_width new_height = int(label_width / aspect_ratio) else: new_height = label_height new_width = int(label_height * aspect_ratio) # Resize the grid image try: grid_img = grid_img.resize((new_width, new_height), Image.LANCZOS) except RecursionError: pass self.img = grid_img if self.display_most_recent_image_flag is False: self._update_image_label(grid_img) def _update_image_label(self, img: Image.Image) -> None: """Update the image label with the provided image. Args: img (Image.Image): The image to display. """ # Convert the PIL image to a Tkinter PhotoImage tk_image = ImageTk.PhotoImage(img) # Update the image label with the Tkinter PhotoImage self.image_label.config(image=tk_image) # Keep a reference to the image to prevent it from being garbage collected self.image_label.image = tk_image def display_most_recent_image(self) -> None: """Display the most recent image(s) from the output directory.""" # Get a list of all image files in the output directory image_files = glob.glob("./_internal/output/Classic/*") image_files += glob.glob("./_internal/output/Adetailer/*") image_files += glob.glob("./_internal/output/Flux/*") image_files += glob.glob("./_internal/output/HiresFix/*") image_files += glob.glob("./_internal/output/Img2Img/*") # If there are no image files, return if not image_files: return # Sort files by modification time in descending order image_files.sort(key=os.path.getmtime, reverse=True) # Get most recent timestamp latest_time = os.path.getmtime(image_files[0]) # Get all images from same batch (within 1 second of most recent) batch_images = [] for file in image_files: if abs(os.path.getmtime(file) - latest_time) < 1.0: try: img = Image.open(file) batch_images.append(img) except: continue if not batch_images: return # Display single image or grid of batch if len(batch_images) == 1: self.update_image(batch_images[0]) else: self.update_image(batch_images) def _start_resize_worker(self): """Start the resize worker thread""" self._resize_thread = threading.Thread(target=self._resize_worker, daemon=True) self._resize_thread.start() def _resize_worker(self): """Worker thread for handling resize operations""" while self._resize_running: try: # Wait for resize event or timeout if self._resize_queue.qsize() > 0: event = self._resize_queue.get(timeout=0.1) self._do_resize(event) self._resize_queue.task_done() else: self._resize_event.wait(timeout=0.1) self._resize_event.clear() except queue.Empty: continue def _queue_resize(self, event): """Queue a resize event""" current_time = time.time() # Debounce resize events if current_time - self._last_resize_time > self._resize_delay: self._last_resize_time = current_time self._resize_queue.put(event) self._resize_event.set() def _do_resize(self, event): """Handle resize operation in worker thread""" width = self.winfo_width() height = self.winfo_height() # Update UI components in main thread self.after(0, lambda: self._update_components(width, height)) # Update image if exists if hasattr(self, "img"): self._update_image_threaded(self.img) def _update_components(self, width, height): """Update UI components sizes""" # Update component sizes based on window dimensions width = self.winfo_width() height = self.winfo_height() # Scale text boxes prompt_height = int(height * 0.25) neg_height = int(height * 0.15) self.prompt_entry.configure(height=prompt_height) self.neg.configure(height=neg_height) def _update_image_threaded(self, img): """Thread-safe image update with caching""" if img is None: return with self._resize_lock: # Calculate dimensions aspect_ratio = img.width / img.height label_width = int(4 * self.winfo_width() / 7) label_height = int(4 * self.winfo_height() / 7) if label_width / aspect_ratio <= label_height: new_width = label_width new_height = int(label_width / aspect_ratio) else: new_height = label_height new_width = int(label_height * aspect_ratio) # Check cache cache_key = (new_width, new_height) if cache_key in self._image_cache: resized_img = self._image_cache[cache_key] else: try: resized_img = img.resize((new_width, new_height), Image.LANCZOS) self._image_cache[cache_key] = resized_img # Limit cache size if len(self._image_cache) > 5: self._image_cache.pop(next(iter(self._image_cache))) except: return if not self.display_most_recent_image_flag: # Update image in main thread self.after(0, lambda: self._update_image_label_safe(resized_img)) def _update_image_label_safe(self, img): """Thread-safe image label update""" if not self.display_most_recent_image_flag: self._current_image = ImageTk.PhotoImage(img) self.image_label.configure(image=self._current_image) def _cleanup(self): """Clean up threads before closing""" self._resize_running = False self._resize_event.set() if self._resize_thread: self._resize_thread.join(timeout=1.0) self.destroy() def interrupt_generation(self) -> None: """Interrupt ongoing image generation process.""" if not self.is_generating: return # Set interrupt flag first self.interrupt_flag = True # Clear CUDA cache and release memory if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() # Stop and cleanup threads for thread in self.generation_threads[:]: if thread and thread.is_alive(): thread.join(timeout=1.0) if thread in self.generation_threads: self.generation_threads.remove(thread) # Reset UI state self.progress.set(0) self.title("LightDiffusion") self.generate_button.configure(state="normal") self.display_most_recent_image_flag = True # Clear any pending resize tasks with self._resize_lock: self._resize_queue.queue.clear() # Reset model state if needed if hasattr(self, "checkpointloadersimple_241"): del self.checkpointloadersimple_241 self.ckpt = None # Always reset flags self.generation_threads.clear() # Reset generation state self.is_generating = False self.generate_button.configure(state="normal") if __name__ == "__main__": from modules.user.app_instance import app app.mainloop()