gokaygokay commited on
Commit
9e88c26
·
verified ·
1 Parent(s): 92211ba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -116
app.py CHANGED
@@ -2,7 +2,7 @@ import spaces
2
  import gradio as gr
3
  import torch
4
  from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor, pipeline
5
- from diffusers import StableDiffusion3Pipeline, StableDiffusion3Img2ImgPipeline
6
  import re
7
  import random
8
  import numpy as np
@@ -32,11 +32,8 @@ vlm_processor = PaliGemmaProcessor.from_pretrained("gokaygokay/sd3-long-captione
32
  enhancer_medium = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance", device=device)
33
  enhancer_long = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance-Long", device=device)
34
 
35
- def load_pipeline(pipeline_type):
36
- if pipeline_type == "text2img":
37
- return StableDiffusion3Pipeline.from_pretrained(model_path, torch_dtype=dtype).to(device)
38
- elif pipeline_type == "img2img":
39
- return StableDiffusion3Img2ImgPipeline.from_pretrained(model_path, torch_dtype=dtype).to(device)
40
 
41
  MAX_SEED = np.iinfo(np.int32).max
42
  MAX_IMAGE_SIZE = 1344
@@ -93,10 +90,8 @@ def generate_image(prompt, negative_prompt, seed, randomize_seed, width, height,
93
  seed = random.randint(0, MAX_SEED)
94
 
95
  generator = torch.Generator().manual_seed(seed)
96
-
97
- pipe = load_pipeline("text2img")
98
 
99
- image = pipe(
100
  prompt=prompt,
101
  negative_prompt=negative_prompt,
102
  guidance_scale=guidance_scale,
@@ -123,46 +118,6 @@ def process_workflow(image, text_prompt, use_vlm, use_enhancer, model_choice, ne
123
 
124
  return generated_image, prompt, used_seed
125
 
126
- @spaces.GPU
127
- def img2img_generate(
128
- prompt: str,
129
- init_image: gr.Image,
130
- use_vlm: bool,
131
- use_enhancer: bool,
132
- model_choice: str,
133
- negative_prompt: str = "",
134
- seed: int = 0,
135
- randomize_seed: bool = False,
136
- guidance_scale: float = 7,
137
- num_inference_steps: int = 30,
138
- strength: float = 0.8,
139
- ):
140
- if use_vlm and init_image is not None:
141
- prompt = create_captions_rich(init_image)
142
-
143
- if use_enhancer:
144
- prompt = enhance_prompt(prompt, model_choice)
145
-
146
- if randomize_seed:
147
- seed = random.randint(0, MAX_SEED)
148
-
149
- generator = torch.Generator().manual_seed(seed)
150
-
151
- img2img_pipe = load_pipeline("img2img")
152
-
153
- init_image = init_image.resize((768, 768))
154
-
155
- image = img2img_pipe(
156
- prompt=prompt,
157
- image=init_image,
158
- negative_prompt=negative_prompt,
159
- guidance_scale=guidance_scale,
160
- num_inference_steps=num_inference_steps,
161
- generator=generator,
162
- strength=strength,
163
- ).images[0]
164
-
165
- return image, prompt, seed
166
 
167
  custom_css = """
168
  .input-group, .output-group {
@@ -184,35 +139,35 @@ custom_css = """
184
  # Gradio Interface
185
  with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="blue", secondary_hue="gray")) as demo:
186
  gr.Markdown("# VLM Captioner + Prompt Enhancer + SD3 Image Generator")
187
- with gr.Tab(label="Text to Image"):
188
- with gr.Row():
189
- with gr.Column(scale=1):
190
- with gr.Group(elem_classes="input-group"):
191
- input_image = gr.Image(label="Input Image for VLM")
192
- use_vlm = gr.Checkbox(label="Use VLM Captioner", value=False)
193
-
194
- with gr.Group(elem_classes="input-group"):
195
- text_prompt = gr.Textbox(label="Text Prompt")
196
- use_enhancer = gr.Checkbox(label="Use Prompt Enhancer", value=False)
197
- model_choice = gr.Radio(["Medium", "Long"], label="Enhancer Model", value="Long")
198
-
199
- with gr.Accordion("Advanced Settings", open=False):
200
- negative_prompt = gr.Textbox(label="Negative Prompt")
201
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
202
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
203
- width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=64, value=1024)
204
- height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=64, value=1024)
205
- guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=10.0, step=0.1, value=5.0)
206
- num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=50, step=1, value=28)
207
-
208
- generate_btn = gr.Button("Generate Image", elem_classes="submit-btn")
209
 
