gokaygokay commited on
Commit
25e3823
·
verified ·
1 Parent(s): 8a3fbac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -29
app.py CHANGED
@@ -2,7 +2,7 @@ import spaces
2
  import gradio as gr
3
  import torch
4
  from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor, pipeline
5
- from diffusers import StableDiffusion3Pipeline
6
  import re
7
  import random
8
  import numpy as np
@@ -118,6 +118,46 @@ def process_workflow(image, text_prompt, use_vlm, use_enhancer, model_choice, ne
118
 
119
  return generated_image, prompt, used_seed
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  custom_css = """
123
  .input-group, .output-group {
@@ -139,35 +179,35 @@ custom_css = """
139
  # Gradio Interface
140
  with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="blue", secondary_hue="gray")) as demo:
141
  gr.Markdown("# VLM Captioner + Prompt Enhancer + SD3 Image Generator")
142
-
143
- with gr.Row():
144
- with gr.Column(scale=1):
145
- with gr.Group(elem_classes="input-group"):
146
- input_image = gr.Image(label="Input Image for VLM")
147
- use_vlm = gr.Checkbox(label="Use VLM Captioner", value=False)
148
-
149
- with gr.Group(elem_classes="input-group"):
150
- text_prompt = gr.Textbox(label="Text Prompt")
151
- use_enhancer = gr.Checkbox(label="Use Prompt Enhancer", value=False)
152
- model_choice = gr.Radio(["Medium", "Long"], label="Enhancer Model", value="Long")
 
 
 
 
 
 
 
 
 
 
 
153
 
