Spaces:
Paused
Paused
import os | |
import requests | |
import torch | |
import gradio as gr | |
import spaces | |
from PIL import Image | |
from huggingface_hub import login | |
from torchvision import transforms | |
from diffusers.utils import load_image | |
from models.transformer_sd3 import SD3Transformer2DModel | |
from pipeline_stable_diffusion_3_ipa import StableDiffusion3Pipeline | |
# ---------------------------- | |
# Step 1: Download IP Adapter if not exists | |
# ---------------------------- | |
url = "https://huggingface.co/InstantX/SD3.5-Large-IP-Adapter/resolve/main/ip-adapter.bin" | |
file_path = "ip-adapter.bin" | |
if not os.path.exists(file_path): | |
print("File not found, downloading...") | |
response = requests.get(url, stream=True) | |
with open(file_path, "wb") as file: | |
for chunk in response.iter_content(chunk_size=1024): | |
if chunk: | |
file.write(chunk) | |
print("Download completed!") | |
# ---------------------------- | |
# Step 2: Hugging Face Login | |
# ---------------------------- | |
token = os.getenv("HF_TOKEN") | |
if not token: | |
raise ValueError("Hugging Face token not found. Set the 'HF_TOKEN' environment variable.") | |
login(token=token) | |
# ---------------------------- | |
# Step 3: Model Paths | |
# ---------------------------- | |
model_path = 'stabilityai/stable-diffusion-3.5-large' | |
ip_adapter_path = './ip-adapter.bin' | |
image_encoder_path = "google/siglip-so400m-patch14-384" | |
# ---------------------------- | |
# Step 4: Load Transformer and Pipeline | |
# ---------------------------- | |
transformer = SD3Transformer2DModel.from_pretrained( | |
model_path, subfolder="transformer", torch_dtype=torch.float16 | |
) | |
pipe = StableDiffusion3Pipeline.from_pretrained( | |
model_path, transformer=transformer, torch_dtype=torch.float16 | |
).to("cuda") | |
pipe.init_ipadapter( | |
ip_adapter_path=ip_adapter_path, | |
image_encoder_path=image_encoder_path, | |
nb_token=64, | |
) | |
# ---------------------------- | |
# Step 5: Preprocess Reference Image | |
# ---------------------------- | |
def preprocess_image(image_path): | |
"""Ensure the input image is a valid PIL Image and resize it.""" | |
image = Image.open(image_path).convert("RGB") | |
# Ensure the image is resized into a square | |
size = max(image.size) # Get the largest dimension | |
image = image.resize((size, size), Image.BILINEAR) | |
preprocess = transforms.Compose([ | |
transforms.Resize((384, 384)), | |
transforms.ToTensor(), | |
transforms.ConvertImageDtype(torch.float16) | |
]) | |
return preprocess(image).unsqueeze(0).to("cuda") | |
# ---------------------------- | |
# Step 6: Gradio Function | |
# ---------------------------- | |
def gui_generation(prompt, ref_img, guidance_scale, ipadapter_scale): | |
"""Generate an image using Stable Diffusion 3.5 Large with IP-Adapter.""" | |
try: | |
# Load and preprocess the reference image | |
ref_img_tensor = preprocess_image(ref_img.name) | |
except Exception as e: | |
raise ValueError(f"Error loading reference image: {e}") | |
# Run the pipeline | |
with torch.no_grad(): | |
image = pipe( | |
width=1024, | |
height=1024, | |
prompt=prompt, | |
negative_prompt="lowres, low quality, worst quality", | |
num_inference_steps=24, | |
guidance_scale=guidance_scale, | |
generator=torch.Generator("cuda").manual_seed(42), | |
clip_image=ref_img_tensor, | |
ipadapter_scale=ipadapter_scale | |
).images[0] | |
return image | |
# ---------------------------- | |
# Step 7: Gradio Interface | |
# ---------------------------- | |
prompt_box = gr.Textbox(label="Prompt", placeholder="Enter your image generation prompt") | |
ref_img = gr.File(label="Upload Reference Image") | |
guidance_slider = gr.Slider( | |
label="Guidance Scale", | |
minimum=2, | |
maximum=16, | |
value=7, | |
step=0.5, | |
info="Controls adherence to the text prompt" | |
) | |
ipadapter_slider = gr.Slider( | |
label="IP-Adapter Scale", | |
minimum=0, | |
maximum=1, | |
value=0.5, | |
step=0.1, | |
info="Controls influence of the image prompt" | |
) | |
interface = gr.Interface( | |
fn=gui_generation, | |
inputs=[prompt_box, ref_img, guidance_slider, ipadapter_slider], | |
outputs="image", | |
title="Image Generation with Stable Diffusion 3.5 Large and IP-Adapter", | |
description="Generates an image based on a text prompt and a reference image using Stable Diffusion 3.5 Large with IP-Adapter." | |
) | |
# ---------------------------- | |
# Step 8: Launch Gradio App | |
# ---------------------------- | |
interface.launch() | |