Rooni's picture
Update app.py
d0dfd30 verified
raw
history blame
5.45 kB
import os
import random
import sys
import requests
from typing import Sequence, Mapping, Any, Union
import gradio as gr
from deep_translator import GoogleTranslator
from langdetect import detect
from gradio_client import Client, handle_file
# Функция для получения случайного API ключа
def get_random_api_key():
keys = os.getenv("KEYS", "").split(",")
if keys and keys[0]: # Проверяем, установлены ли ключи и не пусты ли они
return random.choice(keys).strip()
else:
raise ValueError("API ключи не найдены. Пожалуйста, установите переменную окружения KEYS.")
# Ссылка на файл CSS
css_url = "https://neurixyufi-aihub.static.hf.space/style.css"
# Получение CSS по ссылке
try:
response = requests.get(css_url)
response.raise_for_status()
css = response.text + " h1{text-align:center}"
except requests.exceptions.RequestException as e:
print(f"Ошибка при загрузке CSS: {e}")
css = " h1{text-align:center}"
# Функция для перевода текста на английский
def translate_to_english(prompt):
language = detect(prompt)
if language != 'en':
prompt = GoogleTranslator(source=language, target='en').translate(prompt)
return prompt
# Функция для загрузки изображений в кеш и отправки ссылки на API
def upload_image_to_hf_cache(image):
if isinstance(image, dict) and 'url' in image:
return image['url']
elif isinstance(image, str):
return image
else:
raise ValueError("Неподдерживаемый формат изображения")
# Функция для генерации изображения через API
def generate_image(prompt, structure_image, style_image, depth_strength=15, style_strength=0.5, progress=gr.Progress(track_tqdm=True)) -> str:
"""Основная функция генерации изображения."""
prompt = translate_to_english(prompt)
structure_image_url = upload_image_to_hf_cache(structure_image)
style_image_url = upload_image_to_hf_cache(style_image)
client = Client("multimodalart/flux-style-shaping", hf_token=get_random_api_key())
result = client.predict(
prompt=prompt,
structure_image=handle_file(structure_image_url),
style_image=handle_file(style_image_url),
depth_strength=depth_strength,
style_strength=style_strength,
api_name="/generate_image",
timeout=3000
)
if isinstance(result, str) and os.path.exists(result):
output_image = Image.open(result)
elif isinstance(result, bytes):
output_image = Image.open(BytesIO(result))
else:
raise ValueError(f"Неожиданный тип результата API: {type(result)}")
return output_image
# Примеры для Gradio
examples = [
["", "https://huggingface.co/spaces/multimodalart/flux-style-shaping/resolve/main/mona.png", "https://huggingface.co/spaces/multimodalart/flux-style-shaping/resolve/main/receita-tacos.webp", 15, 0.6],
["Девочка смотрит на дом, который горит", "https://huggingface.co/spaces/multimodalart/flux-style-shaping/resolve/main/disaster_girl.png", "https://huggingface.co/spaces/multimodalart/flux-style-shaping/resolve/main/abaporu.jpg", 15, 0.15],
["Город Истанбул с высоты птичьего полёта", "https://huggingface.co/spaces/multimodalart/flux-style-shaping/resolve/main/natasha.png", "https://huggingface.co/spaces/multimodalart/flux-style-shaping/resolve/main/istambul.jpg", 15, 0.5],
]
output_image = gr.Image(label="Сгенерированное изображение", show_share_button=False)
with gr.Blocks(css=css) as app:
gr.Markdown("# Структуратор")
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(label="Запрос", placeholder="Введите ваш запрос здесь...")
with gr.Row():
with gr.Group():
structure_image = gr.Image(label="Изображение структуры", type="filepath")
depth_strength = gr.Slider(minimum=0, maximum=50, value=15, label="Сила глубины")
with gr.Group():
style_image = gr.Image(label="Изображение стиля", type="filepath")
style_strength = gr.Slider(minimum=0, maximum=1, value=0.5, label="Сила стиля")
generate_btn = gr.Button("Создать", variant='primary')
gr.Examples(
examples=examples,
inputs=[prompt_input, structure_image, style_image, depth_strength, style_strength],
outputs=[output_image],
fn=generate_image,
label="Примеры",
cache_examples=False,
)
with gr.Column():
output_image.render()
generate_btn.click(
fn=generate_image,
inputs=[prompt_input, structure_image, style_image, depth_strength, style_strength],
outputs=[output_image],
concurrency_limit=250
)
if __name__ == "__main__":
app.launch(show_api=False, share=False)