File size: 20,198 Bytes
7d3e201
 
 
52a10bb
7d3e201
ba1587e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52a10bb
ba1587e
52a10bb
 
 
7d3e201
ba1587e
 
 
 
 
 
 
 
 
7d3e201
 
 
ba1587e
 
 
52a10bb
 
ba1587e
 
 
 
 
 
 
 
 
52a10bb
 
 
 
 
7d3e201
ba1587e
 
 
 
 
 
 
 
 
 
 
 
 
7d3e201
ba1587e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d3e201
 
ba1587e
 
 
 
 
52a10bb
 
 
7d3e201
52a10bb
ba1587e
 
 
 
 
 
 
52a10bb
 
 
 
ba1587e
52a10bb
 
 
 
 
 
 
ba1587e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52a10bb
7d3e201
ba1587e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52a10bb
ba1587e
52a10bb
 
 
ba1587e
52a10bb
 
 
ba1587e
 
 
 
52a10bb
 
 
ba1587e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52a10bb
 
ba1587e
52a10bb
ba1587e
7d3e201
 
 
 
 
 
ba1587e
 
 
7d3e201
 
 
 
 
 
 
ba1587e
52a10bb
ba1587e
 
 
 
 
 
 
 
 
 
 
 
 
 
7d3e201
ba1587e
 
 
 
 
 
 
 
 
 
 
 
 
 
7d3e201
ba1587e
 
 
 
 
 
 
 
52a10bb
ba1587e
 
 
 
 
 
 
 
 
 
 
 
 
 
52a10bb
 
ba1587e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d3e201
 
ba1587e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52a10bb
ba1587e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52a10bb
 
ba1587e
52a10bb
7d3e201
ba1587e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
import gradio as gr
import numpy as np
import torch
import re

from diffusers import (
    StableDiffusionPipeline,
    ControlNetModel,
    StableDiffusionControlNetPipeline,
    DDIMScheduler,
)
from peft import PeftModel
from PIL import Image

# ------------------------------------------------------------------
# Пример «заготовки» для IP-Adapter:
# Предполагается, что у вас есть некий класс, умеющий:
# 1) Загружать веса IP-Adapter
# 2) Преобразовывать дополнительное «референс-изображение» в эмбеддинг
# 3) Подмешивать этот эмбеддинг в процесс диффузии или текстовые эмбеддинги
# ------------------------------------------------------------------
class IPAdapterModel:
    def __init__(self, path_to_weights: str, device="cpu"):
        """
        Инициализация и загрузка весов IP-Adapter.
        path_to_weights - путь к файлам модели
        """
        # Здесь должен быть код инициализации вашей модели.
        # Например, что-то вроде:
        # self.model = torch.load(path_to_weights, map_location=device)
        # self.model.eval()
        # ...
        self.device = device
        self.dummy_weights_loaded = True  # признак, что "что-то" загрузили

    def encode_reference_image(self, image: Image.Image):
        """
        Преобразовать референс-изображение в некий вектор (embedding),
        который затем можно использовать для модификации генерации.
        """
        # В реальном коде будет извлечение фич.
        # Для примера вернём фиктивный тензор.
        dummy_embedding = torch.zeros((1, 768)).to(self.device)
        return dummy_embedding

    def blend_latents_with_adapter(self, latents: torch.Tensor, adapter_embedding: torch.Tensor, scale: float):
        """
        Примерная функция, которая «подмешивает» признаки из адаптера
        в латенты перед декодированием.
        latents: (batch, channels, height, width)
        adapter_embedding: (1, embedding_dim)
        scale: сила влияния адаптера
        """
        # Для демонстрации просто прибавим (scale * mean(adapter_embedding))
        # В реальном IP-Adapter это гораздо сложнее.
        if adapter_embedding is not None:
            # Возьмём скаляр (к примеру)
            mean_val = adapter_embedding.mean()
            latents = latents + scale * mean_val
        return latents


# ------------------------------------------------------------------
# Регулярное выражение для проверки корректности модели
# ------------------------------------------------------------------
VALID_REPO_ID_REGEX = re.compile(r"^[a-zA-Z0-9._\-]+/[a-zA-Z0-9._\-]+$")
def is_valid_repo_id(repo_id):
    return bool(VALID_REPO_ID_REGEX.match(repo_id)) and not repo_id.endswith(('-', '.'))

