Sergidev commited on
Commit
bc0df63
·
verified ·
1 Parent(s): d38aad6
Files changed (1) hide show
  1. app.py +54 -199
app.py CHANGED
@@ -14,214 +14,69 @@ from datetime import datetime
14
  from diffusers.models import AutoencoderKL
15
  from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
16
 
17
- logging.basicConfig(level=logging.INFO)
18
- logger = logging.getLogger(__name__)
19
 
20
- DESCRIPTION = "PonyDiffusion V6 XL"
21
- if not torch.cuda.is_available():
22
- DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU. </p>"
23
- IS_COLAB = utils.is_google_colab() or os.getenv("IS_COLAB") == "1"
24
- HF_TOKEN = os.getenv("HF_TOKEN")
25
- CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1"
26
- MIN_IMAGE_SIZE = int(os.getenv("MIN_IMAGE_SIZE", "512"))
27
- MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "2048"))
28
- USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE") == "1"
29
- ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
30
- OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./outputs")
31
 
32
- MODEL = os.getenv(
33
- "MODEL",
34
- "https://huggingface.co/AstraliteHeart/pony-diffusion-v6/blob/main/v6.safetensors",
35
- )
36
-
37
- torch.backends.cudnn.deterministic = True
38
- torch.backends.cudnn.benchmark = False
39
-
40
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
41
-
42
- def load_pipeline(model_name):
43
- vae = AutoencoderKL.from_pretrained(
44
- "madebyollin/sdxl-vae-fp16-fix",
45
- torch_dtype=torch.float16,
46
- )
47
- pipeline = (
48
- StableDiffusionXLPipeline.from_single_file
49
- if MODEL.endswith(".safetensors")
50
- else StableDiffusionXLPipeline.from_pretrained
51
- )
52
-
53
- pipe = pipeline(
54
- model_name,
55
- vae=vae,
56
- torch_dtype=torch.float16,
57
- custom_pipeline="lpw_stable_diffusion_xl",
58
- use_safetensors=True,
59
- add_watermarker=False,
60
- use_auth_token=HF_TOKEN,
61
- variant="fp16",
62
- )
63
-
64
- pipe.to(device)
65
- return pipe
66
-
67
- def parse_json_parameters(json_str):
68
- try:
69
- params = json.loads(json_str)
70
- return params
71
- except json.JSONDecodeError:
72
- return None
73
-
74
- def apply_json_parameters(json_str):
75
- params = parse_json_parameters(json_str)
76
- if params:
77
- return (
78
- params.get("prompt", ""),
79
- params.get("negative_prompt", ""),
80
- params.get("seed", 0),
81
- params.get("width", 1024),
82
- params.get("height", 1024),
83
- params.get("guidance_scale", 7.0),
84
- params.get("num_inference_steps", 30),
85
- params.get("sampler", "DPM++ 2M SDE Karras"),
86
- params.get("aspect_ratio", "1024 x 1024"),
87
- params.get("use_upscaler", False),
88
- params.get("upscaler_strength", 0.55),
89
- params.get("upscale_by", 1.5),
90
- )
91
- return [gr.update()] * 12
92
-
93
- @spaces.GPU
94
- def generate(
95
- prompt: str,
96
- negative_prompt: str = "",
97
- seed: int = 0,
98
- custom_width: int = 1024,
99
- custom_height: int = 1024,
100
- guidance_scale: float = 7.0,
101
- num_inference_steps: int = 30,
102
- sampler: str = "DPM++ 2M SDE Karras",
103
- aspect_ratio_selector: str = "1024 x 1024",
104
- use_upscaler: bool = False,
105
- upscaler_strength: float = 0.55,
106
- upscale_by: float = 1.5,
107
- progress=gr.Progress(track_tqdm=True),
108
- ) -> Image:
109
- generator = utils.seed_everything(seed)
110
-
111
- width, height = utils.aspect_ratio_handler(
112
- aspect_ratio_selector,
113
- custom_width,
114
- custom_height,
115
- )
116
-
117
- width, height = utils.preprocess_image_dimensions(width, height)
118
-
119
- backup_scheduler = pipe.scheduler
120
- pipe.scheduler = utils.get_scheduler(pipe.scheduler.config, sampler)
121
-
122
- if use_upscaler:
123
- upscaler_pipe = StableDiffusionXLImg2ImgPipeline(**pipe.components)
124
- metadata = {
125
- "prompt": prompt,
126
- "negative_prompt": negative_prompt,
127
- "resolution": f"{width} x {height}",
128
- "guidance_scale": guidance_scale,
129
- "num_inference_steps": num_inference_steps,
130
- "seed": seed,
131
- "sampler": sampler,
132
- }
133
 
