File size: 14,059 Bytes
4512528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
import streamlit as st
import torch
import random
import subprocess
import os
from PIL import Image
from diffusers import StableDiffusionPipeline, UNet2DConditionModel
# ... any other imports you need ...

# ------------------------------------------------------------------------------
# 1. Load your models ONCE in global scope (so they don't reload on every run).
# ------------------------------------------------------------------------------

@st.cache_resource
def load_sd_pipeline(base_model_path: str, fine_tuned_path: str):
    # Safety checker dummy function for demonstration:
    def dummy_safety_checker(images, clip_input):
        return images, False

    pipe = StableDiffusionPipeline.from_pretrained(
        base_model_path,
        torch_dtype=torch.float16
    )
    pipe.to("cuda")

    # Load the fine-tuned UNet
    unet = UNet2DConditionModel.from_pretrained(
        fine_tuned_path,
        subfolder="unet",
        torch_dtype=torch.float16
    ).to('cuda')

    pipe.unet = unet
    pipe.safety_checker = dummy_safety_checker

    return pipe

# Similarly, if you want to load Zero123++ or other pipelines:
@st.cache_resource
def load_zero123_pipeline():
    from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler

    pipeline = DiffusionPipeline.from_pretrained(
        "sudo-ai/zero123plus-v1.2",
        custom_pipeline="sudo-ai/zero123plus-pipeline",
        torch_dtype=torch.float16
    )
    pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
        pipeline.scheduler.config, timestep_spacing='trailing'
    )
    pipeline.to("cuda")
    return pipeline


# Example placeholders for the SyncDreamer command or internal functions:
def run_syncdreamer(input_path: str, output_dir: str = "syncdreamer_output"):
    """Runs SyncDreamer on input_path and places results into output_dir."""
    st.info("Running SyncDreamer... (placeholder)")
    # This is where your actual command would go:
    # subprocess.run([...], check=True)
    os.makedirs(output_dir, exist_ok=True)
    # (In a real scenario, you'd handle .jpg to .png conversion, etc.)
    st.success(f"SyncDreamer completed. Results in: {output_dir}")


# Helper function for Zero123++ pipeline
def make_square_min_dim(image: Image.Image, min_side: int = 320) -> Image.Image:
    w, h = image.size
    scale = max(min_side / w, min_side / h, 1.0)
    new_w, new_h = int(w * scale), int(h * scale)
    image = image.resize((new_w, new_h), Image.LANCZOS)

    side = max(new_w, new_h)
    new_img = Image.new(mode="RGB", size=(side, side), color=(255, 255, 255))
    offset_x = (side - new_w) // 2
    offset_y = (side - new_h) // 2
    new_img.paste(image, (offset_x, offset_y))
    return new_img


# ------------------------------------------------------------------------------
# 2. Streamlit application.
# ------------------------------------------------------------------------------

