gokaygokay commited on
Commit
27de533
·
verified ·
1 Parent(s): 8ca0852

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +200 -0
app.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ import torch
4
+ from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor, pipeline
5
+ from diffusers import StableDiffusion3Pipeline
6
+ import re
7
+ import random
8
+ import numpy as np
9
+ import os
10
+ from huggingface_hub import snapshot_download
11
+
12
+ # Initialize models
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ dtype = torch.float16
15
+
16
+ huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
17
+
18
+ model_path = snapshot_download(
19
+ repo_id="stabilityai/stable-diffusion-3-medium",
20
+ repo_type="model",
21
+ ignore_patterns=["*.md", "*..gitattributes"],
22
+ local_dir="SD3",
23
+ token=huggingface_token, # type a new token-id.
24
+ )
25
+
26
+ # VLM Captioner
27
+ vlm_model = PaliGemmaForConditionalGeneration.from_pretrained("gokaygokay/sd3-long-captioner").to(device).eval()
28
+ vlm_processor = PaliGemmaProcessor.from_pretrained("gokaygokay/sd3-long-captioner")
29
+
30
+ # Prompt Enhancer
31
+ enhancer_medium = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance", device=device)
32
+ enhancer_long = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance-Long", device=device)
33
+
34
+ # SD3
35
+ sd3_pipe = StableDiffusion3Pipeline.from_pretrained(model_path, torch_dtype=dtype).to(device)
36
+
37
+ MAX_SEED = np.iinfo(np.int32).max
38
+ MAX_IMAGE_SIZE = 1344
39
+
40
+ # VLM Captioner function
41
+ def create_captions_rich(image):
42
+ prompt = "caption en"
43
+ model_inputs = vlm_processor(text=prompt, images=image, return_tensors="pt").to(device)
44
+ input_len = model_inputs["input_ids"].shape[-1]
45
+
46
+ with torch.inference_mode():
47
+ generation = vlm_model.generate(**model_inputs, max_new_tokens=256, do_sample=False)
48
+ generation = generation[0][input_len:]
49
+ decoded = vlm_processor.decode(generation, skip_special_tokens=True)
50
+
51
+ return modify_caption(decoded)
52
+
53
+ # Helper function for caption modification
54
+ def modify_caption(caption: str) -> str:
55
+ prefix_substrings = [
56
+ ('captured from ', ''),
57
+ ('captured at ', '')
58
+ ]
59
+ pattern = '|'.join([re.escape(opening) for opening, _ in prefix_substrings])
60
+ replacers = {opening: replacer for opening, replacer in prefix_substrings}
61
+
62
+ def replace_fn(match):
63
+ return replacers[match.group(0)]
64
+
65
+ return re.sub(pattern, replace_fn, caption, count=1, flags=re.IGNORECASE)
66
+
67
+ # Prompt Enhancer function
68
+ def enhance_prompt(input_prompt, model_choice):
69
+ if model_choice == "Medium":
70
+ result = enhancer_medium("Enhance the description: " + input_prompt)
71
+ enhanced_text = result[0]['summary_text']
72
+
73
+ pattern = r'^.*?of\s+(.*?(?:\.|$))'
74
+ match = re.match(pattern, enhanced_text, re.IGNORECASE | re.DOTALL)
75
+
76
+ if match:
77
+ remaining_text = enhanced_text[match.end():].strip()
78
+ modified_sentence = match.group(1).capitalize()
79
+ enhanced_text = modified_sentence + ' ' + remaining_text
80
+ else: # Long
81
+ result = enhancer_long("Enhance the description: " + input_prompt)
82
+ enhanced_text = result[0]['summary_text']
83
+
84
+ return enhanced_text
85
+
86
+ # SD3 Generation function
87
+ def generate_image(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
88
+ if randomize_seed:
89
+ seed = random.randint(0, MAX_SEED)
90
+
91
+ generator = torch.Generator().manual_seed(seed)
92
+
93
+ image = sd3_pipe(
94
+ prompt=prompt,
95
+ negative_prompt=negative_prompt,
96
+ guidance_scale=guidance_scale,
97
+ num_inference_steps=num_inference_steps,
98
+ width=width,
99
+ height=height,
100
+ generator=generator
101
+ ).images[0]
102
+
103
+ return image, seed
104
+
105
+ # Gradio Interface
106
+ @spaces.GPU
107
+ def process_workflow(image, text_prompt, use_vlm, use_enhancer, model_choice, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
108
+ if use_vlm and image is not None:
109
+ prompt = create_captions_rich(image)
110
+ else:
111
+ prompt = text_prompt
112
+
113
+ if use_enhancer:
114
+ prompt = enhance_prompt(prompt, model_choice)
115
+
116
+ generated_image, used_seed = generate_image(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps)
117
+
118
+ return generated_image, prompt, used_seed
119
+
120
+
121
+ css = """
122
+ body {
123
+ font-family: 'Arial', sans-serif;
124
+ background-color: #f0f4f8;
125
+ }
126
+ .container {
127
+ max-width: 800px;
128
+ margin: 0 auto;
129
+ padding: 20px;
130
+ background-color: white;
131
+ border-radius: 10px;
132
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
133
+ }
134
+ h1 {
135
+ color: #2c3e50;
136
+ text-align: center;
137
+ margin-bottom: 20px;
138
+ }
139
+ .input-box, .output-box {
140
+ border: 1px solid #bdc3c7;
141
+ border-radius: 5px;
142
+ padding: 10px;
143
+ }
144
+ .input-box:focus, .output-box:focus {
145
+ border-color: #3498db;
146
+ box-shadow: 0 0 5px rgba(52, 152, 219, 0.5);
147
+ }
148
+ .submit-btn {
149
+ background-color: #2980b9;
150
+ color: white;
151
+ border: none;
152
+ padding: 10px 20px;
153
+ border-radius: 5px;
154
+ cursor: pointer;
155
+ transition: background-color 0.3s;
156
+ }
157
+ .submit-btn:hover {
158
+ background-color: #3498db;
159
+ }
160
+ """
161
+
162
+
163
+ # Gradio Interface
164
+ with gr.Blocks(css=css) as demo:
165
+ gr.Markdown("# SD3 Image Generator + VLM Captioner + Prompt Enhancer")
166
+
167
+ with gr.Row():
168
+ with gr.Column():
169
+ input_image = gr.Image(label="Input Image for VLM")
170
+ text_prompt = gr.Textbox(label="Text Prompt")
171
+ use_vlm = gr.Checkbox(label="Use VLM Captioner", value=False)
172
+ use_enhancer = gr.Checkbox(label="Use Prompt Enhancer", value=False)
173
+ model_choice = gr.Radio(["Medium", "Long"], label="Enhancer Model", value="Long")
174
+
175
+ with gr.Accordion("Advanced Settings", open=False):
176
+ negative_prompt = gr.Textbox(label="Negative Prompt")
177
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
178
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
179
+ width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=64, value=1024)
180
+ height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=64, value=1024)
181
+ guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=10.0, step=0.1, value=5.0)
182
+ num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=50, step=1, value=28)
183
+
184
+ generate_btn = gr.Button("Generate Image")
185
+
186
+ with gr.Column():
187
+ output_image = gr.Image(label="Generated Image")
188
+ final_prompt = gr.Textbox(label="Final Prompt Used")
189
+ used_seed = gr.Number(label="Seed Used")
190
+
191
+ generate_btn.click(
192
+ fn=process_workflow,
193
+ inputs=[
194
+ input_image, text_prompt, use_vlm, use_enhancer, model_choice,
195
+ negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps
196
+ ],
197
+ outputs=[output_image, final_prompt, used_seed]
198
+ )
199
+
200
+ demo.launch()