Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,029 Bytes
886f105 7d40369 9461fdc 18c290b 9461fdc 7d40369 9461fdc 7d40369 9461fdc a0e45a8 9461fdc 8f2de47 7d40369 7054a36 9461fdc 7054a36 99d1063 7054a36 7d40369 9461fdc 7d40369 99d1063 7d40369 9461fdc 7d40369 99d1063 7d40369 9461fdc 7d40369 886f105 9461fdc 886f105 18c290b 7d40369 9461fdc 7d40369 9461fdc 18c290b 8f2de47 7d40369 9461fdc 7d40369 d8e7562 9461fdc 7d40369 9461fdc 7d40369 9461fdc 7d40369 a554d27 886f105 7d40369 9461fdc 7d40369 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import spaces
import torch
from diffusers import (
FluxPipeline,
StableDiffusion3Pipeline,
PixArtSigmaPipeline,
SanaPipeline,
AuraFlowPipeline,
Kandinsky3Pipeline,
HunyuanDiTPipeline,
LuminaText2ImgPipeline
)
import gradio as gr
cache_dir = '/workspace/hf_cache'
MODEL_CONFIGS = {
"FLUX": {
"repo_id": "black-forest-labs/FLUX.1-dev",
"pipeline_class": FluxPipeline,
"cache_dir": cache_dir,
},
"Stable Diffusion 3.5": {
"repo_id": "stabilityai/stable-diffusion-3.5-large",
"pipeline_class": StableDiffusion3Pipeline,
"cache_dir": cache_dir,
},
"PixArt": {
"repo_id": "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS",
"pipeline_class": PixArtSigmaPipeline,
"cache_dir": cache_dir,
},
"SANA": {
"repo_id": "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
"pipeline_class": SanaPipeline,
"cache_dir": cache_dir,
},
"AuraFlow": {
"repo_id": "fal/AuraFlow",
"pipeline_class": AuraFlowPipeline,
"cache_dir": cache_dir,
},
"Kandinsky": {
"repo_id": "kandinsky-community/kandinsky-3",
"pipeline_class": Kandinsky3Pipeline,
"cache_dir": cache_dir,
},
"Hunyuan": {
"repo_id": "Tencent-Hunyuan/HunyuanDiT-Diffusers",
"pipeline_class": HunyuanDiTPipeline,
"cache_dir": cache_dir,
},
"Lumina": {
"repo_id": "Alpha-VLLM/Lumina-Next-SFT-diffusers",
"pipeline_class": LuminaText2ImgPipeline,
"cache_dir": cache_dir,
}
}
def generate_image_with_progress(pipe, prompt, num_steps, guidance_scale=None, seed=None, progress=gr.Progress()):
generator = None
if seed is not None:
generator = torch.Generator("cuda").manual_seed(seed)
def callback(pipe, step_index, timestep, callback_kwargs):
print(f" callback => {pipe}, {step_index}, {timestep}")
if step_index is None:
step_index = 0
cur_prg = step_index / num_steps
progress(cur_prg, desc=f"Step {step_index}/{num_steps}")
return callback_kwargs
if hasattr(pipe, "guidance_scale"):
image = pipe(
prompt,
num_inference_steps=num_steps,
guidance_scale=guidance_scale,
callback_on_step_end=callback,
).images[0]
else:
image = pipe(
prompt,
num_inference_steps=num_steps,
generator=generator,
output_type="pil",
callback_on_step_end=callback,
).images[0]
return image
@spaces.GPU(duration=170)
def create_pipeline_logic(model_name, config):
def start_process(prompt_text):
print(f"starting {model_name}")
progress = gr.Progress()
num_steps = 30
guidance_scale = 7.5 # Example guidance scale, can be adjusted per model
seed = 42
pipe_class = config["pipeline_class"]
pipe = pipe_class.from_pretrained(
config["repo_id"],
#cache_dir=config["cache_dir"],
torch_dtype=torch.bfloat16
).to("cuda")
image = generate_image_with_progress(
pipe, prompt_text, num_steps=num_steps, guidance_scale=guidance_scale, seed=seed, progress=progress
)
return f"Seed: {seed}", image
return start_process
def main():
with gr.Blocks() as app:
gr.Markdown("# Dynamic Multiple Model Image Generation")
prompt_text = gr.Textbox(label="Enter prompt")
for model_name, config in MODEL_CONFIGS.items():
with gr.Tab(model_name):
button = gr.Button(f"Run {model_name}")
output = gr.Textbox(label="Status")
img = gr.Image(label=model_name, height=300)
start_process = create_pipeline_logic(model_name, config)
button.click(fn=start_process, inputs=[prompt_text], outputs=[output, img])
app.launch()
if __name__ == "__main__":
main()
|