# ------------------------------------------------------------------
# Аппаратные настройки
# ------------------------------------------------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

# ------------------------------------------------------------------
# Константы
# ------------------------------------------------------------------
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024

# ------------------------------------------------------------------
# Базовая модель (Stable Diffusion) по умолчанию
# ------------------------------------------------------------------
model_repo_id = "CompVis/stable-diffusion-v1-4"

# Загрузка базового пайплайна (без ControlNet)
pipe = StableDiffusionPipeline.from_pretrained(
    model_repo_id, torch_dtype=torch_dtype, safety_checker=None
).to(device)

# Применим DDIM-схему как пример
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)

# Пробуем подгрузить LoRA (unet + text_encoder)
try:
    pipe.unet = PeftModel.from_pretrained(pipe.unet, "./unet")
    pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, "./text_encoder")
except Exception as e:
    print(f"Не удалось подгрузить LoRA по умолчанию: {e}")

# ------------------------------------------------------------------
# Инициализация «IP-Adapter» (для примера укажем вымышленный путь).
# Предположим, что IP-Adapter мы храним в ./ip_adapter_weights
# ------------------------------------------------------------------
ip_adapter_model = None
try:
    ip_adapter_model = IPAdapterModel("./ip_adapter_weights", device=device)
except Exception as e:
    print(f"Не удалось загрузить IP-Adapter: {e}")

