Spaces:
Sleeping
Sleeping
Commit
·
05d00b7
1
Parent(s):
ac36933
init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- DockerFile → Dockerfile +2 -0
- README.md +4 -22
- app.py +10 -74
- app_3d.py +0 -21
- app_canny.py +0 -83
- app_matnet.py +0 -83
- app_sd.py +0 -154
- app_texnet.py +0 -259
- cv_utils.py +0 -17
- depth_estimator.py +0 -25
- examples/bunny/frame_0001.png +0 -3
- examples/bunny/mesh.obj +0 -0
- examples/bunny/uv_normal.png +0 -3
- examples/fighter/frame_0001.png +0 -3
- examples/fighter/mesh.obj +0 -0
- examples/fighter/uv_normal.png +0 -3
- examples/highheel/frame_0001.png +0 -3
- examples/highheel/mesh.obj +0 -0
- examples/highheel/uv_normal.png +0 -3
- examples/monkey/frame_0001.png +0 -3
- examples/monkey/mesh.obj +0 -0
- examples/monkey/uv_normal.png +0 -3
- examples/tank/frame_0001.png +0 -3
- examples/tank/mesh.obj +0 -3
- examples/tank/uv_normal.png +0 -3
- examples/tshirt/frame_0001.png +0 -3
- examples/tshirt/mesh.obj +0 -3
- examples/tshirt/uv_normal.png +0 -3
- image_segmentor.py +0 -33
- install.sh +0 -18
- model.py +0 -959
- pre-requirements.txt +0 -9
- preprocessor.py +0 -120
- push_dataset.py +0 -9
- requirements.txt +0 -9
- rgb2x/generate_blend.py +0 -142
- rgb2x/gradio_demo_rgb2x.py +0 -157
- rgb2x/load_image.py +0 -119
- rgb2x/pipeline_rgb2x.py +0 -821
- run.sh +5 -0
- settings.py +0 -23
- text2tex/lib/__init__.py +0 -0
- text2tex/lib/camera_helper.py +0 -231
- text2tex/lib/constants.py +0 -648
- text2tex/lib/diffusion_helper.py +0 -189
- text2tex/lib/io_helper.py +0 -78
- text2tex/lib/mesh_helper.py +0 -148
- text2tex/lib/projection_helper.py +0 -464
- text2tex/lib/render_helper.py +0 -108
- text2tex/lib/shading_helper.py +0 -45
DockerFile → Dockerfile
RENAMED
@@ -8,8 +8,10 @@ RUN conda env create -f /code/environment.yml
|
|
8 |
|
9 |
# Set up a new user named "user" with user ID 1000
|
10 |
RUN useradd -m -u 1000 user
|
|
|
11 |
# Switch to the "user" user
|
12 |
USER user
|
|
|
13 |
# Set home to the user's home directory
|
14 |
ENV HOME=/home/user \
|
15 |
PYTHONPATH=$HOME/app \
|
|
|
8 |
|
9 |
# Set up a new user named "user" with user ID 1000
|
10 |
RUN useradd -m -u 1000 user
|
11 |
+
|
12 |
# Switch to the "user" user
|
13 |
USER user
|
14 |
+
|
15 |
# Set home to the user's home directory
|
16 |
ENV HOME=/home/user \
|
17 |
PYTHONPATH=$HOME/app \
|
README.md
CHANGED
@@ -1,28 +1,10 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: docker
|
7 |
pinned: false
|
8 |
-
license: mit
|
9 |
---
|
10 |
|
11 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
12 |
-
|
13 |
-
|
14 |
-
## setup locally
|
15 |
-
conda create -n matgen python=3.11
|
16 |
-
conda activate matgen
|
17 |
-
pip install diffusers["torch"] transformers accelerate xformers
|
18 |
-
pip install gradio
|
19 |
-
pip install controlnet-aux
|
20 |
-
|
21 |
-
## local authen
|
22 |
-
huggingface-cli login
|
23 |
-
|
24 |
-
## on using Huggingface ZeroGPU
|
25 |
-
need to import spaces and the corresponding decorator
|
26 |
-
https://huggingface.co/docs/hub/spaces-zerogpu
|
27 |
-
|
28 |
-
also, check the usage of controlnet over zerogpu here: https://huggingface.co/spaces/radames/Enhance-This-HiDiffusion-SDXL/blob/main/app.py
|
|
|
1 |
---
|
2 |
+
title: Gradio Conda Template
|
3 |
+
emoji: 🐨
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: indigo
|
6 |
sdk: docker
|
7 |
pinned: false
|
|
|
8 |
---
|
9 |
|
10 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
CHANGED
@@ -1,80 +1,16 @@
|
|
1 |
-
#!/usr/bin/env python
|
2 |
-
|
3 |
import gradio as gr
|
4 |
-
import torch
|
5 |
-
|
6 |
-
import sys
|
7 |
-
pyt_version_str=torch.__version__.split("+")[0].replace(".", "")
|
8 |
-
version_str="".join([
|
9 |
-
f"py3{sys.version_info.minor}_cu",
|
10 |
-
torch.version.cuda.replace(".",""),
|
11 |
-
f"_pyt{pyt_version_str}"
|
12 |
-
])
|
13 |
-
print(f"Using version: {version_str}") # used to locate pytorch3d version in the requirements.txt for huggingface
|
14 |
-
|
15 |
-
|
16 |
-
from app_canny import create_demo as create_demo_canny
|
17 |
-
from app_texnet import create_demo as create_demo_texnet
|
18 |
-
|
19 |
-
from model import Model
|
20 |
-
from settings import ALLOW_CHANGING_BASE_MODEL, DEFAULT_MODEL_ID, SHOW_DUPLICATE_BUTTON
|
21 |
|
22 |
-
DESCRIPTION = "# Material Authoring Demo v0.3"
|
23 |
|
24 |
-
|
25 |
-
|
26 |
|
27 |
-
# model = Model(base_model_id=DEFAULT_MODEL_ID, task_name="Canny")
|
28 |
-
model = Model(base_model_id=DEFAULT_MODEL_ID, task_name="texnet")
|
29 |
|
30 |
with gr.Blocks() as demo:
|
31 |
-
gr.Markdown(
|
32 |
-
gr.
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
)
|
37 |
-
|
38 |
-
|
39 |
-
with gr.Tab("Texnet+Matnet"):
|
40 |
-
create_demo_texnet(model.process_texnet)
|
41 |
-
|
42 |
-
with gr.Accordion(label="Base model", open=False):
|
43 |
-
with gr.Row():
|
44 |
-
with gr.Column(scale=5):
|
45 |
-
current_base_model = gr.Text(label="Current base model")
|
46 |
-
with gr.Column(scale=1):
|
47 |
-
check_base_model_button = gr.Button("Check current base model")
|
48 |
-
with gr.Row():
|
49 |
-
with gr.Column(scale=5):
|
50 |
-
new_base_model_id = gr.Text(
|
51 |
-
label="New base model",
|
52 |
-
max_lines=1,
|
53 |
-
placeholder="stable-diffusion-v1-5/stable-diffusion-v1-5",
|
54 |
-
info="The base model must be compatible with Stable Diffusion v1.5.",
|
55 |
-
interactive=ALLOW_CHANGING_BASE_MODEL,
|
56 |
-
)
|
57 |
-
with gr.Column(scale=1):
|
58 |
-
change_base_model_button = gr.Button("Change base model", interactive=ALLOW_CHANGING_BASE_MODEL)
|
59 |
-
if not ALLOW_CHANGING_BASE_MODEL:
|
60 |
-
gr.Markdown(
|
61 |
-
"""The base model is not allowed to be changed in this Space so as not to slow down the demo, but it can be changed if you duplicate the Space."""
|
62 |
-
)
|
63 |
-
|
64 |
-
check_base_model_button.click(
|
65 |
-
fn=lambda: model.base_model_id,
|
66 |
-
outputs=current_base_model,
|
67 |
-
queue=False,
|
68 |
-
api_name="check_base_model",
|
69 |
-
)
|
70 |
-
gr.on(
|
71 |
-
triggers=[new_base_model_id.submit, change_base_model_button.click],
|
72 |
-
fn=model.set_base_model,
|
73 |
-
inputs=new_base_model_id,
|
74 |
-
outputs=current_base_model,
|
75 |
-
api_name=False,
|
76 |
-
concurrency_id="main",
|
77 |
-
)
|
78 |
-
|
79 |
-
if __name__ == "__main__":
|
80 |
-
demo.queue(max_size=20).launch()
|
|
|
|
|
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
|
|
3 |
|
4 |
+
def update(name):
|
5 |
+
return f"Welcome to Gradio, {name}!"
|
6 |
|
|
|
|
|
7 |
|
8 |
with gr.Blocks() as demo:
|
9 |
+
gr.Markdown("Start typing below and then click **Run** to see the output.")
|
10 |
+
with gr.Row():
|
11 |
+
inp = gr.Textbox(placeholder="What is your name?")
|
12 |
+
out = gr.Textbox()
|
13 |
+
btn = gr.Button("Run")
|
14 |
+
btn.click(fn=update, inputs=inp, outputs=out)
|
15 |
+
|
16 |
+
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app_3d.py
DELETED
@@ -1,21 +0,0 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
import os
|
3 |
-
|
4 |
-
def load_mesh(mesh_file_name):
|
5 |
-
return mesh_file_name
|
6 |
-
|
7 |
-
demo = gr.Interface(
|
8 |
-
fn=load_mesh,
|
9 |
-
inputs=gr.Model3D(),
|
10 |
-
outputs=gr.Model3D(
|
11 |
-
clear_color=(255.0, 0.0, 0.0, 0.0), label="3D Model", display_mode="wireframe"),
|
12 |
-
examples=[
|
13 |
-
[os.path.join(os.path.dirname(__file__), "examples/bunny/mesh.obj")],
|
14 |
-
[os.path.join(os.path.dirname(__file__), "examples/monkey/mesh.obj")],
|
15 |
-
[os.path.join(os.path.dirname(__file__), "examples/Bunny.obj")],
|
16 |
-
],
|
17 |
-
cache_examples=True
|
18 |
-
)
|
19 |
-
|
20 |
-
if __name__ == "__main__":
|
21 |
-
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app_canny.py
DELETED
@@ -1,83 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python
|
2 |
-
|
3 |
-
import gradio as gr
|
4 |
-
|
5 |
-
from settings import (
|
6 |
-
DEFAULT_IMAGE_RESOLUTION,
|
7 |
-
DEFAULT_NUM_IMAGES,
|
8 |
-
MAX_IMAGE_RESOLUTION,
|
9 |
-
MAX_NUM_IMAGES,
|
10 |
-
MAX_SEED,
|
11 |
-
)
|
12 |
-
from utils import randomize_seed_fn
|
13 |
-
|
14 |
-
|
15 |
-
def create_demo(process):
|
16 |
-
with gr.Blocks() as demo:
|
17 |
-
with gr.Row():
|
18 |
-
with gr.Column():
|
19 |
-
image = gr.Image()
|
20 |
-
prompt = gr.Textbox(label="Prompt", submit_btn=True)
|
21 |
-
with gr.Accordion("Advanced options", open=False):
|
22 |
-
num_samples = gr.Slider(
|
23 |
-
label="Number of images", minimum=1, maximum=MAX_NUM_IMAGES, value=DEFAULT_NUM_IMAGES, step=1
|
24 |
-
)
|
25 |
-
image_resolution = gr.Slider(
|
26 |
-
label="Image resolution",
|
27 |
-
minimum=256,
|
28 |
-
maximum=MAX_IMAGE_RESOLUTION,
|
29 |
-
value=DEFAULT_IMAGE_RESOLUTION,
|
30 |
-
step=256,
|
31 |
-
)
|
32 |
-
canny_low_threshold = gr.Slider(
|
33 |
-
label="Canny low threshold", minimum=1, maximum=255, value=100, step=1
|
34 |
-
)
|
35 |
-
canny_high_threshold = gr.Slider(
|
36 |
-
label="Canny high threshold", minimum=1, maximum=255, value=200, step=1
|
37 |
-
)
|
38 |
-
num_steps = gr.Slider(label="Number of steps", minimum=1, maximum=100, value=20, step=1)
|
39 |
-
guidance_scale = gr.Slider(label="Guidance scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
|
40 |
-
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
|
41 |
-
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
42 |
-
a_prompt = gr.Textbox(label="Additional prompt", value="best quality, extremely detailed")
|
43 |
-
n_prompt = gr.Textbox(
|
44 |
-
label="Negative prompt",
|
45 |
-
value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
|
46 |
-
)
|
47 |
-
with gr.Column():
|
48 |
-
result = gr.Gallery(label="Output", show_label=False, columns=2, object_fit="scale-down")
|
49 |
-
inputs = [
|
50 |
-
image,
|
51 |
-
prompt,
|
52 |
-
a_prompt,
|
53 |
-
n_prompt,
|
54 |
-
num_samples,
|
55 |
-
image_resolution,
|
56 |
-
num_steps,
|
57 |
-
guidance_scale,
|
58 |
-
seed,
|
59 |
-
canny_low_threshold,
|
60 |
-
canny_high_threshold,
|
61 |
-
]
|
62 |
-
prompt.submit(
|
63 |
-
fn=randomize_seed_fn,
|
64 |
-
inputs=[seed, randomize_seed],
|
65 |
-
outputs=seed,
|
66 |
-
queue=False,
|
67 |
-
api_name=False,
|
68 |
-
).then(
|
69 |
-
fn=process,
|
70 |
-
inputs=inputs,
|
71 |
-
outputs=result,
|
72 |
-
api_name="canny",
|
73 |
-
concurrency_id="main",
|
74 |
-
)
|
75 |
-
return demo
|
76 |
-
|
77 |
-
|
78 |
-
if __name__ == "__main__":
|
79 |
-
from model import Model
|
80 |
-
|
81 |
-
model = Model(task_name="Canny")
|
82 |
-
demo = create_demo(model.process_canny)
|
83 |
-
demo.queue().launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app_matnet.py
DELETED
@@ -1,83 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python
|
2 |
-
|
3 |
-
import gradio as gr
|
4 |
-
|
5 |
-
from settings import (
|
6 |
-
DEFAULT_IMAGE_RESOLUTION,
|
7 |
-
DEFAULT_NUM_IMAGES,
|
8 |
-
MAX_IMAGE_RESOLUTION,
|
9 |
-
MAX_NUM_IMAGES,
|
10 |
-
MAX_SEED,
|
11 |
-
)
|
12 |
-
from utils import randomize_seed_fn
|
13 |
-
|
14 |
-
|
15 |
-
def create_demo(process):
|
16 |
-
with gr.Blocks() as demo:
|
17 |
-
with gr.Row():
|
18 |
-
with gr.Column():
|
19 |
-
image = gr.Image()
|
20 |
-
prompt = gr.Textbox(label="Prompt", submit_btn=True)
|
21 |
-
with gr.Accordion("Advanced options", open=False):
|
22 |
-
num_samples = gr.Slider(
|
23 |
-
label="Number of images", minimum=1, maximum=MAX_NUM_IMAGES, value=DEFAULT_NUM_IMAGES, step=1
|
24 |
-
)
|
25 |
-
image_resolution = gr.Slider(
|
26 |
-
label="Image resolution",
|
27 |
-
minimum=256,
|
28 |
-
maximum=MAX_IMAGE_RESOLUTION,
|
29 |
-
value=DEFAULT_IMAGE_RESOLUTION,
|
30 |
-
step=256,
|
31 |
-
)
|
32 |
-
canny_low_threshold = gr.Slider(
|
33 |
-
label="Canny low threshold", minimum=1, maximum=255, value=100, step=1
|
34 |
-
)
|
35 |
-
canny_high_threshold = gr.Slider(
|
36 |
-
label="Canny high threshold", minimum=1, maximum=255, value=200, step=1
|
37 |
-
)
|
38 |
-
num_steps = gr.Slider(label="Number of steps", minimum=1, maximum=100, value=20, step=1)
|
39 |
-
guidance_scale = gr.Slider(label="Guidance scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
|
40 |
-
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
|
41 |
-
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
42 |
-
a_prompt = gr.Textbox(label="Additional prompt", value="best quality, extremely detailed")
|
43 |
-
n_prompt = gr.Textbox(
|
44 |
-
label="Negative prompt",
|
45 |
-
value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
|
46 |
-
)
|
47 |
-
with gr.Column():
|
48 |
-
result = gr.Gallery(label="Output", show_label=False, columns=2, object_fit="scale-down")
|
49 |
-
inputs = [
|
50 |
-
image,
|
51 |
-
prompt,
|
52 |
-
a_prompt,
|
53 |
-
n_prompt,
|
54 |
-
num_samples,
|
55 |
-
image_resolution,
|
56 |
-
num_steps,
|
57 |
-
guidance_scale,
|
58 |
-
seed,
|
59 |
-
canny_low_threshold,
|
60 |
-
canny_high_threshold,
|
61 |
-
]
|
62 |
-
prompt.submit(
|
63 |
-
fn=randomize_seed_fn,
|
64 |
-
inputs=[seed, randomize_seed],
|
65 |
-
outputs=seed,
|
66 |
-
queue=False,
|
67 |
-
api_name=False,
|
68 |
-
).then(
|
69 |
-
fn=process,
|
70 |
-
inputs=inputs,
|
71 |
-
outputs=result,
|
72 |
-
api_name="canny",
|
73 |
-
concurrency_id="main",
|
74 |
-
)
|
75 |
-
return demo
|
76 |
-
|
77 |
-
|
78 |
-
if __name__ == "__main__":
|
79 |
-
from model import Model
|
80 |
-
|
81 |
-
model = Model(task_name="Canny")
|
82 |
-
demo = create_demo(model.process_canny)
|
83 |
-
demo.queue().launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app_sd.py
DELETED
@@ -1,154 +0,0 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
import numpy as np
|
3 |
-
import random
|
4 |
-
|
5 |
-
import spaces #[uncomment to use ZeroGPU]
|
6 |
-
from diffusers import DiffusionPipeline
|
7 |
-
import torch
|
8 |
-
|
9 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
10 |
-
model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
|
11 |
-
|
12 |
-
if torch.cuda.is_available():
|
13 |
-
torch_dtype = torch.float16
|
14 |
-
else:
|
15 |
-
torch_dtype = torch.float32
|
16 |
-
|
17 |
-
pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
|
18 |
-
pipe = pipe.to(device)
|
19 |
-
|
20 |
-
MAX_SEED = np.iinfo(np.int32).max
|
21 |
-
MAX_IMAGE_SIZE = 1024
|
22 |
-
|
23 |
-
|
24 |
-
@spaces.GPU #[uncomment to use ZeroGPU]
|
25 |
-
def infer(
|
26 |
-
prompt,
|
27 |
-
negative_prompt,
|
28 |
-
seed,
|
29 |
-
randomize_seed,
|
30 |
-
width,
|
31 |
-
height,
|
32 |
-
guidance_scale,
|
33 |
-
num_inference_steps,
|
34 |
-
progress=gr.Progress(track_tqdm=True),
|
35 |
-
):
|
36 |
-
if randomize_seed:
|
37 |
-
seed = random.randint(0, MAX_SEED)
|
38 |
-
|
39 |
-
generator = torch.Generator().manual_seed(seed)
|
40 |
-
|
41 |
-
image = pipe(
|
42 |
-
prompt=prompt,
|
43 |
-
negative_prompt=negative_prompt,
|
44 |
-
guidance_scale=guidance_scale,
|
45 |
-
num_inference_steps=num_inference_steps,
|
46 |
-
width=width,
|
47 |
-
height=height,
|
48 |
-
generator=generator,
|
49 |
-
).images[0]
|
50 |
-
|
51 |
-
return image, seed
|
52 |
-
|
53 |
-
|
54 |
-
examples = [
|
55 |
-
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
|
56 |
-
"An astronaut riding a green horse",
|
57 |
-
"A delicious ceviche cheesecake slice",
|
58 |
-
]
|
59 |
-
|
60 |
-
css = """
|
61 |
-
#col-container {
|
62 |
-
margin: 0 auto;
|
63 |
-
max-width: 640px;
|
64 |
-
}
|
65 |
-
"""
|
66 |
-
|
67 |
-
with gr.Blocks(css=css) as demo:
|
68 |
-
with gr.Column(elem_id="col-container"):
|
69 |
-
gr.Markdown(" # Text-to-Image Gradio Template")
|
70 |
-
|
71 |
-
with gr.Row():
|
72 |
-
prompt = gr.Text(
|
73 |
-
label="Prompt",
|
74 |
-
show_label=False,
|
75 |
-
max_lines=1,
|
76 |
-
placeholder="Enter your prompt",
|
77 |
-
container=False,
|
78 |
-
)
|
79 |
-
|
80 |
-
run_button = gr.Button("Run", scale=0, variant="primary")
|
81 |
-
|
82 |
-
result = gr.Image(label="Result", show_label=False)
|
83 |
-
|
84 |
-
with gr.Accordion("Advanced Settings", open=False):
|
85 |
-
negative_prompt = gr.Text(
|
86 |
-
label="Negative prompt",
|
87 |
-
max_lines=1,
|
88 |
-
placeholder="Enter a negative prompt",
|
89 |
-
visible=False,
|
90 |
-
)
|
91 |
-
|
92 |
-
seed = gr.Slider(
|
93 |
-
label="Seed",
|
94 |
-
minimum=0,
|
95 |
-
maximum=MAX_SEED,
|
96 |
-
step=1,
|
97 |
-
value=0,
|
98 |
-
)
|
99 |
-
|
100 |
-
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
101 |
-
|
102 |
-
with gr.Row():
|
103 |
-
width = gr.Slider(
|
104 |
-
label="Width",
|
105 |
-
minimum=256,
|
106 |
-
maximum=MAX_IMAGE_SIZE,
|
107 |
-
step=32,
|
108 |
-
value=1024, # Replace with defaults that work for your model
|
109 |
-
)
|
110 |
-
|
111 |
-
height = gr.Slider(
|
112 |
-
label="Height",
|
113 |
-
minimum=256,
|
114 |
-
maximum=MAX_IMAGE_SIZE,
|
115 |
-
step=32,
|
116 |
-
value=1024, # Replace with defaults that work for your model
|
117 |
-
)
|
118 |
-
|
119 |
-
with gr.Row():
|
120 |
-
guidance_scale = gr.Slider(
|
121 |
-
label="Guidance scale",
|
122 |
-
minimum=0.0,
|
123 |
-
maximum=10.0,
|
124 |
-
step=0.1,
|
125 |
-
value=0.0, # Replace with defaults that work for your model
|
126 |
-
)
|
127 |
-
|
128 |
-
num_inference_steps = gr.Slider(
|
129 |
-
label="Number of inference steps",
|
130 |
-
minimum=1,
|
131 |
-
maximum=50,
|
132 |
-
step=1,
|
133 |
-
value=2, # Replace with defaults that work for your model
|
134 |
-
)
|
135 |
-
|
136 |
-
gr.Examples(examples=examples, inputs=[prompt])
|
137 |
-
gr.on(
|
138 |
-
triggers=[run_button.click, prompt.submit],
|
139 |
-
fn=infer,
|
140 |
-
inputs=[
|
141 |
-
prompt,
|
142 |
-
negative_prompt,
|
143 |
-
seed,
|
144 |
-
randomize_seed,
|
145 |
-
width,
|
146 |
-
height,
|
147 |
-
guidance_scale,
|
148 |
-
num_inference_steps,
|
149 |
-
],
|
150 |
-
outputs=[result, seed],
|
151 |
-
)
|
152 |
-
|
153 |
-
if __name__ == "__main__":
|
154 |
-
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app_texnet.py
DELETED
@@ -1,259 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python
|
2 |
-
|
3 |
-
import os
|
4 |
-
import shutil
|
5 |
-
import tempfile
|
6 |
-
import gradio as gr
|
7 |
-
from PIL import Image
|
8 |
-
import numpy as np
|
9 |
-
|
10 |
-
from settings import (
|
11 |
-
DEFAULT_IMAGE_RESOLUTION,
|
12 |
-
DEFAULT_NUM_IMAGES,
|
13 |
-
MAX_IMAGE_RESOLUTION,
|
14 |
-
MAX_NUM_IMAGES,
|
15 |
-
MAX_SEED,
|
16 |
-
)
|
17 |
-
from utils import randomize_seed_fn
|
18 |
-
|
19 |
-
# ---- helper to build a quick textured copy of the mesh ---------------
|
20 |
-
def apply_texture(src_mesh:str, texture:str, tag:str)->str:
|
21 |
-
"""
|
22 |
-
Writes a copy of `src_mesh` and tiny .mtl that points to `texture`.
|
23 |
-
Returns the new OBJ/GLB path for viewing.
|
24 |
-
"""
|
25 |
-
tmp_dir = tempfile.mkdtemp()
|
26 |
-
mesh_copy = os.path.join(tmp_dir, f"{tag}.obj")
|
27 |
-
mtl_name = f"{tag}.mtl"
|
28 |
-
|
29 |
-
# copy geometry
|
30 |
-
shutil.copy(src_mesh, mesh_copy)
|
31 |
-
|
32 |
-
# write minimal MTL
|
33 |
-
with open(os.path.join(tmp_dir, mtl_name), "w") as f:
|
34 |
-
f.write(f"newmtl material_0\nmap_Kd {os.path.basename(texture)}\n")
|
35 |
-
|
36 |
-
# ensure texture lives next to OBJ
|
37 |
-
shutil.copy(texture, os.path.join(tmp_dir, os.path.basename(texture)))
|
38 |
-
|
39 |
-
# patch OBJ to reference our new MTL
|
40 |
-
with open(mesh_copy, "r+") as f:
|
41 |
-
lines = f.readlines()
|
42 |
-
if not lines[0].startswith("mtllib"):
|
43 |
-
lines.insert(0, f"mtllib {mtl_name}\n")
|
44 |
-
f.seek(0); f.writelines(lines)
|
45 |
-
|
46 |
-
return mesh_copy
|
47 |
-
|
48 |
-
def image_to_temp_path(img_like, tag, out_dir=None):
|
49 |
-
"""
|
50 |
-
Convert various image-like objects (str, PIL.Image, list, tuple) to temp PNG path.
|
51 |
-
Returns the path to the saved image file.
|
52 |
-
"""
|
53 |
-
# Handle tuple or list input
|
54 |
-
if isinstance(img_like, (list, tuple)):
|
55 |
-
if len(img_like) == 0:
|
56 |
-
raise ValueError("Empty image list/tuple.")
|
57 |
-
img_like = img_like[0]
|
58 |
-
|
59 |
-
# If it's already a file path
|
60 |
-
if isinstance(img_like, str):
|
61 |
-
return img_like
|
62 |
-
|
63 |
-
# If it's a PIL Image
|
64 |
-
if isinstance(img_like, Image.Image):
|
65 |
-
temp_path = os.path.join(tempfile.mkdtemp() if out_dir is None else out_dir, f"{tag}.png")
|
66 |
-
os.makedirs(os.path.dirname(temp_path), exist_ok=True)
|
67 |
-
img_like.save(temp_path)
|
68 |
-
return temp_path
|
69 |
-
|
70 |
-
# if it's numpy array
|
71 |
-
if isinstance(img_like, np.ndarray):
|
72 |
-
temp_path = os.path.join(tempfile.mkdtemp() if out_dir is None else out_dir, f"{tag}.png")
|
73 |
-
os.makedirs(os.path.dirname(temp_path), exist_ok=True)
|
74 |
-
img_like = Image.fromarray(img_like)
|
75 |
-
img_like.save(temp_path)
|
76 |
-
return temp_path
|
77 |
-
|
78 |
-
raise ValueError(f"Expected PIL.Image, str, list, or tuple — got {type(img_like)}")
|
79 |
-
|
80 |
-
def show_mesh(which, mesh, inp, coarse, fine):
|
81 |
-
"""Switch the displayed texture based on dropdown change."""
|
82 |
-
print()
|
83 |
-
tex_map = {
|
84 |
-
"Input": image_to_temp_path(inp, "input"),
|
85 |
-
"Coarse": coarse[0] if isinstance(coarse, tuple) else coarse,
|
86 |
-
"Fine": fine[0] if isinstance(fine, tuple) else fine,
|
87 |
-
}
|
88 |
-
texture_path = tex_map[which]
|
89 |
-
return apply_texture(mesh, texture_path, which.lower())
|
90 |
-
# ----------------------------------------------------------------------
|
91 |
-
|
92 |
-
|
93 |
-
def create_demo(process):
|
94 |
-
with gr.Blocks() as demo:
|
95 |
-
with gr.Row():
|
96 |
-
with gr.Column():
|
97 |
-
gr.Markdown("## Select preset from the example list, and modify the prompt accordingly")
|
98 |
-
with gr.Row():
|
99 |
-
name = gr.Textbox(label="Name", interactive=False, visible=False)
|
100 |
-
representative = gr.Image(label="Geometry", interactive=False)
|
101 |
-
image = gr.Image(label="UV Normal", interactive=False)
|
102 |
-
prompt = gr.Textbox(label="Prompt", submit_btn=True)
|
103 |
-
with gr.Accordion("Advanced options", open=False):
|
104 |
-
num_samples = gr.Slider(
|
105 |
-
label="Number of images", minimum=1, maximum=MAX_NUM_IMAGES, value=DEFAULT_NUM_IMAGES, step=1
|
106 |
-
)
|
107 |
-
image_resolution = gr.Slider(
|
108 |
-
label="Image resolution",
|
109 |
-
minimum=256,
|
110 |
-
maximum=MAX_IMAGE_RESOLUTION,
|
111 |
-
value=DEFAULT_IMAGE_RESOLUTION,
|
112 |
-
step=256,
|
113 |
-
)
|
114 |
-
num_steps = gr.Slider(label="Number of steps", minimum=1, maximum=100, value=10, step=1)
|
115 |
-
guidance_scale = gr.Slider(label="Guidance scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
|
116 |
-
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
|
117 |
-
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
118 |
-
a_prompt = gr.Textbox(label="Additional prompt", value="best quality, extremely detailed")
|
119 |
-
n_prompt = gr.Textbox(
|
120 |
-
label="Negative prompt",
|
121 |
-
value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
|
122 |
-
)
|
123 |
-
with gr.Column():
|
124 |
-
# 2x2 grid of images for the output textures
|
125 |
-
gr.Markdown("### Output BRDF")
|
126 |
-
with gr.Row():
|
127 |
-
base_color = gr.Gallery(label="Base Color", show_label=True, columns=1, object_fit="scale-down")
|
128 |
-
normal = gr.Gallery(label="Displacement Map", show_label=True, columns=1, object_fit="scale-down")
|
129 |
-
with gr.Row():
|
130 |
-
roughness = gr.Gallery(label="Roughness Map", show_label=True, columns=1, object_fit="scale-down")
|
131 |
-
metallic = gr.Gallery(label="Metallic Map", show_label=True, columns=1, object_fit="scale-down")
|
132 |
-
|
133 |
-
gr.Markdown("### Download Packed Blender Files for 3D Visualization")
|
134 |
-
out_blender_path = gr.File(label="Generated Blender File", file_types=[".blend"])
|
135 |
-
|
136 |
-
inputs = [
|
137 |
-
name, # Name of the object
|
138 |
-
representative, # Geometry mesh
|
139 |
-
image,
|
140 |
-
prompt,
|
141 |
-
a_prompt,
|
142 |
-
n_prompt,
|
143 |
-
num_samples,
|
144 |
-
image_resolution,
|
145 |
-
num_steps,
|
146 |
-
guidance_scale,
|
147 |
-
seed,
|
148 |
-
]
|
149 |
-
|
150 |
-
# first call → run diffusion / texture network
|
151 |
-
prompt.submit(
|
152 |
-
fn=randomize_seed_fn,
|
153 |
-
inputs=[seed, randomize_seed],
|
154 |
-
outputs=seed,
|
155 |
-
queue=False,
|
156 |
-
api_name=False,
|
157 |
-
).then(
|
158 |
-
fn=process,
|
159 |
-
inputs=inputs,
|
160 |
-
outputs=[base_color, normal, roughness, metallic, out_blender_path],
|
161 |
-
api_name="canny",
|
162 |
-
concurrency_id="main",
|
163 |
-
)
|
164 |
-
|
165 |
-
gr.Examples(
|
166 |
-
fn=process,
|
167 |
-
inputs=inputs,
|
168 |
-
outputs=[base_color, normal, roughness, metallic],
|
169 |
-
examples=[
|
170 |
-
[
|
171 |
-
"bunny",
|
172 |
-
"examples/bunny/frame_0001.png", # /dgxusers/Users/jyang/project/ObjectReal/data/control/preprocess/bunny/uv_normal/fused.png
|
173 |
-
"examples/bunny/uv_normal.png", # /dgxusers/Users/jyang/project/ObjectReal/data/control/preprocess/bunny/uv_normal/fused.png
|
174 |
-
"feather",
|
175 |
-
a_prompt.value,
|
176 |
-
n_prompt.value,
|
177 |
-
num_samples.value,
|
178 |
-
image_resolution.value,
|
179 |
-
num_steps.value,
|
180 |
-
guidance_scale.value,
|
181 |
-
seed.value,
|
182 |
-
],
|
183 |
-
[
|
184 |
-
"monkey",
|
185 |
-
"examples/monkey/frame_0001.png", # /dgxusers/Users/jyang/project/ObjectReal/data/control/preprocess/monkey/uv_normal/fused.png
|
186 |
-
"examples/monkey/uv_normal.png", # /dgxusers/Users/jyang/project/ObjectReal/data/control/preprocess/monkey/uv_normal/fused.png
|
187 |
-
"wood",
|
188 |
-
a_prompt.value,
|
189 |
-
n_prompt.value,
|
190 |
-
num_samples.value,
|
191 |
-
image_resolution.value,
|
192 |
-
num_steps.value,
|
193 |
-
guidance_scale.value,
|
194 |
-
seed.value,
|
195 |
-
],
|
196 |
-
[
|
197 |
-
"tshirt",
|
198 |
-
"examples/tshirt/frame_0001.png", # /dgxusers/Users/jyang/project/ObjectReal/data/control/preprocess/monkey/uv_normal/fused.png
|
199 |
-
"examples/tshirt/uv_normal.png", # /dgxusers/Users/jyang/project/ObjectReal/data/control/preprocess/monkey/uv_normal/fused.png
|
200 |
-
"wood",
|
201 |
-
a_prompt.value,
|
202 |
-
n_prompt.value,
|
203 |
-
num_samples.value,
|
204 |
-
image_resolution.value,
|
205 |
-
num_steps.value,
|
206 |
-
guidance_scale.value,
|
207 |
-
seed.value,
|
208 |
-
],
|
209 |
-
# [
|
210 |
-
# "highheel",
|
211 |
-
# "examples/highheel/frame_0001.png", # /dgxusers/Users/jyang/project/ObjectReal/data/control/preprocess/monkey/uv_normal/fused.png
|
212 |
-
# "examples/highheel/uv_normal.png", # /dgxusers/Users/jyang/project/ObjectReal/data/control/preprocess/monkey/uv_normal/fused.png
|
213 |
-
# "wood",
|
214 |
-
# a_prompt.value,
|
215 |
-
# n_prompt.value,
|
216 |
-
# num_samples.value,
|
217 |
-
# image_resolution.value,
|
218 |
-
# num_steps.value,
|
219 |
-
# guidance_scale.value,
|
220 |
-
# seed.value,
|
221 |
-
# ],
|
222 |
-
[
|
223 |
-
"tank",
|
224 |
-
"examples/tank/frame_0001.png", # /dgxusers/Users/jyang/project/ObjectReal/data/control/preprocess/monkey/uv_normal/fused.png
|
225 |
-
"examples/tank/uv_normal.png", # /dgxusers/Users/jyang/project/ObjectReal/data/control/preprocess/monkey/uv_normal/fused.png
|
226 |
-
"wood",
|
227 |
-
a_prompt.value,
|
228 |
-
n_prompt.value,
|
229 |
-
num_samples.value,
|
230 |
-
image_resolution.value,
|
231 |
-
num_steps.value,
|
232 |
-
guidance_scale.value,
|
233 |
-
seed.value,
|
234 |
-
],
|
235 |
-
[
|
236 |
-
"fighter",
|
237 |
-
"examples/fighter/frame_0001.png", # /dgxusers/Users/jyang/project/ObjectReal/data/control/preprocess/monkey/uv_normal/fused.png
|
238 |
-
"examples/fighter/uv_normal.png", # /dgxusers/Users/jyang/project/ObjectReal/data/control/preprocess/monkey/uv_normal/fused.png
|
239 |
-
"wood",
|
240 |
-
a_prompt.value,
|
241 |
-
n_prompt.value,
|
242 |
-
num_samples.value,
|
243 |
-
image_resolution.value,
|
244 |
-
num_steps.value,
|
245 |
-
guidance_scale.value,
|
246 |
-
seed.value,
|
247 |
-
],
|
248 |
-
],
|
249 |
-
)
|
250 |
-
|
251 |
-
return demo
|
252 |
-
|
253 |
-
|
254 |
-
if __name__ == "__main__":
|
255 |
-
from model import Model
|
256 |
-
|
257 |
-
model = Model(task_name="Texnet")
|
258 |
-
demo = create_demo(model.process_texnet)
|
259 |
-
demo.queue().launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cv_utils.py
DELETED
@@ -1,17 +0,0 @@
|
|
1 |
-
import cv2
|
2 |
-
import numpy as np
|
3 |
-
|
4 |
-
|
5 |
-
def resize_image(input_image, resolution, interpolation=None):
|
6 |
-
H, W, C = input_image.shape
|
7 |
-
H = float(H)
|
8 |
-
W = float(W)
|
9 |
-
k = float(resolution) / max(H, W)
|
10 |
-
H *= k
|
11 |
-
W *= k
|
12 |
-
H = int(np.round(H / 64.0)) * 64
|
13 |
-
W = int(np.round(W / 64.0)) * 64
|
14 |
-
if interpolation is None:
|
15 |
-
interpolation = cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA
|
16 |
-
img = cv2.resize(input_image, (W, H), interpolation=interpolation)
|
17 |
-
return img
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
depth_estimator.py
DELETED
@@ -1,25 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import PIL.Image
|
3 |
-
from controlnet_aux.util import HWC3
|
4 |
-
from transformers import pipeline
|
5 |
-
|
6 |
-
from cv_utils import resize_image
|
7 |
-
|
8 |
-
|
9 |
-
class DepthEstimator:
|
10 |
-
def __init__(self):
|
11 |
-
self.model = pipeline("depth-estimation")
|
12 |
-
|
13 |
-
def __call__(self, image: np.ndarray, **kwargs) -> PIL.Image.Image:
|
14 |
-
detect_resolution = kwargs.pop("detect_resolution", 512)
|
15 |
-
image_resolution = kwargs.pop("image_resolution", 512)
|
16 |
-
image = np.array(image)
|
17 |
-
image = HWC3(image)
|
18 |
-
image = resize_image(image, resolution=detect_resolution)
|
19 |
-
image = PIL.Image.fromarray(image)
|
20 |
-
image = self.model(image)
|
21 |
-
image = image["depth"]
|
22 |
-
image = np.array(image)
|
23 |
-
image = HWC3(image)
|
24 |
-
image = resize_image(image, resolution=image_resolution)
|
25 |
-
return PIL.Image.fromarray(image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/bunny/frame_0001.png
DELETED
Git LFS Details
|
examples/bunny/mesh.obj
DELETED
The diff for this file is too large to render.
See raw diff
|
|
examples/bunny/uv_normal.png
DELETED
Git LFS Details
|
examples/fighter/frame_0001.png
DELETED
Git LFS Details
|
examples/fighter/mesh.obj
DELETED
The diff for this file is too large to render.
See raw diff
|
|
examples/fighter/uv_normal.png
DELETED
Git LFS Details
|
examples/highheel/frame_0001.png
DELETED
Git LFS Details
|
examples/highheel/mesh.obj
DELETED
The diff for this file is too large to render.
See raw diff
|
|
examples/highheel/uv_normal.png
DELETED
Git LFS Details
|
examples/monkey/frame_0001.png
DELETED
Git LFS Details
|
examples/monkey/mesh.obj
DELETED
The diff for this file is too large to render.
See raw diff
|
|
examples/monkey/uv_normal.png
DELETED
Git LFS Details
|
examples/tank/frame_0001.png
DELETED
Git LFS Details
|
examples/tank/mesh.obj
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:301633de1a7757f78a6f67abb6e61bcc8e6a01f5a54a8582d1943ad0ad943211
|
3 |
-
size 6942253
|
|
|
|
|
|
|
|
examples/tank/uv_normal.png
DELETED
Git LFS Details
|
examples/tshirt/frame_0001.png
DELETED
Git LFS Details
|
examples/tshirt/mesh.obj
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:b7c6c9bdec8d646a1980e5b987a1182c92af84cc945ef49c1735d4337185d3e5
|
3 |
-
size 39275876
|
|
|
|
|
|
|
|
examples/tshirt/uv_normal.png
DELETED
Git LFS Details
|
image_segmentor.py
DELETED
@@ -1,33 +0,0 @@
|
|
1 |
-
import cv2
|
2 |
-
import numpy as np
|
3 |
-
import PIL.Image
|
4 |
-
import torch
|
5 |
-
from controlnet_aux.util import HWC3, ade_palette
|
6 |
-
from transformers import AutoImageProcessor, UperNetForSemanticSegmentation
|
7 |
-
|
8 |
-
from cv_utils import resize_image
|
9 |
-
|
10 |
-
|
11 |
-
class ImageSegmentor:
|
12 |
-
def __init__(self):
|
13 |
-
self.image_processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-small")
|
14 |
-
self.image_segmentor = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-convnext-small")
|
15 |
-
|
16 |
-
@torch.inference_mode()
|
17 |
-
def __call__(self, image: np.ndarray, **kwargs) -> PIL.Image.Image:
|
18 |
-
detect_resolution = kwargs.pop("detect_resolution", 512)
|
19 |
-
image_resolution = kwargs.pop("image_resolution", 512)
|
20 |
-
image = HWC3(image)
|
21 |
-
image = resize_image(image, resolution=detect_resolution)
|
22 |
-
image = PIL.Image.fromarray(image)
|
23 |
-
|
24 |
-
pixel_values = self.image_processor(image, return_tensors="pt").pixel_values
|
25 |
-
outputs = self.image_segmentor(pixel_values)
|
26 |
-
seg = self.image_processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
|
27 |
-
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
|
28 |
-
for label, color in enumerate(ade_palette()):
|
29 |
-
color_seg[seg == label, :] = color
|
30 |
-
color_seg = color_seg.astype(np.uint8)
|
31 |
-
|
32 |
-
color_seg = resize_image(color_seg, resolution=image_resolution, interpolation=cv2.INTER_NEAREST)
|
33 |
-
return PIL.Image.fromarray(color_seg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
install.sh
DELETED
@@ -1,18 +0,0 @@
|
|
1 |
-
#!/bin/bash
|
2 |
-
eval "$(conda shell.bash hook)"
|
3 |
-
# conda activate base
|
4 |
-
# conda remove -n matgen-plus --all
|
5 |
-
|
6 |
-
conda create -n matgen-plus python=3.11
|
7 |
-
conda activate matgen-plus
|
8 |
-
|
9 |
-
pip install diffusers["torch"] transformers accelerate xformers
|
10 |
-
pip install gradio
|
11 |
-
pip install controlnet-aux
|
12 |
-
|
13 |
-
# text2tex
|
14 |
-
conda install pytorch3d -c pytorch -c conda-forge
|
15 |
-
conda install -c conda-forge open-clip-torch pytorch-lightning
|
16 |
-
pip install trimesh xatlas scikit-learn opencv-python omegaconf
|
17 |
-
|
18 |
-
python app.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model.py
DELETED
@@ -1,959 +0,0 @@
|
|
1 |
-
import gc
|
2 |
-
|
3 |
-
# get socket and check if the name is vgldgx01
|
4 |
-
import socket
|
5 |
-
if socket.gethostname() != "vgldgx01":
|
6 |
-
import spaces #[uncomment to use ZeroGPU]
|
7 |
-
|
8 |
-
import numpy as np
|
9 |
-
import PIL.Image
|
10 |
-
import torch
|
11 |
-
from controlnet_aux.util import HWC3
|
12 |
-
from diffusers import (
|
13 |
-
ControlNetModel,
|
14 |
-
DiffusionPipeline,
|
15 |
-
StableDiffusionControlNetPipeline,
|
16 |
-
StableDiffusionImg2ImgPipeline,
|
17 |
-
UniPCMultistepScheduler,
|
18 |
-
DDIMScheduler, #rgb2x
|
19 |
-
)
|
20 |
-
import torchvision
|
21 |
-
from torchvision import transforms
|
22 |
-
from cv_utils import resize_image
|
23 |
-
from preprocessor import Preprocessor
|
24 |
-
from settings import MAX_IMAGE_RESOLUTION, MAX_NUM_IMAGES
|
25 |
-
from tqdm.auto import tqdm
|
26 |
-
import subprocess
|
27 |
-
|
28 |
-
from rgb2x.pipeline_rgb2x import StableDiffusionAOVMatEstPipeline
|
29 |
-
from app_texnet import image_to_temp_path
|
30 |
-
import os
|
31 |
-
import time
|
32 |
-
import tempfile
|
33 |
-
from text2tex.scripts.generate_texture import text2tex_call, init_args
|
34 |
-
from glob import glob
|
35 |
-
|
36 |
-
CONTROLNET_MODEL_IDS = {
|
37 |
-
# "Openpose": "lllyasviel/control_v11p_sd15_openpose",
|
38 |
-
# "Canny": "lllyasviel/control_v11p_sd15_canny",
|
39 |
-
# "MLSD": "lllyasviel/control_v11p_sd15_mlsd",
|
40 |
-
# "scribble": "lllyasviel/control_v11p_sd15_scribble",
|
41 |
-
# "softedge": "lllyasviel/control_v11p_sd15_softedge",
|
42 |
-
# "segmentation": "lllyasviel/control_v11p_sd15_seg",
|
43 |
-
# "depth": "lllyasviel/control_v11f1p_sd15_depth",
|
44 |
-
# "NormalBae": "lllyasviel/control_v11p_sd15_normalbae",
|
45 |
-
# "lineart": "lllyasviel/control_v11p_sd15_lineart",
|
46 |
-
# "lineart_anime": "lllyasviel/control_v11p_sd15s2_lineart_anime",
|
47 |
-
# "shuffle": "lllyasviel/control_v11e_sd15_shuffle",
|
48 |
-
# "ip2p": "lllyasviel/control_v11e_sd15_ip2p",
|
49 |
-
# "inpaint": "lllyasviel/control_v11e_sd15_inpaint",
|
50 |
-
# "texnet": "/home/jyang/projects/ObjectReal/logs/train_texnet_deploy/checkpoint-55000/controlnet" # load and call
|
51 |
-
"texnet": "jingyangcarl/texnet",
|
52 |
-
}
|
53 |
-
|
54 |
-
|
55 |
-
def download_all_controlnet_weights() -> None:
|
56 |
-
for model_id in CONTROLNET_MODEL_IDS.values():
|
57 |
-
ControlNetModel.from_pretrained(model_id)
|
58 |
-
|
59 |
-
|
60 |
-
class Model:
|
61 |
-
def __init__(
|
62 |
-
self, base_model_id: str = "stable-diffusion-v1-5/stable-diffusion-v1-5", task_name: str = "Canny"
|
63 |
-
) -> None:
|
64 |
-
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
65 |
-
self.base_model_id = ""
|
66 |
-
self.task_name = ""
|
67 |
-
self.pipe = self.load_pipe(base_model_id, task_name)
|
68 |
-
self.pipe_base = StableDiffusionImg2ImgPipeline.from_pretrained(
|
69 |
-
'runwayml/stable-diffusion-v1-5', safety_checker=None, torch_dtype=torch.float16
|
70 |
-
).to(self.device)
|
71 |
-
self.preprocessor = Preprocessor()
|
72 |
-
|
73 |
-
# set up pipe_rgb2x
|
74 |
-
self.pipe_rgb2x = StableDiffusionAOVMatEstPipeline.from_pretrained(
|
75 |
-
"zheng95z/rgb-to-x",
|
76 |
-
torch_dtype=torch.float16,
|
77 |
-
).to(self.device)
|
78 |
-
self.pipe_rgb2x.scheduler = DDIMScheduler.from_config(
|
79 |
-
self.pipe_rgb2x.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing"
|
80 |
-
)
|
81 |
-
self.pipe_rgb2x.set_progress_bar_config(disable=True)
|
82 |
-
|
83 |
-
# setup blender
|
84 |
-
self.blender_path = '/tmp/blender-3.2.2-linux-x64/blender'
|
85 |
-
if not os.path.exists(self.blender_path):
|
86 |
-
print("Downloading Blender...")
|
87 |
-
subprocess.run(["wget", "https://download.blender.org/release/Blender3.2/blender-3.2.2-linux-x64.tar.xz", "-O", "/tmp/blender-3.2.2-linux-x64.tar.xz"], check=True)
|
88 |
-
subprocess.run(["tar", "-xf", "/tmp/blender-3.2.2-linux-x64.tar.xz", "-C", "/tmp"], check=True)
|
89 |
-
print("Blender downloaded and extracted.")
|
90 |
-
|
91 |
-
def load_pipe(self, base_model_id: str, task_name: str) -> DiffusionPipeline:
|
92 |
-
if (
|
93 |
-
base_model_id == self.base_model_id
|
94 |
-
and task_name == self.task_name
|
95 |
-
and hasattr(self, "pipe")
|
96 |
-
and self.pipe is not None
|
97 |
-
):
|
98 |
-
return self.pipe
|
99 |
-
model_id = CONTROLNET_MODEL_IDS[task_name]
|
100 |
-
controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float16)
|
101 |
-
to_upload = False
|
102 |
-
if to_upload:
|
103 |
-
# confirm before uploading
|
104 |
-
confirm = input(f"Do you want to upload {model_id} to the hub? (y/n): ")
|
105 |
-
if confirm.lower() == "y":
|
106 |
-
controlnet.push_to_hub("jingyangcarl/texnet")
|
107 |
-
else:
|
108 |
-
print("Upload cancelled.")
|
109 |
-
pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
110 |
-
base_model_id, safety_checker=None, controlnet=controlnet, torch_dtype=torch.float16
|
111 |
-
)
|
112 |
-
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
113 |
-
pipe.to(self.device)
|
114 |
-
if self.device.type == "cuda":
|
115 |
-
import os
|
116 |
-
if os.environ.get("SPACES_ZERO_GPU", "0") == "1":
|
117 |
-
# when running on ZeroGPU, enable CPU offload
|
118 |
-
# pipe.enable_xformers_memory_efficient_attention() doens't work
|
119 |
-
# pipe.enable_model_cpu_offload()
|
120 |
-
pass
|
121 |
-
else:
|
122 |
-
pipe.enable_xformers_memory_efficient_attention()
|
123 |
-
torch.cuda.empty_cache()
|
124 |
-
gc.collect()
|
125 |
-
self.base_model_id = base_model_id
|
126 |
-
self.task_name = task_name
|
127 |
-
return pipe
|
128 |
-
|
129 |
-
def set_base_model(self, base_model_id: str) -> str:
|
130 |
-
if not base_model_id or base_model_id == self.base_model_id:
|
131 |
-
return self.base_model_id
|
132 |
-
del self.pipe
|
133 |
-
torch.cuda.empty_cache()
|
134 |
-
gc.collect()
|
135 |
-
try:
|
136 |
-
self.pipe = self.load_pipe(base_model_id, self.task_name)
|
137 |
-
except Exception: # noqa: BLE001
|
138 |
-
self.pipe = self.load_pipe(self.base_model_id, self.task_name)
|
139 |
-
return self.base_model_id
|
140 |
-
|
141 |
-
def load_controlnet_weight(self, task_name: str) -> None:
|
142 |
-
if task_name == self.task_name:
|
143 |
-
return
|
144 |
-
if self.pipe is not None and hasattr(self.pipe, "controlnet"):
|
145 |
-
del self.pipe.controlnet
|
146 |
-
torch.cuda.empty_cache()
|
147 |
-
gc.collect()
|
148 |
-
model_id = CONTROLNET_MODEL_IDS[task_name]
|
149 |
-
controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float16)
|
150 |
-
controlnet.to(self.device)
|
151 |
-
torch.cuda.empty_cache()
|
152 |
-
gc.collect()
|
153 |
-
self.pipe.controlnet = controlnet
|
154 |
-
self.task_name = task_name
|
155 |
-
|
156 |
-
def get_prompt(self, prompt: str, additional_prompt: str) -> str:
|
157 |
-
return additional_prompt if not prompt else f"{prompt}, {additional_prompt}"
|
158 |
-
|
159 |
-
# @spaces.GPU #[uncomment to use ZeroGPU]
|
160 |
-
@torch.autocast("cuda")
|
161 |
-
def run_pipe(
|
162 |
-
self,
|
163 |
-
prompt: str,
|
164 |
-
negative_prompt: str,
|
165 |
-
control_image: PIL.Image.Image,
|
166 |
-
num_images: int,
|
167 |
-
num_steps: int,
|
168 |
-
guidance_scale: float,
|
169 |
-
seed: int,
|
170 |
-
) -> list[PIL.Image.Image]:
|
171 |
-
generator = torch.Generator().manual_seed(seed)
|
172 |
-
# self.pipe.to(self.device)
|
173 |
-
return self.pipe(
|
174 |
-
prompt=prompt,
|
175 |
-
negative_prompt=negative_prompt,
|
176 |
-
guidance_scale=guidance_scale,
|
177 |
-
num_images_per_prompt=num_images,
|
178 |
-
num_inference_steps=num_steps,
|
179 |
-
generator=generator,
|
180 |
-
image=control_image,
|
181 |
-
).images
|
182 |
-
|
183 |
-
# @spaces.GPU #[uncomment to use ZeroGPU]
|
184 |
-
@torch.inference_mode()
|
185 |
-
def process_texnet(
|
186 |
-
self,
|
187 |
-
obj_name: str,
|
188 |
-
represented_image: np.ndarray | None, # not used
|
189 |
-
image: np.ndarray,
|
190 |
-
prompt: str,
|
191 |
-
additional_prompt: str,
|
192 |
-
negative_prompt: str,
|
193 |
-
num_images: int,
|
194 |
-
image_resolution: int,
|
195 |
-
num_steps: int,
|
196 |
-
guidance_scale: float,
|
197 |
-
seed: int,
|
198 |
-
low_threshold: int,
|
199 |
-
high_threshold: int,
|
200 |
-
) -> list[PIL.Image.Image]:
|
201 |
-
if image is None:
|
202 |
-
raise ValueError
|
203 |
-
if image_resolution > MAX_IMAGE_RESOLUTION:
|
204 |
-
raise ValueError
|
205 |
-
if num_images > MAX_NUM_IMAGES:
|
206 |
-
raise ValueError
|
207 |
-
|
208 |
-
prompt_nospace = prompt.replace(' ', '_')
|
209 |
-
|
210 |
-
# self.preprocessor.load("texnet")
|
211 |
-
# control_image = self.preprocessor(
|
212 |
-
# image=image, low_threshold=low_threshold, high_threshold=high_threshold, image_resolution=image_resolution, output_type="pil"
|
213 |
-
# )
|
214 |
-
|
215 |
-
# self.load_controlnet_weight("texnet")
|
216 |
-
# tex_coarse = self.run_pipe(
|
217 |
-
# prompt=self.get_prompt(prompt, additional_prompt),
|
218 |
-
# negative_prompt=negative_prompt,
|
219 |
-
# control_image=control_image,
|
220 |
-
# num_images=num_images,
|
221 |
-
# num_steps=num_steps,
|
222 |
-
# guidance_scale=guidance_scale,
|
223 |
-
# seed=seed,
|
224 |
-
# )
|
225 |
-
|
226 |
-
# # use img2img pipeline
|
227 |
-
# self.pipe_backup = self.pipe
|
228 |
-
# self.pipe = self.pipe_base
|
229 |
-
|
230 |
-
# # refine
|
231 |
-
tex_fine = []
|
232 |
-
mesh_fine = []
|
233 |
-
# for result_coarse in tex_coarse:
|
234 |
-
# # clean up GPU cache
|
235 |
-
# torch.cuda.empty_cache()
|
236 |
-
# gc.collect()
|
237 |
-
|
238 |
-
# # masking
|
239 |
-
# mask = (np.array(control_image).sum(axis=-1) == 0)[...,None]
|
240 |
-
# image_masked = PIL.Image.fromarray(np.where(mask, control_image, result_coarse))
|
241 |
-
# image_blurry = transforms.GaussianBlur(kernel_size=5, sigma=1)(image_masked)
|
242 |
-
# result_fine = self.run_pipe(
|
243 |
-
# # prompt=prompt,
|
244 |
-
# prompt=self.get_prompt(prompt, additional_prompt),
|
245 |
-
# negative_prompt=negative_prompt,
|
246 |
-
# control_image=image_blurry,
|
247 |
-
# num_images=1,
|
248 |
-
# num_steps=num_steps,
|
249 |
-
# guidance_scale=guidance_scale,
|
250 |
-
# seed=seed,
|
251 |
-
# )[0]
|
252 |
-
# result_fine = PIL.Image.fromarray(np.where(mask, control_image, result_fine))
|
253 |
-
# tex_fine.append(result_fine)
|
254 |
-
|
255 |
-
temp_out_path = tempfile.mkdtemp()
|
256 |
-
temp_out_path = 'output'
|
257 |
-
|
258 |
-
# put text2tex here,
|
259 |
-
args = init_args()
|
260 |
-
args.input_dir = f'examples/{obj_name}/'
|
261 |
-
args.output_dir = os.path.join(temp_out_path, f'{obj_name}/{prompt_nospace}')
|
262 |
-
args.obj_name = obj_name
|
263 |
-
args.obj_file = 'mesh.obj'
|
264 |
-
args.prompt = f'{prompt} {obj_name}'
|
265 |
-
args.add_view_to_prompt = True
|
266 |
-
args.ddim_steps = 5
|
267 |
-
# args.ddim_steps = 50
|
268 |
-
args.new_strength = 1.0
|
269 |
-
args.update_strength = 0.3
|
270 |
-
args.view_threshold = 0.1
|
271 |
-
args.blend = 0
|
272 |
-
args.dist = 1
|
273 |
-
args.num_viewpoints = 2
|
274 |
-
# args.num_viewpoints = 36
|
275 |
-
args.viewpoint_mode = 'predefined'
|
276 |
-
args.use_principle = True
|
277 |
-
args.update_steps = 2
|
278 |
-
# args.update_steps = 20
|
279 |
-
args.update_mode = 'heuristic'
|
280 |
-
args.seed = 42
|
281 |
-
args.post_process = True
|
282 |
-
args.device = '2080'
|
283 |
-
args.uv_size = 1000
|
284 |
-
args.image_size = 512
|
285 |
-
# args.image_size = 768
|
286 |
-
args.use_objaverse = True # assume the mesh is normalized with y-axis as up
|
287 |
-
output_dir = text2tex_call(args)
|
288 |
-
|
289 |
-
# get the texture and mesh with underscore '_post', which is the id of the last mesh, should be good for the visual
|
290 |
-
post_idx = glob(os.path.join(output_dir, 'update', 'mesh', "*_post.png"))[0].split('/')[-1].split('_')[0]
|
291 |
-
|
292 |
-
tex_fine.append(PIL.Image.open(os.path.join(output_dir, 'update', 'mesh', f"{post_idx}.png")).convert("RGB"))
|
293 |
-
mesh_fine.append(os.path.join(output_dir, 'update', 'mesh', f"{post_idx}.obj"))
|
294 |
-
torch.cuda.empty_cache()
|
295 |
-
|
296 |
-
# restore the original pipe
|
297 |
-
# self.pipe = self.pipe_backup
|
298 |
-
|
299 |
-
# use rgb2x for now for generating the texture
|
300 |
-
def rgb2x(
|
301 |
-
pipeline,
|
302 |
-
photo,
|
303 |
-
inference_step = 50,
|
304 |
-
num_samples = 1,
|
305 |
-
):
|
306 |
-
generator = torch.Generator(device="cuda").manual_seed(seed)
|
307 |
-
|
308 |
-
# Check if the width and height are multiples of 8. If not, crop it using torchvision.transforms.CenterCrop
|
309 |
-
old_height = photo.shape[1]
|
310 |
-
old_width = photo.shape[2]
|
311 |
-
new_height = old_height
|
312 |
-
new_width = old_width
|
313 |
-
radio = old_height / old_width
|
314 |
-
max_side = 1000
|
315 |
-
if old_height > old_width:
|
316 |
-
new_height = max_side
|
317 |
-
new_width = int(new_height / radio)
|
318 |
-
else:
|
319 |
-
new_width = max_side
|
320 |
-
new_height = int(new_width * radio)
|
321 |
-
|
322 |
-
if new_width % 8 != 0 or new_height % 8 != 0:
|
323 |
-
new_width = new_width // 8 * 8
|
324 |
-
new_height = new_height // 8 * 8
|
325 |
-
|
326 |
-
photo = torchvision.transforms.Resize((new_height, new_width))(photo)
|
327 |
-
|
328 |
-
required_aovs = ["albedo", "normal", "roughness", "metallic", "irradiance"]
|
329 |
-
prompts = {
|
330 |
-
"albedo": "Albedo (diffuse basecolor)",
|
331 |
-
"normal": "Camera-space Normal",
|
332 |
-
"roughness": "Roughness",
|
333 |
-
"metallic": "Metallicness",
|
334 |
-
"irradiance": "Irradiance (diffuse lighting)",
|
335 |
-
}
|
336 |
-
|
337 |
-
return_list = []
|
338 |
-
for i in tqdm(range(num_samples), desc="Running Pipeline", leave=False):
|
339 |
-
for aov_name in required_aovs:
|
340 |
-
prompt = prompts[aov_name]
|
341 |
-
generated_image = pipeline(
|
342 |
-
prompt=prompt,
|
343 |
-
photo=photo,
|
344 |
-
num_inference_steps=inference_step,
|
345 |
-
height=new_height,
|
346 |
-
width=new_width,
|
347 |
-
generator=generator,
|
348 |
-
required_aovs=[aov_name],
|
349 |
-
).images[0][0]
|
350 |
-
|
351 |
-
generated_image = torchvision.transforms.Resize(
|
352 |
-
(old_height, old_width)
|
353 |
-
)(generated_image)
|
354 |
-
|
355 |
-
# generated_image = (generated_image, f"Generated {aov_name} {i}")
|
356 |
-
# generated_image = (generated_image, f"{aov_name}")
|
357 |
-
return_list.append(generated_image)
|
358 |
-
|
359 |
-
return photo, return_list, prompts
|
360 |
-
|
361 |
-
# Load rgb2x pipeline
|
362 |
-
_, preds, prompts = rgb2x(self.pipe_rgb2x, torchvision.transforms.PILToTensor()(tex_fine[0]).to(self.pipe.device), inference_step=num_steps, num_samples=num_images)
|
363 |
-
|
364 |
-
intrinsic_dir = os.path.join(output_dir, 'intrinsic')
|
365 |
-
use_text2tex = True
|
366 |
-
if use_text2tex:
|
367 |
-
base_color_path = image_to_temp_path(tex_fine[0], "base_color", out_dir=intrinsic_dir)
|
368 |
-
normal_map_path = image_to_temp_path(preds[0], "normal_map", out_dir=intrinsic_dir)
|
369 |
-
roughness_path = image_to_temp_path(preds[1], "roughness", out_dir=intrinsic_dir)
|
370 |
-
metallic_path = image_to_temp_path(preds[2], "metallic", out_dir=intrinsic_dir)
|
371 |
-
else:
|
372 |
-
base_color_path = image_to_temp_path(tex_fine[0].rotate(90), "base_color", out_dir=intrinsic_dir)
|
373 |
-
normal_map_path = image_to_temp_path(preds[0].rotate(90), "normal_map", out_dir=intrinsic_dir)
|
374 |
-
roughness_path = image_to_temp_path(preds[1].rotate(90), "roughness", out_dir=intrinsic_dir)
|
375 |
-
metallic_path = image_to_temp_path(preds[2].rotate(90), "metallic", out_dir=intrinsic_dir)
|
376 |
-
current_timecode = time.strftime("%Y%m%d_%H%M%S")
|
377 |
-
# output_blend_path = os.path.join(os.getcwd(), "output", f"{obj_name}_{prompt_nospace}_{current_timecode}.blend") # replace with desired output path
|
378 |
-
output_blend_path = os.path.join(tempfile.mkdtemp(), f"{obj_name}_{prompt_nospace}_{current_timecode}.blend") # replace with desired output path
|
379 |
-
os.makedirs(os.path.dirname(output_blend_path), exist_ok=True)
|
380 |
-
|
381 |
-
def run_blend_generation(
|
382 |
-
blender_path,
|
383 |
-
generate_script_path,
|
384 |
-
obj_path,
|
385 |
-
base_color_path,
|
386 |
-
normal_map_path,
|
387 |
-
roughness_path,
|
388 |
-
metallic_path,
|
389 |
-
output_blend
|
390 |
-
):
|
391 |
-
cmd = [
|
392 |
-
blender_path, "--background", "--python", generate_script_path, "--",
|
393 |
-
obj_path, base_color_path, normal_map_path, roughness_path, metallic_path, output_blend
|
394 |
-
]
|
395 |
-
subprocess.run(cmd, check=True)
|
396 |
-
|
397 |
-
# check if the blender_path exists, if not download
|
398 |
-
run_blend_generation(
|
399 |
-
blender_path=self.blender_path,
|
400 |
-
generate_script_path="rgb2x/generate_blend.py",
|
401 |
-
# obj_path=f"examples/{obj_name}/mesh.obj", # replace with actual mesh path
|
402 |
-
obj_path=mesh_fine[0], # replace with actual mesh path
|
403 |
-
base_color_path=base_color_path,
|
404 |
-
normal_map_path=normal_map_path,
|
405 |
-
roughness_path=roughness_path,
|
406 |
-
metallic_path=metallic_path,
|
407 |
-
output_blend=output_blend_path # replace with desired output path
|
408 |
-
)
|
409 |
-
|
410 |
-
# gallary
|
411 |
-
return [*tex_fine], [preds[1]], [preds[2]], [preds[3]], [output_blend_path]
|
412 |
-
|
413 |
-
# @spaces.GPU #[uncomment to use ZeroGPU]
|
414 |
-
@torch.inference_mode()
|
415 |
-
def process_canny(
|
416 |
-
self,
|
417 |
-
image: np.ndarray,
|
418 |
-
prompt: str,
|
419 |
-
additional_prompt: str,
|
420 |
-
negative_prompt: str,
|
421 |
-
num_images: int,
|
422 |
-
image_resolution: int,
|
423 |
-
num_steps: int,
|
424 |
-
guidance_scale: float,
|
425 |
-
seed: int,
|
426 |
-
low_threshold: int,
|
427 |
-
high_threshold: int,
|
428 |
-
) -> list[PIL.Image.Image]:
|
429 |
-
if image is None:
|
430 |
-
raise ValueError
|
431 |
-
if image_resolution > MAX_IMAGE_RESOLUTION:
|
432 |
-
raise ValueError
|
433 |
-
if num_images > MAX_NUM_IMAGES:
|
434 |
-
raise ValueError
|
435 |
-
|
436 |
-
self.preprocessor.load("Canny")
|
437 |
-
control_image = self.preprocessor(
|
438 |
-
image=image, low_threshold=low_threshold, high_threshold=high_threshold, detect_resolution=image_resolution
|
439 |
-
)
|
440 |
-
|
441 |
-
self.load_controlnet_weight("Canny")
|
442 |
-
results = self.run_pipe(
|
443 |
-
prompt=self.get_prompt(prompt, additional_prompt),
|
444 |
-
negative_prompt=negative_prompt,
|
445 |
-
control_image=control_image,
|
446 |
-
num_images=num_images,
|
447 |
-
num_steps=num_steps,
|
448 |
-
guidance_scale=guidance_scale,
|
449 |
-
seed=seed,
|
450 |
-
)
|
451 |
-
return [control_image, *results]
|
452 |
-
|
453 |
-
@torch.inference_mode()
|
454 |
-
def process_mlsd(
|
455 |
-
self,
|
456 |
-
image: np.ndarray,
|
457 |
-
prompt: str,
|
458 |
-
additional_prompt: str,
|
459 |
-
negative_prompt: str,
|
460 |
-
num_images: int,
|
461 |
-
image_resolution: int,
|
462 |
-
preprocess_resolution: int,
|
463 |
-
num_steps: int,
|
464 |
-
guidance_scale: float,
|
465 |
-
seed: int,
|
466 |
-
value_threshold: float,
|
467 |
-
distance_threshold: float,
|
468 |
-
) -> list[PIL.Image.Image]:
|
469 |
-
if image is None:
|
470 |
-
raise ValueError
|
471 |
-
if image_resolution > MAX_IMAGE_RESOLUTION:
|
472 |
-
raise ValueError
|
473 |
-
if num_images > MAX_NUM_IMAGES:
|
474 |
-
raise ValueError
|
475 |
-
|
476 |
-
self.preprocessor.load("MLSD")
|
477 |
-
control_image = self.preprocessor(
|
478 |
-
image=image,
|
479 |
-
image_resolution=image_resolution,
|
480 |
-
detect_resolution=preprocess_resolution,
|
481 |
-
thr_v=value_threshold,
|
482 |
-
thr_d=distance_threshold,
|
483 |
-
)
|
484 |
-
self.load_controlnet_weight("MLSD")
|
485 |
-
results = self.run_pipe(
|
486 |
-
prompt=self.get_prompt(prompt, additional_prompt),
|
487 |
-
negative_prompt=negative_prompt,
|
488 |
-
control_image=control_image,
|
489 |
-
num_images=num_images,
|
490 |
-
num_steps=num_steps,
|
491 |
-
guidance_scale=guidance_scale,
|
492 |
-
seed=seed,
|
493 |
-
)
|
494 |
-
return [control_image, *results]
|
495 |
-
|
496 |
-
@torch.inference_mode()
|
497 |
-
def process_scribble(
|
498 |
-
self,
|
499 |
-
image: np.ndarray,
|
500 |
-
prompt: str,
|
501 |
-
additional_prompt: str,
|
502 |
-
negative_prompt: str,
|
503 |
-
num_images: int,
|
504 |
-
image_resolution: int,
|
505 |
-
preprocess_resolution: int,
|
506 |
-
num_steps: int,
|
507 |
-
guidance_scale: float,
|
508 |
-
seed: int,
|
509 |
-
preprocessor_name: str,
|
510 |
-
) -> list[PIL.Image.Image]:
|
511 |
-
if image is None:
|
512 |
-
raise ValueError
|
513 |
-
if image_resolution > MAX_IMAGE_RESOLUTION:
|
514 |
-
raise ValueError
|
515 |
-
if num_images > MAX_NUM_IMAGES:
|
516 |
-
raise ValueError
|
517 |
-
|
518 |
-
if preprocessor_name == "None":
|
519 |
-
image = HWC3(image)
|
520 |
-
image = resize_image(image, resolution=image_resolution)
|
521 |
-
control_image = PIL.Image.fromarray(image)
|
522 |
-
elif preprocessor_name == "HED":
|
523 |
-
self.preprocessor.load(preprocessor_name)
|
524 |
-
control_image = self.preprocessor(
|
525 |
-
image=image,
|
526 |
-
image_resolution=image_resolution,
|
527 |
-
detect_resolution=preprocess_resolution,
|
528 |
-
scribble=False,
|
529 |
-
)
|
530 |
-
elif preprocessor_name == "PidiNet":
|
531 |
-
self.preprocessor.load(preprocessor_name)
|
532 |
-
control_image = self.preprocessor(
|
533 |
-
image=image,
|
534 |
-
image_resolution=image_resolution,
|
535 |
-
detect_resolution=preprocess_resolution,
|
536 |
-
safe=False,
|
537 |
-
)
|
538 |
-
self.load_controlnet_weight("scribble")
|
539 |
-
results = self.run_pipe(
|
540 |
-
prompt=self.get_prompt(prompt, additional_prompt),
|
541 |
-
negative_prompt=negative_prompt,
|
542 |
-
control_image=control_image,
|
543 |
-
num_images=num_images,
|
544 |
-
num_steps=num_steps,
|
545 |
-
guidance_scale=guidance_scale,
|
546 |
-
seed=seed,
|
547 |
-
)
|
548 |
-
return [control_image, *results]
|
549 |
-
|
550 |
-
@torch.inference_mode()
|
551 |
-
def process_scribble_interactive(
|
552 |
-
self,
|
553 |
-
image_and_mask: dict[str, np.ndarray | list[np.ndarray]] | None,
|
554 |
-
prompt: str,
|
555 |
-
additional_prompt: str,
|
556 |
-
negative_prompt: str,
|
557 |
-
num_images: int,
|
558 |
-
image_resolution: int,
|
559 |
-
num_steps: int,
|
560 |
-
guidance_scale: float,
|
561 |
-
seed: int,
|
562 |
-
) -> list[PIL.Image.Image]:
|
563 |
-
if image_and_mask is None:
|
564 |
-
raise ValueError
|
565 |
-
if image_resolution > MAX_IMAGE_RESOLUTION:
|
566 |
-
raise ValueError
|
567 |
-
if num_images > MAX_NUM_IMAGES:
|
568 |
-
raise ValueError
|
569 |
-
|
570 |
-
image = 255 - image_and_mask["composite"] # type: ignore
|
571 |
-
image = HWC3(image)
|
572 |
-
image = resize_image(image, resolution=image_resolution)
|
573 |
-
control_image = PIL.Image.fromarray(image)
|
574 |
-
|
575 |
-
self.load_controlnet_weight("scribble")
|
576 |
-
results = self.run_pipe(
|
577 |
-
prompt=self.get_prompt(prompt, additional_prompt),
|
578 |
-
negative_prompt=negative_prompt,
|
579 |
-
control_image=control_image,
|
580 |
-
num_images=num_images,
|
581 |
-
num_steps=num_steps,
|
582 |
-
guidance_scale=guidance_scale,
|
583 |
-
seed=seed,
|
584 |
-
)
|
585 |
-
return [control_image, *results]
|
586 |
-
|
587 |
-
@torch.inference_mode()
|
588 |
-
def process_softedge(
|
589 |
-
self,
|
590 |
-
image: np.ndarray,
|
591 |
-
prompt: str,
|
592 |
-
additional_prompt: str,
|
593 |
-
negative_prompt: str,
|
594 |
-
num_images: int,
|
595 |
-
image_resolution: int,
|
596 |
-
preprocess_resolution: int,
|
597 |
-
num_steps: int,
|
598 |
-
guidance_scale: float,
|
599 |
-
seed: int,
|
600 |
-
preprocessor_name: str,
|
601 |
-
) -> list[PIL.Image.Image]:
|
602 |
-
if image is None:
|
603 |
-
raise ValueError
|
604 |
-
if image_resolution > MAX_IMAGE_RESOLUTION:
|
605 |
-
raise ValueError
|
606 |
-
if num_images > MAX_NUM_IMAGES:
|
607 |
-
raise ValueError
|
608 |
-
|
609 |
-
if preprocessor_name == "None":
|
610 |
-
image = HWC3(image)
|
611 |
-
image = resize_image(image, resolution=image_resolution)
|
612 |
-
control_image = PIL.Image.fromarray(image)
|
613 |
-
elif preprocessor_name in ["HED", "HED safe"]:
|
614 |
-
safe = "safe" in preprocessor_name
|
615 |
-
self.preprocessor.load("HED")
|
616 |
-
control_image = self.preprocessor(
|
617 |
-
image=image,
|
618 |
-
image_resolution=image_resolution,
|
619 |
-
detect_resolution=preprocess_resolution,
|
620 |
-
scribble=safe,
|
621 |
-
)
|
622 |
-
elif preprocessor_name in ["PidiNet", "PidiNet safe"]:
|
623 |
-
safe = "safe" in preprocessor_name
|
624 |
-
self.preprocessor.load("PidiNet")
|
625 |
-
control_image = self.preprocessor(
|
626 |
-
image=image,
|
627 |
-
image_resolution=image_resolution,
|
628 |
-
detect_resolution=preprocess_resolution,
|
629 |
-
safe=safe,
|
630 |
-
)
|
631 |
-
else:
|
632 |
-
raise ValueError
|
633 |
-
self.load_controlnet_weight("softedge")
|
634 |
-
results = self.run_pipe(
|
635 |
-
prompt=self.get_prompt(prompt, additional_prompt),
|
636 |
-
negative_prompt=negative_prompt,
|
637 |
-
control_image=control_image,
|
638 |
-
num_images=num_images,
|
639 |
-
num_steps=num_steps,
|
640 |
-
guidance_scale=guidance_scale,
|
641 |
-
seed=seed,
|
642 |
-
)
|
643 |
-
return [control_image, *results]
|
644 |
-
|
645 |
-
@torch.inference_mode()
|
646 |
-
def process_openpose(
|
647 |
-
self,
|
648 |
-
image: np.ndarray,
|
649 |
-
prompt: str,
|
650 |
-
additional_prompt: str,
|
651 |
-
negative_prompt: str,
|
652 |
-
num_images: int,
|
653 |
-
image_resolution: int,
|
654 |
-
preprocess_resolution: int,
|
655 |
-
num_steps: int,
|
656 |
-
guidance_scale: float,
|
657 |
-
seed: int,
|
658 |
-
preprocessor_name: str,
|
659 |
-
) -> list[PIL.Image.Image]:
|
660 |
-
if image is None:
|
661 |
-
raise ValueError
|
662 |
-
if image_resolution > MAX_IMAGE_RESOLUTION:
|
663 |
-
raise ValueError
|
664 |
-
if num_images > MAX_NUM_IMAGES:
|
665 |
-
raise ValueError
|
666 |
-
|
667 |
-
if preprocessor_name == "None":
|
668 |
-
image = HWC3(image)
|
669 |
-
image = resize_image(image, resolution=image_resolution)
|
670 |
-
control_image = PIL.Image.fromarray(image)
|
671 |
-
else:
|
672 |
-
self.preprocessor.load("Openpose")
|
673 |
-
control_image = self.preprocessor(
|
674 |
-
image=image,
|
675 |
-
image_resolution=image_resolution,
|
676 |
-
detect_resolution=preprocess_resolution,
|
677 |
-
hand_and_face=True,
|
678 |
-
)
|
679 |
-
self.load_controlnet_weight("Openpose")
|
680 |
-
results = self.run_pipe(
|
681 |
-
prompt=self.get_prompt(prompt, additional_prompt),
|
682 |
-
negative_prompt=negative_prompt,
|
683 |
-
control_image=control_image,
|
684 |
-
num_images=num_images,
|
685 |
-
num_steps=num_steps,
|
686 |
-
guidance_scale=guidance_scale,
|
687 |
-
seed=seed,
|
688 |
-
)
|
689 |
-
return [control_image, *results]
|
690 |
-
|
691 |
-
@torch.inference_mode()
|
692 |
-
def process_segmentation(
|
693 |
-
self,
|
694 |
-
image: np.ndarray,
|
695 |
-
prompt: str,
|
696 |
-
additional_prompt: str,
|
697 |
-
negative_prompt: str,
|
698 |
-
num_images: int,
|
699 |
-
image_resolution: int,
|
700 |
-
preprocess_resolution: int,
|
701 |
-
num_steps: int,
|
702 |
-
guidance_scale: float,
|
703 |
-
seed: int,
|
704 |
-
preprocessor_name: str,
|
705 |
-
) -> list[PIL.Image.Image]:
|
706 |
-
if image is None:
|
707 |
-
raise ValueError
|
708 |
-
if image_resolution > MAX_IMAGE_RESOLUTION:
|
709 |
-
raise ValueError
|
710 |
-
if num_images > MAX_NUM_IMAGES:
|
711 |
-
raise ValueError
|
712 |
-
|
713 |
-
if preprocessor_name == "None":
|
714 |
-
image = HWC3(image)
|
715 |
-
image = resize_image(image, resolution=image_resolution)
|
716 |
-
control_image = PIL.Image.fromarray(image)
|
717 |
-
else:
|
718 |
-
self.preprocessor.load(preprocessor_name)
|
719 |
-
control_image = self.preprocessor(
|
720 |
-
image=image,
|
721 |
-
image_resolution=image_resolution,
|
722 |
-
detect_resolution=preprocess_resolution,
|
723 |
-
)
|
724 |
-
self.load_controlnet_weight("segmentation")
|
725 |
-
results = self.run_pipe(
|
726 |
-
prompt=self.get_prompt(prompt, additional_prompt),
|
727 |
-
negative_prompt=negative_prompt,
|
728 |
-
control_image=control_image,
|
729 |
-
num_images=num_images,
|
730 |
-
num_steps=num_steps,
|
731 |
-
guidance_scale=guidance_scale,
|
732 |
-
seed=seed,
|
733 |
-
)
|
734 |
-
return [control_image, *results]
|
735 |
-
|
736 |
-
@torch.inference_mode()
|
737 |
-
def process_depth(
|
738 |
-
self,
|
739 |
-
image: np.ndarray,
|
740 |
-
prompt: str,
|
741 |
-
additional_prompt: str,
|
742 |
-
negative_prompt: str,
|
743 |
-
num_images: int,
|
744 |
-
image_resolution: int,
|
745 |
-
preprocess_resolution: int,
|
746 |
-
num_steps: int,
|
747 |
-
guidance_scale: float,
|
748 |
-
seed: int,
|
749 |
-
preprocessor_name: str,
|
750 |
-
) -> list[PIL.Image.Image]:
|
751 |
-
if image is None:
|
752 |
-
raise ValueError
|
753 |
-
if image_resolution > MAX_IMAGE_RESOLUTION:
|
754 |
-
raise ValueError
|
755 |
-
if num_images > MAX_NUM_IMAGES:
|
756 |
-
raise ValueError
|
757 |
-
|
758 |
-
if preprocessor_name == "None":
|
759 |
-
image = HWC3(image)
|
760 |
-
image = resize_image(image, resolution=image_resolution)
|
761 |
-
control_image = PIL.Image.fromarray(image)
|
762 |
-
else:
|
763 |
-
self.preprocessor.load(preprocessor_name)
|
764 |
-
control_image = self.preprocessor(
|
765 |
-
image=image,
|
766 |
-
image_resolution=image_resolution,
|
767 |
-
detect_resolution=preprocess_resolution,
|
768 |
-
)
|
769 |
-
self.load_controlnet_weight("depth")
|
770 |
-
results = self.run_pipe(
|
771 |
-
prompt=self.get_prompt(prompt, additional_prompt),
|
772 |
-
negative_prompt=negative_prompt,
|
773 |
-
control_image=control_image,
|
774 |
-
num_images=num_images,
|
775 |
-
num_steps=num_steps,
|
776 |
-
guidance_scale=guidance_scale,
|
777 |
-
seed=seed,
|
778 |
-
)
|
779 |
-
return [control_image, *results]
|
780 |
-
|
781 |
-
@torch.inference_mode()
|
782 |
-
def process_normal(
|
783 |
-
self,
|
784 |
-
image: np.ndarray,
|
785 |
-
prompt: str,
|
786 |
-
additional_prompt: str,
|
787 |
-
negative_prompt: str,
|
788 |
-
num_images: int,
|
789 |
-
image_resolution: int,
|
790 |
-
preprocess_resolution: int,
|
791 |
-
num_steps: int,
|
792 |
-
guidance_scale: float,
|
793 |
-
seed: int,
|
794 |
-
preprocessor_name: str,
|
795 |
-
) -> list[PIL.Image.Image]:
|
796 |
-
if image is None:
|
797 |
-
raise ValueError
|
798 |
-
if image_resolution > MAX_IMAGE_RESOLUTION:
|
799 |
-
raise ValueError
|
800 |
-
if num_images > MAX_NUM_IMAGES:
|
801 |
-
raise ValueError
|
802 |
-
|
803 |
-
if preprocessor_name == "None":
|
804 |
-
image = HWC3(image)
|
805 |
-
image = resize_image(image, resolution=image_resolution)
|
806 |
-
control_image = PIL.Image.fromarray(image)
|
807 |
-
else:
|
808 |
-
self.preprocessor.load("NormalBae")
|
809 |
-
control_image = self.preprocessor(
|
810 |
-
image=image,
|
811 |
-
image_resolution=image_resolution,
|
812 |
-
detect_resolution=preprocess_resolution,
|
813 |
-
)
|
814 |
-
self.load_controlnet_weight("NormalBae")
|
815 |
-
results = self.run_pipe(
|
816 |
-
prompt=self.get_prompt(prompt, additional_prompt),
|
817 |
-
negative_prompt=negative_prompt,
|
818 |
-
control_image=control_image,
|
819 |
-
num_images=num_images,
|
820 |
-
num_steps=num_steps,
|
821 |
-
guidance_scale=guidance_scale,
|
822 |
-
seed=seed,
|
823 |
-
)
|
824 |
-
return [control_image, *results]
|
825 |
-
|
826 |
-
@torch.inference_mode()
|
827 |
-
def process_lineart(
|
828 |
-
self,
|
829 |
-
image: np.ndarray,
|
830 |
-
prompt: str,
|
831 |
-
additional_prompt: str,
|
832 |
-
negative_prompt: str,
|
833 |
-
num_images: int,
|
834 |
-
image_resolution: int,
|
835 |
-
preprocess_resolution: int,
|
836 |
-
num_steps: int,
|
837 |
-
guidance_scale: float,
|
838 |
-
seed: int,
|
839 |
-
preprocessor_name: str,
|
840 |
-
) -> list[PIL.Image.Image]:
|
841 |
-
if image is None:
|
842 |
-
raise ValueError
|
843 |
-
if image_resolution > MAX_IMAGE_RESOLUTION:
|
844 |
-
raise ValueError
|
845 |
-
if num_images > MAX_NUM_IMAGES:
|
846 |
-
raise ValueError
|
847 |
-
|
848 |
-
if preprocessor_name in ["None", "None (anime)"]:
|
849 |
-
image = HWC3(image)
|
850 |
-
image = resize_image(image, resolution=image_resolution)
|
851 |
-
control_image = PIL.Image.fromarray(image)
|
852 |
-
elif preprocessor_name in ["Lineart", "Lineart coarse"]:
|
853 |
-
coarse = "coarse" in preprocessor_name
|
854 |
-
self.preprocessor.load("Lineart")
|
855 |
-
control_image = self.preprocessor(
|
856 |
-
image=image,
|
857 |
-
image_resolution=image_resolution,
|
858 |
-
detect_resolution=preprocess_resolution,
|
859 |
-
coarse=coarse,
|
860 |
-
)
|
861 |
-
elif preprocessor_name == "Lineart (anime)":
|
862 |
-
self.preprocessor.load("LineartAnime")
|
863 |
-
control_image = self.preprocessor(
|
864 |
-
image=image,
|
865 |
-
image_resolution=image_resolution,
|
866 |
-
detect_resolution=preprocess_resolution,
|
867 |
-
)
|
868 |
-
if "anime" in preprocessor_name:
|
869 |
-
self.load_controlnet_weight("lineart_anime")
|
870 |
-
else:
|
871 |
-
self.load_controlnet_weight("lineart")
|
872 |
-
results = self.run_pipe(
|
873 |
-
prompt=self.get_prompt(prompt, additional_prompt),
|
874 |
-
negative_prompt=negative_prompt,
|
875 |
-
control_image=control_image,
|
876 |
-
num_images=num_images,
|
877 |
-
num_steps=num_steps,
|
878 |
-
guidance_scale=guidance_scale,
|
879 |
-
seed=seed,
|
880 |
-
)
|
881 |
-
return [control_image, *results]
|
882 |
-
|
883 |
-
@torch.inference_mode()
|
884 |
-
def process_shuffle(
|
885 |
-
self,
|
886 |
-
image: np.ndarray,
|
887 |
-
prompt: str,
|
888 |
-
additional_prompt: str,
|
889 |
-
negative_prompt: str,
|
890 |
-
num_images: int,
|
891 |
-
image_resolution: int,
|
892 |
-
num_steps: int,
|
893 |
-
guidance_scale: float,
|
894 |
-
seed: int,
|
895 |
-
preprocessor_name: str,
|
896 |
-
) -> list[PIL.Image.Image]:
|
897 |
-
if image is None:
|
898 |
-
raise ValueError
|
899 |
-
if image_resolution > MAX_IMAGE_RESOLUTION:
|
900 |
-
raise ValueError
|
901 |
-
if num_images > MAX_NUM_IMAGES:
|
902 |
-
raise ValueError
|
903 |
-
|
904 |
-
if preprocessor_name == "None":
|
905 |
-
image = HWC3(image)
|
906 |
-
image = resize_image(image, resolution=image_resolution)
|
907 |
-
control_image = PIL.Image.fromarray(image)
|
908 |
-
else:
|
909 |
-
self.preprocessor.load(preprocessor_name)
|
910 |
-
control_image = self.preprocessor(
|
911 |
-
image=image,
|
912 |
-
image_resolution=image_resolution,
|
913 |
-
)
|
914 |
-
self.load_controlnet_weight("shuffle")
|
915 |
-
results = self.run_pipe(
|
916 |
-
prompt=self.get_prompt(prompt, additional_prompt),
|
917 |
-
negative_prompt=negative_prompt,
|
918 |
-
control_image=control_image,
|
919 |
-
num_images=num_images,
|
920 |
-
num_steps=num_steps,
|
921 |
-
guidance_scale=guidance_scale,
|
922 |
-
seed=seed,
|
923 |
-
)
|
924 |
-
return [control_image, *results]
|
925 |
-
|
926 |
-
@torch.inference_mode()
|
927 |
-
def process_ip2p(
|
928 |
-
self,
|
929 |
-
image: np.ndarray,
|
930 |
-
prompt: str,
|
931 |
-
additional_prompt: str,
|
932 |
-
negative_prompt: str,
|
933 |
-
num_images: int,
|
934 |
-
image_resolution: int,
|
935 |
-
num_steps: int,
|
936 |
-
guidance_scale: float,
|
937 |
-
seed: int,
|
938 |
-
) -> list[PIL.Image.Image]:
|
939 |
-
if image is None:
|
940 |
-
raise ValueError
|
941 |
-
if image_resolution > MAX_IMAGE_RESOLUTION:
|
942 |
-
raise ValueError
|
943 |
-
if num_images > MAX_NUM_IMAGES:
|
944 |
-
raise ValueError
|
945 |
-
|
946 |
-
image = HWC3(image)
|
947 |
-
image = resize_image(image, resolution=image_resolution)
|
948 |
-
control_image = PIL.Image.fromarray(image)
|
949 |
-
self.load_controlnet_weight("ip2p")
|
950 |
-
results = self.run_pipe(
|
951 |
-
prompt=self.get_prompt(prompt, additional_prompt),
|
952 |
-
negative_prompt=negative_prompt,
|
953 |
-
control_image=control_image,
|
954 |
-
num_images=num_images,
|
955 |
-
num_steps=num_steps,
|
956 |
-
guidance_scale=guidance_scale,
|
957 |
-
seed=seed,
|
958 |
-
)
|
959 |
-
return [control_image, *results]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pre-requirements.txt
DELETED
@@ -1,9 +0,0 @@
|
|
1 |
-
accelerate
|
2 |
-
diffusers
|
3 |
-
invisible_watermark
|
4 |
-
torch
|
5 |
-
torchvision
|
6 |
-
transformers
|
7 |
-
xformers
|
8 |
-
controlnet-aux # for controlnet
|
9 |
-
spaces # no need to specify here
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
preprocessor.py
DELETED
@@ -1,120 +0,0 @@
|
|
1 |
-
import gc
|
2 |
-
from typing import TYPE_CHECKING
|
3 |
-
|
4 |
-
if TYPE_CHECKING:
|
5 |
-
from collections.abc import Callable
|
6 |
-
|
7 |
-
import numpy as np
|
8 |
-
import PIL.Image
|
9 |
-
import torch
|
10 |
-
from controlnet_aux import (
|
11 |
-
CannyDetector,
|
12 |
-
ContentShuffleDetector,
|
13 |
-
HEDdetector,
|
14 |
-
LineartAnimeDetector,
|
15 |
-
LineartDetector,
|
16 |
-
MidasDetector,
|
17 |
-
MLSDdetector,
|
18 |
-
NormalBaeDetector,
|
19 |
-
OpenposeDetector,
|
20 |
-
PidiNetDetector,
|
21 |
-
)
|
22 |
-
from controlnet_aux.util import HWC3
|
23 |
-
|
24 |
-
from cv_utils import resize_image
|
25 |
-
from depth_estimator import DepthEstimator
|
26 |
-
from image_segmentor import ImageSegmentor
|
27 |
-
|
28 |
-
|
29 |
-
class Preprocessor:
|
30 |
-
MODEL_ID = "lllyasviel/Annotators"
|
31 |
-
|
32 |
-
def __init__(self) -> None:
|
33 |
-
self.model: Callable = None # type: ignore
|
34 |
-
self.name = ""
|
35 |
-
|
36 |
-
def load(self, name: str) -> None: # noqa: C901, PLR0912
|
37 |
-
if name == self.name:
|
38 |
-
return
|
39 |
-
if name == "HED":
|
40 |
-
self.model = HEDdetector.from_pretrained(self.MODEL_ID)
|
41 |
-
elif name == "Midas":
|
42 |
-
self.model = MidasDetector.from_pretrained(self.MODEL_ID)
|
43 |
-
elif name == "MLSD":
|
44 |
-
self.model = MLSDdetector.from_pretrained(self.MODEL_ID)
|
45 |
-
elif name == "Openpose":
|
46 |
-
self.model = OpenposeDetector.from_pretrained(self.MODEL_ID)
|
47 |
-
elif name == "PidiNet":
|
48 |
-
self.model = PidiNetDetector.from_pretrained(self.MODEL_ID)
|
49 |
-
elif name == "NormalBae":
|
50 |
-
self.model = NormalBaeDetector.from_pretrained(self.MODEL_ID)
|
51 |
-
elif name == "Lineart":
|
52 |
-
self.model = LineartDetector.from_pretrained(self.MODEL_ID)
|
53 |
-
elif name == "LineartAnime":
|
54 |
-
self.model = LineartAnimeDetector.from_pretrained(self.MODEL_ID)
|
55 |
-
elif name == "Canny":
|
56 |
-
self.model = CannyDetector()
|
57 |
-
elif name == "ContentShuffle":
|
58 |
-
self.model = ContentShuffleDetector()
|
59 |
-
elif name == "DPT":
|
60 |
-
self.model = DepthEstimator()
|
61 |
-
elif name == "UPerNet":
|
62 |
-
self.model = ImageSegmentor()
|
63 |
-
elif name == 'texnet':
|
64 |
-
self.model = TexnetPreprocessor()
|
65 |
-
else:
|
66 |
-
raise ValueError
|
67 |
-
torch.cuda.empty_cache()
|
68 |
-
gc.collect()
|
69 |
-
self.name = name
|
70 |
-
|
71 |
-
def __call__(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image: # noqa: ANN003
|
72 |
-
if self.name == "Canny":
|
73 |
-
if "detect_resolution" in kwargs:
|
74 |
-
detect_resolution = kwargs.pop("detect_resolution")
|
75 |
-
image = np.array(image)
|
76 |
-
image = HWC3(image)
|
77 |
-
image = resize_image(image, resolution=detect_resolution)
|
78 |
-
image = self.model(image, **kwargs)
|
79 |
-
return PIL.Image.fromarray(image)
|
80 |
-
if self.name == "Midas":
|
81 |
-
detect_resolution = kwargs.pop("detect_resolution", 512)
|
82 |
-
image_resolution = kwargs.pop("image_resolution", 512)
|
83 |
-
image = np.array(image)
|
84 |
-
image = HWC3(image)
|
85 |
-
image = resize_image(image, resolution=detect_resolution)
|
86 |
-
image = self.model(image, **kwargs)
|
87 |
-
image = HWC3(image)
|
88 |
-
image = resize_image(image, resolution=image_resolution)
|
89 |
-
return PIL.Image.fromarray(image)
|
90 |
-
return self.model(image, **kwargs)
|
91 |
-
|
92 |
-
|
93 |
-
# https://github.com/huggingface/controlnet_aux/blob/master/src/controlnet_aux/canny/__init__.py
|
94 |
-
class TexnetPreprocessor:
|
95 |
-
def __call__(self, input_image=None, low_threshold=100, high_threshold=200, image_resolution=512, output_type=None, **kwargs):
|
96 |
-
if "img" in kwargs:
|
97 |
-
warnings.warn("img is deprecated, please use `input_image=...` instead.", DeprecationWarning)
|
98 |
-
input_image = kwargs.pop("img")
|
99 |
-
|
100 |
-
if input_image is None:
|
101 |
-
raise ValueError("input_image must be defined.")
|
102 |
-
|
103 |
-
if not isinstance(input_image, np.ndarray):
|
104 |
-
input_image = np.array(input_image, dtype=np.uint8)
|
105 |
-
output_type = output_type or "pil"
|
106 |
-
else:
|
107 |
-
output_type = output_type or "np"
|
108 |
-
|
109 |
-
input_image = HWC3(input_image)
|
110 |
-
input_image = resize_image(input_image, image_resolution)
|
111 |
-
H, W, C = input_image.shape
|
112 |
-
|
113 |
-
# detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
|
114 |
-
output_image = input_image.copy()
|
115 |
-
|
116 |
-
if output_type == "pil":
|
117 |
-
# detected_map = Image.fromarray(detected_map)
|
118 |
-
output_image = PIL.Image.fromarray(output_image)
|
119 |
-
|
120 |
-
return output_image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
push_dataset.py
DELETED
@@ -1,9 +0,0 @@
|
|
1 |
-
from huggingface_hub import HfApi
|
2 |
-
api = HfApi()
|
3 |
-
|
4 |
-
api.upload_folder(
|
5 |
-
folder_path="./examples",
|
6 |
-
repo_id="jingyangcarl/matgen",
|
7 |
-
repo_type="space",
|
8 |
-
path_in_repo="examples", # Upload to a specific folder
|
9 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
DELETED
@@ -1,9 +0,0 @@
|
|
1 |
-
torch
|
2 |
-
torchvision
|
3 |
-
pytorch3d @ git+https://github.com/facebookresearch/pytorch3d.git@stable
|
4 |
-
trimesh
|
5 |
-
xatlas
|
6 |
-
scikit-learn
|
7 |
-
opencv-python
|
8 |
-
matplotlib
|
9 |
-
omegaconf
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rgb2x/generate_blend.py
DELETED
@@ -1,142 +0,0 @@
|
|
1 |
-
import bpy
|
2 |
-
import sys
|
3 |
-
import os
|
4 |
-
|
5 |
-
def create_tex_node(nodes, img_path, label, color_space, location):
|
6 |
-
img = bpy.data.images.load(img_path)
|
7 |
-
tex = nodes.new(type='ShaderNodeTexImage')
|
8 |
-
tex.image = img
|
9 |
-
tex.label = label
|
10 |
-
tex.location = location
|
11 |
-
tex.image.colorspace_settings.name = color_space
|
12 |
-
return tex
|
13 |
-
|
14 |
-
def setup_environment_lighting(hdri_path):
|
15 |
-
if not bpy.data.worlds:
|
16 |
-
bpy.data.worlds.new(name="World")
|
17 |
-
if bpy.context.scene.world is None:
|
18 |
-
bpy.context.scene.world = bpy.data.worlds[0]
|
19 |
-
world = bpy.context.scene.world
|
20 |
-
|
21 |
-
world.use_nodes = True
|
22 |
-
nodes = world.node_tree.nodes
|
23 |
-
links = world.node_tree.links
|
24 |
-
nodes.clear()
|
25 |
-
|
26 |
-
env_tex = nodes.new(type="ShaderNodeTexEnvironment")
|
27 |
-
env_tex.image = bpy.data.images.load(hdri_path)
|
28 |
-
env_tex.location = (-300, 0)
|
29 |
-
|
30 |
-
bg = nodes.new(type="ShaderNodeBackground")
|
31 |
-
bg.location = (0, 0)
|
32 |
-
|
33 |
-
output = nodes.new(type="ShaderNodeOutputWorld")
|
34 |
-
output.location = (300, 0)
|
35 |
-
|
36 |
-
links.new(env_tex.outputs["Color"], bg.inputs["Color"])
|
37 |
-
links.new(bg.outputs["Background"], output.inputs["Surface"])
|
38 |
-
|
39 |
-
def setup_gpu_rendering():
|
40 |
-
bpy.context.scene.render.engine = 'CYCLES'
|
41 |
-
prefs = bpy.context.preferences
|
42 |
-
cprefs = prefs.addons['cycles'].preferences
|
43 |
-
|
44 |
-
# Choose backend depending on GPU type: 'CUDA', 'OPTIX', 'HIP', 'METAL'
|
45 |
-
cprefs.compute_device_type = 'CUDA'
|
46 |
-
bpy.context.scene.cycles.device = 'GPU'
|
47 |
-
|
48 |
-
def generate_blend(obj_path, base_color_path, normal_map_path, roughness_path, metallic_path, output_blend):
|
49 |
-
# Reset scene
|
50 |
-
bpy.ops.wm.read_factory_settings(use_empty=True)
|
51 |
-
|
52 |
-
# Import OBJ
|
53 |
-
bpy.ops.import_scene.obj(filepath=obj_path)
|
54 |
-
obj = bpy.context.selected_objects[0]
|
55 |
-
|
56 |
-
# Create material
|
57 |
-
mat = bpy.data.materials.new(name="BRDF_Material")
|
58 |
-
mat.use_nodes = True
|
59 |
-
nodes = mat.node_tree.nodes
|
60 |
-
links = mat.node_tree.links
|
61 |
-
nodes.clear()
|
62 |
-
|
63 |
-
output = nodes.new(type='ShaderNodeOutputMaterial')
|
64 |
-
output.location = (400, 0)
|
65 |
-
|
66 |
-
principled = nodes.new(type='ShaderNodeBsdfPrincipled')
|
67 |
-
principled.location = (100, 0)
|
68 |
-
links.new(principled.outputs['BSDF'], output.inputs['Surface'])
|
69 |
-
|
70 |
-
# Base Color
|
71 |
-
base_color = create_tex_node(nodes, base_color_path, "Base Color", 'sRGB', (-600, 200))
|
72 |
-
links.new(base_color.outputs['Color'], principled.inputs['Base Color'])
|
73 |
-
|
74 |
-
# Roughness
|
75 |
-
rough = create_tex_node(nodes, roughness_path, "Roughness", 'Non-Color', (-600, 0))
|
76 |
-
links.new(rough.outputs['Color'], principled.inputs['Roughness'])
|
77 |
-
|
78 |
-
# Metallic
|
79 |
-
metal = create_tex_node(nodes, metallic_path, "Metallic", 'Non-Color', (-600, -200))
|
80 |
-
links.new(metal.outputs['Color'], principled.inputs['Metallic'])
|
81 |
-
|
82 |
-
# Normal Map
|
83 |
-
normal_tex = create_tex_node(nodes, normal_map_path, "Normal Map", 'Non-Color', (-800, -400))
|
84 |
-
normal_map = nodes.new(type='ShaderNodeNormalMap')
|
85 |
-
normal_map.location = (-400, -400)
|
86 |
-
links.new(normal_tex.outputs['Color'], normal_map.inputs['Color'])
|
87 |
-
links.new(normal_map.outputs['Normal'], principled.inputs['Normal'])
|
88 |
-
|
89 |
-
# Assign material
|
90 |
-
if obj.data.materials:
|
91 |
-
obj.data.materials[0] = mat
|
92 |
-
else:
|
93 |
-
obj.data.materials.append(mat)
|
94 |
-
|
95 |
-
# Global Illumination using Blender's default forest HDRI
|
96 |
-
blender_data_path = bpy.utils.resource_path('LOCAL')
|
97 |
-
forest_hdri_path = os.path.join(blender_data_path, "datafiles", "studiolights", "world", "forest.exr")
|
98 |
-
print(f"Using HDRI: {forest_hdri_path}")
|
99 |
-
setup_environment_lighting(forest_hdri_path)
|
100 |
-
|
101 |
-
# GPU rendering setup
|
102 |
-
setup_gpu_rendering()
|
103 |
-
|
104 |
-
# Pack textures into .blend
|
105 |
-
bpy.ops.file.pack_all()
|
106 |
-
|
107 |
-
# Set the 3D View to Rendered mode and focus on object
|
108 |
-
for area in bpy.context.screen.areas:
|
109 |
-
if area.type == 'VIEW_3D':
|
110 |
-
for space in area.spaces:
|
111 |
-
if space.type == 'VIEW_3D':
|
112 |
-
space.shading.type = 'RENDERED' # Set viewport shading to Rendered
|
113 |
-
for region in area.regions:
|
114 |
-
if region.type == 'WINDOW':
|
115 |
-
override = {'area': area, 'region': region, 'scene': bpy.context.scene}
|
116 |
-
bpy.ops.view3d.view_all(override, center=True)
|
117 |
-
|
118 |
-
elif area.type == 'NODE_EDITOR':
|
119 |
-
for space in area.spaces:
|
120 |
-
if space.type == 'NODE_EDITOR':
|
121 |
-
space.tree_type = 'ShaderNodeTree' # Switch to Shader Editor
|
122 |
-
space.shader_type = 'OBJECT'
|
123 |
-
|
124 |
-
# Optional: Switch active workspace to Shading (if it exists)
|
125 |
-
for workspace in bpy.data.workspaces:
|
126 |
-
if workspace.name == 'Shading':
|
127 |
-
bpy.context.window.workspace = workspace
|
128 |
-
break
|
129 |
-
|
130 |
-
# Save the .blend file
|
131 |
-
bpy.ops.wm.save_as_mainfile(filepath=output_blend)
|
132 |
-
print(f"✅ Saved .blend file with BRDF, HDRI, GPU: {output_blend}")
|
133 |
-
|
134 |
-
if __name__ == "__main__":
|
135 |
-
argv = sys.argv
|
136 |
-
argv = argv[argv.index("--") + 1:] # Only use args after "--"
|
137 |
-
|
138 |
-
if len(argv) != 6:
|
139 |
-
print("Usage:\n blender --background --python generate_blend.py -- obj base_color normal roughness metallic output.blend")
|
140 |
-
sys.exit(1)
|
141 |
-
|
142 |
-
generate_blend(*argv)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rgb2x/gradio_demo_rgb2x.py
DELETED
@@ -1,157 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
|
3 |
-
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
|
4 |
-
|
5 |
-
import gradio as gr
|
6 |
-
import torch
|
7 |
-
import torchvision
|
8 |
-
from diffusers import DDIMScheduler
|
9 |
-
from load_image import load_exr_image, load_ldr_image
|
10 |
-
from pipeline_rgb2x import StableDiffusionAOVMatEstPipeline
|
11 |
-
|
12 |
-
current_directory = os.path.dirname(os.path.abspath(__file__))
|
13 |
-
|
14 |
-
|
15 |
-
def get_rgb2x_demo():
|
16 |
-
# Load pipeline
|
17 |
-
pipe = StableDiffusionAOVMatEstPipeline.from_pretrained(
|
18 |
-
"zheng95z/rgb-to-x",
|
19 |
-
torch_dtype=torch.float16,
|
20 |
-
cache_dir=os.path.join(current_directory, "model_cache"),
|
21 |
-
).to("cuda")
|
22 |
-
pipe.scheduler = DDIMScheduler.from_config(
|
23 |
-
pipe.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing"
|
24 |
-
)
|
25 |
-
pipe.set_progress_bar_config(disable=True)
|
26 |
-
pipe.to("cuda")
|
27 |
-
|
28 |
-
# Augmentation
|
29 |
-
def callback(
|
30 |
-
photo,
|
31 |
-
seed,
|
32 |
-
inference_step,
|
33 |
-
num_samples,
|
34 |
-
):
|
35 |
-
generator = torch.Generator(device="cuda").manual_seed(seed)
|
36 |
-
|
37 |
-
if photo.name.endswith(".exr"):
|
38 |
-
photo = load_exr_image(photo.name, tonemaping=True, clamp=True).to("cuda")
|
39 |
-
elif (
|
40 |
-
photo.name.endswith(".png")
|
41 |
-
or photo.name.endswith(".jpg")
|
42 |
-
or photo.name.endswith(".jpeg")
|
43 |
-
):
|
44 |
-
photo = load_ldr_image(photo.name, from_srgb=True).to("cuda")
|
45 |
-
|
46 |
-
# Check if the width and height are multiples of 8. If not, crop it using torchvision.transforms.CenterCrop
|
47 |
-
old_height = photo.shape[1]
|
48 |
-
old_width = photo.shape[2]
|
49 |
-
new_height = old_height
|
50 |
-
new_width = old_width
|
51 |
-
radio = old_height / old_width
|
52 |
-
max_side = 1000
|
53 |
-
if old_height > old_width:
|
54 |
-
new_height = max_side
|
55 |
-
new_width = int(new_height / radio)
|
56 |
-
else:
|
57 |
-
new_width = max_side
|
58 |
-
new_height = int(new_width * radio)
|
59 |
-
|
60 |
-
if new_width % 8 != 0 or new_height % 8 != 0:
|
61 |
-
new_width = new_width // 8 * 8
|
62 |
-
new_height = new_height // 8 * 8
|
63 |
-
|
64 |
-
photo = torchvision.transforms.Resize((new_height, new_width))(photo)
|
65 |
-
|
66 |
-
required_aovs = ["albedo", "normal", "roughness", "metallic", "irradiance"]
|
67 |
-
prompts = {
|
68 |
-
"albedo": "Albedo (diffuse basecolor)",
|
69 |
-
"normal": "Camera-space Normal",
|
70 |
-
"roughness": "Roughness",
|
71 |
-
"metallic": "Metallicness",
|
72 |
-
"irradiance": "Irradiance (diffuse lighting)",
|
73 |
-
}
|
74 |
-
|
75 |
-
return_list = []
|
76 |
-
for i in range(num_samples):
|
77 |
-
for aov_name in required_aovs:
|
78 |
-
prompt = prompts[aov_name]
|
79 |
-
generated_image = pipe(
|
80 |
-
prompt=prompt,
|
81 |
-
photo=photo,
|
82 |
-
num_inference_steps=inference_step,
|
83 |
-
height=new_height,
|
84 |
-
width=new_width,
|
85 |
-
generator=generator,
|
86 |
-
required_aovs=[aov_name],
|
87 |
-
).images[0][0]
|
88 |
-
|
89 |
-
generated_image = torchvision.transforms.Resize(
|
90 |
-
(old_height, old_width)
|
91 |
-
)(generated_image)
|
92 |
-
|
93 |
-
generated_image = (generated_image, f"Generated {aov_name} {i}")
|
94 |
-
return_list.append(generated_image)
|
95 |
-
|
96 |
-
return return_list
|
97 |
-
|
98 |
-
block = gr.Blocks()
|
99 |
-
with block:
|
100 |
-
with gr.Row():
|
101 |
-
gr.Markdown("## Model RGB -> X (Realistic image -> Intrinsic channels)")
|
102 |
-
with gr.Row():
|
103 |
-
# Input side
|
104 |
-
with gr.Column():
|
105 |
-
gr.Markdown("### Given Image")
|
106 |
-
photo = gr.File(label="Photo", file_types=[".exr", ".png", ".jpg"])
|
107 |
-
|
108 |
-
gr.Markdown("### Parameters")
|
109 |
-
run_button = gr.Button(value="Run")
|
110 |
-
with gr.Accordion("Advanced options", open=False):
|
111 |
-
seed = gr.Slider(
|
112 |
-
label="Seed",
|
113 |
-
minimum=-1,
|
114 |
-
maximum=2147483647,
|
115 |
-
step=1,
|
116 |
-
randomize=True,
|
117 |
-
)
|
118 |
-
inference_step = gr.Slider(
|
119 |
-
label="Inference Step",
|
120 |
-
minimum=1,
|
121 |
-
maximum=100,
|
122 |
-
step=1,
|
123 |
-
value=50,
|
124 |
-
)
|
125 |
-
num_samples = gr.Slider(
|
126 |
-
label="Samples",
|
127 |
-
minimum=1,
|
128 |
-
maximum=100,
|
129 |
-
step=1,
|
130 |
-
value=1,
|
131 |
-
)
|
132 |
-
|
133 |
-
# Output side
|
134 |
-
with gr.Column():
|
135 |
-
gr.Markdown("### Output Gallery")
|
136 |
-
result_gallery = gr.Gallery(
|
137 |
-
label="Output",
|
138 |
-
show_label=False,
|
139 |
-
elem_id="gallery",
|
140 |
-
columns=2,
|
141 |
-
)
|
142 |
-
|
143 |
-
inputs = [
|
144 |
-
photo,
|
145 |
-
seed,
|
146 |
-
inference_step,
|
147 |
-
num_samples,
|
148 |
-
]
|
149 |
-
run_button.click(fn=callback, inputs=inputs, outputs=result_gallery, queue=True)
|
150 |
-
|
151 |
-
return block
|
152 |
-
|
153 |
-
|
154 |
-
if __name__ == "__main__":
|
155 |
-
demo = get_rgb2x_demo()
|
156 |
-
demo.queue(max_size=1)
|
157 |
-
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rgb2x/load_image.py
DELETED
@@ -1,119 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
|
3 |
-
import cv2
|
4 |
-
import torch
|
5 |
-
|
6 |
-
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
|
7 |
-
import numpy as np
|
8 |
-
|
9 |
-
|
10 |
-
def convert_rgb_2_XYZ(rgb):
|
11 |
-
# Reference: https://web.archive.org/web/20191027010220/http://www.brucelindbloom.com/index.html?Eqn_RGB_XYZ_Matrix.html
|
12 |
-
# rgb: (h, w, 3)
|
13 |
-
# XYZ: (h, w, 3)
|
14 |
-
XYZ = torch.ones_like(rgb)
|
15 |
-
XYZ[:, :, 0] = (
|
16 |
-
0.4124564 * rgb[:, :, 0] + 0.3575761 * rgb[:, :, 1] + 0.1804375 * rgb[:, :, 2]
|
17 |
-
)
|
18 |
-
XYZ[:, :, 1] = (
|
19 |
-
0.2126729 * rgb[:, :, 0] + 0.7151522 * rgb[:, :, 1] + 0.0721750 * rgb[:, :, 2]
|
20 |
-
)
|
21 |
-
XYZ[:, :, 2] = (
|
22 |
-
0.0193339 * rgb[:, :, 0] + 0.1191920 * rgb[:, :, 1] + 0.9503041 * rgb[:, :, 2]
|
23 |
-
)
|
24 |
-
return XYZ
|
25 |
-
|
26 |
-
|
27 |
-
def convert_XYZ_2_Yxy(XYZ):
|
28 |
-
# XYZ: (h, w, 3)
|
29 |
-
# Yxy: (h, w, 3)
|
30 |
-
Yxy = torch.ones_like(XYZ)
|
31 |
-
Yxy[:, :, 0] = XYZ[:, :, 1]
|
32 |
-
sum = torch.sum(XYZ, dim=2)
|
33 |
-
inv_sum = 1.0 / torch.clamp(sum, min=1e-4)
|
34 |
-
Yxy[:, :, 1] = XYZ[:, :, 0] * inv_sum
|
35 |
-
Yxy[:, :, 2] = XYZ[:, :, 1] * inv_sum
|
36 |
-
return Yxy
|
37 |
-
|
38 |
-
|
39 |
-
def convert_rgb_2_Yxy(rgb):
|
40 |
-
# rgb: (h, w, 3)
|
41 |
-
# Yxy: (h, w, 3)
|
42 |
-
return convert_XYZ_2_Yxy(convert_rgb_2_XYZ(rgb))
|
43 |
-
|
44 |
-
|
45 |
-
def convert_XYZ_2_rgb(XYZ):
|
46 |
-
# XYZ: (h, w, 3)
|
47 |
-
# rgb: (h, w, 3)
|
48 |
-
rgb = torch.ones_like(XYZ)
|
49 |
-
rgb[:, :, 0] = (
|
50 |
-
3.2404542 * XYZ[:, :, 0] - 1.5371385 * XYZ[:, :, 1] - 0.4985314 * XYZ[:, :, 2]
|
51 |
-
)
|
52 |
-
rgb[:, :, 1] = (
|
53 |
-
-0.9692660 * XYZ[:, :, 0] + 1.8760108 * XYZ[:, :, 1] + 0.0415560 * XYZ[:, :, 2]
|
54 |
-
)
|
55 |
-
rgb[:, :, 2] = (
|
56 |
-
0.0556434 * XYZ[:, :, 0] - 0.2040259 * XYZ[:, :, 1] + 1.0572252 * XYZ[:, :, 2]
|
57 |
-
)
|
58 |
-
return rgb
|
59 |
-
|
60 |
-
|
61 |
-
def convert_Yxy_2_XYZ(Yxy):
|
62 |
-
# Yxy: (h, w, 3)
|
63 |
-
# XYZ: (h, w, 3)
|
64 |
-
XYZ = torch.ones_like(Yxy)
|
65 |
-
XYZ[:, :, 0] = Yxy[:, :, 1] / torch.clamp(Yxy[:, :, 2], min=1e-6) * Yxy[:, :, 0]
|
66 |
-
XYZ[:, :, 1] = Yxy[:, :, 0]
|
67 |
-
XYZ[:, :, 2] = (
|
68 |
-
(1.0 - Yxy[:, :, 1] - Yxy[:, :, 2])
|
69 |
-
/ torch.clamp(Yxy[:, :, 2], min=1e-4)
|
70 |
-
* Yxy[:, :, 0]
|
71 |
-
)
|
72 |
-
return XYZ
|
73 |
-
|
74 |
-
|
75 |
-
def convert_Yxy_2_rgb(Yxy):
|
76 |
-
# Yxy: (h, w, 3)
|
77 |
-
# rgb: (h, w, 3)
|
78 |
-
return convert_XYZ_2_rgb(convert_Yxy_2_XYZ(Yxy))
|
79 |
-
|
80 |
-
|
81 |
-
def load_ldr_image(image_path, from_srgb=False, clamp=False, normalize=False):
|
82 |
-
# Load png or jpg image
|
83 |
-
image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
|
84 |
-
image = torch.from_numpy(image.astype(np.float32) / 255.0) # (h, w, c)
|
85 |
-
image[~torch.isfinite(image)] = 0
|
86 |
-
if from_srgb:
|
87 |
-
# Convert from sRGB to linear RGB
|
88 |
-
image = image**2.2
|
89 |
-
if clamp:
|
90 |
-
image = torch.clamp(image, min=0.0, max=1.0)
|
91 |
-
if normalize:
|
92 |
-
# Normalize to [-1, 1]
|
93 |
-
image = image * 2.0 - 1.0
|
94 |
-
image = torch.nn.functional.normalize(image, dim=-1, eps=1e-6)
|
95 |
-
return image.permute(2, 0, 1) # returns (c, h, w)
|
96 |
-
|
97 |
-
|
98 |
-
def load_exr_image(image_path, tonemaping=False, clamp=False, normalize=False):
|
99 |
-
image = cv2.cvtColor(cv2.imread(image_path, -1), cv2.COLOR_BGR2RGB)
|
100 |
-
image = torch.from_numpy(image.astype("float32")) # (h, w, c)
|
101 |
-
image[~torch.isfinite(image)] = 0
|
102 |
-
if tonemaping:
|
103 |
-
# Exposure adjuestment
|
104 |
-
image_Yxy = convert_rgb_2_Yxy(image)
|
105 |
-
lum = (
|
106 |
-
image[:, :, 0:1] * 0.2125
|
107 |
-
+ image[:, :, 1:2] * 0.7154
|
108 |
-
+ image[:, :, 2:3] * 0.0721
|
109 |
-
)
|
110 |
-
lum = torch.log(torch.clamp(lum, min=1e-6))
|
111 |
-
lum_mean = torch.exp(torch.mean(lum))
|
112 |
-
lp = image_Yxy[:, :, 0:1] * 0.18 / torch.clamp(lum_mean, min=1e-6)
|
113 |
-
image_Yxy[:, :, 0:1] = lp
|
114 |
-
image = convert_Yxy_2_rgb(image_Yxy)
|
115 |
-
if clamp:
|
116 |
-
image = torch.clamp(image, min=0.0, max=1.0)
|
117 |
-
if normalize:
|
118 |
-
image = torch.nn.functional.normalize(image, dim=-1, eps=1e-6)
|
119 |
-
return image.permute(2, 0, 1) # returns (c, h, w)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rgb2x/pipeline_rgb2x.py
DELETED
@@ -1,821 +0,0 @@
|
|
1 |
-
import inspect
|
2 |
-
from dataclasses import dataclass
|
3 |
-
from typing import Callable, List, Optional, Union
|
4 |
-
|
5 |
-
import numpy as np
|
6 |
-
import PIL
|
7 |
-
import torch
|
8 |
-
from diffusers.configuration_utils import register_to_config
|
9 |
-
from diffusers.image_processor import VaeImageProcessor
|
10 |
-
from diffusers.loaders import (
|
11 |
-
LoraLoaderMixin,
|
12 |
-
TextualInversionLoaderMixin,
|
13 |
-
)
|
14 |
-
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
15 |
-
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
16 |
-
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
|
17 |
-
rescale_noise_cfg,
|
18 |
-
)
|
19 |
-
from diffusers.schedulers import KarrasDiffusionSchedulers
|
20 |
-
from diffusers.utils import (
|
21 |
-
CONFIG_NAME,
|
22 |
-
BaseOutput,
|
23 |
-
deprecate,
|
24 |
-
logging,
|
25 |
-
)
|
26 |
-
from diffusers.utils.torch_utils import randn_tensor
|
27 |
-
from transformers import CLIPTextModel, CLIPTokenizer
|
28 |
-
|
29 |
-
logger = logging.get_logger(__name__)
|
30 |
-
|
31 |
-
|
32 |
-
class VaeImageProcrssorAOV(VaeImageProcessor):
|
33 |
-
"""
|
34 |
-
Image processor for VAE AOV.
|
35 |
-
|
36 |
-
Args:
|
37 |
-
do_resize (`bool`, *optional*, defaults to `True`):
|
38 |
-
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
|
39 |
-
vae_scale_factor (`int`, *optional*, defaults to `8`):
|
40 |
-
VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
|
41 |
-
resample (`str`, *optional*, defaults to `lanczos`):
|
42 |
-
Resampling filter to use when resizing the image.
|
43 |
-
do_normalize (`bool`, *optional*, defaults to `True`):
|
44 |
-
Whether to normalize the image to [-1,1].
|
45 |
-
"""
|
46 |
-
|
47 |
-
config_name = CONFIG_NAME
|
48 |
-
|
49 |
-
@register_to_config
|
50 |
-
def __init__(
|
51 |
-
self,
|
52 |
-
do_resize: bool = True,
|
53 |
-
vae_scale_factor: int = 8,
|
54 |
-
resample: str = "lanczos",
|
55 |
-
do_normalize: bool = True,
|
56 |
-
):
|
57 |
-
super().__init__()
|
58 |
-
|
59 |
-
def postprocess(
|
60 |
-
self,
|
61 |
-
image: torch.FloatTensor,
|
62 |
-
output_type: str = "pil",
|
63 |
-
do_denormalize: Optional[List[bool]] = None,
|
64 |
-
do_gamma_correction: bool = True,
|
65 |
-
):
|
66 |
-
if not isinstance(image, torch.Tensor):
|
67 |
-
raise ValueError(
|
68 |
-
f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
|
69 |
-
)
|
70 |
-
if output_type not in ["latent", "pt", "np", "pil"]:
|
71 |
-
deprecation_message = (
|
72 |
-
f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
|
73 |
-
"`pil`, `np`, `pt`, `latent`"
|
74 |
-
)
|
75 |
-
deprecate(
|
76 |
-
"Unsupported output_type",
|
77 |
-
"1.0.0",
|
78 |
-
deprecation_message,
|
79 |
-
standard_warn=False,
|
80 |
-
)
|
81 |
-
output_type = "np"
|
82 |
-
|
83 |
-
if output_type == "latent":
|
84 |
-
return image
|
85 |
-
|
86 |
-
if do_denormalize is None:
|
87 |
-
do_denormalize = [self.config.do_normalize] * image.shape[0]
|
88 |
-
|
89 |
-
image = torch.stack(
|
90 |
-
[
|
91 |
-
self.denormalize(image[i]) if do_denormalize[i] else image[i]
|
92 |
-
for i in range(image.shape[0])
|
93 |
-
]
|
94 |
-
)
|
95 |
-
|
96 |
-
# Gamma correction
|
97 |
-
if do_gamma_correction:
|
98 |
-
image = torch.pow(image, 1.0 / 2.2)
|
99 |
-
|
100 |
-
if output_type == "pt":
|
101 |
-
return image
|
102 |
-
|
103 |
-
image = self.pt_to_numpy(image)
|
104 |
-
|
105 |
-
if output_type == "np":
|
106 |
-
return image
|
107 |
-
|
108 |
-
if output_type == "pil":
|
109 |
-
return self.numpy_to_pil(image)
|
110 |
-
|
111 |
-
def preprocess_normal(
|
112 |
-
self,
|
113 |
-
image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
|
114 |
-
height: Optional[int] = None,
|
115 |
-
width: Optional[int] = None,
|
116 |
-
) -> torch.Tensor:
|
117 |
-
image = torch.stack([image], axis=0)
|
118 |
-
return image
|
119 |
-
|
120 |
-
|
121 |
-
@dataclass
|
122 |
-
class StableDiffusionAOVPipelineOutput(BaseOutput):
|
123 |
-
"""
|
124 |
-
Output class for Stable Diffusion AOV pipelines.
|
125 |
-
|
126 |
-
Args:
|
127 |
-
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
128 |
-
List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
|
129 |
-
num_channels)`.
|
130 |
-
nsfw_content_detected (`List[bool]`)
|
131 |
-
List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or
|
132 |
-
`None` if safety checking could not be performed.
|
133 |
-
"""
|
134 |
-
|
135 |
-
images: Union[List[PIL.Image.Image], np.ndarray]
|
136 |
-
|
137 |
-
|
138 |
-
class StableDiffusionAOVMatEstPipeline(
|
139 |
-
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin
|
140 |
-
):
|
141 |
-
r"""
|
142 |
-
Pipeline for AOVs.
|
143 |
-
|
144 |
-
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
145 |
-
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
146 |
-
|
147 |
-
The pipeline also inherits the following loading methods:
|
148 |
-
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
149 |
-
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
150 |
-
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
151 |
-
|
152 |
-
Args:
|
153 |
-
vae ([`AutoencoderKL`]):
|
154 |
-
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
155 |
-
text_encoder ([`~transformers.CLIPTextModel`]):
|
156 |
-
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
157 |
-
tokenizer ([`~transformers.CLIPTokenizer`]):
|
158 |
-
A `CLIPTokenizer` to tokenize text.
|
159 |
-
unet ([`UNet2DConditionModel`]):
|
160 |
-
A `UNet2DConditionModel` to denoise the encoded image latents.
|
161 |
-
scheduler ([`SchedulerMixin`]):
|
162 |
-
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
163 |
-
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
164 |
-
"""
|
165 |
-
|
166 |
-
def __init__(
|
167 |
-
self,
|
168 |
-
vae: AutoencoderKL,
|
169 |
-
text_encoder: CLIPTextModel,
|
170 |
-
tokenizer: CLIPTokenizer,
|
171 |
-
unet: UNet2DConditionModel,
|
172 |
-
scheduler: KarrasDiffusionSchedulers,
|
173 |
-
):
|
174 |
-
super().__init__()
|
175 |
-
|
176 |
-
self.register_modules(
|
177 |
-
vae=vae,
|
178 |
-
text_encoder=text_encoder,
|
179 |
-
tokenizer=tokenizer,
|
180 |
-
unet=unet,
|
181 |
-
scheduler=scheduler,
|
182 |
-
)
|
183 |
-
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
184 |
-
self.image_processor = VaeImageProcrssorAOV(
|
185 |
-
vae_scale_factor=self.vae_scale_factor
|
186 |
-
)
|
187 |
-
self.register_to_config()
|
188 |
-
|
189 |
-
def _encode_prompt(
|
190 |
-
self,
|
191 |
-
prompt,
|
192 |
-
device,
|
193 |
-
num_images_per_prompt,
|
194 |
-
do_classifier_free_guidance,
|
195 |
-
negative_prompt=None,
|
196 |
-
prompt_embeds: Optional[torch.FloatTensor] = None,
|
197 |
-
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
198 |
-
):
|
199 |
-
r"""
|
200 |
-
Encodes the prompt into text encoder hidden states.
|
201 |
-
|
202 |
-
Args:
|
203 |
-
prompt (`str` or `List[str]`, *optional*):
|
204 |
-
prompt to be encoded
|
205 |
-
device: (`torch.device`):
|
206 |
-
torch device
|
207 |
-
num_images_per_prompt (`int`):
|
208 |
-
number of images that should be generated per prompt
|
209 |
-
do_classifier_free_guidance (`bool`):
|
210 |
-
whether to use classifier free guidance or not
|
211 |
-
negative_ prompt (`str` or `List[str]`, *optional*):
|
212 |
-
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
213 |
-
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
214 |
-
less than `1`).
|
215 |
-
prompt_embeds (`torch.FloatTensor`, *optional*):
|
216 |
-
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
217 |
-
provided, text embeddings will be generated from `prompt` input argument.
|
218 |
-
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
219 |
-
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
220 |
-
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
221 |
-
argument.
|
222 |
-
"""
|
223 |
-
if prompt is not None and isinstance(prompt, str):
|
224 |
-
batch_size = 1
|
225 |
-
elif prompt is not None and isinstance(prompt, list):
|
226 |
-
batch_size = len(prompt)
|
227 |
-
else:
|
228 |
-
batch_size = prompt_embeds.shape[0]
|
229 |
-
|
230 |
-
if prompt_embeds is None:
|
231 |
-
# textual inversion: procecss multi-vector tokens if necessary
|
232 |
-
if isinstance(self, TextualInversionLoaderMixin):
|
233 |
-
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
234 |
-
|
235 |
-
text_inputs = self.tokenizer(
|
236 |
-
prompt,
|
237 |
-
padding="max_length",
|
238 |
-
max_length=self.tokenizer.model_max_length,
|
239 |
-
truncation=True,
|
240 |
-
return_tensors="pt",
|
241 |
-
)
|
242 |
-
text_input_ids = text_inputs.input_ids
|
243 |
-
untruncated_ids = self.tokenizer(
|
244 |
-
prompt, padding="longest", return_tensors="pt"
|
245 |
-
).input_ids
|
246 |
-
|
247 |
-
if untruncated_ids.shape[-1] >= text_input_ids.shape[
|
248 |
-
-1
|
249 |
-
] and not torch.equal(text_input_ids, untruncated_ids):
|
250 |
-
removed_text = self.tokenizer.batch_decode(
|
251 |
-
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
252 |
-
)
|
253 |
-
logger.warning(
|
254 |
-
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
255 |
-
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
256 |
-
)
|
257 |
-
|
258 |
-
if (
|
259 |
-
hasattr(self.text_encoder.config, "use_attention_mask")
|
260 |
-
and self.text_encoder.config.use_attention_mask
|
261 |
-
):
|
262 |
-
attention_mask = text_inputs.attention_mask.to(device)
|
263 |
-
else:
|
264 |
-
attention_mask = None
|
265 |
-
|
266 |
-
prompt_embeds = self.text_encoder(
|
267 |
-
text_input_ids.to(device),
|
268 |
-
attention_mask=attention_mask,
|
269 |
-
)
|
270 |
-
prompt_embeds = prompt_embeds[0]
|
271 |
-
|
272 |
-
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
273 |
-
|
274 |
-
bs_embed, seq_len, _ = prompt_embeds.shape
|
275 |
-
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
276 |
-
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
277 |
-
prompt_embeds = prompt_embeds.view(
|
278 |
-
bs_embed * num_images_per_prompt, seq_len, -1
|
279 |
-
)
|
280 |
-
|
281 |
-
# get unconditional embeddings for classifier free guidance
|
282 |
-
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
283 |
-
uncond_tokens: List[str]
|
284 |
-
if negative_prompt is None:
|
285 |
-
uncond_tokens = [""] * batch_size
|
286 |
-
elif type(prompt) is not type(negative_prompt):
|
287 |
-
raise TypeError(
|
288 |
-
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
289 |
-
f" {type(prompt)}."
|
290 |
-
)
|
291 |
-
elif isinstance(negative_prompt, str):
|
292 |
-
uncond_tokens = [negative_prompt]
|
293 |
-
elif batch_size != len(negative_prompt):
|
294 |
-
raise ValueError(
|
295 |
-
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
296 |
-
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
297 |
-
" the batch size of `prompt`."
|
298 |
-
)
|
299 |
-
else:
|
300 |
-
uncond_tokens = negative_prompt
|
301 |
-
|
302 |
-
# textual inversion: procecss multi-vector tokens if necessary
|
303 |
-
if isinstance(self, TextualInversionLoaderMixin):
|
304 |
-
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
305 |
-
|
306 |
-
max_length = prompt_embeds.shape[1]
|
307 |
-
uncond_input = self.tokenizer(
|
308 |
-
uncond_tokens,
|
309 |
-
padding="max_length",
|
310 |
-
max_length=max_length,
|
311 |
-
truncation=True,
|
312 |
-
return_tensors="pt",
|
313 |
-
)
|
314 |
-
|
315 |
-
if (
|
316 |
-
hasattr(self.text_encoder.config, "use_attention_mask")
|
317 |
-
and self.text_encoder.config.use_attention_mask
|
318 |
-
):
|
319 |
-
attention_mask = uncond_input.attention_mask.to(device)
|
320 |
-
else:
|
321 |
-
attention_mask = None
|
322 |
-
|
323 |
-
negative_prompt_embeds = self.text_encoder(
|
324 |
-
uncond_input.input_ids.to(device),
|
325 |
-
attention_mask=attention_mask,
|
326 |
-
)
|
327 |
-
negative_prompt_embeds = negative_prompt_embeds[0]
|
328 |
-
|
329 |
-
if do_classifier_free_guidance:
|
330 |
-
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
331 |
-
seq_len = negative_prompt_embeds.shape[1]
|
332 |
-
|
333 |
-
negative_prompt_embeds = negative_prompt_embeds.to(
|
334 |
-
dtype=self.text_encoder.dtype, device=device
|
335 |
-
)
|
336 |
-
|
337 |
-
negative_prompt_embeds = negative_prompt_embeds.repeat(
|
338 |
-
1, num_images_per_prompt, 1
|
339 |
-
)
|
340 |
-
negative_prompt_embeds = negative_prompt_embeds.view(
|
341 |
-
batch_size * num_images_per_prompt, seq_len, -1
|
342 |
-
)
|
343 |
-
|
344 |
-
# For classifier free guidance, we need to do two forward passes.
|
345 |
-
# Here we concatenate the unconditional and text embeddings into a single batch
|
346 |
-
# to avoid doing two forward passes
|
347 |
-
# pix2pix has two negative embeddings, and unlike in other pipelines latents are ordered [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds]
|
348 |
-
prompt_embeds = torch.cat(
|
349 |
-
[prompt_embeds, negative_prompt_embeds, negative_prompt_embeds]
|
350 |
-
)
|
351 |
-
|
352 |
-
return prompt_embeds
|
353 |
-
|
354 |
-
def prepare_extra_step_kwargs(self, generator, eta):
|
355 |
-
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
356 |
-
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
357 |
-
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
358 |
-
# and should be between [0, 1]
|
359 |
-
|
360 |
-
accepts_eta = "eta" in set(
|
361 |
-
inspect.signature(self.scheduler.step).parameters.keys()
|
362 |
-
)
|
363 |
-
extra_step_kwargs = {}
|
364 |
-
if accepts_eta:
|
365 |
-
extra_step_kwargs["eta"] = eta
|
366 |
-
|
367 |
-
# check if the scheduler accepts generator
|
368 |
-
accepts_generator = "generator" in set(
|
369 |
-
inspect.signature(self.scheduler.step).parameters.keys()
|
370 |
-
)
|
371 |
-
if accepts_generator:
|
372 |
-
extra_step_kwargs["generator"] = generator
|
373 |
-
return extra_step_kwargs
|
374 |
-
|
375 |
-
def check_inputs(
|
376 |
-
self,
|
377 |
-
prompt,
|
378 |
-
callback_steps,
|
379 |
-
negative_prompt=None,
|
380 |
-
prompt_embeds=None,
|
381 |
-
negative_prompt_embeds=None,
|
382 |
-
):
|
383 |
-
if (callback_steps is None) or (
|
384 |
-
callback_steps is not None
|
385 |
-
and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
386 |
-
):
|
387 |
-
raise ValueError(
|
388 |
-
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
389 |
-
f" {type(callback_steps)}."
|
390 |
-
)
|
391 |
-
|
392 |
-
if prompt is not None and prompt_embeds is not None:
|
393 |
-
raise ValueError(
|
394 |
-
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
395 |
-
" only forward one of the two."
|
396 |
-
)
|
397 |
-
elif prompt is None and prompt_embeds is None:
|
398 |
-
raise ValueError(
|
399 |
-
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
400 |
-
)
|
401 |
-
elif prompt is not None and (
|
402 |
-
not isinstance(prompt, str) and not isinstance(prompt, list)
|
403 |
-
):
|
404 |
-
raise ValueError(
|
405 |
-
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
|
406 |
-
)
|
407 |
-
|
408 |
-
if negative_prompt is not None and negative_prompt_embeds is not None:
|
409 |
-
raise ValueError(
|
410 |
-
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
411 |
-
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
412 |
-
)
|
413 |
-
|
414 |
-
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
415 |
-
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
416 |
-
raise ValueError(
|
417 |
-
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
418 |
-
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
419 |
-
f" {negative_prompt_embeds.shape}."
|
420 |
-
)
|
421 |
-
|
422 |
-
def prepare_latents(
|
423 |
-
self,
|
424 |
-
batch_size,
|
425 |
-
num_channels_latents,
|
426 |
-
height,
|
427 |
-
width,
|
428 |
-
dtype,
|
429 |
-
device,
|
430 |
-
generator,
|
431 |
-
latents=None,
|
432 |
-
):
|
433 |
-
shape = (
|
434 |
-
batch_size,
|
435 |
-
num_channels_latents,
|
436 |
-
height // self.vae_scale_factor,
|
437 |
-
width // self.vae_scale_factor,
|
438 |
-
)
|
439 |
-
if isinstance(generator, list) and len(generator) != batch_size:
|
440 |
-
raise ValueError(
|
441 |
-
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
442 |
-
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
443 |
-
)
|
444 |
-
|
445 |
-
if latents is None:
|
446 |
-
latents = randn_tensor(
|
447 |
-
shape, generator=generator, device=device, dtype=dtype
|
448 |
-
)
|
449 |
-
else:
|
450 |
-
latents = latents.to(device)
|
451 |
-
|
452 |
-
# scale the initial noise by the standard deviation required by the scheduler
|
453 |
-
latents = latents * self.scheduler.init_noise_sigma
|
454 |
-
return latents
|
455 |
-
|
456 |
-
def prepare_image_latents(
|
457 |
-
self,
|
458 |
-
image,
|
459 |
-
batch_size,
|
460 |
-
num_images_per_prompt,
|
461 |
-
dtype,
|
462 |
-
device,
|
463 |
-
do_classifier_free_guidance,
|
464 |
-
generator=None,
|
465 |
-
):
|
466 |
-
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
|
467 |
-
raise ValueError(
|
468 |
-
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
469 |
-
)
|
470 |
-
|
471 |
-
image = image.to(device=device, dtype=dtype)
|
472 |
-
|
473 |
-
batch_size = batch_size * num_images_per_prompt
|
474 |
-
|
475 |
-
if image.shape[1] == 4:
|
476 |
-
image_latents = image
|
477 |
-
else:
|
478 |
-
if isinstance(generator, list) and len(generator) != batch_size:
|
479 |
-
raise ValueError(
|
480 |
-
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
481 |
-
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
482 |
-
)
|
483 |
-
|
484 |
-
if isinstance(generator, list):
|
485 |
-
image_latents = [
|
486 |
-
self.vae.encode(image[i : i + 1]).latent_dist.mode()
|
487 |
-
for i in range(batch_size)
|
488 |
-
]
|
489 |
-
image_latents = torch.cat(image_latents, dim=0)
|
490 |
-
else:
|
491 |
-
image_latents = self.vae.encode(image).latent_dist.mode()
|
492 |
-
|
493 |
-
if (
|
494 |
-
batch_size > image_latents.shape[0]
|
495 |
-
and batch_size % image_latents.shape[0] == 0
|
496 |
-
):
|
497 |
-
# expand image_latents for batch_size
|
498 |
-
deprecation_message = (
|
499 |
-
f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial"
|
500 |
-
" images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
|
501 |
-
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
|
502 |
-
" your script to pass as many initial images as text prompts to suppress this warning."
|
503 |
-
)
|
504 |
-
deprecate(
|
505 |
-
"len(prompt) != len(image)",
|
506 |
-
"1.0.0",
|
507 |
-
deprecation_message,
|
508 |
-
standard_warn=False,
|
509 |
-
)
|
510 |
-
additional_image_per_prompt = batch_size // image_latents.shape[0]
|
511 |
-
image_latents = torch.cat(
|
512 |
-
[image_latents] * additional_image_per_prompt, dim=0
|
513 |
-
)
|
514 |
-
elif (
|
515 |
-
batch_size > image_latents.shape[0]
|
516 |
-
and batch_size % image_latents.shape[0] != 0
|
517 |
-
):
|
518 |
-
raise ValueError(
|
519 |
-
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
|
520 |
-
)
|
521 |
-
else:
|
522 |
-
image_latents = torch.cat([image_latents], dim=0)
|
523 |
-
|
524 |
-
if do_classifier_free_guidance:
|
525 |
-
uncond_image_latents = torch.zeros_like(image_latents)
|
526 |
-
image_latents = torch.cat(
|
527 |
-
[image_latents, image_latents, uncond_image_latents], dim=0
|
528 |
-
)
|
529 |
-
|
530 |
-
return image_latents
|
531 |
-
|
532 |
-
@torch.no_grad()
|
533 |
-
def __call__(
|
534 |
-
self,
|
535 |
-
prompt: Union[str, List[str]] = None,
|
536 |
-
photo: Union[
|
537 |
-
torch.FloatTensor,
|
538 |
-
PIL.Image.Image,
|
539 |
-
np.ndarray,
|
540 |
-
List[torch.FloatTensor],
|
541 |
-
List[PIL.Image.Image],
|
542 |
-
List[np.ndarray],
|
543 |
-
] = None,
|
544 |
-
height: Optional[int] = None,
|
545 |
-
width: Optional[int] = None,
|
546 |
-
num_inference_steps: int = 100,
|
547 |
-
required_aovs: List[str] = ["albedo"],
|
548 |
-
negative_prompt: Optional[Union[str, List[str]]] = None,
|
549 |
-
num_images_per_prompt: Optional[int] = 1,
|
550 |
-
use_default_scaling_factor: Optional[bool] = False,
|
551 |
-
guidance_scale: float = 0.0,
|
552 |
-
image_guidance_scale: float = 0.0,
|
553 |
-
guidance_rescale: float = 0.0,
|
554 |
-
eta: float = 0.0,
|
555 |
-
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
556 |
-
latents: Optional[torch.FloatTensor] = None,
|
557 |
-
prompt_embeds: Optional[torch.FloatTensor] = None,
|
558 |
-
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
559 |
-
output_type: Optional[str] = "pil",
|
560 |
-
return_dict: bool = True,
|
561 |
-
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
562 |
-
callback_steps: int = 1,
|
563 |
-
):
|
564 |
-
r"""
|
565 |
-
The call function to the pipeline for generation.
|
566 |
-
|
567 |
-
Args:
|
568 |
-
prompt (`str` or `List[str]`, *optional*):
|
569 |
-
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
570 |
-
image (`torch.FloatTensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
571 |
-
`Image` or tensor representing an image batch to be repainted according to `prompt`. Can also accept
|
572 |
-
image latents as `image`, but if passing latents directly it is not encoded again.
|
573 |
-
num_inference_steps (`int`, *optional*, defaults to 100):
|
574 |
-
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
575 |
-
expense of slower inference.
|
576 |
-
guidance_scale (`float`, *optional*, defaults to 7.5):
|
577 |
-
A higher guidance scale value encourages the model to generate images closely linked to the text
|
578 |
-
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
579 |
-
image_guidance_scale (`float`, *optional*, defaults to 1.5):
|
580 |
-
Push the generated image towards the inital `image`. Image guidance scale is enabled by setting
|
581 |
-
`image_guidance_scale > 1`. Higher image guidance scale encourages generated images that are closely
|
582 |
-
linked to the source `image`, usually at the expense of lower image quality. This pipeline requires a
|
583 |
-
value of at least `1`.
|
584 |
-
negative_prompt (`str` or `List[str]`, *optional*):
|
585 |
-
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
586 |
-
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
587 |
-
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
588 |
-
The number of images to generate per prompt.
|
589 |
-
eta (`float`, *optional*, defaults to 0.0):
|
590 |
-
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
591 |
-
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
592 |
-
generator (`torch.Generator`, *optional*):
|
593 |
-
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
594 |
-
generation deterministic.
|
595 |
-
latents (`torch.FloatTensor`, *optional*):
|
596 |
-
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
597 |
-
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
598 |
-
tensor is generated by sampling using the supplied random `generator`.
|
599 |
-
prompt_embeds (`torch.FloatTensor`, *optional*):
|
600 |
-
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
601 |
-
provided, text embeddings are generated from the `prompt` input argument.
|
602 |
-
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
603 |
-
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
604 |
-
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
605 |
-
output_type (`str`, *optional*, defaults to `"pil"`):
|
606 |
-
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
607 |
-
return_dict (`bool`, *optional*, defaults to `True`):
|
608 |
-
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
609 |
-
plain tuple.
|
610 |
-
callback (`Callable`, *optional*):
|
611 |
-
A function that calls every `callback_steps` steps during inference. The function is called with the
|
612 |
-
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
613 |
-
callback_steps (`int`, *optional*, defaults to 1):
|
614 |
-
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
615 |
-
every step.
|
616 |
-
|
617 |
-
Examples:
|
618 |
-
|
619 |
-
```py
|
620 |
-
>>> import PIL
|
621 |
-
>>> import requests
|
622 |
-
>>> import torch
|
623 |
-
>>> from io import BytesIO
|
624 |
-
|
625 |
-
>>> from diffusers import StableDiffusionInstructPix2PixPipeline
|
626 |
-
|
627 |
-
|
628 |
-
>>> def download_image(url):
|
629 |
-
... response = requests.get(url)
|
630 |
-
... return PIL.Image.open(BytesIO(response.content)).convert("RGB")
|
631 |
-
|
632 |
-
|
633 |
-
>>> img_url = "https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png"
|
634 |
-
|
635 |
-
>>> image = download_image(img_url).resize((512, 512))
|
636 |
-
|
637 |
-
>>> pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
|
638 |
-
... "timbrooks/instruct-pix2pix", torch_dtype=torch.float16
|
639 |
-
... )
|
640 |
-
>>> pipe = pipe.to("cuda")
|
641 |
-
|
642 |
-
>>> prompt = "make the mountains snowy"
|
643 |
-
>>> image = pipe(prompt=prompt, image=image).images[0]
|
644 |
-
```
|
645 |
-
|
646 |
-
Returns:
|
647 |
-
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
648 |
-
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
649 |
-
otherwise a `tuple` is returned where the first element is a list with the generated images and the
|
650 |
-
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
651 |
-
"not-safe-for-work" (nsfw) content.
|
652 |
-
"""
|
653 |
-
# 0. Check inputs
|
654 |
-
self.check_inputs(
|
655 |
-
prompt,
|
656 |
-
callback_steps,
|
657 |
-
negative_prompt,
|
658 |
-
prompt_embeds,
|
659 |
-
negative_prompt_embeds,
|
660 |
-
)
|
661 |
-
|
662 |
-
# 1. Define call parameters
|
663 |
-
if prompt is not None and isinstance(prompt, str):
|
664 |
-
batch_size = 1
|
665 |
-
elif prompt is not None and isinstance(prompt, list):
|
666 |
-
batch_size = len(prompt)
|
667 |
-
else:
|
668 |
-
batch_size = prompt_embeds.shape[0]
|
669 |
-
|
670 |
-
device = self._execution_device
|
671 |
-
do_classifier_free_guidance = (
|
672 |
-
guidance_scale > 1.0 and image_guidance_scale >= 1.0
|
673 |
-
)
|
674 |
-
# check if scheduler is in sigmas space
|
675 |
-
scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas")
|
676 |
-
|
677 |
-
# 2. Encode input prompt
|
678 |
-
prompt_embeds = self._encode_prompt(
|
679 |
-
prompt,
|
680 |
-
device,
|
681 |
-
num_images_per_prompt,
|
682 |
-
do_classifier_free_guidance,
|
683 |
-
negative_prompt,
|
684 |
-
prompt_embeds=prompt_embeds,
|
685 |
-
negative_prompt_embeds=negative_prompt_embeds,
|
686 |
-
)
|
687 |
-
|
688 |
-
# 3. Preprocess image
|
689 |
-
# Normalize image to [-1,1]
|
690 |
-
preprocessed_photo = self.image_processor.preprocess(photo)
|
691 |
-
|
692 |
-
# 4. set timesteps
|
693 |
-
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
694 |
-
timesteps = self.scheduler.timesteps
|
695 |
-
|
696 |
-
# 5. Prepare Image latents
|
697 |
-
image_latents = self.prepare_image_latents(
|
698 |
-
preprocessed_photo,
|
699 |
-
batch_size,
|
700 |
-
num_images_per_prompt,
|
701 |
-
prompt_embeds.dtype,
|
702 |
-
device,
|
703 |
-
do_classifier_free_guidance,
|
704 |
-
generator,
|
705 |
-
)
|
706 |
-
image_latents = image_latents * self.vae.config.scaling_factor
|
707 |
-
|
708 |
-
height, width = image_latents.shape[-2:]
|
709 |
-
height = height * self.vae_scale_factor
|
710 |
-
width = width * self.vae_scale_factor
|
711 |
-
|
712 |
-
# 6. Prepare latent variables
|
713 |
-
num_channels_latents = self.unet.config.out_channels
|
714 |
-
latents = self.prepare_latents(
|
715 |
-
batch_size * num_images_per_prompt,
|
716 |
-
num_channels_latents,
|
717 |
-
height,
|
718 |
-
width,
|
719 |
-
prompt_embeds.dtype,
|
720 |
-
device,
|
721 |
-
generator,
|
722 |
-
latents,
|
723 |
-
)
|
724 |
-
|
725 |
-
# 7. Check that shapes of latents and image match the UNet channels
|
726 |
-
num_channels_image = image_latents.shape[1]
|
727 |
-
if num_channels_latents + num_channels_image != self.unet.config.in_channels:
|
728 |
-
raise ValueError(
|
729 |
-
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
|
730 |
-
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
731 |
-
f" `num_channels_image`: {num_channels_image} "
|
732 |
-
f" = {num_channels_latents+num_channels_image}. Please verify the config of"
|
733 |
-
" `pipeline.unet` or your `image` input."
|
734 |
-
)
|
735 |
-
|
736 |
-
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
737 |
-
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
738 |
-
|
739 |
-
# 9. Denoising loop
|
740 |
-
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
741 |
-
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
742 |
-
for i, t in enumerate(timesteps):
|
743 |
-
# Expand the latents if we are doing classifier free guidance.
|
744 |
-
# The latents are expanded 3 times because for pix2pix the guidance\
|
745 |
-
# is applied for both the text and the input image.
|
746 |
-
latent_model_input = (
|
747 |
-
torch.cat([latents] * 3) if do_classifier_free_guidance else latents
|
748 |
-
)
|
749 |
-
|
750 |
-
# concat latents, image_latents in the channel dimension
|
751 |
-
scaled_latent_model_input = self.scheduler.scale_model_input(
|
752 |
-
latent_model_input, t
|
753 |
-
)
|
754 |
-
scaled_latent_model_input = torch.cat(
|
755 |
-
[scaled_latent_model_input, image_latents], dim=1
|
756 |
-
)
|
757 |
-
|
758 |
-
# predict the noise residual
|
759 |
-
noise_pred = self.unet(
|
760 |
-
scaled_latent_model_input,
|
761 |
-
t,
|
762 |
-
encoder_hidden_states=prompt_embeds,
|
763 |
-
return_dict=False,
|
764 |
-
)[0]
|
765 |
-
|
766 |
-
# perform guidance
|
767 |
-
if do_classifier_free_guidance:
|
768 |
-
(
|
769 |
-
noise_pred_text,
|
770 |
-
noise_pred_image,
|
771 |
-
noise_pred_uncond,
|
772 |
-
) = noise_pred.chunk(3)
|
773 |
-
noise_pred = (
|
774 |
-
noise_pred_uncond
|
775 |
-
+ guidance_scale * (noise_pred_text - noise_pred_image)
|
776 |
-
+ image_guidance_scale * (noise_pred_image - noise_pred_uncond)
|
777 |
-
)
|
778 |
-
|
779 |
-
if do_classifier_free_guidance and guidance_rescale > 0.0:
|
780 |
-
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
781 |
-
noise_pred = rescale_noise_cfg(
|
782 |
-
noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
|
783 |
-
)
|
784 |
-
|
785 |
-
# compute the previous noisy sample x_t -> x_t-1
|
786 |
-
latents = self.scheduler.step(
|
787 |
-
noise_pred, t, latents, **extra_step_kwargs, return_dict=False
|
788 |
-
)[0]
|
789 |
-
|
790 |
-
# call the callback, if provided
|
791 |
-
if i == len(timesteps) - 1 or (
|
792 |
-
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
793 |
-
):
|
794 |
-
progress_bar.update()
|
795 |
-
if callback is not None and i % callback_steps == 0:
|
796 |
-
callback(i, t, latents)
|
797 |
-
|
798 |
-
aov_latents = latents / self.vae.config.scaling_factor
|
799 |
-
aov = self.vae.decode(aov_latents, return_dict=False)[0]
|
800 |
-
do_denormalize = [True] * aov.shape[0]
|
801 |
-
aov_name = required_aovs[0]
|
802 |
-
if aov_name == "albedo" or aov_name == "irradiance":
|
803 |
-
do_gamma_correction = True
|
804 |
-
else:
|
805 |
-
do_gamma_correction = False
|
806 |
-
|
807 |
-
if aov_name == "roughness" or aov_name == "metallic":
|
808 |
-
aov = aov[:, 0:1].repeat(1, 3, 1, 1)
|
809 |
-
|
810 |
-
aov = self.image_processor.postprocess(
|
811 |
-
aov,
|
812 |
-
output_type=output_type,
|
813 |
-
do_denormalize=do_denormalize,
|
814 |
-
do_gamma_correction=do_gamma_correction,
|
815 |
-
)
|
816 |
-
aovs = [aov]
|
817 |
-
|
818 |
-
# Offload last model to CPU
|
819 |
-
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
820 |
-
self.final_offload_hook.offload()
|
821 |
-
return StableDiffusionAOVPipelineOutput(images=aovs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
run.sh
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
CONDA_ENV=$(head -1 /code/environment.yml | cut -d" " -f2)
|
3 |
+
eval "$(conda shell.bash hook)"
|
4 |
+
conda activate $CONDA_ENV
|
5 |
+
python app.py
|
settings.py
DELETED
@@ -1,23 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
|
3 |
-
import numpy as np
|
4 |
-
|
5 |
-
DEFAULT_MODEL_ID = os.getenv("DEFAULT_MODEL_ID", "stable-diffusion-v1-5/stable-diffusion-v1-5")
|
6 |
-
|
7 |
-
MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "3"))
|
8 |
-
DEFAULT_NUM_IMAGES = min(MAX_NUM_IMAGES, int(os.getenv("DEFAULT_NUM_IMAGES", "1")))
|
9 |
-
MAX_IMAGE_RESOLUTION = int(os.getenv("MAX_IMAGE_RESOLUTION", "2048"))
|
10 |
-
DEFAULT_IMAGE_RESOLUTION = min(MAX_IMAGE_RESOLUTION, int(os.getenv("DEFAULT_IMAGE_RESOLUTION", "1024")))
|
11 |
-
|
12 |
-
ALLOW_CHANGING_BASE_MODEL = os.getenv("SPACE_ID") != "hysts/ControlNet-v1-1"
|
13 |
-
SHOW_DUPLICATE_BUTTON = os.getenv("SHOW_DUPLICATE_BUTTON") == "1"
|
14 |
-
|
15 |
-
MAX_SEED = np.iinfo(np.int32).max
|
16 |
-
|
17 |
-
# Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
|
18 |
-
|
19 |
-
# setup CUDA
|
20 |
-
# disable the following when deployting to hugging face
|
21 |
-
# if os.getenv("CUDA_VISIBLE_DEVICES") is None:
|
22 |
-
# os.environ["CUDA_VISIBLE_DEVICES"] = "7"
|
23 |
-
# os.environ["GRADIO_SERVER_PORT"] = "7864"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
text2tex/lib/__init__.py
DELETED
File without changes
|
text2tex/lib/camera_helper.py
DELETED
@@ -1,231 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
|
3 |
-
import numpy as np
|
4 |
-
|
5 |
-
from sklearn.metrics.pairwise import cosine_similarity
|
6 |
-
|
7 |
-
from pytorch3d.renderer import (
|
8 |
-
PerspectiveCameras,
|
9 |
-
look_at_view_transform
|
10 |
-
)
|
11 |
-
|
12 |
-
# customized
|
13 |
-
import sys
|
14 |
-
sys.path.append(".")
|
15 |
-
|
16 |
-
from lib.constants import VIEWPOINTS
|
17 |
-
|
18 |
-
# ---------------- UTILS ----------------------
|
19 |
-
|
20 |
-
def degree_to_radian(d):
|
21 |
-
return d * np.pi / 180
|
22 |
-
|
23 |
-
def radian_to_degree(r):
|
24 |
-
return 180 * r / np.pi
|
25 |
-
|
26 |
-
def xyz_to_polar(xyz):
|
27 |
-
""" assume y-axis is the up axis """
|
28 |
-
|
29 |
-
x, y, z = xyz
|
30 |
-
|
31 |
-
theta = 180 * np.arccos(z) / np.pi
|
32 |
-
phi = 180 * np.arccos(y) / np.pi
|
33 |
-
|
34 |
-
return theta, phi
|
35 |
-
|
36 |
-
def polar_to_xyz(theta, phi, dist):
|
37 |
-
""" assume y-axis is the up axis """
|
38 |
-
|
39 |
-
theta = degree_to_radian(theta)
|
40 |
-
phi = degree_to_radian(phi)
|
41 |
-
|
42 |
-
x = np.sin(phi) * np.sin(theta) * dist
|
43 |
-
y = np.cos(phi) * dist
|
44 |
-
z = np.sin(phi) * np.cos(theta) * dist
|
45 |
-
|
46 |
-
return [x, y, z]
|
47 |
-
|
48 |
-
|
49 |
-
# ---------------- VIEWPOINTS ----------------------
|
50 |
-
|
51 |
-
|
52 |
-
def filter_viewpoints(pre_viewpoints: dict, viewpoints: dict):
|
53 |
-
""" return the binary mask of viewpoints to be filtered """
|
54 |
-
|
55 |
-
filter_mask = [0 for _ in viewpoints.keys()]
|
56 |
-
for i, v in viewpoints.items():
|
57 |
-
x_v, y_v, z_v = polar_to_xyz(v["azim"], 90 - v["elev"], v["dist"])
|
58 |
-
|
59 |
-
for _, pv in pre_viewpoints.items():
|
60 |
-
x_pv, y_pv, z_pv = polar_to_xyz(pv["azim"], 90 - pv["elev"], pv["dist"])
|
61 |
-
sim = cosine_similarity(
|
62 |
-
np.array([[x_v, y_v, z_v]]),
|
63 |
-
np.array([[x_pv, y_pv, z_pv]])
|
64 |
-
)[0, 0]
|
65 |
-
|
66 |
-
if sim > 0.9:
|
67 |
-
filter_mask[i] = 1
|
68 |
-
|
69 |
-
return filter_mask
|
70 |
-
|
71 |
-
|
72 |
-
def init_viewpoints(mode, sample_space, init_dist, init_elev, principle_directions,
|
73 |
-
use_principle=True, use_shapenet=False, use_objaverse=False):
|
74 |
-
|
75 |
-
if mode == "predefined":
|
76 |
-
|
77 |
-
(
|
78 |
-
dist_list,
|
79 |
-
elev_list,
|
80 |
-
azim_list,
|
81 |
-
sector_list
|
82 |
-
) = init_predefined_viewpoints(sample_space, init_dist, init_elev)
|
83 |
-
|
84 |
-
elif mode == "hemisphere":
|
85 |
-
|
86 |
-
(
|
87 |
-
dist_list,
|
88 |
-
elev_list,
|
89 |
-
azim_list,
|
90 |
-
sector_list
|
91 |
-
) = init_hemisphere_viewpoints(sample_space, init_dist)
|
92 |
-
|
93 |
-
else:
|
94 |
-
raise NotImplementedError()
|
95 |
-
|
96 |
-
# punishments for views -> in case always selecting the same view
|
97 |
-
view_punishments = [1 for _ in range(len(dist_list))]
|
98 |
-
|
99 |
-
if use_principle:
|
100 |
-
|
101 |
-
(
|
102 |
-
dist_list,
|
103 |
-
elev_list,
|
104 |
-
azim_list,
|
105 |
-
sector_list,
|
106 |
-
view_punishments
|
107 |
-
) = init_principle_viewpoints(
|
108 |
-
principle_directions,
|
109 |
-
dist_list,
|
110 |
-
elev_list,
|
111 |
-
azim_list,
|
112 |
-
sector_list,
|
113 |
-
view_punishments,
|
114 |
-
use_shapenet,
|
115 |
-
use_objaverse
|
116 |
-
)
|
117 |
-
|
118 |
-
return dist_list, elev_list, azim_list, sector_list, view_punishments
|
119 |
-
|
120 |
-
|
121 |
-
def init_principle_viewpoints(
|
122 |
-
principle_directions,
|
123 |
-
dist_list,
|
124 |
-
elev_list,
|
125 |
-
azim_list,
|
126 |
-
sector_list,
|
127 |
-
view_punishments,
|
128 |
-
use_shapenet=False,
|
129 |
-
use_objaverse=False
|
130 |
-
):
|
131 |
-
|
132 |
-
if use_shapenet:
|
133 |
-
key = "shapenet"
|
134 |
-
|
135 |
-
pre_elev_list = [v for v in VIEWPOINTS[key]["elev"]]
|
136 |
-
pre_azim_list = [v for v in VIEWPOINTS[key]["azim"]]
|
137 |
-
pre_sector_list = [v for v in VIEWPOINTS[key]["sector"]]
|
138 |
-
|
139 |
-
num_principle = 10
|
140 |
-
pre_dist_list = [dist_list[0] for _ in range(num_principle)]
|
141 |
-
pre_view_punishments = [0 for _ in range(num_principle)]
|
142 |
-
|
143 |
-
elif use_objaverse:
|
144 |
-
key = "objaverse"
|
145 |
-
|
146 |
-
pre_elev_list = [v for v in VIEWPOINTS[key]["elev"]]
|
147 |
-
pre_azim_list = [v for v in VIEWPOINTS[key]["azim"]]
|
148 |
-
pre_sector_list = [v for v in VIEWPOINTS[key]["sector"]]
|
149 |
-
|
150 |
-
num_principle = 10
|
151 |
-
pre_dist_list = [dist_list[0] for _ in range(num_principle)]
|
152 |
-
pre_view_punishments = [0 for _ in range(num_principle)]
|
153 |
-
else:
|
154 |
-
num_principle = 6
|
155 |
-
pre_elev_list = [v for v in VIEWPOINTS[num_principle]["elev"]]
|
156 |
-
pre_azim_list = [v for v in VIEWPOINTS[num_principle]["azim"]]
|
157 |
-
pre_sector_list = [v for v in VIEWPOINTS[num_principle]["sector"]]
|
158 |
-
pre_dist_list = [dist_list[0] for _ in range(num_principle)]
|
159 |
-
pre_view_punishments = [0 for _ in range(num_principle)]
|
160 |
-
|
161 |
-
dist_list = pre_dist_list + dist_list
|
162 |
-
elev_list = pre_elev_list + elev_list
|
163 |
-
azim_list = pre_azim_list + azim_list
|
164 |
-
sector_list = pre_sector_list + sector_list
|
165 |
-
view_punishments = pre_view_punishments + view_punishments
|
166 |
-
|
167 |
-
return dist_list, elev_list, azim_list, sector_list, view_punishments
|
168 |
-
|
169 |
-
|
170 |
-
def init_predefined_viewpoints(sample_space, init_dist, init_elev):
|
171 |
-
|
172 |
-
viewpoints = VIEWPOINTS[sample_space]
|
173 |
-
|
174 |
-
assert sample_space == len(viewpoints["sector"])
|
175 |
-
|
176 |
-
dist_list = [init_dist for _ in range(sample_space)] # always the same dist
|
177 |
-
elev_list = [viewpoints["elev"][i] for i in range(sample_space)]
|
178 |
-
azim_list = [viewpoints["azim"][i] for i in range(sample_space)]
|
179 |
-
sector_list = [viewpoints["sector"][i] for i in range(sample_space)]
|
180 |
-
|
181 |
-
return dist_list, elev_list, azim_list, sector_list
|
182 |
-
|
183 |
-
|
184 |
-
def init_hemisphere_viewpoints(sample_space, init_dist):
|
185 |
-
"""
|
186 |
-
y is up-axis
|
187 |
-
"""
|
188 |
-
|
189 |
-
num_points = 2 * sample_space
|
190 |
-
ga = np.pi * (3. - np.sqrt(5.)) # golden angle in radians
|
191 |
-
|
192 |
-
flags = []
|
193 |
-
elev_list = [] # degree
|
194 |
-
azim_list = [] # degree
|
195 |
-
|
196 |
-
for i in range(num_points):
|
197 |
-
y = 1 - (i / float(num_points - 1)) * 2 # y goes from 1 to -1
|
198 |
-
|
199 |
-
# only take the north hemisphere
|
200 |
-
if y >= 0:
|
201 |
-
flags.append(True)
|
202 |
-
else:
|
203 |
-
flags.append(False)
|
204 |
-
|
205 |
-
theta = ga * i # golden angle increment
|
206 |
-
|
207 |
-
elev_list.append(radian_to_degree(np.arcsin(y)))
|
208 |
-
azim_list.append(radian_to_degree(theta))
|
209 |
-
|
210 |
-
radius = np.sqrt(1 - y * y) # radius at y
|
211 |
-
x = np.cos(theta) * radius
|
212 |
-
z = np.sin(theta) * radius
|
213 |
-
|
214 |
-
elev_list = [elev_list[i] for i in range(len(elev_list)) if flags[i]]
|
215 |
-
azim_list = [azim_list[i] for i in range(len(azim_list)) if flags[i]]
|
216 |
-
|
217 |
-
dist_list = [init_dist for _ in elev_list]
|
218 |
-
sector_list = ["good" for _ in elev_list] # HACK don't define sector names for now
|
219 |
-
|
220 |
-
return dist_list, elev_list, azim_list, sector_list
|
221 |
-
|
222 |
-
|
223 |
-
# ---------------- CAMERAS ----------------------
|
224 |
-
|
225 |
-
|
226 |
-
def init_camera(dist, elev, azim, image_size, device):
|
227 |
-
R, T = look_at_view_transform(dist, elev, azim)
|
228 |
-
image_size = torch.tensor([image_size, image_size]).unsqueeze(0)
|
229 |
-
cameras = PerspectiveCameras(R=R, T=T, device=device, image_size=image_size)
|
230 |
-
|
231 |
-
return cameras
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
text2tex/lib/constants.py
DELETED
@@ -1,648 +0,0 @@
|
|
1 |
-
PALETTE = {
|
2 |
-
0: [255, 255, 255], # white - background
|
3 |
-
1: [204, 50, 50], # red - old
|
4 |
-
2: [231, 180, 22], # yellow - update
|
5 |
-
3: [45, 201, 55] # green - new
|
6 |
-
}
|
7 |
-
|
8 |
-
QUAD_WEIGHTS = {
|
9 |
-
0: 0, # background
|
10 |
-
1: 0.1, # old
|
11 |
-
2: 0.5, # update
|
12 |
-
3: 1 # new
|
13 |
-
}
|
14 |
-
|
15 |
-
VIEWPOINTS = {
|
16 |
-
1: {
|
17 |
-
"azim": [
|
18 |
-
0
|
19 |
-
],
|
20 |
-
"elev": [
|
21 |
-
0
|
22 |
-
],
|
23 |
-
"sector": [
|
24 |
-
"front"
|
25 |
-
]
|
26 |
-
},
|
27 |
-
2: {
|
28 |
-
"azim": [
|
29 |
-
0,
|
30 |
-
30
|
31 |
-
],
|
32 |
-
"elev": [
|
33 |
-
0,
|
34 |
-
0
|
35 |
-
],
|
36 |
-
"sector": [
|
37 |
-
"front",
|
38 |
-
"front"
|
39 |
-
]
|
40 |
-
},
|
41 |
-
4: {
|
42 |
-
"azim": [
|
43 |
-
45,
|
44 |
-
315,
|
45 |
-
135,
|
46 |
-
225,
|
47 |
-
],
|
48 |
-
"elev": [
|
49 |
-
0,
|
50 |
-
0,
|
51 |
-
0,
|
52 |
-
0,
|
53 |
-
],
|
54 |
-
"sector": [
|
55 |
-
"front right",
|
56 |
-
"front left",
|
57 |
-
"back right",
|
58 |
-
"back left",
|
59 |
-
]
|
60 |
-
},
|
61 |
-
6: {
|
62 |
-
"azim": [
|
63 |
-
0,
|
64 |
-
90,
|
65 |
-
270,
|
66 |
-
0,
|
67 |
-
180,
|
68 |
-
0
|
69 |
-
],
|
70 |
-
"elev": [
|
71 |
-
0,
|
72 |
-
0,
|
73 |
-
0,
|
74 |
-
90,
|
75 |
-
0,
|
76 |
-
-90
|
77 |
-
],
|
78 |
-
"sector": [
|
79 |
-
"front",
|
80 |
-
"right",
|
81 |
-
"left",
|
82 |
-
"top",
|
83 |
-
"back",
|
84 |
-
"bottom",
|
85 |
-
]
|
86 |
-
},
|
87 |
-
"shapenet": {
|
88 |
-
"azim": [
|
89 |
-
270,
|
90 |
-
315,
|
91 |
-
225,
|
92 |
-
0,
|
93 |
-
180,
|
94 |
-
45,
|
95 |
-
135,
|
96 |
-
90,
|
97 |
-
270,
|
98 |
-
270
|
99 |
-
],
|
100 |
-
"elev": [
|
101 |
-
15,
|
102 |
-
15,
|
103 |
-
15,
|
104 |
-
15,
|
105 |
-
15,
|
106 |
-
15,
|
107 |
-
15,
|
108 |
-
15,
|
109 |
-
90,
|
110 |
-
-90
|
111 |
-
],
|
112 |
-
"sector": [
|
113 |
-
"front",
|
114 |
-
"front right",
|
115 |
-
"front left",
|
116 |
-
"right",
|
117 |
-
"left",
|
118 |
-
"back right",
|
119 |
-
"back left",
|
120 |
-
"back",
|
121 |
-
"top",
|
122 |
-
"bottom",
|
123 |
-
]
|
124 |
-
},
|
125 |
-
"objaverse": {
|
126 |
-
"azim": [
|
127 |
-
0,
|
128 |
-
45,
|
129 |
-
315,
|
130 |
-
90,
|
131 |
-
270,
|
132 |
-
135,
|
133 |
-
225,
|
134 |
-
180,
|
135 |
-
0,
|
136 |
-
0
|
137 |
-
],
|
138 |
-
"elev": [
|
139 |
-
15,
|
140 |
-
15,
|
141 |
-
15,
|
142 |
-
15,
|
143 |
-
15,
|
144 |
-
15,
|
145 |
-
15,
|
146 |
-
15,
|
147 |
-
90,
|
148 |
-
-90
|
149 |
-
],
|
150 |
-
"sector": [
|
151 |
-
"front",
|
152 |
-
"front right",
|
153 |
-
"front left",
|
154 |
-
"right",
|
155 |
-
"left",
|
156 |
-
"back right",
|
157 |
-
"back left",
|
158 |
-
"back",
|
159 |
-
"top",
|
160 |
-
"bottom",
|
161 |
-
]
|
162 |
-
},
|
163 |
-
12: {
|
164 |
-
"azim": [
|
165 |
-
45,
|
166 |
-
315,
|
167 |
-
135,
|
168 |
-
225,
|
169 |
-
|
170 |
-
0,
|
171 |
-
45,
|
172 |
-
315,
|
173 |
-
90,
|
174 |
-
270,
|
175 |
-
135,
|
176 |
-
225,
|
177 |
-
180,
|
178 |
-
],
|
179 |
-
"elev": [
|
180 |
-
0,
|
181 |
-
0,
|
182 |
-
0,
|
183 |
-
0,
|
184 |
-
|
185 |
-
45,
|
186 |
-
45,
|
187 |
-
45,
|
188 |
-
45,
|
189 |
-
45,
|
190 |
-
45,
|
191 |
-
45,
|
192 |
-
45,
|
193 |
-
],
|
194 |
-
"sector": [
|
195 |
-
"front right",
|
196 |
-
"front left",
|
197 |
-
"back right",
|
198 |
-
"back left",
|
199 |
-
|
200 |
-
"front",
|
201 |
-
"front right",
|
202 |
-
"front left",
|
203 |
-
"right",
|
204 |
-
"left",
|
205 |
-
"back right",
|
206 |
-
"back left",
|
207 |
-
"back",
|
208 |
-
]
|
209 |
-
},
|
210 |
-
20: {
|
211 |
-
"azim": [
|
212 |
-
45,
|
213 |
-
315,
|
214 |
-
135,
|
215 |
-
225,
|
216 |
-
|
217 |
-
0,
|
218 |
-
45,
|
219 |
-
315,
|
220 |
-
90,
|
221 |
-
270,
|
222 |
-
135,
|
223 |
-
225,
|
224 |
-
180,
|
225 |
-
|
226 |
-
0,
|
227 |
-
45,
|
228 |
-
315,
|
229 |
-
90,
|
230 |
-
270,
|
231 |
-
135,
|
232 |
-
225,
|
233 |
-
180,
|
234 |
-
],
|
235 |
-
"elev": [
|
236 |
-
0,
|
237 |
-
0,
|
238 |
-
0,
|
239 |
-
0,
|
240 |
-
|
241 |
-
30,
|
242 |
-
30,
|
243 |
-
30,
|
244 |
-
30,
|
245 |
-
30,
|
246 |
-
30,
|
247 |
-
30,
|
248 |
-
30,
|
249 |
-
|
250 |
-
60,
|
251 |
-
60,
|
252 |
-
60,
|
253 |
-
60,
|
254 |
-
60,
|
255 |
-
60,
|
256 |
-
60,
|
257 |
-
60,
|
258 |
-
],
|
259 |
-
"sector": [
|
260 |
-
"front right",
|
261 |
-
"front left",
|
262 |
-
"back right",
|
263 |
-
"back left",
|
264 |
-
|
265 |
-
"front",
|
266 |
-
"front right",
|
267 |
-
"front left",
|
268 |
-
"right",
|
269 |
-
"left",
|
270 |
-
"back right",
|
271 |
-
"back left",
|
272 |
-
"back",
|
273 |
-
|
274 |
-
"front",
|
275 |
-
"front right",
|
276 |
-
"front left",
|
277 |
-
"right",
|
278 |
-
"left",
|
279 |
-
"back right",
|
280 |
-
"back left",
|
281 |
-
"back",
|
282 |
-
]
|
283 |
-
},
|
284 |
-
36: {
|
285 |
-
"azim": [
|
286 |
-
45,
|
287 |
-
315,
|
288 |
-
135,
|
289 |
-
225,
|
290 |
-
|
291 |
-
0,
|
292 |
-
45,
|
293 |
-
315,
|
294 |
-
90,
|
295 |
-
270,
|
296 |
-
135,
|
297 |
-
225,
|
298 |
-
180,
|
299 |
-
|
300 |
-
0,
|
301 |
-
45,
|
302 |
-
315,
|
303 |
-
90,
|
304 |
-
270,
|
305 |
-
135,
|
306 |
-
225,
|
307 |
-
180,
|
308 |
-
|
309 |
-
22.5,
|
310 |
-
337.5,
|
311 |
-
67.5,
|
312 |
-
292.5,
|
313 |
-
112.5,
|
314 |
-
247.5,
|
315 |
-
157.5,
|
316 |
-
202.5,
|
317 |
-
|
318 |
-
22.5,
|
319 |
-
337.5,
|
320 |
-
67.5,
|
321 |
-
292.5,
|
322 |
-
112.5,
|
323 |
-
247.5,
|
324 |
-
157.5,
|
325 |
-
202.5,
|
326 |
-
],
|
327 |
-
"elev": [
|
328 |
-
0,
|
329 |
-
0,
|
330 |
-
0,
|
331 |
-
0,
|
332 |
-
|
333 |
-
30,
|
334 |
-
30,
|
335 |
-
30,
|
336 |
-
30,
|
337 |
-
30,
|
338 |
-
30,
|
339 |
-
30,
|
340 |
-
30,
|
341 |
-
|
342 |
-
60,
|
343 |
-
60,
|
344 |
-
60,
|
345 |
-
60,
|
346 |
-
60,
|
347 |
-
60,
|
348 |
-
60,
|
349 |
-
60,
|
350 |
-
|
351 |
-
15,
|
352 |
-
15,
|
353 |
-
15,
|
354 |
-
15,
|
355 |
-
15,
|
356 |
-
15,
|
357 |
-
15,
|
358 |
-
15,
|
359 |
-
|
360 |
-
45,
|
361 |
-
45,
|
362 |
-
45,
|
363 |
-
45,
|
364 |
-
45,
|
365 |
-
45,
|
366 |
-
45,
|
367 |
-
45,
|
368 |
-
],
|
369 |
-
"sector": [
|
370 |
-
"front right",
|
371 |
-
"front left",
|
372 |
-
"back right",
|
373 |
-
"back left",
|
374 |
-
|
375 |
-
"front",
|
376 |
-
"front right",
|
377 |
-
"front left",
|
378 |
-
"right",
|
379 |
-
"left",
|
380 |
-
"back right",
|
381 |
-
"back left",
|
382 |
-
"back",
|
383 |
-
|
384 |
-
"top front",
|
385 |
-
"top right",
|
386 |
-
"top left",
|
387 |
-
"top right",
|
388 |
-
"top left",
|
389 |
-
"top right",
|
390 |
-
"top left",
|
391 |
-
"top back",
|
392 |
-
|
393 |
-
"front right",
|
394 |
-
"front left",
|
395 |
-
"front right",
|
396 |
-
"front left",
|
397 |
-
"back right",
|
398 |
-
"back left",
|
399 |
-
"back right",
|
400 |
-
"back left",
|
401 |
-
|
402 |
-
"front right",
|
403 |
-
"front left",
|
404 |
-
"front right",
|
405 |
-
"front left",
|
406 |
-
"back right",
|
407 |
-
"back left",
|
408 |
-
"back right",
|
409 |
-
"back left",
|
410 |
-
]
|
411 |
-
},
|
412 |
-
68: {
|
413 |
-
"azim": [
|
414 |
-
45,
|
415 |
-
315,
|
416 |
-
135,
|
417 |
-
225,
|
418 |
-
|
419 |
-
0,
|
420 |
-
45,
|
421 |
-
315,
|
422 |
-
90,
|
423 |
-
270,
|
424 |
-
135,
|
425 |
-
225,
|
426 |
-
180,
|
427 |
-
|
428 |
-
0,
|
429 |
-
45,
|
430 |
-
315,
|
431 |
-
90,
|
432 |
-
270,
|
433 |
-
135,
|
434 |
-
225,
|
435 |
-
180,
|
436 |
-
|
437 |
-
22.5,
|
438 |
-
337.5,
|
439 |
-
67.5,
|
440 |
-
292.5,
|
441 |
-
112.5,
|
442 |
-
247.5,
|
443 |
-
157.5,
|
444 |
-
202.5,
|
445 |
-
|
446 |
-
22.5,
|
447 |
-
337.5,
|
448 |
-
67.5,
|
449 |
-
292.5,
|
450 |
-
112.5,
|
451 |
-
247.5,
|
452 |
-
157.5,
|
453 |
-
202.5,
|
454 |
-
|
455 |
-
0,
|
456 |
-
45,
|
457 |
-
315,
|
458 |
-
90,
|
459 |
-
270,
|
460 |
-
135,
|
461 |
-
225,
|
462 |
-
180,
|
463 |
-
|
464 |
-
0,
|
465 |
-
45,
|
466 |
-
315,
|
467 |
-
90,
|
468 |
-
270,
|
469 |
-
135,
|
470 |
-
225,
|
471 |
-
180,
|
472 |
-
|
473 |
-
22.5,
|
474 |
-
337.5,
|
475 |
-
67.5,
|
476 |
-
292.5,
|
477 |
-
112.5,
|
478 |
-
247.5,
|
479 |
-
157.5,
|
480 |
-
202.5,
|
481 |
-
|
482 |
-
22.5,
|
483 |
-
337.5,
|
484 |
-
67.5,
|
485 |
-
292.5,
|
486 |
-
112.5,
|
487 |
-
247.5,
|
488 |
-
157.5,
|
489 |
-
202.5
|
490 |
-
],
|
491 |
-
"elev": [
|
492 |
-
0,
|
493 |
-
0,
|
494 |
-
0,
|
495 |
-
0,
|
496 |
-
|
497 |
-
30,
|
498 |
-
30,
|
499 |
-
30,
|
500 |
-
30,
|
501 |
-
30,
|
502 |
-
30,
|
503 |
-
30,
|
504 |
-
30,
|
505 |
-
|
506 |
-
60,
|
507 |
-
60,
|
508 |
-
60,
|
509 |
-
60,
|
510 |
-
60,
|
511 |
-
60,
|
512 |
-
60,
|
513 |
-
60,
|
514 |
-
|
515 |
-
15,
|
516 |
-
15,
|
517 |
-
15,
|
518 |
-
15,
|
519 |
-
15,
|
520 |
-
15,
|
521 |
-
15,
|
522 |
-
15,
|
523 |
-
|
524 |
-
45,
|
525 |
-
45,
|
526 |
-
45,
|
527 |
-
45,
|
528 |
-
45,
|
529 |
-
45,
|
530 |
-
45,
|
531 |
-
45,
|
532 |
-
|
533 |
-
-30,
|
534 |
-
-30,
|
535 |
-
-30,
|
536 |
-
-30,
|
537 |
-
-30,
|
538 |
-
-30,
|
539 |
-
-30,
|
540 |
-
-30,
|
541 |
-
|
542 |
-
-60,
|
543 |
-
-60,
|
544 |
-
-60,
|
545 |
-
-60,
|
546 |
-
-60,
|
547 |
-
-60,
|
548 |
-
-60,
|
549 |
-
-60,
|
550 |
-
|
551 |
-
-15,
|
552 |
-
-15,
|
553 |
-
-15,
|
554 |
-
-15,
|
555 |
-
-15,
|
556 |
-
-15,
|
557 |
-
-15,
|
558 |
-
-15,
|
559 |
-
|
560 |
-
-45,
|
561 |
-
-45,
|
562 |
-
-45,
|
563 |
-
-45,
|
564 |
-
-45,
|
565 |
-
-45,
|
566 |
-
-45,
|
567 |
-
-45,
|
568 |
-
],
|
569 |
-
"sector": [
|
570 |
-
"front right",
|
571 |
-
"front left",
|
572 |
-
"back right",
|
573 |
-
"back left",
|
574 |
-
|
575 |
-
"front",
|
576 |
-
"front right",
|
577 |
-
"front left",
|
578 |
-
"right",
|
579 |
-
"left",
|
580 |
-
"back right",
|
581 |
-
"back left",
|
582 |
-
"back",
|
583 |
-
|
584 |
-
"top front",
|
585 |
-
"top right",
|
586 |
-
"top left",
|
587 |
-
"top right",
|
588 |
-
"top left",
|
589 |
-
"top right",
|
590 |
-
"top left",
|
591 |
-
"top back",
|
592 |
-
|
593 |
-
"front right",
|
594 |
-
"front left",
|
595 |
-
"front right",
|
596 |
-
"front left",
|
597 |
-
"back right",
|
598 |
-
"back left",
|
599 |
-
"back right",
|
600 |
-
"back left",
|
601 |
-
|
602 |
-
"front right",
|
603 |
-
"front left",
|
604 |
-
"front right",
|
605 |
-
"front left",
|
606 |
-
"back right",
|
607 |
-
"back left",
|
608 |
-
"back right",
|
609 |
-
"back left",
|
610 |
-
|
611 |
-
"front",
|
612 |
-
"front right",
|
613 |
-
"front left",
|
614 |
-
"right",
|
615 |
-
"left",
|
616 |
-
"back right",
|
617 |
-
"back left",
|
618 |
-
"back",
|
619 |
-
|
620 |
-
"bottom front",
|
621 |
-
"bottom right",
|
622 |
-
"bottom left",
|
623 |
-
"bottom right",
|
624 |
-
"bottom left",
|
625 |
-
"bottom right",
|
626 |
-
"bottom left",
|
627 |
-
"bottom back",
|
628 |
-
|
629 |
-
"bottom front right",
|
630 |
-
"bottom front left",
|
631 |
-
"bottom front right",
|
632 |
-
"bottom front left",
|
633 |
-
"bottom back right",
|
634 |
-
"bottom back left",
|
635 |
-
"bottom back right",
|
636 |
-
"bottom back left",
|
637 |
-
|
638 |
-
"bottom front right",
|
639 |
-
"bottom front left",
|
640 |
-
"bottom front right",
|
641 |
-
"bottom front left",
|
642 |
-
"bottom back right",
|
643 |
-
"bottom back left",
|
644 |
-
"bottom back right",
|
645 |
-
"bottom back left",
|
646 |
-
]
|
647 |
-
}
|
648 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
text2tex/lib/diffusion_helper.py
DELETED
@@ -1,189 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
|
3 |
-
import cv2
|
4 |
-
import numpy as np
|
5 |
-
|
6 |
-
from PIL import Image
|
7 |
-
from torchvision import transforms
|
8 |
-
|
9 |
-
# Stable Diffusion 2
|
10 |
-
from diffusers import (
|
11 |
-
StableDiffusionInpaintPipeline,
|
12 |
-
StableDiffusionPipeline,
|
13 |
-
EulerDiscreteScheduler
|
14 |
-
)
|
15 |
-
|
16 |
-
# customized
|
17 |
-
import sys
|
18 |
-
sys.path.append(".")
|
19 |
-
|
20 |
-
from models.ControlNet.gradio_depth2image import init_model, process
|
21 |
-
|
22 |
-
|
23 |
-
def get_controlnet_depth():
|
24 |
-
print("=> initializing ControlNet Depth...")
|
25 |
-
model, ddim_sampler = init_model()
|
26 |
-
|
27 |
-
return model, ddim_sampler
|
28 |
-
|
29 |
-
|
30 |
-
def get_inpainting(device):
|
31 |
-
print("=> initializing Inpainting...")
|
32 |
-
|
33 |
-
model = StableDiffusionInpaintPipeline.from_pretrained(
|
34 |
-
"stabilityai/stable-diffusion-2-inpainting",
|
35 |
-
torch_dtype=torch.float16,
|
36 |
-
).to(device)
|
37 |
-
|
38 |
-
return model
|
39 |
-
|
40 |
-
def get_text2image(device):
|
41 |
-
print("=> initializing Inpainting...")
|
42 |
-
|
43 |
-
model_id = "stabilityai/stable-diffusion-2"
|
44 |
-
scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
|
45 |
-
model = StableDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler, torch_dtype=torch.float16).to(device)
|
46 |
-
|
47 |
-
return model
|
48 |
-
|
49 |
-
|
50 |
-
@torch.no_grad()
|
51 |
-
def apply_controlnet_depth(model, ddim_sampler,
|
52 |
-
init_image, prompt, strength, ddim_steps,
|
53 |
-
generate_mask_image, keep_mask_image, depth_map_np,
|
54 |
-
a_prompt, n_prompt, guidance_scale, seed, eta, num_samples,
|
55 |
-
device, blend=0, save_memory=False):
|
56 |
-
"""
|
57 |
-
Use Stable Diffusion 2 to generate image
|
58 |
-
|
59 |
-
Arguments:
|
60 |
-
args: input arguments
|
61 |
-
model: Stable Diffusion 2 model
|
62 |
-
init_image_tensor: input image, torch.FloatTensor of shape (1, H, W, 3)
|
63 |
-
mask_tensor: depth map of the input image, torch.FloatTensor of shape (1, H, W, 1)
|
64 |
-
depth_map_np: depth map of the input image, torch.FloatTensor of shape (1, H, W)
|
65 |
-
"""
|
66 |
-
|
67 |
-
print("=> generating ControlNet Depth RePaint image...")
|
68 |
-
|
69 |
-
|
70 |
-
# Stable Diffusion 2 receives PIL.Image
|
71 |
-
# NOTE Stable Diffusion 2 returns a PIL.Image object
|
72 |
-
# image and mask_image should be PIL images.
|
73 |
-
# The mask structure is white for inpainting and black for keeping as is
|
74 |
-
diffused_image_np = process(
|
75 |
-
model, ddim_sampler,
|
76 |
-
np.array(init_image), prompt, a_prompt, n_prompt, num_samples,
|
77 |
-
ddim_steps, guidance_scale, seed, eta,
|
78 |
-
strength=strength, detected_map=depth_map_np, unknown_mask=np.array(generate_mask_image), save_memory=save_memory
|
79 |
-
)[0]
|
80 |
-
|
81 |
-
init_image = init_image.convert("RGB")
|
82 |
-
diffused_image = Image.fromarray(diffused_image_np).convert("RGB")
|
83 |
-
|
84 |
-
if blend > 0 and transforms.ToTensor()(keep_mask_image).sum() > 0:
|
85 |
-
print("=> blending the generated region...")
|
86 |
-
kernel_size = 3
|
87 |
-
kernel = np.ones((kernel_size, kernel_size), np.uint8)
|
88 |
-
|
89 |
-
keep_image_np = np.array(init_image).astype(np.uint8)
|
90 |
-
keep_image_np_dilate = cv2.dilate(keep_image_np, kernel, iterations=1)
|
91 |
-
|
92 |
-
keep_mask_np = np.array(keep_mask_image).astype(np.uint8)
|
93 |
-
keep_mask_np_dilate = cv2.dilate(keep_mask_np, kernel, iterations=1)
|
94 |
-
|
95 |
-
generate_image_np = np.array(diffused_image).astype(np.uint8)
|
96 |
-
|
97 |
-
overlap_mask_np = np.array(generate_mask_image).astype(np.uint8)
|
98 |
-
overlap_mask_np *= keep_mask_np_dilate
|
99 |
-
print("=> blending {} pixels...".format(np.sum(overlap_mask_np)))
|
100 |
-
|
101 |
-
overlap_keep = keep_image_np_dilate[overlap_mask_np == 1]
|
102 |
-
overlap_generate = generate_image_np[overlap_mask_np == 1]
|
103 |
-
|
104 |
-
overlap_np = overlap_keep * blend + overlap_generate * (1 - blend)
|
105 |
-
|
106 |
-
generate_image_np[overlap_mask_np == 1] = overlap_np
|
107 |
-
|
108 |
-
diffused_image = Image.fromarray(generate_image_np.astype(np.uint8)).convert("RGB")
|
109 |
-
|
110 |
-
init_image_masked = init_image
|
111 |
-
diffused_image_masked = diffused_image
|
112 |
-
|
113 |
-
return diffused_image, init_image_masked, diffused_image_masked
|
114 |
-
|
115 |
-
|
116 |
-
@torch.no_grad()
|
117 |
-
def apply_inpainting(model,
|
118 |
-
init_image, mask_image_tensor, prompt, height, width, device):
|
119 |
-
"""
|
120 |
-
Use Stable Diffusion 2 to generate image
|
121 |
-
|
122 |
-
Arguments:
|
123 |
-
args: input arguments
|
124 |
-
model: Stable Diffusion 2 model
|
125 |
-
init_image_tensor: input image, torch.FloatTensor of shape (1, H, W, 3)
|
126 |
-
mask_tensor: depth map of the input image, torch.FloatTensor of shape (1, H, W, 1)
|
127 |
-
depth_map_tensor: depth map of the input image, torch.FloatTensor of shape (1, H, W)
|
128 |
-
"""
|
129 |
-
|
130 |
-
print("=> generating Inpainting image...")
|
131 |
-
|
132 |
-
mask_image = mask_image_tensor[0].cpu()
|
133 |
-
mask_image = mask_image.permute(2, 0, 1)
|
134 |
-
mask_image = transforms.ToPILImage()(mask_image).convert("L")
|
135 |
-
|
136 |
-
# NOTE Stable Diffusion 2 returns a PIL.Image object
|
137 |
-
# image and mask_image should be PIL images.
|
138 |
-
# The mask structure is white for inpainting and black for keeping as is
|
139 |
-
diffused_image = model(
|
140 |
-
prompt=prompt,
|
141 |
-
image=init_image.resize((512, 512)),
|
142 |
-
mask_image=mask_image.resize((512, 512)),
|
143 |
-
height=512,
|
144 |
-
width=512
|
145 |
-
).images[0].resize((height, width))
|
146 |
-
|
147 |
-
return diffused_image
|
148 |
-
|
149 |
-
|
150 |
-
@torch.no_grad()
|
151 |
-
def apply_inpainting_postprocess(model,
|
152 |
-
init_image, mask_image_tensor, prompt, height, width, device):
|
153 |
-
"""
|
154 |
-
Use Stable Diffusion 2 to generate image
|
155 |
-
|
156 |
-
Arguments:
|
157 |
-
args: input arguments
|
158 |
-
model: Stable Diffusion 2 model
|
159 |
-
init_image_tensor: input image, torch.FloatTensor of shape (1, H, W, 3)
|
160 |
-
mask_tensor: depth map of the input image, torch.FloatTensor of shape (1, H, W, 1)
|
161 |
-
depth_map_tensor: depth map of the input image, torch.FloatTensor of shape (1, H, W)
|
162 |
-
"""
|
163 |
-
|
164 |
-
print("=> generating Inpainting image...")
|
165 |
-
|
166 |
-
mask_image = mask_image_tensor[0].cpu()
|
167 |
-
mask_image = mask_image.permute(2, 0, 1)
|
168 |
-
mask_image = transforms.ToPILImage()(mask_image).convert("L")
|
169 |
-
|
170 |
-
# NOTE Stable Diffusion 2 returns a PIL.Image object
|
171 |
-
# image and mask_image should be PIL images.
|
172 |
-
# The mask structure is white for inpainting and black for keeping as is
|
173 |
-
diffused_image = model(
|
174 |
-
prompt=prompt,
|
175 |
-
image=init_image.resize((512, 512)),
|
176 |
-
mask_image=mask_image.resize((512, 512)),
|
177 |
-
height=512,
|
178 |
-
width=512
|
179 |
-
).images[0].resize((height, width))
|
180 |
-
|
181 |
-
diffused_image_tensor = torch.from_numpy(np.array(diffused_image)).to(device)
|
182 |
-
|
183 |
-
init_images_tensor = torch.from_numpy(np.array(init_image)).to(device)
|
184 |
-
|
185 |
-
init_images_tensor = diffused_image_tensor * mask_image_tensor[0] + init_images_tensor * (1 - mask_image_tensor[0])
|
186 |
-
init_image = Image.fromarray(init_images_tensor.cpu().numpy().astype(np.uint8)).convert("RGB")
|
187 |
-
|
188 |
-
return init_image
|
189 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
text2tex/lib/io_helper.py
DELETED
@@ -1,78 +0,0 @@
|
|
1 |
-
# common utils
|
2 |
-
import os
|
3 |
-
import json
|
4 |
-
|
5 |
-
# numpy
|
6 |
-
import numpy as np
|
7 |
-
|
8 |
-
# visualization
|
9 |
-
import matplotlib
|
10 |
-
import matplotlib.cm as cm
|
11 |
-
import matplotlib.pyplot as plt
|
12 |
-
|
13 |
-
matplotlib.use("Agg")
|
14 |
-
|
15 |
-
from pytorch3d.io import save_obj
|
16 |
-
|
17 |
-
from torchvision import transforms
|
18 |
-
|
19 |
-
|
20 |
-
def save_depth(fragments, output_dir, init_image, view_idx):
|
21 |
-
print("=> saving depth...")
|
22 |
-
width, height = init_image.size
|
23 |
-
dpi = 100
|
24 |
-
figsize = width / float(dpi), height / float(dpi)
|
25 |
-
|
26 |
-
depth_np = fragments.zbuf[0].cpu().numpy()
|
27 |
-
|
28 |
-
fig = plt.figure(figsize=figsize)
|
29 |
-
ax = fig.add_axes([0, 0, 1, 1])
|
30 |
-
# Hide spines, ticks, etc.
|
31 |
-
ax.axis('off')
|
32 |
-
# Display the image.
|
33 |
-
ax.imshow(depth_np, cmap='gray')
|
34 |
-
|
35 |
-
plt.savefig(os.path.join(output_dir, "{}.png".format(view_idx)), bbox_inches='tight', pad_inches=0)
|
36 |
-
np.save(os.path.join(output_dir, "{}.npy".format(view_idx)), depth_np[..., 0])
|
37 |
-
|
38 |
-
|
39 |
-
def save_backproject_obj(output_dir, obj_name,
|
40 |
-
verts, faces, verts_uvs, faces_uvs, projected_texture,
|
41 |
-
device):
|
42 |
-
print("=> saving OBJ file...")
|
43 |
-
texture_map = transforms.ToTensor()(projected_texture).to(device)
|
44 |
-
texture_map = texture_map.permute(1, 2, 0)
|
45 |
-
obj_path = os.path.join(output_dir, obj_name)
|
46 |
-
|
47 |
-
save_obj(
|
48 |
-
obj_path,
|
49 |
-
verts=verts,
|
50 |
-
faces=faces,
|
51 |
-
decimal_places=5,
|
52 |
-
verts_uvs=verts_uvs,
|
53 |
-
faces_uvs=faces_uvs,
|
54 |
-
texture_map=texture_map
|
55 |
-
)
|
56 |
-
|
57 |
-
|
58 |
-
def save_args(args, output_dir):
|
59 |
-
with open(os.path.join(output_dir, "args.json"), "w") as f:
|
60 |
-
json.dump(
|
61 |
-
{k: v for k, v in vars(args).items()},
|
62 |
-
f,
|
63 |
-
indent=4
|
64 |
-
)
|
65 |
-
|
66 |
-
|
67 |
-
def save_viewpoints(args, output_dir, dist_list, elev_list, azim_list, view_list):
|
68 |
-
with open(os.path.join(output_dir, "viewpoints.json"), "w") as f:
|
69 |
-
json.dump(
|
70 |
-
{
|
71 |
-
"dist": dist_list,
|
72 |
-
"elev": elev_list,
|
73 |
-
"azim": azim_list,
|
74 |
-
"view": view_list
|
75 |
-
},
|
76 |
-
f,
|
77 |
-
indent=4
|
78 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
text2tex/lib/mesh_helper.py
DELETED
@@ -1,148 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import torch
|
3 |
-
import trimesh
|
4 |
-
import xatlas
|
5 |
-
|
6 |
-
import numpy as np
|
7 |
-
|
8 |
-
from sklearn.decomposition import PCA
|
9 |
-
|
10 |
-
from torchvision import transforms
|
11 |
-
|
12 |
-
from tqdm import tqdm
|
13 |
-
|
14 |
-
from pytorch3d.io import (
|
15 |
-
load_obj,
|
16 |
-
load_objs_as_meshes
|
17 |
-
)
|
18 |
-
|
19 |
-
|
20 |
-
def compute_principle_directions(model_path, num_points=20000):
|
21 |
-
mesh = trimesh.load_mesh(model_path, force="mesh")
|
22 |
-
pc, _ = trimesh.sample.sample_surface_even(mesh, num_points)
|
23 |
-
|
24 |
-
pc -= np.mean(pc, axis=0, keepdims=True)
|
25 |
-
|
26 |
-
principle_directions = PCA(n_components=3).fit(pc).components_
|
27 |
-
|
28 |
-
return principle_directions
|
29 |
-
|
30 |
-
|
31 |
-
def init_mesh(input_path, cache_path, device):
|
32 |
-
print("=> parameterizing target mesh...")
|
33 |
-
|
34 |
-
mesh = trimesh.load_mesh(input_path, force='mesh')
|
35 |
-
try:
|
36 |
-
vertices, faces = mesh.vertices, mesh.faces
|
37 |
-
except AttributeError:
|
38 |
-
print("multiple materials in {} are not supported".format(input_path))
|
39 |
-
exit()
|
40 |
-
|
41 |
-
vmapping, indices, uvs = xatlas.parametrize(vertices, faces)
|
42 |
-
xatlas.export(str(cache_path), vertices[vmapping], indices, uvs)
|
43 |
-
|
44 |
-
print("=> loading target mesh...")
|
45 |
-
|
46 |
-
# principle_directions = compute_principle_directions(cache_path)
|
47 |
-
principle_directions = None
|
48 |
-
|
49 |
-
_, faces, aux = load_obj(cache_path, device=device)
|
50 |
-
mesh = load_objs_as_meshes([cache_path], device=device)
|
51 |
-
|
52 |
-
num_verts = mesh.verts_packed().shape[0]
|
53 |
-
|
54 |
-
# make sure mesh center is at origin
|
55 |
-
bbox = mesh.get_bounding_boxes()
|
56 |
-
mesh_center = bbox.mean(dim=2).repeat(num_verts, 1)
|
57 |
-
mesh = apply_offsets_to_mesh(mesh, -mesh_center)
|
58 |
-
|
59 |
-
# make sure mesh size is normalized
|
60 |
-
box_size = bbox[..., 1] - bbox[..., 0]
|
61 |
-
box_max = box_size.max(dim=1, keepdim=True)[0].repeat(num_verts, 3)
|
62 |
-
mesh = apply_scale_to_mesh(mesh, 1 / box_max)
|
63 |
-
|
64 |
-
return mesh, mesh.verts_packed(), faces, aux, principle_directions, mesh_center, box_max
|
65 |
-
|
66 |
-
|
67 |
-
def apply_offsets_to_mesh(mesh, offsets):
|
68 |
-
new_mesh = mesh.offset_verts(offsets)
|
69 |
-
|
70 |
-
return new_mesh
|
71 |
-
|
72 |
-
def apply_scale_to_mesh(mesh, scale):
|
73 |
-
new_mesh = mesh.scale_verts(scale)
|
74 |
-
|
75 |
-
return new_mesh
|
76 |
-
|
77 |
-
|
78 |
-
def adjust_uv_map(faces, aux, init_texture, uv_size):
|
79 |
-
"""
|
80 |
-
adjust UV map to be compatiable with multiple textures.
|
81 |
-
UVs for different materials will be decomposed and placed horizontally
|
82 |
-
|
83 |
-
+-----+-----+-----+--
|
84 |
-
| 1 | 2 | 3 |
|
85 |
-
+-----+-----+-----+--
|
86 |
-
|
87 |
-
"""
|
88 |
-
|
89 |
-
textures_ids = faces.textures_idx
|
90 |
-
materials_idx = faces.materials_idx
|
91 |
-
verts_uvs = aux.verts_uvs
|
92 |
-
|
93 |
-
num_materials = torch.unique(materials_idx).shape[0]
|
94 |
-
|
95 |
-
new_verts_uvs = verts_uvs.clone()
|
96 |
-
for material_id in range(num_materials):
|
97 |
-
# apply offsets to horizontal axis
|
98 |
-
faces_ids = textures_ids[materials_idx == material_id].unique()
|
99 |
-
new_verts_uvs[faces_ids, 0] += material_id
|
100 |
-
|
101 |
-
new_verts_uvs[:, 0] /= num_materials
|
102 |
-
|
103 |
-
init_texture_tensor = transforms.ToTensor()(init_texture)
|
104 |
-
init_texture_tensor = torch.cat([init_texture_tensor for _ in range(num_materials)], dim=-1)
|
105 |
-
init_texture = transforms.ToPILImage()(init_texture_tensor).resize((uv_size, uv_size))
|
106 |
-
|
107 |
-
return new_verts_uvs, init_texture
|
108 |
-
|
109 |
-
|
110 |
-
@torch.no_grad()
|
111 |
-
def update_face_angles(mesh, cameras, fragments):
|
112 |
-
def get_angle(x, y):
|
113 |
-
x = torch.nn.functional.normalize(x)
|
114 |
-
y = torch.nn.functional.normalize(y)
|
115 |
-
inner_product = (x * y).sum(dim=1)
|
116 |
-
x_norm = x.pow(2).sum(dim=1).pow(0.5)
|
117 |
-
y_norm = y.pow(2).sum(dim=1).pow(0.5)
|
118 |
-
cos = inner_product / (x_norm * y_norm)
|
119 |
-
angle = torch.acos(cos)
|
120 |
-
angle = angle * 180 / 3.14159
|
121 |
-
|
122 |
-
return angle
|
123 |
-
|
124 |
-
# face normals
|
125 |
-
face_normals = mesh.faces_normals_padded()[0]
|
126 |
-
|
127 |
-
# view vector (object center -> camera center)
|
128 |
-
camera_center = cameras.get_camera_center()
|
129 |
-
|
130 |
-
face_angles = get_angle(
|
131 |
-
face_normals,
|
132 |
-
camera_center.repeat(face_normals.shape[0], 1)
|
133 |
-
) # (F)
|
134 |
-
|
135 |
-
face_angles_rev = get_angle(
|
136 |
-
face_normals,
|
137 |
-
-camera_center.repeat(face_normals.shape[0], 1)
|
138 |
-
) # (F)
|
139 |
-
|
140 |
-
face_angles = torch.minimum(face_angles, face_angles_rev)
|
141 |
-
|
142 |
-
# Indices of unique visible faces
|
143 |
-
visible_map = fragments.pix_to_face.unique() # (num_visible_faces)
|
144 |
-
invisible_mask = torch.ones_like(face_angles)
|
145 |
-
invisible_mask[visible_map] = 0
|
146 |
-
face_angles[invisible_mask == 1] = 10000. # angles of invisible faces are ignored
|
147 |
-
|
148 |
-
return face_angles
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
text2tex/lib/projection_helper.py
DELETED
@@ -1,464 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import torch
|
3 |
-
|
4 |
-
import cv2
|
5 |
-
import random
|
6 |
-
|
7 |
-
import numpy as np
|
8 |
-
|
9 |
-
from torchvision import transforms
|
10 |
-
|
11 |
-
from pytorch3d.renderer import TexturesUV
|
12 |
-
from pytorch3d.ops import interpolate_face_attributes
|
13 |
-
|
14 |
-
from PIL import Image
|
15 |
-
|
16 |
-
from tqdm import tqdm
|
17 |
-
|
18 |
-
# customized
|
19 |
-
import sys
|
20 |
-
sys.path.append(".")
|
21 |
-
|
22 |
-
from lib.camera_helper import init_camera
|
23 |
-
from lib.render_helper import init_renderer, render
|
24 |
-
from lib.shading_helper import (
|
25 |
-
BlendParams,
|
26 |
-
init_soft_phong_shader,
|
27 |
-
init_flat_texel_shader,
|
28 |
-
)
|
29 |
-
from lib.vis_helper import visualize_outputs, visualize_quad_mask
|
30 |
-
from lib.constants import *
|
31 |
-
|
32 |
-
|
33 |
-
def get_all_4_locations(values_y, values_x):
|
34 |
-
y_0 = torch.floor(values_y)
|
35 |
-
y_1 = torch.ceil(values_y)
|
36 |
-
x_0 = torch.floor(values_x)
|
37 |
-
x_1 = torch.ceil(values_x)
|
38 |
-
|
39 |
-
return torch.cat([y_0, y_0, y_1, y_1], 0).long(), torch.cat([x_0, x_1, x_0, x_1], 0).long()
|
40 |
-
|
41 |
-
|
42 |
-
def compose_quad_mask(new_mask_image, update_mask_image, old_mask_image, device):
|
43 |
-
"""
|
44 |
-
compose quad mask:
|
45 |
-
-> 0: background
|
46 |
-
-> 1: old
|
47 |
-
-> 2: update
|
48 |
-
-> 3: new
|
49 |
-
"""
|
50 |
-
|
51 |
-
new_mask_tensor = transforms.ToTensor()(new_mask_image).to(device)
|
52 |
-
update_mask_tensor = transforms.ToTensor()(update_mask_image).to(device)
|
53 |
-
old_mask_tensor = transforms.ToTensor()(old_mask_image).to(device)
|
54 |
-
|
55 |
-
all_mask_tensor = new_mask_tensor + update_mask_tensor + old_mask_tensor
|
56 |
-
|
57 |
-
quad_mask_tensor = torch.zeros_like(all_mask_tensor)
|
58 |
-
quad_mask_tensor[old_mask_tensor == 1] = 1
|
59 |
-
quad_mask_tensor[update_mask_tensor == 1] = 2
|
60 |
-
quad_mask_tensor[new_mask_tensor == 1] = 3
|
61 |
-
|
62 |
-
return old_mask_tensor, update_mask_tensor, new_mask_tensor, all_mask_tensor, quad_mask_tensor
|
63 |
-
|
64 |
-
|
65 |
-
def compute_view_heat(similarity_tensor, quad_mask_tensor):
|
66 |
-
num_total_pixels = quad_mask_tensor.reshape(-1).shape[0]
|
67 |
-
heat = 0
|
68 |
-
for idx in QUAD_WEIGHTS:
|
69 |
-
heat += (quad_mask_tensor == idx).sum() * QUAD_WEIGHTS[idx] / num_total_pixels
|
70 |
-
|
71 |
-
return heat
|
72 |
-
|
73 |
-
|
74 |
-
def select_viewpoint(selected_view_ids, view_punishments,
|
75 |
-
mode, dist_list, elev_list, azim_list, sector_list, view_idx,
|
76 |
-
similarity_texture_cache, exist_texture,
|
77 |
-
mesh, faces, verts_uvs,
|
78 |
-
image_size, faces_per_pixel,
|
79 |
-
init_image_dir, mask_image_dir, normal_map_dir, depth_map_dir, similarity_map_dir,
|
80 |
-
device, use_principle=False
|
81 |
-
):
|
82 |
-
if mode == "sequential":
|
83 |
-
|
84 |
-
num_views = len(dist_list)
|
85 |
-
|
86 |
-
dist = dist_list[view_idx % num_views]
|
87 |
-
elev = elev_list[view_idx % num_views]
|
88 |
-
azim = azim_list[view_idx % num_views]
|
89 |
-
sector = sector_list[view_idx % num_views]
|
90 |
-
|
91 |
-
selected_view_ids.append(view_idx % num_views)
|
92 |
-
|
93 |
-
elif mode == "heuristic":
|
94 |
-
|
95 |
-
if use_principle and view_idx < 6:
|
96 |
-
|
97 |
-
selected_view_idx = view_idx
|
98 |
-
|
99 |
-
else:
|
100 |
-
|
101 |
-
selected_view_idx = None
|
102 |
-
max_heat = 0
|
103 |
-
|
104 |
-
print("=> selecting next view...")
|
105 |
-
view_heat_list = []
|
106 |
-
for sample_idx in tqdm(range(len(dist_list))):
|
107 |
-
|
108 |
-
view_heat, *_ = render_one_view_and_build_masks(dist_list[sample_idx], elev_list[sample_idx], azim_list[sample_idx],
|
109 |
-
sample_idx, sample_idx, view_punishments,
|
110 |
-
similarity_texture_cache, exist_texture,
|
111 |
-
mesh, faces, verts_uvs,
|
112 |
-
image_size, faces_per_pixel,
|
113 |
-
init_image_dir, mask_image_dir, normal_map_dir, depth_map_dir, similarity_map_dir,
|
114 |
-
device)
|
115 |
-
|
116 |
-
if view_heat > max_heat:
|
117 |
-
selected_view_idx = sample_idx
|
118 |
-
max_heat = view_heat
|
119 |
-
|
120 |
-
view_heat_list.append(view_heat.item())
|
121 |
-
|
122 |
-
print(view_heat_list)
|
123 |
-
print("select view {} with heat {}".format(selected_view_idx, max_heat))
|
124 |
-
|
125 |
-
|
126 |
-
dist = dist_list[selected_view_idx]
|
127 |
-
elev = elev_list[selected_view_idx]
|
128 |
-
azim = azim_list[selected_view_idx]
|
129 |
-
sector = sector_list[selected_view_idx]
|
130 |
-
|
131 |
-
selected_view_ids.append(selected_view_idx)
|
132 |
-
|
133 |
-
view_punishments[selected_view_idx] *= 0.01
|
134 |
-
|
135 |
-
elif mode == "random":
|
136 |
-
|
137 |
-
selected_view_idx = random.choice(range(len(dist_list)))
|
138 |
-
|
139 |
-
dist = dist_list[selected_view_idx]
|
140 |
-
elev = elev_list[selected_view_idx]
|
141 |
-
azim = azim_list[selected_view_idx]
|
142 |
-
sector = sector_list[selected_view_idx]
|
143 |
-
|
144 |
-
selected_view_ids.append(selected_view_idx)
|
145 |
-
|
146 |
-
else:
|
147 |
-
raise NotImplementedError()
|
148 |
-
|
149 |
-
return dist, elev, azim, sector, selected_view_ids, view_punishments
|
150 |
-
|
151 |
-
|
152 |
-
@torch.no_grad()
|
153 |
-
def build_backproject_mask(mesh, faces, verts_uvs,
|
154 |
-
cameras, reference_image, faces_per_pixel,
|
155 |
-
image_size, uv_size, device):
|
156 |
-
# construct pixel UVs
|
157 |
-
renderer_scaled = init_renderer(cameras,
|
158 |
-
shader=init_soft_phong_shader(
|
159 |
-
camera=cameras,
|
160 |
-
blend_params=BlendParams(),
|
161 |
-
device=device),
|
162 |
-
image_size=image_size,
|
163 |
-
faces_per_pixel=faces_per_pixel
|
164 |
-
)
|
165 |
-
fragments_scaled = renderer_scaled.rasterizer(mesh)
|
166 |
-
|
167 |
-
# get UV coordinates for each pixel
|
168 |
-
faces_verts_uvs = verts_uvs[faces.textures_idx]
|
169 |
-
|
170 |
-
pixel_uvs = interpolate_face_attributes(
|
171 |
-
fragments_scaled.pix_to_face, fragments_scaled.bary_coords, faces_verts_uvs
|
172 |
-
) # NxHsxWsxKx2
|
173 |
-
pixel_uvs = pixel_uvs.permute(0, 3, 1, 2, 4).reshape(-1, 2)
|
174 |
-
|
175 |
-
texture_locations_y, texture_locations_x = get_all_4_locations(
|
176 |
-
(1 - pixel_uvs[:, 1]).reshape(-1) * (uv_size - 1),
|
177 |
-
pixel_uvs[:, 0].reshape(-1) * (uv_size - 1)
|
178 |
-
)
|
179 |
-
|
180 |
-
K = faces_per_pixel
|
181 |
-
|
182 |
-
texture_values = torch.from_numpy(np.array(reference_image.resize((image_size, image_size)))).float() / 255.
|
183 |
-
texture_values = texture_values.to(device).unsqueeze(0).expand([4, -1, -1, -1]).unsqueeze(0).expand([K, -1, -1, -1, -1])
|
184 |
-
|
185 |
-
# texture
|
186 |
-
texture_tensor = torch.zeros(uv_size, uv_size, 3).to(device)
|
187 |
-
texture_tensor[texture_locations_y, texture_locations_x, :] = texture_values.reshape(-1, 3)
|
188 |
-
|
189 |
-
return texture_tensor[:, :, 0]
|
190 |
-
|
191 |
-
|
192 |
-
@torch.no_grad()
|
193 |
-
def build_diffusion_mask(mesh_stuff,
|
194 |
-
renderer, exist_texture, similarity_texture_cache, target_value, device, image_size,
|
195 |
-
smooth_mask=False, view_threshold=0.01):
|
196 |
-
|
197 |
-
mesh, faces, verts_uvs = mesh_stuff
|
198 |
-
mask_mesh = mesh.clone() # NOTE in-place operation - DANGER!!!
|
199 |
-
|
200 |
-
# visible mask => the whole region
|
201 |
-
exist_texture_expand = exist_texture.unsqueeze(0).unsqueeze(-1).expand(-1, -1, -1, 3).to(device)
|
202 |
-
mask_mesh.textures = TexturesUV(
|
203 |
-
maps=torch.ones_like(exist_texture_expand),
|
204 |
-
faces_uvs=faces.textures_idx[None, ...],
|
205 |
-
verts_uvs=verts_uvs[None, ...],
|
206 |
-
sampling_mode="nearest"
|
207 |
-
)
|
208 |
-
# visible_mask_tensor, *_ = render(mask_mesh, renderer)
|
209 |
-
visible_mask_tensor, _, similarity_map_tensor, *_ = render(mask_mesh, renderer)
|
210 |
-
# faces that are too rotated away from the viewpoint will be treated as invisible
|
211 |
-
valid_mask_tensor = (similarity_map_tensor >= view_threshold).float()
|
212 |
-
visible_mask_tensor *= valid_mask_tensor
|
213 |
-
|
214 |
-
# nonexist mask <=> new mask
|
215 |
-
exist_texture_expand = exist_texture.unsqueeze(0).unsqueeze(-1).expand(-1, -1, -1, 3).to(device)
|
216 |
-
mask_mesh.textures = TexturesUV(
|
217 |
-
maps=1 - exist_texture_expand,
|
218 |
-
faces_uvs=faces.textures_idx[None, ...],
|
219 |
-
verts_uvs=verts_uvs[None, ...],
|
220 |
-
sampling_mode="nearest"
|
221 |
-
)
|
222 |
-
new_mask_tensor, *_ = render(mask_mesh, renderer)
|
223 |
-
new_mask_tensor *= valid_mask_tensor
|
224 |
-
|
225 |
-
# exist mask => visible mask - new mask
|
226 |
-
exist_mask_tensor = visible_mask_tensor - new_mask_tensor
|
227 |
-
exist_mask_tensor[exist_mask_tensor < 0] = 0 # NOTE dilate can lead to overflow
|
228 |
-
|
229 |
-
# all update mask
|
230 |
-
mask_mesh.textures = TexturesUV(
|
231 |
-
maps=(
|
232 |
-
similarity_texture_cache.argmax(0) == target_value
|
233 |
-
# # only consider the views that have already appeared before
|
234 |
-
# similarity_texture_cache[0:target_value+1].argmax(0) == target_value
|
235 |
-
).float().unsqueeze(0).unsqueeze(-1).expand(-1, -1, -1, 3).to(device),
|
236 |
-
faces_uvs=faces.textures_idx[None, ...],
|
237 |
-
verts_uvs=verts_uvs[None, ...],
|
238 |
-
sampling_mode="nearest"
|
239 |
-
)
|
240 |
-
all_update_mask_tensor, *_ = render(mask_mesh, renderer)
|
241 |
-
|
242 |
-
# current update mask => intersection between all update mask and exist mask
|
243 |
-
update_mask_tensor = exist_mask_tensor * all_update_mask_tensor
|
244 |
-
|
245 |
-
# keep mask => exist mask - update mask
|
246 |
-
old_mask_tensor = exist_mask_tensor - update_mask_tensor
|
247 |
-
|
248 |
-
# convert
|
249 |
-
new_mask = new_mask_tensor[0].cpu().float().permute(2, 0, 1)
|
250 |
-
new_mask = transforms.ToPILImage()(new_mask).convert("L")
|
251 |
-
|
252 |
-
update_mask = update_mask_tensor[0].cpu().float().permute(2, 0, 1)
|
253 |
-
update_mask = transforms.ToPILImage()(update_mask).convert("L")
|
254 |
-
|
255 |
-
old_mask = old_mask_tensor[0].cpu().float().permute(2, 0, 1)
|
256 |
-
old_mask = transforms.ToPILImage()(old_mask).convert("L")
|
257 |
-
|
258 |
-
exist_mask = exist_mask_tensor[0].cpu().float().permute(2, 0, 1)
|
259 |
-
exist_mask = transforms.ToPILImage()(exist_mask).convert("L")
|
260 |
-
|
261 |
-
return new_mask, update_mask, old_mask, exist_mask
|
262 |
-
|
263 |
-
|
264 |
-
@torch.no_grad()
|
265 |
-
def render_one_view(mesh,
|
266 |
-
dist, elev, azim,
|
267 |
-
image_size, faces_per_pixel,
|
268 |
-
device):
|
269 |
-
|
270 |
-
# render the view
|
271 |
-
cameras = init_camera(
|
272 |
-
dist, elev, azim,
|
273 |
-
image_size, device
|
274 |
-
)
|
275 |
-
renderer = init_renderer(cameras,
|
276 |
-
shader=init_soft_phong_shader(
|
277 |
-
camera=cameras,
|
278 |
-
blend_params=BlendParams(),
|
279 |
-
device=device),
|
280 |
-
image_size=image_size,
|
281 |
-
faces_per_pixel=faces_per_pixel
|
282 |
-
)
|
283 |
-
|
284 |
-
init_images_tensor, normal_maps_tensor, similarity_tensor, depth_maps_tensor, fragments = render(mesh, renderer)
|
285 |
-
|
286 |
-
return (
|
287 |
-
cameras, renderer,
|
288 |
-
init_images_tensor, normal_maps_tensor, similarity_tensor, depth_maps_tensor, fragments
|
289 |
-
)
|
290 |
-
|
291 |
-
|
292 |
-
@torch.no_grad()
|
293 |
-
def build_similarity_texture_cache_for_all_views(mesh, faces, verts_uvs,
|
294 |
-
dist_list, elev_list, azim_list,
|
295 |
-
image_size, image_size_scaled, uv_size, faces_per_pixel,
|
296 |
-
device):
|
297 |
-
|
298 |
-
num_candidate_views = len(dist_list)
|
299 |
-
similarity_texture_cache = torch.zeros(num_candidate_views, uv_size, uv_size).to(device)
|
300 |
-
|
301 |
-
print("=> building similarity texture cache for all views...")
|
302 |
-
for i in tqdm(range(num_candidate_views)):
|
303 |
-
cameras, _, _, _, similarity_tensor, _, _ = render_one_view(mesh,
|
304 |
-
dist_list[i], elev_list[i], azim_list[i],
|
305 |
-
image_size, faces_per_pixel, device)
|
306 |
-
|
307 |
-
similarity_texture_cache[i] = build_backproject_mask(mesh, faces, verts_uvs,
|
308 |
-
cameras, transforms.ToPILImage()(similarity_tensor[0, :, :, 0]).convert("RGB"), faces_per_pixel,
|
309 |
-
image_size_scaled, uv_size, device)
|
310 |
-
|
311 |
-
return similarity_texture_cache
|
312 |
-
|
313 |
-
|
314 |
-
@torch.no_grad()
|
315 |
-
def render_one_view_and_build_masks(dist, elev, azim,
|
316 |
-
selected_view_idx, view_idx, view_punishments,
|
317 |
-
similarity_texture_cache, exist_texture,
|
318 |
-
mesh, faces, verts_uvs,
|
319 |
-
image_size, faces_per_pixel,
|
320 |
-
init_image_dir, mask_image_dir, normal_map_dir, depth_map_dir, similarity_map_dir,
|
321 |
-
device, save_intermediate=False, smooth_mask=False, view_threshold=0.01):
|
322 |
-
|
323 |
-
# render the view
|
324 |
-
(
|
325 |
-
cameras, renderer,
|
326 |
-
init_images_tensor, normal_maps_tensor, similarity_tensor, depth_maps_tensor, fragments
|
327 |
-
) = render_one_view(mesh,
|
328 |
-
dist, elev, azim,
|
329 |
-
image_size, faces_per_pixel,
|
330 |
-
device
|
331 |
-
)
|
332 |
-
|
333 |
-
init_image = init_images_tensor[0].cpu()
|
334 |
-
init_image = init_image.permute(2, 0, 1)
|
335 |
-
init_image = transforms.ToPILImage()(init_image).convert("RGB")
|
336 |
-
|
337 |
-
normal_map = normal_maps_tensor[0].cpu()
|
338 |
-
normal_map = normal_map.permute(2, 0, 1)
|
339 |
-
normal_map = transforms.ToPILImage()(normal_map).convert("RGB")
|
340 |
-
|
341 |
-
depth_map = depth_maps_tensor[0].cpu().numpy()
|
342 |
-
depth_map = Image.fromarray(depth_map).convert("L")
|
343 |
-
|
344 |
-
similarity_map = similarity_tensor[0, :, :, 0].cpu()
|
345 |
-
similarity_map = transforms.ToPILImage()(similarity_map).convert("L")
|
346 |
-
|
347 |
-
|
348 |
-
flat_renderer = init_renderer(cameras,
|
349 |
-
shader=init_flat_texel_shader(
|
350 |
-
camera=cameras,
|
351 |
-
device=device),
|
352 |
-
image_size=image_size,
|
353 |
-
faces_per_pixel=faces_per_pixel
|
354 |
-
)
|
355 |
-
new_mask_image, update_mask_image, old_mask_image, exist_mask_image = build_diffusion_mask(
|
356 |
-
(mesh, faces, verts_uvs),
|
357 |
-
flat_renderer, exist_texture, similarity_texture_cache, selected_view_idx, device, image_size,
|
358 |
-
smooth_mask=smooth_mask, view_threshold=view_threshold
|
359 |
-
)
|
360 |
-
# NOTE the view idx is the absolute idx in the sample space (i.e. `selected_view_idx`)
|
361 |
-
# it should match with `similarity_texture_cache`
|
362 |
-
|
363 |
-
(
|
364 |
-
old_mask_tensor,
|
365 |
-
update_mask_tensor,
|
366 |
-
new_mask_tensor,
|
367 |
-
all_mask_tensor,
|
368 |
-
quad_mask_tensor
|
369 |
-
) = compose_quad_mask(new_mask_image, update_mask_image, old_mask_image, device)
|
370 |
-
|
371 |
-
view_heat = compute_view_heat(similarity_tensor, quad_mask_tensor)
|
372 |
-
view_heat *= view_punishments[selected_view_idx]
|
373 |
-
|
374 |
-
# save intermediate results
|
375 |
-
if save_intermediate:
|
376 |
-
init_image.save(os.path.join(init_image_dir, "{}.png".format(view_idx)))
|
377 |
-
normal_map.save(os.path.join(normal_map_dir, "{}.png".format(view_idx)))
|
378 |
-
depth_map.save(os.path.join(depth_map_dir, "{}.png".format(view_idx)))
|
379 |
-
similarity_map.save(os.path.join(similarity_map_dir, "{}.png".format(view_idx)))
|
380 |
-
|
381 |
-
new_mask_image.save(os.path.join(mask_image_dir, "{}_new.png".format(view_idx)))
|
382 |
-
update_mask_image.save(os.path.join(mask_image_dir, "{}_update.png".format(view_idx)))
|
383 |
-
old_mask_image.save(os.path.join(mask_image_dir, "{}_old.png".format(view_idx)))
|
384 |
-
exist_mask_image.save(os.path.join(mask_image_dir, "{}_exist.png".format(view_idx)))
|
385 |
-
|
386 |
-
visualize_quad_mask(mask_image_dir, quad_mask_tensor, view_idx, view_heat, device)
|
387 |
-
|
388 |
-
return (
|
389 |
-
view_heat,
|
390 |
-
renderer, cameras, fragments,
|
391 |
-
init_image, normal_map, depth_map,
|
392 |
-
init_images_tensor, normal_maps_tensor, depth_maps_tensor, similarity_tensor,
|
393 |
-
old_mask_image, update_mask_image, new_mask_image,
|
394 |
-
old_mask_tensor, update_mask_tensor, new_mask_tensor, all_mask_tensor, quad_mask_tensor
|
395 |
-
)
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
@torch.no_grad()
|
400 |
-
def backproject_from_image(mesh, faces, verts_uvs, cameras,
|
401 |
-
reference_image, new_mask_image, update_mask_image,
|
402 |
-
init_texture, exist_texture,
|
403 |
-
image_size, uv_size, faces_per_pixel,
|
404 |
-
device):
|
405 |
-
|
406 |
-
# construct pixel UVs
|
407 |
-
renderer_scaled = init_renderer(cameras,
|
408 |
-
shader=init_soft_phong_shader(
|
409 |
-
camera=cameras,
|
410 |
-
blend_params=BlendParams(),
|
411 |
-
device=device),
|
412 |
-
image_size=image_size,
|
413 |
-
faces_per_pixel=faces_per_pixel
|
414 |
-
)
|
415 |
-
fragments_scaled = renderer_scaled.rasterizer(mesh)
|
416 |
-
|
417 |
-
# get UV coordinates for each pixel
|
418 |
-
faces_verts_uvs = verts_uvs[faces.textures_idx]
|
419 |
-
|
420 |
-
pixel_uvs = interpolate_face_attributes(
|
421 |
-
fragments_scaled.pix_to_face, fragments_scaled.bary_coords, faces_verts_uvs
|
422 |
-
) # NxHsxWsxKx2
|
423 |
-
pixel_uvs = pixel_uvs.permute(0, 3, 1, 2, 4).reshape(pixel_uvs.shape[-2], pixel_uvs.shape[1], pixel_uvs.shape[2], 2)
|
424 |
-
|
425 |
-
# the update mask has to be on top of the diffusion mask
|
426 |
-
new_mask_image_tensor = transforms.ToTensor()(new_mask_image).to(device).unsqueeze(-1)
|
427 |
-
update_mask_image_tensor = transforms.ToTensor()(update_mask_image).to(device).unsqueeze(-1)
|
428 |
-
|
429 |
-
project_mask_image_tensor = torch.logical_or(update_mask_image_tensor, new_mask_image_tensor).float()
|
430 |
-
project_mask_image = project_mask_image_tensor * 255.
|
431 |
-
project_mask_image = Image.fromarray(project_mask_image[0, :, :, 0].cpu().numpy().astype(np.uint8))
|
432 |
-
|
433 |
-
project_mask_image_scaled = project_mask_image.resize(
|
434 |
-
(image_size, image_size),
|
435 |
-
Image.Resampling.NEAREST
|
436 |
-
)
|
437 |
-
project_mask_image_tensor_scaled = transforms.ToTensor()(project_mask_image_scaled).to(device)
|
438 |
-
|
439 |
-
pixel_uvs_masked = pixel_uvs[project_mask_image_tensor_scaled == 1]
|
440 |
-
|
441 |
-
texture_locations_y, texture_locations_x = get_all_4_locations(
|
442 |
-
(1 - pixel_uvs_masked[:, 1]).reshape(-1) * (uv_size - 1),
|
443 |
-
pixel_uvs_masked[:, 0].reshape(-1) * (uv_size - 1)
|
444 |
-
)
|
445 |
-
|
446 |
-
K = pixel_uvs.shape[0]
|
447 |
-
project_mask_image_tensor_scaled = project_mask_image_tensor_scaled[:, None, :, :, None].repeat(1, 4, 1, 1, 3)
|
448 |
-
|
449 |
-
texture_values = torch.from_numpy(np.array(reference_image.resize((image_size, image_size))))
|
450 |
-
texture_values = texture_values.to(device).unsqueeze(0).expand([4, -1, -1, -1]).unsqueeze(0).expand([K, -1, -1, -1, -1])
|
451 |
-
|
452 |
-
texture_values_masked = texture_values.reshape(-1, 3)[project_mask_image_tensor_scaled.reshape(-1, 3) == 1].reshape(-1, 3)
|
453 |
-
|
454 |
-
# texture
|
455 |
-
texture_tensor = torch.from_numpy(np.array(init_texture)).to(device)
|
456 |
-
texture_tensor[texture_locations_y, texture_locations_x, :] = texture_values_masked
|
457 |
-
|
458 |
-
init_texture = Image.fromarray(texture_tensor.cpu().numpy().astype(np.uint8))
|
459 |
-
|
460 |
-
# update texture cache
|
461 |
-
exist_texture[texture_locations_y, texture_locations_x] = 1
|
462 |
-
|
463 |
-
return init_texture, project_mask_image, exist_texture
|
464 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
text2tex/lib/render_helper.py
DELETED
@@ -1,108 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import torch
|
3 |
-
|
4 |
-
import cv2
|
5 |
-
|
6 |
-
import numpy as np
|
7 |
-
|
8 |
-
from PIL import Image
|
9 |
-
|
10 |
-
from torchvision import transforms
|
11 |
-
from pytorch3d.ops import interpolate_face_attributes
|
12 |
-
from pytorch3d.renderer import (
|
13 |
-
RasterizationSettings,
|
14 |
-
MeshRendererWithFragments,
|
15 |
-
MeshRasterizer,
|
16 |
-
)
|
17 |
-
|
18 |
-
# customized
|
19 |
-
import sys
|
20 |
-
sys.path.append(".")
|
21 |
-
|
22 |
-
|
23 |
-
def init_renderer(camera, shader, image_size, faces_per_pixel):
|
24 |
-
raster_settings = RasterizationSettings(image_size=image_size, faces_per_pixel=faces_per_pixel)
|
25 |
-
renderer = MeshRendererWithFragments(
|
26 |
-
rasterizer=MeshRasterizer(
|
27 |
-
cameras=camera,
|
28 |
-
raster_settings=raster_settings
|
29 |
-
),
|
30 |
-
shader=shader
|
31 |
-
)
|
32 |
-
|
33 |
-
return renderer
|
34 |
-
|
35 |
-
|
36 |
-
@torch.no_grad()
|
37 |
-
def render(mesh, renderer, pad_value=10):
|
38 |
-
def phong_normal_shading(meshes, fragments) -> torch.Tensor:
|
39 |
-
faces = meshes.faces_packed() # (F, 3)
|
40 |
-
vertex_normals = meshes.verts_normals_packed() # (V, 3)
|
41 |
-
faces_normals = vertex_normals[faces]
|
42 |
-
pixel_normals = interpolate_face_attributes(
|
43 |
-
fragments.pix_to_face, fragments.bary_coords, faces_normals
|
44 |
-
)
|
45 |
-
|
46 |
-
return pixel_normals
|
47 |
-
|
48 |
-
def similarity_shading(meshes, fragments):
|
49 |
-
faces = meshes.faces_packed() # (F, 3)
|
50 |
-
vertex_normals = meshes.verts_normals_packed() # (V, 3)
|
51 |
-
faces_normals = vertex_normals[faces]
|
52 |
-
vertices = meshes.verts_packed() # (V, 3)
|
53 |
-
face_positions = vertices[faces]
|
54 |
-
view_directions = torch.nn.functional.normalize((renderer.shader.cameras.get_camera_center().reshape(1, 1, 3) - face_positions), p=2, dim=2)
|
55 |
-
cosine_similarity = torch.nn.CosineSimilarity(dim=2)(faces_normals, view_directions)
|
56 |
-
pixel_similarity = interpolate_face_attributes(
|
57 |
-
fragments.pix_to_face, fragments.bary_coords, cosine_similarity.unsqueeze(-1)
|
58 |
-
)
|
59 |
-
|
60 |
-
return pixel_similarity
|
61 |
-
|
62 |
-
def get_relative_depth_map(fragments, pad_value=pad_value):
|
63 |
-
absolute_depth = fragments.zbuf[..., 0] # B, H, W
|
64 |
-
no_depth = -1
|
65 |
-
|
66 |
-
depth_min, depth_max = absolute_depth[absolute_depth != no_depth].min(), absolute_depth[absolute_depth != no_depth].max()
|
67 |
-
target_min, target_max = 50, 255
|
68 |
-
|
69 |
-
depth_value = absolute_depth[absolute_depth != no_depth]
|
70 |
-
depth_value = depth_max - depth_value # reverse values
|
71 |
-
|
72 |
-
depth_value /= (depth_max - depth_min)
|
73 |
-
depth_value = depth_value * (target_max - target_min) + target_min
|
74 |
-
|
75 |
-
relative_depth = absolute_depth.clone()
|
76 |
-
relative_depth[absolute_depth != no_depth] = depth_value
|
77 |
-
relative_depth[absolute_depth == no_depth] = pad_value # not completely black
|
78 |
-
|
79 |
-
return relative_depth
|
80 |
-
|
81 |
-
|
82 |
-
images, fragments = renderer(mesh)
|
83 |
-
normal_maps = phong_normal_shading(mesh, fragments).squeeze(-2)
|
84 |
-
similarity_maps = similarity_shading(mesh, fragments).squeeze(-2) # -1 - 1
|
85 |
-
depth_maps = get_relative_depth_map(fragments)
|
86 |
-
|
87 |
-
# normalize similarity mask to 0 - 1
|
88 |
-
similarity_maps = torch.abs(similarity_maps) # 0 - 1
|
89 |
-
|
90 |
-
# HACK erode, eliminate isolated dots
|
91 |
-
non_zero_similarity = (similarity_maps > 0).float()
|
92 |
-
non_zero_similarity = (non_zero_similarity * 255.).cpu().numpy().astype(np.uint8)[0]
|
93 |
-
non_zero_similarity = cv2.erode(non_zero_similarity, kernel=np.ones((3, 3), np.uint8), iterations=2)
|
94 |
-
non_zero_similarity = torch.from_numpy(non_zero_similarity).to(similarity_maps.device).unsqueeze(0) / 255.
|
95 |
-
similarity_maps = non_zero_similarity.unsqueeze(-1) * similarity_maps
|
96 |
-
|
97 |
-
return images, normal_maps, similarity_maps, depth_maps, fragments
|
98 |
-
|
99 |
-
|
100 |
-
@torch.no_grad()
|
101 |
-
def check_visible_faces(mesh, fragments):
|
102 |
-
pix_to_face = fragments.pix_to_face
|
103 |
-
|
104 |
-
# Indices of unique visible faces
|
105 |
-
visible_map = pix_to_face.unique() # (num_visible_faces)
|
106 |
-
|
107 |
-
return visible_map
|
108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
text2tex/lib/shading_helper.py
DELETED
@@ -1,45 +0,0 @@
|
|
1 |
-
from typing import NamedTuple, Sequence
|
2 |
-
|
3 |
-
from pytorch3d.renderer.mesh.shader import ShaderBase
|
4 |
-
from pytorch3d.renderer import (
|
5 |
-
AmbientLights,
|
6 |
-
SoftPhongShader
|
7 |
-
)
|
8 |
-
|
9 |
-
|
10 |
-
class BlendParams(NamedTuple):
|
11 |
-
sigma: float = 1e-4
|
12 |
-
gamma: float = 1e-4
|
13 |
-
background_color: Sequence = (1, 1, 1)
|
14 |
-
|
15 |
-
|
16 |
-
class FlatTexelShader(ShaderBase):
|
17 |
-
|
18 |
-
def __init__(self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None):
|
19 |
-
super().__init__(device, cameras, lights, materials, blend_params)
|
20 |
-
|
21 |
-
def forward(self, fragments, meshes, **_kwargs):
|
22 |
-
texels = meshes.sample_textures(fragments)
|
23 |
-
texels[(fragments.pix_to_face == -1), :] = 0
|
24 |
-
return texels.squeeze(-2)
|
25 |
-
|
26 |
-
|
27 |
-
def init_soft_phong_shader(camera, blend_params, device):
|
28 |
-
lights = AmbientLights(device=device)
|
29 |
-
shader = SoftPhongShader(
|
30 |
-
cameras=camera,
|
31 |
-
lights=lights,
|
32 |
-
device=device,
|
33 |
-
blend_params=blend_params
|
34 |
-
)
|
35 |
-
|
36 |
-
return shader
|
37 |
-
|
38 |
-
|
39 |
-
def init_flat_texel_shader(camera, device):
|
40 |
-
shader=FlatTexelShader(
|
41 |
-
cameras=camera,
|
42 |
-
device=device
|
43 |
-
)
|
44 |
-
|
45 |
-
return shader
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|