Tech-Meld commited on
Commit
0dd3003
·
verified ·
1 Parent(s): 3df52b6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -169
app.py CHANGED
@@ -1,171 +1,37 @@
1
- import gradio as gr
2
- import numpy as np
3
- import random
4
  import torch
5
- from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler, AutoencoderKL, UNet2DConditionModel
6
- from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
7
- import spaces
8
-
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- dtype = torch.float16
11
-
12
- # Use the correct repo for SDXL
13
- repo = "stabilityai/sdxl-turbo" # This is the correct repo for SDXL
14
-
15
- # Load the model components separately
16
- vae = AutoencoderKL.from_pretrained(repo, subfolder="vae", torch_dtype=torch.float16).to(device)
17
- text_encoder = SD3Transformer2DModel.from_pretrained(repo, subfolder="text_encoder", torch_dtype=torch.float16).to(device)
18
- unet = UNet2DConditionModel.from_pretrained(repo, subfolder="unet", torch_dtype=torch.float16).to(device)
19
- scheduler = EulerDiscreteScheduler.from_pretrained(repo, subfolder="scheduler", torch_dtype=torch.float16)
20
-
21
- # Construct the pipeline (this is how you work with SDXL)
22
- pipe = StableDiffusionPipeline(
23
- vae=vae,
24
- text_encoder=text_encoder,
25
- unet=unet,
26
- scheduler=scheduler
27
- ).to(device)
28
-
29
- MAX_SEED = np.iinfo(np.int32).max
30
- MAX_IMAGE_SIZE = 1344
31
-
32
- def infer(prompts, negative_prompts, seeds, randomize_seeds, widths, heights, guidance_scales, num_inference_steps, progress=gr.Progress(track_tqdm=True)):
33
- images = []
34
- for i, prompt in enumerate(prompts):
35
- if randomize_seeds[i]:
36
- seeds[i] = random.randint(0, MAX_SEED)
37
-
38
- generator = torch.Generator().manual_seed(seeds[i])
39
-
40
- # SDXL requires a slightly different call format:
41
- image = pipe(
42
- prompt=prompt,
43
- negative_prompt=negative_prompts[i],
44
- guidance_scale=guidance_scales[i],
45
- num_inference_steps=num_inference_steps[i],
46
- width=widths[i],
47
- height=heights[i],
48
- generator=generator
49
- ).images[0]
50
-
51
- images.append(image)
52
-
53
- return images, seeds
54
-
55
- examples = [
56
- ["Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", "A blurry astronaut", 0, True, 512, 512, 7.5, 28],
57
- ["An astronaut riding a green horse", "Astronaut on a regular horse", 0, True, 512, 512, 7.5, 28],
58
- ["A delicious ceviche cheesecake slice", "A cheesecake that looks boring", 0, True, 512, 512, 7.5, 28],
59
- ]
60
-
61
- css="""
62
- #col-container {
63
- margin: 0 auto;
64
- max-width: 580px;
65
- }
66
- """
67
-
68
- with gr.Blocks(css=css) as demo:
69
-
70
- with gr.Column(elem_id="col-container"):
71
- gr.Markdown(f"""
72
- # Demo [Automated Stable Diffusion XL](https://huggingface.co/stabilityai/stablediffusion-xl)
73
- """)
74
-
75
- with gr.Row():
76
- prompt_group = gr.Group(elem_id="prompt_group")
77
- with prompt_group:
78
- prompt_input = gr.Text(
79
- label="Prompt",
80
- show_label=False,
81
- max_lines=1,
82
- placeholder="Enter your prompt",
83
- container=False,
84
- )
85
- negative_prompt_input = gr.Text(
86
- label="Negative prompt",
87
- max_lines=1,
88
- placeholder="Enter a negative prompt",
89
- )
90
- seed_input = gr.Slider(
91
- label="Seed",
92
- minimum=0,
93
- maximum=MAX_SEED,
94
- step=1,
95
- value=0,
96
- )
97
- randomize_seed_input = gr.Checkbox(label="Randomize seed", value=True)
98
- width_input = gr.Slider(
99
- label="Width",
100
- minimum=256,
101
- maximum=MAX_IMAGE_SIZE,
102
- step=64,
103
- value=512,
104
- )
105
- height_input = gr.Slider(
106
- label="Height",
107
- minimum=256,
108
- maximum=MAX_IMAGE_SIZE,
109
- step=64,
110
- value=512,
111
- )
112
- guidance_scale_input = gr.Slider(
113
- label="Guidance scale",
114
- minimum=0.0,
115
- maximum=10.0,
116
- step=0.1,
117
- value=7.5,
118
- )
119
- num_inference_steps_input = gr.Slider(
120
- label="Number of inference steps",
121
- minimum=1,
122
- maximum=50,
123
- step=1,
124
- value=28,
125
- )
126
- run_button = gr.Button("Run", scale=0)
127
-
128
- result = gr.Gallery(label="Results", show_label=False, columns=4, rows=1)
129
- add_button = gr.Button("Add Prompt")
130
-
131
- with gr.Accordion("Advanced Settings", open=False):
132
- pass
133
-
134
- gr.Examples(
135
- examples = examples,
136
- inputs = [
137
- prompt_input,
138
- negative_prompt_input,
139
- seed_input,
140
- randomize_seed_input,
141
- width_input,
142
- height_input,
143
- guidance_scale_input,
144
- num_inference_steps_input
145
- ]
146
- )
147
-
148
- def add_prompt():
149
- prompt_group.duplicate()
150
-
151
- def clear_prompts():
152
- prompt_group.clear()
153
 
154
- add_button.click(add_prompt)
155
- gr.on(
156
- triggers=[run_button.click, prompt_input.submit, negative_prompt_input.submit],
157
- fn=infer,
158
- inputs=[
159
- prompt_input,
160
- negative_prompt_input,
161
- seed_input,
162
- randomize_seed_input,
163
- width_input,
164
- height_input,
165
- guidance_scale_input,
166
- num_inference_steps_input
167
- ],
168
- outputs=[result, seed_input],
169
- api_name="infer"
170
- )
171
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
3
+ from huggingface_hub import hf_hub_download
4
+ from safetensors.torch import load_file
5
+ import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
+ base = "stabilityai/stable-diffusion-xl-base-1.0"
8
+ repo = "ByteDance/SDXL-Lightning"
9
+ ckpt = "sdxl_lightning_4step_unet.safetensors"
10
+
11
+ # Load model.
12
+ unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cpu")
13
+ unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cpu"))
14
+ pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float32, variant="fp32").to("cpu")
15
+
16
+ # Ensure sampler uses "trailing" timesteps.
17
+ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
18
+
19
+ def generate_images(prompt, num_inference_steps, guidance_scale, batch_size):
20
+ images = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, batch_size=batch_size).images
21
+ return images
22
+
23
+ # Define Gradio interface
24
+ iface = gr.Interface(
25
+ fn=generate_images,
26
+ inputs=[
27
+ gr.Textbox(label="Prompt"),
28
+ gr.Slider(label="Num Inference Steps", minimum=1, maximum=50, step=1, value=4),
29
+ gr.Slider(label="Guidance Scale", minimum=0, maximum=20, step=0.1, value=0),
30
+ gr.Slider(label="Batch Size", minimum=1, maximum=8, step=1, value=2),
31
+ ],
32
+ outputs=gr.Gallery(label="Generated Images"),
33
+ title="SDXL Lightning 4-Step Inference (CPU)",
34
+ description="Generate images with Stable Diffusion XL Lightning 4-Step model on CPU.",
35
+ )
36
+
37
+ iface.launch()