File size: 1,817 Bytes
683afc3
 
52d3f89
c1497a6
0737dc8
4fbc46c
0737dc8
4fbc46c
c1497a6
683afc3
6fc2fae
 
52d3f89
683afc3
52d3f89
 
 
 
 
0737dc8
 
8d2ed6a
52d3f89
c1497a6
0737dc8
52d3f89
 
683afc3
8d2ed6a
 
 
 
 
 
683afc3
8d2ed6a
 
 
683afc3
8d2ed6a
 
683afc3
8d2ed6a
52d3f89
 
683afc3
 
8d2ed6a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import gradio as gr
import torch
from diffusers import StableDiffusion3Pipeline, ControlNetModel, UniPCMultistepScheduler
from huggingface_hub import login
import os

# Log in to Hugging Face with token from environment variables
token = os.getenv("HF_TOKEN")
login(token=token)

# Model IDs for the base Stable Diffusion model and ControlNet variant
model_id = "stabilityai/stable-diffusion-3.5-large-turbo"
controlnet_id = "lllyasviel/control_v11p_sd15_inpaint"  # Adjust based on ControlNet needs

# Load ControlNet and Stable Diffusion models
controlnet = ControlNetModel.from_pretrained(controlnet_id, torch_dtype=torch.bfloat16)
pipe = StableDiffusion3Pipeline.from_pretrained(model_id, controlnet=controlnet, torch_dtype=torch.bfloat16)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cuda") if torch.cuda.is_available() else pipe

# Gradio interface function
def generate_image(prompt, reference_image):
    # Prepare the reference image
    reference_image = reference_image.convert("RGB").resize((512, 512))

    # Generate the image using the pipeline with ControlNet
    generated_image = pipe(
        prompt=prompt,
        image=reference_image,
        controlnet_conditioning_scale=1.0,
        guidance_scale=7.5,
        num_inference_steps=50
    ).images[0]
    return generated_image

# Set up Gradio interface
interface = gr.Interface(
    fn=generate_image,
    inputs=[
        gr.Textbox(label="Prompt"),
        gr.Image(type="pil", label="Reference Image (Style)")
    ],
    outputs="image",
    title="Image Generation with Stable Diffusion 3.5 and ControlNet",
    description="Generates an image based on a text prompt and style reference image using Stable Diffusion 3.5 and ControlNet."
)

# Launch the Gradio interface
interface.launch()