Spaces:
Paused
Paused
File size: 1,754 Bytes
683afc3 121ee3d c1497a6 0737dc8 74c4e79 121ee3d 4fbc46c c1497a6 683afc3 121ee3d 683afc3 121ee3d 52d3f89 74c4e79 0737dc8 74c4e79 8d2ed6a 121ee3d c1497a6 0737dc8 121ee3d 52d3f89 683afc3 8d2ed6a 683afc3 8d2ed6a 683afc3 8d2ed6a 683afc3 8d2ed6a 121ee3d 683afc3 8d2ed6a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
import gradio as gr
import torch
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
from huggingface_hub import login
import os
import spaces
# Log in to Hugging Face with your token
token = os.getenv("HF_TOKEN")
login(token=token)
# Model IDs for Stable Diffusion 1.5 and ControlNet
model_id = "runwayml/stable-diffusion-v1-5" # Compatible with ControlNet
controlnet_id = "lllyasviel/control_v11p_sd15_inpaint"
# Load the ControlNet model and Stable Diffusion pipeline
controlnet = ControlNetModel.from_pretrained(controlnet_id, torch_dtype=torch.float16)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
model_id, controlnet=controlnet, torch_dtype=torch.float16
)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cuda")
@spaces.GPU
def generate_image(prompt, reference_image):
# Prepare the reference image for ControlNet
reference_image = reference_image.convert("RGB").resize((512, 512))
# Generate the image with ControlNet conditioning
generated_image = pipe(
prompt=prompt,
image=reference_image,
controlnet_conditioning_scale=1.0,
guidance_scale=7.5,
num_inference_steps=50
).images[0]
return generated_image
# Set up Gradio interface
interface = gr.Interface(
fn=generate_image,
inputs=[
gr.Textbox(label="Prompt"),
gr.Image(type="pil", label="Reference Image (Style)")
],
outputs="image",
title="Image Generation with Stable Diffusion 1.5 and ControlNet",
description="Generates an image based on a text prompt and a reference image using Stable Diffusion 1.5 with ControlNet."
)
# Launch the Gradio interface
interface.launch()
|