TextToImages / app.py
Anonym26's picture
Update app.py
5fb8ff9 verified
raw
history blame
5.84 kB
import gradio
import gradio as gr
import aiohttp
import asyncio
from PIL import Image
from io import BytesIO
from asyncio import Semaphore
from dotenv import load_dotenv
import os
# Загрузка токена из .env файла
load_dotenv()
API_TOKEN = os.getenv("HF_API_TOKEN")
# Конфигурация API
HEADERS = {"Authorization": f"Bearer {API_TOKEN}"}
MODELS = {
"Midjourney": "Jovie/Midjourney",
"FLUX.1 [dev]": "black-forest-labs/FLUX.1-dev",
"Stable Diffusion v2.1": "stabilityai/stable-diffusion-2-1",
"Stable Diffusion v3.5 Large": "stabilityai/stable-diffusion-3.5-large",
"Stable Diffusion v1.0 Large": "stabilityai/stable-diffusion-xl-base-1.0",
"Leonardo AI": "goofyai/Leonardo_Ai_Style_Illustration",
}
# Настройки
MAX_CONCURRENT_REQUESTS = 3
GROUP_DELAY = 1
# Асинхронная функция для отправки запроса к API
async def query_model(prompt, model_name, model_url, semaphore):
async with semaphore: # Ограничиваем количество одновременно выполняемых задач
try:
async with aiohttp.ClientSession() as session:
async with session.post(
f"https://api-inference.huggingface.co/models/{model_url}",
headers=HEADERS,
json={"inputs": prompt},
) as response:
if response.status == 200:
image_data = await response.read()
return model_name, Image.open(BytesIO(image_data))
else:
error_message = await response.json()
warnings = error_message.get("warnings", [])
print(f"Ошибка для модели {model_name}: {error_message.get('error', 'unknown error')}")
if warnings:
print(f"Предупреждения для модели {model_name}: {warnings}")
return model_name, None
except Exception as e:
print(f"Ошибка соединения с моделью {model_name}: {e}")
return model_name, None
# Асинхронная обработка запросов первой группы
async def handle_first_group(prompt):
semaphore = Semaphore(MAX_CONCURRENT_REQUESTS) # Создаём локальный семафор
tasks = [
query_model(prompt, model_name, model_url, semaphore)
for model_name, model_url in list(MODELS.items())[:3]
]
results = await asyncio.gather(*tasks)
return {model_name: image for model_name, image in results if image}
# Асинхронная обработка запросов второй группы
async def handle_second_group(prompt):
await asyncio.sleep(GROUP_DELAY) # Задержка перед запросами ко второй группе
semaphore = Semaphore(MAX_CONCURRENT_REQUESTS) # Создаём локальный семафор
tasks = [
query_model(prompt, model_name, model_url, semaphore)
for model_name, model_url in list(MODELS.items())[3:]
]
results = await asyncio.gather(*tasks)
return {model_name: image for model_name, image in results if image}
# Асинхронная обработка запросов
async def handle(prompt):
# Обработка двух групп моделей
first_group_results = await handle_first_group(prompt)
second_group_results = await handle_second_group(prompt)
return {**first_group_results, **second_group_results}
# Интерфейс Gradio
with gr.Blocks() as demo:
gr.Markdown("## Генерация изображений с использованием моделей Hugging Face")
# Поле ввода
user_input = gr.Textbox(label="Введите описание изображения", placeholder="Например, 'Красный автомобиль в лесу'")
# Вывод изображений
with gr.Row():
outputs = {name: gr.Image(label=name) for name in MODELS.keys()}
# Кнопка генерации
generate_button = gr.Button("Сгенерировать")
# Асинхронная обработка ввода
async def on_submit(prompt):
results = await handle(prompt)
return [results.get(name, None) for name in MODELS.keys()]
generate_button.click(
fn=on_submit,
inputs=[user_input],
outputs=list(outputs.values()),
)
user_input.submit(
fn=on_submit,
inputs=[user_input],
outputs=list(outputs.values()),
)
# Ссылки на соцсети
with gr.Row():
with gr.Column(scale=1):
gr.Image(value='icon.jpg')
with gr.Column(scale=4):
gradio.HTML("""<div style="text-align: center; font-family: 'Helvetica Neue', sans-serif; padding: 10px; color: #333333;">
<p style="font-size: 18px; font-weight: 600; margin-bottom: 8px;">
Эта демка была создана телеграм каналом <strong style="color: #007ACC;"><a href='https://t.me/mlphys'> mlphys</a></strong>. Другие мои социальные сети:
</p>
<p style="font-size: 16px;">
<a href="https://t.me/mlphys" target="_blank" style="color: #0088cc; text-decoration: none; font-weight: 500;">Telegram</a> |
<a href="https://x.com/quensy23" target="_blank" style="color: #1DA1F2; text-decoration: none; font-weight: 500;">Twitter</a> |
<a href="https://github.com/freQuensy23-coder" target="_blank" style="color: #0088cc; text-decoration: none; font-weight: 500;">GitHub</a>
</p>
</div>""")
demo.launch()