test_gradio / app.py
amos1088's picture
uuu
d5f11d4
raw
history blame
4.43 kB
import os
import requests
import torch
import gradio as gr
import spaces
from PIL import Image
from huggingface_hub import login
from torchvision import transforms
from diffusers.utils import load_image
from models.transformer_sd3 import SD3Transformer2DModel
from pipeline_stable_diffusion_3_ipa import StableDiffusion3Pipeline
# ----------------------------
# Step 1: Download IP Adapter if not exists
# ----------------------------
url = "https://huggingface.co/InstantX/SD3.5-Large-IP-Adapter/resolve/main/ip-adapter.bin"
file_path = "ip-adapter.bin"
if not os.path.exists(file_path):
print("File not found, downloading...")
response = requests.get(url, stream=True)
with open(file_path, "wb") as file:
for chunk in response.iter_content(chunk_size=1024):
if chunk:
file.write(chunk)
print("Download completed!")
# ----------------------------
# Step 2: Hugging Face Login
# ----------------------------
token = os.getenv("HF_TOKEN")
if not token:
raise ValueError("Hugging Face token not found. Set the 'HF_TOKEN' environment variable.")
login(token=token)
# ----------------------------
# Step 3: Model Paths
# ----------------------------
model_path = 'stabilityai/stable-diffusion-3.5-large'
ip_adapter_path = './ip-adapter.bin'
image_encoder_path = "google/siglip-so400m-patch14-384"
# ----------------------------
# Step 4: Load Transformer and Pipeline
# ----------------------------
transformer = SD3Transformer2DModel.from_pretrained(
model_path, subfolder="transformer", torch_dtype=torch.float16
)
pipe = StableDiffusion3Pipeline.from_pretrained(
model_path, transformer=transformer, torch_dtype=torch.float16
).to("cuda")
pipe.init_ipadapter(
ip_adapter_path=ip_adapter_path,
image_encoder_path=image_encoder_path,
nb_token=64,
)
# ----------------------------
# Step 5: Preprocess Reference Image
# ----------------------------
def preprocess_image(image_path):
"""Ensure the input image is a valid PIL Image and resize it."""
image = Image.open(image_path).convert("RGB")
# Ensure the image is resized into a square
size = max(image.size) # Get the largest dimension
image = image.resize((size, size), Image.BILINEAR)
preprocess = transforms.Compose([
transforms.Resize((384, 384)),
transforms.ToTensor(),
transforms.ConvertImageDtype(torch.float16)
])
return preprocess(image).unsqueeze(0).to("cuda")
# ----------------------------
# Step 6: Gradio Function
# ----------------------------
@spaces.GPU
def gui_generation(prompt, ref_img, guidance_scale, ipadapter_scale):
"""Generate an image using Stable Diffusion 3.5 Large with IP-Adapter."""
try:
# Load and preprocess the reference image
ref_img_tensor = preprocess_image(ref_img.name)
except Exception as e:
raise ValueError(f"Error loading reference image: {e}")
# Run the pipeline
with torch.no_grad():
image = pipe(
width=1024,
height=1024,
prompt=prompt,
negative_prompt="lowres, low quality, worst quality",
num_inference_steps=24,
guidance_scale=guidance_scale,
generator=torch.Generator("cuda").manual_seed(42),
clip_image=ref_img_tensor,
ipadapter_scale=ipadapter_scale
).images[0]
return image
# ----------------------------
# Step 7: Gradio Interface
# ----------------------------
prompt_box = gr.Textbox(label="Prompt", placeholder="Enter your image generation prompt")
ref_img = gr.File(label="Upload Reference Image")
guidance_slider = gr.Slider(
label="Guidance Scale",
minimum=2,
maximum=16,
value=7,
step=0.5,
info="Controls adherence to the text prompt"
)
ipadapter_slider = gr.Slider(
label="IP-Adapter Scale",
minimum=0,
maximum=1,
value=0.5,
step=0.1,
info="Controls influence of the image prompt"
)
interface = gr.Interface(
fn=gui_generation,
inputs=[prompt_box, ref_img, guidance_slider, ipadapter_slider],
outputs="image",
title="Image Generation with Stable Diffusion 3.5 Large and IP-Adapter",
description="Generates an image based on a text prompt and a reference image using Stable Diffusion 3.5 Large with IP-Adapter."
)
# ----------------------------
# Step 8: Launch Gradio App
# ----------------------------
interface.launch()