# ------------------------------------------------------------------
# Функция генерации
# ------------------------------------------------------------------
def infer(
    model,                 # Текстовое поле: модель (repo) напр. "CompVis/stable-diffusion-v1-4"
    prompt,                # Текст: позитивный промпт
    negative_prompt,       # Текст: негативный промпт
    seed,                  # Сид генератора
    width,                 # Ширина
    height,                # Высота
    guidance_scale,        # guidance scale
    num_inference_steps,   # Количество шагов диффузии
    use_controlnet,        # Чекбокс: включать ли ControlNet
    control_strength,      # Слайдер: сила влияния ControlNet
    controlnet_mode,       # Выпадающий список: edge_detection, pose_estimation, depth_estimation
    controlnet_image,      # Изображение для ControlNet
    use_ip_adapter,        # Чекбокс: включать ли IP-adapter
    ip_adapter_scale,      # Слайдер: сила влияния IP-adapter
    ip_adapter_image,      # Изображение для IP-adapter
    progress=gr.Progress(track_tqdm=True),
):
    global model_repo_id, pipe, ip_adapter_model

    # ---------------------------
    #  1) Проверяем, не сменил ли пользователь модель
    # ---------------------------
    if model != model_repo_id:
        if not is_valid_repo_id(model):
            raise gr.Error(f"Некорректный идентификатор модели: '{model}'. Проверьте название.")

        try:
            # Подгружаем модель (без ControlNet)
            new_pipe = StableDiffusionPipeline.from_pretrained(
                model, torch_dtype=torch_dtype, safety_checker=None
            ).to(device)
            new_pipe.scheduler = DDIMScheduler.from_config(new_pipe.scheduler.config)

            # Повторно загружаем LoRA
            try:
                new_pipe.unet = PeftModel.from_pretrained(new_pipe.unet, "./unet")
                new_pipe.text_encoder = PeftModel.from_pretrained(new_pipe.text_encoder, "./text_encoder")
            except Exception as e:
                print(f"Не удалось подгрузить LoRA для новой модели: {e}")

            pipe = new_pipe
            model_repo_id = model

        except Exception as e:
            raise gr.Error(f"Не удалось загрузить модель '{model}'.\nОшибка: {e}")

    # ---------------------------
    #  2) Если включён ControlNet — создаём ControlNetPipeline
    # ---------------------------
    local_pipe = pipe  # по умолчанию используем базовый pipe

    if use_controlnet:
        # Выбираем репозиторий ControlNet в зависимости от режима
        if controlnet_mode == "edge_detection":
            controlnet_repo = "lllyasviel/sd-controlnet-canny"
        elif controlnet_mode == "pose_estimation":
            controlnet_repo = "lllyasviel/sd-controlnet-openpose"
        elif controlnet_mode == "depth_estimation":
            controlnet_repo = "lllyasviel/sd-controlnet-depth"
        else:
            raise gr.Error(f"Неизвестный режим ControlNet: {controlnet_mode}")

        try:
            controlnet_model = ControlNetModel.from_pretrained(
                controlnet_repo,
                torch_dtype=torch_dtype
            ).to(device)

            # Создаём новый pipeline, указывая ControlNet
            local_pipe = StableDiffusionControlNetPipeline(
                vae=pipe.vae,
                text_encoder=pipe.text_encoder,
                tokenizer=pipe.tokenizer,
                unet=pipe.unet,
                controlnet=controlnet_model,
                scheduler=pipe.scheduler,
                safety_checker=None,
                feature_extractor=pipe.feature_extractor,
                requires_safety_checker=False,
            ).to(device)

        except Exception as e:
            raise gr.Error(f"Ошибка загрузки ControlNet ({controlnet_mode}): {e}")

    # ---------------------------
    #  3) Генератор случайных чисел для детерминированности
    # ---------------------------
    generator = torch.Generator(device=device).manual_seed(seed)

    # ---------------------------
    #  4) Если есть IP-Adapter, подгружаем фичи из референс-изображения
    # ---------------------------
    ip_adapter_embedding = None
    if use_ip_adapter and ip_adapter_model is not None and ip_adapter_model.dummy_weights_loaded:
        if ip_adapter_image is not None:
            ip_adapter_embedding = ip_adapter_model.encode_reference_image(ip_adapter_image)
        else:
            print("IP-Adapter включён, но не загружено референс-изображение.")
    elif use_ip_adapter:
        print("IP-Adapter включён, но модель не загружена или не инициализирована.")

    # ---------------------------
    #  5) Выполняем диффузию
    #    (с учётом ControlNet, если включён)
    # ---------------------------

    # Параметры для ControlNetPipeline
    #   - Для edge/pose/depth обычно передают control_image через параметр "image"
    #   - Дополнительно можно задать "controlnet_conditioning_scale" (aka strength)
    #     чтобы указать вес ControlNet.
    #   - Документация: https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/controlnet
    extra_kwargs = {}
    if use_controlnet and controlnet_image is not None:
        extra_kwargs["image"] = controlnet_image
        extra_kwargs["controlnet_conditioning_scale"] = control_strength
    elif use_controlnet:
        print("ControlNet включён, но не загружено изображение для ControlNet.")

    # Запуск генерации
    try:
        output = local_pipe(
            prompt=prompt,
            negative_prompt=negative_prompt,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
            width=width,
            height=height,
            generator=generator,
            **extra_kwargs
        )
        image = output.images[0]
        latents = getattr(output, "latents", None)  # не во всех версиях diffusers есть latents
    except Exception as e:
        raise gr.Error(f"Ошибка при генерации изображения: {e}")

    # ---------------------------
    #  6) Применяем IP-Adapter к результату (если нужно).
    #    В реальных библиотеках IP-Adapter может вмешиваться раньше (до/во время диффузии).
    #    Для примера демонстрируем "пост-обработку latents" (если latents сохранились).
    # ---------------------------
    if use_ip_adapter and ip_adapter_embedding is not None and latents is not None:
        try:
            # Простейший «пример» подмешивания в латенты
            new_latents = ip_adapter_model.blend_latents_with_adapter(latents, ip_adapter_embedding, ip_adapter_scale)

            # Теперь нужно декодировать latents в картинку заново
            # (подразумеваем, что local_pipe поддерживает .vae.decode())
            new_latents = new_latents.to(dtype=pipe.vae.dtype)
            image = pipe.vae.decode(new_latents / 0.18215)
            image = (image / 2 + 0.5).clamp(0, 1)
            image = image.detach().cpu().permute(0, 2, 3, 1).numpy()[0]
            image = (image * 255).astype(np.uint8)
            image = Image.fromarray(image)

        except Exception as e:
            raise gr.Error(f"Ошибка при применении IP-Adapter: {e}")

    return image, seed

# ------------------------------------------------------------------
# Примеры для удобного тестирования
# ------------------------------------------------------------------
examples = [
    "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
    "An astronaut riding a green horse",
    "A delicious ceviche cheesecake slice",
]

# ------------------------------------------------------------------
# CSS (дополнительно, опционально)
# ------------------------------------------------------------------
css = """
#col-container {
    margin: 0 auto;
    max-width: 640px;
}
"""