134
- if use_upscaler:
135
- new_width = int(width * upscale_by)
136
- new_height = int(height * upscale_by)
137
- metadata["use_upscaler"] = {
138
- "upscale_method": "nearest-exact",
139
- "upscaler_strength": upscaler_strength,
140
- "upscale_by": upscale_by,
141
- "new_resolution": f"{new_width} x {new_height}",
 
 
 
 
 
 
 
142
  }
143
- else:
144
- metadata["use_upscaler"] = None
145
- logger.info(json.dumps(metadata, indent=4))
146
 
147
- try:
148
- if use_upscaler:
149
- latents = pipe(
150
- prompt=prompt,
151
- negative_prompt=negative_prompt,
152
- width=width,
153
- height=height,
154
- guidance_scale=guidance_scale,
155
- num_inference_steps=num_inference_steps,
156
- generator=generator,
157
- output_type="latent",
158
- ).images
159
- upscaled_latents = utils.upscale(latents, "nearest-exact", upscale_by)
160
- images = upscaler_pipe(
161
- prompt=prompt,
162
- negative_prompt=negative_prompt,
163
- image=upscaled_latents,
164
- guidance_scale=guidance_scale,
165
- num_inference_steps=num_inference_steps,
166
- strength=upscaler_strength,
167
- generator=generator,
168
- output_type="pil",
169
- ).images
170
- else:
171
- images = pipe(
172
- prompt=prompt,
173
- negative_prompt=negative_prompt,
174
- width=width,
175
- height=height,
176
- guidance_scale=guidance_scale,
177
- num_inference_steps=num_inference_steps,
178
- generator=generator,
179
- output_type="pil",
180
- ).images
181
 
182
- if images and IS_COLAB:
183
- for image in images:
184
- filepath = utils.save_image(image, metadata, OUTPUT_DIR)
185
- logger.info(f"Image saved as {filepath} with metadata")
 
 
186
 
187
- # Update history after generation
188
- history = gr.get_state("history") or []
189
- history.insert(0, {"prompt": prompt, "image": images[0], "metadata": metadata})
190
- gr.set_state("history", history[:10]) # Keep only the last 10 entries
191
 
192
- return images, metadata, gr.update(choices=[h["prompt"] for h in history])
193
- except Exception as e:
194
- logger.exception(f"An error occurred: {e}")
195
- raise
196
- finally:
197
- if use_upscaler:
198
- del upscaler_pipe
199
- pipe.scheduler = backup_scheduler
200
- utils.free_memory()
201
-
202
- def get_random_prompt():
203
- anime_characters = [
204
- "Naruto Uzumaki", "Monkey D. Luffy", "Goku", "Eren Yeager", "Light Yagami",
205
- "Lelouch Lamperouge", "Edward Elric", "Levi Ackerman", "Spike Spiegel",
206
- "Sakura Haruno", "Mikasa Ackerman", "Asuka Langley Soryu", "Rem", "Megumin",
207
- "Violet Evergarden"
208
- ]
209
- styles = ["pixel art", "stylized anime", "digital art", "watercolor", "sketch"]
210
- scores = ["score_9", "score_8_up", "score_7_up"]
211
-
212
- character = random.choice(anime_characters)
213
- style = random.choice(styles)
214
- score = ", ".join(random.sample(scores, k=3))
215
-
216
- return f"{score}, {character}, {style}, show accurate"
217
 
218
- if torch.cuda.is_available():
219
- pipe = load_pipeline(MODEL)
220
- logger.info("Loaded on Device!")
221
- else:
222
- pipe = None
 
 
 
 
 
 
 
 
 
223
 
