File size: 3,120 Bytes
37c7828 7195e76 a7b5445 839eb81 bd9bf48 839eb81 20271b2 7195e76 20271b2 7195e76 20271b2 7195e76 20271b2 bd9bf48 20271b2 bd9bf48 20271b2 7195e76 20271b2 fc20c03 20271b2 94d851c 20271b2 82198c8 20271b2 82198c8 20271b2 feabc9a 20271b2 9024a26 20271b2 ee8ab11 20271b2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import os
import gradio as gr
import torch
import spaces
import random
from diffusers import StableDiffusion3Pipeline
from diffusers.loaders import SD3LoraLoaderMixin
from safetensors.torch import load_file, save_file
# Ensure GPU allocation for image generation (moved here)
def main():
# Device selection
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if device == "cuda" else torch.float32
# Load Hugging Face token securely
token = os.getenv("HF_TOKEN")
# Model ID for SD 3.5 Large
model_repo_id = "stabilityai/stable-diffusion-3.5-large"
# Convert .pt to .safetensors if needed
lora_pt_path = "lora_trained_model.pt"
lora_safetensors_path = "lora_trained_model.safetensors"
if os.path.exists(lora_pt_path) and not os.path.exists(lora_safetensors_path):
print("π Converting LoRA .pt to .safetensors...")
lora_weights = torch.load(lora_pt_path, map_location="cpu")
save_file(lora_weights, lora_safetensors_path)
print(f"β
LoRA saved as {lora_safetensors_path}")
# Load Stable Diffusion 3.5 pipeline with optimized settings
pipeline = StableDiffusion3Pipeline.from_pretrained(
model_repo_id,
torch_dtype=torch_dtype,
use_safetensors=True,
).to(device)
# Load and fuse LoRA weights (optimized method)
if os.path.exists(lora_safetensors_path):
try:
SD3LoraLoaderMixin.load_lora_weights(pipeline, lora_safetensors_path)
pipeline.fuse_lora()
print("β
LoRA weights loaded and fused successfully!")
except Exception as e:
print(f"β Error loading LoRA: {e}")
else:
print("β οΈ LoRA file not found! Running base model.")
# Ensure LoRA is applied correctly
applied_lora = any("lora" in name.lower() for name, _ in pipeline.text_encoder.named_parameters())
print(f"β
LoRA Applied: {applied_lora}")
# Image generation function with GPU decorator
@spaces.GPU(duration=65)
def generate_image(prompt: str, seed: int = None):
"""Generates an image using Stable Diffusion 3.5 with LoRA fine-tuning."""
seed = seed or random.randint(0, 100000)
generator = torch.Generator(device).manual_seed(seed)
return pipeline(prompt, generator=generator).images[0]
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("# πΌοΈ LoRA Fine-Tuned SD 3.5 Image Generator")
with gr.Row():
prompt_input = gr.Textbox(
label="Enter Prompt",
value="A woman in her 20s with expressive black eyes, graceful face, elegant body, standing on the beach at sunset. Photorealistic, highly detailed."
)
seed_input = gr.Number(label="Seed (optional)", value=None)
generate_btn = gr.Button("Generate Image")
output_image = gr.Image(label="Generated Image")
generate_btn.click(generate_image, inputs=[prompt_input, seed_input], outputs=output_image)
# Launch Gradio app
demo.launch()
if __name__ == "__main__":
main()
|