File size: 2,219 Bytes
37c7828
 
7195e76
 
 
 
e7f18a8
7195e76
 
 
 
 
 
 
 
 
 
82198c8
7195e76
 
 
 
 
 
82198c8
7195e76
 
265ed58
0364ee7
7195e76
 
 
fc20c03
82198c8
 
 
 
 
 
 
 
 
 
 
 
 
 
7195e76
 
 
feabc9a
7195e76
 
 
fc8f71e
7195e76
 
9024a26
7195e76
feabc9a
7195e76
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
import gradio as gr
import torch
from diffusers import StableDiffusion3Pipeline
import spaces
import random

# Device selection
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

# Load the Hugging Face token securely
token = os.getenv("HF_TOKEN")

# Model ID for SD 3.5 Large
model_repo_id = "stabilityai/stable-diffusion-3.5-large"

# Load Stable Diffusion pipeline once at the start
pipeline = StableDiffusion3Pipeline.from_pretrained(
    model_repo_id,
    torch_dtype=torch_dtype,
    use_safetensors=True,  # Use safetensors format if supported
).to(device)

# Load the LoRA trained weights once at the start
lora_path = "lora_trained_model.pt"  # Ensure this file is uploaded in the Space
if os.path.exists(lora_path):
    lora_state_dict = torch.load(lora_path, map_location=device, weights_only=True)
    pipeline.load_lora_weights(lora_path)  # Load LoRA weights into the pipeline
    print("✅ LoRA weights loaded successfully!")
else:
    print("⚠️ LoRA file not found! Running base model.")

# Ensure GPU allocation in Hugging Face Spaces
@spaces.GPU(duration=65)
def generate_image(prompt: str, seed: int = None):
    """Generates an image using Stable Diffusion 3.5 with LoRA fine-tuning."""
    if seed is None:
        seed = random.randint(0, 100000)
    
    # Create a generator with the seed
    generator = torch.manual_seed(seed)

    # Generate the image using the pipeline
    image = pipeline(prompt, generator=generator).images[0]
    return image

# Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("# 🖼️ LoRA Fine-Tuned SD 3.5 Image Generator")

    with gr.Row():
        prompt_input = gr.Textbox(label="Enter Prompt", value="A woman in her 20s with expressive black eyes, graceful face, elegant body, standing on the beach at sunset. Photorealistic, highly detailed.")
        seed_input = gr.Number(label="Seed (optional)", value=None)
    
    generate_btn = gr.Button("Generate Image")
    output_image = gr.Image(label="Generated Image")

    generate_btn.click(generate_image, inputs=[prompt_input, seed_input], outputs=output_image)

# Launch Gradio App
demo.launch()