File size: 2,790 Bytes
683afc3
 
0737dc8
6c3f566
0737dc8
6c3f566
0737dc8
 
 
f954913
c1497a6
0737dc8
4fbc46c
0737dc8
4fbc46c
c1497a6
683afc3
6fc2fae
 
de93c44
0737dc8
de93c44
0737dc8
6c3f566
de93c44
 
 
 
 
 
6c3f566
 
de93c44
 
 
8d2ed6a
de93c44
 
 
ca7110d
de93c44
 
6fc2fae
683afc3
de93c44
0737dc8
 
 
8d2ed6a
0737dc8
c1497a6
0737dc8
 
8d2ed6a
683afc3
8d2ed6a
 
 
 
 
 
683afc3
8d2ed6a
 
 
683afc3
8d2ed6a
 
683afc3
8d2ed6a
0737dc8
6c3f566
683afc3
 
8d2ed6a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import gradio as gr
import torch
from diffusers import (
    StableDiffusion3Pipeline,  # For SD3 models like Stable Diffusion 3.5
    ControlNetModel,
    SD3Transformer2DModel,  # Replacing UNet with SD3 transformer
    AutoencoderKL,
    UniPCMultistepScheduler,
)
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from huggingface_hub import login
import os

# Log in to Hugging Face with token from environment variables
token = os.getenv("HF_TOKEN")
login(token=token)

# Model IDs for the base Stable Diffusion model and ControlNet variant
model_id = "stabilityai/stable-diffusion-3.5-large-turbo"
controlnet_id = "lllyasviel/control_v11p_sd15_inpaint"

# Load each model component required by the pipeline
controlnet = ControlNetModel.from_pretrained(controlnet_id, torch_dtype=torch.float16)
transformer = SD3Transformer2DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.float16)
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float16)
feature_extractor = CLIPFeatureExtractor.from_pretrained(model_id)
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder")
tokenizer = CLIPTokenizer.from_pretrained(model_id)

# Initialize the pipeline with all components
pipeline = StableDiffusion3Pipeline(
    transformer=transformer,  # Using SD3 transformer
    vae=vae,
    text_encoder=text_encoder,
    tokenizer=tokenizer,
    controlnet=controlnet,
    scheduler=UniPCMultistepScheduler.from_config({"name": "UniPCMultistepScheduler"}),
    feature_extractor=feature_extractor,
    torch_dtype=torch.float16,
)

# Set device for pipeline
pipeline = pipeline.to("cuda") if torch.cuda.is_available() else pipeline

# Enable model CPU offloading for memory optimization
pipeline.enable_model_cpu_offload()

# Gradio interface function
def generate_image(prompt, reference_image):
    # Resize and prepare reference image
    reference_image = reference_image.convert("RGB").resize((512, 512))

    # Generate image using the pipeline with ControlNet
    generated_image = pipeline(
        prompt=prompt,
        image=reference_image,
        controlnet_conditioning_scale=1.0,
        guidance_scale=7.5,
        num_inference_steps=50
    ).images[0]
    return generated_image

# Set up Gradio interface
interface = gr.Interface(
    fn=generate_image,
    inputs=[
        gr.Textbox(label="Prompt"),
        gr.Image(type="pil", label="Reference Image (Style)")
    ],
    outputs="image",
    title="Image Generation with ControlNet (Reference-Only Style Transfer)",
    description="Generates an image based on a text prompt and style reference image using Stable Diffusion 3.5 and ControlNet (reference-only mode)."
)

# Launch the Gradio interface
interface.launch()