File size: 12,394 Bytes
0adeb3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
# app.py
import gradio as gr
import torch
import requests
from PIL import Image
from io import BytesIO
import numpy as np
import os
from tqdm import tqdm # Добавляем импорт tqdm

# Импорты из diffusers
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler, StableDiffusionPipeline
# from diffusers.utils import load_image # Не нужен для этого кода
# from huggingface_hub import hf_hub_download # Не нужен для этого кода

# --- Вспомогательная функция для скачивания файлов (например, с Civitai) ---
# Эта функция будет скачивать модель SafeTensor внутри Space при первом запуске
def download_file(url, local_filename):
    """Скачивает файл по URL с индикатором прогресса."""
    print(f"Скачиваю {url} в {local_filename}...")
    # Проверяем, существует ли файл, чтобы не скачивать его каждый раз
    if os.path.exists(local_filename):
        print(f"Файл уже существует: {local_filename}. Пропускаю скачивание.")
        return local_filename

    try:
        response = requests.get(url, stream=True)
        response.raise_for_status() # Проверка на ошибки HTTP

        total_size_in_bytes = int(response.headers.get('content-length', 0))
        block_size = 8192 # 8 Kibibytes

        with tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True, desc=f"Скачивание {local_filename}") as progress_bar:
            with open(local_filename, 'wb') as f:
                for chunk in response.iter_content(chunk_size=block_size):
                    progress_bar.update(len(chunk))
                    f.write(chunk)

        print(f"Скачивание завершено: {local_filename}")
        return local_filename
    except requests.exceptions.RequestException as e:
        print(f"Ошибка скачивания {url}: {e}")
        return None
    except Exception as e:
        print(f"Произошла другая ошибка при скачивании: {e}")
        return None


# --- Определение путей/ID моделей ---
# URL вашей SafeTensor модели с Civitai
CIVITAI_SAFETENSOR_URL = "https://civitai.com/api/download/models/1413133?type=Model&format=SafeTensor&size=full&fp=fp8"
# Локальное имя файла для сохранения SafeTensor модели внутри Space
LOCAL_SAFETENSOR_FILENAME = "ultrareal_fine_tune_fp8_full.safetensors"

# ControlNet модель с Hugging Face
CONTROLNET_MODEL_ID = "ABDALLALSWAITI/FLUX.1-dev-ControlNet-Union-Pro-2.0-fp8"

# --- Скачиваем SafeTensor модель (выполнится при запуске скрипта в Space) ---
print("Начинаю скачивание базовой модели...")
downloaded_base_model_path = download_file(CIVITAI_SAFETENSOR_URL, LOCAL_SAFETENSOR_FILENAME)

if not downloaded_base_model_path or not os.path.exists(downloaded_base_model_path):
    # Если скачивание не удалось или файл не существует после попытки
    print(f"Критическая ошибка: Не удалось получить файл базовой модели по пути: {LOCAL_SAFETENSOR_FILENAME}")
    print("Проверьте логи Space на наличие ошибок скачивания.")
    # Возможно, здесь стоит выбросить исключение или как-то иначе остановить приложение
    # Для примера, просто присвоим None и приложение не сможет загрузить пайплайн
    pipeline = None
else:
    # --- Загрузка моделей и создание пайплайна ---
    def load_pipeline_components(base_model_path, controlnet_model_id):
        """Загружает базовую модель из локального файла, ControlNet и собирает пайплайн."""
        print(f"Загрузка ControlNet модели: {controlnet_model_id}")
        # Загрузка ControlNet с Hugging Face Hub - кешируется автоматически Space
        controlnet = ControlNetModel.from_pretrained(controlnet_model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)

        print(f"Загрузка базовой модели из локального файла: {base_model_path}")
        # Загружаем базовую модель из локального SafeTensor файла
        # diffusers умеет загружать локальные файлы .safetensors
        pipe = StableDiffusionPipeline.from_pretrained(
            base_model_path, # Указываем путь к локальному файлу внутри Space
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
            safety_checker=None # Отключение safety checker для скорости (используйте с осторожностью!)
        )

        # Теперь объединяем базовый пайплайн с ControlNet
        # Создаем StableDiffusionControlNetPipeline на основе загруженного базового пайплайна
        print("Создание пайплайна StableDiffusionControlNetPipeline...")
        controlnet_pipe = StableDiffusionControlNetPipeline(
            vae=pipe.vae,
            text_encoder=pipe.text_encoder,
            tokenizer=pipe.tokenizer,
            unet=pipe.unet,
            controlnet=controlnet, # Передаем загруженный ControlNet
            scheduler=pipe.scheduler, # Используем планировщик из базового пайплайна
            safety_checker=None,
            feature_extractor=pipe.feature_extractor
        )

        # Рекомендуется использовать планировщик UniPC для ControlNet (или обновить существующий)
        # Обновляем планировщик в новом ControlNet пайплайне
        controlnet_pipe.scheduler = UniPCMultistepScheduler.from_config(controlnet_pipe.scheduler.config)

        # Удаляем старый базовый пайплайн для освобождения памяти
        del pipe
        if torch.cuda.is_available():
             torch.cuda.empty_cache()

        # Перемещаем ControlNet пайплайн на GPU, если доступно
        if torch.cuda.is_available():
            controlnet_pipe = controlnet_pipe.to("cuda")
            print("Пайплайн перемещен на GPU.")
        else:
            print("GPU не найдено. Пайплайн будет работать на CPU (будет медленно).") # В Space на CPU работать не будет эффективно

        return controlnet_pipe

    # Загружаем пайплайн при запуске скрипта, только если файл модели успешно скачан
    pipeline = load_pipeline_components(downloaded_base_model_path, CONTROLNET_MODEL_ID)


