Anonym26 commited on
Commit
4095011
·
verified ·
1 Parent(s): 9b23bee

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -96
app.py CHANGED
@@ -1,96 +1,132 @@
1
- import gradio as gr
2
- import aiohttp
3
- import asyncio
4
- from PIL import Image
5
- from io import BytesIO
6
- from dotenv import load_dotenv
7
- import os
8
-
9
- # Загрузка токена из .env файла
10
- load_dotenv()
11
- API_TOKEN = os.getenv("HF_API_TOKEN")
12
-
13
- # Конфигурация API
14
- HEADERS = {"Authorization": f"Bearer {API_TOKEN}"}
15
- MODELS = {
16
- "Midjourney": "Jovie/Midjourney",
17
- "FLUX.1 [dev]": "black-forest-labs/FLUX.1-dev",
18
- "Stable Diffusion v2.1": "stabilityai/stable-diffusion-2-1",
19
- "Stable Diffusion v3.5 Large": "stabilityai/stable-diffusion-3.5-large",
20
- "Stable Diffusion v1.0 Large": "stabilityai/stable-diffusion-xl-base-1.0",
21
- "Leonardo AI": "goofyai/Leonardo_Ai_Style_Illustration",
22
- }
23
-
24
- # Асинхронная функция для отправки запроса к API
25
- async def query_model(prompt, model_name, model_url, session):
26
- try:
27
- async with session.post(
28
- f"https://api-inference.huggingface.co/models/{model_url}",
29
- headers=HEADERS,
30
- json={"inputs": prompt},
31
- ) as response:
32
- if response.status == 200:
33
- image_data = await response.read()
34
- return model_name, Image.open(BytesIO(image_data))
35
- else:
36
- error_message = await response.json()
37
- warnings = error_message.get("warnings", [])
38
- print(f"Ошибка для модели {model_name}: {error_message.get('error', 'unknown error')}")
39
- if warnings:
40
- print(f"Предупреждения для модели {model_name}: {warnings}")
41
- return model_name, None
42
- except Exception as e:
43
- print(f"Ошибка соединения с моделью {model_name}: {e}")
44
- return model_name, None
45
-
46
- # Асинхронная обработка всех запросов
47
- async def handle(prompt):
48
- async with aiohttp.ClientSession() as session: # Создаём единый клиент для всех запросов
49
- tasks = [
50
- query_model(prompt, model_name, model_url, session)
51
- for model_name, model_url in MODELS.items()
52
- ]
53
- results = await asyncio.gather(*tasks)
54
- return {model_name: image for model_name, image in results if image}
55
-
56
- # Интерфейс Gradio
57
- with gr.Blocks() as demo:
58
- gr.Markdown("## Генерация изображений с использованием моделей Hugging Face")
59
-
60
- # Поле ввода
61
- user_input = gr.Textbox(label="Введите описание изображения", placeholder="Например, 'Красный автомобиль в лесу'")
62
-
63
- # Вывод изображений
64
- with gr.Row():
65
- outputs = {name: gr.Image(label=name) for name in MODELS.keys()}
66
-
67
- # Асинхронная обработка ввода
68
- async def on_submit(prompt):
69
- results = await handle(prompt)
70
- return [results.get(name, None) for name in MODELS.keys()]
71
-
72
- # Кнопка генерации
73
- generate_button = gr.Button("Сгенерировать")
74
- generate_button.click(
75
- fn=on_submit,
76
- inputs=[user_input],
77
- outputs=list(outputs.values()),
78
- )
79
- user_input.submit(
80
- fn=on_submit,
81
- inputs=[user_input],
82
- outputs=list(outputs.values()),
83
- )
84
-
85
- # Ссылки на соцсети
86
- with gr.Row():
87
- with gr.Column(scale=1):
88
- gr.Image(value='icon.jpg')
89
- with gr.Column(scale=4):
90
- gr.Markdown("""
91
- ### Поддержка проекта
92
- - [Telegram](https://t.me/mlphys)
93
- - [GitHub](https://github.com/freQuensy23-coder)
94
- """)
95
-
96
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio
2
+ import gradio as gr
3
+ import aiohttp
4
+ import asyncio
5
+ from PIL import Image
6
+ from io import BytesIO
7
+ from asyncio import Semaphore
8
+ from dotenv import load_dotenv
9
+ import os
10
+
11
+ # Загрузка токена из .env файла
12
+ load_dotenv()
13
+ API_TOKEN = os.getenv("HF_API_TOKEN")
14
+
15
+ # Конфигурация API
16
+ HEADERS = {"Authorization": f"Bearer {API_TOKEN}"}
17
+ MODELS = {
18
+ "Midjourney": "Jovie/Midjourney",
19
+ "FLUX.1 [dev]": "black-forest-labs/FLUX.1-dev",
20
+ "Stable Diffusion v2.1": "stabilityai/stable-diffusion-2-1",
21
+ "Stable Diffusion v3.5 Large": "stabilityai/stable-diffusion-3.5-large",
22
+ "Stable Diffusion v1.0 Large": "stabilityai/stable-diffusion-xl-base-1.0",
23
+ "Leonardo AI": "goofyai/Leonardo_Ai_Style_Illustration",
24
+ }
25
+
26
+ # Настройки
27
+ MAX_CONCURRENT_REQUESTS = 3
28
+ GROUP_DELAY = 61
29
+
30
+ # Асинхронная функция для отправки запроса к API
31
+ async def query_model(prompt, model_name, model_url, semaphore):
32
+ async with semaphore: # Ограничиваем количество одновременно выполняемых задач
33
+ try:
34
+ async with aiohttp.ClientSession() as session:
35
+ async with session.post(
36
+ f"https://api-inference.huggingface.co/models/{model_url}",
37
+ headers=HEADERS,
38
+ json={"inputs": prompt},
39
+ ) as response:
40
+ if response.status == 200:
41
+ image_data = await response.read()
42
+ return model_name, Image.open(BytesIO(image_data))
43
+ else:
44
+ error_message = await response.json()
45
+ warnings = error_message.get("warnings", [])
46
+ print(f"Ошибка для модели {model_name}: {error_message.get('error', 'unknown error')}")
47
+ if warnings:
48
+ print(f"Предупреждения для модели {model_name}: {warnings}")
49
+ return model_name, None
50
+ except Exception as e:
51
+ print(f"Ошибка соединения с моделью {model_name}: {e}")
52
+ return model_name, None
53
+
54
+
55
+ # Асинхронная обработка запросов первой группы
56
+ async def handle_first_group(prompt):
57
+ semaphore = Semaphore(MAX_CONCURRENT_REQUESTS) # Создаём локальный семафор
58
+ tasks = [
59
+ query_model(prompt, model_name, model_url, semaphore)
60
+ for model_name, model_url in list(MODELS.items())[:3]
61
+ ]
62
+ results = await asyncio.gather(*tasks)
63
+ return {model_name: image for model_name, image in results if image}
64
+
65
+
66
+ # Асинхронная обработка запросов второй группы
67
+ async def handle_second_group(prompt):
68
+ await asyncio.sleep(GROUP_DELAY) # Задержка перед запросами ко второй группе
69
+ semaphore = Semaphore(MAX_CONCURRENT_REQUESTS) # Создаём локальный семафор
70
+ tasks = [
71
+ query_model(prompt, model_name, model_url, semaphore)
72
+ for model_name, model_url in list(MODELS.items())[3:]
73
+ ]
74
+ results = await asyncio.gather(*tasks)
75
+ return {model_name: image for model_name, image in results if image}
76
+
77
+
78
+ # Асинхронная обработка запросов
79
+ async def handle(prompt):
80
+ # Обработка двух групп моделей
81
+ first_group_results = await handle_first_group(prompt)
82
+ second_group_results = await handle_second_group(prompt)
83
+ return {**first_group_results, **second_group_results}
84
+
85
+
86
+ # Интерфейс Gradio
87
+ with gr.Blocks() as demo:
88
+ gr.Markdown("## Генерация изображений с использованием моделей Hugging Face")
89
+
90
+ # Поле ввода
91
+ user_input = gr.Textbox(label="Введите описание изображения", placeholder="Например, 'Красный автомобиль в лесу'")
92
+
93
+ # Вывод изображений
94
+ with gr.Row():
95
+ outputs = {name: gr.Image(label=name) for name in MODELS.keys()}
96
+
97
+ # Кнопка генерации
98
+ generate_button = gr.Button("Сгенерировать")
99
+
100
+ # Асинхронная обработка ввода
101
+ async def on_submit(prompt):
102
+ results = await handle(prompt)
103
+ return [results.get(name, None) for name in MODELS.keys()]
104
+
105
+ generate_button.click(
106
+ fn=on_submit,
107
+ inputs=[user_input],
108
+ outputs=list(outputs.values()),
109
+ )
110
+ user_input.submit(
111
+ fn=on_submit,
112
+ inputs=[user_input],
113
+ outputs=list(outputs.values()),
114
+ )
115
+
116
+ # Ссылки на соцсети
117
+ with gr.Row():
118
+ with gr.Column(scale=1):
119
+ gr.Image(value='icon.jpg')
120
+ with gr.Column(scale=4):
121
+ gradio.HTML("""<div style="text-align: center; font-family: 'Helvetica Neue', sans-serif; padding: 10px; color: #333333;">
122
+ <p style="font-size: 18px; font-weight: 600; margin-bottom: 8px;">
123
+ Эта демка была создана телеграм каналом <strong style="color: #007ACC;"><a href='https://t.me/mlphys'> mlphys</a></strong>. Другие мои социальные сети:
124
+ </p>
125
+ <p style="font-size: 16px;">
126
+ <a href="https://t.me/mlphys" target="_blank" style="color: #0088cc; text-decoration: none; font-weight: 500;">Telegram</a> |
127
+ <a href="https://x.com/quensy23" target="_blank" style="color: #1DA1F2; text-decoration: none; font-weight: 500;">Twitter</a> |
128
+ <a href="https://github.com/freQuensy23-coder" target="_blank" style="color: #0088cc; text-decoration: none; font-weight: 500;">GitHub</a>
129
+ </p>
130
+ </div>""")
131
+
132
+ demo.launch()