import gradio as gr import aiohttp import asyncio from PIL import Image from io import BytesIO from asyncio import Semaphore from dotenv import load_dotenv import os # Загрузка токена из .env файла load_dotenv() API_TOKEN = os.getenv("HF_API_TOKEN") # Конфигурация API HEADERS = {"Authorization": f"Bearer {API_TOKEN}"} MODELS = { "Stable Diffusion v1.5": "Yntec/stable-diffusion-v1-5", "Stable Diffusion v2.1": "stabilityai/stable-diffusion-2-1", "Stable Diffusion v3.5 Large": "stabilityai/stable-diffusion-3.5-large", } # Настройки MAX_CONCURRENT_REQUESTS = 3 # Асинхронная функция для отправки запроса к API async def query_model(prompt, model_name, model_url, semaphore): async with semaphore: # Ограничиваем количество одновременно выполняемых задач try: async with aiohttp.ClientSession() as session: async with session.post( f"https://api-inference.huggingface.co/models/{model_url}", headers=HEADERS, json={"inputs": prompt}, ) as response: if response.status == 200: image_data = await response.read() return model_name, Image.open(BytesIO(image_data)) else: error_message = await response.json() warnings = error_message.get("warnings", []) print(f"Ошибка для модели {model_name}: {error_message.get('error', 'unknown error')}") if warnings: print(f"Предупреждения для модели {model_name}: {warnings}") return model_name, None except Exception as e: print(f"Ошибка соединения с моделью {model_name}: {e}") return model_name, None # Асинхронная обработка запросов async def handle(prompt): semaphore = Semaphore(MAX_CONCURRENT_REQUESTS) # Создаём локальный семафор tasks = [ query_model(prompt, model_name, model_url, semaphore) for model_name, model_url in MODELS.items() ] results = await asyncio.gather(*tasks) return {model_name: image for model_name, image in results if image} # Интерфейс Gradio with gr.Blocks() as demo: gr.Markdown("## Генерация изображений с использованием моделей Hugging Face") # Поле ввода user_input = gr.Textbox(label="Введите описание изображения", placeholder="Например, 'Красный автомобиль в лесу'") # Вывод изображений with gr.Row(): outputs = {name: gr.Image(label=name) for name in MODELS.keys()} # Кнопка генерации generate_button = gr.Button("Сгенерировать") # Асинхронная обработка ввода async def on_submit(prompt): results = await handle(prompt) return [results.get(name, None) for name in MODELS.keys()] generate_button.click( fn=on_submit, inputs=[user_input], outputs=list(outputs.values()), ) user_input.submit( fn=on_submit, inputs=[user_input], outputs=list(outputs.values()), ) # Ссылки на соцсети with gr.Row(): gr.Markdown( """ ### Поддержка проекта - [Telegram](https://t.me/mlphys) - [GitHub](https://github.com/freQuensy23-coder) """ ) demo.launch()