Spaces:
Paused
Paused
import gradio as gr | |
import torch | |
from PIL import Image | |
from models.transformer_sd3 import SD3Transformer2DModel | |
from pipeline_stable_diffusion_3_ipa import StableDiffusion3Pipeline | |
import os | |
from huggingface_hub import login | |
token = os.getenv("HF_TOKEN") | |
login(token=token) | |
# Model and paths | |
model_path = 'stabilityai/stable-diffusion-3.5-large' | |
ip_adapter_path = './ip-adapter.bin' | |
image_encoder_path = "google/siglip-so400m-patch14-384" | |
ref_img_path = './assets/1.jpg' # Reference image path | |
# Load SD3.5 pipeline and components | |
transformer = SD3Transformer2DModel.from_pretrained( | |
model_path, subfolder="transformer", torch_dtype=torch.bfloat16 | |
) | |
pipe = StableDiffusion3Pipeline.from_pretrained( | |
model_path, transformer=transformer, torch_dtype=torch.bfloat16 | |
).to("cuda") | |
pipe.init_ipadapter( | |
ip_adapter_path=ip_adapter_path, | |
image_encoder_path=image_encoder_path, | |
nb_token=64, | |
) | |
def gui_generation(prompt: str, negative_prompt: str, ipadapter_scale: float, num_imgs: int): | |
""" | |
Generate images based on prompt, negative prompt, and IP-Adapter scale. | |
""" | |
ref_img = Image.open(ref_img_path).convert('RGB') # Load reference image | |
generator = torch.Generator("cuda").manual_seed(42) # Reproducibility | |
images = [] | |
for _ in range(num_imgs): | |
output = pipe( | |
width=1024, | |
height=1024, | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
num_inference_steps=24, | |
guidance_scale=5.0, | |
generator=generator, | |
clip_image=ref_img, | |
ipadapter_scale=ipadapter_scale, | |
).images[0] | |
images.append(output) | |
return images | |
# Gradio UI elements | |
prompt_box = gr.Textbox(label="Prompt", placeholder="Enter your generation prompt here") | |
negative_prompt_box = gr.Textbox(label="Negative Prompt", placeholder="e.g., lowres, worst quality") | |
ipadapter_slider = gr.Slider(0.1, 1.0, value=0.5, step=0.1, label="IP-Adapter Scale") | |
number_slider = gr.Slider(1, 5, value=1, step=1, label="Number of Images") | |
gallery = gr.Gallery(label="Generated Images", columns=[3], rows=[1], object_fit="contain", height="auto") | |
interface = gr.Interface( | |
gui_generation, | |
inputs=[prompt_box, negative_prompt_box, ipadapter_slider, number_slider], | |
outputs=gallery, | |
title="Stable Diffusion 3.5 Image Generation with IP-Adapter", | |
description="Generate high-quality images with Stable Diffusion 3.5 Large and IP-Adapter guidance." | |
) | |
interface.launch() | |