aiqcamp commited on
Commit
ce24b84
·
verified ·
1 Parent(s): 0ec65eb

Delete inference_webui.py

Browse files
Files changed (1) hide show
  1. inference_webui.py +0 -225
inference_webui.py DELETED
@@ -1,225 +0,0 @@
1
- import random
2
- import os
3
- import uuid
4
- from datetime import datetime
5
- import gradio as gr
6
- import numpy as np
7
- import spaces
8
- import torch
9
- from diffusers import DiffusionPipeline
10
- from PIL import Image
11
-
12
- # Create permanent storage directory
13
- SAVE_DIR = "saved_images" # Gradio will handle the persistence
14
- if not os.path.exists(SAVE_DIR):
15
- os.makedirs(SAVE_DIR, exist_ok=True)
16
-
17
- # Load the default image
18
- DEFAULT_IMAGE_PATH = "cover1.webp"
19
-
20
- device = "cuda" if torch.cuda.is_available() else "cpu"
21
- repo_id = "black-forest-labs/FLUX.1-dev"
22
- adapter_id = "alvdansen/pola-photo-flux"
23
-
24
- pipeline = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16)
25
- pipeline.load_lora_weights(adapter_id)
26
- pipeline = pipeline.to(device)
27
-
28
- MAX_SEED = np.iinfo(np.int32).max
29
- MAX_IMAGE_SIZE = 1024
30
-
31
- def save_generated_image(image, prompt):
32
- # Generate unique filename with timestamp
33
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
34
- unique_id = str(uuid.uuid4())[:8]
35
- filename = f"{timestamp}_{unique_id}.png"
36
- filepath = os.path.join(SAVE_DIR, filename)
37
-
38
- # Save the image
39
- image.save(filepath)
40
-
41
- # Save metadata
42
- metadata_file = os.path.join(SAVE_DIR, "metadata.txt")
43
- with open(metadata_file, "a", encoding="utf-8") as f:
44
- f.write(f"{filename}|{prompt}|{timestamp}\n")
45
-
46
- return filepath
47
-
48
- def load_generated_images():
49
- if not os.path.exists(SAVE_DIR):
50
- return []
51
-
52
- # Load all images from the directory
53
- image_files = [os.path.join(SAVE_DIR, f) for f in os.listdir(SAVE_DIR)
54
- if f.endswith(('.png', '.jpg', '.jpeg', '.webp'))]
55
- # Sort by creation time (newest first)
56
- image_files.sort(key=lambda x: os.path.getctime(x), reverse=True)
57
- return image_files
58
-
59
- def load_predefined_images():
60
- # Return empty list since we're not using predefined images
61
- return []
62
-
63
- @spaces.GPU(duration=120)
64
- def inference(
65
- prompt: str,
66
- seed: int,
67
- randomize_seed: bool,
68
- width: int,
69
- height: int,
70
- guidance_scale: float,
71
- num_inference_steps: int,
72
- lora_scale: float,
73
- progress: gr.Progress = gr.Progress(track_tqdm=True),
74
- ):
75
- if randomize_seed:
76
- seed = random.randint(0, MAX_SEED)
77
- generator = torch.Generator(device=device).manual_seed(seed)
78
-
79
- image = pipeline(
80
- prompt=prompt,
81
- guidance_scale=guidance_scale,
82
- num_inference_steps=num_inference_steps,
83
- width=width,
84
- height=height,
85
- generator=generator,
86
- joint_attention_kwargs={"scale": lora_scale},
87
- ).images[0]
88
-
89
- # Save the generated image
90
- filepath = save_generated_image(image, prompt)
91
-
92
- # Return the image, seed, and updated gallery
93
- return image, seed, load_generated_images()
94
-
95
-
96
- examples = [
97
- "polaroid style, a woman with long blonde hair wearing big round hippie sunglasses with a slight smile, white oversized fur coat, black dress, early evening in the city, polaroid style [trigger]"
98
- ]
99
-
100
- css = """
101
- footer {
102
- visibility: hidden;
103
- }
104
- """
105
-
106
- with gr.Blocks(theme=gr.themes.Soft(), css=css, analytics_enabled=False) as demo:
107
- gr.HTML('<div class="title"> Polaroid style Image Generation </div>')
108
-
109
-
110
- with gr.Tabs() as tabs:
111
- with gr.Tab("Generation"):
112
- with gr.Column(elem_id="col-container"):
113
- with gr.Row():
114
- prompt = gr.Text(
115
- label="Prompt",
116
- show_label=False,
117
- max_lines=1,
118
- placeholder="Enter your prompt",
119
- container=False,
120
- )
121
- run_button = gr.Button("Run", scale=0)
122
-
123
- # Modified to include the default image
124
- result = gr.Image(
125
- label="Result",
126
- show_label=False,
127
- value=DEFAULT_IMAGE_PATH # Set the default image
128
- )
129
-
130
- with gr.Accordion("Advanced Settings", open=False):
131
- seed = gr.Slider(
132
- label="Seed",
133
- minimum=0,
134
- maximum=MAX_SEED,
135
- step=1,
136
- value=42,
137
- )
138
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
139
-
140
- with gr.Row():
141
- width = gr.Slider(
142
- label="Width",
143
- minimum=256,
144
- maximum=MAX_IMAGE_SIZE,
145
- step=32,
146
- value=1024,
147
- )
148
- height = gr.Slider(
149
- label="Height",
150
- minimum=256,
151
- maximum=MAX_IMAGE_SIZE,
152
- step=32,
153
- value=768,
154
- )
155
-
156
- with gr.Row():
157
- guidance_scale = gr.Slider(
158
- label="Guidance scale",
159
- minimum=0.0,
160
- maximum=10.0,
161
- step=0.1,
162
- value=3.5,
163
- )
164
- num_inference_steps = gr.Slider(
165
- label="Number of inference steps",
166
- minimum=1,
167
- maximum=50,
168
- step=1,
169
- value=30,
170
- )
171
- lora_scale = gr.Slider(
172
- label="LoRA scale",
173
- minimum=0.0,
174
- maximum=1.0,
175
- step=0.1,
176
- value=1.0,
177
- )
178
-
179
- gr.Examples(
180
- examples=examples,
181
- inputs=[prompt],
182
- outputs=[result, seed],
183
- )
184
-
185
- with gr.Tab("Gallery"):
186
- gallery_header = gr.Markdown("### Generated Images Gallery")
187
- generated_gallery = gr.Gallery(
188
- label="Generated Images",
189
- columns=6,
190
- show_label=False,
191
- value=load_generated_images(),
192
- elem_id="generated_gallery",
193
- height="auto"
194
- )
195
- refresh_btn = gr.Button("🔄 Refresh Gallery")
196
-
197
-
198
- # Event handlers
199
- def refresh_gallery():
200
- return load_generated_images()
201
-
202
- refresh_btn.click(
203
- fn=refresh_gallery,
204
- inputs=None,
205
- outputs=generated_gallery,
206
- )
207
-
208
- gr.on(
209
- triggers=[run_button.click, prompt.submit],
210
- fn=inference,
211
- inputs=[
212
- prompt,
213
- seed,
214
- randomize_seed,
215
- width,
216
- height,
217
- guidance_scale,
218
- num_inference_steps,
219
- lora_scale,
220
- ],
221
- outputs=[result, seed, generated_gallery],
222
- )
223
-
224
- demo.queue()
225
- demo.launch()