Futuretop's picture
Update app.py
62e8450 verified
import os, gc
import gradio as gr
import numpy as np
import random
from transformers import CLIPTokenizer, CLIPFeatureExtractor
import spaces
from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler
import torch
torch.cuda.empty_cache()
device = "cuda" if torch.cuda.is_available() else "cpu"
model_repo_id = "tensorart/stable-diffusion-3.5-large-TurboX"
if torch.cuda.is_available():
torch_dtype = torch.float16
else:
torch_dtype = torch.bfloat16
tokenizer = CLIPTokenizer.from_pretrained(
"openai/clip-vit-base-patch32", # or clip-vit-large if you prefer
use_fast=True
)
feature_extractor = CLIPFeatureExtractor.from_pretrained(
"openai/clip-vit-base-patch32"
)
# 3) Dispatch & load in FP16 with offloading
pipe = DiffusionPipeline.from_pretrained(
model_repo_id,
scheduler=FlowMatchEulerDiscreteScheduler.from_pretrained(
model_repo_id,
subfolder="scheduler",
shift=5,
use_safetensors=True
),
tokenizer=tokenizer,
feature_extractor=feature_extractor,
torch_dtype=torch.bfloat16, # load weights in half-precision
use_safetensors=True
)
# 4) Memory savings hooks (all on your single GPU + CPU offload)
pipe.enable_attention_slicing() # slice big attention maps
pipe.vae.enable_slicing() # slice VAE decode
pipe.enable_xformers_memory_efficient_attention() # if xformers is installed
pipe.enable_model_cpu_offload() # offload idle submodules to CPU
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
@spaces.GPU(duration=65)
def infer(
prompt,
negative_prompt="",
seed=42,
randomize_seed=False,
width=1024,
height=1024,
guidance_scale=1.5,
num_inference_steps=8,
progress=gr.Progress(track_tqdm=True),
):
full_prompt = "cartoon styled korean" + prompt
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
image = pipe(
prompt=full_prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
).images[0]
return image, seed
css = """
body {
background: linear-gradient(135deg, #f9e2e6 0%, #e8f3fc 50%, #e2f9f2 100%);
background-attachment: fixed;
min-height: 100vh;
}
#col-container {
margin: 0 auto;
max-width: 640px;
background-color: rgba(255, 255, 255, 0.85);
border-radius: 16px;
box-shadow: 0 8px 16px rgba(0, 0, 0, 0.1);
padding: 24px;
backdrop-filter: blur(10px);
}
.gradio-container {
background: transparent !important;
}
.gr-button-primary {
background: linear-gradient(90deg, #6b9dfc, #8c6bfc) !important;
border: none !important;
transition: all 0.3s ease;
}
.gr-button-primary:hover {
transform: translateY(-2px);
box-shadow: 0 5px 15px rgba(108, 99, 255, 0.3);
}
.gr-form {
border-radius: 12px;
background-color: rgba(255, 255, 255, 0.7);
}
.gr-accordion {
border-radius: 12px;
overflow: hidden;
}
h1 {
background: linear-gradient(90deg, #6b9dfc, #8c6bfc);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
font-weight: 800;
}
"""
with gr.Blocks(theme="apriel", css=css) as demo:
with gr.Column(elem_id="col-container"):
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt copied from the previous website",
container=False,
)
run_button = gr.Button("Run", scale=0, variant="primary")
result = gr.Image(label="Result", show_label=False)
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Text(
label="Negative prompt",
max_lines=1,
placeholder="Enter a negative prompt",
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=512,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=512,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=7.5,
step=0.1,
value=1.5,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=8,
)
gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[
prompt,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
],
outputs=[result, seed],
)
if __name__ == "__main__":
demo.launch(mcp_server=True)