210
- with gr.Column(scale=1):
211
- with gr.Group(elem_classes="output-group"):
212
- output_image = gr.Image(label="Generated Image")
213
- final_prompt = gr.Textbox(label="Final Prompt Used")
214
- used_seed = gr.Number(label="Seed Used")
215
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  generate_btn.click(
217
  fn=process_workflow,
218
  inputs=[
@@ -222,43 +177,4 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="blue", secondar
222
  outputs=[output_image, final_prompt, used_seed]
223
  )
224
 
225
- with gr.Tab(label="Image to Image"):
226
- with gr.Row():
227
- with gr.Column(scale=1):
228
- with gr.Group(elem_classes="input-group"):
229
- init_image = gr.Image(label="Input Image", type="pil")
230
- use_vlm = gr.Checkbox(label="Use VLM Captioner", value=False)
231
-
232
- with gr.Group(elem_classes="input-group"):
233
- img2img_prompt = gr.Textbox(label="Text Prompt")
234
- use_enhancer = gr.Checkbox(label="Use Prompt Enhancer", value=False)
235
- model_choice = gr.Radio(["Medium", "Long"], label="Enhancer Model", value="Long")
236
-
237
- with gr.Accordion("Advanced Settings", open=False):
238
- negative_prompt = gr.Textbox(label="Negative Prompt")
239
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
240
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
241
- guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=10.0, step=0.1, value=5)
242
- num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=50, step=1, value=28)
243
- strength = gr.Slider(label="Img2Img Strength", minimum=0.0, maximum=1.0, step=0.01, value=0.5)
244
-
245
- img2img_generate_btn = gr.Button("Generate Image", elem_classes="submit-btn")
246
-
247
- with gr.Column(scale=1):
248
- with gr.Group(elem_classes="output-group"):
249
- img2img_output = gr.Image(label="Generated Image")
250
- img2img_final_prompt = gr.Textbox(label="Final Prompt Used")
251
- img2img_used_seed = gr.Number(label="Seed Used")
252
-
253
- img2img_generate_btn.click(
254
- fn=img2img_generate,
255
- inputs=[
256
- img2img_prompt, init_image, use_vlm, use_enhancer, model_choice,
257
- negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, strength
258
- ],
259
- outputs=[img2img_output, img2img_final_prompt, img2img_used_seed]
260
- )
261
-
262
-
263
-
264
  demo.launch(debug=True)
 
2
  import gradio as gr
3
  import torch
4
  from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor, pipeline
5
+ from diffusers import StableDiffusion3Pipeline
6
  import re
7
  import random
8
  import numpy as np
 
32
  enhancer_medium = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance", device=device)
33
  enhancer_long = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance-Long", device=device)
34
 
35
+ # SD3
36
+ sd3_pipe = StableDiffusion3Pipeline.from_pretrained(model_path, torch_dtype=dtype).to(device)
 
 
 
37
 
38
  MAX_SEED = np.iinfo(np.int32).max
39
  MAX_IMAGE_SIZE = 1344
 
90
  seed = random.randint(0, MAX_SEED)
91
 
92
  generator = torch.Generator().manual_seed(seed)
 
 
93
 
94
+ image = sd3_pipe(
95
  prompt=prompt,
96
  negative_prompt=negative_prompt,
97
  guidance_scale=guidance_scale,
 
118
 
119
  return generated_image, prompt, used_seed
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  custom_css = """
123
  .input-group, .output-group {
 
139
  # Gradio Interface
140
  with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="blue", secondary_hue="gray")) as demo:
141
  gr.Markdown("# VLM Captioner + Prompt Enhancer + SD3 Image Generator")
142
+
143
+ with gr.Row():
144
+ with gr.Column(scale=1):
145
+ with gr.Group(elem_classes="input-group"):
146
+ input_image = gr.Image(label="Input Image for VLM")
147
+ use_vlm = gr.Checkbox(label="Use VLM Captioner", value=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
+ with gr.Group(elem_classes="input-group"):
150
+ text_prompt = gr.Textbox(label="Text Prompt")
151
+ use_enhancer = gr.Checkbox(label="Use Prompt Enhancer", value=False)
152
+ model_choice = gr.Radio(["Medium", "Long"], label="Enhancer Model", value="Long")
153
+
154
+ with gr.Accordion("Advanced Settings", open=False):
155
+ negative_prompt = gr.Textbox(label="Negative Prompt")
156
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
157
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
158
+ width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=64, value=1024)
159
+ height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=64, value=1024)
160
+ guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=10.0, step=0.1, value=5.0)
161
+ num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=50, step=1, value=28)
162
+
163
+ generate_btn = gr.Button("Generate Image", elem_classes="submit-btn")
164
+
165
+ with gr.Column(scale=1):
166
+ with gr.Group(elem_classes="output-group"):
167
+ output_image = gr.Image(label="Generated Image")
168
+ final_prompt = gr.Textbox(label="Final Prompt Used")
169
+ used_seed = gr.Number(label="Seed Used")
170
+
171
  generate_btn.click(
172
  fn=process_workflow,
173
  inputs=[
 
177
  outputs=[output_image, final_prompt, used_seed]
178
  )
179
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  demo.launch(debug=True)