ControlNet / app.py
trashchenkov's picture
Update app.py
f615f7e verified
raw
history blame
14.6 kB
import gradio as gr
import numpy as np
import torch
from diffusers import DiffusionPipeline
from peft import PeftModel
import re
from PIL import Image
# Устройство и тип данных
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# Регулярное выражение для проверки корректности модели
VALID_REPO_ID_REGEX = re.compile(r"^[a-zA-Z0-9._\-]+/[a-zA-Z0-9._\-]+$")
def is_valid_repo_id(repo_id):
return bool(VALID_REPO_ID_REGEX.match(repo_id)) and not repo_id.endswith(("-", "."))
# Базовые константы
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
# Изначально загружаем модель по умолчанию (без ControlNet/IP-adapter)
model_repo_id = "CompVis/stable-diffusion-v1-4"
pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype, safety_checker=None).to(device)
# Попробуем подгрузить LoRA-модификации (unet + text_encoder)
try:
pipe.unet = PeftModel.from_pretrained(pipe.unet, "./unet")
pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, "./text_encoder")
except Exception as e:
print(f"Не удалось подгрузить LoRA по умолчанию: {e}")
def infer(
model,
prompt,
negative_prompt,
seed,
width,
height,
guidance_scale,
num_inference_steps,
use_controlnet,
control_strength,
controlnet_mode,
controlnet_image,
use_ip_adapter,
ip_adapter_scale,
ip_adapter_image,
progress=gr.Progress(track_tqdm=True),
):
"""
Функция генерации изображения с учётом дополнительных опций:
- Если включён ControlNet или IP‑adapter, используется пайплайн StableDiffusionControlNetPipeline.
- При включённом IP‑adapter без ControlNet создаётся пустое (заглушка) изображение для параметра controlnet.
- В остальных случаях используется стандартный пайплайн.
"""
global model_repo_id, pipe
# Если хотя бы один из режимов (ControlNet или IP‑adapter) включён, переключаемся на ControlNet‑пайплайн
if use_controlnet or use_ip_adapter:
# Если модель изменилась или текущий pipe не поддерживает IP‑adapter (нет метода load_ip_adapter),
# загружаем новый пайплайн.
if model != model_repo_id or not hasattr(pipe, "load_ip_adapter"):
try:
# Импорт необходимых классов внутри функции (если они не нужны при базовой генерации)
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
except ImportError as e:
raise gr.Error(f"Не удалось импортировать необходимые модули для ControlNet: {e}")
# Определяем, какую модель ControlNet использовать.
if use_controlnet:
if controlnet_mode == "edge_detection":
cn_model_id = "lllyasviel/sd-controlnet-canny"
elif controlnet_mode == "pose_estimation":
cn_model_id = "lllyasviel/sd-controlnet-openpose"
else:
cn_model_id = "lllyasviel/sd-controlnet-canny"
else:
# Если включён только IP‑adapter, используем модель по умолчанию (например, canny)
cn_model_id = "lllyasviel/sd-controlnet-canny"
try:
controlnet = ControlNetModel.from_pretrained(cn_model_id, torch_dtype=torch_dtype)
new_pipe = StableDiffusionControlNetPipeline.from_pretrained(
model, torch_dtype=torch_dtype, controlnet=controlnet
).to(device)
new_pipe.safety_checker = None
# Подгружаем LoRA-модификации (если они есть)
try:
new_pipe.unet = PeftModel.from_pretrained(new_pipe.unet, "./unet")
new_pipe.text_encoder = PeftModel.from_pretrained(new_pipe.text_encoder, "./text_encoder")
except Exception as e:
print(f"Не удалось подгрузить LoRA: {e}")
# Если включён IP‑adapter, загружаем его и устанавливаем масштаб.
if use_ip_adapter:
new_pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin")
new_pipe.set_ip_adapter_scale(ip_adapter_scale)
pipe = new_pipe
model_repo_id = model
except Exception as e:
raise gr.Error(f"Не удалось загрузить модель с ControlNet/IP-adapter '{model}'.\nОшибка: {e}")
# Подготавливаем изображение для передачи в ControlNet.
# Если включён ControlNet, пользователь должен загрузить изображение.
# Если нет, но включён IP‑adapter, создаём пустое изображение-заглушку.
if use_controlnet:
if controlnet_image is None:
raise gr.Error("ControlNet включён, но изображение для него не загружено.")
cn_image = controlnet_image
cn_image = cn_image.resize((width, height))
else:
cn_image = Image.new("RGB", (width, height), (255, 255, 255))
# Если включён IP‑adapter, проверяем, что изображение для него загружено.
if use_ip_adapter and ip_adapter_image is None:
raise gr.Error("IP-adapter включён, но изображение для него не загружено.")
if ip_adapter_image:
ip_adapter_image = ip_adapter_image.resize((width, height))
try:
generator = torch.Generator(device=device).manual_seed(seed)
# Вызываем пайплайн StableDiffusionControlNetPipeline.
output = pipe(
prompt=prompt,
image=cn_image,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
controlnet_conditioning_scale=control_strength if use_controlnet else 1.0,
ip_adapter_image=ip_adapter_image if use_ip_adapter else None,
)
image = output.images[0]
except Exception as e:
raise gr.Error(f"Ошибка при генерации изображения с ControlNet/IP-adapter: {e}")
return image, seed
else:
# Если ни один из дополнительных режимов не включён, используем стандартный пайплайн.
if model != model_repo_id:
if not is_valid_repo_id(model):
raise gr.Error(f"Некорректный идентификатор модели: '{model}'. Проверьте название.")
try:
new_pipe = DiffusionPipeline.from_pretrained(model, torch_dtype=torch_dtype).to(device)
try:
new_pipe.unet = PeftModel.from_pretrained(new_pipe.unet, "./unet")
new_pipe.text_encoder = PeftModel.from_pretrained(new_pipe.text_encoder, "./text_encoder")
except Exception as e:
print(f"Не удалось подгрузить LoRA: {e}")
pipe = new_pipe
model_repo_id = model
except Exception as e:
raise gr.Error(f"Не удалось загрузить модель '{model}'.\nОшибка: {e}")
try:
generator = torch.Generator(device=device).manual_seed(seed)
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
).images[0]
except Exception as e:
raise gr.Error(f"Ошибка при генерации изображения: {e}")
return image, seed
# Примеры для удобного тестирования
examples = [
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
"An astronaut riding a green horse",
"A delicious ceviche cheesecake slice",
]
# Дополнительный CSS для оформления
css = """
#col-container {
margin: 0 auto;
max-width: 640px;
}
"""
# Создаём Gradio-приложение
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("# Text-to-Image App")
# Поле для ввода/смены модели
model = gr.Textbox(
label="Model",
value="CompVis/stable-diffusion-v1-4", # Значение по умолчанию
interactive=True,
)
# Основные поля для Prompt и Negative Prompt
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
negative_prompt = gr.Text(
label="Negative prompt",
max_lines=1,
placeholder="Enter a negative prompt",
visible=True,
)
# Слайдер для выбора seed
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=42,
)
# Слайдеры для guidance_scale и num_inference_steps
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=10.0,
step=0.1,
value=7.0,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=20,
)
# Чекбокс для включения ControlNet
use_controlnet = gr.Checkbox(label="Использовать ControlNet", value=False)
# Группа дополнительных настроек для ControlNet (будет показана только при включённом чекбоксе)
with gr.Group(visible=False) as controlnet_group:
control_strength = gr.Slider(
label="ControlNet conditioning scale",
minimum=0.0,
maximum=2.0,
step=0.1,
value=0.7,
)
controlnet_mode = gr.Dropdown(
label="Режим работы ControlNet",
choices=["edge_detection", "pose_estimation"],
value="edge_detection",
)
controlnet_image = gr.Image(
label="Изображение для ControlNet",
type="pil",
)
# Чекбокс для включения IP‑adapter
use_ip_adapter = gr.Checkbox(label="Использовать IP-adapter", value=False)
# Группа дополнительных настроек для IP‑adapter
with gr.Group(visible=False) as ip_adapter_group:
ip_adapter_scale = gr.Slider(
label="IP-adapter Scale",
minimum=0.0,
maximum=2.0,
step=0.1,
value=0.6,
)
ip_adapter_image = gr.Image(
label="Изображение для IP-adapter",
type="pil",
)
# Обработка событий для показа/скрытия дополнительных настроек
use_controlnet.change(lambda x: gr.update(visible=x), inputs=use_controlnet, outputs=controlnet_group)
use_ip_adapter.change(lambda x: gr.update(visible=x), inputs=use_ip_adapter, outputs=ip_adapter_group)
# Кнопка запуска
run_button = gr.Button("Run", variant="primary")
# Поле для отображения результата
result = gr.Image(label="Result", show_label=False)
# Продвинутые настройки (Accordion)
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=512,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=512,
)
# Примеры
gr.Examples(examples=examples, inputs=[prompt])
# Связка кнопки "Run" с функцией "infer"
run_button.click(
infer,
inputs=[
model,
prompt,
negative_prompt,
seed,
width,
height,
guidance_scale,
num_inference_steps,
use_controlnet,
control_strength,
controlnet_mode,
controlnet_image,
use_ip_adapter,
ip_adapter_scale,
ip_adapter_image,
],
outputs=[result, seed],
)
# Запуск приложения
if __name__ == "__main__":
demo.launch()