# --- Функция рендеринга для Gradio ---
# Эта функция будет вызываться интерфейсом Gradio в Space
def generate_image_gradio(controlnet_image: np.ndarray, prompt: str, negative_prompt: str = "", guidance_scale: float = 7.5, num_inference_steps: int = 30, controlnet_conditioning_scale: float = 1.0):
    """
    Генерирует изображение с использованием Stable Diffusion ControlNet.
    Принимает изображение NumPy, текст промта и другие параметры.
    Возвращает сгенерированное изображение в формате PIL Image.
    """
    # Проверяем, успешно ли загрузился пайплайн
    if pipeline is None:
         return None, "Ошибка: Пайплайн модели не загружен. Проверьте логи Space."

    if controlnet_image is None:
        return None, "Ошибка: необходимо загрузить изображение для ControlNet."

    print(f"Генерация изображения с промтом: '{prompt}'")
    print(f"Размер входного изображения: {controlnet_image.shape}")

    # Gradio возвращает изображение как numpy array. Преобразуем в PIL Image для пайплайна.
    input_image_pil = Image.fromarray(controlnet_image).convert("RGB")

    # Выполняем рендеринг с помощью пайплайна
    try:
        output = pipeline(
            prompt=prompt,
            image=input_image_pil, # Входное изображение для ControlNet
            negative_prompt=negative_prompt,
            guidance_scale=guidance_scale,
            num_inference_steps=num_inference_steps,
            controlnet_conditioning_scale=controlnet_conditioning_scale
            # Здесь можно добавить generator=... (для сидов), width=..., height=..., etc.
        )

        # Результат находится в output.images[0]
        generated_image_pil = output.images[0]

        print("Генерация завершена.")
        return generated_image_pil, "Успех!"
    except Exception as e:
        print(f"Ошибка при генерации: {e}")
        return None, f"Ошибка при генерации: {e}"


# --- Настройка интерфейса Gradio ---
# Определяем входные и выходные элементы
input_image_comp = gr.Image(type="numpy", label="Изображение для ControlNet (набросок, карта глубины и т.д.)")
prompt_comp = gr.Textbox(label="Промт (Prompt)")
negative_prompt_comp = gr.Textbox(label="Негативный промт (Negative Prompt)")
guidance_scale_comp = gr.Slider(minimum=1.0, maximum=20.0, value=7.5, step=0.1, label="Степень соответствия промту (Guidance Scale)")
num_inference_steps_comp = gr.Slider(minimum=10, maximum=150, value=30, step=1, label="Количество шагов (Inference Steps)")
controlnet_conditioning_scale_comp = gr.Slider(minimum=0.0, maximum=2.0, value=1.0, step=0.05, label="Вес ControlNet (ControlNet Scale)")

output_image_comp = gr.Image(type="pil", label="Сгенерированное изображение")
status_text_comp = gr.Textbox(label="Статус")


# Создаем интерфейс Gradio
# Поскольку мы в Space, Gradio SDK сам вызовет interface.launch()
# Нам просто нужно определить интерфейс
interface = gr.Interface(
    fn=generate_image_gradio,
    inputs=[
        input_image_comp,
        prompt_comp,
        negative_prompt_comp,
        guidance_scale_comp,
        num_inference_steps_comp,
        controlnet_conditioning_scale_comp
    ],
    outputs=[output_image_comp, status_text_comp],
    title="Stable Diffusion ControlNet Interface (SafeTensor Base Model)",
    description="Загрузите изображение для ControlNet, введите промт и нажмите 'Generate'. Используется локальная SafeTensor модель и ControlNet с Hugging Face."
)

# Важно: Не вызывайте interface.launch() в блоке if __name__ == "__main__":
# Gradio SDK в Space сделает это автоматически.
# Если вы оставите if __name__ == "__main__": interface.launch(), оно тоже будет работать,
# но в среде Space это менее критично, чем при локальном запуске.
# Для ясности в Space можно убрать блок if __name__ == "__main__":
# Я оставил его в коде выше, но знайте, что SDK вызовет interface.launch() независимо от него.