trashchenkov commited on
Commit
113271f
·
verified ·
1 Parent(s): ba1587e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +270 -390
app.py CHANGED
@@ -1,289 +1,185 @@
1
  import gradio as gr
2
  import numpy as np
3
  import torch
4
- import re
5
-
6
- from diffusers import (
7
- StableDiffusionPipeline,
8
- ControlNetModel,
9
- StableDiffusionControlNetPipeline,
10
- DDIMScheduler,
11
- )
12
  from peft import PeftModel
 
13
  from PIL import Image
14
 
15
- # ------------------------------------------------------------------
16
- # Пример «заготовки» для IP-Adapter:
17
- # Предполагается, что у вас есть некий класс, умеющий:
18
- # 1) Загружать веса IP-Adapter
19
- # 2) Преобразовывать дополнительное «референс-изображение» в эмбеддинг
20
- # 3) Подмешивать этот эмбеддинг в процесс диффузии или текстовые эмбеддинги
21
- # ------------------------------------------------------------------
22
- class IPAdapterModel:
23
- def __init__(self, path_to_weights: str, device="cpu"):
24
- """
25
- Инициализация и загрузка весов IP-Adapter.
26
- path_to_weights - путь к файлам модели
27
- """
28
- # Здесь должен быть код инициализации вашей модели.
29
- # Например, что-то вроде:
30
- # self.model = torch.load(path_to_weights, map_location=device)
31
- # self.model.eval()
32
- # ...
33
- self.device = device
34
- self.dummy_weights_loaded = True # признак, что "что-то" загрузили
35
-
36
- def encode_reference_image(self, image: Image.Image):
37
- """
38
- Преобразовать референс-изображение в некий вектор (embedding),
39
- который затем можно использовать для модификации генерации.
40
- """
41
- # В реальном коде будет извлечение фич.
42
- # Для примера вернём фиктивный тензор.
43
- dummy_embedding = torch.zeros((1, 768)).to(self.device)
44
- return dummy_embedding
45
-
46
- def blend_latents_with_adapter(self, latents: torch.Tensor, adapter_embedding: torch.Tensor, scale: float):
47
- """
48
- Примерная функция, которая «подмешивает» признаки из адаптера
49
- в латенты перед декодированием.
50
- latents: (batch, channels, height, width)
51
- adapter_embedding: (1, embedding_dim)
52
- scale: сила влияния адаптера
53
- """
54
- # Для демонстрации просто прибавим (scale * mean(adapter_embedding))
55
- # В реальном IP-Adapter это гораздо сложнее.
56
- if adapter_embedding is not None:
57
- # Возьмём скаляр (к примеру)
58
- mean_val = adapter_embedding.mean()
59
- latents = latents + scale * mean_val
60
- return latents
61
-
62
-
63
- # ------------------------------------------------------------------
64
  # Регулярное выражение для проверки корректности модели
65
- # ------------------------------------------------------------------
66
  VALID_REPO_ID_REGEX = re.compile(r"^[a-zA-Z0-9._\-]+/[a-zA-Z0-9._\-]+$")
67
- def is_valid_repo_id(repo_id):
68
- return bool(VALID_REPO_ID_REGEX.match(repo_id)) and not repo_id.endswith(('-', '.'))
69
 
70
- # ------------------------------------------------------------------
71
- # Аппаратные настройки
72
- # ------------------------------------------------------------------
73
- device = "cuda" if torch.cuda.is_available() else "cpu"
74
- torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
75
 
76
- # ------------------------------------------------------------------
77
- # Константы
78
- # ------------------------------------------------------------------
 
79
  MAX_SEED = np.iinfo(np.int32).max
80
  MAX_IMAGE_SIZE = 1024
81
 
82
- # ------------------------------------------------------------------
83
- # Базовая модель (Stable Diffusion) по умолчанию
84
- # ------------------------------------------------------------------
85
  model_repo_id = "CompVis/stable-diffusion-v1-4"
 
86
 
