File size: 1,775 Bytes
683afc3
c1497a6
0737dc8
74c4e79
9754bfe
f5ffe3a
 
97c3973
 
b1029c2
97c3973
b1029c2
feede18
4fbc46c
c1497a6
683afc3
b12bc82
a88d434
b1029c2
bcbf6e0
0737dc8
74c4e79
43107ac
 
 
97c3973
b1029c2
43107ac
b1029c2
43107ac
97c3973
 
 
43107ac
b1029c2
 
 
97c3973
 
 
683afc3
7968596
 
 
 
 
43107ac
7968596
 
 
 
 
9754bfe
7968596
683afc3
7968596
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import gradio as gr
from huggingface_hub import login
import os
import spaces
import torch
from diffusers import StableDiffusionXLPipeline
from PIL import Image
import torch
from diffusers import AutoPipelineForText2Image, DDIMScheduler
from diffusers import AutoPipelineForText2Image
from diffusers.utils import load_image
import torch

token = os.getenv("HF_TOKEN")
login(token=token)


pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16).to("cuda")



@spaces.GPU
def generate_image(prompt, reference_images, controlnet_conditioning_scale):
    pipeline.load_ip_adapter(["h94/IP-Adapter"]*len(reference_images), subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
    style_images = [Image.open(reference_image) for reference_image in reference_images]
    # reference_image.resize((512, 512))
    scale = {
        "up": {"block_0": [0.0, controlnet_conditioning_scale/len(reference_images), 0.0]},
    }
    pipeline.set_ip_adapter_scale([scale]*len(reference_images))

    image = pipeline(
        prompt=prompt,
        ip_adapter_image=style_images,
        negative_prompt="",
        guidance_scale=5,
        num_inference_steps=30,
    ).images[0]

    return image

# Set up Gradio interface
interface = gr.Interface(
    fn=generate_image,
    inputs=[
        gr.Textbox(label="Prompt"),
        gr.inputs.File(file_count="multiple"),
        gr.Slider(label="Control Net Conditioning Scale", minimum=0, maximum=1.0, step=0.1, value=0.6),
    ],
    outputs="image",
    title="Image Generation with Stable Diffusion 3 medium and ControlNet",
    description="Generates an image based on a text prompt and a reference image using Stable Diffusion 3 medium with ControlNet."

)

interface.launch()