Testing2 / appCODE.py
DonImages's picture
Rename app.py to appCODE.py
d1a7285 verified
raw
history blame
2.98 kB
import os
import gradio as gr
import torch
import spaces
import random
from diffusers import StableDiffusion3Pipeline
from diffusers.loaders import SD3LoraLoaderMixin
from safetensors.torch import load_file, save_file
# Device selection
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# Load Hugging Face token securely
token = os.getenv("HF_TOKEN")
# Model ID for SD 3.5 Large
model_repo_id = "stabilityai/stable-diffusion-3.5-large"
# Convert .pt to .safetensors if needed
lora_pt_path = "lora_trained_model.pt"
lora_safetensors_path = "lora_trained_model.safetensors"
if os.path.exists(lora_pt_path) and not os.path.exists(lora_safetensors_path):
print("πŸ”„ Converting LoRA .pt to .safetensors...")
lora_weights = torch.load(lora_pt_path, map_location="cpu")
save_file(lora_weights, lora_safetensors_path)
print(f"βœ… LoRA saved as {lora_safetensors_path}")
# Load Stable Diffusion pipeline
pipeline = StableDiffusion3Pipeline.from_pretrained(
model_repo_id,
torch_dtype=torch_dtype,
use_safetensors=True, # Use safetensors format if supported
).to(device)
# Load and fuse LoRA trained weights
if os.path.exists(lora_safetensors_path):
try:
pipeline.load_lora_weights(".", weight_name="lora_trained_model.safetensors") # Corrected loading method
pipeline.fuse_lora() # Merges LoRA into the base model
print("βœ… LoRA weights loaded and fused successfully!")
except Exception as e:
print(f"❌ Error loading LoRA: {e}")
else:
print("⚠️ LoRA file not found! Running base model.")
# Verify if LoRA is applied
for name, param in pipeline.text_encoder.named_parameters():
if "lora" in name.lower():
print(f"βœ… LoRA applied to: {name}, requires_grad={param.requires_grad}")
# Ensure GPU allocation in Hugging Face Spaces
@spaces.GPU(duration=65)
def generate_image(prompt: str, seed: int = None):
"""Generates an image using Stable Diffusion 3.5 with LoRA fine-tuning."""
if seed is None:
seed = random.randint(0, 100000)
# Create a generator with the seed
generator = torch.manual_seed(seed)
# Generate the image using the pipeline
image = pipeline(prompt, generator=generator).images[0]
return image
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("# πŸ–ΌοΈ LoRA Fine-Tuned SD 3.5 Image Generator")
with gr.Row():
prompt_input = gr.Textbox(label="Enter Prompt", value="A woman in her 20s with expressive black eyes, graceful face, elegant body, standing on the beach at sunset. Photorealistic, highly detailed.")
seed_input = gr.Number(label="Seed (optional)", value=None)
generate_btn = gr.Button("Generate Image")
output_image = gr.Image(label="Generated Image")
generate_btn.click(generate_image, inputs=[prompt_input, seed_input], outputs=output_image)
# Launch the Gradio app
demo.launch()