# ------------------------------------------------------------------
# Создаём Gradio-приложение
# ------------------------------------------------------------------
import sys

def run_app():
    with gr.Blocks(css=css) as demo:
        with gr.Column(elem_id="col-container"):
            gr.Markdown("# Text-to-Image App (ControlNet + IP-Adapter)")

            # Поле для ввода/смены модели
            model = gr.Textbox(
                label="Model (HuggingFace repo)",
                value="CompVis/stable-diffusion-v1-4",
                interactive=True
            )

            # Основные поля для Prompt и Negative Prompt
            prompt = gr.Text(
                label="Prompt",
                show_label=False,
                max_lines=1,
                placeholder="Enter your prompt",
                container=False,
            )
            negative_prompt = gr.Text(
                label="Negative prompt",
                max_lines=1,
                placeholder="Enter a negative prompt",
                visible=True,
            )

            # Слайдер для выбора seed
            seed = gr.Slider(
                label="Seed",
                minimum=0,
                maximum=MAX_SEED,
                step=1,
                value=42,
            )

            # Слайдеры
            guidance_scale = gr.Slider(
                label="Guidance scale",
                minimum=0.0,
                maximum=15.0,
                step=0.1,
                value=7.0,
            )
            num_inference_steps = gr.Slider(
                label="Number of inference steps",
                minimum=1,
                maximum=100,
                step=1,
                value=20,
            )

            # Кнопка запуска
            run_button = gr.Button("Run", variant="primary")

            # Поле для отображения результата
            result = gr.Image(label="Result", show_label=False)

            # Продвинутые настройки
            with gr.Accordion("Advanced Settings", open=False):
                with gr.Row():
                    width = gr.Slider(
                        label="Width",
                        minimum=256,
                        maximum=MAX_IMAGE_SIZE,
                        step=64,
                        value=512,
                    )
                    height = gr.Slider(
                        label="Height",
                        minimum=256,
                        maximum=MAX_IMAGE_SIZE,
                        step=64,
                        value=512,
                    )

                # Блоки ControlNet
                use_controlnet = gr.Checkbox(label="Use ControlNet", value=False)
                with gr.Group(visible=False) as controlnet_group:
                    control_strength = gr.Slider(
                        label="ControlNet Strength (Conditioning Scale)",
                        minimum=0.0,
                        maximum=2.0,
                        step=0.1,
                        value=1.0,
                    )
                    controlnet_mode = gr.Dropdown(
                        label="ControlNet Mode",
                        choices=["edge_detection", "pose_estimation", "depth_estimation"],
                        value="edge_detection",
                    )
                    controlnet_image = gr.Image(
                        label="ControlNet Image (map / pose / edges)",
                        type="pil"
                    )

                def update_controlnet_group(use_controlnet):
                    return {"visible": use_controlnet}

                use_controlnet.change(
                    update_controlnet_group,
                    inputs=[use_controlnet],
                    outputs=[controlnet_group]
                )

                # Блоки IP-adapter
                use_ip_adapter = gr.Checkbox(label="Use IP-adapter", value=False)
                with gr.Group(visible=False) as ip_adapter_group:
                    ip_adapter_scale = gr.Slider(
                        label="IP-adapter Scale",
                        minimum=0.0,
                        maximum=2.0,
                        step=0.1,
                        value=1.0,
                    )
                    ip_adapter_image = gr.Image(
                        label="IP-adapter Image (reference)",
                        type="pil"
                    )

                def update_ip_adapter_group(use_ip_adapter):
                    return {"visible": use_ip_adapter}

                use_ip_adapter.change(
                    update_ip_adapter_group,
                    inputs=[use_ip_adapter],
                    outputs=[ip_adapter_group]
                )

            # Примеры
            gr.Examples(examples=examples, inputs=[prompt])

            # Связка кнопки "Run" с функцией "infer"
            run_button.click(
                infer,
                inputs=[
                    model,
                    prompt,
                    negative_prompt,
                    seed,
                    width,
                    height,
                    guidance_scale,
                    num_inference_steps,
                    use_controlnet,
                    control_strength,
                    controlnet_mode,
                    controlnet_image,
                    use_ip_adapter,
                    ip_adapter_scale,
                    ip_adapter_image
                ],
                outputs=[result, seed],
            )

    demo.launch()

if __name__ == "__main__":
    run_app()