|
|
|
import gradio as gr |
|
import torch |
|
import requests |
|
from PIL import Image |
|
import numpy as np |
|
import os |
|
from tqdm import tqdm |
|
|
|
|
|
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler, StableDiffusionPipeline |
|
|
|
|
|
|
|
|
|
|
|
def download_file(url, local_filename): |
|
"""Скачивает файл по URL с индикатором прогресса.""" |
|
print(f"Скачиваю {url} в {local_filename}...") |
|
|
|
if os.path.exists(local_filename): |
|
print(f"Файл уже существует: {local_filename}. Пропускаю скачивание.") |
|
return local_filename |
|
|
|
try: |
|
response = requests.get(url, stream=True) |
|
response.raise_for_status() |
|
|
|
total_size_in_bytes = int(response.headers.get('content-length', 0)) |
|
block_size = 8192 |
|
|
|
|
|
if total_size_in_bytes > 0: |
|
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True, desc=f"Скачивание {local_filename}") |
|
else: |
|
print("Размер файла неизвестен, скачивание без индикатора прогресса.") |
|
progress_bar = None |
|
|
|
|
|
with open(local_filename, 'wb') as f: |
|
for chunk in response.iter_content(chunk_size=block_size): |
|
if progress_bar: |
|
progress_bar.update(len(chunk)) |
|
f.write(chunk) |
|
|
|
if progress_bar: |
|
progress_bar.close() |
|
|
|
print(f"Скачивание завершено: {local_filename}") |
|
return local_filename |
|
except requests.exceptions.RequestException as e: |
|
print(f"Ошибка скачивания {url}: {e}") |
|
return None |
|
except Exception as e: |
|
print(f"Произошла другая ошибка при скачивании: {e}") |
|
return None |
|
|
|
|
|
|
|
|
|
CIVITAI_SAFETENSOR_URL = "https://civitai.com/api/download/models/1413133?type=Model&format=SafeTensor&size=full&fp=fp8" |
|
|
|
LOCAL_SAFETENSOR_FILENAME = "ultrareal_fine_tune_fp8_full.safetensors" |
|
|
|
|
|
CONTROLNET_MODEL_ID = "ABDALLALSWAITI/FLUX.1-dev-ControlNet-Union-Pro-2.0-fp8" |
|
|
|
|
|
pipeline = None |
|
downloaded_base_model_path = None |
|
|
|
|
|
print("Начинаю скачивание базовой модели...") |
|
downloaded_base_model_path = download_file(CIVITAI_SAFETENSOR_URL, LOCAL_SAFETENSOR_FILENAME) |
|
|
|
|
|
|
|
def load_pipeline_components(base_model_path, controlnet_model_id): |
|
"""Загружает базовую модель из локального файла, ControlNet и собирает пайплайн.""" |
|
if not base_model_path or not os.path.exists(base_model_path): |
|
print(f"Ошибка загрузки: Файл базовой модели не найден по пути: {base_model_path}") |
|
return None |
|
|
|
print(f"Загрузка ControlNet модели: {controlnet_model_id}") |
|
|
|
try: |
|
controlnet = ControlNetModel.from_pretrained(controlnet_model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32) |
|
except Exception as e: |
|
print(f"Ошибка загрузки ControlNet модели с HF Hub: {controlnet_model_id}. Проверьте ID или соединение.") |
|
print(f"Ошибка: {e}") |
|
return None |
|
|
|
print(f"Загрузка базовой модели из локального файла: {base_model_path} с использованием from_single_file") |
|
|
|
try: |
|
pipe = StableDiffusionPipeline.from_single_file( |
|
base_model_path, |
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
|
|
|
|
|
|
|
|
) |
|
|
|
if hasattr(pipe, 'safety_checker') and pipe.safety_checker is not None: |
|
print("Отключаю safety checker...") |
|
pipe.safety_checker = None |
|
print("Safety checker отключен.") |
|
|
|
except Exception as e: |
|
print(f"Ошибка при загрузке базовой модели из файла {base_model_path}: {e}") |
|
print("Убедитесь, что файл не поврежден, соответствует формату StableDiffusion и from_single_file может его обработать.") |
|
return None |
|
|
|
|
|
|
|
print("Создание финального пайплайна StableDiffusionControlNetPipeline...") |
|
try: |
|
controlnet_pipe = StableDiffusionControlNetPipeline( |
|
vae=pipe.vae, |
|
text_encoder=pipe.text_encoder, |
|
tokenizer=pipe.tokenizer, |
|
unet=pipe.unet, |
|
controlnet=controlnet, |
|
scheduler=pipe.scheduler, |
|
safety_checker=None, |
|
feature_extractor=pipe.feature_extractor |
|
) |
|
|
|
|
|
|
|
controlnet_pipe.scheduler = UniPCMultistepScheduler.from_config(controlnet_pipe.scheduler.config) |
|
|
|
|
|
del pipe |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
print("Память GPU очищена после создания ControlNet пайплайна.") |
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
controlnet_pipe = controlnet_pipe.to("cuda") |
|
print("Финальный пайплайн перемещен на GPU.") |
|
else: |
|
print("GPU не найдено. Пайплайн будет работать на CPU (крайне медленно).") |
|
|
|
return controlnet_pipe |
|
|
|
except Exception as e: |
|
print(f"Ошибка при создании финального StableDiffusionControlNetPipeline: {e}") |
|
print("Проверьте совместимость компонентов (базовая модель и ControlNet).") |
|
return None |
|
|
|
|
|
|
|
|
|
if downloaded_base_model_path and os.path.exists(downloaded_base_model_path): |
|
pipeline = load_pipeline_components(downloaded_base_model_path, CONTROLNET_MODEL_ID) |
|
else: |
|
print("Пропуск загрузки пайплайна из-за ошибки скачивания или отсутствия файла.") |
|
pipeline = None |
|
|
|
|
|
|
|
|
|
def generate_image_gradio(controlnet_image: np.ndarray, prompt: str, negative_prompt: str = "", guidance_scale: float = 7.5, num_inference_steps: int = 30, controlnet_conditioning_scale: float = 1.0): |
|
""" |
|
Генерирует изображение с использованием Stable Diffusion ControlNet. |
|
Принимает изображение NumPy, текст промта и другие параметры. |
|
Возвращает сгенерированное изображение в формате PIL Image. |
|
""" |
|
|
|
if pipeline is None: |
|
print("Попытка генерации, но пайплайн модели не загружен.") |
|
return None, "Ошибка: Пайплайн модели не загружен. Проверьте логи Space." |
|
|
|
if controlnet_image is None: |
|
return None, "Ошибка: необходимо загрузить изображение для ControlNet." |
|
|
|
print(f"Генерация изображения с промтом: '{prompt}'") |
|
print(f"Размер входного изображения: {controlnet_image.shape}") |
|
|
|
|
|
|
|
input_image_pil = Image.fromarray(controlnet_image).convert("RGB") |
|
|
|
|
|
try: |
|
|
|
|
|
output = pipeline( |
|
prompt=prompt, |
|
image=input_image_pil, |
|
negative_prompt=negative_prompt, |
|
guidance_scale=guidance_scale, |
|
num_inference_steps=num_inference_steps, |
|
controlnet_conditioning_scale=controlnet_conditioning_scale |
|
) |
|
|
|
|
|
generated_image_pil = output.images[0] |
|
|
|
print("Генерация завершена.") |
|
return generated_image_pil, "Успех!" |
|
except Exception as e: |
|
print(f"Ошибка при генерации: {e}") |
|
|
|
return None, f"Ошибка при генерации: {e}" |
|
|
|
|
|
|
|
|
|
input_image_comp = gr.Image(type="numpy", label="Изображение для ControlNet (набросок, карта глубины и т.д.)") |
|
prompt_comp = gr.Textbox(label="Промт (Prompt)") |
|
negative_prompt_comp = gr.Textbox(label="Негативный промт (Negative Prompt)") |
|
guidance_scale_comp = gr.Slider(minimum=1.0, maximum=20.0, value=7.5, step=0.1, label="Степень соответствия промту (Guidance Scale)") |
|
num_inference_steps_comp = gr.Slider(minimum=10, maximum=150, value=30, step=1, label="Количество шагов (Inference Steps)") |
|
controlnet_conditioning_scale_comp = gr.Slider(minimum=0.0, maximum=2.0, value=1.0, step=0.05, label="Вес ControlNet (ControlNet Scale)") |
|
|
|
output_image_comp = gr.Image(type="pil", label="Сгенерированное изображение") |
|
status_text_comp = gr.Textbox(label="Статус") |
|
|
|
|
|
|
|
|
|
|
|
interface = gr.Interface( |
|
fn=generate_image_gradio, |
|
inputs=[ |
|
input_image_comp, |
|
prompt_comp, |
|
negative_prompt_comp, |
|
guidance_scale_comp, |
|
num_inference_steps_comp, |
|
controlnet_conditioning_scale_comp |
|
], |
|
outputs=[output_image_comp, status_text_comp], |
|
title="Stable Diffusion ControlNet Interface (SafeTensor Base Model on HF Space)", |
|
description="Загрузите изображение для ControlNet, введите промт и нажмите 'Generate'. Используется локальная SafeTensor модель и ControlNet с Hugging Face." |
|
) |
|
|
|
|
|
|