Spaces:
Paused
Paused
File size: 4,428 Bytes
4f91ffe 51f8f41 545ba28 6dcb6b3 51f8f41 4f91ffe 6dcb6b3 4f91ffe 1ae6c5e 6dcb6b3 4fbc46c 6dcb6b3 c1497a6 3aadc38 6dcb6b3 3aadc38 6dcb6b3 d8f1f69 6dcb6b3 36e35d5 d8f1f69 6dcb6b3 d8f1f69 d5f11d4 6dcb6b3 d5f11d4 6dcb6b3 d5f11d4 d09f5de d5f11d4 d09f5de 545ba28 6dcb6b3 545ba28 d5f11d4 6dcb6b3 d5f11d4 bf8b15f 6dcb6b3 36e35d5 545ba28 51f8f41 d5f11d4 6dcb6b3 e129330 51f8f41 e129330 51f8f41 e129330 6dcb6b3 d09f5de |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import os
import requests
import torch
import gradio as gr
import spaces
from PIL import Image
from huggingface_hub import login
from torchvision import transforms
from diffusers.utils import load_image
from models.transformer_sd3 import SD3Transformer2DModel
from pipeline_stable_diffusion_3_ipa import StableDiffusion3Pipeline
# ----------------------------
# Step 1: Download IP Adapter if not exists
# ----------------------------
url = "https://huggingface.co/InstantX/SD3.5-Large-IP-Adapter/resolve/main/ip-adapter.bin"
file_path = "ip-adapter.bin"
if not os.path.exists(file_path):
print("File not found, downloading...")
response = requests.get(url, stream=True)
with open(file_path, "wb") as file:
for chunk in response.iter_content(chunk_size=1024):
if chunk:
file.write(chunk)
print("Download completed!")
# ----------------------------
# Step 2: Hugging Face Login
# ----------------------------
token = os.getenv("HF_TOKEN")
if not token:
raise ValueError("Hugging Face token not found. Set the 'HF_TOKEN' environment variable.")
login(token=token)
# ----------------------------
# Step 3: Model Paths
# ----------------------------
model_path = 'stabilityai/stable-diffusion-3.5-large'
ip_adapter_path = './ip-adapter.bin'
image_encoder_path = "google/siglip-so400m-patch14-384"
# ----------------------------
# Step 4: Load Transformer and Pipeline
# ----------------------------
transformer = SD3Transformer2DModel.from_pretrained(
model_path, subfolder="transformer", torch_dtype=torch.float16
)
pipe = StableDiffusion3Pipeline.from_pretrained(
model_path, transformer=transformer, torch_dtype=torch.float16
).to("cuda")
pipe.init_ipadapter(
ip_adapter_path=ip_adapter_path,
image_encoder_path=image_encoder_path,
nb_token=64,
)
# ----------------------------
# Step 5: Preprocess Reference Image
# ----------------------------
def preprocess_image(image_path):
"""Ensure the input image is a valid PIL Image and resize it."""
image = Image.open(image_path).convert("RGB")
# Ensure the image is resized into a square
size = max(image.size) # Get the largest dimension
image = image.resize((size, size), Image.BILINEAR)
preprocess = transforms.Compose([
transforms.Resize((384, 384)),
transforms.ToTensor(),
transforms.ConvertImageDtype(torch.float16)
])
return preprocess(image).unsqueeze(0).to("cuda")
# ----------------------------
# Step 6: Gradio Function
# ----------------------------
@spaces.GPU
def gui_generation(prompt, ref_img, guidance_scale, ipadapter_scale):
"""Generate an image using Stable Diffusion 3.5 Large with IP-Adapter."""
try:
# Load and preprocess the reference image
ref_img_tensor = preprocess_image(ref_img.name)
except Exception as e:
raise ValueError(f"Error loading reference image: {e}")
# Run the pipeline
with torch.no_grad():
image = pipe(
width=1024,
height=1024,
prompt=prompt,
negative_prompt="lowres, low quality, worst quality",
num_inference_steps=24,
guidance_scale=guidance_scale,
generator=torch.Generator("cuda").manual_seed(42),
clip_image=ref_img_tensor,
ipadapter_scale=ipadapter_scale
).images[0]
return image
# ----------------------------
# Step 7: Gradio Interface
# ----------------------------
prompt_box = gr.Textbox(label="Prompt", placeholder="Enter your image generation prompt")
ref_img = gr.File(label="Upload Reference Image")
guidance_slider = gr.Slider(
label="Guidance Scale",
minimum=2,
maximum=16,
value=7,
step=0.5,
info="Controls adherence to the text prompt"
)
ipadapter_slider = gr.Slider(
label="IP-Adapter Scale",
minimum=0,
maximum=1,
value=0.5,
step=0.1,
info="Controls influence of the image prompt"
)
interface = gr.Interface(
fn=gui_generation,
inputs=[prompt_box, ref_img, guidance_slider, ipadapter_slider],
outputs="image",
title="Image Generation with Stable Diffusion 3.5 Large and IP-Adapter",
description="Generates an image based on a text prompt and a reference image using Stable Diffusion 3.5 Large with IP-Adapter."
)
# ----------------------------
# Step 8: Launch Gradio App
# ----------------------------
interface.launch()
|