Spaces:
Paused
Paused
File size: 2,075 Bytes
683afc3 0737dc8 f954913 c1497a6 0737dc8 4fbc46c 0737dc8 4fbc46c c1497a6 683afc3 0737dc8 8d2ed6a 0737dc8 ca7110d c1497a6 0737dc8 683afc3 0737dc8 8d2ed6a 0737dc8 c1497a6 0737dc8 8d2ed6a 683afc3 8d2ed6a 683afc3 0737dc8 8d2ed6a 683afc3 8d2ed6a 683afc3 8d2ed6a 0737dc8 683afc3 8d2ed6a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import gradio as gr
import torch
from diffusers import (
StableDiffusionControlNetPipeline,
ControlNetModel,
UNet2DConditionModel,
AutoencoderKL,
UniPCMultistepScheduler,
)
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from huggingface_hub import login
import os
# Log in to Hugging Face with token from environment variables
token = os.getenv("HF_TOKEN")
login(token=token)
# Model and ControlNet IDs
model_id = "runwayml/stable-diffusion-v1-5" # Known compatible model with ControlNet
controlnet_id = "lllyasviel/sd-controlnet-canny" # ControlNet model for edge detection
# Load ControlNet model and other components
controlnet = ControlNetModel.from_pretrained(controlnet_id, torch_dtype=torch.float16)
pipeline = StableDiffusionControlNetPipeline.from_pretrained(
model_id,
controlnet=controlnet,
torch_dtype=torch.float16
)
# Optional: Set up the faster scheduler
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
# Enable CPU offloading for memory optimization
pipeline.enable_model_cpu_offload()
# Gradio interface function
def generate_image(prompt, reference_image):
# Resize and prepare reference image
reference_image = reference_image.convert("RGB").resize((512, 512))
# Generate image using the pipeline with ControlNet
generated_image = pipeline(
prompt=prompt,
image=reference_image,
controlnet_conditioning_scale=1.0,
guidance_scale=7.5,
num_inference_steps=50
).images[0]
return generated_image
# Set up Gradio interface
interface = gr.Interface(
fn=generate_image,
inputs=[
gr.Textbox(label="Prompt"),
gr.Image(type="pil", label="Reference Image (Style)")
],
outputs="image",
title="Image Generation with ControlNet (Reference-Only Style Transfer)",
description="Generates an image based on a text prompt and style reference image using Stable Diffusion and ControlNet (reference-only mode)."
)
# Launch the Gradio interface
interface.launch()
|