154
- with gr.Accordion("Advanced Settings", open=False):
155
- negative_prompt = gr.Textbox(label="Negative Prompt")
156
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
157
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
158
- width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=64, value=1024)
159
- height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=64, value=1024)
160
- guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=10.0, step=0.1, value=5.0)
161
- num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=50, step=1, value=28)
162
-
163
- generate_btn = gr.Button("Generate Image", elem_classes="submit-btn")
164
-
165
- with gr.Column(scale=1):
166
- with gr.Group(elem_classes="output-group"):
167
- output_image = gr.Image(label="Generated Image")
168
- final_prompt = gr.Textbox(label="Final Prompt Used")
169
- used_seed = gr.Number(label="Seed Used")
170
-
171
  generate_btn.click(
172
  fn=process_workflow,
173
  inputs=[
@@ -177,4 +217,43 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="blue", secondar
177
  outputs=[output_image, final_prompt, used_seed]
178
  )
179
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  demo.launch(debug=True)
 
2
  import gradio as gr
3
  import torch
4
  from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor, pipeline
5
+ from diffusers import StableDiffusion3Pipeline, StableDiffusion3Img2ImgPipeline
6
  import re
7
  import random
8
  import numpy as np
 
118
 
119
  return generated_image, prompt, used_seed
120
 
121
+ @spaces.GPU
122
+ def img2img_generate(
123
+ prompt: str,
124
+ init_image: gr.Image,
125
+ use_vlm: bool,
126
+ use_enhancer: bool,
127
+ model_choice: str,
128
+ negative_prompt: str = "",
129
+ seed: int = 0,
130
+ randomize_seed: bool = False,
131
+ guidance_scale: float = 7,
132
+ num_inference_steps: int = 30,
133
+ strength: float = 0.8,
134
+ ):
135
+ if use_vlm and init_image is not None:
136
+ prompt = create_captions_rich(init_image)
137
+
138
+ if use_enhancer:
139
+ prompt = enhance_prompt(prompt, model_choice)
140
+
141
+ if randomize_seed:
142
+ seed = random.randint(0, MAX_SEED)
143
+
144
+ generator = torch.Generator().manual_seed(seed)
145
+
146
+ img2img_pipe = StableDiffusion3Img2ImgPipeline.from_pretrained(model_path, torch_dtype=dtype).to(device)
147
+
148
+ init_image = init_image.resize((768, 768))
149
+
150
+ image = img2img_pipe(
151
+ prompt=prompt,
152
+ image=init_image,
153
+ negative_prompt=negative_prompt,
154
+ guidance_scale=guidance_scale,
155
+ num_inference_steps=num_inference_steps,
156
+ generator=generator,
157
+ strength=strength,
158
+ ).images[0]
159
+
160
+ return image, prompt, seed
161
 
162
  custom_css = """
163
  .input-group, .output-group {
 
179
  # Gradio Interface
180
  with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="blue", secondary_hue="gray")) as demo:
181
  gr.Markdown("# VLM Captioner + Prompt Enhancer + SD3 Image Generator")
182
+ with gr.Tab(label="Text to Image"):
183
+ with gr.Row():
184
+ with gr.Column(scale=1):
185
+ with gr.Group(elem_classes="input-group"):
186
+ input_image = gr.Image(label="Input Image for VLM")
187
+ use_vlm = gr.Checkbox(label="Use VLM Captioner", value=False)
188
+
189
+ with gr.Group(elem_classes="input-group"):
190
+ text_prompt = gr.Textbox(label="Text Prompt")
191
+ use_enhancer = gr.Checkbox(label="Use Prompt Enhancer", value=False)
192
+ model_choice = gr.Radio(["Medium", "Long"], label="Enhancer Model", value="Long")
193
+
194
+ with gr.Accordion("Advanced Settings", open=False):
195
+ negative_prompt = gr.Textbox(label="Negative Prompt")
196
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
197
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
198
+ width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=64, value=1024)
199
+ height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=64, value=1024)
200
+ guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=10.0, step=0.1, value=5.0)
201
+ num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=50, step=1, value=28)
202
+
203
+ generate_btn = gr.Button("Generate Image", elem_classes="submit-btn")
204
 
205
+ with gr.Column(scale=1):
206
+ with gr.Group(elem_classes="output-group"):
207
+ output_image = gr.Image(label="Generated Image")
208
+ final_prompt = gr.Textbox(label="Final Prompt Used")
209
+ used_seed = gr.Number(label="Seed Used")
210
+
 
 
 
 
 
 
 
 
 
 
 
211
  generate_btn.click(
212
  fn=process_workflow,
213
  inputs=[
 
217
  outputs=[output_image, final_prompt, used_seed]
218
  )
219
 
220
+ with gr.Tab(label="Image to Image"):
221
+ with gr.Row():
222
+ with gr.Column(scale=1):
223
+ with gr.Group(elem_classes="input-group"):
224
+ init_image = gr.Image(label="Input Image", type="pil")
225
+ use_vlm = gr.Checkbox(label="Use VLM Captioner", value=False)
226
+
227
+ with gr.Group(elem_classes="input-group"):
228
+ img2img_prompt = gr.Textbox(label="Text Prompt")
229
+ use_enhancer = gr.Checkbox(label="Use Prompt Enhancer", value=False)
230
+ model_choice = gr.Radio(["Medium", "Long"], label="Enhancer Model", value="Long")
231
+
232
+ with gr.Accordion("Advanced Settings", open=False):
233
+ negative_prompt = gr.Textbox(label="Negative Prompt")
234
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
235
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
236
+ guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=10.0, step=0.1, value=5)
237
+ num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=50, step=1, value=28)
238
+ strength = gr.Slider(label="Img2Img Strength", minimum=0.0, maximum=1.0, step=0.01, value=0.5)
239
+
240
+ img2img_generate_btn = gr.Button("Generate Image", elem_classes="submit-btn")
241
+
242
+ with gr.Column(scale=1):
243
+ with gr.Group(elem_classes="output-group"):
244
+ img2img_output = gr.Image(label="Generated Image")
245
+ img2img_final_prompt = gr.Textbox(label="Final Prompt Used")
246
+ img2img_used_seed = gr.Number(label="Seed Used")
247
+
248
+ img2img_generate_btn.click(
249
+ fn=img2img_generate,
250
+ inputs=[
251
+ img2img_prompt, init_image, use_vlm, use_enhancer, model_choice,
252
+ negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, strength
253
+ ],
254
+ outputs=[img2img_output, img2img_final_prompt, img2img_used_seed]
255
+ )
256
+
257
+
258
+
259
  demo.launch(debug=True)