ign / app.py
Superigni's picture
fix path model
59ec49b verified
raw
history blame
15.3 kB
# app.py
import gradio as gr
import torch
import requests
from PIL import Image
import numpy as np
import os
from tqdm import tqdm # Добавляем импорт tqdm
# Импорты из diffusers
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler, StableDiffusionPipeline
# from diffusers.utils import load_image # Не нужен для этого кода
# from huggingface_hub import hf_hub_download # Не нужен для этого кода
# --- Вспомогательная функция для скачивания файлов (например, с Civitai) ---
# Эта функция будет скачивать модель SafeTensor внутри Space при первом запуске
def download_file(url, local_filename):
"""Скачивает файл по URL с индикатором прогресса."""
print(f"Скачиваю {url} в {local_filename}...")
# Проверяем, существует ли файл, чтобы не скачивать его каждый раз
if os.path.exists(local_filename):
print(f"Файл уже существует: {local_filename}. Пропускаю скачивание.")
return local_filename
try:
response = requests.get(url, stream=True)
response.raise_for_status() # Проверка на ошибки HTTP
total_size_in_bytes = int(response.headers.get('content-length', 0))
block_size = 8192 # 8 Kibibytes
# Используем tqdm для индикатора прогресса, только если размер известен
if total_size_in_bytes > 0:
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True, desc=f"Скачивание {local_filename}")
else:
print("Размер файла неизвестен, скачивание без индикатора прогресса.")
progress_bar = None
with open(local_filename, 'wb') as f:
for chunk in response.iter_content(chunk_size=block_size):
if progress_bar:
progress_bar.update(len(chunk))
f.write(chunk)
if progress_bar:
progress_bar.close()
print(f"Скачивание завершено: {local_filename}")
return local_filename
except requests.exceptions.RequestException as e:
print(f"Ошибка скачивания {url}: {e}")
return None
except Exception as e:
print(f"Произошла другая ошибка при скачивании: {e}")
return None
# --- Определение путей/ID моделей ---
# URL вашей SafeTensor модели с Civitai
CIVITAI_SAFETENSOR_URL = "https://civitai.com/api/download/models/1413133?type=Model&format=SafeTensor&size=full&fp=fp8"
# Локальное имя файла для сохранения SafeTensor модели внутри Space
LOCAL_SAFETENSOR_FILENAME = "ultrareal_fine_tune_fp8_full.safetensors"
# ControlNet модель с Hugging Face
CONTROLNET_MODEL_ID = "ABDALLALSWAITI/FLUX.1-dev-ControlNet-Union-Pro-2.0-fp8"
# Переменная для хранения пайплайна (будет загружен при запуске скрипта)
pipeline = None
downloaded_base_model_path = None # Переменная для пути к скачанному файлу
# --- Скачиваем SafeTensor модель (выполнится при запуске скрипта в Space) ---
print("Начинаю скачивание базовой модели...")
downloaded_base_model_path = download_file(CIVITAI_SAFETENSOR_URL, LOCAL_SAFETENSOR_FILENAME)
# --- Загрузка моделей и создание пайплайна ---
# Эта функция вызывается один раз при запуске скрипта
def load_pipeline_components(base_model_path, controlnet_model_id):
"""Загружает базовую модель из локального файла, ControlNet и собирает пайплайн."""
if not base_model_path or not os.path.exists(base_model_path):
print(f"Ошибка загрузки: Файл базовой модели не найден по пути: {base_model_path}")
return None # Не можем загрузить пайплайн без файла модели
print(f"Загрузка ControlNet модели: {controlnet_model_id}")
# Загрузка ControlNet с Hugging Face Hub - кешируется автоматически Space
try:
controlnet = ControlNetModel.from_pretrained(controlnet_model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
except Exception as e:
print(f"Ошибка загрузки ControlNet модели с HF Hub: {controlnet_model_id}. Проверьте ID или соединение.")
print(f"Ошибка: {e}")
return None # Не можем загрузить пайплайн без ControlNet
print(f"Загрузка базовой модели из локального файла: {base_model_path} с использованием from_single_file")
# Используем from_single_file для загрузки пайплайна из одного SafeTensor файла
try:
pipe = StableDiffusionPipeline.from_single_file(
base_model_path, # Указываем путь к локальному файлу .safetensors
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
# from_single_file пытается найти конфигурацию VAE, tokenizer и scheduler.
# Если ваша модель требует специфической конфигурации, возможно,
# потребуется указать путь к папке с конфигом или загрузить их отдельно.
# Для большинства Safetensor SD 1.5/2.x from_single_file работает из коробки.
)
# Отключение safety checker после загрузки, если он был загружен
if hasattr(pipe, 'safety_checker') and pipe.safety_checker is not None:
print("Отключаю safety checker...")
pipe.safety_checker = None
print("Safety checker отключен.")
except Exception as e:
print(f"Ошибка при загрузке базовой модели из файла {base_model_path}: {e}")
print("Убедитесь, что файл не поврежден, соответствует формату StableDiffusion и from_single_file может его обработать.")
return None # Возвращаем None, если загрузка базовой модели не удалась
# --- Создание пайплайна StableDiffusionControlNetPipeline из компонентов ---
# Этот блок выполняется ТОЛЬКО если базовая модель и ControlNet успешно загружены
print("Создание финального пайплайна StableDiffusionControlNetPipeline...")
try:
controlnet_pipe = StableDiffusionControlNetPipeline(
vae=pipe.vae,
text_encoder=pipe.text_encoder,
tokenizer=pipe.tokenizer,
unet=pipe.unet,
controlnet=controlnet, # Передаем загруженный ControlNet
scheduler=pipe.scheduler, # Используем планировщик из базового пайплайна
safety_checker=None, # Убираем safety_checker здесь при создании нового пайплайна
feature_extractor=pipe.feature_extractor
)
# Рекомендуется использовать планировщик UniPC для ControlNet (или обновить существующий)
# Обновляем планировщик в новом ControlNet пайплайне
controlnet_pipe.scheduler = UniPCMultistepScheduler.from_config(controlnet_pipe.scheduler.config)
# Удаляем старый объект пайплайна для освобождения памяти GPU
del pipe
if torch.cuda.is_available():
torch.cuda.empty_cache()
print("Память GPU очищена после создания ControlNet пайплайна.")
# Перемещаем ControlNet пайплайн на GPU, если доступно
if torch.cuda.is_available():
controlnet_pipe = controlnet_pipe.to("cuda")
print("Финальный пайплайн перемещен на GPU.")
else:
print("GPU не найдено. Пайплайн будет работать на CPU (крайне медленно).") # В Space на CPU работать не будет эффективно
return controlnet_pipe # Возвращаем готовый пайплайн
except Exception as e:
print(f"Ошибка при создании финального StableDiffusionControlNetPipeline: {e}")
print("Проверьте совместимость компонентов (базовая модель и ControlNet).")
return None # Возвращаем None, если собрать финальный пайплайн не удалось
# --- Загружаем пайплайн при запуске скрипта, только если файл модели успешно скачан ---
# Этот код выполняется после скачивания файла
if downloaded_base_model_path and os.path.exists(downloaded_base_model_path):
pipeline = load_pipeline_components(downloaded_base_model_path, CONTROLNET_MODEL_ID)
else:
print("Пропуск загрузки пайплайна из-за ошибки скачивания или отсутствия файла.")
pipeline = None # Убеждаемся, что pipeline равен None при ошибке
# --- Функция рендеринга для Gradio ---
# Эта функция будет вызываться интерфейсом Gradio в Space
def generate_image_gradio(controlnet_image: np.ndarray, prompt: str, negative_prompt: str = "", guidance_scale: float = 7.5, num_inference_steps: int = 30, controlnet_conditioning_scale: float = 1.0):
"""
Генерирует изображение с использованием Stable Diffusion ControlNet.
Принимает изображение NumPy, текст промта и другие параметры.
Возвращает сгенерированное изображение в формате PIL Image.
"""
# Проверяем, успешно ли загрузился пайплайн
if pipeline is None:
print("Попытка генерации, но пайплайн модели не загружен.")
return None, "Ошибка: Пайплайн модели не загружен. Проверьте логи Space."
if controlnet_image is None:
return None, "Ошибка: необходимо загрузить изображение для ControlNet."
print(f"Генерация изображения с промтом: '{prompt}'")
print(f"Размер входного изображения: {controlnet_image.shape}")
# Gradio возвращает изображение как numpy array. Преобразуем в PIL Image для пайплайна.
# diffusers ControlNet ожидают изображение в формате PIL Image или PyTorch Tensor в RGB
input_image_pil = Image.fromarray(controlnet_image).convert("RGB")
# Выполняем рендеринг с помощью пайплайна
try:
# Здесь вы можете добавить generator=... (для сидов), width=..., height=..., etc.
# Передаем все параметры в вызов пайплайна
output = pipeline(
prompt=prompt,
image=input_image_pil, # Входное изображение для ControlNet
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
controlnet_conditioning_scale=controlnet_conditioning_scale
)
# Результат находится в output.images[0]
generated_image_pil = output.images[0]
print("Генерация завершена.")
return generated_image_pil, "Успех!"
except Exception as e:
print(f"Ошибка при генерации: {e}")
# Возвращаем None и сообщение об ошибке в интерфейс Gradio
return None, f"Ошибка при генерации: {e}"
# --- Настройка интерфейса Gradio ---
# Определяем входные и выходные элементы
input_image_comp = gr.Image(type="numpy", label="Изображение для ControlNet (набросок, карта глубины и т.д.)")
prompt_comp = gr.Textbox(label="Промт (Prompt)")
negative_prompt_comp = gr.Textbox(label="Негативный промт (Negative Prompt)")
guidance_scale_comp = gr.Slider(minimum=1.0, maximum=20.0, value=7.5, step=0.1, label="Степень соответствия промту (Guidance Scale)")
num_inference_steps_comp = gr.Slider(minimum=10, maximum=150, value=30, step=1, label="Количество шагов (Inference Steps)")
controlnet_conditioning_scale_comp = gr.Slider(minimum=0.0, maximum=2.0, value=1.0, step=0.05, label="Вес ControlNet (ControlNet Scale)")
output_image_comp = gr.Image(type="pil", label="Сгенерированное изображение")
status_text_comp = gr.Textbox(label="Статус")
# Создаем интерфейс Gradio
# Поскольку мы в Space, Gradio SDK сам вызовет interface.launch()
# Нам просто нужно определить объект интерфейса
interface = gr.Interface(
fn=generate_image_gradio,
inputs=[
input_image_comp,
prompt_comp,
negative_prompt_comp,
guidance_scale_comp,
num_inference_steps_comp,
controlnet_conditioning_scale_comp
],
outputs=[output_image_comp, status_text_comp],
title="Stable Diffusion ControlNet Interface (SafeTensor Base Model on HF Space)",
description="Загрузите изображение для ControlNet, введите промт и нажмите 'Generate'. Используется локальная SafeTensor модель и ControlNet с Hugging Face."
)
# Нет необходимости вызывать interface.launch() в блоке if __name__ == "__main__":
# Gradio SDK в Space сделает это автоматически при запуске скрипта.