|
import os |
|
import gradio as gr |
|
import torch |
|
import spaces |
|
import random |
|
from diffusers import StableDiffusion3Pipeline |
|
from diffusers.loaders import SD3LoraLoaderMixin |
|
from safetensors.torch import load_file, save_file |
|
|
|
|
|
def main(): |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
torch_dtype = torch.float16 if device == "cuda" else torch.float32 |
|
|
|
|
|
token = os.getenv("HF_TOKEN") |
|
|
|
|
|
model_repo_id = "stabilityai/stable-diffusion-3.5-large" |
|
|
|
|
|
lora_pt_path = "lora_trained_model.pt" |
|
lora_safetensors_path = "lora_trained_model.safetensors" |
|
|
|
if os.path.exists(lora_pt_path) and not os.path.exists(lora_safetensors_path): |
|
print("π Converting LoRA .pt to .safetensors...") |
|
lora_weights = torch.load(lora_pt_path, map_location="cpu") |
|
save_file(lora_weights, lora_safetensors_path) |
|
print(f"β
LoRA saved as {lora_safetensors_path}") |
|
|
|
|
|
pipeline = StableDiffusion3Pipeline.from_pretrained( |
|
model_repo_id, |
|
torch_dtype=torch_dtype, |
|
use_safetensors=True, |
|
).to(device) |
|
|
|
|
|
if os.path.exists(lora_safetensors_path): |
|
try: |
|
SD3LoraLoaderMixin.load_lora_weights(pipeline, lora_safetensors_path) |
|
pipeline.fuse_lora() |
|
print("β
LoRA weights loaded and fused successfully!") |
|
except Exception as e: |
|
print(f"β Error loading LoRA: {e}") |
|
else: |
|
print("β οΈ LoRA file not found! Running base model.") |
|
|
|
|
|
applied_lora = any("lora" in name.lower() for name, _ in pipeline.text_encoder.named_parameters()) |
|
print(f"β
LoRA Applied: {applied_lora}") |
|
|
|
|
|
@spaces.GPU(duration=65) |
|
def generate_image(prompt: str, seed: int = None): |
|
"""Generates an image using Stable Diffusion 3.5 with LoRA fine-tuning.""" |
|
seed = seed or random.randint(0, 100000) |
|
generator = torch.Generator(device).manual_seed(seed) |
|
return pipeline(prompt, generator=generator).images[0] |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# πΌοΈ LoRA Fine-Tuned SD 3.5 Image Generator") |
|
|
|
with gr.Row(): |
|
prompt_input = gr.Textbox( |
|
label="Enter Prompt", |
|
value="A woman in her 20s with expressive black eyes, graceful face, elegant body, standing on the beach at sunset. Photorealistic, highly detailed." |
|
) |
|
seed_input = gr.Number(label="Seed (optional)", value=None) |
|
|
|
generate_btn = gr.Button("Generate Image") |
|
output_image = gr.Image(label="Generated Image") |
|
|
|
generate_btn.click(generate_image, inputs=[prompt_input, seed_input], outputs=output_image) |
|
|
|
|
|
demo.launch() |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|