Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import torch | |
import re | |
from diffusers import ( | |
StableDiffusionPipeline, | |
ControlNetModel, | |
StableDiffusionControlNetPipeline, | |
DDIMScheduler, | |
) | |
from peft import PeftModel | |
from PIL import Image | |
# ------------------------------------------------------------------ | |
# Пример «заготовки» для IP-Adapter: | |
# Предполагается, что у вас есть некий класс, умеющий: | |
# 1) Загружать веса IP-Adapter | |
# 2) Преобразовывать дополнительное «референс-изображение» в эмбеддинг | |
# 3) Подмешивать этот эмбеддинг в процесс диффузии или текстовые эмбеддинги | |
# ------------------------------------------------------------------ | |
class IPAdapterModel: | |
def __init__(self, path_to_weights: str, device="cpu"): | |
""" | |
Инициализация и загрузка весов IP-Adapter. | |
path_to_weights - путь к файлам модели | |
""" | |
# Здесь должен быть код инициализации вашей модели. | |
# Например, что-то вроде: | |
# self.model = torch.load(path_to_weights, map_location=device) | |
# self.model.eval() | |
# ... | |
self.device = device | |
self.dummy_weights_loaded = True # признак, что "что-то" загрузили | |
def encode_reference_image(self, image: Image.Image): | |
""" | |
Преобразовать референс-изображение в некий вектор (embedding), | |
который затем можно использовать для модификации генерации. | |
""" | |
# В реальном коде будет извлечение фич. | |
# Для примера вернём фиктивный тензор. | |
dummy_embedding = torch.zeros((1, 768)).to(self.device) | |
return dummy_embedding | |
def blend_latents_with_adapter(self, latents: torch.Tensor, adapter_embedding: torch.Tensor, scale: float): | |
""" | |
Примерная функция, которая «подмешивает» признаки из адаптера | |
в латенты перед декодированием. | |
latents: (batch, channels, height, width) | |
adapter_embedding: (1, embedding_dim) | |
scale: сила влияния адаптера | |
""" | |
# Для демонстрации просто прибавим (scale * mean(adapter_embedding)) | |
# В реальном IP-Adapter это гораздо сложнее. | |
if adapter_embedding is not None: | |
# Возьмём скаляр (к примеру) | |
mean_val = adapter_embedding.mean() | |
latents = latents + scale * mean_val | |
return latents | |
# ------------------------------------------------------------------ | |
# Регулярное выражение для проверки корректности модели | |
# ------------------------------------------------------------------ | |
VALID_REPO_ID_REGEX = re.compile(r"^[a-zA-Z0-9._\-]+/[a-zA-Z0-9._\-]+$") | |
def is_valid_repo_id(repo_id): | |
return bool(VALID_REPO_ID_REGEX.match(repo_id)) and not repo_id.endswith(('-', '.')) | |
# ------------------------------------------------------------------ | |
# Аппаратные настройки | |
# ------------------------------------------------------------------ | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
# ------------------------------------------------------------------ | |
# Константы | |
# ------------------------------------------------------------------ | |
MAX_SEED = np.iinfo(np.int32).max | |
MAX_IMAGE_SIZE = 1024 | |
# ------------------------------------------------------------------ | |
# Базовая модель (Stable Diffusion) по умолчанию | |
# ------------------------------------------------------------------ | |
model_repo_id = "CompVis/stable-diffusion-v1-4" | |
# Загрузка базового пайплайна (без ControlNet) | |
pipe = StableDiffusionPipeline.from_pretrained( | |
model_repo_id, torch_dtype=torch_dtype, safety_checker=None | |
).to(device) | |
# Применим DDIM-схему как пример | |
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) | |
# Пробуем подгрузить LoRA (unet + text_encoder) | |
try: | |
pipe.unet = PeftModel.from_pretrained(pipe.unet, "./unet") | |
pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, "./text_encoder") | |
except Exception as e: | |
print(f"Не удалось подгрузить LoRA по умолчанию: {e}") | |
# ------------------------------------------------------------------ | |
# Инициализация «IP-Adapter» (для примера укажем вымышленный путь). | |
# Предположим, что IP-Adapter мы храним в ./ip_adapter_weights | |
# ------------------------------------------------------------------ | |
ip_adapter_model = None | |
try: | |
ip_adapter_model = IPAdapterModel("./ip_adapter_weights", device=device) | |
except Exception as e: | |
print(f"Не удалось загрузить IP-Adapter: {e}") | |
# ------------------------------------------------------------------ | |
# Функция генерации | |
# ------------------------------------------------------------------ | |
def infer( | |
model, # Текстовое поле: модель (repo) напр. "CompVis/stable-diffusion-v1-4" | |
prompt, # Текст: позитивный промпт | |
negative_prompt, # Текст: негативный промпт | |
seed, # Сид генератора | |
width, # Ширина | |
height, # Высота | |
guidance_scale, # guidance scale | |
num_inference_steps, # Количество шагов диффузии | |
use_controlnet, # Чекбокс: включать ли ControlNet | |
control_strength, # Слайдер: сила влияния ControlNet | |
controlnet_mode, # Выпадающий список: edge_detection, pose_estimation, depth_estimation | |
controlnet_image, # Изображение для ControlNet | |
use_ip_adapter, # Чекбокс: включать ли IP-adapter | |
ip_adapter_scale, # Слайдер: сила влияния IP-adapter | |
ip_adapter_image, # Изображение для IP-adapter | |
progress=gr.Progress(track_tqdm=True), | |
): | |
global model_repo_id, pipe, ip_adapter_model | |
# --------------------------- | |
# 1) Проверяем, не сменил ли пользователь модель | |
# --------------------------- | |
if model != model_repo_id: | |
if not is_valid_repo_id(model): | |
raise gr.Error(f"Некорректный идентификатор модели: '{model}'. Проверьте название.") | |
try: | |
# Подгружаем модель (без ControlNet) | |
new_pipe = StableDiffusionPipeline.from_pretrained( | |
model, torch_dtype=torch_dtype, safety_checker=None | |
).to(device) | |
new_pipe.scheduler = DDIMScheduler.from_config(new_pipe.scheduler.config) | |
# Повторно загружаем LoRA | |
try: | |
new_pipe.unet = PeftModel.from_pretrained(new_pipe.unet, "./unet") | |
new_pipe.text_encoder = PeftModel.from_pretrained(new_pipe.text_encoder, "./text_encoder") | |
except Exception as e: | |
print(f"Не удалось подгрузить LoRA для новой модели: {e}") | |
pipe = new_pipe | |
model_repo_id = model | |
except Exception as e: | |
raise gr.Error(f"Не удалось загрузить модель '{model}'.\nОшибка: {e}") | |
# --------------------------- | |
# 2) Если включён ControlNet — создаём ControlNetPipeline | |
# --------------------------- | |
local_pipe = pipe # по умолчанию используем базовый pipe | |
if use_controlnet: | |
# Выбираем репозиторий ControlNet в зависимости от режима | |
if controlnet_mode == "edge_detection": | |
controlnet_repo = "lllyasviel/sd-controlnet-canny" | |
elif controlnet_mode == "pose_estimation": | |
controlnet_repo = "lllyasviel/sd-controlnet-openpose" | |
elif controlnet_mode == "depth_estimation": | |
controlnet_repo = "lllyasviel/sd-controlnet-depth" | |
else: | |
raise gr.Error(f"Неизвестный режим ControlNet: {controlnet_mode}") | |
try: | |
controlnet_model = ControlNetModel.from_pretrained( | |
controlnet_repo, | |
torch_dtype=torch_dtype | |
).to(device) | |
# Создаём новый pipeline, указывая ControlNet | |
local_pipe = StableDiffusionControlNetPipeline( | |
vae=pipe.vae, | |
text_encoder=pipe.text_encoder, | |
tokenizer=pipe.tokenizer, | |
unet=pipe.unet, | |
controlnet=controlnet_model, | |
scheduler=pipe.scheduler, | |
safety_checker=None, | |
feature_extractor=pipe.feature_extractor, | |
requires_safety_checker=False, | |
).to(device) | |
except Exception as e: | |
raise gr.Error(f"Ошибка загрузки ControlNet ({controlnet_mode}): {e}") | |
# --------------------------- | |
# 3) Генератор случайных чисел для детерминированности | |
# --------------------------- | |
generator = torch.Generator(device=device).manual_seed(seed) | |
# --------------------------- | |
# 4) Если есть IP-Adapter, подгружаем фичи из референс-изображения | |
# --------------------------- | |
ip_adapter_embedding = None | |
if use_ip_adapter and ip_adapter_model is not None and ip_adapter_model.dummy_weights_loaded: | |
if ip_adapter_image is not None: | |
ip_adapter_embedding = ip_adapter_model.encode_reference_image(ip_adapter_image) | |
else: | |
print("IP-Adapter включён, но не загружено референс-изображение.") | |
elif use_ip_adapter: | |
print("IP-Adapter включён, но модель не загружена или не инициализирована.") | |
# --------------------------- | |
# 5) Выполняем диффузию | |
# (с учётом ControlNet, если включён) | |
# --------------------------- | |
# Параметры для ControlNetPipeline | |
# - Для edge/pose/depth обычно передают control_image через параметр "image" | |
# - Дополнительно можно задать "controlnet_conditioning_scale" (aka strength) | |
# чтобы указать вес ControlNet. | |
# - Документация: https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/controlnet | |
extra_kwargs = {} | |
if use_controlnet and controlnet_image is not None: | |
extra_kwargs["image"] = controlnet_image | |
extra_kwargs["controlnet_conditioning_scale"] = control_strength | |
elif use_controlnet: | |
print("ControlNet включён, но не загружено изображение для ControlNet.") | |
# Запуск генерации | |
try: | |
output = local_pipe( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
width=width, | |
height=height, | |
generator=generator, | |
**extra_kwargs | |
) | |
image = output.images[0] | |
latents = getattr(output, "latents", None) # не во всех версиях diffusers есть latents | |
except Exception as e: | |
raise gr.Error(f"Ошибка при генерации изображения: {e}") | |
# --------------------------- | |
# 6) Применяем IP-Adapter к результату (если нужно). | |
# В реальных библиотеках IP-Adapter может вмешиваться раньше (до/во время диффузии). | |
# Для примера демонстрируем "пост-обработку latents" (если latents сохранились). | |
# --------------------------- | |
if use_ip_adapter and ip_adapter_embedding is not None and latents is not None: | |
try: | |
# Простейший «пример» подмешивания в латенты | |
new_latents = ip_adapter_model.blend_latents_with_adapter(latents, ip_adapter_embedding, ip_adapter_scale) | |
# Теперь нужно декодировать latents в картинку заново | |
# (подразумеваем, что local_pipe поддерживает .vae.decode()) | |
new_latents = new_latents.to(dtype=pipe.vae.dtype) | |
image = pipe.vae.decode(new_latents / 0.18215) | |
image = (image / 2 + 0.5).clamp(0, 1) | |
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()[0] | |
image = (image * 255).astype(np.uint8) | |
image = Image.fromarray(image) | |
except Exception as e: | |
raise gr.Error(f"Ошибка при применении IP-Adapter: {e}") | |
return image, seed | |
# ------------------------------------------------------------------ | |
# Примеры для удобного тестирования | |
# ------------------------------------------------------------------ | |
examples = [ | |
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", | |
"An astronaut riding a green horse", | |
"A delicious ceviche cheesecake slice", | |
] | |
# ------------------------------------------------------------------ | |
# CSS (дополнительно, опционально) | |
# ------------------------------------------------------------------ | |
css = """ | |
#col-container { | |
margin: 0 auto; | |
max-width: 640px; | |
} | |
""" | |
# ------------------------------------------------------------------ | |
# Создаём Gradio-приложение | |
# ------------------------------------------------------------------ | |
import sys | |
def run_app(): | |
with gr.Blocks(css=css) as demo: | |
with gr.Column(elem_id="col-container"): | |
gr.Markdown("# Text-to-Image App (ControlNet + IP-Adapter)") | |
# Поле для ввода/смены модели | |
model = gr.Textbox( | |
label="Model (HuggingFace repo)", | |
value="CompVis/stable-diffusion-v1-4", | |
interactive=True | |
) | |
# Основные поля для Prompt и Negative Prompt | |
prompt = gr.Text( | |
label="Prompt", | |
show_label=False, | |
max_lines=1, | |
placeholder="Enter your prompt", | |
container=False, | |
) | |
negative_prompt = gr.Text( | |
label="Negative prompt", | |
max_lines=1, | |
placeholder="Enter a negative prompt", | |
visible=True, | |
) | |
# Слайдер для выбора seed | |
seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=MAX_SEED, | |
step=1, | |
value=42, | |
) | |
# Слайдеры | |
guidance_scale = gr.Slider( | |
label="Guidance scale", | |
minimum=0.0, | |
maximum=15.0, | |
step=0.1, | |
value=7.0, | |
) | |
num_inference_steps = gr.Slider( | |
label="Number of inference steps", | |
minimum=1, | |
maximum=100, | |
step=1, | |
value=20, | |
) | |
# Кнопка запуска | |
run_button = gr.Button("Run", variant="primary") | |
# Поле для отображения результата | |
result = gr.Image(label="Result", show_label=False) | |
# Продвинутые настройки | |
with gr.Accordion("Advanced Settings", open=False): | |
with gr.Row(): | |
width = gr.Slider( | |
label="Width", | |
minimum=256, | |
maximum=MAX_IMAGE_SIZE, | |
step=64, | |
value=512, | |
) | |
height = gr.Slider( | |
label="Height", | |
minimum=256, | |
maximum=MAX_IMAGE_SIZE, | |
step=64, | |
value=512, | |
) | |
# Блоки ControlNet | |
use_controlnet = gr.Checkbox(label="Use ControlNet", value=False) | |
with gr.Group(visible=False) as controlnet_group: | |
control_strength = gr.Slider( | |
label="ControlNet Strength (Conditioning Scale)", | |
minimum=0.0, | |
maximum=2.0, | |
step=0.1, | |
value=1.0, | |
) | |
controlnet_mode = gr.Dropdown( | |
label="ControlNet Mode", | |
choices=["edge_detection", "pose_estimation", "depth_estimation"], | |
value="edge_detection", | |
) | |
controlnet_image = gr.Image( | |
label="ControlNet Image (map / pose / edges)", | |
type="pil" | |
) | |
def update_controlnet_group(use_controlnet): | |
return {"visible": use_controlnet} | |
use_controlnet.change( | |
update_controlnet_group, | |
inputs=[use_controlnet], | |
outputs=[controlnet_group] | |
) | |
# Блоки IP-adapter | |
use_ip_adapter = gr.Checkbox(label="Use IP-adapter", value=False) | |
with gr.Group(visible=False) as ip_adapter_group: | |
ip_adapter_scale = gr.Slider( | |
label="IP-adapter Scale", | |
minimum=0.0, | |
maximum=2.0, | |
step=0.1, | |
value=1.0, | |
) | |
ip_adapter_image = gr.Image( | |
label="IP-adapter Image (reference)", | |
type="pil" | |
) | |
def update_ip_adapter_group(use_ip_adapter): | |
return {"visible": use_ip_adapter} | |
use_ip_adapter.change( | |
update_ip_adapter_group, | |
inputs=[use_ip_adapter], | |
outputs=[ip_adapter_group] | |
) | |
# Примеры | |
gr.Examples(examples=examples, inputs=[prompt]) | |
# Связка кнопки "Run" с функцией "infer" | |
run_button.click( | |
infer, | |
inputs=[ | |
model, | |
prompt, | |
negative_prompt, | |
seed, | |
width, | |
height, | |
guidance_scale, | |
num_inference_steps, | |
use_controlnet, | |
control_strength, | |
controlnet_mode, | |
controlnet_image, | |
use_ip_adapter, | |
ip_adapter_scale, | |
ip_adapter_image | |
], | |
outputs=[result, seed], | |
) | |
demo.launch() | |
if __name__ == "__main__": | |
run_app() | |