File size: 4,428 Bytes
4f91ffe
 
51f8f41
 
 
 
 
545ba28
6dcb6b3
51f8f41
 
 
4f91ffe
6dcb6b3
 
 
4f91ffe
 
 
 
 
 
 
 
 
 
 
1ae6c5e
6dcb6b3
 
 
4fbc46c
6dcb6b3
 
c1497a6
3aadc38
6dcb6b3
 
 
3aadc38
 
 
 
6dcb6b3
 
 
d8f1f69
6dcb6b3
36e35d5
d8f1f69
 
6dcb6b3
d8f1f69
 
 
 
 
 
 
 
d5f11d4
6dcb6b3
d5f11d4
6dcb6b3
 
d5f11d4
d09f5de
d5f11d4
 
 
 
d09f5de
545ba28
 
 
 
 
6dcb6b3
545ba28
d5f11d4
6dcb6b3
 
 
 
 
 
d5f11d4
 
 
 
 
bf8b15f
6dcb6b3
36e35d5
 
 
 
 
 
 
 
 
545ba28
51f8f41
 
 
 
 
d5f11d4
6dcb6b3
 
 
e129330
 
 
 
 
 
 
 
51f8f41
 
e129330
 
 
 
 
 
 
 
 
 
 
 
 
 
51f8f41
 
e129330
 
6dcb6b3
 
 
d09f5de
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import os
import requests
import torch
import gradio as gr
import spaces
from PIL import Image
from huggingface_hub import login
from torchvision import transforms
from diffusers.utils import load_image

from models.transformer_sd3 import SD3Transformer2DModel
from pipeline_stable_diffusion_3_ipa import StableDiffusion3Pipeline

# ----------------------------
# Step 1: Download IP Adapter if not exists
# ----------------------------
url = "https://huggingface.co/InstantX/SD3.5-Large-IP-Adapter/resolve/main/ip-adapter.bin"
file_path = "ip-adapter.bin"

if not os.path.exists(file_path):
    print("File not found, downloading...")
    response = requests.get(url, stream=True)
    with open(file_path, "wb") as file:
        for chunk in response.iter_content(chunk_size=1024):
            if chunk:
                file.write(chunk)
    print("Download completed!")

# ----------------------------
# Step 2: Hugging Face Login
# ----------------------------
token = os.getenv("HF_TOKEN")
if not token:
    raise ValueError("Hugging Face token not found. Set the 'HF_TOKEN' environment variable.")
login(token=token)

# ----------------------------
# Step 3: Model Paths
# ----------------------------
model_path = 'stabilityai/stable-diffusion-3.5-large'
ip_adapter_path = './ip-adapter.bin'
image_encoder_path = "google/siglip-so400m-patch14-384"

# ----------------------------
# Step 4: Load Transformer and Pipeline
# ----------------------------
transformer = SD3Transformer2DModel.from_pretrained(
    model_path, subfolder="transformer", torch_dtype=torch.float16
)

pipe = StableDiffusion3Pipeline.from_pretrained(
    model_path, transformer=transformer, torch_dtype=torch.float16
).to("cuda")

pipe.init_ipadapter(
    ip_adapter_path=ip_adapter_path,
    image_encoder_path=image_encoder_path,
    nb_token=64,
)


# ----------------------------
# Step 5: Preprocess Reference Image
# ----------------------------
def preprocess_image(image_path):
    """Ensure the input image is a valid PIL Image and resize it."""
    image = Image.open(image_path).convert("RGB")

    # Ensure the image is resized into a square
    size = max(image.size)  # Get the largest dimension
    image = image.resize((size, size), Image.BILINEAR)

    preprocess = transforms.Compose([
        transforms.Resize((384, 384)),
        transforms.ToTensor(),
        transforms.ConvertImageDtype(torch.float16)
    ])
    return preprocess(image).unsqueeze(0).to("cuda")


# ----------------------------
# Step 6: Gradio Function
# ----------------------------
@spaces.GPU
def gui_generation(prompt, ref_img, guidance_scale, ipadapter_scale):
    """Generate an image using Stable Diffusion 3.5 Large with IP-Adapter."""
    try:
        # Load and preprocess the reference image
        ref_img_tensor = preprocess_image(ref_img.name)
    except Exception as e:
        raise ValueError(f"Error loading reference image: {e}")

    # Run the pipeline
    with torch.no_grad():
        image = pipe(
            width=1024,
            height=1024,
            prompt=prompt,
            negative_prompt="lowres, low quality, worst quality",
            num_inference_steps=24,
            guidance_scale=guidance_scale,
            generator=torch.Generator("cuda").manual_seed(42),
            clip_image=ref_img_tensor,
            ipadapter_scale=ipadapter_scale
        ).images[0]

    return image


# ----------------------------
# Step 7: Gradio Interface
# ----------------------------
prompt_box = gr.Textbox(label="Prompt", placeholder="Enter your image generation prompt")
ref_img = gr.File(label="Upload Reference Image")
guidance_slider = gr.Slider(
    label="Guidance Scale",
    minimum=2,
    maximum=16,
    value=7,
    step=0.5,
    info="Controls adherence to the text prompt"
)

ipadapter_slider = gr.Slider(
    label="IP-Adapter Scale",
    minimum=0,
    maximum=1,
    value=0.5,
    step=0.1,
    info="Controls influence of the image prompt"
)

interface = gr.Interface(
    fn=gui_generation,
    inputs=[prompt_box, ref_img, guidance_slider, ipadapter_slider],
    outputs="image",
    title="Image Generation with Stable Diffusion 3.5 Large and IP-Adapter",
    description="Generates an image based on a text prompt and a reference image using Stable Diffusion 3.5 Large with IP-Adapter."
)

# ----------------------------
# Step 8: Launch Gradio App
# ----------------------------
interface.launch()