michellemoorre's picture
Fix ui and add apex
412b3d8
raw
history blame
4.05 kB
import gradio as gr
import numpy as np
import random
import spaces
from models import TVARPipeline
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
model_repo_id = "michellemoorre/var-test"
pipe = TVARPipeline.from_pretrained(model_repo_id, device=device)
MAX_SEED = np.iinfo(np.int32).max
@spaces.GPU(duration=65)
def infer(
prompt,
negative_prompt="",
seed=42,
randomize_seed=False,
guidance_scale=4.0,
top_k=450,
top_p=0.95,
re=True,
re_max_depth=10,
re_start_iter=2,
progress=gr.Progress(track_tqdm=True),
):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
image = pipe(
prompt=prompt,
null_prompt=negative_prompt,
cfg=guidance_scale,
top_p=top_p,
top_k=top_k,
re=re,
re_max_depth=re_max_depth,
re_start_iter=re_start_iter,
g_seed=seed,
)[0]
return image, seed
# TODO: add examples from preview
examples = [
"A capybara wearing a suit holding a sign that reads Hello World",
]
css = """
#col-container {
margin: 0 auto;
max-width: 640px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(" # [OpenTVAR](https://huggingface.co/stabilityai/stable-diffusion-3.5-large)")
gr.Markdown("[Learn more](https://stability.ai/news/introducing-stable-diffusion-3-5) about the OpenTVAR.")
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Run", scale=0, variant="primary")
result = gr.Image(label="Result", show_label=False)
seed = gr.Number(
label="Seed",
minimum=0,
maximum=MAX_SEED,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=7.5,
step=0.5,
value=4.,
)
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Text(
label="Negative prompt",
max_lines=1,
placeholder="Enter a negative prompt",
visible=True,
)
with gr.Row():
top_k = gr.Slider(
label="Sampling top k",
minimum=10,
maximum=1000,
step=20,
value=450,
)
top_p = gr.Slider(
label="Sampling top p",
minimum=0.0,
maximum=1.,
step=0.05,
value=0.95,
)
re = gr.Checkbox(label="Rejection Sampling (RE)", value=True)
with gr.Row():
re_max_depth = gr.Slider(
label="RE Depth",
minimum=0,
maximum=20,
step=4,
value=10,
)
re_start_iter = gr.Slider(
label="RE Start Scale",
minimum=0,
maximum=9,
step=1,
value=2,
)
gr.Examples(examples=examples, inputs=[prompt], outputs=[result, seed], fn=infer, cache_examples=True)# cache_mode="lazy")
gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[
prompt,
negative_prompt,
seed,
randomize_seed,
guidance_scale,
top_k,
top_p,
re,
re_max_depth,
re_start_iter,
],
outputs=[result, seed],
)
if __name__ == "__main__":
demo.launch()