87
- # Загрузка базового пайплайна (без ControlNet)
88
- pipe = StableDiffusionPipeline.from_pretrained(
89
- model_repo_id, torch_dtype=torch_dtype, safety_checker=None
90
- ).to(device)
91
-
92
- # ��рименим DDIM-схему как пример
93
- pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
94
-
95
- # Пробуем подгрузить LoRA (unet + text_encoder)
96
  try:
97
  pipe.unet = PeftModel.from_pretrained(pipe.unet, "./unet")
98
  pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, "./text_encoder")
99
  except Exception as e:
100
  print(f"Не удалось подгрузить LoRA по умолчанию: {e}")
101
 
102
- # ------------------------------------------------------------------
103
- # Инициализация «IP-Adapter» (для примера укажем вымышленный путь).
104
- # Предположим, что IP-Adapter мы храним в ./ip_adapter_weights
105
- # ------------------------------------------------------------------
106
- ip_adapter_model = None
107
- try:
108
- ip_adapter_model = IPAdapterModel("./ip_adapter_weights", device=device)
109
- except Exception as e:
110
- print(f"Не удалось загрузить IP-Adapter: {e}")
111
 
112
- # ------------------------------------------------------------------
113
- # Функция генерации
114
- # ------------------------------------------------------------------
115
  def infer(
116
- model, # Текстовое поле: модель (repo) напр. "CompVis/stable-diffusion-v1-4"
117
- prompt, # Текст: позитивный промпт
118
- negative_prompt, # Текст: негативный промпт
119
- seed, # Сид генератора
120
- width, # Ширина
121
- height, # Высота
122
- guidance_scale, # guidance scale
123
- num_inference_steps, # Количество шагов диффузии
124
- use_controlnet, # Чекбокс: включать ли ControlNet
125
- control_strength, # Слайдер: сила влияния ControlNet
126
- controlnet_mode, # Выпадающий список: edge_detection, pose_estimation, depth_estimation
127
- controlnet_image, # Изображение для ControlNet
128
- use_ip_adapter, # Чекбокс: включать ли IP-adapter
129
- ip_adapter_scale, # Слайдер: сила влияния IP-adapter
130
- ip_adapter_image, # Изображение для IP-adapter
131
  progress=gr.Progress(track_tqdm=True),
132
  ):