def main():
    st.title("Funko Generator Demo")

    # Let’s load pipelines in the background:
    base_model_path = "runwayml/stable-diffusion-v1-5"
    fine_tuned_path = "/content/drive/MyDrive/CC_Project/checkpoint-3000"  # adapt if needed
    sd_pipe = load_sd_pipeline(base_model_path, fine_tuned_path)

    zero123_pipe = load_zero123_pipeline()  # For multi-view generation

    # Session state to hold:
    if "latest_image" not in st.session_state:
        st.session_state["latest_image"] = None
    if "original_prompt" not in st.session_state:
        st.session_state["original_prompt"] = ""

    # --------------------------------------------------------------------------
    # A) Prompt input & initial generation
    # --------------------------------------------------------------------------
    st.subheader("1. Enter your Funko prompt")

    # Show examples in the UI
    with st.expander("Examples of valid prompts"):
        st.write("""
        - A standing plain human Funko in a blue shirt and blue pants with round black eyes with glasses with a belt.
        - A sitting angry animal Funko with squint black eyes.
        - A standing happy robot Funko in a brown shirt and grey pants with squint black eyes with cane and monocle.
        - ...
        """)

    user_prompt = st.text_area("Type your Funko prompt here:", 
                               value="A standing plain human Funko in a blue shirt and blue pants with round black eyes with glasses.")
    generate_button = st.button("Generate Initial Funko")

    if generate_button:
        st.session_state["original_prompt"] = user_prompt
        with st.spinner("Generating image..."):
            with torch.autocast("cuda"):
                image = sd_pipe(user_prompt, num_inference_steps=50).images[0]
            st.session_state["latest_image"] = image
        
        st.success("Image generated!")

    if st.session_state["latest_image"] is not None:
        st.image(st.session_state["latest_image"], caption="Latest Generated Image", use_column_width=True)

    # --------------------------------------------------------------------------
    # B) Change the Funko (attributes)
    # --------------------------------------------------------------------------
    st.subheader("2. Modify Funko Attributes")
    st.write("Select new attributes below. If you choose 'none', that attribute will be ignored/omitted in the prompt.")

    # Possible attributes (from your code) — including 'none'
    characters = ['none', 'animal', 'human', 'robot']
    eyes_shape = ['none', 'anime', 'black', 'closed', 'round', 'square', 'squint']
    eyes_color = ['none', 'black', 'blue', 'brown', 'green', 'grey', 'orange', 'pink', 'purple', 'red', 'white', 'yellow']
    eyewear = ['none', 'eyepatch', 'glasses', 'goggles', 'helmet', 'mask', 'sunglasses']
    hair_color = ['none', 'black', 'blonde', 'blue', 'brown', 'green', 'grey', 'orange', 'pink', 'purple', 'red', 'white', 'yellow']
    emotion = ['none', 'angry', 'happy', 'plain', 'sad']
    shirt_color = ['none', 'black', 'blue', 'brown', 'green', 'grey', 'orange', 'pink', 'purple', 'red', 'white', 'yellow']
    pants_color = ['none', 'black', 'blue', 'brown', 'green', 'grey', 'orange', 'pink', 'purple', 'red', 'white', 'yellow']
    accessories = ['none', 'bag', 'ball', 'belt', 'bird', 'book', 'cape', 'guitar', 'hat', 'helmet', 'sword', 'wand', 'wings']
    pose = ['none', 'sitting', 'standing']

    # Create selection widgets:
    chosen_char = st.selectbox("Character:", characters)
    chosen_eyes_shape = st.selectbox("Eyes Shape:", eyes_shape)
    chosen_eyes_color = st.selectbox("Eyes Color:", eyes_color)
    chosen_eyewear = st.selectbox("Eyewear:", eyewear)
    chosen_hair_color = st.selectbox("Hair Color:", hair_color)
    chosen_emotion = st.selectbox("Emotion:", emotion)
    chosen_shirt_color = st.selectbox("Shirt Color:", shirt_color)
    chosen_pants_color = st.selectbox("Pants Color:", pants_color)
    chosen_accessories = st.selectbox("Accessories:", accessories)
    chosen_pose = st.selectbox("Pose:", pose)

    # Now we form a modified prompt. For demonstration,
    # let's do something simple: we take the original prompt, parse it, and
    # replace only the attributes that are not 'none'.
    def modify_prompt(base_prompt: str):
        # A simple example: we can build a new prompt from scratch, ignoring the old text.
        # In reality, you might parse the old text or do something more sophisticated.
        new_prompt_segments = []

        # Pose
        if chosen_pose != 'none':
            new_prompt_segments.append(f"A {chosen_pose}")
        else:
            new_prompt_segments.append("A standing")  # default fallback

        # Emotion + Character
        if chosen_emotion != 'none':
            new_prompt_segments.append(chosen_emotion)
        else:
            new_prompt_segments.append("plain")  # fallback

        if chosen_char != 'none':
            new_prompt_segments.append(chosen_char + " Funko")
        else:
            new_prompt_segments.append("human Funko")

        # Shirt color
        if chosen_shirt_color != 'none':
            new_prompt_segments.append(f"in a {chosen_shirt_color} shirt")
        else:
            new_prompt_segments.append("in a blue shirt")

        # Pants color
        if chosen_pants_color != 'none':
            new_prompt_segments.append(f"and {chosen_pants_color} pants")
        else:
            new_prompt_segments.append("and blue pants")

        # Eyes
        eye_text = []
        if chosen_eyes_shape != 'none':
            eye_text.append(f"{chosen_eyes_shape}")
        else:
            eye_text.append("round")
        if chosen_eyes_color != 'none':
            eye_text.append(f"{chosen_eyes_color}")
        else:
            eye_text.append("black")
        eye_text.append("eyes")
        new_prompt_segments.append("with " + " ".join(eye_text))

        # Eyewear
        if chosen_eyewear != 'none':
            new_prompt_segments.append(f"with {chosen_eyewear}")

        # Hair
        if chosen_hair_color != 'none':
            new_prompt_segments.append(f"with {chosen_hair_color} hair")

        # Accessories
        if chosen_accessories != 'none':
            new_prompt_segments.append(f"with a {chosen_accessories}")

        return " ".join(new_prompt_segments) + "."

    if st.button("Generate Modified Funko"):
        if not st.session_state["original_prompt"]:
            st.warning("Please generate an initial Funko (step 1) before modifying it.")
        else:
            new_prompt = modify_prompt(st.session_state["original_prompt"])
            st.write(f"**New Prompt**: {new_prompt}")

            with st.spinner("Generating modified image..."):
                with torch.autocast("cuda"):
                    image = sd_pipe(new_prompt, num_inference_steps=50).images[0]
                st.session_state["latest_image"] = image

            st.image(st.session_state["latest_image"], caption="Modified Image", use_column_width=True)

    # --------------------------------------------------------------------------
    # C) Animate the Funko with SyncDreamer
    # --------------------------------------------------------------------------
    st.subheader("3. Animate the Funko (SyncDreamer)")
    st.write("Click the button to run SyncDreamer on the last generated image. (Demo)")

    if st.button("Animate with SyncDreamer"):
        if st.session_state["latest_image"] is None:
            st.warning("No image found. Please generate a Funko first.")
        else:
            # Save latest image locally so SyncDreamer can process it
            input_path = "latest_funko.png"
            st.session_state["latest_image"].save(input_path)
            run_syncdreamer(input_path, output_dir="syncdreamer_output")

            # Optionally display a placeholder or actual frames/GIF
            # ...
            st.success("SyncDreamer animation completed (placeholder).")

    # --------------------------------------------------------------------------
    # D) Multi-View 3D Funko (Zero123++)
    # --------------------------------------------------------------------------
    st.subheader("4. Generate Multi-View 3D Funko (Zero123++)")

    if st.button("Generate Multi-View 3D"):
        if st.session_state["latest_image"] is None:
            st.warning("No image found. Please generate a Funko first.")
        else:
            # Save the last image as input for Zero123
            input_path = "funko_for_zero123.png"
            st.session_state["latest_image"].save(input_path)

            # Make sure image is at least 320x320 and square
            original_img = Image.open(input_path).convert("RGB")
            cond = make_square_min_dim(original_img, min_side=320)

            # Inference
            st.info("Running Zero123++ pipeline... Please wait.")
            with torch.autocast("cuda"):
                result_grid = zero123_pipe(cond, num_inference_steps=50).images[0]

            result_grid.save("zero123_grid.png")
            st.image(result_grid, caption="Zero123++ Multi-View Grid (640x960)")

            # Optionally crop and display sub-views
            # Here we crop 6 sub-images of 320x320 from the 640x960 grid:
            coords = [
                (0,   0,   320, 320),
                (320, 0,   640, 320),
                (0,   320, 320, 640),
                (320, 320, 640, 640),
                (0,   640, 320, 960),
                (320, 640, 640, 960),
            ]
            st.write("### Generated Views:")
            for i, (x1, y1, x2, y2) in enumerate(coords):
                sub_img = result_grid.crop((x1, y1, x2, y2))
                sub_path = f"zero123_view_{i}.png"
                sub_img.save(sub_path)
                st.image(sub_path, width=256)

    # --------------------------------------------------------------------------
    # E) Integrate a New Background
    # --------------------------------------------------------------------------
    st.subheader("5. Apply a New Background to Each View")

    st.write("Upload a background image, then apply it to each previously generated view.")
    bg_file = st.file_uploader("Upload Background Image", type=["png", "jpg", "jpeg"])
    if bg_file is not None:
        st.image(bg_file, caption="Selected Background", width=200)

    if st.button("Apply Background to Multi-View"):
        if bg_file is None:
            st.warning("No background uploaded.")
        else:
            # Save background to disk:
            bg_path = "background.png"
            with open(bg_path, "wb") as f:
                f.write(bg_file.read())

            # In a real implementation, you would do the compositing described
            # in your original code with threshold-based masking, etc.
            # For demonstration, let's just say "Applied!"
            st.success("Background compositing placeholder done. Check your images in the output folder.")

    st.write("End of the Demo. Adjust code as needed for your pipeline paths and logic.")


if __name__ == "__main__":
    main()