File size: 3,423 Bytes
4f91ffe
 
51f8f41
 
 
 
 
4ec6616
6dcb6b3
51f8f41
 
 
4f91ffe
6dcb6b3
 
 
4f91ffe
 
 
 
 
 
 
 
 
 
 
1ae6c5e
6dcb6b3
 
 
4fbc46c
6dcb6b3
 
c1497a6
3aadc38
d5f11d4
 
6dcb6b3
 
 
 
 
b2d0aef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51f8f41
 
 
d5f11d4
6dcb6b3
 
 
e129330
 
 
 
 
 
 
 
51f8f41
 
e129330
 
 
 
 
 
 
 
 
 
 
 
 
 
51f8f41
 
e129330
 
6dcb6b3
 
 
d09f5de
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import requests
import torch
import gradio as gr
import spaces
from PIL import Image
from huggingface_hub import login
import torchvision.transforms as T
from diffusers.utils import load_image

from models.transformer_sd3 import SD3Transformer2DModel
from pipeline_stable_diffusion_3_ipa import StableDiffusion3Pipeline

# ----------------------------
# Step 1: Download IP Adapter if not exists
# ----------------------------
url = "https://huggingface.co/InstantX/SD3.5-Large-IP-Adapter/resolve/main/ip-adapter.bin"
file_path = "ip-adapter.bin"

if not os.path.exists(file_path):
    print("File not found, downloading...")
    response = requests.get(url, stream=True)
    with open(file_path, "wb") as file:
        for chunk in response.iter_content(chunk_size=1024):
            if chunk:
                file.write(chunk)
    print("Download completed!")

# ----------------------------
# Step 2: Hugging Face Login
# ----------------------------
token = os.getenv("HF_TOKEN")
if not token:
    raise ValueError("Hugging Face token not found. Set the 'HF_TOKEN' environment variable.")
login(token=token)



# ----------------------------
# Step 6: Gradio Function
# ----------------------------
@spaces.GPU
def gui_generation(prompt, ref_img, guidance_scale, ipadapter_scale):

    model_path = 'stabilityai/stable-diffusion-3.5-large'
    ip_adapter_path = './ip-adapter.bin'
    image_encoder_path = "google/siglip-so400m-patch14-384"

    transformer = SD3Transformer2DModel.from_pretrained(
        model_path, subfolder="transformer", torch_dtype=torch.bfloat16
    )

    pipe = StableDiffusion3Pipeline.from_pretrained(
        model_path, transformer=transformer, torch_dtype=torch.bfloat16
    ).to("cuda")

    pipe.init_ipadapter(
        ip_adapter_path=ip_adapter_path,
        image_encoder_path=image_encoder_path,
        nb_token=64,
    )

    ref_img = load_image(ref_img.name).convert('RGB')

    # please note that SD3.5 Large is sensitive to highres generation like 1536x1536
    image = pipe(
        width=1024,
        height=1024,
        prompt=prompt,
        negative_prompt="lowres, low quality, worst quality",
        num_inference_steps=24,
        guidance_scale=guidance_scale,
        generator=torch.Generator("cuda").manual_seed(42),
        clip_image=ref_img,
        ipadapter_scale=ipadapter_scale,
    ).images[0]

    return image


# ----------------------------
# Step 7: Gradio Interface
# ----------------------------
prompt_box = gr.Textbox(label="Prompt", placeholder="Enter your image generation prompt")
ref_img = gr.File(label="Upload Reference Image")
guidance_slider = gr.Slider(
    label="Guidance Scale",
    minimum=2,
    maximum=16,
    value=7,
    step=0.5,
    info="Controls adherence to the text prompt"
)

ipadapter_slider = gr.Slider(
    label="IP-Adapter Scale",
    minimum=0,
    maximum=1,
    value=0.5,
    step=0.1,
    info="Controls influence of the image prompt"
)

interface = gr.Interface(
    fn=gui_generation,
    inputs=[prompt_box, ref_img, guidance_slider, ipadapter_slider],
    outputs="image",
    title="Image Generation with Stable Diffusion 3.5 Large and IP-Adapter",
    description="Generates an image based on a text prompt and a reference image using Stable Diffusion 3.5 Large with IP-Adapter."
)

# ----------------------------
# Step 8: Launch Gradio App
# ----------------------------
interface.launch()