File size: 5,595 Bytes
7b1a432
 
 
c93b55a
 
 
7b1a432
02a3a52
7b1a432
c93b55a
 
 
 
 
7b1a432
dc25832
c93b55a
 
7b1a432
dfab3a9
4a6aac0
7b1a432
dc25832
 
c93b55a
dc25832
dfab3a9
dc25832
dfab3a9
 
dc25832
c93b55a
dc25832
 
dfab3a9
c93b55a
 
 
50c10d2
c93b55a
 
 
 
 
 
 
 
 
 
 
 
 
 
a51aea0
d2d8e3a
dfee4d1
d2d8e3a
 
c93b55a
 
 
 
 
 
 
 
 
 
dfee4d1
c93b55a
dfee4d1
 
c93b55a
81b1867
 
 
79d9e62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b1a432
 
79d9e62
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import gradio as gr
import torch
import os
import random
import numpy as np
from diffusers import DiffusionPipeline
from safetensors.torch import load_file
from spaces import GPU  # Remove if not in HF Space

# 1. Model and LoRA Loading (Before Gradio)
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
token = os.getenv("HF_TOKEN")
model_repo_id = "stabilityai/stable-diffusion-3.5-large"

try:
    pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype, use_auth_token=token) # No need to check for token existence, diffusers handles this
    pipe = pipe.to(device)

    lora_filename = "lora_trained_model.safetensors"  # EXACT filename of your LoRA
    lora_path = os.path.join("./", lora_filename)

    if os.path.exists(lora_path):
        lora_weights = load_file(lora_path)
        text_encoder = pipe.text_encoder
        text_encoder.load_state_dict(lora_weights, strict=False)
        print(f"LoRA loaded successfully from: {lora_path}")
    else:
        print(f"Error: LoRA file not found at: {lora_path}")
        exit()  # Stop if LoRA is not found

    print("Stable Diffusion model and LoRA loaded successfully!")

except Exception as e:
    print(f"Error loading model or LoRA: {e}")
    exit()


MAX_SEED = 99999999999
MAX_IMAGE_SIZE = 1024

@GPU(duration=65)  # Only if in HF Space
def infer(
    prompt,
    negative_prompt="",
    seed=42,
    randomize_seed=False,
    width=1024,
    height=1024,
    guidance_scale=4.5,
    num_inference_steps=40,
    progress=gr.Progress(track_tqdm=True),
):
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)  # Generate a new seed if randomize_seed is True

    generator = torch.Generator(device=device).manual_seed(seed)  # Ensure the generator is on the correct device
    
    try:
        image = pipe(
            prompt=prompt,
            negative_prompt=negative_prompt,
            guidance_scale=guidance_scale,
            num_inference_steps=num_inference_steps,
            width=width,
            height=height,
            generator=generator,
        ).images[0]
        return image, seed  # Don't return seed back to the UI
    except Exception as e:
        print(f"Error during image generation: {e}")  # Print error for debugging
        return f"Error: {e}", seed  # Return error to Gradio interface

examples = [
        "A capybara wearing a suit holding a sign that reads Hello World",
]

css = """
#col-container {
    margin: 0 auto;
    max-width: 640px;
}
"""

with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown(" # [Stable Diffusion 3.5 Large (8B)](https://huggingface.co/stabilityai/stable-diffusion-3.5-large)")
        gr.Markdown("[Learn more](https://stability.ai/news/introducing-stable-diffusion-3-5) about the Stable Diffusion 3.5 series. Try on [Stability AI API](https://platform.stability.ai/docs/api-reference#tag/Generate/paths/~1v2beta~1stable-image~1generate~1sd3/post), or [download model](https://huggingface.co/stabilityai/stable-diffusion-3.5-large) to run locally with ComfyUI or diffusers.")
        with gr.Row():
            prompt = gr.Text(
                label="Prompt",
                show_label=False,
                max_lines=1,
                placeholder="Enter your prompt",
                container=False,
            )

            run_button = gr.Button("Run", scale=0, variant="primary")

        result = gr.Image(label="Result", show_label=False)

        with gr.Accordion("Advanced Settings", open=False):
            negative_prompt = gr.Text(
                label="Negative prompt",
                max_lines=1,
                placeholder="Enter a negative prompt",
                visible=False,
            )

            seed = gr.Slider(
                label="Seed",
                minimum=0,
                maximum=MAX_SEED,
                step=1,
                value=0,
            )

            randomize_seed = gr.Checkbox(label="Randomize seed", value=True)

            with gr.Row():
                width = gr.Slider(
                    label="Width",
                    minimum=512,
                    maximum=MAX_IMAGE_SIZE,
                    step=32,
                    value=1024, 
                )

                height = gr.Slider(
                    label="Height",
                    minimum=512,
                    maximum=MAX_IMAGE_SIZE,
                    step=32,
                    value=1024,
                )

            with gr.Row():
                guidance_scale = gr.Slider(
                    label="Guidance scale",
                    minimum=0.0,
                    maximum=7.5,
                    step=0.1,
                    value=4.5,
                )

                num_inference_steps = gr.Slider(
                    label="Number of inference steps",
                    minimum=1,
                    maximum=50,
                    step=1,
                    value=40, 
                )

        gr.Examples(examples=examples, inputs=[prompt], outputs=[result, seed], fn=infer, cache_examples=True, cache_mode="lazy")
    gr.on(
        triggers=[run_button.click, prompt.submit],
        fn=infer,
        inputs=[
            prompt,
            negative_prompt,
            seed,
            randomize_seed,
            width,
            height,
            guidance_scale,
            num_inference_steps,
        ],
        outputs=[result, seed],
    )

if __name__ == "__main__":
    demo.launch()