text3d-r1 / app.py
fantos's picture
Update app.py
f8844a3 verified
raw
history blame
7.72 kB
import spaces
import argparse
import os
import time
from os import path
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
import gradio as gr
import torch
from diffusers import FluxPipeline
# Setup and initialization code remains the same
cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
os.environ["TRANSFORMERS_CACHE"] = cache_path
os.environ["HF_HUB_CACHE"] = cache_path
os.environ["HF_HOME"] = cache_path
torch.backends.cuda.matmul.allow_tf32 = True
class timer:
def __init__(self, method_name="timed process"):
self.method = method_name
def __enter__(self):
self.start = time.time()
print(f"{self.method} starts")
def __exit__(self, exc_type, exc_val, exc_tb):
end = time.time()
print(f"{self.method} took {str(round(end - self.start, 2))}s")
# Model initialization
if not path.exists(cache_path):
os.makedirs(cache_path, exist_ok=True)
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"))
pipe.fuse_lora(lora_scale=0.125)
pipe.to(device="cuda", dtype=torch.bfloat16)
# Custom CSS for enhanced visual design
css = """
footer {display: none !important}
.container {max-width: 1200px; margin: auto;}
.gr-form {border-radius: 12px; padding: 20px; background: rgba(255, 255, 255, 0.05);}
.gr-box {border-radius: 8px; border: 1px solid rgba(255, 255, 255, 0.1);}
.gr-button {
border-radius: 8px;
background: linear-gradient(90deg, #4B79A1 0%, #283E51 100%);
border: none;
color: white;
transition: transform 0.2s ease;
}
.gr-button:hover {
transform: translateY(-2px);
box-shadow: 0 5px 15px rgba(0,0,0,0.2);
}
.gr-input {background: rgba(255, 255, 255, 0.05) !important;}
.gr-input:focus {border-color: #4B79A1 !important;}
.title-text {
text-align: center;
font-size: 2.5em;
font-weight: bold;
background: linear-gradient(90deg, #4B79A1 0%, #283E51 100%);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
margin-bottom: 1em;
}
"""
# Create Gradio interface with enhanced design
with gr.Blocks(theme=gr.themes.Soft(
primary_hue="blue",
secondary_hue="slate",
neutral_hue="slate",
font=gr.themes.GoogleFont("Inter")
), css=css) as demo:
gr.HTML("""
<div class="title-text">AI Image Generator</div>
<div style="text-align: center; margin-bottom: 2em; color: #666;">
Create stunning images from your descriptions using advanced AI
</div>
""")
with gr.Row().style(equal_height=True):
with gr.Column(scale=3):
with gr.Group():
prompt = gr.Textbox(
label="Image Description",
placeholder="Describe the image you want to create...",
lines=3,
elem_classes="gr-input"
)
with gr.Accordion("Advanced Settings", open=False):
with gr.Group():
with gr.Row():
with gr.Column(scale=1):
height = gr.Slider(
label="Height",
minimum=256,
maximum=1152,
step=64,
value=1024,
elem_classes="gr-input"
)
with gr.Column(scale=1):
width = gr.Slider(
label="Width",
minimum=256,
maximum=1152,
step=64,
value=1024,
elem_classes="gr-input"
)
with gr.Row():
with gr.Column(scale=1):
steps = gr.Slider(
label="Inference Steps",
minimum=6,
maximum=25,
step=1,
value=8,
elem_classes="gr-input"
)
with gr.Column(scale=1):
scales = gr.Slider(
label="Guidance Scale",
minimum=0.0,
maximum=5.0,
step=0.1,
value=3.5,
elem_classes="gr-input"
)
seed = gr.Number(
label="Seed (for reproducibility)",
value=3413,
precision=0,
elem_classes="gr-input"
)
generate_btn = gr.Button(
"✨ Generate Image",
variant="primary",
scale=1,
elem_classes="gr-button"
)
gr.HTML("""
<div style="margin-top: 1em; padding: 1em; border-radius: 8px; background: rgba(255, 255, 255, 0.05);">
<h4 style="margin: 0 0 0.5em 0;">Tips for best results:</h4>
<ul style="margin: 0; padding-left: 1.2em;">
<li>Be specific in your descriptions</li>
<li>Include details about style, lighting, and mood</li>
<li>Experiment with different guidance scales</li>
</ul>
</div>
""")
with gr.Column(scale=4):
output = gr.Image(
label="Generated Image",
elem_classes="gr-box",
height=512
)
with gr.Group(visible=False) as loading_info:
gr.HTML("""
<div style="text-align: center; padding: 1em;">
<div style="display: inline-block; animation: spin 1s linear infinite;">⚙️</div>
<p>Generating your image...</p>
</div>
""")
@spaces.GPU
def process_image(height, width, steps, scales, prompt, seed):
global pipe
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"):
return pipe(
prompt=[prompt],
generator=torch.Generator().manual_seed(int(seed)),
num_inference_steps=int(steps),
guidance_scale=float(scales),
height=int(height),
width=int(width),
max_sequence_length=256
).images[0]
# Add loading state
generate_btn.click(
fn=lambda: gr.update(visible=True),
outputs=[loading_info],
queue=False
).then(
process_image,
inputs=[height, width, steps, scales, prompt, seed],
outputs=output
).then(
fn=lambda: gr.update(visible=False),
outputs=[loading_info]
)
if __name__ == "__main__":
demo.launch()