File size: 12,394 Bytes
0adeb3f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
# app.py
import gradio as gr
import torch
import requests
from PIL import Image
from io import BytesIO
import numpy as np
import os
from tqdm import tqdm # Добавляем импорт tqdm
# Импорты из diffusers
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler, StableDiffusionPipeline
# from diffusers.utils import load_image # Не нужен для этого кода
# from huggingface_hub import hf_hub_download # Не нужен для этого кода
# --- Вспомогательная функция для скачивания файлов (например, с Civitai) ---
# Эта функция будет скачивать модель SafeTensor внутри Space при первом запуске
def download_file(url, local_filename):
"""Скачивает файл по URL с индикатором прогресса."""
print(f"Скачиваю {url} в {local_filename}...")
# Проверяем, существует ли файл, чтобы не скачивать его каждый раз
if os.path.exists(local_filename):
print(f"Файл уже существует: {local_filename}. Пропускаю скачивание.")
return local_filename
try:
response = requests.get(url, stream=True)
response.raise_for_status() # Проверка на ошибки HTTP
total_size_in_bytes = int(response.headers.get('content-length', 0))
block_size = 8192 # 8 Kibibytes
with tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True, desc=f"Скачивание {local_filename}") as progress_bar:
with open(local_filename, 'wb') as f:
for chunk in response.iter_content(chunk_size=block_size):
progress_bar.update(len(chunk))
f.write(chunk)
print(f"Скачивание завершено: {local_filename}")
return local_filename
except requests.exceptions.RequestException as e:
print(f"Ошибка скачивания {url}: {e}")
return None
except Exception as e:
print(f"Произошла другая ошибка при скачивании: {e}")
return None
# --- Определение путей/ID моделей ---
# URL вашей SafeTensor модели с Civitai
CIVITAI_SAFETENSOR_URL = "https://civitai.com/api/download/models/1413133?type=Model&format=SafeTensor&size=full&fp=fp8"
# Локальное имя файла для сохранения SafeTensor модели внутри Space
LOCAL_SAFETENSOR_FILENAME = "ultrareal_fine_tune_fp8_full.safetensors"
# ControlNet модель с Hugging Face
CONTROLNET_MODEL_ID = "ABDALLALSWAITI/FLUX.1-dev-ControlNet-Union-Pro-2.0-fp8"
# --- Скачиваем SafeTensor модель (выполнится при запуске скрипта в Space) ---
print("Начинаю скачивание базовой модели...")
downloaded_base_model_path = download_file(CIVITAI_SAFETENSOR_URL, LOCAL_SAFETENSOR_FILENAME)
if not downloaded_base_model_path or not os.path.exists(downloaded_base_model_path):
# Если скачивание не удалось или файл не существует после попытки
print(f"Критическая ошибка: Не удалось получить файл базовой модели по пути: {LOCAL_SAFETENSOR_FILENAME}")
print("Проверьте логи Space на наличие ошибок скачивания.")
# Возможно, здесь стоит выбросить исключение или как-то иначе остановить приложение
# Для примера, просто присвоим None и приложение не сможет загрузить пайплайн
pipeline = None
else:
# --- Загрузка моделей и создание пайплайна ---
def load_pipeline_components(base_model_path, controlnet_model_id):
"""Загружает базовую модель из локального файла, ControlNet и собирает пайплайн."""
print(f"Загрузка ControlNet модели: {controlnet_model_id}")
# Загрузка ControlNet с Hugging Face Hub - кешируется автоматически Space
controlnet = ControlNetModel.from_pretrained(controlnet_model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
print(f"Загрузка базовой модели из локального файла: {base_model_path}")
# Загружаем базовую модель из локального SafeTensor файла
# diffusers умеет загружать локальные файлы .safetensors
pipe = StableDiffusionPipeline.from_pretrained(
base_model_path, # Указываем путь к локальному файлу внутри Space
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
safety_checker=None # Отключение safety checker для скорости (используйте с осторожностью!)
)
# Теперь объединяем базовый пайплайн с ControlNet
# Создаем StableDiffusionControlNetPipeline на основе загруженного базового пайплайна
print("Создание пайплайна StableDiffusionControlNetPipeline...")
controlnet_pipe = StableDiffusionControlNetPipeline(
vae=pipe.vae,
text_encoder=pipe.text_encoder,
tokenizer=pipe.tokenizer,
unet=pipe.unet,
controlnet=controlnet, # Передаем загруженный ControlNet
scheduler=pipe.scheduler, # Используем планировщик из базового пайплайна
safety_checker=None,
feature_extractor=pipe.feature_extractor
)
# Рекомендуется использовать планировщик UniPC для ControlNet (или обновить существующий)
# Обновляем планировщик в новом ControlNet пайплайне
controlnet_pipe.scheduler = UniPCMultistepScheduler.from_config(controlnet_pipe.scheduler.config)
# Удаляем старый базовый пайплайн для освобождения памяти
del pipe
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Перемещаем ControlNet пайплайн на GPU, если доступно
if torch.cuda.is_available():
controlnet_pipe = controlnet_pipe.to("cuda")
print("Пайплайн перемещен на GPU.")
else:
print("GPU не найдено. Пайплайн будет работать на CPU (будет медленно).") # В Space на CPU работать не будет эффективно
return controlnet_pipe
# Загружаем пайплайн при запуске скрипта, только если файл модели успешно скачан
pipeline = load_pipeline_components(downloaded_base_model_path, CONTROLNET_MODEL_ID)
# --- Функция рендеринга для Gradio ---
# Эта функция будет вызываться интерфейсом Gradio в Space
def generate_image_gradio(controlnet_image: np.ndarray, prompt: str, negative_prompt: str = "", guidance_scale: float = 7.5, num_inference_steps: int = 30, controlnet_conditioning_scale: float = 1.0):
"""
Генерирует изображение с использованием Stable Diffusion ControlNet.
Принимает изображение NumPy, текст промта и другие параметры.
Возвращает сгенерированное изображение в формате PIL Image.
"""
# Проверяем, успешно ли загрузился пайплайн
if pipeline is None:
return None, "Ошибка: Пайплайн модели не загружен. Проверьте логи Space."
if controlnet_image is None:
return None, "Ошибка: необходимо загрузить изображение для ControlNet."
print(f"Генерация изображения с промтом: '{prompt}'")
print(f"Размер входного изображения: {controlnet_image.shape}")
# Gradio возвращает изображение как numpy array. Преобразуем в PIL Image для пайплайна.
input_image_pil = Image.fromarray(controlnet_image).convert("RGB")
# Выполняем рендеринг с помощью пайплайна
try:
output = pipeline(
prompt=prompt,
image=input_image_pil, # Входное изображение для ControlNet
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
controlnet_conditioning_scale=controlnet_conditioning_scale
# Здесь можно добавить generator=... (для сидов), width=..., height=..., etc.
)
# Результат находится в output.images[0]
generated_image_pil = output.images[0]
print("Генерация завершена.")
return generated_image_pil, "Успех!"
except Exception as e:
print(f"Ошибка при генерации: {e}")
return None, f"Ошибка при генерации: {e}"
# --- Настройка интерфейса Gradio ---
# Определяем входные и выходные элементы
input_image_comp = gr.Image(type="numpy", label="Изображение для ControlNet (набросок, карта глубины и т.д.)")
prompt_comp = gr.Textbox(label="Промт (Prompt)")
negative_prompt_comp = gr.Textbox(label="Негативный промт (Negative Prompt)")
guidance_scale_comp = gr.Slider(minimum=1.0, maximum=20.0, value=7.5, step=0.1, label="Степень соответствия промту (Guidance Scale)")
num_inference_steps_comp = gr.Slider(minimum=10, maximum=150, value=30, step=1, label="Количество шагов (Inference Steps)")
controlnet_conditioning_scale_comp = gr.Slider(minimum=0.0, maximum=2.0, value=1.0, step=0.05, label="Вес ControlNet (ControlNet Scale)")
output_image_comp = gr.Image(type="pil", label="Сгенерированное изображение")
status_text_comp = gr.Textbox(label="Статус")
# Создаем интерфейс Gradio
# Поскольку мы в Space, Gradio SDK сам вызовет interface.launch()
# Нам просто нужно определить интерфейс
interface = gr.Interface(
fn=generate_image_gradio,
inputs=[
input_image_comp,
prompt_comp,
negative_prompt_comp,
guidance_scale_comp,
num_inference_steps_comp,
controlnet_conditioning_scale_comp
],
outputs=[output_image_comp, status_text_comp],
title="Stable Diffusion ControlNet Interface (SafeTensor Base Model)",
description="Загрузите изображение для ControlNet, введите промт и нажмите 'Generate'. Используется локальная SafeTensor модель и ControlNet с Hugging Face."
)
# Важно: Не вызывайте interface.launch() в блоке if __name__ == "__main__":
# Gradio SDK в Space сделает это автоматически.
# Если вы оставите if __name__ == "__main__": interface.launch(), оно тоже будет работать,
# но в среде Space это менее критично, чем при локальном запуске.
# Для ясности в Space можно убрать блок if __name__ == "__main__":
# Я оставил его в коде выше, но знайте, что SDK вызовет interface.launch() независимо от него. |