224
  with gr.Blocks(css="style.css") as demo:
 
 
225
  title = gr.HTML(
226
  f"""<h1><span>{DESCRIPTION}</span></h1>""",
227
  elem_id="title",
@@ -338,7 +193,7 @@ with gr.Blocks(css="style.css") as demo:
338
  clear_button = gr.Button("Clear All")
339
  random_prompt_button = gr.Button("Random Prompt")
340
 
341
- history_dropdown = gr.Dropdown(label="Generation History", choices=[], interactive=True)
342
 
343
  with gr.Accordion(label="Generation Parameters", open=False):
344
  gr_metadata = gr.JSON(label="Metadata", show_label=False)
@@ -451,4 +306,4 @@ with gr.Blocks(css="style.css") as demo:
451
  outputs=prompt
452
  )
453
 
454
- demo.queue(max_size=20).launch(debug=IS_COLAB, share=IS_COLAB)
 
14
  from diffusers.models import AutoencoderKL
15
  from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
16
 
17
+ # ... (keep all the imports and initial setup)
 
18
 
19
+ # ... (keep all the functions like load_pipeline, parse_json_parameters, apply_json_parameters, generate, get_random_prompt)
 
 
 
 
 
 
 
 
 
 
20
 
21
+ if torch.cuda.is_available():
22
+ pipe = load_pipeline(MODEL)
23
+ logger.info("Loaded on Device!")
24
+ else:
25
+ pipe = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ # Define the JavaScript code as a string
28
+ js_code = """
29
+ <script>
30
+ document.addEventListener('DOMContentLoaded', (event) => {
31
+ const historyDropdown = document.getElementById('history-dropdown');
32
+ const resultGallery = document.querySelector('.gallery');
33
+
34
+ if (historyDropdown && resultGallery) {
35
+ const observer = new MutationObserver((mutations) => {
36
+ mutations.forEach((mutation) => {
37
+ if (mutation.type === 'childList' && mutation.addedNodes.length > 0) {
38
+ const newImage = mutation.addedNodes[0];
39
+ if (newImage.tagName === 'IMG') {
40
+ updateHistory(newImage.src);
41
+ }
42
  }
43
+ });
44
+ });
 
45
 
46
+ observer.observe(resultGallery, { childList: true });
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
+ function updateHistory(imageSrc) {
49
+ const prompt = document.querySelector('#prompt textarea').value;
50
+ const option = document.createElement('option');
51
+ option.value = prompt;
52
+ option.textContent = prompt;
53
+ option.setAttribute('data-image', imageSrc);
54
 
55
+ historyDropdown.insertBefore(option, historyDropdown.firstChild);
 
 
 
56
 
57
+ if (historyDropdown.children.length > 10) {
58
+ historyDropdown.removeChild(historyDropdown.lastChild);
59
+ }
60
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
+ historyDropdown.addEventListener('change', (event) => {
63
+ const selectedOption = event.target.selectedOptions[0];
64
+ const imageSrc = selectedOption.getAttribute('data-image');
65
+ if (imageSrc) {
66
+ const img = document.createElement('img');
67
+ img.src = imageSrc;
68
+ resultGallery.innerHTML = '';
69
+ resultGallery.appendChild(img);
70
+ }
71
+ });
72
+ }
73
+ });
74
+ </script>
75
+ """
76
 
77
  with gr.Blocks(css="style.css") as demo:
78
+ gr.HTML(js_code) # Add the JavaScript code to the interface
79
+
80
  title = gr.HTML(
81
  f"""<h1><span>{DESCRIPTION}</span></h1>""",
82
  elem_id="title",
 
193
  clear_button = gr.Button("Clear All")
194
  random_prompt_button = gr.Button("Random Prompt")
195
 
196
+ history_dropdown = gr.Dropdown(label="Generation History", choices=[], interactive=True, elem_id="history-dropdown")
197
 
198
  with gr.Accordion(label="Generation Parameters", open=False):
199
  gr_metadata = gr.JSON(label="Metadata", show_label=False)
 
306
  outputs=prompt
307
  )
308
 
309
+ demo.queue(max_size=20).launch(debug=IS_COLAB, share=IS_COLAB)