133
- global model_repo_id, pipe, ip_adapter_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
- # ---------------------------
136
- # 1) Проверяем, не сменил ли пользователь модель
137
- # ---------------------------
138
- if model != model_repo_id:
139
- if not is_valid_repo_id(model):
140
- raise gr.Error(f"Некорректный идентификатор модели: '{model}'. Проверьте название.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
  try:
143
- # Подгружаем модель (без ControlNet)
144
- new_pipe = StableDiffusionPipeline.from_pretrained(
145
- model, torch_dtype=torch_dtype, safety_checker=None
146
- ).to(device)
147
- new_pipe.scheduler = DDIMScheduler.from_config(new_pipe.scheduler.config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
- # Повторно загружаем LoRA
 
 
 
 
150
  try:
151
- new_pipe.unet = PeftModel.from_pretrained(new_pipe.unet, "./unet")
152
- new_pipe.text_encoder = PeftModel.from_pretrained(new_pipe.text_encoder, "./text_encoder")
 
 
 
 
 
 
153
  except Exception as e:
154
- print(f"��е удалось подгрузить LoRA для новой модели: {e}")
155
-
156
- pipe = new_pipe
157
- model_repo_id = model
158
-
159
- except Exception as e:
160
- raise gr.Error(f"Не удалось загрузить модель '{model}'.\nОшибка: {e}")
161
-
162
- # ---------------------------
163
- # 2) Если включён ControlNet — создаём ControlNetPipeline
164
- # ---------------------------
165
- local_pipe = pipe # по умолчанию используем базовый pipe
166
-
167
- if use_controlnet:
168
- # Выбираем репозиторий ControlNet в зависимости от режима
169
- if controlnet_mode == "edge_detection":
170
- controlnet_repo = "lllyasviel/sd-controlnet-canny"
171
- elif controlnet_mode == "pose_estimation":
172
- controlnet_repo = "lllyasviel/sd-controlnet-openpose"
173
- elif controlnet_mode == "depth_estimation":
174
- controlnet_repo = "lllyasviel/sd-controlnet-depth"
175
- else:
176
- raise gr.Error(f"Неизвестный режим ControlNet: {controlnet_mode}")
177
 
178
  try:
179
- controlnet_model = ControlNetModel.from_pretrained(
180
- controlnet_repo,
181
- torch_dtype=torch_dtype
182
- ).to(device)
183
-
184
- # Создаём новый pipeline, указывая ControlNet
185
- local_pipe = StableDiffusionControlNetPipeline(
186
- vae=pipe.vae,
187
- text_encoder=pipe.text_encoder,
188
- tokenizer=pipe.tokenizer,
189
- unet=pipe.unet,
190
- controlnet=controlnet_model,
191
- scheduler=pipe.scheduler,
192
- safety_checker=None,
193
- feature_extractor=pipe.feature_extractor,
194
- requires_safety_checker=False,
195
- ).to(device)
196
-
197
  except Exception as e:
198
- raise gr.Error(f"Ошибка загрузки ControlNet ({controlnet_mode}): {e}")
199
-
200
- # ---------------------------
201
- # 3) Генератор случайных чисел для детерминированности
202
- # ---------------------------
203
- generator = torch.Generator(device=device).manual_seed(seed)
204
-
205
- # ---------------------------
206
- # 4) Если есть IP-Adapter, подгружаем фичи из референс-изображения
207
- # ---------------------------
208
- ip_adapter_embedding = None
209
- if use_ip_adapter and ip_adapter_model is not None and ip_adapter_model.dummy_weights_loaded:
210
- if ip_adapter_image is not None:
211
- ip_adapter_embedding = ip_adapter_model.encode_reference_image(ip_adapter_image)
212
- else:
213
- print("IP-Adapter включён, но не загружено референс-изображение.")
214
- elif use_ip_adapter:
215
- print("IP-Adapter включён, но модель не загружена или не инициализирована.")
216
-
217
- # ---------------------------
218
- # 5) Выполняем диффузию
219
- # (с учётом ControlNet, если включён)
220
- # ---------------------------
221
-
222
- # Параметры для ControlNetPipeline
223
- # - Для edge/pose/depth обычно передают control_image через параметр "image"
224
- # - Дополнительно можно задать "controlnet_conditioning_scale" (aka strength)
225
- # чтобы указать вес ControlNet.
226
- # - Документация: https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/controlnet
227
- extra_kwargs = {}
228
- if use_controlnet and controlnet_image is not None:
229
- extra_kwargs["image"] = controlnet_image
230
- extra_kwargs["controlnet_conditioning_scale"] = control_strength
231
- elif use_controlnet:
232
- print("ControlNet включён, но не загружено изображение для ControlNet.")
233
-
234
- # Запуск генерации
235
- try:
236
- output = local_pipe(
237
- prompt=prompt,
238
- negative_prompt=negative_prompt,
239
- num_inference_steps=num_inference_steps,
240
- guidance_scale=guidance_scale,
241
- width=width,
242
- height=height,
243
- generator=generator,
244
- **extra_kwargs
245
- )
246
- image = output.images[0]
247
- latents = getattr(output, "latents", None) # не во всех версиях diffusers есть latents
248
- except Exception as e:
249
- raise gr.Error(f"Ошибка п��и генерации изображения: {e}")
250
-
251
- # ---------------------------
252
- # 6) Применяем IP-Adapter к результату (если нужно).
253
- # В реальных библиотеках IP-Adapter может вмешиваться раньше (до/во время диффузии).
254
- # Для примера демонстрируем "пост-обработку latents" (если latents сохранились).
255
- # ---------------------------
256
- if use_ip_adapter and ip_adapter_embedding is not None and latents is not None:
257
- try:
258
- # Простейший «пример» подмешивания в латенты
259
- new_latents = ip_adapter_model.blend_latents_with_adapter(latents, ip_adapter_embedding, ip_adapter_scale)
260
-
261
- # Теперь нужно декодировать latents в картинку заново
262
- # (подразумеваем, что local_pipe поддерживает .vae.decode())
263
- new_latents = new_latents.to(dtype=pipe.vae.dtype)
264
- image = pipe.vae.decode(new_latents / 0.18215)
265
- image = (image / 2 + 0.5).clamp(0, 1)
266
- image = image.detach().cpu().permute(0, 2, 3, 1).numpy()[0]
267
- image = (image * 255).astype(np.uint8)
268
- image = Image.fromarray(image)
269
 
270
- except Exception as e:
271
- raise gr.Error(f"Ошибка при применении IP-Adapter: {e}")
272
 
273
- return image, seed
274
 
275
- # ------------------------------------------------------------------
276
  # Примеры для удобного тестирования
277
- # ------------------------------------------------------------------
278
  examples = [
279
  "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
280
  "An astronaut riding a green horse",
281
  "A delicious ceviche cheesecake slice",
282
  ]
283
 
284
- # ------------------------------------------------------------------
285
- # CSS (дополнительно, опционально)
286
- # ------------------------------------------------------------------
287
  css = """
288
  #col-container {
289
  margin: 0 auto;
@@ -291,167 +187,151 @@ css = """
291
  }
292
  """
293
 
294
- # ------------------------------------------------------------------
295
  # Создаём Gradio-приложение
296
- # ------------------------------------------------------------------
297
- import sys
298
-
299
- def run_app():
300
- with gr.Blocks(css=css) as demo:
301
- with gr.Column(elem_id="col-container"):
302
- gr.Markdown("# Text-to-Image App (ControlNet + IP-Adapter)")
303
-
304
- # Поле для ввода/смены модели
305
- model = gr.Textbox(
306
- label="Model (HuggingFace repo)",
307
- value="CompVis/stable-diffusion-v1-4",
308
- interactive=True
309
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
 
311
- # Основные поля для Prompt и Negative Prompt
312
- prompt = gr.Text(
313
- label="Prompt",
314
- show_label=False,
315
- max_lines=1,
316
- placeholder="Enter your prompt",
317
- container=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
318
  )
319
- negative_prompt = gr.Text(
320
- label="Negative prompt",
321
- max_lines=1,
322
- placeholder="Enter a negative prompt",
323
- visible=True,
324
  )
325
-
326
- # Слайдер для выбора seed
327
- seed = gr.Slider(
328
- label="Seed",
329
- minimum=0,
330
- maximum=MAX_SEED,
331
- step=1,
332
- value=42,
333
  )
334
 
335
- # Слайдеры
336
- guidance_scale = gr.Slider(
337
- label="Guidance scale",
 
 
 
338
  minimum=0.0,
339
- maximum=15.0,
340
  step=0.1,
341
- value=7.0,
342
  )
343
- num_inference_steps = gr.Slider(
344
- label="Number of inference steps",
345
- minimum=1,
346
- maximum=100,
347
- step=1,
348
- value=20,
349
  )
350
 
351
- # Кнопка запуска
352
- run_button = gr.Button("Run", variant="primary")
353
-
354
- # Поле для отображения результата
355
- result = gr.Image(label="Result", show_label=False)
356
-
357
- # Продвинутые настройки
358
- with gr.Accordion("Advanced Settings", open=False):
359
- with gr.Row():
360
- width = gr.Slider(
361
- label="Width",
362
- minimum=256,
363
- maximum=MAX_IMAGE_SIZE,
364
- step=64,
365
- value=512,
366
- )
367
- height = gr.Slider(
368
- label="Height",
369
- minimum=256,
370
- maximum=MAX_IMAGE_SIZE,
371
- step=64,
372
- value=512,
373
- )
374
-
375
- # Блоки ControlNet
376
- use_controlnet = gr.Checkbox(label="Use ControlNet", value=False)
377
- with gr.Group(visible=False) as controlnet_group:
378
- control_strength = gr.Slider(
379
- label="ControlNet Strength (Conditioning Scale)",
380
- minimum=0.0,
381
- maximum=2.0,
382
- step=0.1,
383
- value=1.0,
384
- )
385
- controlnet_mode = gr.Dropdown(
386
- label="ControlNet Mode",
387
- choices=["edge_detection", "pose_estimation", "depth_estimation"],
388
- value="edge_detection",
389
- )
390
- controlnet_image = gr.Image(
391
- label="ControlNet Image (map / pose / edges)",
392
- type="pil"
393
- )
394
-
395
- def update_controlnet_group(use_controlnet):
396
- return {"visible": use_controlnet}
397
-
398
- use_controlnet.change(
399
- update_controlnet_group,
400
- inputs=[use_controlnet],
401
- outputs=[controlnet_group]
402
  )
403
-
404
- # Блоки IP-adapter
405
- use_ip_adapter = gr.Checkbox(label="Use IP-adapter", value=False)
406
- with gr.Group(visible=False) as ip_adapter_group:
407
- ip_adapter_scale = gr.Slider(
408
- label="IP-adapter Scale",
409
- minimum=0.0,
410
- maximum=2.0,
411
- step=0.1,
412
- value=1.0,
413
- )
414
- ip_adapter_image = gr.Image(
415
- label="IP-adapter Image (reference)",
416
- type="pil"
417
- )
418
-
419
- def update_ip_adapter_group(use_ip_adapter):
420
- return {"visible": use_ip_adapter}
421
-
422
- use_ip_adapter.change(
423
- update_ip_adapter_group,
424
- inputs=[use_ip_adapter],
425
- outputs=[ip_adapter_group]
426
  )
427
 
428
- # Примеры
429
- gr.Examples(examples=examples, inputs=[prompt])
430
-
431
- # Связка кнопки "Run" с функцией "infer"
432
- run_button.click(
433
- infer,
434
- inputs=[
435
- model,
436
- prompt,
437
- negative_prompt,
438
- seed,
439
- width,
440
- height,
441
- guidance_scale,
442
- num_inference_steps,
443
- use_controlnet,
444
- control_strength,
445
- controlnet_mode,
446
- controlnet_image,
447
- use_ip_adapter,
448
- ip_adapter_scale,
449
- ip_adapter_image
450
- ],
451
- outputs=[result, seed],
452
- )
453
-
454
- demo.launch()
455
 
 
456
  if __name__ == "__main__":
457
- run_app()
 
1
  import gradio as gr
2
  import numpy as np
3
  import torch
4
+ from diffusers import DiffusionPipeline
 
 
 
 
 
 
 
5
  from peft import PeftModel
6
+ import re
7
  from PIL import Image
8
 
9
+ # Устройство и тип данных
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
12
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  # Регулярное выражение для проверки корректности модели
 
14
  VALID_REPO_ID_REGEX = re.compile(r"^[a-zA-Z0-9._\-]+/[a-zA-Z0-9._\-]+$")
 
 
15
 
 
 
 
 
 
16
 
17
+ def is_valid_repo_id(repo_id):
18
+ return bool(VALID_REPO_ID_REGEX.match(repo_id)) and not repo_id.endswith(("-", "."))
19
+
20
+ # Базовые константы
21
  MAX_SEED = np.iinfo(np.int32).max
22
  MAX_IMAGE_SIZE = 1024
23
 
24
+ # Изначально загружаем модель по умолчанию (без ControlNet/IP-adapter)
 
 
25
  model_repo_id = "CompVis/stable-diffusion-v1-4"
26
+ pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype, safety_checker=None).to(device)
27
 
28
+ # Попробуем подгрузить LoRA-модификации (unet + text_encoder)
 
 
 
 
 
 
 
 
29
  try:
30
  pipe.unet = PeftModel.from_pretrained(pipe.unet, "./unet")
31
  pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, "./text_encoder")
32
  except Exception as e:
33
  print(f"Не удалось подгрузить LoRA по умолчанию: {e}")
34
 
 
 
 
 
 
 
 
 
 
35
 
 
 
 
36
  def infer(
37
+ model,
38
+ prompt,
39
+ negative_prompt,
40
+ seed,
41
+ width,
42
+ height,
43
+ guidance_scale,
44
+ num_inference_steps,
45
+ use_controlnet,
46
+ control_strength,
47
+ controlnet_mode,
48
+ controlnet_image,
49
+ use_ip_adapter,
50
+ ip_adapter_scale,
51
+ ip_adapter_image,
52
  progress=gr.Progress(track_tqdm=True),
53
  ):
54
+ """
55
+ Функция генерации изображения с учётом дополнительных опций:
56
+ - Если включён ControlNet или IP‑adapter, используется пайплайн StableDiffusionControlNetPipeline.
57
+ - При включённом IP‑adapter без ControlNet создаётся пустое (заглушка) изображение для параметра controlnet.
58
+ - В остальных случаях используется стандартный пайплайн.
59
+ """
60
+ global model_repo_id, pipe
61
+
62
+ # Если хотя бы один из режимов (ControlNet или IP‑adapter) включён, переключаемся на ControlNet‑пайплайн
63
+ if use_controlnet or use_ip_adapter:
64
+ # Если модель изменилась или текущий pipe не поддерживает IP‑adapter (нет метода load_ip_adapter),
65
+ # загружаем новый пайплайн.
66
+ if model != model_repo_id or not hasattr(pipe, "load_ip_adapter"):
67
+ try:
68
+ # Импорт необходимых классов внутри функции (если они не нужны при базовой генерации)
69
+ from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
70
+ except ImportError as e:
71
+ raise gr.Error(f"Не удалось импортировать необходимые модули для ControlNet: {e}")
72
+
73
+ # Определяем, какую модель ControlNet использовать.
74
+ if use_controlnet:
75
+ if controlnet_mode == "edge_detection":
76
+ cn_model_id = "lllyasviel/sd-controlnet-canny"
77
+ elif controlnet_mode == "pose_estimation":
78
+ cn_model_id = "lllyasviel/sd-controlnet-openpose"
79
+ else:
80
+ cn_model_id = "lllyasviel/sd-controlnet-canny"
81
+ else:
82
+ # Если включён только IP‑adapter, используем модель по умолчанию (например, canny)
83
+ cn_model_id = "lllyasviel/sd-controlnet-canny"
84
 
85
+ try:
86
+ controlnet = ControlNetModel.from_pretrained(cn_model_id, torch_dtype=torch_dtype)
87
+ new_pipe = StableDiffusionControlNetPipeline.from_pretrained(
88
+ model, torch_dtype=torch_dtype, controlnet=controlnet
89
+ ).to(device)
90
+ new_pipe.safety_checker = None
91
+
92
+ # Подгружаем LoRA-модификации (если они есть)
93
+ try:
94
+ new_pipe.unet = PeftModel.from_pretrained(new_pipe.unet, "./unet")
95
+ new_pipe.text_encoder = PeftModel.from_pretrained(new_pipe.text_encoder, "./text_encoder")
96
+ except Exception as e:
97
+ print(f"Не удалось подгрузить LoRA: {e}")
98
+
99
+ # Если включён IP‑adapter, загружаем его и устанавливаем масштаб.
100
+ if use_ip_adapter:
101
+ new_pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin")
102
+ new_pipe.set_ip_adapter_scale(ip_adapter_scale)
103
+
104
+ pipe = new_pipe
105
+ model_repo_id = model
106
+ except Exception as e:
107
+ raise gr.Error(f"Не удалось загрузить модель с ControlNet/IP-adapter '{model}'.\nОшибка: {e}")
108
+
109
+ # Подготавливаем изображение для передачи в ControlNet.
110
+ # Если включён ControlNet, пользователь должен загрузить изображение.
111
+ # Если нет, но включён IP‑adapter, создаём пустое изображение-заглушку.
112
+ if use_controlnet:
113
+ if controlnet_image is None:
114
+ raise gr.Error("ControlNet включён, но изображение для него не загружено.")
115
+ cn_image = controlnet_image
116
+ else:
117
+ cn_image = Image.new("RGB", (width, height), (255, 255, 255))
118
 
119
  try:
120
+ generator = torch.Generator(device=device).manual_seed(seed)
121
+ # Вызываем пайплайн StableDiffusionControlNetPipeline.
122
+ # Первый позиционный аргумент — prompt, второй — изображение для управления (control image).
123
+ output = pipe(
124
+ prompt=prompt,
125
+ image=cn_image,
126
+ negative_prompt=negative_prompt,
127
+ guidance_scale=guidance_scale,
128
+ num_inference_steps=num_inference_steps,
129
+ width=width,
130
+ height=height,
131
+ generator=generator,
132
+ controlnet_conditioning_scale=control_strength if use_controlnet else 1.0,
133
+ ip_adapter_image=ip_adapter_image if use_ip_adapter else None,
134
+ )
135
+ image = output.images[0]
136
+ except Exception as e:
137
+ raise gr.Error(f"Ошибка при генерации изображения с ControlNet/IP-adapter: {e}")
138
+
139
+ return image, seed
140
 
141
+ else:
142
+ # Если ни один из дополнительных режимов не включён, используем стандартный пайплайн.
143
+ if model != model_repo_id:
144
+ if not is_valid_repo_id(model):
145
+ raise gr.Error(f"Некорректный идентификатор модели: '{model}'. Проверьте название.")
146
  try:
147
+ new_pipe = DiffusionPipeline.from_pretrained(model, torch_dtype=torch_dtype).to(device)
148
+ try:
149
+ new_pipe.unet = PeftModel.from_pretrained(new_pipe.unet, "./unet")
150
+ new_pipe.text_encoder = PeftModel.from_pretrained(new_pipe.text_encoder, "./text_encoder")
151
+ except Exception as e:
152
+ print(f"Не удалось подгрузить LoRA: {e}")
153
+ pipe = new_pipe
154
+ model_repo_id = model
155
  except Exception as e:
156
+ raise gr.Error(f"Не удалось загрузить модель '{model}'.\nОшибка: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
  try:
159
+ generator = torch.Generator(device=device).manual_seed(seed)
160
+ image = pipe(
161
+ prompt=prompt,
162
+ negative_prompt=negative_prompt,
163
+ guidance_scale=guidance_scale,
164
+ num_inference_steps=num_inference_steps,
165
+ width=width,
166
+ height=height,
167
+ generator=generator,
168
+ ).images[0]
 
 
 
 
 
 
 
 
169
  except Exception as e:
170
+ raise gr.Error(f"Ошибка при генерации изображения: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
+ return image, seed
 
173
 
 
174
 
 
175
  # Примеры для удобного тестирования
 
176
  examples = [
177
  "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
178
  "An astronaut riding a green horse",
179
  "A delicious ceviche cheesecake slice",
180
  ]
181
 
182
+ # Дополнительный CSS для оформления
 
 
183
  css = """
184
  #col-container {
185
  margin: 0 auto;
 
187
  }
