File size: 1,717 Bytes
feabc9a
bafa3e3
d985a28
37c7828
 
d985a28
b2f9a59
 
d985a28
b2f9a59
d985a28
b2f9a59
feabc9a
 
0e361a4
b2f9a59
 
 
 
d402a1c
b2f9a59
 
 
 
feabc9a
37c7828
b2f9a59
5120226
b2f9a59
37c7828
 
b2f9a59
 
37c7828
 
5120226
b2f9a59
 
 
 
feabc9a
 
37c7828
b2f9a59
 
 
 
 
 
 
 
 
 
 
 
 
37c7828
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch
from diffusers import StableDiffusion3Pipeline
from huggingface_hub import login
import os
import gradio as gr

# Retrieve Hugging Face token
token = os.getenv("HF_TOKEN")
if token:
    login(token=token)
else:
    raise ValueError("Hugging Face token not found. Please set it as a repository secret.")

# Load the Stable Diffusion 3.5 model
model_id = "stabilityai/stable-diffusion-3.5-large"
pipe = StableDiffusion3Pipeline.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
    device_map="balanced"
)

# Enable attention slicing for reduced memory usage
pipe.enable_attention_slicing()

# Define the path to the LoRA model
lora_model_path = "./lora_model.pth"

# Load and apply the LoRA weights
def load_lora_model(pipe, lora_model_path):
    lora_weights = torch.load(lora_model_path, map_location="cuda")
    pipe.unet.load_attn_procs(lora_weights)
    return pipe

pipe = load_lora_model(pipe, lora_model_path)

# Generate image function
def generate_image(prompt, steps, scale):
    with torch.inference_mode():  # Avoid gradient computation for inference
        image = pipe(prompt, num_inference_steps=steps, guidance_scale=scale).images[0]
    return image

# Gradio interface
iface = gr.Interface(
    fn=generate_image,
    inputs=[
        gr.Textbox(label="Enter your prompt"),
        gr.Slider(10, 50, step=1, value=30, label="Number of Inference Steps"),
        gr.Slider(1.0, 20.0, step=0.5, value=7.5, label="Guidance Scale"),
    ],
    outputs="image",
    title="Optimized Stable Diffusion with LoRA",
    description="Generate images using Stable Diffusion 3.5 with optimized memory usage."
)

# Launch the Gradio interface
iface.launch()