import gradio as gr import numpy as np import torch import re from diffusers import ( StableDiffusionPipeline, ControlNetModel, StableDiffusionControlNetPipeline, DDIMScheduler, ) from peft import PeftModel from PIL import Image # ------------------------------------------------------------------ # Пример «заготовки» для IP-Adapter: # Предполагается, что у вас есть некий класс, умеющий: # 1) Загружать веса IP-Adapter # 2) Преобразовывать дополнительное «референс-изображение» в эмбеддинг # 3) Подмешивать этот эмбеддинг в процесс диффузии или текстовые эмбеддинги # ------------------------------------------------------------------ class IPAdapterModel: def __init__(self, path_to_weights: str, device="cpu"): """ Инициализация и загрузка весов IP-Adapter. path_to_weights - путь к файлам модели """ # Здесь должен быть код инициализации вашей модели. # Например, что-то вроде: # self.model = torch.load(path_to_weights, map_location=device) # self.model.eval() # ... self.device = device self.dummy_weights_loaded = True # признак, что "что-то" загрузили def encode_reference_image(self, image: Image.Image): """ Преобразовать референс-изображение в некий вектор (embedding), который затем можно использовать для модификации генерации. """ # В реальном коде будет извлечение фич. # Для примера вернём фиктивный тензор. dummy_embedding = torch.zeros((1, 768)).to(self.device) return dummy_embedding def blend_latents_with_adapter(self, latents: torch.Tensor, adapter_embedding: torch.Tensor, scale: float): """ Примерная функция, которая «подмешивает» признаки из адаптера в латенты перед декодированием. latents: (batch, channels, height, width) adapter_embedding: (1, embedding_dim) scale: сила влияния адаптера """ # Для демонстрации просто прибавим (scale * mean(adapter_embedding)) # В реальном IP-Adapter это гораздо сложнее. if adapter_embedding is not None: # Возьмём скаляр (к примеру) mean_val = adapter_embedding.mean() latents = latents + scale * mean_val return latents # ------------------------------------------------------------------ # Регулярное выражение для проверки корректности модели # ------------------------------------------------------------------ VALID_REPO_ID_REGEX = re.compile(r"^[a-zA-Z0-9._\-]+/[a-zA-Z0-9._\-]+$") def is_valid_repo_id(repo_id): return bool(VALID_REPO_ID_REGEX.match(repo_id)) and not repo_id.endswith(('-', '.')) # ------------------------------------------------------------------ # Аппаратные настройки # ------------------------------------------------------------------ device = "cuda" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 # ------------------------------------------------------------------ # Константы # ------------------------------------------------------------------ MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 1024 # ------------------------------------------------------------------ # Базовая модель (Stable Diffusion) по умолчанию # ------------------------------------------------------------------ model_repo_id = "CompVis/stable-diffusion-v1-4" # Загрузка базового пайплайна (без ControlNet) pipe = StableDiffusionPipeline.from_pretrained( model_repo_id, torch_dtype=torch_dtype, safety_checker=None ).to(device) # Применим DDIM-схему как пример pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) # Пробуем подгрузить LoRA (unet + text_encoder) try: pipe.unet = PeftModel.from_pretrained(pipe.unet, "./unet") pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, "./text_encoder") except Exception as e: print(f"Не удалось подгрузить LoRA по умолчанию: {e}") # ------------------------------------------------------------------ # Инициализация «IP-Adapter» (для примера укажем вымышленный путь). # Предположим, что IP-Adapter мы храним в ./ip_adapter_weights # ------------------------------------------------------------------ ip_adapter_model = None try: ip_adapter_model = IPAdapterModel("./ip_adapter_weights", device=device) except Exception as e: print(f"Не удалось загрузить IP-Adapter: {e}") # ------------------------------------------------------------------ # Функция генерации # ------------------------------------------------------------------ def infer( model, # Текстовое поле: модель (repo) напр. "CompVis/stable-diffusion-v1-4" prompt, # Текст: позитивный промпт negative_prompt, # Текст: негативный промпт seed, # Сид генератора width, # Ширина height, # Высота guidance_scale, # guidance scale num_inference_steps, # Количество шагов диффузии use_controlnet, # Чекбокс: включать ли ControlNet control_strength, # Слайдер: сила влияния ControlNet controlnet_mode, # Выпадающий список: edge_detection, pose_estimation, depth_estimation controlnet_image, # Изображение для ControlNet use_ip_adapter, # Чекбокс: включать ли IP-adapter ip_adapter_scale, # Слайдер: сила влияния IP-adapter ip_adapter_image, # Изображение для IP-adapter progress=gr.Progress(track_tqdm=True), ): global model_repo_id, pipe, ip_adapter_model # --------------------------- # 1) Проверяем, не сменил ли пользователь модель # --------------------------- if model != model_repo_id: if not is_valid_repo_id(model): raise gr.Error(f"Некорректный идентификатор модели: '{model}'. Проверьте название.") try: # Подгружаем модель (без ControlNet) new_pipe = StableDiffusionPipeline.from_pretrained( model, torch_dtype=torch_dtype, safety_checker=None ).to(device) new_pipe.scheduler = DDIMScheduler.from_config(new_pipe.scheduler.config) # Повторно загружаем LoRA try: new_pipe.unet = PeftModel.from_pretrained(new_pipe.unet, "./unet") new_pipe.text_encoder = PeftModel.from_pretrained(new_pipe.text_encoder, "./text_encoder") except Exception as e: print(f"Не удалось подгрузить LoRA для новой модели: {e}") pipe = new_pipe model_repo_id = model except Exception as e: raise gr.Error(f"Не удалось загрузить модель '{model}'.\nОшибка: {e}") # --------------------------- # 2) Если включён ControlNet — создаём ControlNetPipeline # --------------------------- local_pipe = pipe # по умолчанию используем базовый pipe if use_controlnet: # Выбираем репозиторий ControlNet в зависимости от режима if controlnet_mode == "edge_detection": controlnet_repo = "lllyasviel/sd-controlnet-canny" elif controlnet_mode == "pose_estimation": controlnet_repo = "lllyasviel/sd-controlnet-openpose" elif controlnet_mode == "depth_estimation": controlnet_repo = "lllyasviel/sd-controlnet-depth" else: raise gr.Error(f"Неизвестный режим ControlNet: {controlnet_mode}") try: controlnet_model = ControlNetModel.from_pretrained( controlnet_repo, torch_dtype=torch_dtype ).to(device) # Создаём новый pipeline, указывая ControlNet local_pipe = StableDiffusionControlNetPipeline( vae=pipe.vae, text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer, unet=pipe.unet, controlnet=controlnet_model, scheduler=pipe.scheduler, safety_checker=None, feature_extractor=pipe.feature_extractor, requires_safety_checker=False, ).to(device) except Exception as e: raise gr.Error(f"Ошибка загрузки ControlNet ({controlnet_mode}): {e}") # --------------------------- # 3) Генератор случайных чисел для детерминированности # --------------------------- generator = torch.Generator(device=device).manual_seed(seed) # --------------------------- # 4) Если есть IP-Adapter, подгружаем фичи из референс-изображения # --------------------------- ip_adapter_embedding = None if use_ip_adapter and ip_adapter_model is not None and ip_adapter_model.dummy_weights_loaded: if ip_adapter_image is not None: ip_adapter_embedding = ip_adapter_model.encode_reference_image(ip_adapter_image) else: print("IP-Adapter включён, но не загружено референс-изображение.") elif use_ip_adapter: print("IP-Adapter включён, но модель не загружена или не инициализирована.") # --------------------------- # 5) Выполняем диффузию # (с учётом ControlNet, если включён) # --------------------------- # Параметры для ControlNetPipeline # - Для edge/pose/depth обычно передают control_image через параметр "image" # - Дополнительно можно задать "controlnet_conditioning_scale" (aka strength) # чтобы указать вес ControlNet. # - Документация: https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/controlnet extra_kwargs = {} if use_controlnet and controlnet_image is not None: extra_kwargs["image"] = controlnet_image extra_kwargs["controlnet_conditioning_scale"] = control_strength elif use_controlnet: print("ControlNet включён, но не загружено изображение для ControlNet.") # Запуск генерации try: output = local_pipe( prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, width=width, height=height, generator=generator, **extra_kwargs ) image = output.images[0] latents = getattr(output, "latents", None) # не во всех версиях diffusers есть latents except Exception as e: raise gr.Error(f"Ошибка при генерации изображения: {e}") # --------------------------- # 6) Применяем IP-Adapter к результату (если нужно). # В реальных библиотеках IP-Adapter может вмешиваться раньше (до/во время диффузии). # Для примера демонстрируем "пост-обработку latents" (если latents сохранились). # --------------------------- if use_ip_adapter and ip_adapter_embedding is not None and latents is not None: try: # Простейший «пример» подмешивания в латенты new_latents = ip_adapter_model.blend_latents_with_adapter(latents, ip_adapter_embedding, ip_adapter_scale) # Теперь нужно декодировать latents в картинку заново # (подразумеваем, что local_pipe поддерживает .vae.decode()) new_latents = new_latents.to(dtype=pipe.vae.dtype) image = pipe.vae.decode(new_latents / 0.18215) image = (image / 2 + 0.5).clamp(0, 1) image = image.detach().cpu().permute(0, 2, 3, 1).numpy()[0] image = (image * 255).astype(np.uint8) image = Image.fromarray(image) except Exception as e: raise gr.Error(f"Ошибка при применении IP-Adapter: {e}") return image, seed # ------------------------------------------------------------------ # Примеры для удобного тестирования # ------------------------------------------------------------------ examples = [ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", "An astronaut riding a green horse", "A delicious ceviche cheesecake slice", ] # ------------------------------------------------------------------ # CSS (дополнительно, опционально) # ------------------------------------------------------------------ css = """ #col-container { margin: 0 auto; max-width: 640px; } """ # ------------------------------------------------------------------ # Создаём Gradio-приложение # ------------------------------------------------------------------ import sys def run_app(): with gr.Blocks(css=css) as demo: with gr.Column(elem_id="col-container"): gr.Markdown("# Text-to-Image App (ControlNet + IP-Adapter)") # Поле для ввода/смены модели model = gr.Textbox( label="Model (HuggingFace repo)", value="CompVis/stable-diffusion-v1-4", interactive=True ) # Основные поля для Prompt и Negative Prompt prompt = gr.Text( label="Prompt", show_label=False, max_lines=1, placeholder="Enter your prompt", container=False, ) negative_prompt = gr.Text( label="Negative prompt", max_lines=1, placeholder="Enter a negative prompt", visible=True, ) # Слайдер для выбора seed seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, ) # Слайдеры guidance_scale = gr.Slider( label="Guidance scale", minimum=0.0, maximum=15.0, step=0.1, value=7.0, ) num_inference_steps = gr.Slider( label="Number of inference steps", minimum=1, maximum=100, step=1, value=20, ) # Кнопка запуска run_button = gr.Button("Run", variant="primary") # Поле для отображения результата result = gr.Image(label="Result", show_label=False) # Продвинутые настройки with gr.Accordion("Advanced Settings", open=False): with gr.Row(): width = gr.Slider( label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=64, value=512, ) height = gr.Slider( label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=64, value=512, ) # Блоки ControlNet use_controlnet = gr.Checkbox(label="Use ControlNet", value=False) with gr.Group(visible=False) as controlnet_group: control_strength = gr.Slider( label="ControlNet Strength (Conditioning Scale)", minimum=0.0, maximum=2.0, step=0.1, value=1.0, ) controlnet_mode = gr.Dropdown( label="ControlNet Mode", choices=["edge_detection", "pose_estimation", "depth_estimation"], value="edge_detection", ) controlnet_image = gr.Image( label="ControlNet Image (map / pose / edges)", type="pil" ) def update_controlnet_group(use_controlnet): return {"visible": use_controlnet} use_controlnet.change( update_controlnet_group, inputs=[use_controlnet], outputs=[controlnet_group] ) # Блоки IP-adapter use_ip_adapter = gr.Checkbox(label="Use IP-adapter", value=False) with gr.Group(visible=False) as ip_adapter_group: ip_adapter_scale = gr.Slider( label="IP-adapter Scale", minimum=0.0, maximum=2.0, step=0.1, value=1.0, ) ip_adapter_image = gr.Image( label="IP-adapter Image (reference)", type="pil" ) def update_ip_adapter_group(use_ip_adapter): return {"visible": use_ip_adapter} use_ip_adapter.change( update_ip_adapter_group, inputs=[use_ip_adapter], outputs=[ip_adapter_group] ) # Примеры gr.Examples(examples=examples, inputs=[prompt]) # Связка кнопки "Run" с функцией "infer" run_button.click( infer, inputs=[ model, prompt, negative_prompt, seed, width, height, guidance_scale, num_inference_steps, use_controlnet, control_strength, controlnet_mode, controlnet_image, use_ip_adapter, ip_adapter_scale, ip_adapter_image ], outputs=[result, seed], ) demo.launch() if __name__ == "__main__": run_app()