import gradio as gr import torch import os import random import numpy as np from diffusers import DiffusionPipeline from safetensors.torch import load_file from spaces import GPU # Remove if not in HF Space # 1. Model and LoRA Loading (Before Gradio) device = "cuda" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 token = os.getenv("HF_TOKEN") model_repo_id = "stabilityai/stable-diffusion-3.5-large" try: pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype, use_auth_token=token) # No need to check for token existence, diffusers handles this pipe = pipe.to(device) lora_filename = "lora_trained_model.safetensors" # EXACT filename of your LoRA lora_path = os.path.join("./", lora_filename) if os.path.exists(lora_path): lora_weights = load_file(lora_path) text_encoder = pipe.text_encoder text_encoder.load_state_dict(lora_weights, strict=False) print(f"LoRA loaded successfully from: {lora_path}") else: print(f"Error: LoRA file not found at: {lora_path}") exit() # Stop if LoRA is not found print("Stable Diffusion model and LoRA loaded successfully!") except Exception as e: print(f"Error loading model or LoRA: {e}") exit() MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 1024 @GPU(duration=65) # Only if in HF Space def infer( prompt, negative_prompt="", seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=4.5, num_inference_steps=40, progress=gr.Progress(track_tqdm=True), ): if randomize_seed: seed = random.randint(0, MAX_SEED) generator = torch.Generator().manual_seed(seed) try: image = pipe( prompt=prompt, negative_prompt=negative_prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, width=width, height=height, generator=generator, ).images[0] return image, seed except Exception as e: print(f"Error during image generation: {e}") # Print error for debugging return f"Error: {e}", seed # Return error to Gradio interface # ... (rest of your Gradio code - examples, CSS, etc. - same as before) # 4. Image generation function (now decorated) @GPU(duration=65) # Only if in HF Space def generate_image(prompt): global pipeline if pipeline is None: print("Error: Pipeline is None (model not loaded)") # Log this specifically return "Error: Model not loaded!" try: print("Starting image generation...") # Log before the image generation image = pipeline(prompt).images[0] print("Image generated successfully!") return image except Exception as e: error_message = f"Error during image generation: {type(e).__name__}: {e}" # Include exception type print(f"Full Error Details:\n{error_message}") # Print full details return error_message # Return error message to Gradio except RuntimeError as re: error_message = f"Runtime Error during image generation: {type(re).__name__}: {re}" # Include exception type print(f"Full Runtime Error Details:\n{error_message}") # Print full details return error_message # Return error message to Gradio # 5. Gradio interface with gr.Blocks() as demo: prompt_input = gr.Textbox(label="Prompt") image_output = gr.Image(label="Generated Image") generate_button = gr.Button("Generate") generate_button.click( fn=generate_image, inputs=prompt_input, outputs=image_output, ) demo.launch()