188
  """
189
 
 
190
  # Создаём Gradio-приложение
191
+ with gr.Blocks(css=css) as demo:
192
+ with gr.Column(elem_id="col-container"):
193
+ gr.Markdown("# Text-to-Image App")
194
+
195
+ # Поле для ввода/смены модели
196
+ model = gr.Textbox(
197
+ label="Model",
198
+ value="CompVis/stable-diffusion-v1-4", # Значение по умолчанию
199
+ interactive=True,
200
+ )
201
+
202
+ # Основные поля для Prompt и Negative Prompt
203
+ prompt = gr.Text(
204
+ label="Prompt",
205
+ show_label=False,
206
+ max_lines=1,
207
+ placeholder="Enter your prompt",
208
+ container=False,
209
+ )
210
+ negative_prompt = gr.Text(
211
+ label="Negative prompt",
212
+ max_lines=1,
213
+ placeholder="Enter a negative prompt",
214
+ visible=True,
215
+ )
216
+
217
+ # Слайдер для выбора seed
218
+ seed = gr.Slider(
219
+ label="Seed",
220
+ minimum=0,
221
+ maximum=MAX_SEED,
222
+ step=1,
223
+ value=42,
224
+ )
225
 
226
+ # Слайдеры для guidance_scale и num_inference_steps
227
+ guidance_scale = gr.Slider(
228
+ label="Guidance scale",
229
+ minimum=0.0,
230
+ maximum=10.0,
231
+ step=0.1,
232
+ value=7.0,
233
+ )
234
+ num_inference_steps = gr.Slider(
235
+ label="Number of inference steps",
236
+ minimum=1,
237
+ maximum=50,
238
+ step=1,
239
+ value=20,
240
+ )
241
+
242
+ # Чекбокс для включения ControlNet
243
+ use_controlnet = gr.Checkbox(label="Использовать ControlNet", value=False)
244
+ # Группа дополнительных настроек для ControlNet (будет показана только при включённом чекбоксе)
245
+ with gr.Group(visible=False) as controlnet_group:
246
+ control_strength = gr.Slider(
247
+ label="ControlNet conditioning scale",
248
+ minimum=0.0,
249
+ maximum=2.0,
250
+ step=0.1,
251
+ value=0.7,
252
  )
253
+ controlnet_mode = gr.Dropdown(
254
+ label="Режим работы ControlNet",
255
+ choices=["edge_detection", "pose_estimation"],
256
+ value="edge_detection",
 
257
  )
258
+ controlnet_image = gr.Image(
259
+ label="Изображение для ControlNet",
260
+ source="upload",
261
+ type="pil",
 
 
 
 
262
  )
263
 
264
+ # Чекбокс для включения IP‑adapter
265
+ use_ip_adapter = gr.Checkbox(label="Использовать IP-adapter", value=False)
266
+ # Группа дополнительных настроек для IP‑adapter
267
+ with gr.Group(visible=False) as ip_adapter_group:
268
+ ip_adapter_scale = gr.Slider(
269
+ label="IP-adapter Scale",
270
  minimum=0.0,
271
+ maximum=2.0,
272
  step=0.1,
273
+ value=0.6,
274
  )
275
+ ip_adapter_image = gr.Image(
276
+ label="Изображение для IP-adapter",
277
+ source="upload",
278
+ type="pil",
 
 
279
  )
280
 
281
+ # Обработка событий для показа/скрытия дополнительных настроек
282
+ use_controlnet.change(lambda x: gr.update(visible=x), inputs=use_controlnet, outputs=controlnet_group)
283
+ use_ip_adapter.change(lambda x: gr.update(visible=x), inputs=use_ip_adapter, outputs=ip_adapter_group)
284
+
285
+ # Кнопка запуска
286
+ run_button = gr.Button("Run", variant="primary")
287
+
288
+ # Поле для отображения результата
289
+ result = gr.Image(label="Result", show_label=False)
290
+
291
+ # Продвинутые настройки (Accordion)
292
+ with gr.Accordion("Advanced Settings", open=False):
293
+ with gr.Row():
294
+ width = gr.Slider(
295
+ label="Width",
296
+ minimum=256,
297
+ maximum=MAX_IMAGE_SIZE,
298
+ step=32,
299
+ value=512,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
  )
301
+ height = gr.Slider(
302
+ label="Height",
303
+ minimum=256,
304
+ maximum=MAX_IMAGE_SIZE,
305
+ step=32,
306
+ value=512,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
  )
308
 
309
+ # Примеры
310
+ gr.Examples(examples=examples, inputs=[prompt])
311
+
312
+ # Связка кнопки "Run" с функцией "infer"
313
+ run_button.click(
314
+ infer,
315
+ inputs=[
316
+ model,
317
+ prompt,
318
+ negative_prompt,
319
+ seed,
320
+ width,
321
+ height,
322
+ guidance_scale,
323
+ num_inference_steps,
324
+ use_controlnet,
325
+ control_strength,
326
+ controlnet_mode,
327
+ controlnet_image,
328
+ use_ip_adapter,
329
+ ip_adapter_scale,
330
+ ip_adapter_image,
331
+ ],
332
+ outputs=[result, seed],
333
+ )
 
 
334
 
335
+ # Запуск приложения
336
  if __name__ == "__main__":
337
+ demo.launch()