Spaces:
Build error
Build error
import streamlit as st | |
import os | |
import random | |
import subprocess | |
import io | |
import numpy as np | |
from PIL import Image | |
import torch | |
from diffusers import StableDiffusionPipeline, UNet2DConditionModel | |
from torchvision import transforms | |
# If you're using Zero123++: | |
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler | |
# ------------------------------------------------------------------------------ | |
# 0. GLOBAL CONFIG & UTILS | |
# ------------------------------------------------------------------------------ | |
# Provide your base SD model path & fine-tuned UNet path here. | |
# (In a HF Space, you might store them in a local folder or load from HF repos.) | |
BASE_MODEL_PATH = "runwayml/stable-diffusion-v1-5" | |
FINE_TUNED_PATH = "my_finetuned_unet" # e.g., local folder or HF Hub ID | |
# If you want to use Zero123++ from a local clone: | |
ZERO123_MODEL_ID = "sudo-ai/zero123plus-v1.2" | |
# Example safety checker dummy, as used in your snippet: | |
def dummy_safety_checker(images, clip_input): | |
return images, False | |
# Make sure to remove or comment out any "!pip install ..." lines and rely | |
# on your requirements.txt in the environment. | |
# ------------------------------------------------------------------------------ | |
# 1. LOAD MODELS & PIPELINES | |
# ------------------------------------------------------------------------------ | |
def load_sd_pipeline(): | |
"""Load the base stable diffusion pipeline with fine-tuned UNet attached.""" | |
pipe = StableDiffusionPipeline.from_pretrained( | |
BASE_MODEL_PATH, | |
torch_dtype=torch.float16 | |
) | |
pipe.to("cuda") | |
# Load and replace UNet | |
unet = UNet2DConditionModel.from_pretrained( | |
FINE_TUNED_PATH, | |
subfolder="unet", | |
torch_dtype=torch.float16 | |
).to("cuda") | |
pipe.unet = unet | |
pipe.safety_checker = dummy_safety_checker | |
return pipe | |
def load_zero123_pipeline(): | |
"""Load Zero123++ pipeline (v1.2) with EulerAncestralDiscreteScheduler.""" | |
pipeline = DiffusionPipeline.from_pretrained( | |
ZERO123_MODEL_ID, | |
custom_pipeline="sudo-ai/zero123plus-pipeline", | |
torch_dtype=torch.float16 | |
) | |
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config( | |
pipeline.scheduler.config, | |
timestep_spacing='trailing' | |
) | |
pipeline.to("cuda") | |
return pipeline | |
# ------------------------------------------------------------------------------ | |
# 2. HELPER FUNCTIONS | |
# ------------------------------------------------------------------------------ | |
def generate_funko_image(pipe, prompt: str, steps: int = 50): | |
"""Generate a Funko image using the loaded Stable Diffusion pipeline.""" | |
with torch.autocast("cuda"): | |
image = pipe(prompt, num_inference_steps=steps).images[0] | |
return image | |
def run_syncdreamer(input_path: str, output_dir: str): | |
""" | |
Placeholder for the SyncDreamer command-line call. | |
You would adapt this to run your real command. For example: | |
syncdreamer_cmd = [ | |
"python", "generate.py", | |
"--ckpt", "ckpt/syncdreamer-pretrain.ckpt", | |
"--input", input_path, | |
"--output", output_dir, | |
"--sample_num", "4", | |
"--cfg_scale", "2.0", | |
... | |
] | |
subprocess.run(syncdreamer_cmd, check=True) | |
""" | |
st.info("Running SyncDreamer... (this is a placeholder call)") | |
os.makedirs(output_dir, exist_ok=True) | |
# In real usage, call the above commented command via subprocess | |
st.success(f"SyncDreamer completed. Output in: {output_dir}") | |
def make_square_min_dim(image: Image.Image, min_side: int = 320) -> Image.Image: | |
""" | |
Resize 'image' so that neither dimension is < min_side, | |
then pad to a square with white background. | |
""" | |
w, h = image.size | |
scale = max(min_side / w, min_side / h, 1.0) | |
new_w, new_h = int(w * scale), int(h * scale) | |
image = image.resize((new_w, new_h), Image.LANCZOS) | |
side = max(new_w, new_h) | |
new_img = Image.new(mode="RGB", size=(side, side), color=(255, 255, 255)) | |
offset_x = (side - new_w) // 2 | |
offset_y = (side - new_h) // 2 | |
new_img.paste(image, (offset_x, offset_y)) | |
return new_img | |
def run_zero123(pipeline, input_image: Image.Image, steps: int = 50): | |
"""Generate a 640x960 grid from Zero123++ pipeline.""" | |
cond = make_square_min_dim(input_image, min_side=320) | |
with torch.autocast("cuda"): | |
result_grid = pipeline(cond, num_inference_steps=steps).images[0] | |
return result_grid | |
def crop_zero123_grid(grid_img: Image.Image): | |
""" | |
Zero123++ default output for 6-views is 640x960 (2 columns, 3 rows). | |
Crop into six 320x320 sub-images. | |
""" | |
coords = [ | |
(0, 0, 320, 320), | |
(320, 0, 640, 320), | |
(0, 320, 320, 640), | |
(320, 320, 640, 640), | |
(0, 640, 320, 960), | |
(320, 640, 640, 960), | |
] | |
sub_images = [] | |
for x1, y1, x2, y2 in coords: | |
sub_img = grid_img.crop((x1, y1, x2, y2)) | |
sub_images.append(sub_img) | |
return sub_images | |
# Example background compositing if desired: | |
def create_mask(image, bg_color=(255,255,255), threshold=30): | |
arr = np.array(image) | |
diff = np.abs(arr - np.array(bg_color)) | |
diff = diff.max(axis=2) | |
mask = (diff > threshold) * 255 | |
return Image.fromarray(mask.astype(np.uint8), mode="L") | |
def composite_foreground_background(fg, bg, bg_color=(255,255,255), threshold=30): | |
fg = fg.convert("RGBA") | |
bg = bg.convert("RGBA").resize(fg.size) | |
mask = create_mask(fg.convert("RGB"), bg_color=bg_color, threshold=threshold) | |
result = Image.composite(fg, bg, mask) | |
return result | |
def get_bg_color(image): | |
corner_pixel = image.getpixel((0, 0)) | |
# Heuristic: if corner pixel is near-white, treat as white background | |
if sum(corner_pixel) / 3 > 240: | |
return (255, 255, 255) | |
else: | |
return (200, 200, 200) | |
# ------------------------------------------------------------------------------ | |
# 3. STREAMLIT UI | |
# ------------------------------------------------------------------------------ | |
def main(): | |
st.title("Funko Generator (SD + SyncDreamer + Zero123)") | |
# Load pipelines once | |
sd_pipe = load_sd_pipeline() | |
zero123_pipe = load_zero123_pipeline() | |
# Session state to store images | |
if "latest_image" not in st.session_state: | |
st.session_state["latest_image"] = None | |
if "original_prompt" not in st.session_state: | |
st.session_state["original_prompt"] = "" | |
# --------------------------- | |
# A) Prompt Input | |
# --------------------------- | |
st.subheader("1. Enter your initial Funko prompt") | |
with st.expander("Prompt Examples"): | |
st.write(""" | |
- A standing plain human Funko in a blue shirt and blue pants with round black eyes with glasses with a belt. | |
- A sitting angry animal Funko with squint black eyes. | |
- A standing happy robot Funko in a brown shirt and grey pants with squint black eyes with cane and monocle. | |
- ... | |
""") | |
user_prompt = st.text_area( | |
"Type your Funko prompt here:", | |
value="A standing plain human Funko in a blue shirt and blue pants with round black eyes with glasses." | |
) | |
generate_initial = st.button("Generate Initial Funko") | |
if generate_initial: | |
st.session_state["original_prompt"] = user_prompt | |
with st.spinner("Generating initial Funko..."): | |
out_img = generate_funko_image(sd_pipe, user_prompt, steps=50) | |
st.session_state["latest_image"] = out_img | |
st.success("Image generated!") | |
if st.session_state["latest_image"] is not None: | |
st.image(st.session_state["latest_image"], caption="Latest Funko Image", use_column_width=True) | |
# --------------------------- | |
# B) Modify Funko Attributes | |
# --------------------------- | |
st.subheader("2. Modify the Funko (attributes)") | |
st.write("Pick new attributes. If you choose 'none', we won't override that attribute.") | |
characters = ['none', 'animal', 'human', 'robot'] | |
eyes_shapes = ['none', 'anime', 'black', 'closed', 'round', 'square', 'squint'] | |
eyes_colors = ['none', 'black', 'blue', 'brown', 'green', 'grey', 'orange', 'pink', 'purple', 'red', 'white', 'yellow'] | |
eyewears = ['none', 'eyepatch', 'glasses', 'goggles', 'helmet', 'mask', 'sunglasses'] | |
hair_colors = ['none', 'black', 'blonde', 'blue', 'brown', 'green', 'grey', 'orange', 'pink', 'purple', 'red', 'white', 'yellow'] | |
emotions = ['none', 'angry', 'happy', 'plain', 'sad'] | |
shirt_colors = ['none', 'black', 'blue', 'brown', 'green', 'grey', 'orange', 'pink', 'purple', 'red', 'white', 'yellow'] | |
pants_colors = ['none', 'black', 'blue', 'brown', 'green', 'grey', 'orange', 'pink', 'purple', 'red', 'white', 'yellow'] | |
accessories = ['none', 'bag', 'ball', 'belt', 'bird', 'book', 'cape', 'guitar', 'hat', 'helmet', 'sword', 'wand', 'wings'] | |
poses = ['none', 'sitting', 'standing'] | |
chosen_char = st.selectbox("Character", characters) | |
chosen_eyes_shape = st.selectbox("Eyes Shape", eyes_shapes) | |
chosen_eyes_color = st.selectbox("Eyes Color", eyes_colors) | |
chosen_eyewear = st.selectbox("Eyewear", eyewears) | |
chosen_hair_color = st.selectbox("Hair Color", hair_colors) | |
chosen_emotion = st.selectbox("Emotion", emotions) | |
chosen_shirt_color = st.selectbox("Shirt Color", shirt_colors) | |
chosen_pants_color = st.selectbox("Pants Color", pants_colors) | |
chosen_accessory = st.selectbox("Accessories", accessories) | |
chosen_pose = st.selectbox("Pose", poses) | |
def build_modified_prompt(): | |
# Simple new prompt builder | |
# If 'none', we do not override the attribute (use fallback or skip). | |
tokens = [] | |
# Pose | |
if chosen_pose != 'none': | |
tokens.append(f"A {chosen_pose}") | |
else: | |
tokens.append("A standing") | |
# Emotion + Character | |
if chosen_emotion != 'none': | |
tokens.append(chosen_emotion) | |
else: | |
tokens.append("plain") | |
if chosen_char != 'none': | |
tokens.append(chosen_char + " Funko") | |
else: | |
tokens.append("human Funko") | |
# Shirt color | |
if chosen_shirt_color != 'none': | |
tokens.append(f"in a {chosen_shirt_color} shirt") | |
else: | |
tokens.append("in a blue shirt") | |
# Pants color | |
if chosen_pants_color != 'none': | |
tokens.append(f"and {chosen_pants_color} pants") | |
else: | |
tokens.append("and blue pants") | |
# Eyes | |
eye_desc = [] | |
if chosen_eyes_shape != 'none': | |
eye_desc.append(chosen_eyes_shape) | |
else: | |
eye_desc.append("round") | |
if chosen_eyes_color != 'none': | |
eye_desc.append(chosen_eyes_color) | |
else: | |
eye_desc.append("black") | |
eye_desc.append("eyes") | |
tokens.append("with " + " ".join(eye_desc)) | |
if chosen_eyewear != 'none': | |
tokens.append(f"with {chosen_eyewear}") | |
if chosen_hair_color != 'none': | |
tokens.append(f"with {chosen_hair_color} hair") | |
if chosen_accessory != 'none': | |
tokens.append(f"with a {chosen_accessory}") | |
return " ".join(tokens) + "." | |
if st.button("Generate Modified Funko"): | |
if st.session_state["original_prompt"] == "": | |
st.warning("Please generate an initial Funko first.") | |
else: | |
new_prompt = build_modified_prompt() | |
st.write("**New Prompt**:", new_prompt) | |
with st.spinner("Generating modified image..."): | |
out_img = generate_funko_image(sd_pipe, new_prompt, steps=50) | |
st.session_state["latest_image"] = out_img | |
st.image(st.session_state["latest_image"], caption="Modified Funko", use_column_width=True) | |
# --------------------------- | |
# C) Animate with SyncDreamer | |
# --------------------------- | |
st.subheader("3. Animate the Funko with SyncDreamer") | |
st.write("Click to run SyncDreamer on the last generated image (placeholder).") | |
if st.button("Animate Funko"): | |
if st.session_state["latest_image"] is None: | |
st.warning("No image to animate. Generate a Funko first.") | |
else: | |
# Save the current image | |
input_path = "latest_funko.png" | |
st.session_state["latest_image"].save(input_path) | |
output_dir = "syncdreamer_output" | |
run_syncdreamer(input_path, output_dir=output_dir) | |
st.success("SyncDreamer run complete (demo). Check output directory for results.") | |
# --------------------------- | |
# D) Multi-View with Zero123++ | |
# --------------------------- | |
st.subheader("4. Generate Multi-View Funko (Zero123++)") | |
if st.button("Generate Multi-View 3D"): | |
if st.session_state["latest_image"] is None: | |
st.warning("No image to process. Generate a Funko first.") | |
else: | |
# Save for Zero123 | |
zero123_input_path = "funko_for_zero123.png" | |
st.session_state["latest_image"].save(zero123_input_path) | |
with st.spinner("Running Zero123++..."): | |
full_image = run_zero123(zero123_pipe, st.session_state["latest_image"], steps=50) | |
# Display the 640x960 grid | |
st.image(full_image, caption="Zero123++ Grid (640x960)", use_column_width=True) | |
# Crop sub-images | |
sub_images = crop_zero123_grid(full_image) | |
st.write("Six sub-views:") | |
for i, s_img in enumerate(sub_images): | |
st.image(s_img, width=256, caption=f"View {i+1}") | |
# --------------------------- | |
# E) Background Compositing | |
# --------------------------- | |
st.subheader("5. Apply Background to Each View") | |
bg_file = st.file_uploader("Upload a background image (PNG/JPG)", type=["png","jpg","jpeg"]) | |
if bg_file is not None: | |
st.image(bg_file, caption="Your Background", width=200) | |
if st.button("Composite Background onto Views"): | |
if bg_file is None: | |
st.warning("No background uploaded.") | |
else: | |
# We assume you already did "Generate Multi-View 3D" so we have "Zero123++ Grid" | |
# In a real scenario, you might store sub-images in session_state after generation | |
# For this example, let's assume we re-run the pipeline or re-crop a stored grid. | |
if st.session_state["latest_image"] is None: | |
st.warning("No Funko image found. Generate or do multi-view first.") | |
else: | |
# We'll read the background | |
bg = Image.open(bg_file).convert("RGBA") | |
# Suppose we have a stored "zero123_grid.png" from the step above | |
# This is a simplistic approach. You might track them in session state. | |
if not os.path.exists("zero123_grid.png"): | |
st.warning("No zero123_grid.png found. Please run Zero123++ step first.") | |
else: | |
grid_img = Image.open("zero123_grid.png").convert("RGB") | |
sub_images = crop_zero123_grid(grid_img) | |
# Composite each sub-image | |
st.write("Applying background to each sub-view...") | |
for i, fg_img in enumerate(sub_images): | |
# Detect background color from Funko sub-view | |
bg_color = get_bg_color(fg_img) | |
comp = composite_foreground_background(fg_img, bg, bg_color=bg_color, threshold=30) | |
st.image(comp, width=256, caption=f"Composite View {i+1}") | |
st.write("---") | |
st.write("End of the demo. Adapt paths and code to your environment as needed.") | |
# ------------------------------------------------------------------------------ | |
# 4. ENTRY POINT | |
# ------------------------------------------------------------------------------ | |
if __name__ == "__main__": | |
main() | |