Spaces:
Running
Running
import gradio | |
import gradio as gr | |
import aiohttp | |
import asyncio | |
from PIL import Image | |
from io import BytesIO | |
from asyncio import Semaphore | |
from dotenv import load_dotenv | |
import os | |
# Загрузка токена из .env файла | |
load_dotenv() | |
API_TOKEN = os.getenv("HF_API_TOKEN") | |
# Конфигурация API | |
HEADERS = {"Authorization": f"Bearer {API_TOKEN}"} | |
MODELS = { | |
"Midjourney": "Jovie/Midjourney", | |
"FLUX.1 [dev]": "black-forest-labs/FLUX.1-dev", | |
"Stable Diffusion v2.1": "stabilityai/stable-diffusion-2-1", | |
"Stable Diffusion v3.5 Large": "stabilityai/stable-diffusion-3.5-large", | |
"Stable Diffusion v1.0 Large": "stabilityai/stable-diffusion-xl-base-1.0", | |
"Leonardo AI": "goofyai/Leonardo_Ai_Style_Illustration", | |
} | |
# Настройки | |
MAX_CONCURRENT_REQUESTS = 3 | |
GROUP_DELAY = 1 | |
# Асинхронная функция для отправки запроса к API | |
async def query_model(prompt, model_name, model_url, semaphore): | |
async with semaphore: # Ограничиваем количество одновременно выполняемых задач | |
try: | |
async with aiohttp.ClientSession() as session: | |
async with session.post( | |
f"https://api-inference.huggingface.co/models/{model_url}", | |
headers=HEADERS, | |
json={"inputs": prompt}, | |
) as response: | |
if response.status == 200: | |
image_data = await response.read() | |
return model_name, Image.open(BytesIO(image_data)) | |
else: | |
error_message = await response.json() | |
warnings = error_message.get("warnings", []) | |
print(f"Ошибка для модели {model_name}: {error_message.get('error', 'unknown error')}") | |
if warnings: | |
print(f"Предупреждения для модели {model_name}: {warnings}") | |
return model_name, None | |
except Exception as e: | |
print(f"Ошибка соединения с моделью {model_name}: {e}") | |
return model_name, None | |
# Асинхронная обработка запросов первой группы | |
async def handle_first_group(prompt): | |
semaphore = Semaphore(MAX_CONCURRENT_REQUESTS) # Создаём локальный семафор | |
tasks = [ | |
query_model(prompt, model_name, model_url, semaphore) | |
for model_name, model_url in list(MODELS.items())[:3] | |
] | |
results = await asyncio.gather(*tasks) | |
return {model_name: image for model_name, image in results if image} | |
# Асинхронная обработка запросов второй группы | |
async def handle_second_group(prompt): | |
await asyncio.sleep(GROUP_DELAY) # Задержка перед запросами ко второй группе | |
semaphore = Semaphore(MAX_CONCURRENT_REQUESTS) # Создаём локальный семафор | |
tasks = [ | |
query_model(prompt, model_name, model_url, semaphore) | |
for model_name, model_url in list(MODELS.items())[3:] | |
] | |
results = await asyncio.gather(*tasks) | |
return {model_name: image for model_name, image in results if image} | |
# Асинхронная обработка запросов | |
async def handle(prompt): | |
# Обработка двух групп моделей | |
first_group_results = await handle_first_group(prompt) | |
second_group_results = await handle_second_group(prompt) | |
return {**first_group_results, **second_group_results} | |
# Интерфейс Gradio | |
with gr.Blocks() as demo: | |
gr.Markdown("## Генерация изображений с использованием моделей Hugging Face") | |
# Поле ввода | |
user_input = gr.Textbox(label="Введите описание изображения", placeholder="Например, 'Красный автомобиль в лесу'") | |
# Вывод изображений | |
with gr.Row(): | |
outputs = {name: gr.Image(label=name) for name in MODELS.keys()} | |
# Кнопка генерации | |
generate_button = gr.Button("Сгенерировать") | |
# Асинхронная обработка ввода | |
async def on_submit(prompt): | |
results = await handle(prompt) | |
return [results.get(name, None) for name in MODELS.keys()] | |
generate_button.click( | |
fn=on_submit, | |
inputs=[user_input], | |
outputs=list(outputs.values()), | |
) | |
user_input.submit( | |
fn=on_submit, | |
inputs=[user_input], | |
outputs=list(outputs.values()), | |
) | |
# Ссылки на соцсети | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Image(value='icon.jpg') | |
with gr.Column(scale=4): | |
gradio.HTML("""<div style="text-align: center; font-family: 'Helvetica Neue', sans-serif; padding: 10px; color: #333333;"> | |
<p style="font-size: 18px; font-weight: 600; margin-bottom: 8px;"> | |
Эта демка была создана телеграм каналом <strong style="color: #007ACC;"><a href='https://t.me/mlphys'> mlphys</a></strong>. Другие мои социальные сети: | |
</p> | |
<p style="font-size: 16px;"> | |
<a href="https://t.me/mlphys" target="_blank" style="color: #0088cc; text-decoration: none; font-weight: 500;">Telegram</a> | | |
<a href="https://x.com/quensy23" target="_blank" style="color: #1DA1F2; text-decoration: none; font-weight: 500;">Twitter</a> | | |
<a href="https://github.com/freQuensy23-coder" target="_blank" style="color: #0088cc; text-decoration: none; font-weight: 500;">GitHub</a> | |
</p> | |
</div>""") | |
demo.launch() | |