webtoon-gen / app.py
openfree's picture
Update app.py
43a7c4d verified
raw
history blame
15 kB
import os
import gc
import uuid
import random
import tempfile
import time
from datetime import datetime
from typing import Any
from huggingface_hub import login, hf_hub_download
import spaces
import gradio as gr
import numpy as np
import torch
from PIL import Image, ImageDraw, ImageFont
from diffusers import FluxPipeline
from transformers import pipeline
# ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ ํ•จ์ˆ˜
def clear_memory():
gc.collect()
try:
if torch.cuda.is_available():
with torch.cuda.device(0):
torch.cuda.empty_cache()
except:
pass
# GPU ์„ค์ •
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
try:
with torch.cuda.device(0):
torch.cuda.empty_cache()
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
except:
print("Warning: Could not configure CUDA settings")
# HF ํ† ํฐ ์„ค์ •
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN is None:
raise ValueError("Please set the HF_TOKEN environment variable")
try:
login(token=HF_TOKEN)
except Exception as e:
raise ValueError(f"Failed to login to Hugging Face: {str(e)}")
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en", device=-1) # CPU์—์„œ ์‹คํ–‰
def translate_to_english(text: str) -> str:
"""ํ•œ๊ธ€ ํ…์ŠคํŠธ๋ฅผ ์˜์–ด๋กœ ๋ฒˆ์—ญ"""
try:
if any(ord('๊ฐ€') <= ord(char) <= ord('ํžฃ') for char in text):
translated = translator(text, max_length=128)[0]['translation_text']
print(f"Translated '{text}' to '{translated}'")
return translated
return text
except Exception as e:
print(f"Translation error: {str(e)}")
return text
# FLUX ํŒŒ์ดํ”„๋ผ์ธ ์ดˆ๊ธฐํ™” ๋ถ€๋ถ„ ์ˆ˜์ •
print("Initializing FLUX pipeline...")
try:
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.float16,
use_auth_token=HF_TOKEN
)
print("FLUX pipeline initialized successfully")
# ๋ฉ”๋ชจ๋ฆฌ ์ตœ์ ํ™” ์„ค์ •
pipe.enable_attention_slicing(slice_size=1)
# GPU ์„ค์ •
if torch.cuda.is_available():
pipe = pipe.to("cuda:0")
torch.cuda.empty_cache()
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
print("Pipeline optimization settings applied")
except Exception as e:
print(f"Error initializing FLUX pipeline: {str(e)}")
raise
# LoRA ๊ฐ€์ค‘์น˜ ๋กœ๋“œ ๋ถ€๋ถ„ ์ˆ˜์ •
print("Loading LoRA weights...")
try:
# ๋กœ์ปฌ LoRA ํŒŒ์ผ์˜ ์ ˆ๋Œ€ ๊ฒฝ๋กœ ํ™•์ธ
current_dir = os.path.dirname(os.path.abspath(__file__))
lora_path = os.path.join(current_dir, "myt-flux-fantasy.safetensors")
if not os.path.exists(lora_path):
raise FileNotFoundError(f"LoRA file not found at: {lora_path}")
print(f"Loading LoRA weights from: {lora_path}")
# LoRA ๊ฐ€์ค‘์น˜ ๋กœ๋“œ
pipe.load_lora_weights(lora_path)
pipe.fuse_lora(lora_scale=0.75) # lora_scale ๊ฐ’ ์กฐ์ •
# ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
torch.cuda.empty_cache()
gc.collect()
print("LoRA weights loaded and fused successfully")
print(f"Current device: {pipe.device}")
except Exception as e:
print(f"Error loading LoRA weights: {str(e)}")
print(f"Full error details: {repr(e)}")
raise ValueError(f"Failed to load LoRA weights: {str(e)}")
@spaces.GPU(duration=60)
def generate_image(
prompt: str,
seed: int,
randomize_seed: bool,
width: int,
height: int,
guidance_scale: float,
num_inference_steps: int,
progress: gr.Progress = gr.Progress()
):
try:
clear_memory()
translated_prompt = translate_to_english(prompt)
print(f"Processing prompt: {translated_prompt}")
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device=device).manual_seed(seed)
print(f"Current device: {pipe.device}")
print(f"Starting image generation...")
with torch.inference_mode(), torch.cuda.amp.autocast(enabled=True):
image = pipe(
prompt=translated_prompt,
width=width,
height=height,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator,
num_images_per_prompt=1,
).images[0]
filepath = save_generated_image(image, translated_prompt)
print(f"Image generated and saved to: {filepath}")
return image, seed
except Exception as e:
print(f"Generation error: {str(e)}")
print(f"Full error details: {repr(e)}")
raise gr.Error(f"Image generation failed: {str(e)}")
finally:
clear_memory()
# ์ €์žฅ ๋””๋ ‰ํ† ๋ฆฌ ์„ค์ •
SAVE_DIR = "saved_images"
if not os.path.exists(SAVE_DIR):
os.makedirs(SAVE_DIR, exist_ok=True)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
def save_generated_image(image, prompt):
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
unique_id = str(uuid.uuid4())[:8]
filename = f"{timestamp}_{unique_id}.png"
filepath = os.path.join(SAVE_DIR, filename)
image.save(filepath)
return filepath
def add_text_with_stroke(draw, text, x, y, font, text_color, stroke_width):
"""ํ…์ŠคํŠธ์— ์™ธ๊ณฝ์„ ์„ ์ถ”๊ฐ€ํ•˜๋Š” ํ•จ์ˆ˜"""
for adj_x in range(-stroke_width, stroke_width + 1):
for adj_y in range(-stroke_width, stroke_width + 1):
draw.text((x + adj_x, y + adj_y), text, font=font, fill=text_color)
def add_text_to_image(
input_image,
text,
font_size,
color,
opacity,
x_position,
y_position,
thickness,
text_position_type,
font_choice
):
try:
if input_image is None or text.strip() == "":
return input_image
if not isinstance(input_image, Image.Image):
if isinstance(input_image, np.ndarray):
image = Image.fromarray(input_image)
else:
raise ValueError("Unsupported image type")
else:
image = input_image.copy()
if image.mode != 'RGBA':
image = image.convert('RGBA')
font_files = {
"Default": "DejaVuSans.ttf",
"Korean Regular": "ko-Regular.ttf"
}
try:
font_file = font_files.get(font_choice, "DejaVuSans.ttf")
font = ImageFont.truetype(font_file, int(font_size))
except Exception as e:
print(f"Font loading error ({font_choice}): {str(e)}")
font = ImageFont.load_default()
color_map = {
'White': (255, 255, 255),
'Black': (0, 0, 0),
'Red': (255, 0, 0),
'Green': (0, 255, 0),
'Blue': (0, 0, 255),
'Yellow': (255, 255, 0),
'Purple': (128, 0, 128)
}
rgb_color = color_map.get(color, (255, 255, 255))
temp_draw = ImageDraw.Draw(image)
text_bbox = temp_draw.textbbox((0, 0), text, font=font)
text_width = text_bbox[2] - text_bbox[0]
text_height = text_bbox[3] - text_bbox[1]
actual_x = int((image.width - text_width) * (x_position / 100))
actual_y = int((image.height - text_height) * (y_position / 100))
text_color = (*rgb_color, int(opacity))
txt_overlay = Image.new('RGBA', image.size, (255, 255, 255, 0))
draw = ImageDraw.Draw(txt_overlay)
add_text_with_stroke(
draw,
text,
actual_x,
actual_y,
font,
text_color,
int(thickness)
)
output_image = Image.alpha_composite(image, txt_overlay)
output_image = output_image.convert('RGB')
return output_image
except Exception as e:
print(f"Error in add_text_to_image: {str(e)}")
return input_image
css = """
footer {display: none}
.main-title {
text-align: center;
margin: 1em 0;
padding: 1.5em;
background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
border-radius: 15px;
box-shadow: 0 4px 6px rgba(0,0,0,0.1);
}
.main-title h1 {
color: #2196F3;
font-size: 2.8em;
margin-bottom: 0.3em;
font-weight: 700;
}
.main-title p {
color: #555;
font-size: 1.3em;
line-height: 1.4;
}
.container {
max-width: 1200px;
margin: auto;
padding: 20px;
}
.input-panel, .output-panel {
background: white;
padding: 1.5em;
border-radius: 12px;
box-shadow: 0 2px 8px rgba(0,0,0,0.08);
margin-bottom: 1em;
}
"""
import requests
def enhance_prompt(prompt: str) -> str:
"""ํ”„๋กฌํ”„ํŠธ๋ฅผ ์• ๋‹ˆ๋ฉ”์ด์…˜ ์Šคํƒ€์ผ๋กœ ์ฆ๊ฐ•"""
try:
# ๊ธฐ๋ณธ ํ’ˆ์งˆ ํ–ฅ์ƒ ํ”„๋กฌํ”„ํŠธ ์ถ”๊ฐ€
enhancements = [
"masterpiece, best quality, highly detailed",
"anime style, animation style",
"vibrant colors, perfect lighting",
"professional composition",
"dynamic pose, expressive features",
"detailed background, perfect shadows",
"[trigger]"
]
# ์• ๋‹ˆ๋ฉ”์ด์…˜ ์Šคํƒ€์ผ ํ”„๋กฌํ”„ํŠธ ๋ณ€ํ™˜
anime_style_prompt = f"an animated {prompt}, detailed anime art style"
# ์ตœ์ข… ํ”„๋กฌํ”„ํŠธ ๊ตฌ์„ฑ
final_prompt = f"{anime_style_prompt}, {', '.join(enhancements)}"
print(f"Enhanced prompt: {final_prompt}")
return final_prompt
except Exception as e:
print(f"Prompt enhancement failed: {str(e)}")
return prompt
# ๊ธฐ์กด์˜ pipeline ์ดˆ๊ธฐํ™” ๋ถ€๋ถ„ ์ œ๊ฑฐ
# try:
# prompt_enhancer = pipeline(...)
# except Exception as e:
# print(f"Error initializing prompt enhancer: {str(e)}")
# prompt_enhancer = None
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
gr.HTML("""
<div class="main-title">
<h1>๐ŸŽจ Webtoon Studio</h1>
<p>Generate webtoon-style images and add text with various styles and positions.</p>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
gen_prompt = gr.Textbox(
label="Generation Prompt",
placeholder="Enter your image generation prompt..."
)
enhance_btn = gr.Button("โœจ Enhance Prompt", variant="secondary")
with gr.Row():
gen_width = gr.Slider(512, 1024, 768, step=64, label="Width")
gen_height = gr.Slider(512, 1024, 768, step=64, label="Height")
with gr.Row():
guidance_scale = gr.Slider(1, 20, 7.5, step=0.5, label="Guidance Scale")
num_steps = gr.Slider(1, 50, 30, step=1, label="Number of Steps")
with gr.Row():
seed = gr.Number(label="Seed", value=-1)
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
generate_btn = gr.Button("Generate Image", variant="primary")
output_image = gr.Image(
label="Generated Image",
type="pil",
show_download_button=True
)
output_seed = gr.Number(label="Used Seed", interactive=False)
# ํ…์ŠคํŠธ ์ถ”๊ฐ€ ์„น์…˜
with gr.Accordion("Text Options", open=False):
text_input = gr.Textbox(
label="Text Content",
placeholder="Enter text to add..."
)
text_position_type = gr.Radio(
choices=["Text Over Image"],
value="Text Over Image",
label="Text Position",
visible=True
)
with gr.Row():
font_choice = gr.Dropdown(
choices=["Default", "Korean Regular"],
value="Default",
label="Font Selection",
interactive=True
)
font_size = gr.Slider(
minimum=10,
maximum=200,
value=40,
step=5,
label="Font Size"
)
with gr.Row():
color_dropdown = gr.Dropdown(
choices=["White", "Black", "Red", "Green", "Blue", "Yellow", "Purple"],
value="White",
label="Text Color"
)
thickness = gr.Slider(
minimum=0,
maximum=10,
value=1,
step=1,
label="Text Thickness"
)
with gr.Row():
opacity_slider = gr.Slider(
minimum=0,
maximum=255,
value=255,
step=1,
label="Opacity"
)
with gr.Row():
x_position = gr.Slider(
minimum=0,
maximum=100,
value=50,
step=1,
label="Left(0%)~Right(100%)"
)
y_position = gr.Slider(
minimum=0,
maximum=100,
value=50,
step=1,
label="High(0%)~Low(100%)"
)
add_text_btn = gr.Button("Apply Text", variant="primary")
# ์ด๋ฒคํŠธ ๋ฐ”์ธ๋”ฉ
generate_btn.click(
fn=generate_image,
inputs=[
gen_prompt,
seed,
randomize_seed,
gen_width,
gen_height,
guidance_scale,
num_steps,
],
outputs=[output_image, output_seed]
)
add_text_btn.click(
fn=add_text_to_image,
inputs=[
output_image,
text_input,
font_size,
color_dropdown,
opacity_slider,
x_position,
y_position,
thickness,
text_position_type,
font_choice
],
outputs=output_image
)
# ์ด๋ฒคํŠธ ๋ฐ”์ธ๋”ฉ ์ถ”๊ฐ€
def update_prompt(prompt):
enhanced = enhance_prompt(prompt)
return enhanced
enhance_btn.click(
fn=update_prompt,
inputs=[gen_prompt],
outputs=[gen_prompt]
)
demo.queue(max_size=5)
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
max_threads=2
)