File size: 3,120 Bytes
37c7828
 
7195e76
 
 
a7b5445
839eb81
bd9bf48
839eb81
20271b2
 
 
 
 
7195e76
20271b2
 
7195e76
20271b2
 
7195e76
20271b2
 
 
bd9bf48
20271b2
 
 
 
 
bd9bf48
20271b2
 
 
 
 
 
7195e76
20271b2
 
 
 
 
 
 
 
 
 
fc20c03
20271b2
 
 
94d851c
20271b2
 
 
 
 
 
 
82198c8
20271b2
 
 
82198c8
20271b2
 
 
 
 
 
feabc9a
20271b2
 
9024a26
20271b2
ee8ab11
20271b2
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
import gradio as gr
import torch
import spaces
import random
from diffusers import StableDiffusion3Pipeline
from diffusers.loaders import SD3LoraLoaderMixin
from safetensors.torch import load_file, save_file

# Ensure GPU allocation for image generation (moved here)
def main():
    # Device selection
    device = "cuda" if torch.cuda.is_available() else "cpu"
    torch_dtype = torch.float16 if device == "cuda" else torch.float32

    # Load Hugging Face token securely
    token = os.getenv("HF_TOKEN")

    # Model ID for SD 3.5 Large
    model_repo_id = "stabilityai/stable-diffusion-3.5-large"

    # Convert .pt to .safetensors if needed
    lora_pt_path = "lora_trained_model.pt"
    lora_safetensors_path = "lora_trained_model.safetensors"

    if os.path.exists(lora_pt_path) and not os.path.exists(lora_safetensors_path):
        print("πŸ”„ Converting LoRA .pt to .safetensors...")
        lora_weights = torch.load(lora_pt_path, map_location="cpu")
        save_file(lora_weights, lora_safetensors_path)
        print(f"βœ… LoRA saved as {lora_safetensors_path}")

    # Load Stable Diffusion 3.5 pipeline with optimized settings
    pipeline = StableDiffusion3Pipeline.from_pretrained(
        model_repo_id,
        torch_dtype=torch_dtype,
        use_safetensors=True,
    ).to(device)

    # Load and fuse LoRA weights (optimized method)
    if os.path.exists(lora_safetensors_path):
        try:
            SD3LoraLoaderMixin.load_lora_weights(pipeline, lora_safetensors_path)
            pipeline.fuse_lora()
            print("βœ… LoRA weights loaded and fused successfully!")
        except Exception as e:
            print(f"❌ Error loading LoRA: {e}")
    else:
        print("⚠️ LoRA file not found! Running base model.")

    # Ensure LoRA is applied correctly
    applied_lora = any("lora" in name.lower() for name, _ in pipeline.text_encoder.named_parameters())
    print(f"βœ… LoRA Applied: {applied_lora}")

    # Image generation function with GPU decorator
    @spaces.GPU(duration=65)
    def generate_image(prompt: str, seed: int = None):
        """Generates an image using Stable Diffusion 3.5 with LoRA fine-tuning."""
        seed = seed or random.randint(0, 100000)
        generator = torch.Generator(device).manual_seed(seed)
        return pipeline(prompt, generator=generator).images[0]

    # Gradio Interface
    with gr.Blocks() as demo:
        gr.Markdown("# πŸ–ΌοΈ LoRA Fine-Tuned SD 3.5 Image Generator")

        with gr.Row():
            prompt_input = gr.Textbox(
                label="Enter Prompt",
                value="A woman in her 20s with expressive black eyes, graceful face, elegant body, standing on the beach at sunset. Photorealistic, highly detailed."
            )
            seed_input = gr.Number(label="Seed (optional)", value=None)

        generate_btn = gr.Button("Generate Image")
        output_image = gr.Image(label="Generated Image")

        generate_btn.click(generate_image, inputs=[prompt_input, seed_input], outputs=output_image)

    # Launch Gradio app
    demo.launch()

if __name__ == "__main__":
    main()