File size: 3,727 Bytes
7b1a432 c93b55a 7b1a432 02a3a52 7b1a432 c93b55a 7b1a432 dc25832 c93b55a 7b1a432 dfab3a9 4a6aac0 7b1a432 dc25832 c93b55a dc25832 dfab3a9 dc25832 dfab3a9 dc25832 c93b55a dc25832 dfab3a9 c93b55a dc25832 dfab3a9 b9c75da dc25832 b9c75da 23d76b5 dfab3a9 7b1a432 23d76b5 7b1a432 23d76b5 7b1a432 dfab3a9 7b1a432 509d782 7b1a432 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import gradio as gr
import torch
import os
import random
import numpy as np
from diffusers import DiffusionPipeline
from safetensors.torch import load_file
from spaces import GPU # Remove if not in HF Space
# 1. Model and LoRA Loading (Before Gradio)
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
token = os.getenv("HF_TOKEN")
model_repo_id = "stabilityai/stable-diffusion-3.5-large"
try:
pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype, use_auth_token=token) # No need to check for token existence, diffusers handles this
pipe = pipe.to(device)
lora_filename = "lora_trained_model.safetensors" # EXACT filename of your LoRA
lora_path = os.path.join("./", lora_filename)
if os.path.exists(lora_path):
lora_weights = load_file(lora_path)
text_encoder = pipe.text_encoder
text_encoder.load_state_dict(lora_weights, strict=False)
print(f"LoRA loaded successfully from: {lora_path}")
else:
print(f"Error: LoRA file not found at: {lora_path}")
exit() # Stop if LoRA is not found
print("Stable Diffusion model and LoRA loaded successfully!")
except Exception as e:
print(f"Error loading model or LoRA: {e}")
exit()
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
@GPU(duration=65) # Only if in HF Space
def infer(
prompt,
negative_prompt="",
seed=42,
randomize_seed=False,
width=1024,
height=1024,
guidance_scale=4.5,
num_inference_steps=40,
progress=gr.Progress(track_tqdm=True),
):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
try:
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
).images[0]
return image, seed
except Exception as e:
print(f"Error during image generation: {e}") # Print error for debugging
return f"Error: {e}", seed # Return error to Gradio interface
# ... (rest of your Gradio code - examples, CSS, etc. - same as before)
# 4. Image generation function (now decorated)
@GPU(duration=65) # Only if in HF Space
def generate_image(prompt):
global pipeline
if pipeline is None:
print("Error: Pipeline is None (model not loaded)") # Log this specifically
return "Error: Model not loaded!"
try:
print("Starting image generation...") # Log before the image generation
image = pipeline(prompt).images[0]
print("Image generated successfully!")
return image
except Exception as e:
error_message = f"Error during image generation: {type(e).__name__}: {e}" # Include exception type
print(f"Full Error Details:\n{error_message}") # Print full details
return error_message # Return error message to Gradio
except RuntimeError as re:
error_message = f"Runtime Error during image generation: {type(re).__name__}: {re}" # Include exception type
print(f"Full Runtime Error Details:\n{error_message}") # Print full details
return error_message # Return error message to Gradio
# 5. Gradio interface
with gr.Blocks() as demo:
prompt_input = gr.Textbox(label="Prompt")
image_output = gr.Image(label="Generated Image")
generate_button = gr.Button("Generate")
generate_button.click(
fn=generate_image,
inputs=prompt_input,
outputs=image_output,
)
demo.launch() |