File size: 2,075 Bytes
683afc3
 
0737dc8
 
 
 
 
 
 
f954913
c1497a6
0737dc8
4fbc46c
0737dc8
4fbc46c
c1497a6
683afc3
0737dc8
 
 
 
 
 
 
 
8d2ed6a
0737dc8
ca7110d
c1497a6
0737dc8
 
683afc3
0737dc8
 
 
 
 
8d2ed6a
0737dc8
c1497a6
0737dc8
 
8d2ed6a
683afc3
8d2ed6a
 
 
 
 
 
683afc3
0737dc8
8d2ed6a
 
 
683afc3
8d2ed6a
 
683afc3
8d2ed6a
0737dc8
 
683afc3
 
8d2ed6a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import gradio as gr
import torch
from diffusers import (
    StableDiffusionControlNetPipeline,
    ControlNetModel,
    UNet2DConditionModel,
    AutoencoderKL,
    UniPCMultistepScheduler,
)
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from huggingface_hub import login
import os

# Log in to Hugging Face with token from environment variables
token = os.getenv("HF_TOKEN")
login(token=token)

# Model and ControlNet IDs
model_id = "runwayml/stable-diffusion-v1-5"  # Known compatible model with ControlNet
controlnet_id = "lllyasviel/sd-controlnet-canny"  # ControlNet model for edge detection

# Load ControlNet model and other components
controlnet = ControlNetModel.from_pretrained(controlnet_id, torch_dtype=torch.float16)
pipeline = StableDiffusionControlNetPipeline.from_pretrained(
    model_id,
    controlnet=controlnet,
    torch_dtype=torch.float16
)

# Optional: Set up the faster scheduler
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)

# Enable CPU offloading for memory optimization
pipeline.enable_model_cpu_offload()


# Gradio interface function
def generate_image(prompt, reference_image):
    # Resize and prepare reference image
    reference_image = reference_image.convert("RGB").resize((512, 512))

    # Generate image using the pipeline with ControlNet
    generated_image = pipeline(
        prompt=prompt,
        image=reference_image,
        controlnet_conditioning_scale=1.0,
        guidance_scale=7.5,
        num_inference_steps=50
    ).images[0]
    return generated_image


# Set up Gradio interface
interface = gr.Interface(
    fn=generate_image,
    inputs=[
        gr.Textbox(label="Prompt"),
        gr.Image(type="pil", label="Reference Image (Style)")
    ],
    outputs="image",
    title="Image Generation with ControlNet (Reference-Only Style Transfer)",
    description="Generates an image based on a text prompt and style reference image using Stable Diffusion and ControlNet (reference-only mode)."
)

# Launch the Gradio interface
interface.launch()