File size: 5,838 Bytes
5fb8ff9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import gradio
import gradio as gr
import aiohttp
import asyncio
from PIL import Image
from io import BytesIO
from asyncio import Semaphore
from dotenv import load_dotenv
import os

# Загрузка токена из .env файла
load_dotenv()
API_TOKEN = os.getenv("HF_API_TOKEN")

# Конфигурация API
HEADERS = {"Authorization": f"Bearer {API_TOKEN}"}
MODELS = {
    "Midjourney": "Jovie/Midjourney",
    "FLUX.1 [dev]": "black-forest-labs/FLUX.1-dev",
    "Stable Diffusion v2.1": "stabilityai/stable-diffusion-2-1",
    "Stable Diffusion v3.5 Large": "stabilityai/stable-diffusion-3.5-large",
    "Stable Diffusion v1.0 Large": "stabilityai/stable-diffusion-xl-base-1.0",
    "Leonardo AI": "goofyai/Leonardo_Ai_Style_Illustration",
}

# Настройки
MAX_CONCURRENT_REQUESTS = 3
GROUP_DELAY = 1

# Асинхронная функция для отправки запроса к API
async def query_model(prompt, model_name, model_url, semaphore):
    async with semaphore:  # Ограничиваем количество одновременно выполняемых задач
        try:
            async with aiohttp.ClientSession() as session:
                async with session.post(
                    f"https://api-inference.huggingface.co/models/{model_url}",
                    headers=HEADERS,
                    json={"inputs": prompt},
                ) as response:
                    if response.status == 200:
                        image_data = await response.read()
                        return model_name, Image.open(BytesIO(image_data))
                    else:
                        error_message = await response.json()
                        warnings = error_message.get("warnings", [])
                        print(f"Ошибка для модели {model_name}: {error_message.get('error', 'unknown error')}")
                        if warnings:
                            print(f"Предупреждения для модели {model_name}: {warnings}")
                        return model_name, None
        except Exception as e:
            print(f"Ошибка соединения с моделью {model_name}: {e}")
            return model_name, None


# Асинхронная обработка запросов первой группы
async def handle_first_group(prompt):
    semaphore = Semaphore(MAX_CONCURRENT_REQUESTS)  # Создаём локальный семафор
    tasks = [
        query_model(prompt, model_name, model_url, semaphore)
        for model_name, model_url in list(MODELS.items())[:3]
    ]
    results = await asyncio.gather(*tasks)
    return {model_name: image for model_name, image in results if image}


# Асинхронная обработка запросов второй группы
async def handle_second_group(prompt):
    await asyncio.sleep(GROUP_DELAY)  # Задержка перед запросами ко второй группе
    semaphore = Semaphore(MAX_CONCURRENT_REQUESTS)  # Создаём локальный семафор
    tasks = [
        query_model(prompt, model_name, model_url, semaphore)
        for model_name, model_url in list(MODELS.items())[3:]
    ]
    results = await asyncio.gather(*tasks)
    return {model_name: image for model_name, image in results if image}


# Асинхронная обработка запросов
async def handle(prompt):
    # Обработка двух групп моделей
    first_group_results = await handle_first_group(prompt)
    second_group_results = await handle_second_group(prompt)
    return {**first_group_results, **second_group_results}


# Интерфейс Gradio
with gr.Blocks() as demo:
    gr.Markdown("## Генерация изображений с использованием моделей Hugging Face")

    # Поле ввода
    user_input = gr.Textbox(label="Введите описание изображения", placeholder="Например, 'Красный автомобиль в лесу'")

    # Вывод изображений
    with gr.Row():
        outputs = {name: gr.Image(label=name) for name in MODELS.keys()}

    # Кнопка генерации
    generate_button = gr.Button("Сгенерировать")

    # Асинхронная обработка ввода
    async def on_submit(prompt):
        results = await handle(prompt)
        return [results.get(name, None) for name in MODELS.keys()]

    generate_button.click(
        fn=on_submit,
        inputs=[user_input],
        outputs=list(outputs.values()),
    )
    user_input.submit(
        fn=on_submit,
        inputs=[user_input],
        outputs=list(outputs.values()),
    )

    # Ссылки на соцсети
    with gr.Row():
        with gr.Column(scale=1):
            gr.Image(value='icon.jpg')
        with gr.Column(scale=4):
            gradio.HTML("""<div style="text-align: center; font-family: 'Helvetica Neue', sans-serif; padding: 10px; color: #333333;">
        <p style="font-size: 18px; font-weight: 600; margin-bottom: 8px;">
            Эта демка была создана телеграм каналом <strong style="color: #007ACC;"><a href='https://t.me/mlphys'> mlphys</a></strong>. Другие мои социальные сети:
        </p>
        <p style="font-size: 16px;">
            <a href="https://t.me/mlphys" target="_blank" style="color: #0088cc; text-decoration: none; font-weight: 500;">Telegram</a> |
            <a href="https://x.com/quensy23" target="_blank" style="color: #1DA1F2; text-decoration: none; font-weight: 500;">Twitter</a> |
            <a href="https://github.com/freQuensy23-coder"  target="_blank" style="color: #0088cc; text-decoration: none; font-weight: 500;">GitHub</a>
        </p>
    </div>""")

demo.launch()