Spaces:
Sleeping
Sleeping
Commit
·
b5042f1
1
Parent(s):
16b2e4d
pushback
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +74 -10
- app_3d.py +21 -0
- app_canny.py +83 -0
- app_matnet.py +83 -0
- app_sd.py +154 -0
- app_texnet.py +259 -0
- cv_utils.py +17 -0
- depth_estimator.py +25 -0
- examples/bunny/mesh.obj +0 -0
- examples/fighter/mesh.obj +0 -0
- examples/highheel/mesh.obj +0 -0
- examples/monkey/mesh.obj +0 -0
- examples/tank/mesh.obj +0 -0
- image_segmentor.py +33 -0
- install.sh +18 -0
- model.py +959 -0
- pre-requirements.txt +9 -0
- preprocessor.py +120 -0
- push_dataset.py +9 -0
- requirements.txt +9 -0
- rgb2x/generate_blend.py +142 -0
- rgb2x/gradio_demo_rgb2x.py +157 -0
- rgb2x/load_image.py +119 -0
- rgb2x/pipeline_rgb2x.py +821 -0
- run.sh +12 -2
- settings.py +23 -0
- text2tex/lib/__init__.py +0 -0
- text2tex/lib/camera_helper.py +231 -0
- text2tex/lib/constants.py +648 -0
- text2tex/lib/diffusion_helper.py +189 -0
- text2tex/lib/io_helper.py +78 -0
- text2tex/lib/mesh_helper.py +148 -0
- text2tex/lib/projection_helper.py +464 -0
- text2tex/lib/render_helper.py +108 -0
- text2tex/lib/shading_helper.py +45 -0
- text2tex/lib/vis_helper.py +209 -0
- text2tex/models/ControlNet/.gitignore +143 -0
- text2tex/models/ControlNet/LICENSE +201 -0
- text2tex/models/ControlNet/README.md +234 -0
- text2tex/models/ControlNet/annotator/canny/__init__.py +5 -0
- text2tex/models/ControlNet/annotator/ckpts/ckpts.txt +1 -0
- text2tex/models/ControlNet/annotator/hed/__init__.py +127 -0
- text2tex/models/ControlNet/annotator/midas/__init__.py +36 -0
- text2tex/models/ControlNet/annotator/midas/api.py +165 -0
- text2tex/models/ControlNet/annotator/midas/midas/__init__.py +0 -0
- text2tex/models/ControlNet/annotator/midas/midas/base_model.py +16 -0
- text2tex/models/ControlNet/annotator/midas/midas/blocks.py +342 -0
- text2tex/models/ControlNet/annotator/midas/midas/dpt_depth.py +109 -0
- text2tex/models/ControlNet/annotator/midas/midas/midas_net.py +76 -0
- text2tex/models/ControlNet/annotator/midas/midas/midas_net_custom.py +128 -0
app.py
CHANGED
@@ -1,16 +1,80 @@
|
|
|
|
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
|
|
3 |
|
4 |
-
|
5 |
-
|
6 |
|
|
|
|
|
7 |
|
8 |
with gr.Blocks() as demo:
|
9 |
-
gr.Markdown(
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
import gradio as gr
|
4 |
+
import torch
|
5 |
+
|
6 |
+
import sys
|
7 |
+
pyt_version_str=torch.__version__.split("+")[0].replace(".", "")
|
8 |
+
version_str="".join([
|
9 |
+
f"py3{sys.version_info.minor}_cu",
|
10 |
+
torch.version.cuda.replace(".",""),
|
11 |
+
f"_pyt{pyt_version_str}"
|
12 |
+
])
|
13 |
+
print(f"Using version: {version_str}") # used to locate pytorch3d version in the requirements.txt for huggingface
|
14 |
+
|
15 |
+
|
16 |
+
from app_canny import create_demo as create_demo_canny
|
17 |
+
from app_texnet import create_demo as create_demo_texnet
|
18 |
+
|
19 |
+
from model import Model
|
20 |
+
from settings import ALLOW_CHANGING_BASE_MODEL, DEFAULT_MODEL_ID, SHOW_DUPLICATE_BUTTON
|
21 |
|
22 |
+
DESCRIPTION = "# Material Authoring Demo v0.3"
|
23 |
|
24 |
+
if not torch.cuda.is_available():
|
25 |
+
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p> Check if the 'CUDA_VISIBLE_DEVICES' are set incorrectly in settings.py"
|
26 |
|
27 |
+
# model = Model(base_model_id=DEFAULT_MODEL_ID, task_name="Canny")
|
28 |
+
model = Model(base_model_id=DEFAULT_MODEL_ID, task_name="texnet")
|
29 |
|
30 |
with gr.Blocks() as demo:
|
31 |
+
gr.Markdown(DESCRIPTION)
|
32 |
+
gr.DuplicateButton(
|
33 |
+
value="Duplicate Space for private use",
|
34 |
+
elem_id="duplicate-button",
|
35 |
+
visible=SHOW_DUPLICATE_BUTTON,
|
36 |
+
)
|
37 |
+
|
38 |
+
with gr.Tabs():
|
39 |
+
with gr.Tab("Texnet+Matnet"):
|
40 |
+
create_demo_texnet(model.process_texnet)
|
41 |
+
|
42 |
+
with gr.Accordion(label="Base model", open=False):
|
43 |
+
with gr.Row():
|
44 |
+
with gr.Column(scale=5):
|
45 |
+
current_base_model = gr.Text(label="Current base model")
|
46 |
+
with gr.Column(scale=1):
|
47 |
+
check_base_model_button = gr.Button("Check current base model")
|
48 |
+
with gr.Row():
|
49 |
+
with gr.Column(scale=5):
|
50 |
+
new_base_model_id = gr.Text(
|
51 |
+
label="New base model",
|
52 |
+
max_lines=1,
|
53 |
+
placeholder="stable-diffusion-v1-5/stable-diffusion-v1-5",
|
54 |
+
info="The base model must be compatible with Stable Diffusion v1.5.",
|
55 |
+
interactive=ALLOW_CHANGING_BASE_MODEL,
|
56 |
+
)
|
57 |
+
with gr.Column(scale=1):
|
58 |
+
change_base_model_button = gr.Button("Change base model", interactive=ALLOW_CHANGING_BASE_MODEL)
|
59 |
+
if not ALLOW_CHANGING_BASE_MODEL:
|
60 |
+
gr.Markdown(
|
61 |
+
"""The base model is not allowed to be changed in this Space so as not to slow down the demo, but it can be changed if you duplicate the Space."""
|
62 |
+
)
|
63 |
+
|
64 |
+
check_base_model_button.click(
|
65 |
+
fn=lambda: model.base_model_id,
|
66 |
+
outputs=current_base_model,
|
67 |
+
queue=False,
|
68 |
+
api_name="check_base_model",
|
69 |
+
)
|
70 |
+
gr.on(
|
71 |
+
triggers=[new_base_model_id.submit, change_base_model_button.click],
|
72 |
+
fn=model.set_base_model,
|
73 |
+
inputs=new_base_model_id,
|
74 |
+
outputs=current_base_model,
|
75 |
+
api_name=False,
|
76 |
+
concurrency_id="main",
|
77 |
+
)
|
78 |
+
|
79 |
+
if __name__ == "__main__":
|
80 |
+
demo.queue(max_size=20).launch()
|
app_3d.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
|
4 |
+
def load_mesh(mesh_file_name):
|
5 |
+
return mesh_file_name
|
6 |
+
|
7 |
+
demo = gr.Interface(
|
8 |
+
fn=load_mesh,
|
9 |
+
inputs=gr.Model3D(),
|
10 |
+
outputs=gr.Model3D(
|
11 |
+
clear_color=(255.0, 0.0, 0.0, 0.0), label="3D Model", display_mode="wireframe"),
|
12 |
+
examples=[
|
13 |
+
[os.path.join(os.path.dirname(__file__), "examples/bunny/mesh.obj")],
|
14 |
+
[os.path.join(os.path.dirname(__file__), "examples/monkey/mesh.obj")],
|
15 |
+
[os.path.join(os.path.dirname(__file__), "examples/Bunny.obj")],
|
16 |
+
],
|
17 |
+
cache_examples=True
|
18 |
+
)
|
19 |
+
|
20 |
+
if __name__ == "__main__":
|
21 |
+
demo.launch()
|
app_canny.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
|
5 |
+
from settings import (
|
6 |
+
DEFAULT_IMAGE_RESOLUTION,
|
7 |
+
DEFAULT_NUM_IMAGES,
|
8 |
+
MAX_IMAGE_RESOLUTION,
|
9 |
+
MAX_NUM_IMAGES,
|
10 |
+
MAX_SEED,
|
11 |
+
)
|
12 |
+
from utils import randomize_seed_fn
|
13 |
+
|
14 |
+
|
15 |
+
def create_demo(process):
|
16 |
+
with gr.Blocks() as demo:
|
17 |
+
with gr.Row():
|
18 |
+
with gr.Column():
|
19 |
+
image = gr.Image()
|
20 |
+
prompt = gr.Textbox(label="Prompt", submit_btn=True)
|
21 |
+
with gr.Accordion("Advanced options", open=False):
|
22 |
+
num_samples = gr.Slider(
|
23 |
+
label="Number of images", minimum=1, maximum=MAX_NUM_IMAGES, value=DEFAULT_NUM_IMAGES, step=1
|
24 |
+
)
|
25 |
+
image_resolution = gr.Slider(
|
26 |
+
label="Image resolution",
|
27 |
+
minimum=256,
|
28 |
+
maximum=MAX_IMAGE_RESOLUTION,
|
29 |
+
value=DEFAULT_IMAGE_RESOLUTION,
|
30 |
+
step=256,
|
31 |
+
)
|
32 |
+
canny_low_threshold = gr.Slider(
|
33 |
+
label="Canny low threshold", minimum=1, maximum=255, value=100, step=1
|
34 |
+
)
|
35 |
+
canny_high_threshold = gr.Slider(
|
36 |
+
label="Canny high threshold", minimum=1, maximum=255, value=200, step=1
|
37 |
+
)
|
38 |
+
num_steps = gr.Slider(label="Number of steps", minimum=1, maximum=100, value=20, step=1)
|
39 |
+
guidance_scale = gr.Slider(label="Guidance scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
|
40 |
+
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
|
41 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
42 |
+
a_prompt = gr.Textbox(label="Additional prompt", value="best quality, extremely detailed")
|
43 |
+
n_prompt = gr.Textbox(
|
44 |
+
label="Negative prompt",
|
45 |
+
value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
|
46 |
+
)
|
47 |
+
with gr.Column():
|
48 |
+
result = gr.Gallery(label="Output", show_label=False, columns=2, object_fit="scale-down")
|
49 |
+
inputs = [
|
50 |
+
image,
|
51 |
+
prompt,
|
52 |
+
a_prompt,
|
53 |
+
n_prompt,
|
54 |
+
num_samples,
|
55 |
+
image_resolution,
|
56 |
+
num_steps,
|
57 |
+
guidance_scale,
|
58 |
+
seed,
|
59 |
+
canny_low_threshold,
|
60 |
+
canny_high_threshold,
|
61 |
+
]
|
62 |
+
prompt.submit(
|
63 |
+
fn=randomize_seed_fn,
|
64 |
+
inputs=[seed, randomize_seed],
|
65 |
+
outputs=seed,
|
66 |
+
queue=False,
|
67 |
+
api_name=False,
|
68 |
+
).then(
|
69 |
+
fn=process,
|
70 |
+
inputs=inputs,
|
71 |
+
outputs=result,
|
72 |
+
api_name="canny",
|
73 |
+
concurrency_id="main",
|
74 |
+
)
|
75 |
+
return demo
|
76 |
+
|
77 |
+
|
78 |
+
if __name__ == "__main__":
|
79 |
+
from model import Model
|
80 |
+
|
81 |
+
model = Model(task_name="Canny")
|
82 |
+
demo = create_demo(model.process_canny)
|
83 |
+
demo.queue().launch()
|
app_matnet.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
|
5 |
+
from settings import (
|
6 |
+
DEFAULT_IMAGE_RESOLUTION,
|
7 |
+
DEFAULT_NUM_IMAGES,
|
8 |
+
MAX_IMAGE_RESOLUTION,
|
9 |
+
MAX_NUM_IMAGES,
|
10 |
+
MAX_SEED,
|
11 |
+
)
|
12 |
+
from utils import randomize_seed_fn
|
13 |
+
|
14 |
+
|
15 |
+
def create_demo(process):
|
16 |
+
with gr.Blocks() as demo:
|
17 |
+
with gr.Row():
|
18 |
+
with gr.Column():
|
19 |
+
image = gr.Image()
|
20 |
+
prompt = gr.Textbox(label="Prompt", submit_btn=True)
|
21 |
+
with gr.Accordion("Advanced options", open=False):
|
22 |
+
num_samples = gr.Slider(
|
23 |
+
label="Number of images", minimum=1, maximum=MAX_NUM_IMAGES, value=DEFAULT_NUM_IMAGES, step=1
|
24 |
+
)
|
25 |
+
image_resolution = gr.Slider(
|
26 |
+
label="Image resolution",
|
27 |
+
minimum=256,
|
28 |
+
maximum=MAX_IMAGE_RESOLUTION,
|
29 |
+
value=DEFAULT_IMAGE_RESOLUTION,
|
30 |
+
step=256,
|
31 |
+
)
|
32 |
+
canny_low_threshold = gr.Slider(
|
33 |
+
label="Canny low threshold", minimum=1, maximum=255, value=100, step=1
|
34 |
+
)
|
35 |
+
canny_high_threshold = gr.Slider(
|
36 |
+
label="Canny high threshold", minimum=1, maximum=255, value=200, step=1
|
37 |
+
)
|
38 |
+
num_steps = gr.Slider(label="Number of steps", minimum=1, maximum=100, value=20, step=1)
|
39 |
+
guidance_scale = gr.Slider(label="Guidance scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
|
40 |
+
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
|
41 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
42 |
+
a_prompt = gr.Textbox(label="Additional prompt", value="best quality, extremely detailed")
|
43 |
+
n_prompt = gr.Textbox(
|
44 |
+
label="Negative prompt",
|
45 |
+
value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
|
46 |
+
)
|
47 |
+
with gr.Column():
|
48 |
+
result = gr.Gallery(label="Output", show_label=False, columns=2, object_fit="scale-down")
|
49 |
+
inputs = [
|
50 |
+
image,
|
51 |
+
prompt,
|
52 |
+
a_prompt,
|
53 |
+
n_prompt,
|
54 |
+
num_samples,
|
55 |
+
image_resolution,
|
56 |
+
num_steps,
|
57 |
+
guidance_scale,
|
58 |
+
seed,
|
59 |
+
canny_low_threshold,
|
60 |
+
canny_high_threshold,
|
61 |
+
]
|
62 |
+
prompt.submit(
|
63 |
+
fn=randomize_seed_fn,
|
64 |
+
inputs=[seed, randomize_seed],
|
65 |
+
outputs=seed,
|
66 |
+
queue=False,
|
67 |
+
api_name=False,
|
68 |
+
).then(
|
69 |
+
fn=process,
|
70 |
+
inputs=inputs,
|
71 |
+
outputs=result,
|
72 |
+
api_name="canny",
|
73 |
+
concurrency_id="main",
|
74 |
+
)
|
75 |
+
return demo
|
76 |
+
|
77 |
+
|
78 |
+
if __name__ == "__main__":
|
79 |
+
from model import Model
|
80 |
+
|
81 |
+
model = Model(task_name="Canny")
|
82 |
+
demo = create_demo(model.process_canny)
|
83 |
+
demo.queue().launch()
|
app_sd.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
import random
|
4 |
+
|
5 |
+
import spaces #[uncomment to use ZeroGPU]
|
6 |
+
from diffusers import DiffusionPipeline
|
7 |
+
import torch
|
8 |
+
|
9 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
10 |
+
model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
|
11 |
+
|
12 |
+
if torch.cuda.is_available():
|
13 |
+
torch_dtype = torch.float16
|
14 |
+
else:
|
15 |
+
torch_dtype = torch.float32
|
16 |
+
|
17 |
+
pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
|
18 |
+
pipe = pipe.to(device)
|
19 |
+
|
20 |
+
MAX_SEED = np.iinfo(np.int32).max
|
21 |
+
MAX_IMAGE_SIZE = 1024
|
22 |
+
|
23 |
+
|
24 |
+
@spaces.GPU #[uncomment to use ZeroGPU]
|
25 |
+
def infer(
|
26 |
+
prompt,
|
27 |
+
negative_prompt,
|
28 |
+
seed,
|
29 |
+
randomize_seed,
|
30 |
+
width,
|
31 |
+
height,
|
32 |
+
guidance_scale,
|
33 |
+
num_inference_steps,
|
34 |
+
progress=gr.Progress(track_tqdm=True),
|
35 |
+
):
|
36 |
+
if randomize_seed:
|
37 |
+
seed = random.randint(0, MAX_SEED)
|
38 |
+
|
39 |
+
generator = torch.Generator().manual_seed(seed)
|
40 |
+
|
41 |
+
image = pipe(
|
42 |
+
prompt=prompt,
|
43 |
+
negative_prompt=negative_prompt,
|
44 |
+
guidance_scale=guidance_scale,
|
45 |
+
num_inference_steps=num_inference_steps,
|
46 |
+
width=width,
|
47 |
+
height=height,
|
48 |
+
generator=generator,
|
49 |
+
).images[0]
|
50 |
+
|
51 |
+
return image, seed
|
52 |
+
|
53 |
+
|
54 |
+
examples = [
|
55 |
+
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
|
56 |
+
"An astronaut riding a green horse",
|
57 |
+
"A delicious ceviche cheesecake slice",
|
58 |
+
]
|
59 |
+
|
60 |
+
css = """
|
61 |
+
#col-container {
|
62 |
+
margin: 0 auto;
|
63 |
+
max-width: 640px;
|
64 |
+
}
|
65 |
+
"""
|
66 |
+
|
67 |
+
with gr.Blocks(css=css) as demo:
|
68 |
+
with gr.Column(elem_id="col-container"):
|
69 |
+
gr.Markdown(" # Text-to-Image Gradio Template")
|
70 |
+
|
71 |
+
with gr.Row():
|
72 |
+
prompt = gr.Text(
|
73 |
+
label="Prompt",
|
74 |
+
show_label=False,
|
75 |
+
max_lines=1,
|
76 |
+
placeholder="Enter your prompt",
|
77 |
+
container=False,
|
78 |
+
)
|
79 |
+
|
80 |
+
run_button = gr.Button("Run", scale=0, variant="primary")
|
81 |
+
|
82 |
+
result = gr.Image(label="Result", show_label=False)
|
83 |
+
|
84 |
+
with gr.Accordion("Advanced Settings", open=False):
|
85 |
+
negative_prompt = gr.Text(
|
86 |
+
label="Negative prompt",
|
87 |
+
max_lines=1,
|
88 |
+
placeholder="Enter a negative prompt",
|
89 |
+
visible=False,
|
90 |
+
)
|
91 |
+
|
92 |
+
seed = gr.Slider(
|
93 |
+
label="Seed",
|
94 |
+
minimum=0,
|
95 |
+
maximum=MAX_SEED,
|
96 |
+
step=1,
|
97 |
+
value=0,
|
98 |
+
)
|
99 |
+
|
100 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
101 |
+
|
102 |
+
with gr.Row():
|
103 |
+
width = gr.Slider(
|
104 |
+
label="Width",
|
105 |
+
minimum=256,
|
106 |
+
maximum=MAX_IMAGE_SIZE,
|
107 |
+
step=32,
|
108 |
+
value=1024, # Replace with defaults that work for your model
|
109 |
+
)
|
110 |
+
|
111 |
+
height = gr.Slider(
|
112 |
+
label="Height",
|
113 |
+
minimum=256,
|
114 |
+
maximum=MAX_IMAGE_SIZE,
|
115 |
+
step=32,
|
116 |
+
value=1024, # Replace with defaults that work for your model
|
117 |
+
)
|
118 |
+
|
119 |
+
with gr.Row():
|
120 |
+
guidance_scale = gr.Slider(
|
121 |
+
label="Guidance scale",
|
122 |
+
minimum=0.0,
|
123 |
+
maximum=10.0,
|
124 |
+
step=0.1,
|
125 |
+
value=0.0, # Replace with defaults that work for your model
|
126 |
+
)
|
127 |
+
|
128 |
+
num_inference_steps = gr.Slider(
|
129 |
+
label="Number of inference steps",
|
130 |
+
minimum=1,
|
131 |
+
maximum=50,
|
132 |
+
step=1,
|
133 |
+
value=2, # Replace with defaults that work for your model
|
134 |
+
)
|
135 |
+
|
136 |
+
gr.Examples(examples=examples, inputs=[prompt])
|
137 |
+
gr.on(
|
138 |
+
triggers=[run_button.click, prompt.submit],
|
139 |
+
fn=infer,
|
140 |
+
inputs=[
|
141 |
+
prompt,
|
142 |
+
negative_prompt,
|
143 |
+
seed,
|
144 |
+
randomize_seed,
|
145 |
+
width,
|
146 |
+
height,
|
147 |
+
guidance_scale,
|
148 |
+
num_inference_steps,
|
149 |
+
],
|
150 |
+
outputs=[result, seed],
|
151 |
+
)
|
152 |
+
|
153 |
+
if __name__ == "__main__":
|
154 |
+
demo.launch()
|
app_texnet.py
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
import os
|
4 |
+
import shutil
|
5 |
+
import tempfile
|
6 |
+
import gradio as gr
|
7 |
+
from PIL import Image
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
from settings import (
|
11 |
+
DEFAULT_IMAGE_RESOLUTION,
|
12 |
+
DEFAULT_NUM_IMAGES,
|
13 |
+
MAX_IMAGE_RESOLUTION,
|
14 |
+
MAX_NUM_IMAGES,
|
15 |
+
MAX_SEED,
|
16 |
+
)
|
17 |
+
from utils import randomize_seed_fn
|
18 |
+
|
19 |
+
# ---- helper to build a quick textured copy of the mesh ---------------
|
20 |
+
def apply_texture(src_mesh:str, texture:str, tag:str)->str:
|
21 |
+
"""
|
22 |
+
Writes a copy of `src_mesh` and tiny .mtl that points to `texture`.
|
23 |
+
Returns the new OBJ/GLB path for viewing.
|
24 |
+
"""
|
25 |
+
tmp_dir = tempfile.mkdtemp()
|
26 |
+
mesh_copy = os.path.join(tmp_dir, f"{tag}.obj")
|
27 |
+
mtl_name = f"{tag}.mtl"
|
28 |
+
|
29 |
+
# copy geometry
|
30 |
+
shutil.copy(src_mesh, mesh_copy)
|
31 |
+
|
32 |
+
# write minimal MTL
|
33 |
+
with open(os.path.join(tmp_dir, mtl_name), "w") as f:
|
34 |
+
f.write(f"newmtl material_0\nmap_Kd {os.path.basename(texture)}\n")
|
35 |
+
|
36 |
+
# ensure texture lives next to OBJ
|
37 |
+
shutil.copy(texture, os.path.join(tmp_dir, os.path.basename(texture)))
|
38 |
+
|
39 |
+
# patch OBJ to reference our new MTL
|
40 |
+
with open(mesh_copy, "r+") as f:
|
41 |
+
lines = f.readlines()
|
42 |
+
if not lines[0].startswith("mtllib"):
|
43 |
+
lines.insert(0, f"mtllib {mtl_name}\n")
|
44 |
+
f.seek(0); f.writelines(lines)
|
45 |
+
|
46 |
+
return mesh_copy
|
47 |
+
|
48 |
+
def image_to_temp_path(img_like, tag, out_dir=None):
|
49 |
+
"""
|
50 |
+
Convert various image-like objects (str, PIL.Image, list, tuple) to temp PNG path.
|
51 |
+
Returns the path to the saved image file.
|
52 |
+
"""
|
53 |
+
# Handle tuple or list input
|
54 |
+
if isinstance(img_like, (list, tuple)):
|
55 |
+
if len(img_like) == 0:
|
56 |
+
raise ValueError("Empty image list/tuple.")
|
57 |
+
img_like = img_like[0]
|
58 |
+
|
59 |
+
# If it's already a file path
|
60 |
+
if isinstance(img_like, str):
|
61 |
+
return img_like
|
62 |
+
|
63 |
+
# If it's a PIL Image
|
64 |
+
if isinstance(img_like, Image.Image):
|
65 |
+
temp_path = os.path.join(tempfile.mkdtemp() if out_dir is None else out_dir, f"{tag}.png")
|
66 |
+
os.makedirs(os.path.dirname(temp_path), exist_ok=True)
|
67 |
+
img_like.save(temp_path)
|
68 |
+
return temp_path
|
69 |
+
|
70 |
+
# if it's numpy array
|
71 |
+
if isinstance(img_like, np.ndarray):
|
72 |
+
temp_path = os.path.join(tempfile.mkdtemp() if out_dir is None else out_dir, f"{tag}.png")
|
73 |
+
os.makedirs(os.path.dirname(temp_path), exist_ok=True)
|
74 |
+
img_like = Image.fromarray(img_like)
|
75 |
+
img_like.save(temp_path)
|
76 |
+
return temp_path
|
77 |
+
|
78 |
+
raise ValueError(f"Expected PIL.Image, str, list, or tuple — got {type(img_like)}")
|
79 |
+
|
80 |
+
def show_mesh(which, mesh, inp, coarse, fine):
|
81 |
+
"""Switch the displayed texture based on dropdown change."""
|
82 |
+
print()
|
83 |
+
tex_map = {
|
84 |
+
"Input": image_to_temp_path(inp, "input"),
|
85 |
+
"Coarse": coarse[0] if isinstance(coarse, tuple) else coarse,
|
86 |
+
"Fine": fine[0] if isinstance(fine, tuple) else fine,
|
87 |
+
}
|
88 |
+
texture_path = tex_map[which]
|
89 |
+
return apply_texture(mesh, texture_path, which.lower())
|
90 |
+
# ----------------------------------------------------------------------
|
91 |
+
|
92 |
+
|
93 |
+
def create_demo(process):
|
94 |
+
with gr.Blocks() as demo:
|
95 |
+
with gr.Row():
|
96 |
+
with gr.Column():
|
97 |
+
gr.Markdown("## Select preset from the example list, and modify the prompt accordingly")
|
98 |
+
with gr.Row():
|
99 |
+
name = gr.Textbox(label="Name", interactive=False, visible=False)
|
100 |
+
representative = gr.Image(label="Geometry", interactive=False)
|
101 |
+
image = gr.Image(label="UV Normal", interactive=False)
|
102 |
+
prompt = gr.Textbox(label="Prompt", submit_btn=True)
|
103 |
+
with gr.Accordion("Advanced options", open=False):
|
104 |
+
num_samples = gr.Slider(
|
105 |
+
label="Number of images", minimum=1, maximum=MAX_NUM_IMAGES, value=DEFAULT_NUM_IMAGES, step=1
|
106 |
+
)
|
107 |
+
image_resolution = gr.Slider(
|
108 |
+
label="Image resolution",
|
109 |
+
minimum=256,
|
110 |
+
maximum=MAX_IMAGE_RESOLUTION,
|
111 |
+
value=DEFAULT_IMAGE_RESOLUTION,
|
112 |
+
step=256,
|
113 |
+
)
|
114 |
+
num_steps = gr.Slider(label="Number of steps", minimum=1, maximum=100, value=10, step=1)
|
115 |
+
guidance_scale = gr.Slider(label="Guidance scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
|
116 |
+
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
|
117 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
118 |
+
a_prompt = gr.Textbox(label="Additional prompt", value="best quality, extremely detailed")
|
119 |
+
n_prompt = gr.Textbox(
|
120 |
+
label="Negative prompt",
|
121 |
+
value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
|
122 |
+
)
|
123 |
+
with gr.Column():
|
124 |
+
# 2x2 grid of images for the output textures
|
125 |
+
gr.Markdown("### Output BRDF")
|
126 |
+
with gr.Row():
|
127 |
+
base_color = gr.Gallery(label="Base Color", show_label=True, columns=1, object_fit="scale-down")
|
128 |
+
normal = gr.Gallery(label="Displacement Map", show_label=True, columns=1, object_fit="scale-down")
|
129 |
+
with gr.Row():
|
130 |
+
roughness = gr.Gallery(label="Roughness Map", show_label=True, columns=1, object_fit="scale-down")
|
131 |
+
metallic = gr.Gallery(label="Metallic Map", show_label=True, columns=1, object_fit="scale-down")
|
132 |
+
|
133 |
+
gr.Markdown("### Download Packed Blender Files for 3D Visualization")
|
134 |
+
out_blender_path = gr.File(label="Generated Blender File", file_types=[".blend"])
|
135 |
+
|
136 |
+
inputs = [
|
137 |
+
name, # Name of the object
|
138 |
+
representative, # Geometry mesh
|
139 |
+
image,
|
140 |
+
prompt,
|
141 |
+
a_prompt,
|
142 |
+
n_prompt,
|
143 |
+
num_samples,
|
144 |
+
image_resolution,
|
145 |
+
num_steps,
|
146 |
+
guidance_scale,
|
147 |
+
seed,
|
148 |
+
]
|
149 |
+
|
150 |
+
# first call → run diffusion / texture network
|
151 |
+
prompt.submit(
|
152 |
+
fn=randomize_seed_fn,
|
153 |
+
inputs=[seed, randomize_seed],
|
154 |
+
outputs=seed,
|
155 |
+
queue=False,
|
156 |
+
api_name=False,
|
157 |
+
).then(
|
158 |
+
fn=process,
|
159 |
+
inputs=inputs,
|
160 |
+
outputs=[base_color, normal, roughness, metallic, out_blender_path],
|
161 |
+
api_name="canny",
|
162 |
+
concurrency_id="main",
|
163 |
+
)
|
164 |
+
|
165 |
+
gr.Examples(
|
166 |
+
fn=process,
|
167 |
+
inputs=inputs,
|
168 |
+
outputs=[base_color, normal, roughness, metallic],
|
169 |
+
examples=[
|
170 |
+
[
|
171 |
+
"bunny",
|
172 |
+
"examples/bunny/frame_0001.png", # /dgxusers/Users/jyang/project/ObjectReal/data/control/preprocess/bunny/uv_normal/fused.png
|
173 |
+
"examples/bunny/uv_normal.png", # /dgxusers/Users/jyang/project/ObjectReal/data/control/preprocess/bunny/uv_normal/fused.png
|
174 |
+
"feather",
|
175 |
+
a_prompt.value,
|
176 |
+
n_prompt.value,
|
177 |
+
num_samples.value,
|
178 |
+
image_resolution.value,
|
179 |
+
num_steps.value,
|
180 |
+
guidance_scale.value,
|
181 |
+
seed.value,
|
182 |
+
],
|
183 |
+
[
|
184 |
+
"monkey",
|
185 |
+
"examples/monkey/frame_0001.png", # /dgxusers/Users/jyang/project/ObjectReal/data/control/preprocess/monkey/uv_normal/fused.png
|
186 |
+
"examples/monkey/uv_normal.png", # /dgxusers/Users/jyang/project/ObjectReal/data/control/preprocess/monkey/uv_normal/fused.png
|
187 |
+
"wood",
|
188 |
+
a_prompt.value,
|
189 |
+
n_prompt.value,
|
190 |
+
num_samples.value,
|
191 |
+
image_resolution.value,
|
192 |
+
num_steps.value,
|
193 |
+
guidance_scale.value,
|
194 |
+
seed.value,
|
195 |
+
],
|
196 |
+
[
|
197 |
+
"tshirt",
|
198 |
+
"examples/tshirt/frame_0001.png", # /dgxusers/Users/jyang/project/ObjectReal/data/control/preprocess/monkey/uv_normal/fused.png
|
199 |
+
"examples/tshirt/uv_normal.png", # /dgxusers/Users/jyang/project/ObjectReal/data/control/preprocess/monkey/uv_normal/fused.png
|
200 |
+
"wood",
|
201 |
+
a_prompt.value,
|
202 |
+
n_prompt.value,
|
203 |
+
num_samples.value,
|
204 |
+
image_resolution.value,
|
205 |
+
num_steps.value,
|
206 |
+
guidance_scale.value,
|
207 |
+
seed.value,
|
208 |
+
],
|
209 |
+
# [
|
210 |
+
# "highheel",
|
211 |
+
# "examples/highheel/frame_0001.png", # /dgxusers/Users/jyang/project/ObjectReal/data/control/preprocess/monkey/uv_normal/fused.png
|
212 |
+
# "examples/highheel/uv_normal.png", # /dgxusers/Users/jyang/project/ObjectReal/data/control/preprocess/monkey/uv_normal/fused.png
|
213 |
+
# "wood",
|
214 |
+
# a_prompt.value,
|
215 |
+
# n_prompt.value,
|
216 |
+
# num_samples.value,
|
217 |
+
# image_resolution.value,
|
218 |
+
# num_steps.value,
|
219 |
+
# guidance_scale.value,
|
220 |
+
# seed.value,
|
221 |
+
# ],
|
222 |
+
[
|
223 |
+
"tank",
|
224 |
+
"examples/tank/frame_0001.png", # /dgxusers/Users/jyang/project/ObjectReal/data/control/preprocess/monkey/uv_normal/fused.png
|
225 |
+
"examples/tank/uv_normal.png", # /dgxusers/Users/jyang/project/ObjectReal/data/control/preprocess/monkey/uv_normal/fused.png
|
226 |
+
"wood",
|
227 |
+
a_prompt.value,
|
228 |
+
n_prompt.value,
|
229 |
+
num_samples.value,
|
230 |
+
image_resolution.value,
|
231 |
+
num_steps.value,
|
232 |
+
guidance_scale.value,
|
233 |
+
seed.value,
|
234 |
+
],
|
235 |
+
[
|
236 |
+
"fighter",
|
237 |
+
"examples/fighter/frame_0001.png", # /dgxusers/Users/jyang/project/ObjectReal/data/control/preprocess/monkey/uv_normal/fused.png
|
238 |
+
"examples/fighter/uv_normal.png", # /dgxusers/Users/jyang/project/ObjectReal/data/control/preprocess/monkey/uv_normal/fused.png
|
239 |
+
"wood",
|
240 |
+
a_prompt.value,
|
241 |
+
n_prompt.value,
|
242 |
+
num_samples.value,
|
243 |
+
image_resolution.value,
|
244 |
+
num_steps.value,
|
245 |
+
guidance_scale.value,
|
246 |
+
seed.value,
|
247 |
+
],
|
248 |
+
],
|
249 |
+
)
|
250 |
+
|
251 |
+
return demo
|
252 |
+
|
253 |
+
|
254 |
+
if __name__ == "__main__":
|
255 |
+
from model import Model
|
256 |
+
|
257 |
+
model = Model(task_name="Texnet")
|
258 |
+
demo = create_demo(model.process_texnet)
|
259 |
+
demo.queue().launch()
|
cv_utils.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
def resize_image(input_image, resolution, interpolation=None):
|
6 |
+
H, W, C = input_image.shape
|
7 |
+
H = float(H)
|
8 |
+
W = float(W)
|
9 |
+
k = float(resolution) / max(H, W)
|
10 |
+
H *= k
|
11 |
+
W *= k
|
12 |
+
H = int(np.round(H / 64.0)) * 64
|
13 |
+
W = int(np.round(W / 64.0)) * 64
|
14 |
+
if interpolation is None:
|
15 |
+
interpolation = cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA
|
16 |
+
img = cv2.resize(input_image, (W, H), interpolation=interpolation)
|
17 |
+
return img
|
depth_estimator.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import PIL.Image
|
3 |
+
from controlnet_aux.util import HWC3
|
4 |
+
from transformers import pipeline
|
5 |
+
|
6 |
+
from cv_utils import resize_image
|
7 |
+
|
8 |
+
|
9 |
+
class DepthEstimator:
|
10 |
+
def __init__(self):
|
11 |
+
self.model = pipeline("depth-estimation")
|
12 |
+
|
13 |
+
def __call__(self, image: np.ndarray, **kwargs) -> PIL.Image.Image:
|
14 |
+
detect_resolution = kwargs.pop("detect_resolution", 512)
|
15 |
+
image_resolution = kwargs.pop("image_resolution", 512)
|
16 |
+
image = np.array(image)
|
17 |
+
image = HWC3(image)
|
18 |
+
image = resize_image(image, resolution=detect_resolution)
|
19 |
+
image = PIL.Image.fromarray(image)
|
20 |
+
image = self.model(image)
|
21 |
+
image = image["depth"]
|
22 |
+
image = np.array(image)
|
23 |
+
image = HWC3(image)
|
24 |
+
image = resize_image(image, resolution=image_resolution)
|
25 |
+
return PIL.Image.fromarray(image)
|
examples/bunny/mesh.obj
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
examples/fighter/mesh.obj
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
examples/highheel/mesh.obj
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
examples/monkey/mesh.obj
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
examples/tank/mesh.obj
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
image_segmentor.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import PIL.Image
|
4 |
+
import torch
|
5 |
+
from controlnet_aux.util import HWC3, ade_palette
|
6 |
+
from transformers import AutoImageProcessor, UperNetForSemanticSegmentation
|
7 |
+
|
8 |
+
from cv_utils import resize_image
|
9 |
+
|
10 |
+
|
11 |
+
class ImageSegmentor:
|
12 |
+
def __init__(self):
|
13 |
+
self.image_processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-small")
|
14 |
+
self.image_segmentor = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-convnext-small")
|
15 |
+
|
16 |
+
@torch.inference_mode()
|
17 |
+
def __call__(self, image: np.ndarray, **kwargs) -> PIL.Image.Image:
|
18 |
+
detect_resolution = kwargs.pop("detect_resolution", 512)
|
19 |
+
image_resolution = kwargs.pop("image_resolution", 512)
|
20 |
+
image = HWC3(image)
|
21 |
+
image = resize_image(image, resolution=detect_resolution)
|
22 |
+
image = PIL.Image.fromarray(image)
|
23 |
+
|
24 |
+
pixel_values = self.image_processor(image, return_tensors="pt").pixel_values
|
25 |
+
outputs = self.image_segmentor(pixel_values)
|
26 |
+
seg = self.image_processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
|
27 |
+
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
|
28 |
+
for label, color in enumerate(ade_palette()):
|
29 |
+
color_seg[seg == label, :] = color
|
30 |
+
color_seg = color_seg.astype(np.uint8)
|
31 |
+
|
32 |
+
color_seg = resize_image(color_seg, resolution=image_resolution, interpolation=cv2.INTER_NEAREST)
|
33 |
+
return PIL.Image.fromarray(color_seg)
|
install.sh
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
eval "$(conda shell.bash hook)"
|
3 |
+
# conda activate base
|
4 |
+
# conda remove -n matgen-plus --all
|
5 |
+
|
6 |
+
conda create -n matgen-plus python=3.11
|
7 |
+
conda activate matgen-plus
|
8 |
+
|
9 |
+
pip install diffusers["torch"] transformers accelerate xformers
|
10 |
+
pip install gradio
|
11 |
+
pip install controlnet-aux
|
12 |
+
|
13 |
+
# text2tex
|
14 |
+
conda install pytorch3d -c pytorch -c conda-forge
|
15 |
+
conda install -c conda-forge open-clip-torch pytorch-lightning
|
16 |
+
pip install trimesh xatlas scikit-learn opencv-python omegaconf
|
17 |
+
|
18 |
+
python app.py
|
model.py
ADDED
@@ -0,0 +1,959 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
|
3 |
+
# get socket and check if the name is vgldgx01
|
4 |
+
import socket
|
5 |
+
if socket.gethostname() != "vgldgx01":
|
6 |
+
import spaces #[uncomment to use ZeroGPU]
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import PIL.Image
|
10 |
+
import torch
|
11 |
+
from controlnet_aux.util import HWC3
|
12 |
+
from diffusers import (
|
13 |
+
ControlNetModel,
|
14 |
+
DiffusionPipeline,
|
15 |
+
StableDiffusionControlNetPipeline,
|
16 |
+
StableDiffusionImg2ImgPipeline,
|
17 |
+
UniPCMultistepScheduler,
|
18 |
+
DDIMScheduler, #rgb2x
|
19 |
+
)
|
20 |
+
import torchvision
|
21 |
+
from torchvision import transforms
|
22 |
+
from cv_utils import resize_image
|
23 |
+
from preprocessor import Preprocessor
|
24 |
+
from settings import MAX_IMAGE_RESOLUTION, MAX_NUM_IMAGES
|
25 |
+
from tqdm.auto import tqdm
|
26 |
+
import subprocess
|
27 |
+
|
28 |
+
from rgb2x.pipeline_rgb2x import StableDiffusionAOVMatEstPipeline
|
29 |
+
from app_texnet import image_to_temp_path
|
30 |
+
import os
|
31 |
+
import time
|
32 |
+
import tempfile
|
33 |
+
from text2tex.scripts.generate_texture import text2tex_call, init_args
|
34 |
+
from glob import glob
|
35 |
+
|
36 |
+
CONTROLNET_MODEL_IDS = {
|
37 |
+
# "Openpose": "lllyasviel/control_v11p_sd15_openpose",
|
38 |
+
# "Canny": "lllyasviel/control_v11p_sd15_canny",
|
39 |
+
# "MLSD": "lllyasviel/control_v11p_sd15_mlsd",
|
40 |
+
# "scribble": "lllyasviel/control_v11p_sd15_scribble",
|
41 |
+
# "softedge": "lllyasviel/control_v11p_sd15_softedge",
|
42 |
+
# "segmentation": "lllyasviel/control_v11p_sd15_seg",
|
43 |
+
# "depth": "lllyasviel/control_v11f1p_sd15_depth",
|
44 |
+
# "NormalBae": "lllyasviel/control_v11p_sd15_normalbae",
|
45 |
+
# "lineart": "lllyasviel/control_v11p_sd15_lineart",
|
46 |
+
# "lineart_anime": "lllyasviel/control_v11p_sd15s2_lineart_anime",
|
47 |
+
# "shuffle": "lllyasviel/control_v11e_sd15_shuffle",
|
48 |
+
# "ip2p": "lllyasviel/control_v11e_sd15_ip2p",
|
49 |
+
# "inpaint": "lllyasviel/control_v11e_sd15_inpaint",
|
50 |
+
# "texnet": "/home/jyang/projects/ObjectReal/logs/train_texnet_deploy/checkpoint-55000/controlnet" # load and call
|
51 |
+
"texnet": "jingyangcarl/texnet",
|
52 |
+
}
|
53 |
+
|
54 |
+
|
55 |
+
def download_all_controlnet_weights() -> None:
|
56 |
+
for model_id in CONTROLNET_MODEL_IDS.values():
|
57 |
+
ControlNetModel.from_pretrained(model_id)
|
58 |
+
|
59 |
+
|
60 |
+
class Model:
|
61 |
+
def __init__(
|
62 |
+
self, base_model_id: str = "stable-diffusion-v1-5/stable-diffusion-v1-5", task_name: str = "Canny"
|
63 |
+
) -> None:
|
64 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
65 |
+
self.base_model_id = ""
|
66 |
+
self.task_name = ""
|
67 |
+
self.pipe = self.load_pipe(base_model_id, task_name)
|
68 |
+
self.pipe_base = StableDiffusionImg2ImgPipeline.from_pretrained(
|
69 |
+
'runwayml/stable-diffusion-v1-5', safety_checker=None, torch_dtype=torch.float16
|
70 |
+
).to(self.device)
|
71 |
+
self.preprocessor = Preprocessor()
|
72 |
+
|
73 |
+
# set up pipe_rgb2x
|
74 |
+
self.pipe_rgb2x = StableDiffusionAOVMatEstPipeline.from_pretrained(
|
75 |
+
"zheng95z/rgb-to-x",
|
76 |
+
torch_dtype=torch.float16,
|
77 |
+
).to(self.device)
|
78 |
+
self.pipe_rgb2x.scheduler = DDIMScheduler.from_config(
|
79 |
+
self.pipe_rgb2x.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing"
|
80 |
+
)
|
81 |
+
self.pipe_rgb2x.set_progress_bar_config(disable=True)
|
82 |
+
|
83 |
+
# setup blender
|
84 |
+
self.blender_path = '/tmp/blender-3.2.2-linux-x64/blender'
|
85 |
+
if not os.path.exists(self.blender_path):
|
86 |
+
print("Downloading Blender...")
|
87 |
+
subprocess.run(["wget", "https://download.blender.org/release/Blender3.2/blender-3.2.2-linux-x64.tar.xz", "-O", "/tmp/blender-3.2.2-linux-x64.tar.xz"], check=True)
|
88 |
+
subprocess.run(["tar", "-xf", "/tmp/blender-3.2.2-linux-x64.tar.xz", "-C", "/tmp"], check=True)
|
89 |
+
print("Blender downloaded and extracted.")
|
90 |
+
|
91 |
+
def load_pipe(self, base_model_id: str, task_name: str) -> DiffusionPipeline:
|
92 |
+
if (
|
93 |
+
base_model_id == self.base_model_id
|
94 |
+
and task_name == self.task_name
|
95 |
+
and hasattr(self, "pipe")
|
96 |
+
and self.pipe is not None
|
97 |
+
):
|
98 |
+
return self.pipe
|
99 |
+
model_id = CONTROLNET_MODEL_IDS[task_name]
|
100 |
+
controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float16)
|
101 |
+
to_upload = False
|
102 |
+
if to_upload:
|
103 |
+
# confirm before uploading
|
104 |
+
confirm = input(f"Do you want to upload {model_id} to the hub? (y/n): ")
|
105 |
+
if confirm.lower() == "y":
|
106 |
+
controlnet.push_to_hub("jingyangcarl/texnet")
|
107 |
+
else:
|
108 |
+
print("Upload cancelled.")
|
109 |
+
pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
110 |
+
base_model_id, safety_checker=None, controlnet=controlnet, torch_dtype=torch.float16
|
111 |
+
)
|
112 |
+
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
113 |
+
pipe.to(self.device)
|
114 |
+
if self.device.type == "cuda":
|
115 |
+
import os
|
116 |
+
if os.environ.get("SPACES_ZERO_GPU", "0") == "1":
|
117 |
+
# when running on ZeroGPU, enable CPU offload
|
118 |
+
# pipe.enable_xformers_memory_efficient_attention() doens't work
|
119 |
+
# pipe.enable_model_cpu_offload()
|
120 |
+
pass
|
121 |
+
else:
|
122 |
+
pipe.enable_xformers_memory_efficient_attention()
|
123 |
+
torch.cuda.empty_cache()
|
124 |
+
gc.collect()
|
125 |
+
self.base_model_id = base_model_id
|
126 |
+
self.task_name = task_name
|
127 |
+
return pipe
|
128 |
+
|
129 |
+
def set_base_model(self, base_model_id: str) -> str:
|
130 |
+
if not base_model_id or base_model_id == self.base_model_id:
|
131 |
+
return self.base_model_id
|
132 |
+
del self.pipe
|
133 |
+
torch.cuda.empty_cache()
|
134 |
+
gc.collect()
|
135 |
+
try:
|
136 |
+
self.pipe = self.load_pipe(base_model_id, self.task_name)
|
137 |
+
except Exception: # noqa: BLE001
|
138 |
+
self.pipe = self.load_pipe(self.base_model_id, self.task_name)
|
139 |
+
return self.base_model_id
|
140 |
+
|
141 |
+
def load_controlnet_weight(self, task_name: str) -> None:
|
142 |
+
if task_name == self.task_name:
|
143 |
+
return
|
144 |
+
if self.pipe is not None and hasattr(self.pipe, "controlnet"):
|
145 |
+
del self.pipe.controlnet
|
146 |
+
torch.cuda.empty_cache()
|
147 |
+
gc.collect()
|
148 |
+
model_id = CONTROLNET_MODEL_IDS[task_name]
|
149 |
+
controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float16)
|
150 |
+
controlnet.to(self.device)
|
151 |
+
torch.cuda.empty_cache()
|
152 |
+
gc.collect()
|
153 |
+
self.pipe.controlnet = controlnet
|
154 |
+
self.task_name = task_name
|
155 |
+
|
156 |
+
def get_prompt(self, prompt: str, additional_prompt: str) -> str:
|
157 |
+
return additional_prompt if not prompt else f"{prompt}, {additional_prompt}"
|
158 |
+
|
159 |
+
# @spaces.GPU #[uncomment to use ZeroGPU]
|
160 |
+
@torch.autocast("cuda")
|
161 |
+
def run_pipe(
|
162 |
+
self,
|
163 |
+
prompt: str,
|
164 |
+
negative_prompt: str,
|
165 |
+
control_image: PIL.Image.Image,
|
166 |
+
num_images: int,
|
167 |
+
num_steps: int,
|
168 |
+
guidance_scale: float,
|
169 |
+
seed: int,
|
170 |
+
) -> list[PIL.Image.Image]:
|
171 |
+
generator = torch.Generator().manual_seed(seed)
|
172 |
+
# self.pipe.to(self.device)
|
173 |
+
return self.pipe(
|
174 |
+
prompt=prompt,
|
175 |
+
negative_prompt=negative_prompt,
|
176 |
+
guidance_scale=guidance_scale,
|
177 |
+
num_images_per_prompt=num_images,
|
178 |
+
num_inference_steps=num_steps,
|
179 |
+
generator=generator,
|
180 |
+
image=control_image,
|
181 |
+
).images
|
182 |
+
|
183 |
+
# @spaces.GPU #[uncomment to use ZeroGPU]
|
184 |
+
@torch.inference_mode()
|
185 |
+
def process_texnet(
|
186 |
+
self,
|
187 |
+
obj_name: str,
|
188 |
+
represented_image: np.ndarray | None, # not used
|
189 |
+
image: np.ndarray,
|
190 |
+
prompt: str,
|
191 |
+
additional_prompt: str,
|
192 |
+
negative_prompt: str,
|
193 |
+
num_images: int,
|
194 |
+
image_resolution: int,
|
195 |
+
num_steps: int,
|
196 |
+
guidance_scale: float,
|
197 |
+
seed: int,
|
198 |
+
low_threshold: int,
|
199 |
+
high_threshold: int,
|
200 |
+
) -> list[PIL.Image.Image]:
|
201 |
+
if image is None:
|
202 |
+
raise ValueError
|
203 |
+
if image_resolution > MAX_IMAGE_RESOLUTION:
|
204 |
+
raise ValueError
|
205 |
+
if num_images > MAX_NUM_IMAGES:
|
206 |
+
raise ValueError
|
207 |
+
|
208 |
+
prompt_nospace = prompt.replace(' ', '_')
|
209 |
+
|
210 |
+
# self.preprocessor.load("texnet")
|
211 |
+
# control_image = self.preprocessor(
|
212 |
+
# image=image, low_threshold=low_threshold, high_threshold=high_threshold, image_resolution=image_resolution, output_type="pil"
|
213 |
+
# )
|
214 |
+
|
215 |
+
# self.load_controlnet_weight("texnet")
|
216 |
+
# tex_coarse = self.run_pipe(
|
217 |
+
# prompt=self.get_prompt(prompt, additional_prompt),
|
218 |
+
# negative_prompt=negative_prompt,
|
219 |
+
# control_image=control_image,
|
220 |
+
# num_images=num_images,
|
221 |
+
# num_steps=num_steps,
|
222 |
+
# guidance_scale=guidance_scale,
|
223 |
+
# seed=seed,
|
224 |
+
# )
|
225 |
+
|
226 |
+
# # use img2img pipeline
|
227 |
+
# self.pipe_backup = self.pipe
|
228 |
+
# self.pipe = self.pipe_base
|
229 |
+
|
230 |
+
# # refine
|
231 |
+
tex_fine = []
|
232 |
+
mesh_fine = []
|
233 |
+
# for result_coarse in tex_coarse:
|
234 |
+
# # clean up GPU cache
|
235 |
+
# torch.cuda.empty_cache()
|
236 |
+
# gc.collect()
|
237 |
+
|
238 |
+
# # masking
|
239 |
+
# mask = (np.array(control_image).sum(axis=-1) == 0)[...,None]
|
240 |
+
# image_masked = PIL.Image.fromarray(np.where(mask, control_image, result_coarse))
|
241 |
+
# image_blurry = transforms.GaussianBlur(kernel_size=5, sigma=1)(image_masked)
|
242 |
+
# result_fine = self.run_pipe(
|
243 |
+
# # prompt=prompt,
|
244 |
+
# prompt=self.get_prompt(prompt, additional_prompt),
|
245 |
+
# negative_prompt=negative_prompt,
|
246 |
+
# control_image=image_blurry,
|
247 |
+
# num_images=1,
|
248 |
+
# num_steps=num_steps,
|
249 |
+
# guidance_scale=guidance_scale,
|
250 |
+
# seed=seed,
|
251 |
+
# )[0]
|
252 |
+
# result_fine = PIL.Image.fromarray(np.where(mask, control_image, result_fine))
|
253 |
+
# tex_fine.append(result_fine)
|
254 |
+
|
255 |
+
temp_out_path = tempfile.mkdtemp()
|
256 |
+
temp_out_path = 'output'
|
257 |
+
|
258 |
+
# put text2tex here,
|
259 |
+
args = init_args()
|
260 |
+
args.input_dir = f'examples/{obj_name}/'
|
261 |
+
args.output_dir = os.path.join(temp_out_path, f'{obj_name}/{prompt_nospace}')
|
262 |
+
args.obj_name = obj_name
|
263 |
+
args.obj_file = 'mesh.obj'
|
264 |
+
args.prompt = f'{prompt} {obj_name}'
|
265 |
+
args.add_view_to_prompt = True
|
266 |
+
args.ddim_steps = 5
|
267 |
+
# args.ddim_steps = 50
|
268 |
+
args.new_strength = 1.0
|
269 |
+
args.update_strength = 0.3
|
270 |
+
args.view_threshold = 0.1
|
271 |
+
args.blend = 0
|
272 |
+
args.dist = 1
|
273 |
+
args.num_viewpoints = 2
|
274 |
+
# args.num_viewpoints = 36
|
275 |
+
args.viewpoint_mode = 'predefined'
|
276 |
+
args.use_principle = True
|
277 |
+
args.update_steps = 2
|
278 |
+
# args.update_steps = 20
|
279 |
+
args.update_mode = 'heuristic'
|
280 |
+
args.seed = 42
|
281 |
+
args.post_process = True
|
282 |
+
args.device = '2080'
|
283 |
+
args.uv_size = 1000
|
284 |
+
args.image_size = 512
|
285 |
+
# args.image_size = 768
|
286 |
+
args.use_objaverse = True # assume the mesh is normalized with y-axis as up
|
287 |
+
output_dir = text2tex_call(args)
|
288 |
+
|
289 |
+
# get the texture and mesh with underscore '_post', which is the id of the last mesh, should be good for the visual
|
290 |
+
post_idx = glob(os.path.join(output_dir, 'update', 'mesh', "*_post.png"))[0].split('/')[-1].split('_')[0]
|
291 |
+
|
292 |
+
tex_fine.append(PIL.Image.open(os.path.join(output_dir, 'update', 'mesh', f"{post_idx}.png")).convert("RGB"))
|
293 |
+
mesh_fine.append(os.path.join(output_dir, 'update', 'mesh', f"{post_idx}.obj"))
|
294 |
+
torch.cuda.empty_cache()
|
295 |
+
|
296 |
+
# restore the original pipe
|
297 |
+
# self.pipe = self.pipe_backup
|
298 |
+
|
299 |
+
# use rgb2x for now for generating the texture
|
300 |
+
def rgb2x(
|
301 |
+
pipeline,
|
302 |
+
photo,
|
303 |
+
inference_step = 50,
|
304 |
+
num_samples = 1,
|
305 |
+
):
|
306 |
+
generator = torch.Generator(device="cuda").manual_seed(seed)
|
307 |
+
|
308 |
+
# Check if the width and height are multiples of 8. If not, crop it using torchvision.transforms.CenterCrop
|
309 |
+
old_height = photo.shape[1]
|
310 |
+
old_width = photo.shape[2]
|
311 |
+
new_height = old_height
|
312 |
+
new_width = old_width
|
313 |
+
radio = old_height / old_width
|
314 |
+
max_side = 1000
|
315 |
+
if old_height > old_width:
|
316 |
+
new_height = max_side
|
317 |
+
new_width = int(new_height / radio)
|
318 |
+
else:
|
319 |
+
new_width = max_side
|
320 |
+
new_height = int(new_width * radio)
|
321 |
+
|
322 |
+
if new_width % 8 != 0 or new_height % 8 != 0:
|
323 |
+
new_width = new_width // 8 * 8
|
324 |
+
new_height = new_height // 8 * 8
|
325 |
+
|
326 |
+
photo = torchvision.transforms.Resize((new_height, new_width))(photo)
|
327 |
+
|
328 |
+
required_aovs = ["albedo", "normal", "roughness", "metallic", "irradiance"]
|
329 |
+
prompts = {
|
330 |
+
"albedo": "Albedo (diffuse basecolor)",
|
331 |
+
"normal": "Camera-space Normal",
|
332 |
+
"roughness": "Roughness",
|
333 |
+
"metallic": "Metallicness",
|
334 |
+
"irradiance": "Irradiance (diffuse lighting)",
|
335 |
+
}
|
336 |
+
|
337 |
+
return_list = []
|
338 |
+
for i in tqdm(range(num_samples), desc="Running Pipeline", leave=False):
|
339 |
+
for aov_name in required_aovs:
|
340 |
+
prompt = prompts[aov_name]
|
341 |
+
generated_image = pipeline(
|
342 |
+
prompt=prompt,
|
343 |
+
photo=photo,
|
344 |
+
num_inference_steps=inference_step,
|
345 |
+
height=new_height,
|
346 |
+
width=new_width,
|
347 |
+
generator=generator,
|
348 |
+
required_aovs=[aov_name],
|
349 |
+
).images[0][0]
|
350 |
+
|
351 |
+
generated_image = torchvision.transforms.Resize(
|
352 |
+
(old_height, old_width)
|
353 |
+
)(generated_image)
|
354 |
+
|
355 |
+
# generated_image = (generated_image, f"Generated {aov_name} {i}")
|
356 |
+
# generated_image = (generated_image, f"{aov_name}")
|
357 |
+
return_list.append(generated_image)
|
358 |
+
|
359 |
+
return photo, return_list, prompts
|
360 |
+
|
361 |
+
# Load rgb2x pipeline
|
362 |
+
_, preds, prompts = rgb2x(self.pipe_rgb2x, torchvision.transforms.PILToTensor()(tex_fine[0]).to(self.pipe.device), inference_step=num_steps, num_samples=num_images)
|
363 |
+
|
364 |
+
intrinsic_dir = os.path.join(output_dir, 'intrinsic')
|
365 |
+
use_text2tex = True
|
366 |
+
if use_text2tex:
|
367 |
+
base_color_path = image_to_temp_path(tex_fine[0], "base_color", out_dir=intrinsic_dir)
|
368 |
+
normal_map_path = image_to_temp_path(preds[0], "normal_map", out_dir=intrinsic_dir)
|
369 |
+
roughness_path = image_to_temp_path(preds[1], "roughness", out_dir=intrinsic_dir)
|
370 |
+
metallic_path = image_to_temp_path(preds[2], "metallic", out_dir=intrinsic_dir)
|
371 |
+
else:
|
372 |
+
base_color_path = image_to_temp_path(tex_fine[0].rotate(90), "base_color", out_dir=intrinsic_dir)
|
373 |
+
normal_map_path = image_to_temp_path(preds[0].rotate(90), "normal_map", out_dir=intrinsic_dir)
|
374 |
+
roughness_path = image_to_temp_path(preds[1].rotate(90), "roughness", out_dir=intrinsic_dir)
|
375 |
+
metallic_path = image_to_temp_path(preds[2].rotate(90), "metallic", out_dir=intrinsic_dir)
|
376 |
+
current_timecode = time.strftime("%Y%m%d_%H%M%S")
|
377 |
+
# output_blend_path = os.path.join(os.getcwd(), "output", f"{obj_name}_{prompt_nospace}_{current_timecode}.blend") # replace with desired output path
|
378 |
+
output_blend_path = os.path.join(tempfile.mkdtemp(), f"{obj_name}_{prompt_nospace}_{current_timecode}.blend") # replace with desired output path
|
379 |
+
os.makedirs(os.path.dirname(output_blend_path), exist_ok=True)
|
380 |
+
|
381 |
+
def run_blend_generation(
|
382 |
+
blender_path,
|
383 |
+
generate_script_path,
|
384 |
+
obj_path,
|
385 |
+
base_color_path,
|
386 |
+
normal_map_path,
|
387 |
+
roughness_path,
|
388 |
+
metallic_path,
|
389 |
+
output_blend
|
390 |
+
):
|
391 |
+
cmd = [
|
392 |
+
blender_path, "--background", "--python", generate_script_path, "--",
|
393 |
+
obj_path, base_color_path, normal_map_path, roughness_path, metallic_path, output_blend
|
394 |
+
]
|
395 |
+
subprocess.run(cmd, check=True)
|
396 |
+
|
397 |
+
# check if the blender_path exists, if not download
|
398 |
+
run_blend_generation(
|
399 |
+
blender_path=self.blender_path,
|
400 |
+
generate_script_path="rgb2x/generate_blend.py",
|
401 |
+
# obj_path=f"examples/{obj_name}/mesh.obj", # replace with actual mesh path
|
402 |
+
obj_path=mesh_fine[0], # replace with actual mesh path
|
403 |
+
base_color_path=base_color_path,
|
404 |
+
normal_map_path=normal_map_path,
|
405 |
+
roughness_path=roughness_path,
|
406 |
+
metallic_path=metallic_path,
|
407 |
+
output_blend=output_blend_path # replace with desired output path
|
408 |
+
)
|
409 |
+
|
410 |
+
# gallary
|
411 |
+
return [*tex_fine], [preds[1]], [preds[2]], [preds[3]], [output_blend_path]
|
412 |
+
|
413 |
+
# @spaces.GPU #[uncomment to use ZeroGPU]
|
414 |
+
@torch.inference_mode()
|
415 |
+
def process_canny(
|
416 |
+
self,
|
417 |
+
image: np.ndarray,
|
418 |
+
prompt: str,
|
419 |
+
additional_prompt: str,
|
420 |
+
negative_prompt: str,
|
421 |
+
num_images: int,
|
422 |
+
image_resolution: int,
|
423 |
+
num_steps: int,
|
424 |
+
guidance_scale: float,
|
425 |
+
seed: int,
|
426 |
+
low_threshold: int,
|
427 |
+
high_threshold: int,
|
428 |
+
) -> list[PIL.Image.Image]:
|
429 |
+
if image is None:
|
430 |
+
raise ValueError
|
431 |
+
if image_resolution > MAX_IMAGE_RESOLUTION:
|
432 |
+
raise ValueError
|
433 |
+
if num_images > MAX_NUM_IMAGES:
|
434 |
+
raise ValueError
|
435 |
+
|
436 |
+
self.preprocessor.load("Canny")
|
437 |
+
control_image = self.preprocessor(
|
438 |
+
image=image, low_threshold=low_threshold, high_threshold=high_threshold, detect_resolution=image_resolution
|
439 |
+
)
|
440 |
+
|
441 |
+
self.load_controlnet_weight("Canny")
|
442 |
+
results = self.run_pipe(
|
443 |
+
prompt=self.get_prompt(prompt, additional_prompt),
|
444 |
+
negative_prompt=negative_prompt,
|
445 |
+
control_image=control_image,
|
446 |
+
num_images=num_images,
|
447 |
+
num_steps=num_steps,
|
448 |
+
guidance_scale=guidance_scale,
|
449 |
+
seed=seed,
|
450 |
+
)
|
451 |
+
return [control_image, *results]
|
452 |
+
|
453 |
+
@torch.inference_mode()
|
454 |
+
def process_mlsd(
|
455 |
+
self,
|
456 |
+
image: np.ndarray,
|
457 |
+
prompt: str,
|
458 |
+
additional_prompt: str,
|
459 |
+
negative_prompt: str,
|
460 |
+
num_images: int,
|
461 |
+
image_resolution: int,
|
462 |
+
preprocess_resolution: int,
|
463 |
+
num_steps: int,
|
464 |
+
guidance_scale: float,
|
465 |
+
seed: int,
|
466 |
+
value_threshold: float,
|
467 |
+
distance_threshold: float,
|
468 |
+
) -> list[PIL.Image.Image]:
|
469 |
+
if image is None:
|
470 |
+
raise ValueError
|
471 |
+
if image_resolution > MAX_IMAGE_RESOLUTION:
|
472 |
+
raise ValueError
|
473 |
+
if num_images > MAX_NUM_IMAGES:
|
474 |
+
raise ValueError
|
475 |
+
|
476 |
+
self.preprocessor.load("MLSD")
|
477 |
+
control_image = self.preprocessor(
|
478 |
+
image=image,
|
479 |
+
image_resolution=image_resolution,
|
480 |
+
detect_resolution=preprocess_resolution,
|
481 |
+
thr_v=value_threshold,
|
482 |
+
thr_d=distance_threshold,
|
483 |
+
)
|
484 |
+
self.load_controlnet_weight("MLSD")
|
485 |
+
results = self.run_pipe(
|
486 |
+
prompt=self.get_prompt(prompt, additional_prompt),
|
487 |
+
negative_prompt=negative_prompt,
|
488 |
+
control_image=control_image,
|
489 |
+
num_images=num_images,
|
490 |
+
num_steps=num_steps,
|
491 |
+
guidance_scale=guidance_scale,
|
492 |
+
seed=seed,
|
493 |
+
)
|
494 |
+
return [control_image, *results]
|
495 |
+
|
496 |
+
@torch.inference_mode()
|
497 |
+
def process_scribble(
|
498 |
+
self,
|
499 |
+
image: np.ndarray,
|
500 |
+
prompt: str,
|
501 |
+
additional_prompt: str,
|
502 |
+
negative_prompt: str,
|
503 |
+
num_images: int,
|
504 |
+
image_resolution: int,
|
505 |
+
preprocess_resolution: int,
|
506 |
+
num_steps: int,
|
507 |
+
guidance_scale: float,
|
508 |
+
seed: int,
|
509 |
+
preprocessor_name: str,
|
510 |
+
) -> list[PIL.Image.Image]:
|
511 |
+
if image is None:
|
512 |
+
raise ValueError
|
513 |
+
if image_resolution > MAX_IMAGE_RESOLUTION:
|
514 |
+
raise ValueError
|
515 |
+
if num_images > MAX_NUM_IMAGES:
|
516 |
+
raise ValueError
|
517 |
+
|
518 |
+
if preprocessor_name == "None":
|
519 |
+
image = HWC3(image)
|
520 |
+
image = resize_image(image, resolution=image_resolution)
|
521 |
+
control_image = PIL.Image.fromarray(image)
|
522 |
+
elif preprocessor_name == "HED":
|
523 |
+
self.preprocessor.load(preprocessor_name)
|
524 |
+
control_image = self.preprocessor(
|
525 |
+
image=image,
|
526 |
+
image_resolution=image_resolution,
|
527 |
+
detect_resolution=preprocess_resolution,
|
528 |
+
scribble=False,
|
529 |
+
)
|
530 |
+
elif preprocessor_name == "PidiNet":
|
531 |
+
self.preprocessor.load(preprocessor_name)
|
532 |
+
control_image = self.preprocessor(
|
533 |
+
image=image,
|
534 |
+
image_resolution=image_resolution,
|
535 |
+
detect_resolution=preprocess_resolution,
|
536 |
+
safe=False,
|
537 |
+
)
|
538 |
+
self.load_controlnet_weight("scribble")
|
539 |
+
results = self.run_pipe(
|
540 |
+
prompt=self.get_prompt(prompt, additional_prompt),
|
541 |
+
negative_prompt=negative_prompt,
|
542 |
+
control_image=control_image,
|
543 |
+
num_images=num_images,
|
544 |
+
num_steps=num_steps,
|
545 |
+
guidance_scale=guidance_scale,
|
546 |
+
seed=seed,
|
547 |
+
)
|
548 |
+
return [control_image, *results]
|
549 |
+
|
550 |
+
@torch.inference_mode()
|
551 |
+
def process_scribble_interactive(
|
552 |
+
self,
|
553 |
+
image_and_mask: dict[str, np.ndarray | list[np.ndarray]] | None,
|
554 |
+
prompt: str,
|
555 |
+
additional_prompt: str,
|
556 |
+
negative_prompt: str,
|
557 |
+
num_images: int,
|
558 |
+
image_resolution: int,
|
559 |
+
num_steps: int,
|
560 |
+
guidance_scale: float,
|
561 |
+
seed: int,
|
562 |
+
) -> list[PIL.Image.Image]:
|
563 |
+
if image_and_mask is None:
|
564 |
+
raise ValueError
|
565 |
+
if image_resolution > MAX_IMAGE_RESOLUTION:
|
566 |
+
raise ValueError
|
567 |
+
if num_images > MAX_NUM_IMAGES:
|
568 |
+
raise ValueError
|
569 |
+
|
570 |
+
image = 255 - image_and_mask["composite"] # type: ignore
|
571 |
+
image = HWC3(image)
|
572 |
+
image = resize_image(image, resolution=image_resolution)
|
573 |
+
control_image = PIL.Image.fromarray(image)
|
574 |
+
|
575 |
+
self.load_controlnet_weight("scribble")
|
576 |
+
results = self.run_pipe(
|
577 |
+
prompt=self.get_prompt(prompt, additional_prompt),
|
578 |
+
negative_prompt=negative_prompt,
|
579 |
+
control_image=control_image,
|
580 |
+
num_images=num_images,
|
581 |
+
num_steps=num_steps,
|
582 |
+
guidance_scale=guidance_scale,
|
583 |
+
seed=seed,
|
584 |
+
)
|
585 |
+
return [control_image, *results]
|
586 |
+
|
587 |
+
@torch.inference_mode()
|
588 |
+
def process_softedge(
|
589 |
+
self,
|
590 |
+
image: np.ndarray,
|
591 |
+
prompt: str,
|
592 |
+
additional_prompt: str,
|
593 |
+
negative_prompt: str,
|
594 |
+
num_images: int,
|
595 |
+
image_resolution: int,
|
596 |
+
preprocess_resolution: int,
|
597 |
+
num_steps: int,
|
598 |
+
guidance_scale: float,
|
599 |
+
seed: int,
|
600 |
+
preprocessor_name: str,
|
601 |
+
) -> list[PIL.Image.Image]:
|
602 |
+
if image is None:
|
603 |
+
raise ValueError
|
604 |
+
if image_resolution > MAX_IMAGE_RESOLUTION:
|
605 |
+
raise ValueError
|
606 |
+
if num_images > MAX_NUM_IMAGES:
|
607 |
+
raise ValueError
|
608 |
+
|
609 |
+
if preprocessor_name == "None":
|
610 |
+
image = HWC3(image)
|
611 |
+
image = resize_image(image, resolution=image_resolution)
|
612 |
+
control_image = PIL.Image.fromarray(image)
|
613 |
+
elif preprocessor_name in ["HED", "HED safe"]:
|
614 |
+
safe = "safe" in preprocessor_name
|
615 |
+
self.preprocessor.load("HED")
|
616 |
+
control_image = self.preprocessor(
|
617 |
+
image=image,
|
618 |
+
image_resolution=image_resolution,
|
619 |
+
detect_resolution=preprocess_resolution,
|
620 |
+
scribble=safe,
|
621 |
+
)
|
622 |
+
elif preprocessor_name in ["PidiNet", "PidiNet safe"]:
|
623 |
+
safe = "safe" in preprocessor_name
|
624 |
+
self.preprocessor.load("PidiNet")
|
625 |
+
control_image = self.preprocessor(
|
626 |
+
image=image,
|
627 |
+
image_resolution=image_resolution,
|
628 |
+
detect_resolution=preprocess_resolution,
|
629 |
+
safe=safe,
|
630 |
+
)
|
631 |
+
else:
|
632 |
+
raise ValueError
|
633 |
+
self.load_controlnet_weight("softedge")
|
634 |
+
results = self.run_pipe(
|
635 |
+
prompt=self.get_prompt(prompt, additional_prompt),
|
636 |
+
negative_prompt=negative_prompt,
|
637 |
+
control_image=control_image,
|
638 |
+
num_images=num_images,
|
639 |
+
num_steps=num_steps,
|
640 |
+
guidance_scale=guidance_scale,
|
641 |
+
seed=seed,
|
642 |
+
)
|
643 |
+
return [control_image, *results]
|
644 |
+
|
645 |
+
@torch.inference_mode()
|
646 |
+
def process_openpose(
|
647 |
+
self,
|
648 |
+
image: np.ndarray,
|
649 |
+
prompt: str,
|
650 |
+
additional_prompt: str,
|
651 |
+
negative_prompt: str,
|
652 |
+
num_images: int,
|
653 |
+
image_resolution: int,
|
654 |
+
preprocess_resolution: int,
|
655 |
+
num_steps: int,
|
656 |
+
guidance_scale: float,
|
657 |
+
seed: int,
|
658 |
+
preprocessor_name: str,
|
659 |
+
) -> list[PIL.Image.Image]:
|
660 |
+
if image is None:
|
661 |
+
raise ValueError
|
662 |
+
if image_resolution > MAX_IMAGE_RESOLUTION:
|
663 |
+
raise ValueError
|
664 |
+
if num_images > MAX_NUM_IMAGES:
|
665 |
+
raise ValueError
|
666 |
+
|
667 |
+
if preprocessor_name == "None":
|
668 |
+
image = HWC3(image)
|
669 |
+
image = resize_image(image, resolution=image_resolution)
|
670 |
+
control_image = PIL.Image.fromarray(image)
|
671 |
+
else:
|
672 |
+
self.preprocessor.load("Openpose")
|
673 |
+
control_image = self.preprocessor(
|
674 |
+
image=image,
|
675 |
+
image_resolution=image_resolution,
|
676 |
+
detect_resolution=preprocess_resolution,
|
677 |
+
hand_and_face=True,
|
678 |
+
)
|
679 |
+
self.load_controlnet_weight("Openpose")
|
680 |
+
results = self.run_pipe(
|
681 |
+
prompt=self.get_prompt(prompt, additional_prompt),
|
682 |
+
negative_prompt=negative_prompt,
|
683 |
+
control_image=control_image,
|
684 |
+
num_images=num_images,
|
685 |
+
num_steps=num_steps,
|
686 |
+
guidance_scale=guidance_scale,
|
687 |
+
seed=seed,
|
688 |
+
)
|
689 |
+
return [control_image, *results]
|
690 |
+
|
691 |
+
@torch.inference_mode()
|
692 |
+
def process_segmentation(
|
693 |
+
self,
|
694 |
+
image: np.ndarray,
|
695 |
+
prompt: str,
|
696 |
+
additional_prompt: str,
|
697 |
+
negative_prompt: str,
|
698 |
+
num_images: int,
|
699 |
+
image_resolution: int,
|
700 |
+
preprocess_resolution: int,
|
701 |
+
num_steps: int,
|
702 |
+
guidance_scale: float,
|
703 |
+
seed: int,
|
704 |
+
preprocessor_name: str,
|
705 |
+
) -> list[PIL.Image.Image]:
|
706 |
+
if image is None:
|
707 |
+
raise ValueError
|
708 |
+
if image_resolution > MAX_IMAGE_RESOLUTION:
|
709 |
+
raise ValueError
|
710 |
+
if num_images > MAX_NUM_IMAGES:
|
711 |
+
raise ValueError
|
712 |
+
|
713 |
+
if preprocessor_name == "None":
|
714 |
+
image = HWC3(image)
|
715 |
+
image = resize_image(image, resolution=image_resolution)
|
716 |
+
control_image = PIL.Image.fromarray(image)
|
717 |
+
else:
|
718 |
+
self.preprocessor.load(preprocessor_name)
|
719 |
+
control_image = self.preprocessor(
|
720 |
+
image=image,
|
721 |
+
image_resolution=image_resolution,
|
722 |
+
detect_resolution=preprocess_resolution,
|
723 |
+
)
|
724 |
+
self.load_controlnet_weight("segmentation")
|
725 |
+
results = self.run_pipe(
|
726 |
+
prompt=self.get_prompt(prompt, additional_prompt),
|
727 |
+
negative_prompt=negative_prompt,
|
728 |
+
control_image=control_image,
|
729 |
+
num_images=num_images,
|
730 |
+
num_steps=num_steps,
|
731 |
+
guidance_scale=guidance_scale,
|
732 |
+
seed=seed,
|
733 |
+
)
|
734 |
+
return [control_image, *results]
|
735 |
+
|
736 |
+
@torch.inference_mode()
|
737 |
+
def process_depth(
|
738 |
+
self,
|
739 |
+
image: np.ndarray,
|
740 |
+
prompt: str,
|
741 |
+
additional_prompt: str,
|
742 |
+
negative_prompt: str,
|
743 |
+
num_images: int,
|
744 |
+
image_resolution: int,
|
745 |
+
preprocess_resolution: int,
|
746 |
+
num_steps: int,
|
747 |
+
guidance_scale: float,
|
748 |
+
seed: int,
|
749 |
+
preprocessor_name: str,
|
750 |
+
) -> list[PIL.Image.Image]:
|
751 |
+
if image is None:
|
752 |
+
raise ValueError
|
753 |
+
if image_resolution > MAX_IMAGE_RESOLUTION:
|
754 |
+
raise ValueError
|
755 |
+
if num_images > MAX_NUM_IMAGES:
|
756 |
+
raise ValueError
|
757 |
+
|
758 |
+
if preprocessor_name == "None":
|
759 |
+
image = HWC3(image)
|
760 |
+
image = resize_image(image, resolution=image_resolution)
|
761 |
+
control_image = PIL.Image.fromarray(image)
|
762 |
+
else:
|
763 |
+
self.preprocessor.load(preprocessor_name)
|
764 |
+
control_image = self.preprocessor(
|
765 |
+
image=image,
|
766 |
+
image_resolution=image_resolution,
|
767 |
+
detect_resolution=preprocess_resolution,
|
768 |
+
)
|
769 |
+
self.load_controlnet_weight("depth")
|
770 |
+
results = self.run_pipe(
|
771 |
+
prompt=self.get_prompt(prompt, additional_prompt),
|
772 |
+
negative_prompt=negative_prompt,
|
773 |
+
control_image=control_image,
|
774 |
+
num_images=num_images,
|
775 |
+
num_steps=num_steps,
|
776 |
+
guidance_scale=guidance_scale,
|
777 |
+
seed=seed,
|
778 |
+
)
|
779 |
+
return [control_image, *results]
|
780 |
+
|
781 |
+
@torch.inference_mode()
|
782 |
+
def process_normal(
|
783 |
+
self,
|
784 |
+
image: np.ndarray,
|
785 |
+
prompt: str,
|
786 |
+
additional_prompt: str,
|
787 |
+
negative_prompt: str,
|
788 |
+
num_images: int,
|
789 |
+
image_resolution: int,
|
790 |
+
preprocess_resolution: int,
|
791 |
+
num_steps: int,
|
792 |
+
guidance_scale: float,
|
793 |
+
seed: int,
|
794 |
+
preprocessor_name: str,
|
795 |
+
) -> list[PIL.Image.Image]:
|
796 |
+
if image is None:
|
797 |
+
raise ValueError
|
798 |
+
if image_resolution > MAX_IMAGE_RESOLUTION:
|
799 |
+
raise ValueError
|
800 |
+
if num_images > MAX_NUM_IMAGES:
|
801 |
+
raise ValueError
|
802 |
+
|
803 |
+
if preprocessor_name == "None":
|
804 |
+
image = HWC3(image)
|
805 |
+
image = resize_image(image, resolution=image_resolution)
|
806 |
+
control_image = PIL.Image.fromarray(image)
|
807 |
+
else:
|
808 |
+
self.preprocessor.load("NormalBae")
|
809 |
+
control_image = self.preprocessor(
|
810 |
+
image=image,
|
811 |
+
image_resolution=image_resolution,
|
812 |
+
detect_resolution=preprocess_resolution,
|
813 |
+
)
|
814 |
+
self.load_controlnet_weight("NormalBae")
|
815 |
+
results = self.run_pipe(
|
816 |
+
prompt=self.get_prompt(prompt, additional_prompt),
|
817 |
+
negative_prompt=negative_prompt,
|
818 |
+
control_image=control_image,
|
819 |
+
num_images=num_images,
|
820 |
+
num_steps=num_steps,
|
821 |
+
guidance_scale=guidance_scale,
|
822 |
+
seed=seed,
|
823 |
+
)
|
824 |
+
return [control_image, *results]
|
825 |
+
|
826 |
+
@torch.inference_mode()
|
827 |
+
def process_lineart(
|
828 |
+
self,
|
829 |
+
image: np.ndarray,
|
830 |
+
prompt: str,
|
831 |
+
additional_prompt: str,
|
832 |
+
negative_prompt: str,
|
833 |
+
num_images: int,
|
834 |
+
image_resolution: int,
|
835 |
+
preprocess_resolution: int,
|
836 |
+
num_steps: int,
|
837 |
+
guidance_scale: float,
|
838 |
+
seed: int,
|
839 |
+
preprocessor_name: str,
|
840 |
+
) -> list[PIL.Image.Image]:
|
841 |
+
if image is None:
|
842 |
+
raise ValueError
|
843 |
+
if image_resolution > MAX_IMAGE_RESOLUTION:
|
844 |
+
raise ValueError
|
845 |
+
if num_images > MAX_NUM_IMAGES:
|
846 |
+
raise ValueError
|
847 |
+
|
848 |
+
if preprocessor_name in ["None", "None (anime)"]:
|
849 |
+
image = HWC3(image)
|
850 |
+
image = resize_image(image, resolution=image_resolution)
|
851 |
+
control_image = PIL.Image.fromarray(image)
|
852 |
+
elif preprocessor_name in ["Lineart", "Lineart coarse"]:
|
853 |
+
coarse = "coarse" in preprocessor_name
|
854 |
+
self.preprocessor.load("Lineart")
|
855 |
+
control_image = self.preprocessor(
|
856 |
+
image=image,
|
857 |
+
image_resolution=image_resolution,
|
858 |
+
detect_resolution=preprocess_resolution,
|
859 |
+
coarse=coarse,
|
860 |
+
)
|
861 |
+
elif preprocessor_name == "Lineart (anime)":
|
862 |
+
self.preprocessor.load("LineartAnime")
|
863 |
+
control_image = self.preprocessor(
|
864 |
+
image=image,
|
865 |
+
image_resolution=image_resolution,
|
866 |
+
detect_resolution=preprocess_resolution,
|
867 |
+
)
|
868 |
+
if "anime" in preprocessor_name:
|
869 |
+
self.load_controlnet_weight("lineart_anime")
|
870 |
+
else:
|
871 |
+
self.load_controlnet_weight("lineart")
|
872 |
+
results = self.run_pipe(
|
873 |
+
prompt=self.get_prompt(prompt, additional_prompt),
|
874 |
+
negative_prompt=negative_prompt,
|
875 |
+
control_image=control_image,
|
876 |
+
num_images=num_images,
|
877 |
+
num_steps=num_steps,
|
878 |
+
guidance_scale=guidance_scale,
|
879 |
+
seed=seed,
|
880 |
+
)
|
881 |
+
return [control_image, *results]
|
882 |
+
|
883 |
+
@torch.inference_mode()
|
884 |
+
def process_shuffle(
|
885 |
+
self,
|
886 |
+
image: np.ndarray,
|
887 |
+
prompt: str,
|
888 |
+
additional_prompt: str,
|
889 |
+
negative_prompt: str,
|
890 |
+
num_images: int,
|
891 |
+
image_resolution: int,
|
892 |
+
num_steps: int,
|
893 |
+
guidance_scale: float,
|
894 |
+
seed: int,
|
895 |
+
preprocessor_name: str,
|
896 |
+
) -> list[PIL.Image.Image]:
|
897 |
+
if image is None:
|
898 |
+
raise ValueError
|
899 |
+
if image_resolution > MAX_IMAGE_RESOLUTION:
|
900 |
+
raise ValueError
|
901 |
+
if num_images > MAX_NUM_IMAGES:
|
902 |
+
raise ValueError
|
903 |
+
|
904 |
+
if preprocessor_name == "None":
|
905 |
+
image = HWC3(image)
|
906 |
+
image = resize_image(image, resolution=image_resolution)
|
907 |
+
control_image = PIL.Image.fromarray(image)
|
908 |
+
else:
|
909 |
+
self.preprocessor.load(preprocessor_name)
|
910 |
+
control_image = self.preprocessor(
|
911 |
+
image=image,
|
912 |
+
image_resolution=image_resolution,
|
913 |
+
)
|
914 |
+
self.load_controlnet_weight("shuffle")
|
915 |
+
results = self.run_pipe(
|
916 |
+
prompt=self.get_prompt(prompt, additional_prompt),
|
917 |
+
negative_prompt=negative_prompt,
|
918 |
+
control_image=control_image,
|
919 |
+
num_images=num_images,
|
920 |
+
num_steps=num_steps,
|
921 |
+
guidance_scale=guidance_scale,
|
922 |
+
seed=seed,
|
923 |
+
)
|
924 |
+
return [control_image, *results]
|
925 |
+
|
926 |
+
@torch.inference_mode()
|
927 |
+
def process_ip2p(
|
928 |
+
self,
|
929 |
+
image: np.ndarray,
|
930 |
+
prompt: str,
|
931 |
+
additional_prompt: str,
|
932 |
+
negative_prompt: str,
|
933 |
+
num_images: int,
|
934 |
+
image_resolution: int,
|
935 |
+
num_steps: int,
|
936 |
+
guidance_scale: float,
|
937 |
+
seed: int,
|
938 |
+
) -> list[PIL.Image.Image]:
|
939 |
+
if image is None:
|
940 |
+
raise ValueError
|
941 |
+
if image_resolution > MAX_IMAGE_RESOLUTION:
|
942 |
+
raise ValueError
|
943 |
+
if num_images > MAX_NUM_IMAGES:
|
944 |
+
raise ValueError
|
945 |
+
|
946 |
+
image = HWC3(image)
|
947 |
+
image = resize_image(image, resolution=image_resolution)
|
948 |
+
control_image = PIL.Image.fromarray(image)
|
949 |
+
self.load_controlnet_weight("ip2p")
|
950 |
+
results = self.run_pipe(
|
951 |
+
prompt=self.get_prompt(prompt, additional_prompt),
|
952 |
+
negative_prompt=negative_prompt,
|
953 |
+
control_image=control_image,
|
954 |
+
num_images=num_images,
|
955 |
+
num_steps=num_steps,
|
956 |
+
guidance_scale=guidance_scale,
|
957 |
+
seed=seed,
|
958 |
+
)
|
959 |
+
return [control_image, *results]
|
pre-requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate
|
2 |
+
diffusers
|
3 |
+
invisible_watermark
|
4 |
+
torch
|
5 |
+
torchvision
|
6 |
+
transformers
|
7 |
+
xformers
|
8 |
+
controlnet-aux # for controlnet
|
9 |
+
spaces # no need to specify here
|
preprocessor.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
from typing import TYPE_CHECKING
|
3 |
+
|
4 |
+
if TYPE_CHECKING:
|
5 |
+
from collections.abc import Callable
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import PIL.Image
|
9 |
+
import torch
|
10 |
+
from controlnet_aux import (
|
11 |
+
CannyDetector,
|
12 |
+
ContentShuffleDetector,
|
13 |
+
HEDdetector,
|
14 |
+
LineartAnimeDetector,
|
15 |
+
LineartDetector,
|
16 |
+
MidasDetector,
|
17 |
+
MLSDdetector,
|
18 |
+
NormalBaeDetector,
|
19 |
+
OpenposeDetector,
|
20 |
+
PidiNetDetector,
|
21 |
+
)
|
22 |
+
from controlnet_aux.util import HWC3
|
23 |
+
|
24 |
+
from cv_utils import resize_image
|
25 |
+
from depth_estimator import DepthEstimator
|
26 |
+
from image_segmentor import ImageSegmentor
|
27 |
+
|
28 |
+
|
29 |
+
class Preprocessor:
|
30 |
+
MODEL_ID = "lllyasviel/Annotators"
|
31 |
+
|
32 |
+
def __init__(self) -> None:
|
33 |
+
self.model: Callable = None # type: ignore
|
34 |
+
self.name = ""
|
35 |
+
|
36 |
+
def load(self, name: str) -> None: # noqa: C901, PLR0912
|
37 |
+
if name == self.name:
|
38 |
+
return
|
39 |
+
if name == "HED":
|
40 |
+
self.model = HEDdetector.from_pretrained(self.MODEL_ID)
|
41 |
+
elif name == "Midas":
|
42 |
+
self.model = MidasDetector.from_pretrained(self.MODEL_ID)
|
43 |
+
elif name == "MLSD":
|
44 |
+
self.model = MLSDdetector.from_pretrained(self.MODEL_ID)
|
45 |
+
elif name == "Openpose":
|
46 |
+
self.model = OpenposeDetector.from_pretrained(self.MODEL_ID)
|
47 |
+
elif name == "PidiNet":
|
48 |
+
self.model = PidiNetDetector.from_pretrained(self.MODEL_ID)
|
49 |
+
elif name == "NormalBae":
|
50 |
+
self.model = NormalBaeDetector.from_pretrained(self.MODEL_ID)
|
51 |
+
elif name == "Lineart":
|
52 |
+
self.model = LineartDetector.from_pretrained(self.MODEL_ID)
|
53 |
+
elif name == "LineartAnime":
|
54 |
+
self.model = LineartAnimeDetector.from_pretrained(self.MODEL_ID)
|
55 |
+
elif name == "Canny":
|
56 |
+
self.model = CannyDetector()
|
57 |
+
elif name == "ContentShuffle":
|
58 |
+
self.model = ContentShuffleDetector()
|
59 |
+
elif name == "DPT":
|
60 |
+
self.model = DepthEstimator()
|
61 |
+
elif name == "UPerNet":
|
62 |
+
self.model = ImageSegmentor()
|
63 |
+
elif name == 'texnet':
|
64 |
+
self.model = TexnetPreprocessor()
|
65 |
+
else:
|
66 |
+
raise ValueError
|
67 |
+
torch.cuda.empty_cache()
|
68 |
+
gc.collect()
|
69 |
+
self.name = name
|
70 |
+
|
71 |
+
def __call__(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image: # noqa: ANN003
|
72 |
+
if self.name == "Canny":
|
73 |
+
if "detect_resolution" in kwargs:
|
74 |
+
detect_resolution = kwargs.pop("detect_resolution")
|
75 |
+
image = np.array(image)
|
76 |
+
image = HWC3(image)
|
77 |
+
image = resize_image(image, resolution=detect_resolution)
|
78 |
+
image = self.model(image, **kwargs)
|
79 |
+
return PIL.Image.fromarray(image)
|
80 |
+
if self.name == "Midas":
|
81 |
+
detect_resolution = kwargs.pop("detect_resolution", 512)
|
82 |
+
image_resolution = kwargs.pop("image_resolution", 512)
|
83 |
+
image = np.array(image)
|
84 |
+
image = HWC3(image)
|
85 |
+
image = resize_image(image, resolution=detect_resolution)
|
86 |
+
image = self.model(image, **kwargs)
|
87 |
+
image = HWC3(image)
|
88 |
+
image = resize_image(image, resolution=image_resolution)
|
89 |
+
return PIL.Image.fromarray(image)
|
90 |
+
return self.model(image, **kwargs)
|
91 |
+
|
92 |
+
|
93 |
+
# https://github.com/huggingface/controlnet_aux/blob/master/src/controlnet_aux/canny/__init__.py
|
94 |
+
class TexnetPreprocessor:
|
95 |
+
def __call__(self, input_image=None, low_threshold=100, high_threshold=200, image_resolution=512, output_type=None, **kwargs):
|
96 |
+
if "img" in kwargs:
|
97 |
+
warnings.warn("img is deprecated, please use `input_image=...` instead.", DeprecationWarning)
|
98 |
+
input_image = kwargs.pop("img")
|
99 |
+
|
100 |
+
if input_image is None:
|
101 |
+
raise ValueError("input_image must be defined.")
|
102 |
+
|
103 |
+
if not isinstance(input_image, np.ndarray):
|
104 |
+
input_image = np.array(input_image, dtype=np.uint8)
|
105 |
+
output_type = output_type or "pil"
|
106 |
+
else:
|
107 |
+
output_type = output_type or "np"
|
108 |
+
|
109 |
+
input_image = HWC3(input_image)
|
110 |
+
input_image = resize_image(input_image, image_resolution)
|
111 |
+
H, W, C = input_image.shape
|
112 |
+
|
113 |
+
# detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
|
114 |
+
output_image = input_image.copy()
|
115 |
+
|
116 |
+
if output_type == "pil":
|
117 |
+
# detected_map = Image.fromarray(detected_map)
|
118 |
+
output_image = PIL.Image.fromarray(output_image)
|
119 |
+
|
120 |
+
return output_image
|
push_dataset.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from huggingface_hub import HfApi
|
2 |
+
api = HfApi()
|
3 |
+
|
4 |
+
api.upload_folder(
|
5 |
+
folder_path="./examples",
|
6 |
+
repo_id="jingyangcarl/matgen",
|
7 |
+
repo_type="space",
|
8 |
+
path_in_repo="examples", # Upload to a specific folder
|
9 |
+
)
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
pytorch3d @ git+https://github.com/facebookresearch/pytorch3d.git@stable
|
4 |
+
trimesh
|
5 |
+
xatlas
|
6 |
+
scikit-learn
|
7 |
+
opencv-python
|
8 |
+
matplotlib
|
9 |
+
omegaconf
|
rgb2x/generate_blend.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import bpy
|
2 |
+
import sys
|
3 |
+
import os
|
4 |
+
|
5 |
+
def create_tex_node(nodes, img_path, label, color_space, location):
|
6 |
+
img = bpy.data.images.load(img_path)
|
7 |
+
tex = nodes.new(type='ShaderNodeTexImage')
|
8 |
+
tex.image = img
|
9 |
+
tex.label = label
|
10 |
+
tex.location = location
|
11 |
+
tex.image.colorspace_settings.name = color_space
|
12 |
+
return tex
|
13 |
+
|
14 |
+
def setup_environment_lighting(hdri_path):
|
15 |
+
if not bpy.data.worlds:
|
16 |
+
bpy.data.worlds.new(name="World")
|
17 |
+
if bpy.context.scene.world is None:
|
18 |
+
bpy.context.scene.world = bpy.data.worlds[0]
|
19 |
+
world = bpy.context.scene.world
|
20 |
+
|
21 |
+
world.use_nodes = True
|
22 |
+
nodes = world.node_tree.nodes
|
23 |
+
links = world.node_tree.links
|
24 |
+
nodes.clear()
|
25 |
+
|
26 |
+
env_tex = nodes.new(type="ShaderNodeTexEnvironment")
|
27 |
+
env_tex.image = bpy.data.images.load(hdri_path)
|
28 |
+
env_tex.location = (-300, 0)
|
29 |
+
|
30 |
+
bg = nodes.new(type="ShaderNodeBackground")
|
31 |
+
bg.location = (0, 0)
|
32 |
+
|
33 |
+
output = nodes.new(type="ShaderNodeOutputWorld")
|
34 |
+
output.location = (300, 0)
|
35 |
+
|
36 |
+
links.new(env_tex.outputs["Color"], bg.inputs["Color"])
|
37 |
+
links.new(bg.outputs["Background"], output.inputs["Surface"])
|
38 |
+
|
39 |
+
def setup_gpu_rendering():
|
40 |
+
bpy.context.scene.render.engine = 'CYCLES'
|
41 |
+
prefs = bpy.context.preferences
|
42 |
+
cprefs = prefs.addons['cycles'].preferences
|
43 |
+
|
44 |
+
# Choose backend depending on GPU type: 'CUDA', 'OPTIX', 'HIP', 'METAL'
|
45 |
+
cprefs.compute_device_type = 'CUDA'
|
46 |
+
bpy.context.scene.cycles.device = 'GPU'
|
47 |
+
|
48 |
+
def generate_blend(obj_path, base_color_path, normal_map_path, roughness_path, metallic_path, output_blend):
|
49 |
+
# Reset scene
|
50 |
+
bpy.ops.wm.read_factory_settings(use_empty=True)
|
51 |
+
|
52 |
+
# Import OBJ
|
53 |
+
bpy.ops.import_scene.obj(filepath=obj_path)
|
54 |
+
obj = bpy.context.selected_objects[0]
|
55 |
+
|
56 |
+
# Create material
|
57 |
+
mat = bpy.data.materials.new(name="BRDF_Material")
|
58 |
+
mat.use_nodes = True
|
59 |
+
nodes = mat.node_tree.nodes
|
60 |
+
links = mat.node_tree.links
|
61 |
+
nodes.clear()
|
62 |
+
|
63 |
+
output = nodes.new(type='ShaderNodeOutputMaterial')
|
64 |
+
output.location = (400, 0)
|
65 |
+
|
66 |
+
principled = nodes.new(type='ShaderNodeBsdfPrincipled')
|
67 |
+
principled.location = (100, 0)
|
68 |
+
links.new(principled.outputs['BSDF'], output.inputs['Surface'])
|
69 |
+
|
70 |
+
# Base Color
|
71 |
+
base_color = create_tex_node(nodes, base_color_path, "Base Color", 'sRGB', (-600, 200))
|
72 |
+
links.new(base_color.outputs['Color'], principled.inputs['Base Color'])
|
73 |
+
|
74 |
+
# Roughness
|
75 |
+
rough = create_tex_node(nodes, roughness_path, "Roughness", 'Non-Color', (-600, 0))
|
76 |
+
links.new(rough.outputs['Color'], principled.inputs['Roughness'])
|
77 |
+
|
78 |
+
# Metallic
|
79 |
+
metal = create_tex_node(nodes, metallic_path, "Metallic", 'Non-Color', (-600, -200))
|
80 |
+
links.new(metal.outputs['Color'], principled.inputs['Metallic'])
|
81 |
+
|
82 |
+
# Normal Map
|
83 |
+
normal_tex = create_tex_node(nodes, normal_map_path, "Normal Map", 'Non-Color', (-800, -400))
|
84 |
+
normal_map = nodes.new(type='ShaderNodeNormalMap')
|
85 |
+
normal_map.location = (-400, -400)
|
86 |
+
links.new(normal_tex.outputs['Color'], normal_map.inputs['Color'])
|
87 |
+
links.new(normal_map.outputs['Normal'], principled.inputs['Normal'])
|
88 |
+
|
89 |
+
# Assign material
|
90 |
+
if obj.data.materials:
|
91 |
+
obj.data.materials[0] = mat
|
92 |
+
else:
|
93 |
+
obj.data.materials.append(mat)
|
94 |
+
|
95 |
+
# Global Illumination using Blender's default forest HDRI
|
96 |
+
blender_data_path = bpy.utils.resource_path('LOCAL')
|
97 |
+
forest_hdri_path = os.path.join(blender_data_path, "datafiles", "studiolights", "world", "forest.exr")
|
98 |
+
print(f"Using HDRI: {forest_hdri_path}")
|
99 |
+
setup_environment_lighting(forest_hdri_path)
|
100 |
+
|
101 |
+
# GPU rendering setup
|
102 |
+
setup_gpu_rendering()
|
103 |
+
|
104 |
+
# Pack textures into .blend
|
105 |
+
bpy.ops.file.pack_all()
|
106 |
+
|
107 |
+
# Set the 3D View to Rendered mode and focus on object
|
108 |
+
for area in bpy.context.screen.areas:
|
109 |
+
if area.type == 'VIEW_3D':
|
110 |
+
for space in area.spaces:
|
111 |
+
if space.type == 'VIEW_3D':
|
112 |
+
space.shading.type = 'RENDERED' # Set viewport shading to Rendered
|
113 |
+
for region in area.regions:
|
114 |
+
if region.type == 'WINDOW':
|
115 |
+
override = {'area': area, 'region': region, 'scene': bpy.context.scene}
|
116 |
+
bpy.ops.view3d.view_all(override, center=True)
|
117 |
+
|
118 |
+
elif area.type == 'NODE_EDITOR':
|
119 |
+
for space in area.spaces:
|
120 |
+
if space.type == 'NODE_EDITOR':
|
121 |
+
space.tree_type = 'ShaderNodeTree' # Switch to Shader Editor
|
122 |
+
space.shader_type = 'OBJECT'
|
123 |
+
|
124 |
+
# Optional: Switch active workspace to Shading (if it exists)
|
125 |
+
for workspace in bpy.data.workspaces:
|
126 |
+
if workspace.name == 'Shading':
|
127 |
+
bpy.context.window.workspace = workspace
|
128 |
+
break
|
129 |
+
|
130 |
+
# Save the .blend file
|
131 |
+
bpy.ops.wm.save_as_mainfile(filepath=output_blend)
|
132 |
+
print(f"✅ Saved .blend file with BRDF, HDRI, GPU: {output_blend}")
|
133 |
+
|
134 |
+
if __name__ == "__main__":
|
135 |
+
argv = sys.argv
|
136 |
+
argv = argv[argv.index("--") + 1:] # Only use args after "--"
|
137 |
+
|
138 |
+
if len(argv) != 6:
|
139 |
+
print("Usage:\n blender --background --python generate_blend.py -- obj base_color normal roughness metallic output.blend")
|
140 |
+
sys.exit(1)
|
141 |
+
|
142 |
+
generate_blend(*argv)
|
rgb2x/gradio_demo_rgb2x.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
|
4 |
+
|
5 |
+
import gradio as gr
|
6 |
+
import torch
|
7 |
+
import torchvision
|
8 |
+
from diffusers import DDIMScheduler
|
9 |
+
from load_image import load_exr_image, load_ldr_image
|
10 |
+
from pipeline_rgb2x import StableDiffusionAOVMatEstPipeline
|
11 |
+
|
12 |
+
current_directory = os.path.dirname(os.path.abspath(__file__))
|
13 |
+
|
14 |
+
|
15 |
+
def get_rgb2x_demo():
|
16 |
+
# Load pipeline
|
17 |
+
pipe = StableDiffusionAOVMatEstPipeline.from_pretrained(
|
18 |
+
"zheng95z/rgb-to-x",
|
19 |
+
torch_dtype=torch.float16,
|
20 |
+
cache_dir=os.path.join(current_directory, "model_cache"),
|
21 |
+
).to("cuda")
|
22 |
+
pipe.scheduler = DDIMScheduler.from_config(
|
23 |
+
pipe.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing"
|
24 |
+
)
|
25 |
+
pipe.set_progress_bar_config(disable=True)
|
26 |
+
pipe.to("cuda")
|
27 |
+
|
28 |
+
# Augmentation
|
29 |
+
def callback(
|
30 |
+
photo,
|
31 |
+
seed,
|
32 |
+
inference_step,
|
33 |
+
num_samples,
|
34 |
+
):
|
35 |
+
generator = torch.Generator(device="cuda").manual_seed(seed)
|
36 |
+
|
37 |
+
if photo.name.endswith(".exr"):
|
38 |
+
photo = load_exr_image(photo.name, tonemaping=True, clamp=True).to("cuda")
|
39 |
+
elif (
|
40 |
+
photo.name.endswith(".png")
|
41 |
+
or photo.name.endswith(".jpg")
|
42 |
+
or photo.name.endswith(".jpeg")
|
43 |
+
):
|
44 |
+
photo = load_ldr_image(photo.name, from_srgb=True).to("cuda")
|
45 |
+
|
46 |
+
# Check if the width and height are multiples of 8. If not, crop it using torchvision.transforms.CenterCrop
|
47 |
+
old_height = photo.shape[1]
|
48 |
+
old_width = photo.shape[2]
|
49 |
+
new_height = old_height
|
50 |
+
new_width = old_width
|
51 |
+
radio = old_height / old_width
|
52 |
+
max_side = 1000
|
53 |
+
if old_height > old_width:
|
54 |
+
new_height = max_side
|
55 |
+
new_width = int(new_height / radio)
|
56 |
+
else:
|
57 |
+
new_width = max_side
|
58 |
+
new_height = int(new_width * radio)
|
59 |
+
|
60 |
+
if new_width % 8 != 0 or new_height % 8 != 0:
|
61 |
+
new_width = new_width // 8 * 8
|
62 |
+
new_height = new_height // 8 * 8
|
63 |
+
|
64 |
+
photo = torchvision.transforms.Resize((new_height, new_width))(photo)
|
65 |
+
|
66 |
+
required_aovs = ["albedo", "normal", "roughness", "metallic", "irradiance"]
|
67 |
+
prompts = {
|
68 |
+
"albedo": "Albedo (diffuse basecolor)",
|
69 |
+
"normal": "Camera-space Normal",
|
70 |
+
"roughness": "Roughness",
|
71 |
+
"metallic": "Metallicness",
|
72 |
+
"irradiance": "Irradiance (diffuse lighting)",
|
73 |
+
}
|
74 |
+
|
75 |
+
return_list = []
|
76 |
+
for i in range(num_samples):
|
77 |
+
for aov_name in required_aovs:
|
78 |
+
prompt = prompts[aov_name]
|
79 |
+
generated_image = pipe(
|
80 |
+
prompt=prompt,
|
81 |
+
photo=photo,
|
82 |
+
num_inference_steps=inference_step,
|
83 |
+
height=new_height,
|
84 |
+
width=new_width,
|
85 |
+
generator=generator,
|
86 |
+
required_aovs=[aov_name],
|
87 |
+
).images[0][0]
|
88 |
+
|
89 |
+
generated_image = torchvision.transforms.Resize(
|
90 |
+
(old_height, old_width)
|
91 |
+
)(generated_image)
|
92 |
+
|
93 |
+
generated_image = (generated_image, f"Generated {aov_name} {i}")
|
94 |
+
return_list.append(generated_image)
|
95 |
+
|
96 |
+
return return_list
|
97 |
+
|
98 |
+
block = gr.Blocks()
|
99 |
+
with block:
|
100 |
+
with gr.Row():
|
101 |
+
gr.Markdown("## Model RGB -> X (Realistic image -> Intrinsic channels)")
|
102 |
+
with gr.Row():
|
103 |
+
# Input side
|
104 |
+
with gr.Column():
|
105 |
+
gr.Markdown("### Given Image")
|
106 |
+
photo = gr.File(label="Photo", file_types=[".exr", ".png", ".jpg"])
|
107 |
+
|
108 |
+
gr.Markdown("### Parameters")
|
109 |
+
run_button = gr.Button(value="Run")
|
110 |
+
with gr.Accordion("Advanced options", open=False):
|
111 |
+
seed = gr.Slider(
|
112 |
+
label="Seed",
|
113 |
+
minimum=-1,
|
114 |
+
maximum=2147483647,
|
115 |
+
step=1,
|
116 |
+
randomize=True,
|
117 |
+
)
|
118 |
+
inference_step = gr.Slider(
|
119 |
+
label="Inference Step",
|
120 |
+
minimum=1,
|
121 |
+
maximum=100,
|
122 |
+
step=1,
|
123 |
+
value=50,
|
124 |
+
)
|
125 |
+
num_samples = gr.Slider(
|
126 |
+
label="Samples",
|
127 |
+
minimum=1,
|
128 |
+
maximum=100,
|
129 |
+
step=1,
|
130 |
+
value=1,
|
131 |
+
)
|
132 |
+
|
133 |
+
# Output side
|
134 |
+
with gr.Column():
|
135 |
+
gr.Markdown("### Output Gallery")
|
136 |
+
result_gallery = gr.Gallery(
|
137 |
+
label="Output",
|
138 |
+
show_label=False,
|
139 |
+
elem_id="gallery",
|
140 |
+
columns=2,
|
141 |
+
)
|
142 |
+
|
143 |
+
inputs = [
|
144 |
+
photo,
|
145 |
+
seed,
|
146 |
+
inference_step,
|
147 |
+
num_samples,
|
148 |
+
]
|
149 |
+
run_button.click(fn=callback, inputs=inputs, outputs=result_gallery, queue=True)
|
150 |
+
|
151 |
+
return block
|
152 |
+
|
153 |
+
|
154 |
+
if __name__ == "__main__":
|
155 |
+
demo = get_rgb2x_demo()
|
156 |
+
demo.queue(max_size=1)
|
157 |
+
demo.launch()
|
rgb2x/load_image.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import torch
|
5 |
+
|
6 |
+
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
|
10 |
+
def convert_rgb_2_XYZ(rgb):
|
11 |
+
# Reference: https://web.archive.org/web/20191027010220/http://www.brucelindbloom.com/index.html?Eqn_RGB_XYZ_Matrix.html
|
12 |
+
# rgb: (h, w, 3)
|
13 |
+
# XYZ: (h, w, 3)
|
14 |
+
XYZ = torch.ones_like(rgb)
|
15 |
+
XYZ[:, :, 0] = (
|
16 |
+
0.4124564 * rgb[:, :, 0] + 0.3575761 * rgb[:, :, 1] + 0.1804375 * rgb[:, :, 2]
|
17 |
+
)
|
18 |
+
XYZ[:, :, 1] = (
|
19 |
+
0.2126729 * rgb[:, :, 0] + 0.7151522 * rgb[:, :, 1] + 0.0721750 * rgb[:, :, 2]
|
20 |
+
)
|
21 |
+
XYZ[:, :, 2] = (
|
22 |
+
0.0193339 * rgb[:, :, 0] + 0.1191920 * rgb[:, :, 1] + 0.9503041 * rgb[:, :, 2]
|
23 |
+
)
|
24 |
+
return XYZ
|
25 |
+
|
26 |
+
|
27 |
+
def convert_XYZ_2_Yxy(XYZ):
|
28 |
+
# XYZ: (h, w, 3)
|
29 |
+
# Yxy: (h, w, 3)
|
30 |
+
Yxy = torch.ones_like(XYZ)
|
31 |
+
Yxy[:, :, 0] = XYZ[:, :, 1]
|
32 |
+
sum = torch.sum(XYZ, dim=2)
|
33 |
+
inv_sum = 1.0 / torch.clamp(sum, min=1e-4)
|
34 |
+
Yxy[:, :, 1] = XYZ[:, :, 0] * inv_sum
|
35 |
+
Yxy[:, :, 2] = XYZ[:, :, 1] * inv_sum
|
36 |
+
return Yxy
|
37 |
+
|
38 |
+
|
39 |
+
def convert_rgb_2_Yxy(rgb):
|
40 |
+
# rgb: (h, w, 3)
|
41 |
+
# Yxy: (h, w, 3)
|
42 |
+
return convert_XYZ_2_Yxy(convert_rgb_2_XYZ(rgb))
|
43 |
+
|
44 |
+
|
45 |
+
def convert_XYZ_2_rgb(XYZ):
|
46 |
+
# XYZ: (h, w, 3)
|
47 |
+
# rgb: (h, w, 3)
|
48 |
+
rgb = torch.ones_like(XYZ)
|
49 |
+
rgb[:, :, 0] = (
|
50 |
+
3.2404542 * XYZ[:, :, 0] - 1.5371385 * XYZ[:, :, 1] - 0.4985314 * XYZ[:, :, 2]
|
51 |
+
)
|
52 |
+
rgb[:, :, 1] = (
|
53 |
+
-0.9692660 * XYZ[:, :, 0] + 1.8760108 * XYZ[:, :, 1] + 0.0415560 * XYZ[:, :, 2]
|
54 |
+
)
|
55 |
+
rgb[:, :, 2] = (
|
56 |
+
0.0556434 * XYZ[:, :, 0] - 0.2040259 * XYZ[:, :, 1] + 1.0572252 * XYZ[:, :, 2]
|
57 |
+
)
|
58 |
+
return rgb
|
59 |
+
|
60 |
+
|
61 |
+
def convert_Yxy_2_XYZ(Yxy):
|
62 |
+
# Yxy: (h, w, 3)
|
63 |
+
# XYZ: (h, w, 3)
|
64 |
+
XYZ = torch.ones_like(Yxy)
|
65 |
+
XYZ[:, :, 0] = Yxy[:, :, 1] / torch.clamp(Yxy[:, :, 2], min=1e-6) * Yxy[:, :, 0]
|
66 |
+
XYZ[:, :, 1] = Yxy[:, :, 0]
|
67 |
+
XYZ[:, :, 2] = (
|
68 |
+
(1.0 - Yxy[:, :, 1] - Yxy[:, :, 2])
|
69 |
+
/ torch.clamp(Yxy[:, :, 2], min=1e-4)
|
70 |
+
* Yxy[:, :, 0]
|
71 |
+
)
|
72 |
+
return XYZ
|
73 |
+
|
74 |
+
|
75 |
+
def convert_Yxy_2_rgb(Yxy):
|
76 |
+
# Yxy: (h, w, 3)
|
77 |
+
# rgb: (h, w, 3)
|
78 |
+
return convert_XYZ_2_rgb(convert_Yxy_2_XYZ(Yxy))
|
79 |
+
|
80 |
+
|
81 |
+
def load_ldr_image(image_path, from_srgb=False, clamp=False, normalize=False):
|
82 |
+
# Load png or jpg image
|
83 |
+
image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
|
84 |
+
image = torch.from_numpy(image.astype(np.float32) / 255.0) # (h, w, c)
|
85 |
+
image[~torch.isfinite(image)] = 0
|
86 |
+
if from_srgb:
|
87 |
+
# Convert from sRGB to linear RGB
|
88 |
+
image = image**2.2
|
89 |
+
if clamp:
|
90 |
+
image = torch.clamp(image, min=0.0, max=1.0)
|
91 |
+
if normalize:
|
92 |
+
# Normalize to [-1, 1]
|
93 |
+
image = image * 2.0 - 1.0
|
94 |
+
image = torch.nn.functional.normalize(image, dim=-1, eps=1e-6)
|
95 |
+
return image.permute(2, 0, 1) # returns (c, h, w)
|
96 |
+
|
97 |
+
|
98 |
+
def load_exr_image(image_path, tonemaping=False, clamp=False, normalize=False):
|
99 |
+
image = cv2.cvtColor(cv2.imread(image_path, -1), cv2.COLOR_BGR2RGB)
|
100 |
+
image = torch.from_numpy(image.astype("float32")) # (h, w, c)
|
101 |
+
image[~torch.isfinite(image)] = 0
|
102 |
+
if tonemaping:
|
103 |
+
# Exposure adjuestment
|
104 |
+
image_Yxy = convert_rgb_2_Yxy(image)
|
105 |
+
lum = (
|
106 |
+
image[:, :, 0:1] * 0.2125
|
107 |
+
+ image[:, :, 1:2] * 0.7154
|
108 |
+
+ image[:, :, 2:3] * 0.0721
|
109 |
+
)
|
110 |
+
lum = torch.log(torch.clamp(lum, min=1e-6))
|
111 |
+
lum_mean = torch.exp(torch.mean(lum))
|
112 |
+
lp = image_Yxy[:, :, 0:1] * 0.18 / torch.clamp(lum_mean, min=1e-6)
|
113 |
+
image_Yxy[:, :, 0:1] = lp
|
114 |
+
image = convert_Yxy_2_rgb(image_Yxy)
|
115 |
+
if clamp:
|
116 |
+
image = torch.clamp(image, min=0.0, max=1.0)
|
117 |
+
if normalize:
|
118 |
+
image = torch.nn.functional.normalize(image, dim=-1, eps=1e-6)
|
119 |
+
return image.permute(2, 0, 1) # returns (c, h, w)
|
rgb2x/pipeline_rgb2x.py
ADDED
@@ -0,0 +1,821 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from typing import Callable, List, Optional, Union
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import PIL
|
7 |
+
import torch
|
8 |
+
from diffusers.configuration_utils import register_to_config
|
9 |
+
from diffusers.image_processor import VaeImageProcessor
|
10 |
+
from diffusers.loaders import (
|
11 |
+
LoraLoaderMixin,
|
12 |
+
TextualInversionLoaderMixin,
|
13 |
+
)
|
14 |
+
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
15 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
16 |
+
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
|
17 |
+
rescale_noise_cfg,
|
18 |
+
)
|
19 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers
|
20 |
+
from diffusers.utils import (
|
21 |
+
CONFIG_NAME,
|
22 |
+
BaseOutput,
|
23 |
+
deprecate,
|
24 |
+
logging,
|
25 |
+
)
|
26 |
+
from diffusers.utils.torch_utils import randn_tensor
|
27 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
28 |
+
|
29 |
+
logger = logging.get_logger(__name__)
|
30 |
+
|
31 |
+
|
32 |
+
class VaeImageProcrssorAOV(VaeImageProcessor):
|
33 |
+
"""
|
34 |
+
Image processor for VAE AOV.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
38 |
+
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
|
39 |
+
vae_scale_factor (`int`, *optional*, defaults to `8`):
|
40 |
+
VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
|
41 |
+
resample (`str`, *optional*, defaults to `lanczos`):
|
42 |
+
Resampling filter to use when resizing the image.
|
43 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
44 |
+
Whether to normalize the image to [-1,1].
|
45 |
+
"""
|
46 |
+
|
47 |
+
config_name = CONFIG_NAME
|
48 |
+
|
49 |
+
@register_to_config
|
50 |
+
def __init__(
|
51 |
+
self,
|
52 |
+
do_resize: bool = True,
|
53 |
+
vae_scale_factor: int = 8,
|
54 |
+
resample: str = "lanczos",
|
55 |
+
do_normalize: bool = True,
|
56 |
+
):
|
57 |
+
super().__init__()
|
58 |
+
|
59 |
+
def postprocess(
|
60 |
+
self,
|
61 |
+
image: torch.FloatTensor,
|
62 |
+
output_type: str = "pil",
|
63 |
+
do_denormalize: Optional[List[bool]] = None,
|
64 |
+
do_gamma_correction: bool = True,
|
65 |
+
):
|
66 |
+
if not isinstance(image, torch.Tensor):
|
67 |
+
raise ValueError(
|
68 |
+
f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
|
69 |
+
)
|
70 |
+
if output_type not in ["latent", "pt", "np", "pil"]:
|
71 |
+
deprecation_message = (
|
72 |
+
f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
|
73 |
+
"`pil`, `np`, `pt`, `latent`"
|
74 |
+
)
|
75 |
+
deprecate(
|
76 |
+
"Unsupported output_type",
|
77 |
+
"1.0.0",
|
78 |
+
deprecation_message,
|
79 |
+
standard_warn=False,
|
80 |
+
)
|
81 |
+
output_type = "np"
|
82 |
+
|
83 |
+
if output_type == "latent":
|
84 |
+
return image
|
85 |
+
|
86 |
+
if do_denormalize is None:
|
87 |
+
do_denormalize = [self.config.do_normalize] * image.shape[0]
|
88 |
+
|
89 |
+
image = torch.stack(
|
90 |
+
[
|
91 |
+
self.denormalize(image[i]) if do_denormalize[i] else image[i]
|
92 |
+
for i in range(image.shape[0])
|
93 |
+
]
|
94 |
+
)
|
95 |
+
|
96 |
+
# Gamma correction
|
97 |
+
if do_gamma_correction:
|
98 |
+
image = torch.pow(image, 1.0 / 2.2)
|
99 |
+
|
100 |
+
if output_type == "pt":
|
101 |
+
return image
|
102 |
+
|
103 |
+
image = self.pt_to_numpy(image)
|
104 |
+
|
105 |
+
if output_type == "np":
|
106 |
+
return image
|
107 |
+
|
108 |
+
if output_type == "pil":
|
109 |
+
return self.numpy_to_pil(image)
|
110 |
+
|
111 |
+
def preprocess_normal(
|
112 |
+
self,
|
113 |
+
image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
|
114 |
+
height: Optional[int] = None,
|
115 |
+
width: Optional[int] = None,
|
116 |
+
) -> torch.Tensor:
|
117 |
+
image = torch.stack([image], axis=0)
|
118 |
+
return image
|
119 |
+
|
120 |
+
|
121 |
+
@dataclass
|
122 |
+
class StableDiffusionAOVPipelineOutput(BaseOutput):
|
123 |
+
"""
|
124 |
+
Output class for Stable Diffusion AOV pipelines.
|
125 |
+
|
126 |
+
Args:
|
127 |
+
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
128 |
+
List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
|
129 |
+
num_channels)`.
|
130 |
+
nsfw_content_detected (`List[bool]`)
|
131 |
+
List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or
|
132 |
+
`None` if safety checking could not be performed.
|
133 |
+
"""
|
134 |
+
|
135 |
+
images: Union[List[PIL.Image.Image], np.ndarray]
|
136 |
+
|
137 |
+
|
138 |
+
class StableDiffusionAOVMatEstPipeline(
|
139 |
+
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin
|
140 |
+
):
|
141 |
+
r"""
|
142 |
+
Pipeline for AOVs.
|
143 |
+
|
144 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
145 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
146 |
+
|
147 |
+
The pipeline also inherits the following loading methods:
|
148 |
+
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
149 |
+
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
150 |
+
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
151 |
+
|
152 |
+
Args:
|
153 |
+
vae ([`AutoencoderKL`]):
|
154 |
+
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
155 |
+
text_encoder ([`~transformers.CLIPTextModel`]):
|
156 |
+
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
157 |
+
tokenizer ([`~transformers.CLIPTokenizer`]):
|
158 |
+
A `CLIPTokenizer` to tokenize text.
|
159 |
+
unet ([`UNet2DConditionModel`]):
|
160 |
+
A `UNet2DConditionModel` to denoise the encoded image latents.
|
161 |
+
scheduler ([`SchedulerMixin`]):
|
162 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
163 |
+
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
164 |
+
"""
|
165 |
+
|
166 |
+
def __init__(
|
167 |
+
self,
|
168 |
+
vae: AutoencoderKL,
|
169 |
+
text_encoder: CLIPTextModel,
|
170 |
+
tokenizer: CLIPTokenizer,
|
171 |
+
unet: UNet2DConditionModel,
|
172 |
+
scheduler: KarrasDiffusionSchedulers,
|
173 |
+
):
|
174 |
+
super().__init__()
|
175 |
+
|
176 |
+
self.register_modules(
|
177 |
+
vae=vae,
|
178 |
+
text_encoder=text_encoder,
|
179 |
+
tokenizer=tokenizer,
|
180 |
+
unet=unet,
|
181 |
+
scheduler=scheduler,
|
182 |
+
)
|
183 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
184 |
+
self.image_processor = VaeImageProcrssorAOV(
|
185 |
+
vae_scale_factor=self.vae_scale_factor
|
186 |
+
)
|
187 |
+
self.register_to_config()
|
188 |
+
|
189 |
+
def _encode_prompt(
|
190 |
+
self,
|
191 |
+
prompt,
|
192 |
+
device,
|
193 |
+
num_images_per_prompt,
|
194 |
+
do_classifier_free_guidance,
|
195 |
+
negative_prompt=None,
|
196 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
197 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
198 |
+
):
|
199 |
+
r"""
|
200 |
+
Encodes the prompt into text encoder hidden states.
|
201 |
+
|
202 |
+
Args:
|
203 |
+
prompt (`str` or `List[str]`, *optional*):
|
204 |
+
prompt to be encoded
|
205 |
+
device: (`torch.device`):
|
206 |
+
torch device
|
207 |
+
num_images_per_prompt (`int`):
|
208 |
+
number of images that should be generated per prompt
|
209 |
+
do_classifier_free_guidance (`bool`):
|
210 |
+
whether to use classifier free guidance or not
|
211 |
+
negative_ prompt (`str` or `List[str]`, *optional*):
|
212 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
213 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
214 |
+
less than `1`).
|
215 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
216 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
217 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
218 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
219 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
220 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
221 |
+
argument.
|
222 |
+
"""
|
223 |
+
if prompt is not None and isinstance(prompt, str):
|
224 |
+
batch_size = 1
|
225 |
+
elif prompt is not None and isinstance(prompt, list):
|
226 |
+
batch_size = len(prompt)
|
227 |
+
else:
|
228 |
+
batch_size = prompt_embeds.shape[0]
|
229 |
+
|
230 |
+
if prompt_embeds is None:
|
231 |
+
# textual inversion: procecss multi-vector tokens if necessary
|
232 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
233 |
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
234 |
+
|
235 |
+
text_inputs = self.tokenizer(
|
236 |
+
prompt,
|
237 |
+
padding="max_length",
|
238 |
+
max_length=self.tokenizer.model_max_length,
|
239 |
+
truncation=True,
|
240 |
+
return_tensors="pt",
|
241 |
+
)
|
242 |
+
text_input_ids = text_inputs.input_ids
|
243 |
+
untruncated_ids = self.tokenizer(
|
244 |
+
prompt, padding="longest", return_tensors="pt"
|
245 |
+
).input_ids
|
246 |
+
|
247 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[
|
248 |
+
-1
|
249 |
+
] and not torch.equal(text_input_ids, untruncated_ids):
|
250 |
+
removed_text = self.tokenizer.batch_decode(
|
251 |
+
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
252 |
+
)
|
253 |
+
logger.warning(
|
254 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
255 |
+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
256 |
+
)
|
257 |
+
|
258 |
+
if (
|
259 |
+
hasattr(self.text_encoder.config, "use_attention_mask")
|
260 |
+
and self.text_encoder.config.use_attention_mask
|
261 |
+
):
|
262 |
+
attention_mask = text_inputs.attention_mask.to(device)
|
263 |
+
else:
|
264 |
+
attention_mask = None
|
265 |
+
|
266 |
+
prompt_embeds = self.text_encoder(
|
267 |
+
text_input_ids.to(device),
|
268 |
+
attention_mask=attention_mask,
|
269 |
+
)
|
270 |
+
prompt_embeds = prompt_embeds[0]
|
271 |
+
|
272 |
+
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
273 |
+
|
274 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
275 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
276 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
277 |
+
prompt_embeds = prompt_embeds.view(
|
278 |
+
bs_embed * num_images_per_prompt, seq_len, -1
|
279 |
+
)
|
280 |
+
|
281 |
+
# get unconditional embeddings for classifier free guidance
|
282 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
283 |
+
uncond_tokens: List[str]
|
284 |
+
if negative_prompt is None:
|
285 |
+
uncond_tokens = [""] * batch_size
|
286 |
+
elif type(prompt) is not type(negative_prompt):
|
287 |
+
raise TypeError(
|
288 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
289 |
+
f" {type(prompt)}."
|
290 |
+
)
|
291 |
+
elif isinstance(negative_prompt, str):
|
292 |
+
uncond_tokens = [negative_prompt]
|
293 |
+
elif batch_size != len(negative_prompt):
|
294 |
+
raise ValueError(
|
295 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
296 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
297 |
+
" the batch size of `prompt`."
|
298 |
+
)
|
299 |
+
else:
|
300 |
+
uncond_tokens = negative_prompt
|
301 |
+
|
302 |
+
# textual inversion: procecss multi-vector tokens if necessary
|
303 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
304 |
+
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
305 |
+
|
306 |
+
max_length = prompt_embeds.shape[1]
|
307 |
+
uncond_input = self.tokenizer(
|
308 |
+
uncond_tokens,
|
309 |
+
padding="max_length",
|
310 |
+
max_length=max_length,
|
311 |
+
truncation=True,
|
312 |
+
return_tensors="pt",
|
313 |
+
)
|
314 |
+
|
315 |
+
if (
|
316 |
+
hasattr(self.text_encoder.config, "use_attention_mask")
|
317 |
+
and self.text_encoder.config.use_attention_mask
|
318 |
+
):
|
319 |
+
attention_mask = uncond_input.attention_mask.to(device)
|
320 |
+
else:
|
321 |
+
attention_mask = None
|
322 |
+
|
323 |
+
negative_prompt_embeds = self.text_encoder(
|
324 |
+
uncond_input.input_ids.to(device),
|
325 |
+
attention_mask=attention_mask,
|
326 |
+
)
|
327 |
+
negative_prompt_embeds = negative_prompt_embeds[0]
|
328 |
+
|
329 |
+
if do_classifier_free_guidance:
|
330 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
331 |
+
seq_len = negative_prompt_embeds.shape[1]
|
332 |
+
|
333 |
+
negative_prompt_embeds = negative_prompt_embeds.to(
|
334 |
+
dtype=self.text_encoder.dtype, device=device
|
335 |
+
)
|
336 |
+
|
337 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(
|
338 |
+
1, num_images_per_prompt, 1
|
339 |
+
)
|
340 |
+
negative_prompt_embeds = negative_prompt_embeds.view(
|
341 |
+
batch_size * num_images_per_prompt, seq_len, -1
|
342 |
+
)
|
343 |
+
|
344 |
+
# For classifier free guidance, we need to do two forward passes.
|
345 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
346 |
+
# to avoid doing two forward passes
|
347 |
+
# pix2pix has two negative embeddings, and unlike in other pipelines latents are ordered [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds]
|
348 |
+
prompt_embeds = torch.cat(
|
349 |
+
[prompt_embeds, negative_prompt_embeds, negative_prompt_embeds]
|
350 |
+
)
|
351 |
+
|
352 |
+
return prompt_embeds
|
353 |
+
|
354 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
355 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
356 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
357 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
358 |
+
# and should be between [0, 1]
|
359 |
+
|
360 |
+
accepts_eta = "eta" in set(
|
361 |
+
inspect.signature(self.scheduler.step).parameters.keys()
|
362 |
+
)
|
363 |
+
extra_step_kwargs = {}
|
364 |
+
if accepts_eta:
|
365 |
+
extra_step_kwargs["eta"] = eta
|
366 |
+
|
367 |
+
# check if the scheduler accepts generator
|
368 |
+
accepts_generator = "generator" in set(
|
369 |
+
inspect.signature(self.scheduler.step).parameters.keys()
|
370 |
+
)
|
371 |
+
if accepts_generator:
|
372 |
+
extra_step_kwargs["generator"] = generator
|
373 |
+
return extra_step_kwargs
|
374 |
+
|
375 |
+
def check_inputs(
|
376 |
+
self,
|
377 |
+
prompt,
|
378 |
+
callback_steps,
|
379 |
+
negative_prompt=None,
|
380 |
+
prompt_embeds=None,
|
381 |
+
negative_prompt_embeds=None,
|
382 |
+
):
|
383 |
+
if (callback_steps is None) or (
|
384 |
+
callback_steps is not None
|
385 |
+
and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
386 |
+
):
|
387 |
+
raise ValueError(
|
388 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
389 |
+
f" {type(callback_steps)}."
|
390 |
+
)
|
391 |
+
|
392 |
+
if prompt is not None and prompt_embeds is not None:
|
393 |
+
raise ValueError(
|
394 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
395 |
+
" only forward one of the two."
|
396 |
+
)
|
397 |
+
elif prompt is None and prompt_embeds is None:
|
398 |
+
raise ValueError(
|
399 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
400 |
+
)
|
401 |
+
elif prompt is not None and (
|
402 |
+
not isinstance(prompt, str) and not isinstance(prompt, list)
|
403 |
+
):
|
404 |
+
raise ValueError(
|
405 |
+
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
|
406 |
+
)
|
407 |
+
|
408 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
409 |
+
raise ValueError(
|
410 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
411 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
412 |
+
)
|
413 |
+
|
414 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
415 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
416 |
+
raise ValueError(
|
417 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
418 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
419 |
+
f" {negative_prompt_embeds.shape}."
|
420 |
+
)
|
421 |
+
|
422 |
+
def prepare_latents(
|
423 |
+
self,
|
424 |
+
batch_size,
|
425 |
+
num_channels_latents,
|
426 |
+
height,
|
427 |
+
width,
|
428 |
+
dtype,
|
429 |
+
device,
|
430 |
+
generator,
|
431 |
+
latents=None,
|
432 |
+
):
|
433 |
+
shape = (
|
434 |
+
batch_size,
|
435 |
+
num_channels_latents,
|
436 |
+
height // self.vae_scale_factor,
|
437 |
+
width // self.vae_scale_factor,
|
438 |
+
)
|
439 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
440 |
+
raise ValueError(
|
441 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
442 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
443 |
+
)
|
444 |
+
|
445 |
+
if latents is None:
|
446 |
+
latents = randn_tensor(
|
447 |
+
shape, generator=generator, device=device, dtype=dtype
|
448 |
+
)
|
449 |
+
else:
|
450 |
+
latents = latents.to(device)
|
451 |
+
|
452 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
453 |
+
latents = latents * self.scheduler.init_noise_sigma
|
454 |
+
return latents
|
455 |
+
|
456 |
+
def prepare_image_latents(
|
457 |
+
self,
|
458 |
+
image,
|
459 |
+
batch_size,
|
460 |
+
num_images_per_prompt,
|
461 |
+
dtype,
|
462 |
+
device,
|
463 |
+
do_classifier_free_guidance,
|
464 |
+
generator=None,
|
465 |
+
):
|
466 |
+
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
|
467 |
+
raise ValueError(
|
468 |
+
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
469 |
+
)
|
470 |
+
|
471 |
+
image = image.to(device=device, dtype=dtype)
|
472 |
+
|
473 |
+
batch_size = batch_size * num_images_per_prompt
|
474 |
+
|
475 |
+
if image.shape[1] == 4:
|
476 |
+
image_latents = image
|
477 |
+
else:
|
478 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
479 |
+
raise ValueError(
|
480 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
481 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
482 |
+
)
|
483 |
+
|
484 |
+
if isinstance(generator, list):
|
485 |
+
image_latents = [
|
486 |
+
self.vae.encode(image[i : i + 1]).latent_dist.mode()
|
487 |
+
for i in range(batch_size)
|
488 |
+
]
|
489 |
+
image_latents = torch.cat(image_latents, dim=0)
|
490 |
+
else:
|
491 |
+
image_latents = self.vae.encode(image).latent_dist.mode()
|
492 |
+
|
493 |
+
if (
|
494 |
+
batch_size > image_latents.shape[0]
|
495 |
+
and batch_size % image_latents.shape[0] == 0
|
496 |
+
):
|
497 |
+
# expand image_latents for batch_size
|
498 |
+
deprecation_message = (
|
499 |
+
f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial"
|
500 |
+
" images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
|
501 |
+
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
|
502 |
+
" your script to pass as many initial images as text prompts to suppress this warning."
|
503 |
+
)
|
504 |
+
deprecate(
|
505 |
+
"len(prompt) != len(image)",
|
506 |
+
"1.0.0",
|
507 |
+
deprecation_message,
|
508 |
+
standard_warn=False,
|
509 |
+
)
|
510 |
+
additional_image_per_prompt = batch_size // image_latents.shape[0]
|
511 |
+
image_latents = torch.cat(
|
512 |
+
[image_latents] * additional_image_per_prompt, dim=0
|
513 |
+
)
|
514 |
+
elif (
|
515 |
+
batch_size > image_latents.shape[0]
|
516 |
+
and batch_size % image_latents.shape[0] != 0
|
517 |
+
):
|
518 |
+
raise ValueError(
|
519 |
+
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
|
520 |
+
)
|
521 |
+
else:
|
522 |
+
image_latents = torch.cat([image_latents], dim=0)
|
523 |
+
|
524 |
+
if do_classifier_free_guidance:
|
525 |
+
uncond_image_latents = torch.zeros_like(image_latents)
|
526 |
+
image_latents = torch.cat(
|
527 |
+
[image_latents, image_latents, uncond_image_latents], dim=0
|
528 |
+
)
|
529 |
+
|
530 |
+
return image_latents
|
531 |
+
|
532 |
+
@torch.no_grad()
|
533 |
+
def __call__(
|
534 |
+
self,
|
535 |
+
prompt: Union[str, List[str]] = None,
|
536 |
+
photo: Union[
|
537 |
+
torch.FloatTensor,
|
538 |
+
PIL.Image.Image,
|
539 |
+
np.ndarray,
|
540 |
+
List[torch.FloatTensor],
|
541 |
+
List[PIL.Image.Image],
|
542 |
+
List[np.ndarray],
|
543 |
+
] = None,
|
544 |
+
height: Optional[int] = None,
|
545 |
+
width: Optional[int] = None,
|
546 |
+
num_inference_steps: int = 100,
|
547 |
+
required_aovs: List[str] = ["albedo"],
|
548 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
549 |
+
num_images_per_prompt: Optional[int] = 1,
|
550 |
+
use_default_scaling_factor: Optional[bool] = False,
|
551 |
+
guidance_scale: float = 0.0,
|
552 |
+
image_guidance_scale: float = 0.0,
|
553 |
+
guidance_rescale: float = 0.0,
|
554 |
+
eta: float = 0.0,
|
555 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
556 |
+
latents: Optional[torch.FloatTensor] = None,
|
557 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
558 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
559 |
+
output_type: Optional[str] = "pil",
|
560 |
+
return_dict: bool = True,
|
561 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
562 |
+
callback_steps: int = 1,
|
563 |
+
):
|
564 |
+
r"""
|
565 |
+
The call function to the pipeline for generation.
|
566 |
+
|
567 |
+
Args:
|
568 |
+
prompt (`str` or `List[str]`, *optional*):
|
569 |
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
570 |
+
image (`torch.FloatTensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
571 |
+
`Image` or tensor representing an image batch to be repainted according to `prompt`. Can also accept
|
572 |
+
image latents as `image`, but if passing latents directly it is not encoded again.
|
573 |
+
num_inference_steps (`int`, *optional*, defaults to 100):
|
574 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
575 |
+
expense of slower inference.
|
576 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
577 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
578 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
579 |
+
image_guidance_scale (`float`, *optional*, defaults to 1.5):
|
580 |
+
Push the generated image towards the inital `image`. Image guidance scale is enabled by setting
|
581 |
+
`image_guidance_scale > 1`. Higher image guidance scale encourages generated images that are closely
|
582 |
+
linked to the source `image`, usually at the expense of lower image quality. This pipeline requires a
|
583 |
+
value of at least `1`.
|
584 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
585 |
+
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
586 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
587 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
588 |
+
The number of images to generate per prompt.
|
589 |
+
eta (`float`, *optional*, defaults to 0.0):
|
590 |
+
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
591 |
+
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
592 |
+
generator (`torch.Generator`, *optional*):
|
593 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
594 |
+
generation deterministic.
|
595 |
+
latents (`torch.FloatTensor`, *optional*):
|
596 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
597 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
598 |
+
tensor is generated by sampling using the supplied random `generator`.
|
599 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
600 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
601 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
602 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
603 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
604 |
+
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
605 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
606 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
607 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
608 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
609 |
+
plain tuple.
|
610 |
+
callback (`Callable`, *optional*):
|
611 |
+
A function that calls every `callback_steps` steps during inference. The function is called with the
|
612 |
+
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
613 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
614 |
+
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
615 |
+
every step.
|
616 |
+
|
617 |
+
Examples:
|
618 |
+
|
619 |
+
```py
|
620 |
+
>>> import PIL
|
621 |
+
>>> import requests
|
622 |
+
>>> import torch
|
623 |
+
>>> from io import BytesIO
|
624 |
+
|
625 |
+
>>> from diffusers import StableDiffusionInstructPix2PixPipeline
|
626 |
+
|
627 |
+
|
628 |
+
>>> def download_image(url):
|
629 |
+
... response = requests.get(url)
|
630 |
+
... return PIL.Image.open(BytesIO(response.content)).convert("RGB")
|
631 |
+
|
632 |
+
|
633 |
+
>>> img_url = "https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png"
|
634 |
+
|
635 |
+
>>> image = download_image(img_url).resize((512, 512))
|
636 |
+
|
637 |
+
>>> pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
|
638 |
+
... "timbrooks/instruct-pix2pix", torch_dtype=torch.float16
|
639 |
+
... )
|
640 |
+
>>> pipe = pipe.to("cuda")
|
641 |
+
|
642 |
+
>>> prompt = "make the mountains snowy"
|
643 |
+
>>> image = pipe(prompt=prompt, image=image).images[0]
|
644 |
+
```
|
645 |
+
|
646 |
+
Returns:
|
647 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
648 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
649 |
+
otherwise a `tuple` is returned where the first element is a list with the generated images and the
|
650 |
+
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
651 |
+
"not-safe-for-work" (nsfw) content.
|
652 |
+
"""
|
653 |
+
# 0. Check inputs
|
654 |
+
self.check_inputs(
|
655 |
+
prompt,
|
656 |
+
callback_steps,
|
657 |
+
negative_prompt,
|
658 |
+
prompt_embeds,
|
659 |
+
negative_prompt_embeds,
|
660 |
+
)
|
661 |
+
|
662 |
+
# 1. Define call parameters
|
663 |
+
if prompt is not None and isinstance(prompt, str):
|
664 |
+
batch_size = 1
|
665 |
+
elif prompt is not None and isinstance(prompt, list):
|
666 |
+
batch_size = len(prompt)
|
667 |
+
else:
|
668 |
+
batch_size = prompt_embeds.shape[0]
|
669 |
+
|
670 |
+
device = self._execution_device
|
671 |
+
do_classifier_free_guidance = (
|
672 |
+
guidance_scale > 1.0 and image_guidance_scale >= 1.0
|
673 |
+
)
|
674 |
+
# check if scheduler is in sigmas space
|
675 |
+
scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas")
|
676 |
+
|
677 |
+
# 2. Encode input prompt
|
678 |
+
prompt_embeds = self._encode_prompt(
|
679 |
+
prompt,
|
680 |
+
device,
|
681 |
+
num_images_per_prompt,
|
682 |
+
do_classifier_free_guidance,
|
683 |
+
negative_prompt,
|
684 |
+
prompt_embeds=prompt_embeds,
|
685 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
686 |
+
)
|
687 |
+
|
688 |
+
# 3. Preprocess image
|
689 |
+
# Normalize image to [-1,1]
|
690 |
+
preprocessed_photo = self.image_processor.preprocess(photo)
|
691 |
+
|
692 |
+
# 4. set timesteps
|
693 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
694 |
+
timesteps = self.scheduler.timesteps
|
695 |
+
|
696 |
+
# 5. Prepare Image latents
|
697 |
+
image_latents = self.prepare_image_latents(
|
698 |
+
preprocessed_photo,
|
699 |
+
batch_size,
|
700 |
+
num_images_per_prompt,
|
701 |
+
prompt_embeds.dtype,
|
702 |
+
device,
|
703 |
+
do_classifier_free_guidance,
|
704 |
+
generator,
|
705 |
+
)
|
706 |
+
image_latents = image_latents * self.vae.config.scaling_factor
|
707 |
+
|
708 |
+
height, width = image_latents.shape[-2:]
|
709 |
+
height = height * self.vae_scale_factor
|
710 |
+
width = width * self.vae_scale_factor
|
711 |
+
|
712 |
+
# 6. Prepare latent variables
|
713 |
+
num_channels_latents = self.unet.config.out_channels
|
714 |
+
latents = self.prepare_latents(
|
715 |
+
batch_size * num_images_per_prompt,
|
716 |
+
num_channels_latents,
|
717 |
+
height,
|
718 |
+
width,
|
719 |
+
prompt_embeds.dtype,
|
720 |
+
device,
|
721 |
+
generator,
|
722 |
+
latents,
|
723 |
+
)
|
724 |
+
|
725 |
+
# 7. Check that shapes of latents and image match the UNet channels
|
726 |
+
num_channels_image = image_latents.shape[1]
|
727 |
+
if num_channels_latents + num_channels_image != self.unet.config.in_channels:
|
728 |
+
raise ValueError(
|
729 |
+
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
|
730 |
+
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
731 |
+
f" `num_channels_image`: {num_channels_image} "
|
732 |
+
f" = {num_channels_latents+num_channels_image}. Please verify the config of"
|
733 |
+
" `pipeline.unet` or your `image` input."
|
734 |
+
)
|
735 |
+
|
736 |
+
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
737 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
738 |
+
|
739 |
+
# 9. Denoising loop
|
740 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
741 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
742 |
+
for i, t in enumerate(timesteps):
|
743 |
+
# Expand the latents if we are doing classifier free guidance.
|
744 |
+
# The latents are expanded 3 times because for pix2pix the guidance\
|
745 |
+
# is applied for both the text and the input image.
|
746 |
+
latent_model_input = (
|
747 |
+
torch.cat([latents] * 3) if do_classifier_free_guidance else latents
|
748 |
+
)
|
749 |
+
|
750 |
+
# concat latents, image_latents in the channel dimension
|
751 |
+
scaled_latent_model_input = self.scheduler.scale_model_input(
|
752 |
+
latent_model_input, t
|
753 |
+
)
|
754 |
+
scaled_latent_model_input = torch.cat(
|
755 |
+
[scaled_latent_model_input, image_latents], dim=1
|
756 |
+
)
|
757 |
+
|
758 |
+
# predict the noise residual
|
759 |
+
noise_pred = self.unet(
|
760 |
+
scaled_latent_model_input,
|
761 |
+
t,
|
762 |
+
encoder_hidden_states=prompt_embeds,
|
763 |
+
return_dict=False,
|
764 |
+
)[0]
|
765 |
+
|
766 |
+
# perform guidance
|
767 |
+
if do_classifier_free_guidance:
|
768 |
+
(
|
769 |
+
noise_pred_text,
|
770 |
+
noise_pred_image,
|
771 |
+
noise_pred_uncond,
|
772 |
+
) = noise_pred.chunk(3)
|
773 |
+
noise_pred = (
|
774 |
+
noise_pred_uncond
|
775 |
+
+ guidance_scale * (noise_pred_text - noise_pred_image)
|
776 |
+
+ image_guidance_scale * (noise_pred_image - noise_pred_uncond)
|
777 |
+
)
|
778 |
+
|
779 |
+
if do_classifier_free_guidance and guidance_rescale > 0.0:
|
780 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
781 |
+
noise_pred = rescale_noise_cfg(
|
782 |
+
noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
|
783 |
+
)
|
784 |
+
|
785 |
+
# compute the previous noisy sample x_t -> x_t-1
|
786 |
+
latents = self.scheduler.step(
|
787 |
+
noise_pred, t, latents, **extra_step_kwargs, return_dict=False
|
788 |
+
)[0]
|
789 |
+
|
790 |
+
# call the callback, if provided
|
791 |
+
if i == len(timesteps) - 1 or (
|
792 |
+
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
793 |
+
):
|
794 |
+
progress_bar.update()
|
795 |
+
if callback is not None and i % callback_steps == 0:
|
796 |
+
callback(i, t, latents)
|
797 |
+
|
798 |
+
aov_latents = latents / self.vae.config.scaling_factor
|
799 |
+
aov = self.vae.decode(aov_latents, return_dict=False)[0]
|
800 |
+
do_denormalize = [True] * aov.shape[0]
|
801 |
+
aov_name = required_aovs[0]
|
802 |
+
if aov_name == "albedo" or aov_name == "irradiance":
|
803 |
+
do_gamma_correction = True
|
804 |
+
else:
|
805 |
+
do_gamma_correction = False
|
806 |
+
|
807 |
+
if aov_name == "roughness" or aov_name == "metallic":
|
808 |
+
aov = aov[:, 0:1].repeat(1, 3, 1, 1)
|
809 |
+
|
810 |
+
aov = self.image_processor.postprocess(
|
811 |
+
aov,
|
812 |
+
output_type=output_type,
|
813 |
+
do_denormalize=do_denormalize,
|
814 |
+
do_gamma_correction=do_gamma_correction,
|
815 |
+
)
|
816 |
+
aovs = [aov]
|
817 |
+
|
818 |
+
# Offload last model to CPU
|
819 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
820 |
+
self.final_offload_hook.offload()
|
821 |
+
return StableDiffusionAOVPipelineOutput(images=aovs)
|
run.sh
CHANGED
@@ -1,5 +1,15 @@
|
|
1 |
#!/bin/bash
|
2 |
-
CONDA_ENV=$(head -1 /code/environment.yml | cut -d" " -f2)
|
3 |
eval "$(conda shell.bash hook)"
|
4 |
-
conda
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
python app.py
|
|
|
1 |
#!/bin/bash
|
|
|
2 |
eval "$(conda shell.bash hook)"
|
3 |
+
conda create -n matgen-plus python=3.11
|
4 |
+
conda activate matgen-plus
|
5 |
+
|
6 |
+
pip install diffusers["torch"] transformers accelerate xformers
|
7 |
+
pip install gradio
|
8 |
+
pip install controlnet-aux
|
9 |
+
|
10 |
+
# text2tex
|
11 |
+
conda install pytorch3d -c pytorch -c conda-forge
|
12 |
+
conda install -c conda-forge open-clip-torch pytorch-lightning
|
13 |
+
pip install trimesh xatlas scikit-learn opencv-python omegaconf
|
14 |
+
|
15 |
python app.py
|
settings.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
DEFAULT_MODEL_ID = os.getenv("DEFAULT_MODEL_ID", "stable-diffusion-v1-5/stable-diffusion-v1-5")
|
6 |
+
|
7 |
+
MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "3"))
|
8 |
+
DEFAULT_NUM_IMAGES = min(MAX_NUM_IMAGES, int(os.getenv("DEFAULT_NUM_IMAGES", "1")))
|
9 |
+
MAX_IMAGE_RESOLUTION = int(os.getenv("MAX_IMAGE_RESOLUTION", "2048"))
|
10 |
+
DEFAULT_IMAGE_RESOLUTION = min(MAX_IMAGE_RESOLUTION, int(os.getenv("DEFAULT_IMAGE_RESOLUTION", "1024")))
|
11 |
+
|
12 |
+
ALLOW_CHANGING_BASE_MODEL = os.getenv("SPACE_ID") != "hysts/ControlNet-v1-1"
|
13 |
+
SHOW_DUPLICATE_BUTTON = os.getenv("SHOW_DUPLICATE_BUTTON") == "1"
|
14 |
+
|
15 |
+
MAX_SEED = np.iinfo(np.int32).max
|
16 |
+
|
17 |
+
# Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
|
18 |
+
|
19 |
+
# setup CUDA
|
20 |
+
# disable the following when deployting to hugging face
|
21 |
+
# if os.getenv("CUDA_VISIBLE_DEVICES") is None:
|
22 |
+
# os.environ["CUDA_VISIBLE_DEVICES"] = "7"
|
23 |
+
# os.environ["GRADIO_SERVER_PORT"] = "7864"
|
text2tex/lib/__init__.py
ADDED
File without changes
|
text2tex/lib/camera_helper.py
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
6 |
+
|
7 |
+
from pytorch3d.renderer import (
|
8 |
+
PerspectiveCameras,
|
9 |
+
look_at_view_transform
|
10 |
+
)
|
11 |
+
|
12 |
+
# customized
|
13 |
+
import sys
|
14 |
+
sys.path.append(".")
|
15 |
+
|
16 |
+
from lib.constants import VIEWPOINTS
|
17 |
+
|
18 |
+
# ---------------- UTILS ----------------------
|
19 |
+
|
20 |
+
def degree_to_radian(d):
|
21 |
+
return d * np.pi / 180
|
22 |
+
|
23 |
+
def radian_to_degree(r):
|
24 |
+
return 180 * r / np.pi
|
25 |
+
|
26 |
+
def xyz_to_polar(xyz):
|
27 |
+
""" assume y-axis is the up axis """
|
28 |
+
|
29 |
+
x, y, z = xyz
|
30 |
+
|
31 |
+
theta = 180 * np.arccos(z) / np.pi
|
32 |
+
phi = 180 * np.arccos(y) / np.pi
|
33 |
+
|
34 |
+
return theta, phi
|
35 |
+
|
36 |
+
def polar_to_xyz(theta, phi, dist):
|
37 |
+
""" assume y-axis is the up axis """
|
38 |
+
|
39 |
+
theta = degree_to_radian(theta)
|
40 |
+
phi = degree_to_radian(phi)
|
41 |
+
|
42 |
+
x = np.sin(phi) * np.sin(theta) * dist
|
43 |
+
y = np.cos(phi) * dist
|
44 |
+
z = np.sin(phi) * np.cos(theta) * dist
|
45 |
+
|
46 |
+
return [x, y, z]
|
47 |
+
|
48 |
+
|
49 |
+
# ---------------- VIEWPOINTS ----------------------
|
50 |
+
|
51 |
+
|
52 |
+
def filter_viewpoints(pre_viewpoints: dict, viewpoints: dict):
|
53 |
+
""" return the binary mask of viewpoints to be filtered """
|
54 |
+
|
55 |
+
filter_mask = [0 for _ in viewpoints.keys()]
|
56 |
+
for i, v in viewpoints.items():
|
57 |
+
x_v, y_v, z_v = polar_to_xyz(v["azim"], 90 - v["elev"], v["dist"])
|
58 |
+
|
59 |
+
for _, pv in pre_viewpoints.items():
|
60 |
+
x_pv, y_pv, z_pv = polar_to_xyz(pv["azim"], 90 - pv["elev"], pv["dist"])
|
61 |
+
sim = cosine_similarity(
|
62 |
+
np.array([[x_v, y_v, z_v]]),
|
63 |
+
np.array([[x_pv, y_pv, z_pv]])
|
64 |
+
)[0, 0]
|
65 |
+
|
66 |
+
if sim > 0.9:
|
67 |
+
filter_mask[i] = 1
|
68 |
+
|
69 |
+
return filter_mask
|
70 |
+
|
71 |
+
|
72 |
+
def init_viewpoints(mode, sample_space, init_dist, init_elev, principle_directions,
|
73 |
+
use_principle=True, use_shapenet=False, use_objaverse=False):
|
74 |
+
|
75 |
+
if mode == "predefined":
|
76 |
+
|
77 |
+
(
|
78 |
+
dist_list,
|
79 |
+
elev_list,
|
80 |
+
azim_list,
|
81 |
+
sector_list
|
82 |
+
) = init_predefined_viewpoints(sample_space, init_dist, init_elev)
|
83 |
+
|
84 |
+
elif mode == "hemisphere":
|
85 |
+
|
86 |
+
(
|
87 |
+
dist_list,
|
88 |
+
elev_list,
|
89 |
+
azim_list,
|
90 |
+
sector_list
|
91 |
+
) = init_hemisphere_viewpoints(sample_space, init_dist)
|
92 |
+
|
93 |
+
else:
|
94 |
+
raise NotImplementedError()
|
95 |
+
|
96 |
+
# punishments for views -> in case always selecting the same view
|
97 |
+
view_punishments = [1 for _ in range(len(dist_list))]
|
98 |
+
|
99 |
+
if use_principle:
|
100 |
+
|
101 |
+
(
|
102 |
+
dist_list,
|
103 |
+
elev_list,
|
104 |
+
azim_list,
|
105 |
+
sector_list,
|
106 |
+
view_punishments
|
107 |
+
) = init_principle_viewpoints(
|
108 |
+
principle_directions,
|
109 |
+
dist_list,
|
110 |
+
elev_list,
|
111 |
+
azim_list,
|
112 |
+
sector_list,
|
113 |
+
view_punishments,
|
114 |
+
use_shapenet,
|
115 |
+
use_objaverse
|
116 |
+
)
|
117 |
+
|
118 |
+
return dist_list, elev_list, azim_list, sector_list, view_punishments
|
119 |
+
|
120 |
+
|
121 |
+
def init_principle_viewpoints(
|
122 |
+
principle_directions,
|
123 |
+
dist_list,
|
124 |
+
elev_list,
|
125 |
+
azim_list,
|
126 |
+
sector_list,
|
127 |
+
view_punishments,
|
128 |
+
use_shapenet=False,
|
129 |
+
use_objaverse=False
|
130 |
+
):
|
131 |
+
|
132 |
+
if use_shapenet:
|
133 |
+
key = "shapenet"
|
134 |
+
|
135 |
+
pre_elev_list = [v for v in VIEWPOINTS[key]["elev"]]
|
136 |
+
pre_azim_list = [v for v in VIEWPOINTS[key]["azim"]]
|
137 |
+
pre_sector_list = [v for v in VIEWPOINTS[key]["sector"]]
|
138 |
+
|
139 |
+
num_principle = 10
|
140 |
+
pre_dist_list = [dist_list[0] for _ in range(num_principle)]
|
141 |
+
pre_view_punishments = [0 for _ in range(num_principle)]
|
142 |
+
|
143 |
+
elif use_objaverse:
|
144 |
+
key = "objaverse"
|
145 |
+
|
146 |
+
pre_elev_list = [v for v in VIEWPOINTS[key]["elev"]]
|
147 |
+
pre_azim_list = [v for v in VIEWPOINTS[key]["azim"]]
|
148 |
+
pre_sector_list = [v for v in VIEWPOINTS[key]["sector"]]
|
149 |
+
|
150 |
+
num_principle = 10
|
151 |
+
pre_dist_list = [dist_list[0] for _ in range(num_principle)]
|
152 |
+
pre_view_punishments = [0 for _ in range(num_principle)]
|
153 |
+
else:
|
154 |
+
num_principle = 6
|
155 |
+
pre_elev_list = [v for v in VIEWPOINTS[num_principle]["elev"]]
|
156 |
+
pre_azim_list = [v for v in VIEWPOINTS[num_principle]["azim"]]
|
157 |
+
pre_sector_list = [v for v in VIEWPOINTS[num_principle]["sector"]]
|
158 |
+
pre_dist_list = [dist_list[0] for _ in range(num_principle)]
|
159 |
+
pre_view_punishments = [0 for _ in range(num_principle)]
|
160 |
+
|
161 |
+
dist_list = pre_dist_list + dist_list
|
162 |
+
elev_list = pre_elev_list + elev_list
|
163 |
+
azim_list = pre_azim_list + azim_list
|
164 |
+
sector_list = pre_sector_list + sector_list
|
165 |
+
view_punishments = pre_view_punishments + view_punishments
|
166 |
+
|
167 |
+
return dist_list, elev_list, azim_list, sector_list, view_punishments
|
168 |
+
|
169 |
+
|
170 |
+
def init_predefined_viewpoints(sample_space, init_dist, init_elev):
|
171 |
+
|
172 |
+
viewpoints = VIEWPOINTS[sample_space]
|
173 |
+
|
174 |
+
assert sample_space == len(viewpoints["sector"])
|
175 |
+
|
176 |
+
dist_list = [init_dist for _ in range(sample_space)] # always the same dist
|
177 |
+
elev_list = [viewpoints["elev"][i] for i in range(sample_space)]
|
178 |
+
azim_list = [viewpoints["azim"][i] for i in range(sample_space)]
|
179 |
+
sector_list = [viewpoints["sector"][i] for i in range(sample_space)]
|
180 |
+
|
181 |
+
return dist_list, elev_list, azim_list, sector_list
|
182 |
+
|
183 |
+
|
184 |
+
def init_hemisphere_viewpoints(sample_space, init_dist):
|
185 |
+
"""
|
186 |
+
y is up-axis
|
187 |
+
"""
|
188 |
+
|
189 |
+
num_points = 2 * sample_space
|
190 |
+
ga = np.pi * (3. - np.sqrt(5.)) # golden angle in radians
|
191 |
+
|
192 |
+
flags = []
|
193 |
+
elev_list = [] # degree
|
194 |
+
azim_list = [] # degree
|
195 |
+
|
196 |
+
for i in range(num_points):
|
197 |
+
y = 1 - (i / float(num_points - 1)) * 2 # y goes from 1 to -1
|
198 |
+
|
199 |
+
# only take the north hemisphere
|
200 |
+
if y >= 0:
|
201 |
+
flags.append(True)
|
202 |
+
else:
|
203 |
+
flags.append(False)
|
204 |
+
|
205 |
+
theta = ga * i # golden angle increment
|
206 |
+
|
207 |
+
elev_list.append(radian_to_degree(np.arcsin(y)))
|
208 |
+
azim_list.append(radian_to_degree(theta))
|
209 |
+
|
210 |
+
radius = np.sqrt(1 - y * y) # radius at y
|
211 |
+
x = np.cos(theta) * radius
|
212 |
+
z = np.sin(theta) * radius
|
213 |
+
|
214 |
+
elev_list = [elev_list[i] for i in range(len(elev_list)) if flags[i]]
|
215 |
+
azim_list = [azim_list[i] for i in range(len(azim_list)) if flags[i]]
|
216 |
+
|
217 |
+
dist_list = [init_dist for _ in elev_list]
|
218 |
+
sector_list = ["good" for _ in elev_list] # HACK don't define sector names for now
|
219 |
+
|
220 |
+
return dist_list, elev_list, azim_list, sector_list
|
221 |
+
|
222 |
+
|
223 |
+
# ---------------- CAMERAS ----------------------
|
224 |
+
|
225 |
+
|
226 |
+
def init_camera(dist, elev, azim, image_size, device):
|
227 |
+
R, T = look_at_view_transform(dist, elev, azim)
|
228 |
+
image_size = torch.tensor([image_size, image_size]).unsqueeze(0)
|
229 |
+
cameras = PerspectiveCameras(R=R, T=T, device=device, image_size=image_size)
|
230 |
+
|
231 |
+
return cameras
|
text2tex/lib/constants.py
ADDED
@@ -0,0 +1,648 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
PALETTE = {
|
2 |
+
0: [255, 255, 255], # white - background
|
3 |
+
1: [204, 50, 50], # red - old
|
4 |
+
2: [231, 180, 22], # yellow - update
|
5 |
+
3: [45, 201, 55] # green - new
|
6 |
+
}
|
7 |
+
|
8 |
+
QUAD_WEIGHTS = {
|
9 |
+
0: 0, # background
|
10 |
+
1: 0.1, # old
|
11 |
+
2: 0.5, # update
|
12 |
+
3: 1 # new
|
13 |
+
}
|
14 |
+
|
15 |
+
VIEWPOINTS = {
|
16 |
+
1: {
|
17 |
+
"azim": [
|
18 |
+
0
|
19 |
+
],
|
20 |
+
"elev": [
|
21 |
+
0
|
22 |
+
],
|
23 |
+
"sector": [
|
24 |
+
"front"
|
25 |
+
]
|
26 |
+
},
|
27 |
+
2: {
|
28 |
+
"azim": [
|
29 |
+
0,
|
30 |
+
30
|
31 |
+
],
|
32 |
+
"elev": [
|
33 |
+
0,
|
34 |
+
0
|
35 |
+
],
|
36 |
+
"sector": [
|
37 |
+
"front",
|
38 |
+
"front"
|
39 |
+
]
|
40 |
+
},
|
41 |
+
4: {
|
42 |
+
"azim": [
|
43 |
+
45,
|
44 |
+
315,
|
45 |
+
135,
|
46 |
+
225,
|
47 |
+
],
|
48 |
+
"elev": [
|
49 |
+
0,
|
50 |
+
0,
|
51 |
+
0,
|
52 |
+
0,
|
53 |
+
],
|
54 |
+
"sector": [
|
55 |
+
"front right",
|
56 |
+
"front left",
|
57 |
+
"back right",
|
58 |
+
"back left",
|
59 |
+
]
|
60 |
+
},
|
61 |
+
6: {
|
62 |
+
"azim": [
|
63 |
+
0,
|
64 |
+
90,
|
65 |
+
270,
|
66 |
+
0,
|
67 |
+
180,
|
68 |
+
0
|
69 |
+
],
|
70 |
+
"elev": [
|
71 |
+
0,
|
72 |
+
0,
|
73 |
+
0,
|
74 |
+
90,
|
75 |
+
0,
|
76 |
+
-90
|
77 |
+
],
|
78 |
+
"sector": [
|
79 |
+
"front",
|
80 |
+
"right",
|
81 |
+
"left",
|
82 |
+
"top",
|
83 |
+
"back",
|
84 |
+
"bottom",
|
85 |
+
]
|
86 |
+
},
|
87 |
+
"shapenet": {
|
88 |
+
"azim": [
|
89 |
+
270,
|
90 |
+
315,
|
91 |
+
225,
|
92 |
+
0,
|
93 |
+
180,
|
94 |
+
45,
|
95 |
+
135,
|
96 |
+
90,
|
97 |
+
270,
|
98 |
+
270
|
99 |
+
],
|
100 |
+
"elev": [
|
101 |
+
15,
|
102 |
+
15,
|
103 |
+
15,
|
104 |
+
15,
|
105 |
+
15,
|
106 |
+
15,
|
107 |
+
15,
|
108 |
+
15,
|
109 |
+
90,
|
110 |
+
-90
|
111 |
+
],
|
112 |
+
"sector": [
|
113 |
+
"front",
|
114 |
+
"front right",
|
115 |
+
"front left",
|
116 |
+
"right",
|
117 |
+
"left",
|
118 |
+
"back right",
|
119 |
+
"back left",
|
120 |
+
"back",
|
121 |
+
"top",
|
122 |
+
"bottom",
|
123 |
+
]
|
124 |
+
},
|
125 |
+
"objaverse": {
|
126 |
+
"azim": [
|
127 |
+
0,
|
128 |
+
45,
|
129 |
+
315,
|
130 |
+
90,
|
131 |
+
270,
|
132 |
+
135,
|
133 |
+
225,
|
134 |
+
180,
|
135 |
+
0,
|
136 |
+
0
|
137 |
+
],
|
138 |
+
"elev": [
|
139 |
+
15,
|
140 |
+
15,
|
141 |
+
15,
|
142 |
+
15,
|
143 |
+
15,
|
144 |
+
15,
|
145 |
+
15,
|
146 |
+
15,
|
147 |
+
90,
|
148 |
+
-90
|
149 |
+
],
|
150 |
+
"sector": [
|
151 |
+
"front",
|
152 |
+
"front right",
|
153 |
+
"front left",
|
154 |
+
"right",
|
155 |
+
"left",
|
156 |
+
"back right",
|
157 |
+
"back left",
|
158 |
+
"back",
|
159 |
+
"top",
|
160 |
+
"bottom",
|
161 |
+
]
|
162 |
+
},
|
163 |
+
12: {
|
164 |
+
"azim": [
|
165 |
+
45,
|
166 |
+
315,
|
167 |
+
135,
|
168 |
+
225,
|
169 |
+
|
170 |
+
0,
|
171 |
+
45,
|
172 |
+
315,
|
173 |
+
90,
|
174 |
+
270,
|
175 |
+
135,
|
176 |
+
225,
|
177 |
+
180,
|
178 |
+
],
|
179 |
+
"elev": [
|
180 |
+
0,
|
181 |
+
0,
|
182 |
+
0,
|
183 |
+
0,
|
184 |
+
|
185 |
+
45,
|
186 |
+
45,
|
187 |
+
45,
|
188 |
+
45,
|
189 |
+
45,
|
190 |
+
45,
|
191 |
+
45,
|
192 |
+
45,
|
193 |
+
],
|
194 |
+
"sector": [
|
195 |
+
"front right",
|
196 |
+
"front left",
|
197 |
+
"back right",
|
198 |
+
"back left",
|
199 |
+
|
200 |
+
"front",
|
201 |
+
"front right",
|
202 |
+
"front left",
|
203 |
+
"right",
|
204 |
+
"left",
|
205 |
+
"back right",
|
206 |
+
"back left",
|
207 |
+
"back",
|
208 |
+
]
|
209 |
+
},
|
210 |
+
20: {
|
211 |
+
"azim": [
|
212 |
+
45,
|
213 |
+
315,
|
214 |
+
135,
|
215 |
+
225,
|
216 |
+
|
217 |
+
0,
|
218 |
+
45,
|
219 |
+
315,
|
220 |
+
90,
|
221 |
+
270,
|
222 |
+
135,
|
223 |
+
225,
|
224 |
+
180,
|
225 |
+
|
226 |
+
0,
|
227 |
+
45,
|
228 |
+
315,
|
229 |
+
90,
|
230 |
+
270,
|
231 |
+
135,
|
232 |
+
225,
|
233 |
+
180,
|
234 |
+
],
|
235 |
+
"elev": [
|
236 |
+
0,
|
237 |
+
0,
|
238 |
+
0,
|
239 |
+
0,
|
240 |
+
|
241 |
+
30,
|
242 |
+
30,
|
243 |
+
30,
|
244 |
+
30,
|
245 |
+
30,
|
246 |
+
30,
|
247 |
+
30,
|
248 |
+
30,
|
249 |
+
|
250 |
+
60,
|
251 |
+
60,
|
252 |
+
60,
|
253 |
+
60,
|
254 |
+
60,
|
255 |
+
60,
|
256 |
+
60,
|
257 |
+
60,
|
258 |
+
],
|
259 |
+
"sector": [
|
260 |
+
"front right",
|
261 |
+
"front left",
|
262 |
+
"back right",
|
263 |
+
"back left",
|
264 |
+
|
265 |
+
"front",
|
266 |
+
"front right",
|
267 |
+
"front left",
|
268 |
+
"right",
|
269 |
+
"left",
|
270 |
+
"back right",
|
271 |
+
"back left",
|
272 |
+
"back",
|
273 |
+
|
274 |
+
"front",
|
275 |
+
"front right",
|
276 |
+
"front left",
|
277 |
+
"right",
|
278 |
+
"left",
|
279 |
+
"back right",
|
280 |
+
"back left",
|
281 |
+
"back",
|
282 |
+
]
|
283 |
+
},
|
284 |
+
36: {
|
285 |
+
"azim": [
|
286 |
+
45,
|
287 |
+
315,
|
288 |
+
135,
|
289 |
+
225,
|
290 |
+
|
291 |
+
0,
|
292 |
+
45,
|
293 |
+
315,
|
294 |
+
90,
|
295 |
+
270,
|
296 |
+
135,
|
297 |
+
225,
|
298 |
+
180,
|
299 |
+
|
300 |
+
0,
|
301 |
+
45,
|
302 |
+
315,
|
303 |
+
90,
|
304 |
+
270,
|
305 |
+
135,
|
306 |
+
225,
|
307 |
+
180,
|
308 |
+
|
309 |
+
22.5,
|
310 |
+
337.5,
|
311 |
+
67.5,
|
312 |
+
292.5,
|
313 |
+
112.5,
|
314 |
+
247.5,
|
315 |
+
157.5,
|
316 |
+
202.5,
|
317 |
+
|
318 |
+
22.5,
|
319 |
+
337.5,
|
320 |
+
67.5,
|
321 |
+
292.5,
|
322 |
+
112.5,
|
323 |
+
247.5,
|
324 |
+
157.5,
|
325 |
+
202.5,
|
326 |
+
],
|
327 |
+
"elev": [
|
328 |
+
0,
|
329 |
+
0,
|
330 |
+
0,
|
331 |
+
0,
|
332 |
+
|
333 |
+
30,
|
334 |
+
30,
|
335 |
+
30,
|
336 |
+
30,
|
337 |
+
30,
|
338 |
+
30,
|
339 |
+
30,
|
340 |
+
30,
|
341 |
+
|
342 |
+
60,
|
343 |
+
60,
|
344 |
+
60,
|
345 |
+
60,
|
346 |
+
60,
|
347 |
+
60,
|
348 |
+
60,
|
349 |
+
60,
|
350 |
+
|
351 |
+
15,
|
352 |
+
15,
|
353 |
+
15,
|
354 |
+
15,
|
355 |
+
15,
|
356 |
+
15,
|
357 |
+
15,
|
358 |
+
15,
|
359 |
+
|
360 |
+
45,
|
361 |
+
45,
|
362 |
+
45,
|
363 |
+
45,
|
364 |
+
45,
|
365 |
+
45,
|
366 |
+
45,
|
367 |
+
45,
|
368 |
+
],
|
369 |
+
"sector": [
|
370 |
+
"front right",
|
371 |
+
"front left",
|
372 |
+
"back right",
|
373 |
+
"back left",
|
374 |
+
|
375 |
+
"front",
|
376 |
+
"front right",
|
377 |
+
"front left",
|
378 |
+
"right",
|
379 |
+
"left",
|
380 |
+
"back right",
|
381 |
+
"back left",
|
382 |
+
"back",
|
383 |
+
|
384 |
+
"top front",
|
385 |
+
"top right",
|
386 |
+
"top left",
|
387 |
+
"top right",
|
388 |
+
"top left",
|
389 |
+
"top right",
|
390 |
+
"top left",
|
391 |
+
"top back",
|
392 |
+
|
393 |
+
"front right",
|
394 |
+
"front left",
|
395 |
+
"front right",
|
396 |
+
"front left",
|
397 |
+
"back right",
|
398 |
+
"back left",
|
399 |
+
"back right",
|
400 |
+
"back left",
|
401 |
+
|
402 |
+
"front right",
|
403 |
+
"front left",
|
404 |
+
"front right",
|
405 |
+
"front left",
|
406 |
+
"back right",
|
407 |
+
"back left",
|
408 |
+
"back right",
|
409 |
+
"back left",
|
410 |
+
]
|
411 |
+
},
|
412 |
+
68: {
|
413 |
+
"azim": [
|
414 |
+
45,
|
415 |
+
315,
|
416 |
+
135,
|
417 |
+
225,
|
418 |
+
|
419 |
+
0,
|
420 |
+
45,
|
421 |
+
315,
|
422 |
+
90,
|
423 |
+
270,
|
424 |
+
135,
|
425 |
+
225,
|
426 |
+
180,
|
427 |
+
|
428 |
+
0,
|
429 |
+
45,
|
430 |
+
315,
|
431 |
+
90,
|
432 |
+
270,
|
433 |
+
135,
|
434 |
+
225,
|
435 |
+
180,
|
436 |
+
|
437 |
+
22.5,
|
438 |
+
337.5,
|
439 |
+
67.5,
|
440 |
+
292.5,
|
441 |
+
112.5,
|
442 |
+
247.5,
|
443 |
+
157.5,
|
444 |
+
202.5,
|
445 |
+
|
446 |
+
22.5,
|
447 |
+
337.5,
|
448 |
+
67.5,
|
449 |
+
292.5,
|
450 |
+
112.5,
|
451 |
+
247.5,
|
452 |
+
157.5,
|
453 |
+
202.5,
|
454 |
+
|
455 |
+
0,
|
456 |
+
45,
|
457 |
+
315,
|
458 |
+
90,
|
459 |
+
270,
|
460 |
+
135,
|
461 |
+
225,
|
462 |
+
180,
|
463 |
+
|
464 |
+
0,
|
465 |
+
45,
|
466 |
+
315,
|
467 |
+
90,
|
468 |
+
270,
|
469 |
+
135,
|
470 |
+
225,
|
471 |
+
180,
|
472 |
+
|
473 |
+
22.5,
|
474 |
+
337.5,
|
475 |
+
67.5,
|
476 |
+
292.5,
|
477 |
+
112.5,
|
478 |
+
247.5,
|
479 |
+
157.5,
|
480 |
+
202.5,
|
481 |
+
|
482 |
+
22.5,
|
483 |
+
337.5,
|
484 |
+
67.5,
|
485 |
+
292.5,
|
486 |
+
112.5,
|
487 |
+
247.5,
|
488 |
+
157.5,
|
489 |
+
202.5
|
490 |
+
],
|
491 |
+
"elev": [
|
492 |
+
0,
|
493 |
+
0,
|
494 |
+
0,
|
495 |
+
0,
|
496 |
+
|
497 |
+
30,
|
498 |
+
30,
|
499 |
+
30,
|
500 |
+
30,
|
501 |
+
30,
|
502 |
+
30,
|
503 |
+
30,
|
504 |
+
30,
|
505 |
+
|
506 |
+
60,
|
507 |
+
60,
|
508 |
+
60,
|
509 |
+
60,
|
510 |
+
60,
|
511 |
+
60,
|
512 |
+
60,
|
513 |
+
60,
|
514 |
+
|
515 |
+
15,
|
516 |
+
15,
|
517 |
+
15,
|
518 |
+
15,
|
519 |
+
15,
|
520 |
+
15,
|
521 |
+
15,
|
522 |
+
15,
|
523 |
+
|
524 |
+
45,
|
525 |
+
45,
|
526 |
+
45,
|
527 |
+
45,
|
528 |
+
45,
|
529 |
+
45,
|
530 |
+
45,
|
531 |
+
45,
|
532 |
+
|
533 |
+
-30,
|
534 |
+
-30,
|
535 |
+
-30,
|
536 |
+
-30,
|
537 |
+
-30,
|
538 |
+
-30,
|
539 |
+
-30,
|
540 |
+
-30,
|
541 |
+
|
542 |
+
-60,
|
543 |
+
-60,
|
544 |
+
-60,
|
545 |
+
-60,
|
546 |
+
-60,
|
547 |
+
-60,
|
548 |
+
-60,
|
549 |
+
-60,
|
550 |
+
|
551 |
+
-15,
|
552 |
+
-15,
|
553 |
+
-15,
|
554 |
+
-15,
|
555 |
+
-15,
|
556 |
+
-15,
|
557 |
+
-15,
|
558 |
+
-15,
|
559 |
+
|
560 |
+
-45,
|
561 |
+
-45,
|
562 |
+
-45,
|
563 |
+
-45,
|
564 |
+
-45,
|
565 |
+
-45,
|
566 |
+
-45,
|
567 |
+
-45,
|
568 |
+
],
|
569 |
+
"sector": [
|
570 |
+
"front right",
|
571 |
+
"front left",
|
572 |
+
"back right",
|
573 |
+
"back left",
|
574 |
+
|
575 |
+
"front",
|
576 |
+
"front right",
|
577 |
+
"front left",
|
578 |
+
"right",
|
579 |
+
"left",
|
580 |
+
"back right",
|
581 |
+
"back left",
|
582 |
+
"back",
|
583 |
+
|
584 |
+
"top front",
|
585 |
+
"top right",
|
586 |
+
"top left",
|
587 |
+
"top right",
|
588 |
+
"top left",
|
589 |
+
"top right",
|
590 |
+
"top left",
|
591 |
+
"top back",
|
592 |
+
|
593 |
+
"front right",
|
594 |
+
"front left",
|
595 |
+
"front right",
|
596 |
+
"front left",
|
597 |
+
"back right",
|
598 |
+
"back left",
|
599 |
+
"back right",
|
600 |
+
"back left",
|
601 |
+
|
602 |
+
"front right",
|
603 |
+
"front left",
|
604 |
+
"front right",
|
605 |
+
"front left",
|
606 |
+
"back right",
|
607 |
+
"back left",
|
608 |
+
"back right",
|
609 |
+
"back left",
|
610 |
+
|
611 |
+
"front",
|
612 |
+
"front right",
|
613 |
+
"front left",
|
614 |
+
"right",
|
615 |
+
"left",
|
616 |
+
"back right",
|
617 |
+
"back left",
|
618 |
+
"back",
|
619 |
+
|
620 |
+
"bottom front",
|
621 |
+
"bottom right",
|
622 |
+
"bottom left",
|
623 |
+
"bottom right",
|
624 |
+
"bottom left",
|
625 |
+
"bottom right",
|
626 |
+
"bottom left",
|
627 |
+
"bottom back",
|
628 |
+
|
629 |
+
"bottom front right",
|
630 |
+
"bottom front left",
|
631 |
+
"bottom front right",
|
632 |
+
"bottom front left",
|
633 |
+
"bottom back right",
|
634 |
+
"bottom back left",
|
635 |
+
"bottom back right",
|
636 |
+
"bottom back left",
|
637 |
+
|
638 |
+
"bottom front right",
|
639 |
+
"bottom front left",
|
640 |
+
"bottom front right",
|
641 |
+
"bottom front left",
|
642 |
+
"bottom back right",
|
643 |
+
"bottom back left",
|
644 |
+
"bottom back right",
|
645 |
+
"bottom back left",
|
646 |
+
]
|
647 |
+
}
|
648 |
+
}
|
text2tex/lib/diffusion_helper.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from PIL import Image
|
7 |
+
from torchvision import transforms
|
8 |
+
|
9 |
+
# Stable Diffusion 2
|
10 |
+
from diffusers import (
|
11 |
+
StableDiffusionInpaintPipeline,
|
12 |
+
StableDiffusionPipeline,
|
13 |
+
EulerDiscreteScheduler
|
14 |
+
)
|
15 |
+
|
16 |
+
# customized
|
17 |
+
import sys
|
18 |
+
sys.path.append(".")
|
19 |
+
|
20 |
+
from models.ControlNet.gradio_depth2image import init_model, process
|
21 |
+
|
22 |
+
|
23 |
+
def get_controlnet_depth():
|
24 |
+
print("=> initializing ControlNet Depth...")
|
25 |
+
model, ddim_sampler = init_model()
|
26 |
+
|
27 |
+
return model, ddim_sampler
|
28 |
+
|
29 |
+
|
30 |
+
def get_inpainting(device):
|
31 |
+
print("=> initializing Inpainting...")
|
32 |
+
|
33 |
+
model = StableDiffusionInpaintPipeline.from_pretrained(
|
34 |
+
"stabilityai/stable-diffusion-2-inpainting",
|
35 |
+
torch_dtype=torch.float16,
|
36 |
+
).to(device)
|
37 |
+
|
38 |
+
return model
|
39 |
+
|
40 |
+
def get_text2image(device):
|
41 |
+
print("=> initializing Inpainting...")
|
42 |
+
|
43 |
+
model_id = "stabilityai/stable-diffusion-2"
|
44 |
+
scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
|
45 |
+
model = StableDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler, torch_dtype=torch.float16).to(device)
|
46 |
+
|
47 |
+
return model
|
48 |
+
|
49 |
+
|
50 |
+
@torch.no_grad()
|
51 |
+
def apply_controlnet_depth(model, ddim_sampler,
|
52 |
+
init_image, prompt, strength, ddim_steps,
|
53 |
+
generate_mask_image, keep_mask_image, depth_map_np,
|
54 |
+
a_prompt, n_prompt, guidance_scale, seed, eta, num_samples,
|
55 |
+
device, blend=0, save_memory=False):
|
56 |
+
"""
|
57 |
+
Use Stable Diffusion 2 to generate image
|
58 |
+
|
59 |
+
Arguments:
|
60 |
+
args: input arguments
|
61 |
+
model: Stable Diffusion 2 model
|
62 |
+
init_image_tensor: input image, torch.FloatTensor of shape (1, H, W, 3)
|
63 |
+
mask_tensor: depth map of the input image, torch.FloatTensor of shape (1, H, W, 1)
|
64 |
+
depth_map_np: depth map of the input image, torch.FloatTensor of shape (1, H, W)
|
65 |
+
"""
|
66 |
+
|
67 |
+
print("=> generating ControlNet Depth RePaint image...")
|
68 |
+
|
69 |
+
|
70 |
+
# Stable Diffusion 2 receives PIL.Image
|
71 |
+
# NOTE Stable Diffusion 2 returns a PIL.Image object
|
72 |
+
# image and mask_image should be PIL images.
|
73 |
+
# The mask structure is white for inpainting and black for keeping as is
|
74 |
+
diffused_image_np = process(
|
75 |
+
model, ddim_sampler,
|
76 |
+
np.array(init_image), prompt, a_prompt, n_prompt, num_samples,
|
77 |
+
ddim_steps, guidance_scale, seed, eta,
|
78 |
+
strength=strength, detected_map=depth_map_np, unknown_mask=np.array(generate_mask_image), save_memory=save_memory
|
79 |
+
)[0]
|
80 |
+
|
81 |
+
init_image = init_image.convert("RGB")
|
82 |
+
diffused_image = Image.fromarray(diffused_image_np).convert("RGB")
|
83 |
+
|
84 |
+
if blend > 0 and transforms.ToTensor()(keep_mask_image).sum() > 0:
|
85 |
+
print("=> blending the generated region...")
|
86 |
+
kernel_size = 3
|
87 |
+
kernel = np.ones((kernel_size, kernel_size), np.uint8)
|
88 |
+
|
89 |
+
keep_image_np = np.array(init_image).astype(np.uint8)
|
90 |
+
keep_image_np_dilate = cv2.dilate(keep_image_np, kernel, iterations=1)
|
91 |
+
|
92 |
+
keep_mask_np = np.array(keep_mask_image).astype(np.uint8)
|
93 |
+
keep_mask_np_dilate = cv2.dilate(keep_mask_np, kernel, iterations=1)
|
94 |
+
|
95 |
+
generate_image_np = np.array(diffused_image).astype(np.uint8)
|
96 |
+
|
97 |
+
overlap_mask_np = np.array(generate_mask_image).astype(np.uint8)
|
98 |
+
overlap_mask_np *= keep_mask_np_dilate
|
99 |
+
print("=> blending {} pixels...".format(np.sum(overlap_mask_np)))
|
100 |
+
|
101 |
+
overlap_keep = keep_image_np_dilate[overlap_mask_np == 1]
|
102 |
+
overlap_generate = generate_image_np[overlap_mask_np == 1]
|
103 |
+
|
104 |
+
overlap_np = overlap_keep * blend + overlap_generate * (1 - blend)
|
105 |
+
|
106 |
+
generate_image_np[overlap_mask_np == 1] = overlap_np
|
107 |
+
|
108 |
+
diffused_image = Image.fromarray(generate_image_np.astype(np.uint8)).convert("RGB")
|
109 |
+
|
110 |
+
init_image_masked = init_image
|
111 |
+
diffused_image_masked = diffused_image
|
112 |
+
|
113 |
+
return diffused_image, init_image_masked, diffused_image_masked
|
114 |
+
|
115 |
+
|
116 |
+
@torch.no_grad()
|
117 |
+
def apply_inpainting(model,
|
118 |
+
init_image, mask_image_tensor, prompt, height, width, device):
|
119 |
+
"""
|
120 |
+
Use Stable Diffusion 2 to generate image
|
121 |
+
|
122 |
+
Arguments:
|
123 |
+
args: input arguments
|
124 |
+
model: Stable Diffusion 2 model
|
125 |
+
init_image_tensor: input image, torch.FloatTensor of shape (1, H, W, 3)
|
126 |
+
mask_tensor: depth map of the input image, torch.FloatTensor of shape (1, H, W, 1)
|
127 |
+
depth_map_tensor: depth map of the input image, torch.FloatTensor of shape (1, H, W)
|
128 |
+
"""
|
129 |
+
|
130 |
+
print("=> generating Inpainting image...")
|
131 |
+
|
132 |
+
mask_image = mask_image_tensor[0].cpu()
|
133 |
+
mask_image = mask_image.permute(2, 0, 1)
|
134 |
+
mask_image = transforms.ToPILImage()(mask_image).convert("L")
|
135 |
+
|
136 |
+
# NOTE Stable Diffusion 2 returns a PIL.Image object
|
137 |
+
# image and mask_image should be PIL images.
|
138 |
+
# The mask structure is white for inpainting and black for keeping as is
|
139 |
+
diffused_image = model(
|
140 |
+
prompt=prompt,
|
141 |
+
image=init_image.resize((512, 512)),
|
142 |
+
mask_image=mask_image.resize((512, 512)),
|
143 |
+
height=512,
|
144 |
+
width=512
|
145 |
+
).images[0].resize((height, width))
|
146 |
+
|
147 |
+
return diffused_image
|
148 |
+
|
149 |
+
|
150 |
+
@torch.no_grad()
|
151 |
+
def apply_inpainting_postprocess(model,
|
152 |
+
init_image, mask_image_tensor, prompt, height, width, device):
|
153 |
+
"""
|
154 |
+
Use Stable Diffusion 2 to generate image
|
155 |
+
|
156 |
+
Arguments:
|
157 |
+
args: input arguments
|
158 |
+
model: Stable Diffusion 2 model
|
159 |
+
init_image_tensor: input image, torch.FloatTensor of shape (1, H, W, 3)
|
160 |
+
mask_tensor: depth map of the input image, torch.FloatTensor of shape (1, H, W, 1)
|
161 |
+
depth_map_tensor: depth map of the input image, torch.FloatTensor of shape (1, H, W)
|
162 |
+
"""
|
163 |
+
|
164 |
+
print("=> generating Inpainting image...")
|
165 |
+
|
166 |
+
mask_image = mask_image_tensor[0].cpu()
|
167 |
+
mask_image = mask_image.permute(2, 0, 1)
|
168 |
+
mask_image = transforms.ToPILImage()(mask_image).convert("L")
|
169 |
+
|
170 |
+
# NOTE Stable Diffusion 2 returns a PIL.Image object
|
171 |
+
# image and mask_image should be PIL images.
|
172 |
+
# The mask structure is white for inpainting and black for keeping as is
|
173 |
+
diffused_image = model(
|
174 |
+
prompt=prompt,
|
175 |
+
image=init_image.resize((512, 512)),
|
176 |
+
mask_image=mask_image.resize((512, 512)),
|
177 |
+
height=512,
|
178 |
+
width=512
|
179 |
+
).images[0].resize((height, width))
|
180 |
+
|
181 |
+
diffused_image_tensor = torch.from_numpy(np.array(diffused_image)).to(device)
|
182 |
+
|
183 |
+
init_images_tensor = torch.from_numpy(np.array(init_image)).to(device)
|
184 |
+
|
185 |
+
init_images_tensor = diffused_image_tensor * mask_image_tensor[0] + init_images_tensor * (1 - mask_image_tensor[0])
|
186 |
+
init_image = Image.fromarray(init_images_tensor.cpu().numpy().astype(np.uint8)).convert("RGB")
|
187 |
+
|
188 |
+
return init_image
|
189 |
+
|
text2tex/lib/io_helper.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# common utils
|
2 |
+
import os
|
3 |
+
import json
|
4 |
+
|
5 |
+
# numpy
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
# visualization
|
9 |
+
import matplotlib
|
10 |
+
import matplotlib.cm as cm
|
11 |
+
import matplotlib.pyplot as plt
|
12 |
+
|
13 |
+
matplotlib.use("Agg")
|
14 |
+
|
15 |
+
from pytorch3d.io import save_obj
|
16 |
+
|
17 |
+
from torchvision import transforms
|
18 |
+
|
19 |
+
|
20 |
+
def save_depth(fragments, output_dir, init_image, view_idx):
|
21 |
+
print("=> saving depth...")
|
22 |
+
width, height = init_image.size
|
23 |
+
dpi = 100
|
24 |
+
figsize = width / float(dpi), height / float(dpi)
|
25 |
+
|
26 |
+
depth_np = fragments.zbuf[0].cpu().numpy()
|
27 |
+
|
28 |
+
fig = plt.figure(figsize=figsize)
|
29 |
+
ax = fig.add_axes([0, 0, 1, 1])
|
30 |
+
# Hide spines, ticks, etc.
|
31 |
+
ax.axis('off')
|
32 |
+
# Display the image.
|
33 |
+
ax.imshow(depth_np, cmap='gray')
|
34 |
+
|
35 |
+
plt.savefig(os.path.join(output_dir, "{}.png".format(view_idx)), bbox_inches='tight', pad_inches=0)
|
36 |
+
np.save(os.path.join(output_dir, "{}.npy".format(view_idx)), depth_np[..., 0])
|
37 |
+
|
38 |
+
|
39 |
+
def save_backproject_obj(output_dir, obj_name,
|
40 |
+
verts, faces, verts_uvs, faces_uvs, projected_texture,
|
41 |
+
device):
|
42 |
+
print("=> saving OBJ file...")
|
43 |
+
texture_map = transforms.ToTensor()(projected_texture).to(device)
|
44 |
+
texture_map = texture_map.permute(1, 2, 0)
|
45 |
+
obj_path = os.path.join(output_dir, obj_name)
|
46 |
+
|
47 |
+
save_obj(
|
48 |
+
obj_path,
|
49 |
+
verts=verts,
|
50 |
+
faces=faces,
|
51 |
+
decimal_places=5,
|
52 |
+
verts_uvs=verts_uvs,
|
53 |
+
faces_uvs=faces_uvs,
|
54 |
+
texture_map=texture_map
|
55 |
+
)
|
56 |
+
|
57 |
+
|
58 |
+
def save_args(args, output_dir):
|
59 |
+
with open(os.path.join(output_dir, "args.json"), "w") as f:
|
60 |
+
json.dump(
|
61 |
+
{k: v for k, v in vars(args).items()},
|
62 |
+
f,
|
63 |
+
indent=4
|
64 |
+
)
|
65 |
+
|
66 |
+
|
67 |
+
def save_viewpoints(args, output_dir, dist_list, elev_list, azim_list, view_list):
|
68 |
+
with open(os.path.join(output_dir, "viewpoints.json"), "w") as f:
|
69 |
+
json.dump(
|
70 |
+
{
|
71 |
+
"dist": dist_list,
|
72 |
+
"elev": elev_list,
|
73 |
+
"azim": azim_list,
|
74 |
+
"view": view_list
|
75 |
+
},
|
76 |
+
f,
|
77 |
+
indent=4
|
78 |
+
)
|
text2tex/lib/mesh_helper.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import trimesh
|
4 |
+
import xatlas
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from sklearn.decomposition import PCA
|
9 |
+
|
10 |
+
from torchvision import transforms
|
11 |
+
|
12 |
+
from tqdm import tqdm
|
13 |
+
|
14 |
+
from pytorch3d.io import (
|
15 |
+
load_obj,
|
16 |
+
load_objs_as_meshes
|
17 |
+
)
|
18 |
+
|
19 |
+
|
20 |
+
def compute_principle_directions(model_path, num_points=20000):
|
21 |
+
mesh = trimesh.load_mesh(model_path, force="mesh")
|
22 |
+
pc, _ = trimesh.sample.sample_surface_even(mesh, num_points)
|
23 |
+
|
24 |
+
pc -= np.mean(pc, axis=0, keepdims=True)
|
25 |
+
|
26 |
+
principle_directions = PCA(n_components=3).fit(pc).components_
|
27 |
+
|
28 |
+
return principle_directions
|
29 |
+
|
30 |
+
|
31 |
+
def init_mesh(input_path, cache_path, device):
|
32 |
+
print("=> parameterizing target mesh...")
|
33 |
+
|
34 |
+
mesh = trimesh.load_mesh(input_path, force='mesh')
|
35 |
+
try:
|
36 |
+
vertices, faces = mesh.vertices, mesh.faces
|
37 |
+
except AttributeError:
|
38 |
+
print("multiple materials in {} are not supported".format(input_path))
|
39 |
+
exit()
|
40 |
+
|
41 |
+
vmapping, indices, uvs = xatlas.parametrize(vertices, faces)
|
42 |
+
xatlas.export(str(cache_path), vertices[vmapping], indices, uvs)
|
43 |
+
|
44 |
+
print("=> loading target mesh...")
|
45 |
+
|
46 |
+
# principle_directions = compute_principle_directions(cache_path)
|
47 |
+
principle_directions = None
|
48 |
+
|
49 |
+
_, faces, aux = load_obj(cache_path, device=device)
|
50 |
+
mesh = load_objs_as_meshes([cache_path], device=device)
|
51 |
+
|
52 |
+
num_verts = mesh.verts_packed().shape[0]
|
53 |
+
|
54 |
+
# make sure mesh center is at origin
|
55 |
+
bbox = mesh.get_bounding_boxes()
|
56 |
+
mesh_center = bbox.mean(dim=2).repeat(num_verts, 1)
|
57 |
+
mesh = apply_offsets_to_mesh(mesh, -mesh_center)
|
58 |
+
|
59 |
+
# make sure mesh size is normalized
|
60 |
+
box_size = bbox[..., 1] - bbox[..., 0]
|
61 |
+
box_max = box_size.max(dim=1, keepdim=True)[0].repeat(num_verts, 3)
|
62 |
+
mesh = apply_scale_to_mesh(mesh, 1 / box_max)
|
63 |
+
|
64 |
+
return mesh, mesh.verts_packed(), faces, aux, principle_directions, mesh_center, box_max
|
65 |
+
|
66 |
+
|
67 |
+
def apply_offsets_to_mesh(mesh, offsets):
|
68 |
+
new_mesh = mesh.offset_verts(offsets)
|
69 |
+
|
70 |
+
return new_mesh
|
71 |
+
|
72 |
+
def apply_scale_to_mesh(mesh, scale):
|
73 |
+
new_mesh = mesh.scale_verts(scale)
|
74 |
+
|
75 |
+
return new_mesh
|
76 |
+
|
77 |
+
|
78 |
+
def adjust_uv_map(faces, aux, init_texture, uv_size):
|
79 |
+
"""
|
80 |
+
adjust UV map to be compatiable with multiple textures.
|
81 |
+
UVs for different materials will be decomposed and placed horizontally
|
82 |
+
|
83 |
+
+-----+-----+-----+--
|
84 |
+
| 1 | 2 | 3 |
|
85 |
+
+-----+-----+-----+--
|
86 |
+
|
87 |
+
"""
|
88 |
+
|
89 |
+
textures_ids = faces.textures_idx
|
90 |
+
materials_idx = faces.materials_idx
|
91 |
+
verts_uvs = aux.verts_uvs
|
92 |
+
|
93 |
+
num_materials = torch.unique(materials_idx).shape[0]
|
94 |
+
|
95 |
+
new_verts_uvs = verts_uvs.clone()
|
96 |
+
for material_id in range(num_materials):
|
97 |
+
# apply offsets to horizontal axis
|
98 |
+
faces_ids = textures_ids[materials_idx == material_id].unique()
|
99 |
+
new_verts_uvs[faces_ids, 0] += material_id
|
100 |
+
|
101 |
+
new_verts_uvs[:, 0] /= num_materials
|
102 |
+
|
103 |
+
init_texture_tensor = transforms.ToTensor()(init_texture)
|
104 |
+
init_texture_tensor = torch.cat([init_texture_tensor for _ in range(num_materials)], dim=-1)
|
105 |
+
init_texture = transforms.ToPILImage()(init_texture_tensor).resize((uv_size, uv_size))
|
106 |
+
|
107 |
+
return new_verts_uvs, init_texture
|
108 |
+
|
109 |
+
|
110 |
+
@torch.no_grad()
|
111 |
+
def update_face_angles(mesh, cameras, fragments):
|
112 |
+
def get_angle(x, y):
|
113 |
+
x = torch.nn.functional.normalize(x)
|
114 |
+
y = torch.nn.functional.normalize(y)
|
115 |
+
inner_product = (x * y).sum(dim=1)
|
116 |
+
x_norm = x.pow(2).sum(dim=1).pow(0.5)
|
117 |
+
y_norm = y.pow(2).sum(dim=1).pow(0.5)
|
118 |
+
cos = inner_product / (x_norm * y_norm)
|
119 |
+
angle = torch.acos(cos)
|
120 |
+
angle = angle * 180 / 3.14159
|
121 |
+
|
122 |
+
return angle
|
123 |
+
|
124 |
+
# face normals
|
125 |
+
face_normals = mesh.faces_normals_padded()[0]
|
126 |
+
|
127 |
+
# view vector (object center -> camera center)
|
128 |
+
camera_center = cameras.get_camera_center()
|
129 |
+
|
130 |
+
face_angles = get_angle(
|
131 |
+
face_normals,
|
132 |
+
camera_center.repeat(face_normals.shape[0], 1)
|
133 |
+
) # (F)
|
134 |
+
|
135 |
+
face_angles_rev = get_angle(
|
136 |
+
face_normals,
|
137 |
+
-camera_center.repeat(face_normals.shape[0], 1)
|
138 |
+
) # (F)
|
139 |
+
|
140 |
+
face_angles = torch.minimum(face_angles, face_angles_rev)
|
141 |
+
|
142 |
+
# Indices of unique visible faces
|
143 |
+
visible_map = fragments.pix_to_face.unique() # (num_visible_faces)
|
144 |
+
invisible_mask = torch.ones_like(face_angles)
|
145 |
+
invisible_mask[visible_map] = 0
|
146 |
+
face_angles[invisible_mask == 1] = 10000. # angles of invisible faces are ignored
|
147 |
+
|
148 |
+
return face_angles
|
text2tex/lib/projection_helper.py
ADDED
@@ -0,0 +1,464 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import random
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from torchvision import transforms
|
10 |
+
|
11 |
+
from pytorch3d.renderer import TexturesUV
|
12 |
+
from pytorch3d.ops import interpolate_face_attributes
|
13 |
+
|
14 |
+
from PIL import Image
|
15 |
+
|
16 |
+
from tqdm import tqdm
|
17 |
+
|
18 |
+
# customized
|
19 |
+
import sys
|
20 |
+
sys.path.append(".")
|
21 |
+
|
22 |
+
from lib.camera_helper import init_camera
|
23 |
+
from lib.render_helper import init_renderer, render
|
24 |
+
from lib.shading_helper import (
|
25 |
+
BlendParams,
|
26 |
+
init_soft_phong_shader,
|
27 |
+
init_flat_texel_shader,
|
28 |
+
)
|
29 |
+
from lib.vis_helper import visualize_outputs, visualize_quad_mask
|
30 |
+
from lib.constants import *
|
31 |
+
|
32 |
+
|
33 |
+
def get_all_4_locations(values_y, values_x):
|
34 |
+
y_0 = torch.floor(values_y)
|
35 |
+
y_1 = torch.ceil(values_y)
|
36 |
+
x_0 = torch.floor(values_x)
|
37 |
+
x_1 = torch.ceil(values_x)
|
38 |
+
|
39 |
+
return torch.cat([y_0, y_0, y_1, y_1], 0).long(), torch.cat([x_0, x_1, x_0, x_1], 0).long()
|
40 |
+
|
41 |
+
|
42 |
+
def compose_quad_mask(new_mask_image, update_mask_image, old_mask_image, device):
|
43 |
+
"""
|
44 |
+
compose quad mask:
|
45 |
+
-> 0: background
|
46 |
+
-> 1: old
|
47 |
+
-> 2: update
|
48 |
+
-> 3: new
|
49 |
+
"""
|
50 |
+
|
51 |
+
new_mask_tensor = transforms.ToTensor()(new_mask_image).to(device)
|
52 |
+
update_mask_tensor = transforms.ToTensor()(update_mask_image).to(device)
|
53 |
+
old_mask_tensor = transforms.ToTensor()(old_mask_image).to(device)
|
54 |
+
|
55 |
+
all_mask_tensor = new_mask_tensor + update_mask_tensor + old_mask_tensor
|
56 |
+
|
57 |
+
quad_mask_tensor = torch.zeros_like(all_mask_tensor)
|
58 |
+
quad_mask_tensor[old_mask_tensor == 1] = 1
|
59 |
+
quad_mask_tensor[update_mask_tensor == 1] = 2
|
60 |
+
quad_mask_tensor[new_mask_tensor == 1] = 3
|
61 |
+
|
62 |
+
return old_mask_tensor, update_mask_tensor, new_mask_tensor, all_mask_tensor, quad_mask_tensor
|
63 |
+
|
64 |
+
|
65 |
+
def compute_view_heat(similarity_tensor, quad_mask_tensor):
|
66 |
+
num_total_pixels = quad_mask_tensor.reshape(-1).shape[0]
|
67 |
+
heat = 0
|
68 |
+
for idx in QUAD_WEIGHTS:
|
69 |
+
heat += (quad_mask_tensor == idx).sum() * QUAD_WEIGHTS[idx] / num_total_pixels
|
70 |
+
|
71 |
+
return heat
|
72 |
+
|
73 |
+
|
74 |
+
def select_viewpoint(selected_view_ids, view_punishments,
|
75 |
+
mode, dist_list, elev_list, azim_list, sector_list, view_idx,
|
76 |
+
similarity_texture_cache, exist_texture,
|
77 |
+
mesh, faces, verts_uvs,
|
78 |
+
image_size, faces_per_pixel,
|
79 |
+
init_image_dir, mask_image_dir, normal_map_dir, depth_map_dir, similarity_map_dir,
|
80 |
+
device, use_principle=False
|
81 |
+
):
|
82 |
+
if mode == "sequential":
|
83 |
+
|
84 |
+
num_views = len(dist_list)
|
85 |
+
|
86 |
+
dist = dist_list[view_idx % num_views]
|
87 |
+
elev = elev_list[view_idx % num_views]
|
88 |
+
azim = azim_list[view_idx % num_views]
|
89 |
+
sector = sector_list[view_idx % num_views]
|
90 |
+
|
91 |
+
selected_view_ids.append(view_idx % num_views)
|
92 |
+
|
93 |
+
elif mode == "heuristic":
|
94 |
+
|
95 |
+
if use_principle and view_idx < 6:
|
96 |
+
|
97 |
+
selected_view_idx = view_idx
|
98 |
+
|
99 |
+
else:
|
100 |
+
|
101 |
+
selected_view_idx = None
|
102 |
+
max_heat = 0
|
103 |
+
|
104 |
+
print("=> selecting next view...")
|
105 |
+
view_heat_list = []
|
106 |
+
for sample_idx in tqdm(range(len(dist_list))):
|
107 |
+
|
108 |
+
view_heat, *_ = render_one_view_and_build_masks(dist_list[sample_idx], elev_list[sample_idx], azim_list[sample_idx],
|
109 |
+
sample_idx, sample_idx, view_punishments,
|
110 |
+
similarity_texture_cache, exist_texture,
|
111 |
+
mesh, faces, verts_uvs,
|
112 |
+
image_size, faces_per_pixel,
|
113 |
+
init_image_dir, mask_image_dir, normal_map_dir, depth_map_dir, similarity_map_dir,
|
114 |
+
device)
|
115 |
+
|
116 |
+
if view_heat > max_heat:
|
117 |
+
selected_view_idx = sample_idx
|
118 |
+
max_heat = view_heat
|
119 |
+
|
120 |
+
view_heat_list.append(view_heat.item())
|
121 |
+
|
122 |
+
print(view_heat_list)
|
123 |
+
print("select view {} with heat {}".format(selected_view_idx, max_heat))
|
124 |
+
|
125 |
+
|
126 |
+
dist = dist_list[selected_view_idx]
|
127 |
+
elev = elev_list[selected_view_idx]
|
128 |
+
azim = azim_list[selected_view_idx]
|
129 |
+
sector = sector_list[selected_view_idx]
|
130 |
+
|
131 |
+
selected_view_ids.append(selected_view_idx)
|
132 |
+
|
133 |
+
view_punishments[selected_view_idx] *= 0.01
|
134 |
+
|
135 |
+
elif mode == "random":
|
136 |
+
|
137 |
+
selected_view_idx = random.choice(range(len(dist_list)))
|
138 |
+
|
139 |
+
dist = dist_list[selected_view_idx]
|
140 |
+
elev = elev_list[selected_view_idx]
|
141 |
+
azim = azim_list[selected_view_idx]
|
142 |
+
sector = sector_list[selected_view_idx]
|
143 |
+
|
144 |
+
selected_view_ids.append(selected_view_idx)
|
145 |
+
|
146 |
+
else:
|
147 |
+
raise NotImplementedError()
|
148 |
+
|
149 |
+
return dist, elev, azim, sector, selected_view_ids, view_punishments
|
150 |
+
|
151 |
+
|
152 |
+
@torch.no_grad()
|
153 |
+
def build_backproject_mask(mesh, faces, verts_uvs,
|
154 |
+
cameras, reference_image, faces_per_pixel,
|
155 |
+
image_size, uv_size, device):
|
156 |
+
# construct pixel UVs
|
157 |
+
renderer_scaled = init_renderer(cameras,
|
158 |
+
shader=init_soft_phong_shader(
|
159 |
+
camera=cameras,
|
160 |
+
blend_params=BlendParams(),
|
161 |
+
device=device),
|
162 |
+
image_size=image_size,
|
163 |
+
faces_per_pixel=faces_per_pixel
|
164 |
+
)
|
165 |
+
fragments_scaled = renderer_scaled.rasterizer(mesh)
|
166 |
+
|
167 |
+
# get UV coordinates for each pixel
|
168 |
+
faces_verts_uvs = verts_uvs[faces.textures_idx]
|
169 |
+
|
170 |
+
pixel_uvs = interpolate_face_attributes(
|
171 |
+
fragments_scaled.pix_to_face, fragments_scaled.bary_coords, faces_verts_uvs
|
172 |
+
) # NxHsxWsxKx2
|
173 |
+
pixel_uvs = pixel_uvs.permute(0, 3, 1, 2, 4).reshape(-1, 2)
|
174 |
+
|
175 |
+
texture_locations_y, texture_locations_x = get_all_4_locations(
|
176 |
+
(1 - pixel_uvs[:, 1]).reshape(-1) * (uv_size - 1),
|
177 |
+
pixel_uvs[:, 0].reshape(-1) * (uv_size - 1)
|
178 |
+
)
|
179 |
+
|
180 |
+
K = faces_per_pixel
|
181 |
+
|
182 |
+
texture_values = torch.from_numpy(np.array(reference_image.resize((image_size, image_size)))).float() / 255.
|
183 |
+
texture_values = texture_values.to(device).unsqueeze(0).expand([4, -1, -1, -1]).unsqueeze(0).expand([K, -1, -1, -1, -1])
|
184 |
+
|
185 |
+
# texture
|
186 |
+
texture_tensor = torch.zeros(uv_size, uv_size, 3).to(device)
|
187 |
+
texture_tensor[texture_locations_y, texture_locations_x, :] = texture_values.reshape(-1, 3)
|
188 |
+
|
189 |
+
return texture_tensor[:, :, 0]
|
190 |
+
|
191 |
+
|
192 |
+
@torch.no_grad()
|
193 |
+
def build_diffusion_mask(mesh_stuff,
|
194 |
+
renderer, exist_texture, similarity_texture_cache, target_value, device, image_size,
|
195 |
+
smooth_mask=False, view_threshold=0.01):
|
196 |
+
|
197 |
+
mesh, faces, verts_uvs = mesh_stuff
|
198 |
+
mask_mesh = mesh.clone() # NOTE in-place operation - DANGER!!!
|
199 |
+
|
200 |
+
# visible mask => the whole region
|
201 |
+
exist_texture_expand = exist_texture.unsqueeze(0).unsqueeze(-1).expand(-1, -1, -1, 3).to(device)
|
202 |
+
mask_mesh.textures = TexturesUV(
|
203 |
+
maps=torch.ones_like(exist_texture_expand),
|
204 |
+
faces_uvs=faces.textures_idx[None, ...],
|
205 |
+
verts_uvs=verts_uvs[None, ...],
|
206 |
+
sampling_mode="nearest"
|
207 |
+
)
|
208 |
+
# visible_mask_tensor, *_ = render(mask_mesh, renderer)
|
209 |
+
visible_mask_tensor, _, similarity_map_tensor, *_ = render(mask_mesh, renderer)
|
210 |
+
# faces that are too rotated away from the viewpoint will be treated as invisible
|
211 |
+
valid_mask_tensor = (similarity_map_tensor >= view_threshold).float()
|
212 |
+
visible_mask_tensor *= valid_mask_tensor
|
213 |
+
|
214 |
+
# nonexist mask <=> new mask
|
215 |
+
exist_texture_expand = exist_texture.unsqueeze(0).unsqueeze(-1).expand(-1, -1, -1, 3).to(device)
|
216 |
+
mask_mesh.textures = TexturesUV(
|
217 |
+
maps=1 - exist_texture_expand,
|
218 |
+
faces_uvs=faces.textures_idx[None, ...],
|
219 |
+
verts_uvs=verts_uvs[None, ...],
|
220 |
+
sampling_mode="nearest"
|
221 |
+
)
|
222 |
+
new_mask_tensor, *_ = render(mask_mesh, renderer)
|
223 |
+
new_mask_tensor *= valid_mask_tensor
|
224 |
+
|
225 |
+
# exist mask => visible mask - new mask
|
226 |
+
exist_mask_tensor = visible_mask_tensor - new_mask_tensor
|
227 |
+
exist_mask_tensor[exist_mask_tensor < 0] = 0 # NOTE dilate can lead to overflow
|
228 |
+
|
229 |
+
# all update mask
|
230 |
+
mask_mesh.textures = TexturesUV(
|
231 |
+
maps=(
|
232 |
+
similarity_texture_cache.argmax(0) == target_value
|
233 |
+
# # only consider the views that have already appeared before
|
234 |
+
# similarity_texture_cache[0:target_value+1].argmax(0) == target_value
|
235 |
+
).float().unsqueeze(0).unsqueeze(-1).expand(-1, -1, -1, 3).to(device),
|
236 |
+
faces_uvs=faces.textures_idx[None, ...],
|
237 |
+
verts_uvs=verts_uvs[None, ...],
|
238 |
+
sampling_mode="nearest"
|
239 |
+
)
|
240 |
+
all_update_mask_tensor, *_ = render(mask_mesh, renderer)
|
241 |
+
|
242 |
+
# current update mask => intersection between all update mask and exist mask
|
243 |
+
update_mask_tensor = exist_mask_tensor * all_update_mask_tensor
|
244 |
+
|
245 |
+
# keep mask => exist mask - update mask
|
246 |
+
old_mask_tensor = exist_mask_tensor - update_mask_tensor
|
247 |
+
|
248 |
+
# convert
|
249 |
+
new_mask = new_mask_tensor[0].cpu().float().permute(2, 0, 1)
|
250 |
+
new_mask = transforms.ToPILImage()(new_mask).convert("L")
|
251 |
+
|
252 |
+
update_mask = update_mask_tensor[0].cpu().float().permute(2, 0, 1)
|
253 |
+
update_mask = transforms.ToPILImage()(update_mask).convert("L")
|
254 |
+
|
255 |
+
old_mask = old_mask_tensor[0].cpu().float().permute(2, 0, 1)
|
256 |
+
old_mask = transforms.ToPILImage()(old_mask).convert("L")
|
257 |
+
|
258 |
+
exist_mask = exist_mask_tensor[0].cpu().float().permute(2, 0, 1)
|
259 |
+
exist_mask = transforms.ToPILImage()(exist_mask).convert("L")
|
260 |
+
|
261 |
+
return new_mask, update_mask, old_mask, exist_mask
|
262 |
+
|
263 |
+
|
264 |
+
@torch.no_grad()
|
265 |
+
def render_one_view(mesh,
|
266 |
+
dist, elev, azim,
|
267 |
+
image_size, faces_per_pixel,
|
268 |
+
device):
|
269 |
+
|
270 |
+
# render the view
|
271 |
+
cameras = init_camera(
|
272 |
+
dist, elev, azim,
|
273 |
+
image_size, device
|
274 |
+
)
|
275 |
+
renderer = init_renderer(cameras,
|
276 |
+
shader=init_soft_phong_shader(
|
277 |
+
camera=cameras,
|
278 |
+
blend_params=BlendParams(),
|
279 |
+
device=device),
|
280 |
+
image_size=image_size,
|
281 |
+
faces_per_pixel=faces_per_pixel
|
282 |
+
)
|
283 |
+
|
284 |
+
init_images_tensor, normal_maps_tensor, similarity_tensor, depth_maps_tensor, fragments = render(mesh, renderer)
|
285 |
+
|
286 |
+
return (
|
287 |
+
cameras, renderer,
|
288 |
+
init_images_tensor, normal_maps_tensor, similarity_tensor, depth_maps_tensor, fragments
|
289 |
+
)
|
290 |
+
|
291 |
+
|
292 |
+
@torch.no_grad()
|
293 |
+
def build_similarity_texture_cache_for_all_views(mesh, faces, verts_uvs,
|
294 |
+
dist_list, elev_list, azim_list,
|
295 |
+
image_size, image_size_scaled, uv_size, faces_per_pixel,
|
296 |
+
device):
|
297 |
+
|
298 |
+
num_candidate_views = len(dist_list)
|
299 |
+
similarity_texture_cache = torch.zeros(num_candidate_views, uv_size, uv_size).to(device)
|
300 |
+
|
301 |
+
print("=> building similarity texture cache for all views...")
|
302 |
+
for i in tqdm(range(num_candidate_views)):
|
303 |
+
cameras, _, _, _, similarity_tensor, _, _ = render_one_view(mesh,
|
304 |
+
dist_list[i], elev_list[i], azim_list[i],
|
305 |
+
image_size, faces_per_pixel, device)
|
306 |
+
|
307 |
+
similarity_texture_cache[i] = build_backproject_mask(mesh, faces, verts_uvs,
|
308 |
+
cameras, transforms.ToPILImage()(similarity_tensor[0, :, :, 0]).convert("RGB"), faces_per_pixel,
|
309 |
+
image_size_scaled, uv_size, device)
|
310 |
+
|
311 |
+
return similarity_texture_cache
|
312 |
+
|
313 |
+
|
314 |
+
@torch.no_grad()
|
315 |
+
def render_one_view_and_build_masks(dist, elev, azim,
|
316 |
+
selected_view_idx, view_idx, view_punishments,
|
317 |
+
similarity_texture_cache, exist_texture,
|
318 |
+
mesh, faces, verts_uvs,
|
319 |
+
image_size, faces_per_pixel,
|
320 |
+
init_image_dir, mask_image_dir, normal_map_dir, depth_map_dir, similarity_map_dir,
|
321 |
+
device, save_intermediate=False, smooth_mask=False, view_threshold=0.01):
|
322 |
+
|
323 |
+
# render the view
|
324 |
+
(
|
325 |
+
cameras, renderer,
|
326 |
+
init_images_tensor, normal_maps_tensor, similarity_tensor, depth_maps_tensor, fragments
|
327 |
+
) = render_one_view(mesh,
|
328 |
+
dist, elev, azim,
|
329 |
+
image_size, faces_per_pixel,
|
330 |
+
device
|
331 |
+
)
|
332 |
+
|
333 |
+
init_image = init_images_tensor[0].cpu()
|
334 |
+
init_image = init_image.permute(2, 0, 1)
|
335 |
+
init_image = transforms.ToPILImage()(init_image).convert("RGB")
|
336 |
+
|
337 |
+
normal_map = normal_maps_tensor[0].cpu()
|
338 |
+
normal_map = normal_map.permute(2, 0, 1)
|
339 |
+
normal_map = transforms.ToPILImage()(normal_map).convert("RGB")
|
340 |
+
|
341 |
+
depth_map = depth_maps_tensor[0].cpu().numpy()
|
342 |
+
depth_map = Image.fromarray(depth_map).convert("L")
|
343 |
+
|
344 |
+
similarity_map = similarity_tensor[0, :, :, 0].cpu()
|
345 |
+
similarity_map = transforms.ToPILImage()(similarity_map).convert("L")
|
346 |
+
|
347 |
+
|
348 |
+
flat_renderer = init_renderer(cameras,
|
349 |
+
shader=init_flat_texel_shader(
|
350 |
+
camera=cameras,
|
351 |
+
device=device),
|
352 |
+
image_size=image_size,
|
353 |
+
faces_per_pixel=faces_per_pixel
|
354 |
+
)
|
355 |
+
new_mask_image, update_mask_image, old_mask_image, exist_mask_image = build_diffusion_mask(
|
356 |
+
(mesh, faces, verts_uvs),
|
357 |
+
flat_renderer, exist_texture, similarity_texture_cache, selected_view_idx, device, image_size,
|
358 |
+
smooth_mask=smooth_mask, view_threshold=view_threshold
|
359 |
+
)
|
360 |
+
# NOTE the view idx is the absolute idx in the sample space (i.e. `selected_view_idx`)
|
361 |
+
# it should match with `similarity_texture_cache`
|
362 |
+
|
363 |
+
(
|
364 |
+
old_mask_tensor,
|
365 |
+
update_mask_tensor,
|
366 |
+
new_mask_tensor,
|
367 |
+
all_mask_tensor,
|
368 |
+
quad_mask_tensor
|
369 |
+
) = compose_quad_mask(new_mask_image, update_mask_image, old_mask_image, device)
|
370 |
+
|
371 |
+
view_heat = compute_view_heat(similarity_tensor, quad_mask_tensor)
|
372 |
+
view_heat *= view_punishments[selected_view_idx]
|
373 |
+
|
374 |
+
# save intermediate results
|
375 |
+
if save_intermediate:
|
376 |
+
init_image.save(os.path.join(init_image_dir, "{}.png".format(view_idx)))
|
377 |
+
normal_map.save(os.path.join(normal_map_dir, "{}.png".format(view_idx)))
|
378 |
+
depth_map.save(os.path.join(depth_map_dir, "{}.png".format(view_idx)))
|
379 |
+
similarity_map.save(os.path.join(similarity_map_dir, "{}.png".format(view_idx)))
|
380 |
+
|
381 |
+
new_mask_image.save(os.path.join(mask_image_dir, "{}_new.png".format(view_idx)))
|
382 |
+
update_mask_image.save(os.path.join(mask_image_dir, "{}_update.png".format(view_idx)))
|
383 |
+
old_mask_image.save(os.path.join(mask_image_dir, "{}_old.png".format(view_idx)))
|
384 |
+
exist_mask_image.save(os.path.join(mask_image_dir, "{}_exist.png".format(view_idx)))
|
385 |
+
|
386 |
+
visualize_quad_mask(mask_image_dir, quad_mask_tensor, view_idx, view_heat, device)
|
387 |
+
|
388 |
+
return (
|
389 |
+
view_heat,
|
390 |
+
renderer, cameras, fragments,
|
391 |
+
init_image, normal_map, depth_map,
|
392 |
+
init_images_tensor, normal_maps_tensor, depth_maps_tensor, similarity_tensor,
|
393 |
+
old_mask_image, update_mask_image, new_mask_image,
|
394 |
+
old_mask_tensor, update_mask_tensor, new_mask_tensor, all_mask_tensor, quad_mask_tensor
|
395 |
+
)
|
396 |
+
|
397 |
+
|
398 |
+
|
399 |
+
@torch.no_grad()
|
400 |
+
def backproject_from_image(mesh, faces, verts_uvs, cameras,
|
401 |
+
reference_image, new_mask_image, update_mask_image,
|
402 |
+
init_texture, exist_texture,
|
403 |
+
image_size, uv_size, faces_per_pixel,
|
404 |
+
device):
|
405 |
+
|
406 |
+
# construct pixel UVs
|
407 |
+
renderer_scaled = init_renderer(cameras,
|
408 |
+
shader=init_soft_phong_shader(
|
409 |
+
camera=cameras,
|
410 |
+
blend_params=BlendParams(),
|
411 |
+
device=device),
|
412 |
+
image_size=image_size,
|
413 |
+
faces_per_pixel=faces_per_pixel
|
414 |
+
)
|
415 |
+
fragments_scaled = renderer_scaled.rasterizer(mesh)
|
416 |
+
|
417 |
+
# get UV coordinates for each pixel
|
418 |
+
faces_verts_uvs = verts_uvs[faces.textures_idx]
|
419 |
+
|
420 |
+
pixel_uvs = interpolate_face_attributes(
|
421 |
+
fragments_scaled.pix_to_face, fragments_scaled.bary_coords, faces_verts_uvs
|
422 |
+
) # NxHsxWsxKx2
|
423 |
+
pixel_uvs = pixel_uvs.permute(0, 3, 1, 2, 4).reshape(pixel_uvs.shape[-2], pixel_uvs.shape[1], pixel_uvs.shape[2], 2)
|
424 |
+
|
425 |
+
# the update mask has to be on top of the diffusion mask
|
426 |
+
new_mask_image_tensor = transforms.ToTensor()(new_mask_image).to(device).unsqueeze(-1)
|
427 |
+
update_mask_image_tensor = transforms.ToTensor()(update_mask_image).to(device).unsqueeze(-1)
|
428 |
+
|
429 |
+
project_mask_image_tensor = torch.logical_or(update_mask_image_tensor, new_mask_image_tensor).float()
|
430 |
+
project_mask_image = project_mask_image_tensor * 255.
|
431 |
+
project_mask_image = Image.fromarray(project_mask_image[0, :, :, 0].cpu().numpy().astype(np.uint8))
|
432 |
+
|
433 |
+
project_mask_image_scaled = project_mask_image.resize(
|
434 |
+
(image_size, image_size),
|
435 |
+
Image.Resampling.NEAREST
|
436 |
+
)
|
437 |
+
project_mask_image_tensor_scaled = transforms.ToTensor()(project_mask_image_scaled).to(device)
|
438 |
+
|
439 |
+
pixel_uvs_masked = pixel_uvs[project_mask_image_tensor_scaled == 1]
|
440 |
+
|
441 |
+
texture_locations_y, texture_locations_x = get_all_4_locations(
|
442 |
+
(1 - pixel_uvs_masked[:, 1]).reshape(-1) * (uv_size - 1),
|
443 |
+
pixel_uvs_masked[:, 0].reshape(-1) * (uv_size - 1)
|
444 |
+
)
|
445 |
+
|
446 |
+
K = pixel_uvs.shape[0]
|
447 |
+
project_mask_image_tensor_scaled = project_mask_image_tensor_scaled[:, None, :, :, None].repeat(1, 4, 1, 1, 3)
|
448 |
+
|
449 |
+
texture_values = torch.from_numpy(np.array(reference_image.resize((image_size, image_size))))
|
450 |
+
texture_values = texture_values.to(device).unsqueeze(0).expand([4, -1, -1, -1]).unsqueeze(0).expand([K, -1, -1, -1, -1])
|
451 |
+
|
452 |
+
texture_values_masked = texture_values.reshape(-1, 3)[project_mask_image_tensor_scaled.reshape(-1, 3) == 1].reshape(-1, 3)
|
453 |
+
|
454 |
+
# texture
|
455 |
+
texture_tensor = torch.from_numpy(np.array(init_texture)).to(device)
|
456 |
+
texture_tensor[texture_locations_y, texture_locations_x, :] = texture_values_masked
|
457 |
+
|
458 |
+
init_texture = Image.fromarray(texture_tensor.cpu().numpy().astype(np.uint8))
|
459 |
+
|
460 |
+
# update texture cache
|
461 |
+
exist_texture[texture_locations_y, texture_locations_x] = 1
|
462 |
+
|
463 |
+
return init_texture, project_mask_image, exist_texture
|
464 |
+
|
text2tex/lib/render_helper.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
from torchvision import transforms
|
11 |
+
from pytorch3d.ops import interpolate_face_attributes
|
12 |
+
from pytorch3d.renderer import (
|
13 |
+
RasterizationSettings,
|
14 |
+
MeshRendererWithFragments,
|
15 |
+
MeshRasterizer,
|
16 |
+
)
|
17 |
+
|
18 |
+
# customized
|
19 |
+
import sys
|
20 |
+
sys.path.append(".")
|
21 |
+
|
22 |
+
|
23 |
+
def init_renderer(camera, shader, image_size, faces_per_pixel):
|
24 |
+
raster_settings = RasterizationSettings(image_size=image_size, faces_per_pixel=faces_per_pixel)
|
25 |
+
renderer = MeshRendererWithFragments(
|
26 |
+
rasterizer=MeshRasterizer(
|
27 |
+
cameras=camera,
|
28 |
+
raster_settings=raster_settings
|
29 |
+
),
|
30 |
+
shader=shader
|
31 |
+
)
|
32 |
+
|
33 |
+
return renderer
|
34 |
+
|
35 |
+
|
36 |
+
@torch.no_grad()
|
37 |
+
def render(mesh, renderer, pad_value=10):
|
38 |
+
def phong_normal_shading(meshes, fragments) -> torch.Tensor:
|
39 |
+
faces = meshes.faces_packed() # (F, 3)
|
40 |
+
vertex_normals = meshes.verts_normals_packed() # (V, 3)
|
41 |
+
faces_normals = vertex_normals[faces]
|
42 |
+
pixel_normals = interpolate_face_attributes(
|
43 |
+
fragments.pix_to_face, fragments.bary_coords, faces_normals
|
44 |
+
)
|
45 |
+
|
46 |
+
return pixel_normals
|
47 |
+
|
48 |
+
def similarity_shading(meshes, fragments):
|
49 |
+
faces = meshes.faces_packed() # (F, 3)
|
50 |
+
vertex_normals = meshes.verts_normals_packed() # (V, 3)
|
51 |
+
faces_normals = vertex_normals[faces]
|
52 |
+
vertices = meshes.verts_packed() # (V, 3)
|
53 |
+
face_positions = vertices[faces]
|
54 |
+
view_directions = torch.nn.functional.normalize((renderer.shader.cameras.get_camera_center().reshape(1, 1, 3) - face_positions), p=2, dim=2)
|
55 |
+
cosine_similarity = torch.nn.CosineSimilarity(dim=2)(faces_normals, view_directions)
|
56 |
+
pixel_similarity = interpolate_face_attributes(
|
57 |
+
fragments.pix_to_face, fragments.bary_coords, cosine_similarity.unsqueeze(-1)
|
58 |
+
)
|
59 |
+
|
60 |
+
return pixel_similarity
|
61 |
+
|
62 |
+
def get_relative_depth_map(fragments, pad_value=pad_value):
|
63 |
+
absolute_depth = fragments.zbuf[..., 0] # B, H, W
|
64 |
+
no_depth = -1
|
65 |
+
|
66 |
+
depth_min, depth_max = absolute_depth[absolute_depth != no_depth].min(), absolute_depth[absolute_depth != no_depth].max()
|
67 |
+
target_min, target_max = 50, 255
|
68 |
+
|
69 |
+
depth_value = absolute_depth[absolute_depth != no_depth]
|
70 |
+
depth_value = depth_max - depth_value # reverse values
|
71 |
+
|
72 |
+
depth_value /= (depth_max - depth_min)
|
73 |
+
depth_value = depth_value * (target_max - target_min) + target_min
|
74 |
+
|
75 |
+
relative_depth = absolute_depth.clone()
|
76 |
+
relative_depth[absolute_depth != no_depth] = depth_value
|
77 |
+
relative_depth[absolute_depth == no_depth] = pad_value # not completely black
|
78 |
+
|
79 |
+
return relative_depth
|
80 |
+
|
81 |
+
|
82 |
+
images, fragments = renderer(mesh)
|
83 |
+
normal_maps = phong_normal_shading(mesh, fragments).squeeze(-2)
|
84 |
+
similarity_maps = similarity_shading(mesh, fragments).squeeze(-2) # -1 - 1
|
85 |
+
depth_maps = get_relative_depth_map(fragments)
|
86 |
+
|
87 |
+
# normalize similarity mask to 0 - 1
|
88 |
+
similarity_maps = torch.abs(similarity_maps) # 0 - 1
|
89 |
+
|
90 |
+
# HACK erode, eliminate isolated dots
|
91 |
+
non_zero_similarity = (similarity_maps > 0).float()
|
92 |
+
non_zero_similarity = (non_zero_similarity * 255.).cpu().numpy().astype(np.uint8)[0]
|
93 |
+
non_zero_similarity = cv2.erode(non_zero_similarity, kernel=np.ones((3, 3), np.uint8), iterations=2)
|
94 |
+
non_zero_similarity = torch.from_numpy(non_zero_similarity).to(similarity_maps.device).unsqueeze(0) / 255.
|
95 |
+
similarity_maps = non_zero_similarity.unsqueeze(-1) * similarity_maps
|
96 |
+
|
97 |
+
return images, normal_maps, similarity_maps, depth_maps, fragments
|
98 |
+
|
99 |
+
|
100 |
+
@torch.no_grad()
|
101 |
+
def check_visible_faces(mesh, fragments):
|
102 |
+
pix_to_face = fragments.pix_to_face
|
103 |
+
|
104 |
+
# Indices of unique visible faces
|
105 |
+
visible_map = pix_to_face.unique() # (num_visible_faces)
|
106 |
+
|
107 |
+
return visible_map
|
108 |
+
|
text2tex/lib/shading_helper.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import NamedTuple, Sequence
|
2 |
+
|
3 |
+
from pytorch3d.renderer.mesh.shader import ShaderBase
|
4 |
+
from pytorch3d.renderer import (
|
5 |
+
AmbientLights,
|
6 |
+
SoftPhongShader
|
7 |
+
)
|
8 |
+
|
9 |
+
|
10 |
+
class BlendParams(NamedTuple):
|
11 |
+
sigma: float = 1e-4
|
12 |
+
gamma: float = 1e-4
|
13 |
+
background_color: Sequence = (1, 1, 1)
|
14 |
+
|
15 |
+
|
16 |
+
class FlatTexelShader(ShaderBase):
|
17 |
+
|
18 |
+
def __init__(self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None):
|
19 |
+
super().__init__(device, cameras, lights, materials, blend_params)
|
20 |
+
|
21 |
+
def forward(self, fragments, meshes, **_kwargs):
|
22 |
+
texels = meshes.sample_textures(fragments)
|
23 |
+
texels[(fragments.pix_to_face == -1), :] = 0
|
24 |
+
return texels.squeeze(-2)
|
25 |
+
|
26 |
+
|
27 |
+
def init_soft_phong_shader(camera, blend_params, device):
|
28 |
+
lights = AmbientLights(device=device)
|
29 |
+
shader = SoftPhongShader(
|
30 |
+
cameras=camera,
|
31 |
+
lights=lights,
|
32 |
+
device=device,
|
33 |
+
blend_params=blend_params
|
34 |
+
)
|
35 |
+
|
36 |
+
return shader
|
37 |
+
|
38 |
+
|
39 |
+
def init_flat_texel_shader(camera, device):
|
40 |
+
shader=FlatTexelShader(
|
41 |
+
cameras=camera,
|
42 |
+
device=device
|
43 |
+
)
|
44 |
+
|
45 |
+
return shader
|
text2tex/lib/vis_helper.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
# visualization
|
7 |
+
import matplotlib
|
8 |
+
import matplotlib.cm as cm
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
|
11 |
+
matplotlib.use("Agg")
|
12 |
+
|
13 |
+
from PIL import Image
|
14 |
+
|
15 |
+
# GIF
|
16 |
+
import imageio.v2 as imageio
|
17 |
+
|
18 |
+
# customized
|
19 |
+
import sys
|
20 |
+
sys.path.append(".")
|
21 |
+
|
22 |
+
from lib.constants import *
|
23 |
+
from lib.camera_helper import polar_to_xyz
|
24 |
+
|
25 |
+
def visualize_quad_mask(mask_image_dir, quad_mask_tensor, view_idx, view_score, device):
|
26 |
+
quad_mask_tensor = quad_mask_tensor.unsqueeze(-1).repeat(1, 1, 1, 3)
|
27 |
+
quad_mask_image_tensor = torch.zeros_like(quad_mask_tensor)
|
28 |
+
|
29 |
+
for idx in PALETTE:
|
30 |
+
selected = quad_mask_tensor[quad_mask_tensor == idx].reshape(-1, 3)
|
31 |
+
selected = torch.FloatTensor(PALETTE[idx]).to(device).unsqueeze(0).repeat(selected.shape[0], 1)
|
32 |
+
|
33 |
+
quad_mask_image_tensor[quad_mask_tensor == idx] = selected.reshape(-1)
|
34 |
+
|
35 |
+
quad_mask_image_np = quad_mask_image_tensor[0].cpu().numpy().astype(np.uint8)
|
36 |
+
quad_mask_image = Image.fromarray(quad_mask_image_np).convert("RGB")
|
37 |
+
quad_mask_image.save(os.path.join(mask_image_dir, "{}_quad_{:.5f}.png".format(view_idx, view_score)))
|
38 |
+
|
39 |
+
|
40 |
+
def visualize_outputs(output_dir, init_image_dir, mask_image_dir, inpainted_image_dir, num_views):
|
41 |
+
# subplot settings
|
42 |
+
num_col = 3
|
43 |
+
num_row = 1
|
44 |
+
subplot_size = 4
|
45 |
+
|
46 |
+
summary_image_dir = os.path.join(output_dir, "summary")
|
47 |
+
os.makedirs(summary_image_dir, exist_ok=True)
|
48 |
+
|
49 |
+
# graph settings
|
50 |
+
print("=> visualizing results...")
|
51 |
+
for view_idx in range(num_views):
|
52 |
+
plt.switch_backend("agg")
|
53 |
+
fig = plt.figure(dpi=100)
|
54 |
+
fig.set_size_inches(subplot_size * num_col, subplot_size * (num_row + 1))
|
55 |
+
fig.set_facecolor('white')
|
56 |
+
|
57 |
+
# rendering
|
58 |
+
plt.subplot2grid((num_row, num_col), (0, 0))
|
59 |
+
plt.imshow(Image.open(os.path.join(init_image_dir, "{}.png".format(view_idx))))
|
60 |
+
plt.text(0, 0, "Rendering", fontsize=16, color='black', backgroundcolor='white')
|
61 |
+
plt.axis('off')
|
62 |
+
|
63 |
+
# mask
|
64 |
+
plt.subplot2grid((num_row, num_col), (0, 1))
|
65 |
+
plt.imshow(Image.open(os.path.join(mask_image_dir, "{}_project.png".format(view_idx))))
|
66 |
+
plt.text(0, 0, "Project Mask", fontsize=16, color='black', backgroundcolor='white')
|
67 |
+
plt.set_cmap(cm.Greys_r)
|
68 |
+
plt.axis('off')
|
69 |
+
|
70 |
+
# inpainted
|
71 |
+
plt.subplot2grid((num_row, num_col), (0, 2))
|
72 |
+
plt.imshow(Image.open(os.path.join(inpainted_image_dir, "{}.png".format(view_idx))))
|
73 |
+
plt.text(0, 0, "Inpainted", fontsize=16, color='black', backgroundcolor='white')
|
74 |
+
plt.axis('off')
|
75 |
+
|
76 |
+
|
77 |
+
plt.savefig(os.path.join(summary_image_dir, "{}.png".format(view_idx)), bbox_inches="tight")
|
78 |
+
fig.clf()
|
79 |
+
|
80 |
+
# generate GIF
|
81 |
+
images = [imageio.imread(os.path.join(summary_image_dir, "{}.png".format(view_idx)))for view_idx in range(num_views)]
|
82 |
+
imageio.mimsave(os.path.join(summary_image_dir, "output.gif"), images, duration=1)
|
83 |
+
|
84 |
+
print("=> done!")
|
85 |
+
|
86 |
+
|
87 |
+
def visualize_principle_viewpoints(output_dir, dist_list, elev_list, azim_list):
|
88 |
+
theta_list = [e for e in azim_list]
|
89 |
+
phi_list = [90 - e for e in elev_list]
|
90 |
+
DIST = dist_list[0]
|
91 |
+
|
92 |
+
xyz_list = [polar_to_xyz(theta, phi, DIST) for theta, phi in zip(theta_list, phi_list)]
|
93 |
+
|
94 |
+
xyz_np = np.array(xyz_list)
|
95 |
+
color_np = np.array([[0, 0, 0]]).repeat(xyz_np.shape[0], 0)
|
96 |
+
|
97 |
+
fig = plt.figure()
|
98 |
+
ax = plt.axes(projection='3d')
|
99 |
+
SCALE = 0.8
|
100 |
+
ax.set_xlim((-DIST, DIST))
|
101 |
+
ax.set_ylim((-DIST, DIST))
|
102 |
+
ax.set_zlim((-SCALE * DIST, SCALE * DIST))
|
103 |
+
|
104 |
+
ax.scatter(xyz_np[:, 0], xyz_np[:, 2], xyz_np[:, 1], s=100, c=color_np, depthshade=True, label="Principle views")
|
105 |
+
ax.scatter([0], [0], [0], c=[[1, 0, 0]], s=100, depthshade=True, label="Object center")
|
106 |
+
|
107 |
+
# draw hemisphere
|
108 |
+
# theta inclination angle
|
109 |
+
# phi azimuthal angle
|
110 |
+
n_theta = 50 # number of values for theta
|
111 |
+
n_phi = 200 # number of values for phi
|
112 |
+
r = DIST #radius of sphere
|
113 |
+
|
114 |
+
# theta, phi = np.mgrid[0.0:0.5*np.pi:n_theta*1j, 0.0:2.0*np.pi:n_phi*1j]
|
115 |
+
theta, phi = np.mgrid[0.0:1*np.pi:n_theta*1j, 0.0:2.0*np.pi:n_phi*1j]
|
116 |
+
|
117 |
+
x = r*np.sin(theta)*np.cos(phi)
|
118 |
+
y = r*np.sin(theta)*np.sin(phi)
|
119 |
+
z = r*np.cos(theta)
|
120 |
+
|
121 |
+
ax.plot_surface(x, y, z, rstride=1, cstride=1, alpha=0.25, linewidth=1)
|
122 |
+
|
123 |
+
# Make the grid
|
124 |
+
ax.quiver(
|
125 |
+
xyz_np[:, 0],
|
126 |
+
xyz_np[:, 2],
|
127 |
+
xyz_np[:, 1],
|
128 |
+
-xyz_np[:, 0],
|
129 |
+
-xyz_np[:, 2],
|
130 |
+
-xyz_np[:, 1],
|
131 |
+
normalize=True,
|
132 |
+
length=0.3
|
133 |
+
)
|
134 |
+
|
135 |
+
ax.set_xlabel('X Label')
|
136 |
+
ax.set_ylabel('Z Label')
|
137 |
+
ax.set_zlabel('Y Label')
|
138 |
+
|
139 |
+
ax.view_init(30, 35)
|
140 |
+
ax.legend()
|
141 |
+
|
142 |
+
plt.show()
|
143 |
+
|
144 |
+
plt.savefig(os.path.join(output_dir, "principle_viewpoints.png"))
|
145 |
+
|
146 |
+
|
147 |
+
|
148 |
+
def visualize_refinement_viewpoints(output_dir, selected_view_ids, dist_list, elev_list, azim_list):
|
149 |
+
theta_list = [azim_list[i] for i in selected_view_ids]
|
150 |
+
phi_list = [90 - elev_list[i] for i in selected_view_ids]
|
151 |
+
DIST = dist_list[0]
|
152 |
+
|
153 |
+
xyz_list = [polar_to_xyz(theta, phi, DIST) for theta, phi in zip(theta_list, phi_list)]
|
154 |
+
|
155 |
+
xyz_np = np.array(xyz_list)
|
156 |
+
color_np = np.array([[0, 0, 0]]).repeat(xyz_np.shape[0], 0)
|
157 |
+
|
158 |
+
fig = plt.figure()
|
159 |
+
ax = plt.axes(projection='3d')
|
160 |
+
SCALE = 0.8
|
161 |
+
ax.set_xlim((-DIST, DIST))
|
162 |
+
ax.set_ylim((-DIST, DIST))
|
163 |
+
ax.set_zlim((-SCALE * DIST, SCALE * DIST))
|
164 |
+
|
165 |
+
ax.scatter(xyz_np[:, 0], xyz_np[:, 2], xyz_np[:, 1], c=color_np, depthshade=True, label="Refinement views")
|
166 |
+
ax.scatter([0], [0], [0], c=[[1, 0, 0]], s=100, depthshade=True, label="Object center")
|
167 |
+
|
168 |
+
# draw hemisphere
|
169 |
+
# theta inclination angle
|
170 |
+
# phi azimuthal angle
|
171 |
+
n_theta = 50 # number of values for theta
|
172 |
+
n_phi = 200 # number of values for phi
|
173 |
+
r = DIST #radius of sphere
|
174 |
+
|
175 |
+
# theta, phi = np.mgrid[0.0:0.5*np.pi:n_theta*1j, 0.0:2.0*np.pi:n_phi*1j]
|
176 |
+
theta, phi = np.mgrid[0.0:1*np.pi:n_theta*1j, 0.0:2.0*np.pi:n_phi*1j]
|
177 |
+
|
178 |
+
x = r*np.sin(theta)*np.cos(phi)
|
179 |
+
y = r*np.sin(theta)*np.sin(phi)
|
180 |
+
z = r*np.cos(theta)
|
181 |
+
|
182 |
+
ax.plot_surface(x, y, z, rstride=1, cstride=1, alpha=0.25, linewidth=1)
|
183 |
+
|
184 |
+
# Make the grid
|
185 |
+
ax.quiver(
|
186 |
+
xyz_np[:, 0],
|
187 |
+
xyz_np[:, 2],
|
188 |
+
xyz_np[:, 1],
|
189 |
+
-xyz_np[:, 0],
|
190 |
+
-xyz_np[:, 2],
|
191 |
+
-xyz_np[:, 1],
|
192 |
+
normalize=True,
|
193 |
+
length=0.3
|
194 |
+
)
|
195 |
+
|
196 |
+
ax.set_xlabel('X Label')
|
197 |
+
ax.set_ylabel('Z Label')
|
198 |
+
ax.set_zlabel('Y Label')
|
199 |
+
|
200 |
+
ax.view_init(30, 35)
|
201 |
+
ax.legend()
|
202 |
+
|
203 |
+
plt.show()
|
204 |
+
|
205 |
+
plt.savefig(os.path.join(output_dir, "refinement_viewpoints.png"))
|
206 |
+
|
207 |
+
fig.clear()
|
208 |
+
|
209 |
+
|
text2tex/models/ControlNet/.gitignore
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.idea/
|
2 |
+
|
3 |
+
training/
|
4 |
+
lightning_logs/
|
5 |
+
image_log/
|
6 |
+
|
7 |
+
*.pth
|
8 |
+
*.pt
|
9 |
+
*.ckpt
|
10 |
+
*.safetensors
|
11 |
+
|
12 |
+
gradio_pose2image_private.py
|
13 |
+
gradio_canny2image_private.py
|
14 |
+
|
15 |
+
# Byte-compiled / optimized / DLL files
|
16 |
+
__pycache__/
|
17 |
+
*.py[cod]
|
18 |
+
*$py.class
|
19 |
+
|
20 |
+
# C extensions
|
21 |
+
*.so
|
22 |
+
|
23 |
+
# Distribution / packaging
|
24 |
+
.Python
|
25 |
+
build/
|
26 |
+
develop-eggs/
|
27 |
+
dist/
|
28 |
+
downloads/
|
29 |
+
eggs/
|
30 |
+
.eggs/
|
31 |
+
lib/
|
32 |
+
lib64/
|
33 |
+
parts/
|
34 |
+
sdist/
|
35 |
+
var/
|
36 |
+
wheels/
|
37 |
+
pip-wheel-metadata/
|
38 |
+
share/python-wheels/
|
39 |
+
*.egg-info/
|
40 |
+
.installed.cfg
|
41 |
+
*.egg
|
42 |
+
MANIFEST
|
43 |
+
|
44 |
+
# PyInstaller
|
45 |
+
# Usually these files are written by a python script from a template
|
46 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
47 |
+
*.manifest
|
48 |
+
*.spec
|
49 |
+
|
50 |
+
# Installer logs
|
51 |
+
pip-log.txt
|
52 |
+
pip-delete-this-directory.txt
|
53 |
+
|
54 |
+
# Unit test / coverage reports
|
55 |
+
htmlcov/
|
56 |
+
.tox/
|
57 |
+
.nox/
|
58 |
+
.coverage
|
59 |
+
.coverage.*
|
60 |
+
.cache
|
61 |
+
nosetests.xml
|
62 |
+
coverage.xml
|
63 |
+
*.cover
|
64 |
+
*.py,cover
|
65 |
+
.hypothesis/
|
66 |
+
.pytest_cache/
|
67 |
+
|
68 |
+
# Translations
|
69 |
+
*.mo
|
70 |
+
*.pot
|
71 |
+
|
72 |
+
# Django stuff:
|
73 |
+
*.log
|
74 |
+
local_settings.py
|
75 |
+
db.sqlite3
|
76 |
+
db.sqlite3-journal
|
77 |
+
|
78 |
+
# Flask stuff:
|
79 |
+
instance/
|
80 |
+
.webassets-cache
|
81 |
+
|
82 |
+
# Scrapy stuff:
|
83 |
+
.scrapy
|
84 |
+
|
85 |
+
# Sphinx documentation
|
86 |
+
docs/_build/
|
87 |
+
|
88 |
+
# PyBuilder
|
89 |
+
target/
|
90 |
+
|
91 |
+
# Jupyter Notebook
|
92 |
+
.ipynb_checkpoints
|
93 |
+
|
94 |
+
# IPython
|
95 |
+
profile_default/
|
96 |
+
ipython_config.py
|
97 |
+
|
98 |
+
# pyenv
|
99 |
+
.python-version
|
100 |
+
|
101 |
+
# pipenv
|
102 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
103 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
104 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
105 |
+
# install all needed dependencies.
|
106 |
+
#Pipfile.lock
|
107 |
+
|
108 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
109 |
+
__pypackages__/
|
110 |
+
|
111 |
+
# Celery stuff
|
112 |
+
celerybeat-schedule
|
113 |
+
celerybeat.pid
|
114 |
+
|
115 |
+
# SageMath parsed files
|
116 |
+
*.sage.py
|
117 |
+
|
118 |
+
# Environments
|
119 |
+
.env
|
120 |
+
.venv
|
121 |
+
env/
|
122 |
+
venv/
|
123 |
+
ENV/
|
124 |
+
env.bak/
|
125 |
+
venv.bak/
|
126 |
+
|
127 |
+
# Spyder project settings
|
128 |
+
.spyderproject
|
129 |
+
.spyproject
|
130 |
+
|
131 |
+
# Rope project settings
|
132 |
+
.ropeproject
|
133 |
+
|
134 |
+
# mkdocs documentation
|
135 |
+
/site
|
136 |
+
|
137 |
+
# mypy
|
138 |
+
.mypy_cache/
|
139 |
+
.dmypy.json
|
140 |
+
dmypy.json
|
141 |
+
|
142 |
+
# Pyre type checker
|
143 |
+
.pyre/
|
text2tex/models/ControlNet/LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
text2tex/models/ControlNet/README.md
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ControlNet
|
2 |
+
|
3 |
+
Official implementation of [Adding Conditional Control to Text-to-Image Diffusion Models](https://arxiv.org/abs/2302.05543).
|
4 |
+
|
5 |
+
ControlNet is a neural network structure to control diffusion models by adding extra conditions.
|
6 |
+
|
7 |
+

|
8 |
+
|
9 |
+
It copys the weights of neural network blocks into a "locked" copy and a "trainable" copy.
|
10 |
+
|
11 |
+
The "trainable" one learns your condition. The "locked" one preserves your model.
|
12 |
+
|
13 |
+
Thanks to this, training with small dataset of image pairs will not destroy the production-ready diffusion models.
|
14 |
+
|
15 |
+
The "zero convolution" is 1×1 convolution with both weight and bias initialized as zeros.
|
16 |
+
|
17 |
+
Before training, all zero convolutions output zeros, and ControlNet will not cause any distortion.
|
18 |
+
|
19 |
+
No layer is trained from scratch. You are still fine-tuning. Your original model is safe.
|
20 |
+
|
21 |
+
This allows training on small-scale or even personal devices.
|
22 |
+
|
23 |
+
This is also friendly to merge/replacement/offsetting of models/weights/blocks/layers.
|
24 |
+
|
25 |
+
### FAQ
|
26 |
+
|
27 |
+
**Q:** But wait, if the weight of a conv layer is zero, the gradient will also be zero, and the network will not learn anything. Why "zero convolution" works?
|
28 |
+
|
29 |
+
**A:** This is not true. [See an explanation here](docs/faq.md).
|
30 |
+
|
31 |
+
# Stable Diffusion + ControlNet
|
32 |
+
|
33 |
+
By repeating the above simple structure 14 times, we can control stable diffusion in this way:
|
34 |
+
|
35 |
+

|
36 |
+
|
37 |
+
Note that the way we connect layers is computational efficient. The original SD encoder does not need to store gradients (the locked original SD Encoder Block 1234 and Middle). The required GPU memory is not much larger than original SD, although many layers are added. Great!
|
38 |
+
|
39 |
+
# Production-Ready Pretrained Models
|
40 |
+
|
41 |
+
First create a new conda environment
|
42 |
+
|
43 |
+
conda env create -f environment.yaml
|
44 |
+
conda activate control
|
45 |
+
|
46 |
+
All models and detectors can be downloaded from [our Hugging Face page](https://huggingface.co/lllyasviel/ControlNet). Make sure that SD models are put in "ControlNet/models" and detectors are put in "ControlNet/annotator/ckpts". Make sure that you download all necessary pretrained weights and detector models from that Hugging Face page, including HED edge detection model, Midas depth estimation model, Openpose, and so on.
|
47 |
+
|
48 |
+
We provide 9 Gradio apps with these models.
|
49 |
+
|
50 |
+
All test images can be found at the folder "test_imgs".
|
51 |
+
|
52 |
+
### News
|
53 |
+
|
54 |
+
2023/02/12 - Now you can play with any community model by [Transferring the ControlNet](https://github.com/lllyasviel/ControlNet/discussions/12).
|
55 |
+
|
56 |
+
2023/02/11 - [Low VRAM mode](docs/low_vram.md) is added. Please use this mode if you are using 8GB GPU(s) or if you want larger batch size.
|
57 |
+
|
58 |
+
## ControlNet with Canny Edge
|
59 |
+
|
60 |
+
Stable Diffusion 1.5 + ControlNet (using simple Canny edge detection)
|
61 |
+
|
62 |
+
python gradio_canny2image.py
|
63 |
+
|
64 |
+
The Gradio app also allows you to change the Canny edge thresholds. Just try it for more details.
|
65 |
+
|
66 |
+
Prompt: "bird"
|
67 |
+

|
68 |
+
|
69 |
+
Prompt: "cute dog"
|
70 |
+

|
71 |
+
|
72 |
+
## ControlNet with M-LSD Lines
|
73 |
+
|
74 |
+
Stable Diffusion 1.5 + ControlNet (using simple M-LSD straight line detection)
|
75 |
+
|
76 |
+
python gradio_hough2image.py
|
77 |
+
|
78 |
+
The Gradio app also allows you to change the M-LSD thresholds. Just try it for more details.
|
79 |
+
|
80 |
+
Prompt: "room"
|
81 |
+

|
82 |
+
|
83 |
+
Prompt: "building"
|
84 |
+

|
85 |
+
|
86 |
+
## ControlNet with HED Boundary
|
87 |
+
|
88 |
+
Stable Diffusion 1.5 + ControlNet (using soft HED Boundary)
|
89 |
+
|
90 |
+
python gradio_hed2image.py
|
91 |
+
|
92 |
+
The soft HED Boundary will preserve many details in input images, making this app suitable for recoloring and stylizing. Just try it for more details.
|
93 |
+
|
94 |
+
Prompt: "oil painting of handsome old man, masterpiece"
|
95 |
+

|
96 |
+
|
97 |
+
Prompt: "Cyberpunk robot"
|
98 |
+

|
99 |
+
|
100 |
+
## ControlNet with User Scribbles
|
101 |
+
|
102 |
+
Stable Diffusion 1.5 + ControlNet (using Scribbles)
|
103 |
+
|
104 |
+
python gradio_scribble2image.py
|
105 |
+
|
106 |
+
Note that the UI is based on Gradio, and Gradio is somewhat difficult to customize. Right now you need to draw scribbles outside the UI (using your favorite drawing software, for example, MS Paint) and then import the scribble image to Gradio.
|
107 |
+
|
108 |
+
Prompt: "turtle"
|
109 |
+

|
110 |
+
|
111 |
+
Prompt: "hot air balloon"
|
112 |
+

|
113 |
+
|
114 |
+
### Interactive Interface
|
115 |
+
|
116 |
+
We actually provide an interactive interface
|
117 |
+
|
118 |
+
python gradio_scribble2image_interactive.py
|
119 |
+
|
120 |
+
However, because gradio is very [buggy](https://github.com/gradio-app/gradio/issues/3166) and difficult to customize, right now, user need to first set canvas width and heights and then click "Open drawing canvas" to get a drawing area. Please do not upload image to that drawing canvas. Also, the drawing area is very small; it should be bigger. But I failed to find out how to make it larger. Again, gradio is really buggy.
|
121 |
+
|
122 |
+
The below dog sketch is drawn by me. Perhaps we should draw a better dog for showcase.
|
123 |
+
|
124 |
+
Prompt: "dog in a room"
|
125 |
+

|
126 |
+
|
127 |
+
## ControlNet with Fake Scribbles
|
128 |
+
|
129 |
+
Stable Diffusion 1.5 + ControlNet (using fake scribbles)
|
130 |
+
|
131 |
+
python gradio_fake_scribble2image.py
|
132 |
+
|
133 |
+
Sometimes we are lazy, and we do not want to draw scribbles. This script use the exactly same scribble-based model but use a simple algorithm to synthesize scribbles from input images.
|
134 |
+
|
135 |
+
Prompt: "bag"
|
136 |
+

|
137 |
+
|
138 |
+
Prompt: "shose" (Note that "shose" is a typo; it should be "shoes". But it still seems to work.)
|
139 |
+

|
140 |
+
|
141 |
+
## ControlNet with Human Pose
|
142 |
+
|
143 |
+
Stable Diffusion 1.5 + ControlNet (using human pose)
|
144 |
+
|
145 |
+
python gradio_pose2image.py
|
146 |
+
|
147 |
+
Apparently, this model deserves a better UI to directly manipulate pose skeleton. However, again, Gradio is somewhat difficult to customize. Right now you need to input an image and then the Openpose will detect the pose for you.
|
148 |
+
|
149 |
+
Prompt: "Chief in the kitchen"
|
150 |
+

|
151 |
+
|
152 |
+
Prompt: "An astronaut on the moon"
|
153 |
+

|
154 |
+
|
155 |
+
## ControlNet with Semantic Segmentation
|
156 |
+
|
157 |
+
Stable Diffusion 1.5 + ControlNet (using semantic segmentation)
|
158 |
+
|
159 |
+
python gradio_seg2image.py
|
160 |
+
|
161 |
+
This model use ADE20K's segmentation protocol. Again, this model deserves a better UI to directly draw the segmentations. However, again, Gradio is somewhat difficult to customize. Right now you need to input an image and then a model called Uniformer will detect the segmentations for you. Just try it for more details.
|
162 |
+
|
163 |
+
Prompt: "House"
|
164 |
+

|
165 |
+
|
166 |
+
Prompt: "River"
|
167 |
+

|
168 |
+
|
169 |
+
## ControlNet with Depth
|
170 |
+
|
171 |
+
Stable Diffusion 1.5 + ControlNet (using depth map)
|
172 |
+
|
173 |
+
python gradio_depth2image.py
|
174 |
+
|
175 |
+
Great! Now SD 1.5 also have a depth control. FINALLY. So many possibilities (considering SD1.5 has much more community models than SD2).
|
176 |
+
|
177 |
+
Note that different from Stability's model, the ControlNet receive the full 512×512 depth map, rather than 64×64 depth. Note that Stability's SD2 depth model use 64*64 depth maps. This means that the ControlNet will preserve more details in the depth map.
|
178 |
+
|
179 |
+
This is always a strength because if users do not want to preserve more details, they can simply use another SD to post-process an i2i. But if they want to preserve more details, ControlNet becomes their only choice. Again, SD2 uses 64×64 depth, we use 512×512.
|
180 |
+
|
181 |
+
Prompt: "Stormtrooper's lecture"
|
182 |
+

|
183 |
+
|
184 |
+
## ControlNet with Normal Map
|
185 |
+
|
186 |
+
Stable Diffusion 1.5 + ControlNet (using normal map)
|
187 |
+
|
188 |
+
python gradio_normal2image.py
|
189 |
+
|
190 |
+
This model use normal map. Rightnow in the APP, the normal is computed from the midas depth map and a user threshold (to determine how many area is background with identity normal face to viewer, tune the "Normal background threshold" in the gradio app to get a feeling).
|
191 |
+
|
192 |
+
Prompt: "Cute toy"
|
193 |
+

|
194 |
+
|
195 |
+
Prompt: "Plaster statue of Abraham Lincoln"
|
196 |
+

|
197 |
+
|
198 |
+
Compared to depth model, this model seems to be a bit better at preserving the geometry. This is intuitive: minor details are not salient in depth maps, but are salient in normal maps. Below is the depth result with same inputs. You can see that the hairstyle of the man in the input image is modified by depth model, but preserved by the normal model.
|
199 |
+
|
200 |
+
Prompt: "Plaster statue of Abraham Lincoln"
|
201 |
+

|
202 |
+
|
203 |
+
## ControlNet with Anime Line Drawing
|
204 |
+
|
205 |
+
We also trained a relatively simple ControlNet for anime line drawings. This tool may be useful for artistic creations. (Although the image details in the results is a bit modified, since it still diffuse latent images.)
|
206 |
+
|
207 |
+
This model is not available right now. We need to evaluate the potential risks before releasing this model. Nevertheless, you may be interested in [transferring the ControlNet to any community model](https://github.com/lllyasviel/ControlNet/discussions/12).
|
208 |
+
|
209 |
+

|
210 |
+
|
211 |
+
# Annotate Your Own Data
|
212 |
+
|
213 |
+
We provide simple python scripts to process images.
|
214 |
+
|
215 |
+
[See a gradio example here](docs/annotator.md).
|
216 |
+
|
217 |
+
# Train with Your Own Data
|
218 |
+
|
219 |
+
Training a ControlNet is as easy as (or even easier than) training a simple pix2pix.
|
220 |
+
|
221 |
+
[See the steps here](docs/train.md).
|
222 |
+
|
223 |
+
# Citation
|
224 |
+
|
225 |
+
@misc{zhang2023adding,
|
226 |
+
title={Adding Conditional Control to Text-to-Image Diffusion Models},
|
227 |
+
author={Lvmin Zhang and Maneesh Agrawala},
|
228 |
+
year={2023},
|
229 |
+
eprint={2302.05543},
|
230 |
+
archivePrefix={arXiv},
|
231 |
+
primaryClass={cs.CV}
|
232 |
+
}
|
233 |
+
|
234 |
+
[Arxiv Link](https://arxiv.org/abs/2302.05543)
|
text2tex/models/ControlNet/annotator/canny/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
|
3 |
+
|
4 |
+
def apply_canny(img, low_threshold, high_threshold):
|
5 |
+
return cv2.Canny(img, low_threshold, high_threshold)
|
text2tex/models/ControlNet/annotator/ckpts/ckpts.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Weights here.
|
text2tex/models/ControlNet/annotator/hed/__init__.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
from einops import rearrange
|
5 |
+
|
6 |
+
|
7 |
+
class Network(torch.nn.Module):
|
8 |
+
def __init__(self):
|
9 |
+
super().__init__()
|
10 |
+
|
11 |
+
self.netVggOne = torch.nn.Sequential(
|
12 |
+
torch.nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1),
|
13 |
+
torch.nn.ReLU(inplace=False),
|
14 |
+
torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
|
15 |
+
torch.nn.ReLU(inplace=False)
|
16 |
+
)
|
17 |
+
|
18 |
+
self.netVggTwo = torch.nn.Sequential(
|
19 |
+
torch.nn.MaxPool2d(kernel_size=2, stride=2),
|
20 |
+
torch.nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
|
21 |
+
torch.nn.ReLU(inplace=False),
|
22 |
+
torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
|
23 |
+
torch.nn.ReLU(inplace=False)
|
24 |
+
)
|
25 |
+
|
26 |
+
self.netVggThr = torch.nn.Sequential(
|
27 |
+
torch.nn.MaxPool2d(kernel_size=2, stride=2),
|
28 |
+
torch.nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
|
29 |
+
torch.nn.ReLU(inplace=False),
|
30 |
+
torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
|
31 |
+
torch.nn.ReLU(inplace=False),
|
32 |
+
torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
|
33 |
+
torch.nn.ReLU(inplace=False)
|
34 |
+
)
|
35 |
+
|
36 |
+
self.netVggFou = torch.nn.Sequential(
|
37 |
+
torch.nn.MaxPool2d(kernel_size=2, stride=2),
|
38 |
+
torch.nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
|
39 |
+
torch.nn.ReLU(inplace=False),
|
40 |
+
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
41 |
+
torch.nn.ReLU(inplace=False),
|
42 |
+
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
43 |
+
torch.nn.ReLU(inplace=False)
|
44 |
+
)
|
45 |
+
|
46 |
+
self.netVggFiv = torch.nn.Sequential(
|
47 |
+
torch.nn.MaxPool2d(kernel_size=2, stride=2),
|
48 |
+
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
49 |
+
torch.nn.ReLU(inplace=False),
|
50 |
+
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
51 |
+
torch.nn.ReLU(inplace=False),
|
52 |
+
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
53 |
+
torch.nn.ReLU(inplace=False)
|
54 |
+
)
|
55 |
+
|
56 |
+
self.netScoreOne = torch.nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1, stride=1, padding=0)
|
57 |
+
self.netScoreTwo = torch.nn.Conv2d(in_channels=128, out_channels=1, kernel_size=1, stride=1, padding=0)
|
58 |
+
self.netScoreThr = torch.nn.Conv2d(in_channels=256, out_channels=1, kernel_size=1, stride=1, padding=0)
|
59 |
+
self.netScoreFou = torch.nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0)
|
60 |
+
self.netScoreFiv = torch.nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0)
|
61 |
+
|
62 |
+
self.netCombine = torch.nn.Sequential(
|
63 |
+
torch.nn.Conv2d(in_channels=5, out_channels=1, kernel_size=1, stride=1, padding=0),
|
64 |
+
torch.nn.Sigmoid()
|
65 |
+
)
|
66 |
+
|
67 |
+
self.load_state_dict({strKey.replace('module', 'net'): tenWeight for strKey, tenWeight in torch.load('./annotator/ckpts/network-bsds500.pth').items()})
|
68 |
+
# end
|
69 |
+
|
70 |
+
def forward(self, tenInput):
|
71 |
+
tenInput = tenInput * 255.0
|
72 |
+
tenInput = tenInput - torch.tensor(data=[104.00698793, 116.66876762, 122.67891434], dtype=tenInput.dtype, device=tenInput.device).view(1, 3, 1, 1)
|
73 |
+
|
74 |
+
tenVggOne = self.netVggOne(tenInput)
|
75 |
+
tenVggTwo = self.netVggTwo(tenVggOne)
|
76 |
+
tenVggThr = self.netVggThr(tenVggTwo)
|
77 |
+
tenVggFou = self.netVggFou(tenVggThr)
|
78 |
+
tenVggFiv = self.netVggFiv(tenVggFou)
|
79 |
+
|
80 |
+
tenScoreOne = self.netScoreOne(tenVggOne)
|
81 |
+
tenScoreTwo = self.netScoreTwo(tenVggTwo)
|
82 |
+
tenScoreThr = self.netScoreThr(tenVggThr)
|
83 |
+
tenScoreFou = self.netScoreFou(tenVggFou)
|
84 |
+
tenScoreFiv = self.netScoreFiv(tenVggFiv)
|
85 |
+
|
86 |
+
tenScoreOne = torch.nn.functional.interpolate(input=tenScoreOne, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
|
87 |
+
tenScoreTwo = torch.nn.functional.interpolate(input=tenScoreTwo, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
|
88 |
+
tenScoreThr = torch.nn.functional.interpolate(input=tenScoreThr, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
|
89 |
+
tenScoreFou = torch.nn.functional.interpolate(input=tenScoreFou, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
|
90 |
+
tenScoreFiv = torch.nn.functional.interpolate(input=tenScoreFiv, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
|
91 |
+
|
92 |
+
return self.netCombine(torch.cat([ tenScoreOne, tenScoreTwo, tenScoreThr, tenScoreFou, tenScoreFiv ], 1))
|
93 |
+
# end
|
94 |
+
# end
|
95 |
+
|
96 |
+
|
97 |
+
netNetwork = Network().cuda().eval()
|
98 |
+
|
99 |
+
|
100 |
+
def apply_hed(input_image):
|
101 |
+
assert input_image.ndim == 3
|
102 |
+
input_image = input_image[:, :, ::-1].copy()
|
103 |
+
with torch.no_grad():
|
104 |
+
image_hed = torch.from_numpy(input_image).float().cuda()
|
105 |
+
image_hed = image_hed / 255.0
|
106 |
+
image_hed = rearrange(image_hed, 'h w c -> 1 c h w')
|
107 |
+
edge = netNetwork(image_hed)[0]
|
108 |
+
edge = (edge.cpu().numpy() * 255.0).clip(0, 255).astype(np.uint8)
|
109 |
+
return edge[0]
|
110 |
+
|
111 |
+
|
112 |
+
def nms(x, t, s):
|
113 |
+
x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
|
114 |
+
|
115 |
+
f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
|
116 |
+
f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
|
117 |
+
f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
|
118 |
+
f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
|
119 |
+
|
120 |
+
y = np.zeros_like(x)
|
121 |
+
|
122 |
+
for f in [f1, f2, f3, f4]:
|
123 |
+
np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
|
124 |
+
|
125 |
+
z = np.zeros_like(y, dtype=np.uint8)
|
126 |
+
z[y > t] = 255
|
127 |
+
return z
|
text2tex/models/ControlNet/annotator/midas/__init__.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from einops import rearrange
|
6 |
+
from .api import MiDaSInference
|
7 |
+
|
8 |
+
model = MiDaSInference(model_type="dpt_hybrid").cuda()
|
9 |
+
|
10 |
+
|
11 |
+
def apply_midas(input_image, a=np.pi * 2.0, bg_th=0.1):
|
12 |
+
assert input_image.ndim == 3
|
13 |
+
image_depth = input_image
|
14 |
+
with torch.no_grad():
|
15 |
+
image_depth = torch.from_numpy(image_depth).float().cuda()
|
16 |
+
image_depth = image_depth / 127.5 - 1.0
|
17 |
+
image_depth = rearrange(image_depth, 'h w c -> 1 c h w')
|
18 |
+
depth = model(image_depth)[0]
|
19 |
+
|
20 |
+
depth_pt = depth.clone()
|
21 |
+
depth_pt -= torch.min(depth_pt)
|
22 |
+
depth_pt /= torch.max(depth_pt)
|
23 |
+
depth_pt = depth_pt.cpu().numpy()
|
24 |
+
depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8)
|
25 |
+
|
26 |
+
depth_np = depth.cpu().numpy()
|
27 |
+
x = cv2.Sobel(depth_np, cv2.CV_32F, 1, 0, ksize=3)
|
28 |
+
y = cv2.Sobel(depth_np, cv2.CV_32F, 0, 1, ksize=3)
|
29 |
+
z = np.ones_like(x) * a
|
30 |
+
x[depth_pt < bg_th] = 0
|
31 |
+
y[depth_pt < bg_th] = 0
|
32 |
+
normal = np.stack([x, y, z], axis=2)
|
33 |
+
normal /= np.sum(normal ** 2.0, axis=2, keepdims=True) ** 0.5
|
34 |
+
normal_image = (normal * 127.5 + 127.5).clip(0, 255).astype(np.uint8)
|
35 |
+
|
36 |
+
return depth_image, normal_image
|
text2tex/models/ControlNet/annotator/midas/api.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# based on https://github.com/isl-org/MiDaS
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torchvision.transforms import Compose
|
7 |
+
|
8 |
+
from .midas.dpt_depth import DPTDepthModel
|
9 |
+
from .midas.midas_net import MidasNet
|
10 |
+
from .midas.midas_net_custom import MidasNet_small
|
11 |
+
from .midas.transforms import Resize, NormalizeImage, PrepareForNet
|
12 |
+
|
13 |
+
import os
|
14 |
+
import sys
|
15 |
+
|
16 |
+
BASE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../..")
|
17 |
+
|
18 |
+
ISL_PATHS = {
|
19 |
+
"dpt_large": BASE_DIR+"/annotator/ckpts/dpt_large-midas-2f21e586.pt",
|
20 |
+
"dpt_hybrid": BASE_DIR+"/annotator/ckpts/dpt_hybrid-midas-501f0c75.pt",
|
21 |
+
"midas_v21": "",
|
22 |
+
"midas_v21_small": "",
|
23 |
+
}
|
24 |
+
|
25 |
+
|
26 |
+
def disabled_train(self, mode=True):
|
27 |
+
"""Overwrite model.train with this function to make sure train/eval mode
|
28 |
+
does not change anymore."""
|
29 |
+
return self
|
30 |
+
|
31 |
+
|
32 |
+
def load_midas_transform(model_type):
|
33 |
+
# https://github.com/isl-org/MiDaS/blob/master/run.py
|
34 |
+
# load transform only
|
35 |
+
if model_type == "dpt_large": # DPT-Large
|
36 |
+
net_w, net_h = 384, 384
|
37 |
+
resize_mode = "minimal"
|
38 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
39 |
+
|
40 |
+
elif model_type == "dpt_hybrid": # DPT-Hybrid
|
41 |
+
net_w, net_h = 384, 384
|
42 |
+
resize_mode = "minimal"
|
43 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
44 |
+
|
45 |
+
elif model_type == "midas_v21":
|
46 |
+
net_w, net_h = 384, 384
|
47 |
+
resize_mode = "upper_bound"
|
48 |
+
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
49 |
+
|
50 |
+
elif model_type == "midas_v21_small":
|
51 |
+
net_w, net_h = 256, 256
|
52 |
+
resize_mode = "upper_bound"
|
53 |
+
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
54 |
+
|
55 |
+
else:
|
56 |
+
assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
|
57 |
+
|
58 |
+
transform = Compose(
|
59 |
+
[
|
60 |
+
Resize(
|
61 |
+
net_w,
|
62 |
+
net_h,
|
63 |
+
resize_target=None,
|
64 |
+
keep_aspect_ratio=True,
|
65 |
+
ensure_multiple_of=32,
|
66 |
+
resize_method=resize_mode,
|
67 |
+
image_interpolation_method=cv2.INTER_CUBIC,
|
68 |
+
),
|
69 |
+
normalization,
|
70 |
+
PrepareForNet(),
|
71 |
+
]
|
72 |
+
)
|
73 |
+
|
74 |
+
return transform
|
75 |
+
|
76 |
+
|
77 |
+
def load_model(model_type):
|
78 |
+
# https://github.com/isl-org/MiDaS/blob/master/run.py
|
79 |
+
# load network
|
80 |
+
model_path = ISL_PATHS[model_type]
|
81 |
+
if model_type == "dpt_large": # DPT-Large
|
82 |
+
model = DPTDepthModel(
|
83 |
+
path=model_path,
|
84 |
+
backbone="vitl16_384",
|
85 |
+
non_negative=True,
|
86 |
+
)
|
87 |
+
net_w, net_h = 384, 384
|
88 |
+
resize_mode = "minimal"
|
89 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
90 |
+
|
91 |
+
elif model_type == "dpt_hybrid": # DPT-Hybrid
|
92 |
+
model = DPTDepthModel(
|
93 |
+
path=model_path,
|
94 |
+
backbone="vitb_rn50_384",
|
95 |
+
non_negative=True,
|
96 |
+
)
|
97 |
+
net_w, net_h = 384, 384
|
98 |
+
resize_mode = "minimal"
|
99 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
100 |
+
|
101 |
+
elif model_type == "midas_v21":
|
102 |
+
model = MidasNet(model_path, non_negative=True)
|
103 |
+
net_w, net_h = 384, 384
|
104 |
+
resize_mode = "upper_bound"
|
105 |
+
normalization = NormalizeImage(
|
106 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
107 |
+
)
|
108 |
+
|
109 |
+
elif model_type == "midas_v21_small":
|
110 |
+
model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
|
111 |
+
non_negative=True, blocks={'expand': True})
|
112 |
+
net_w, net_h = 256, 256
|
113 |
+
resize_mode = "upper_bound"
|
114 |
+
normalization = NormalizeImage(
|
115 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
116 |
+
)
|
117 |
+
|
118 |
+
else:
|
119 |
+
print(f"model_type '{model_type}' not implemented, use: --model_type large")
|
120 |
+
assert False
|
121 |
+
|
122 |
+
transform = Compose(
|
123 |
+
[
|
124 |
+
Resize(
|
125 |
+
net_w,
|
126 |
+
net_h,
|
127 |
+
resize_target=None,
|
128 |
+
keep_aspect_ratio=True,
|
129 |
+
ensure_multiple_of=32,
|
130 |
+
resize_method=resize_mode,
|
131 |
+
image_interpolation_method=cv2.INTER_CUBIC,
|
132 |
+
),
|
133 |
+
normalization,
|
134 |
+
PrepareForNet(),
|
135 |
+
]
|
136 |
+
)
|
137 |
+
|
138 |
+
return model.eval(), transform
|
139 |
+
|
140 |
+
|
141 |
+
class MiDaSInference(nn.Module):
|
142 |
+
MODEL_TYPES_TORCH_HUB = [
|
143 |
+
"DPT_Large",
|
144 |
+
"DPT_Hybrid",
|
145 |
+
"MiDaS_small"
|
146 |
+
]
|
147 |
+
MODEL_TYPES_ISL = [
|
148 |
+
"dpt_large",
|
149 |
+
"dpt_hybrid",
|
150 |
+
"midas_v21",
|
151 |
+
"midas_v21_small",
|
152 |
+
]
|
153 |
+
|
154 |
+
def __init__(self, model_type):
|
155 |
+
super().__init__()
|
156 |
+
assert (model_type in self.MODEL_TYPES_ISL)
|
157 |
+
model, _ = load_model(model_type)
|
158 |
+
self.model = model
|
159 |
+
self.model.train = disabled_train
|
160 |
+
|
161 |
+
def forward(self, x):
|
162 |
+
with torch.no_grad():
|
163 |
+
prediction = self.model(x)
|
164 |
+
return prediction
|
165 |
+
|
text2tex/models/ControlNet/annotator/midas/midas/__init__.py
ADDED
File without changes
|
text2tex/models/ControlNet/annotator/midas/midas/base_model.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
class BaseModel(torch.nn.Module):
|
5 |
+
def load(self, path):
|
6 |
+
"""Load model from file.
|
7 |
+
|
8 |
+
Args:
|
9 |
+
path (str): file path
|
10 |
+
"""
|
11 |
+
parameters = torch.load(path, map_location=torch.device('cpu'))
|
12 |
+
|
13 |
+
if "optimizer" in parameters:
|
14 |
+
parameters = parameters["model"]
|
15 |
+
|
16 |
+
self.load_state_dict(parameters)
|
text2tex/models/ControlNet/annotator/midas/midas/blocks.py
ADDED
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from .vit import (
|
5 |
+
_make_pretrained_vitb_rn50_384,
|
6 |
+
_make_pretrained_vitl16_384,
|
7 |
+
_make_pretrained_vitb16_384,
|
8 |
+
forward_vit,
|
9 |
+
)
|
10 |
+
|
11 |
+
def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
|
12 |
+
if backbone == "vitl16_384":
|
13 |
+
pretrained = _make_pretrained_vitl16_384(
|
14 |
+
use_pretrained, hooks=hooks, use_readout=use_readout
|
15 |
+
)
|
16 |
+
scratch = _make_scratch(
|
17 |
+
[256, 512, 1024, 1024], features, groups=groups, expand=expand
|
18 |
+
) # ViT-L/16 - 85.0% Top1 (backbone)
|
19 |
+
elif backbone == "vitb_rn50_384":
|
20 |
+
pretrained = _make_pretrained_vitb_rn50_384(
|
21 |
+
use_pretrained,
|
22 |
+
hooks=hooks,
|
23 |
+
use_vit_only=use_vit_only,
|
24 |
+
use_readout=use_readout,
|
25 |
+
)
|
26 |
+
scratch = _make_scratch(
|
27 |
+
[256, 512, 768, 768], features, groups=groups, expand=expand
|
28 |
+
) # ViT-H/16 - 85.0% Top1 (backbone)
|
29 |
+
elif backbone == "vitb16_384":
|
30 |
+
pretrained = _make_pretrained_vitb16_384(
|
31 |
+
use_pretrained, hooks=hooks, use_readout=use_readout
|
32 |
+
)
|
33 |
+
scratch = _make_scratch(
|
34 |
+
[96, 192, 384, 768], features, groups=groups, expand=expand
|
35 |
+
) # ViT-B/16 - 84.6% Top1 (backbone)
|
36 |
+
elif backbone == "resnext101_wsl":
|
37 |
+
pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
|
38 |
+
scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
|
39 |
+
elif backbone == "efficientnet_lite3":
|
40 |
+
pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
|
41 |
+
scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
|
42 |
+
else:
|
43 |
+
print(f"Backbone '{backbone}' not implemented")
|
44 |
+
assert False
|
45 |
+
|
46 |
+
return pretrained, scratch
|
47 |
+
|
48 |
+
|
49 |
+
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
|
50 |
+
scratch = nn.Module()
|
51 |
+
|
52 |
+
out_shape1 = out_shape
|
53 |
+
out_shape2 = out_shape
|
54 |
+
out_shape3 = out_shape
|
55 |
+
out_shape4 = out_shape
|
56 |
+
if expand==True:
|
57 |
+
out_shape1 = out_shape
|
58 |
+
out_shape2 = out_shape*2
|
59 |
+
out_shape3 = out_shape*4
|
60 |
+
out_shape4 = out_shape*8
|
61 |
+
|
62 |
+
scratch.layer1_rn = nn.Conv2d(
|
63 |
+
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
64 |
+
)
|
65 |
+
scratch.layer2_rn = nn.Conv2d(
|
66 |
+
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
67 |
+
)
|
68 |
+
scratch.layer3_rn = nn.Conv2d(
|
69 |
+
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
70 |
+
)
|
71 |
+
scratch.layer4_rn = nn.Conv2d(
|
72 |
+
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
73 |
+
)
|
74 |
+
|
75 |
+
return scratch
|
76 |
+
|
77 |
+
|
78 |
+
def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
|
79 |
+
efficientnet = torch.hub.load(
|
80 |
+
"rwightman/gen-efficientnet-pytorch",
|
81 |
+
"tf_efficientnet_lite3",
|
82 |
+
pretrained=use_pretrained,
|
83 |
+
exportable=exportable
|
84 |
+
)
|
85 |
+
return _make_efficientnet_backbone(efficientnet)
|
86 |
+
|
87 |
+
|
88 |
+
def _make_efficientnet_backbone(effnet):
|
89 |
+
pretrained = nn.Module()
|
90 |
+
|
91 |
+
pretrained.layer1 = nn.Sequential(
|
92 |
+
effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
|
93 |
+
)
|
94 |
+
pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
|
95 |
+
pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
|
96 |
+
pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
|
97 |
+
|
98 |
+
return pretrained
|
99 |
+
|
100 |
+
|
101 |
+
def _make_resnet_backbone(resnet):
|
102 |
+
pretrained = nn.Module()
|
103 |
+
pretrained.layer1 = nn.Sequential(
|
104 |
+
resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
|
105 |
+
)
|
106 |
+
|
107 |
+
pretrained.layer2 = resnet.layer2
|
108 |
+
pretrained.layer3 = resnet.layer3
|
109 |
+
pretrained.layer4 = resnet.layer4
|
110 |
+
|
111 |
+
return pretrained
|
112 |
+
|
113 |
+
|
114 |
+
def _make_pretrained_resnext101_wsl(use_pretrained):
|
115 |
+
resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
|
116 |
+
return _make_resnet_backbone(resnet)
|
117 |
+
|
118 |
+
|
119 |
+
|
120 |
+
class Interpolate(nn.Module):
|
121 |
+
"""Interpolation module.
|
122 |
+
"""
|
123 |
+
|
124 |
+
def __init__(self, scale_factor, mode, align_corners=False):
|
125 |
+
"""Init.
|
126 |
+
|
127 |
+
Args:
|
128 |
+
scale_factor (float): scaling
|
129 |
+
mode (str): interpolation mode
|
130 |
+
"""
|
131 |
+
super(Interpolate, self).__init__()
|
132 |
+
|
133 |
+
self.interp = nn.functional.interpolate
|
134 |
+
self.scale_factor = scale_factor
|
135 |
+
self.mode = mode
|
136 |
+
self.align_corners = align_corners
|
137 |
+
|
138 |
+
def forward(self, x):
|
139 |
+
"""Forward pass.
|
140 |
+
|
141 |
+
Args:
|
142 |
+
x (tensor): input
|
143 |
+
|
144 |
+
Returns:
|
145 |
+
tensor: interpolated data
|
146 |
+
"""
|
147 |
+
|
148 |
+
x = self.interp(
|
149 |
+
x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
|
150 |
+
)
|
151 |
+
|
152 |
+
return x
|
153 |
+
|
154 |
+
|
155 |
+
class ResidualConvUnit(nn.Module):
|
156 |
+
"""Residual convolution module.
|
157 |
+
"""
|
158 |
+
|
159 |
+
def __init__(self, features):
|
160 |
+
"""Init.
|
161 |
+
|
162 |
+
Args:
|
163 |
+
features (int): number of features
|
164 |
+
"""
|
165 |
+
super().__init__()
|
166 |
+
|
167 |
+
self.conv1 = nn.Conv2d(
|
168 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True
|
169 |
+
)
|
170 |
+
|
171 |
+
self.conv2 = nn.Conv2d(
|
172 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True
|
173 |
+
)
|
174 |
+
|
175 |
+
self.relu = nn.ReLU(inplace=True)
|
176 |
+
|
177 |
+
def forward(self, x):
|
178 |
+
"""Forward pass.
|
179 |
+
|
180 |
+
Args:
|
181 |
+
x (tensor): input
|
182 |
+
|
183 |
+
Returns:
|
184 |
+
tensor: output
|
185 |
+
"""
|
186 |
+
out = self.relu(x)
|
187 |
+
out = self.conv1(out)
|
188 |
+
out = self.relu(out)
|
189 |
+
out = self.conv2(out)
|
190 |
+
|
191 |
+
return out + x
|
192 |
+
|
193 |
+
|
194 |
+
class FeatureFusionBlock(nn.Module):
|
195 |
+
"""Feature fusion block.
|
196 |
+
"""
|
197 |
+
|
198 |
+
def __init__(self, features):
|
199 |
+
"""Init.
|
200 |
+
|
201 |
+
Args:
|
202 |
+
features (int): number of features
|
203 |
+
"""
|
204 |
+
super(FeatureFusionBlock, self).__init__()
|
205 |
+
|
206 |
+
self.resConfUnit1 = ResidualConvUnit(features)
|
207 |
+
self.resConfUnit2 = ResidualConvUnit(features)
|
208 |
+
|
209 |
+
def forward(self, *xs):
|
210 |
+
"""Forward pass.
|
211 |
+
|
212 |
+
Returns:
|
213 |
+
tensor: output
|
214 |
+
"""
|
215 |
+
output = xs[0]
|
216 |
+
|
217 |
+
if len(xs) == 2:
|
218 |
+
output += self.resConfUnit1(xs[1])
|
219 |
+
|
220 |
+
output = self.resConfUnit2(output)
|
221 |
+
|
222 |
+
output = nn.functional.interpolate(
|
223 |
+
output, scale_factor=2, mode="bilinear", align_corners=True
|
224 |
+
)
|
225 |
+
|
226 |
+
return output
|
227 |
+
|
228 |
+
|
229 |
+
|
230 |
+
|
231 |
+
class ResidualConvUnit_custom(nn.Module):
|
232 |
+
"""Residual convolution module.
|
233 |
+
"""
|
234 |
+
|
235 |
+
def __init__(self, features, activation, bn):
|
236 |
+
"""Init.
|
237 |
+
|
238 |
+
Args:
|
239 |
+
features (int): number of features
|
240 |
+
"""
|
241 |
+
super().__init__()
|
242 |
+
|
243 |
+
self.bn = bn
|
244 |
+
|
245 |
+
self.groups=1
|
246 |
+
|
247 |
+
self.conv1 = nn.Conv2d(
|
248 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
|
249 |
+
)
|
250 |
+
|
251 |
+
self.conv2 = nn.Conv2d(
|
252 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
|
253 |
+
)
|
254 |
+
|
255 |
+
if self.bn==True:
|
256 |
+
self.bn1 = nn.BatchNorm2d(features)
|
257 |
+
self.bn2 = nn.BatchNorm2d(features)
|
258 |
+
|
259 |
+
self.activation = activation
|
260 |
+
|
261 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
262 |
+
|
263 |
+
def forward(self, x):
|
264 |
+
"""Forward pass.
|
265 |
+
|
266 |
+
Args:
|
267 |
+
x (tensor): input
|
268 |
+
|
269 |
+
Returns:
|
270 |
+
tensor: output
|
271 |
+
"""
|
272 |
+
|
273 |
+
out = self.activation(x)
|
274 |
+
out = self.conv1(out)
|
275 |
+
if self.bn==True:
|
276 |
+
out = self.bn1(out)
|
277 |
+
|
278 |
+
out = self.activation(out)
|
279 |
+
out = self.conv2(out)
|
280 |
+
if self.bn==True:
|
281 |
+
out = self.bn2(out)
|
282 |
+
|
283 |
+
if self.groups > 1:
|
284 |
+
out = self.conv_merge(out)
|
285 |
+
|
286 |
+
return self.skip_add.add(out, x)
|
287 |
+
|
288 |
+
# return out + x
|
289 |
+
|
290 |
+
|
291 |
+
class FeatureFusionBlock_custom(nn.Module):
|
292 |
+
"""Feature fusion block.
|
293 |
+
"""
|
294 |
+
|
295 |
+
def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
|
296 |
+
"""Init.
|
297 |
+
|
298 |
+
Args:
|
299 |
+
features (int): number of features
|
300 |
+
"""
|
301 |
+
super(FeatureFusionBlock_custom, self).__init__()
|
302 |
+
|
303 |
+
self.deconv = deconv
|
304 |
+
self.align_corners = align_corners
|
305 |
+
|
306 |
+
self.groups=1
|
307 |
+
|
308 |
+
self.expand = expand
|
309 |
+
out_features = features
|
310 |
+
if self.expand==True:
|
311 |
+
out_features = features//2
|
312 |
+
|
313 |
+
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
|
314 |
+
|
315 |
+
self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
|
316 |
+
self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
|
317 |
+
|
318 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
319 |
+
|
320 |
+
def forward(self, *xs):
|
321 |
+
"""Forward pass.
|
322 |
+
|
323 |
+
Returns:
|
324 |
+
tensor: output
|
325 |
+
"""
|
326 |
+
output = xs[0]
|
327 |
+
|
328 |
+
if len(xs) == 2:
|
329 |
+
res = self.resConfUnit1(xs[1])
|
330 |
+
output = self.skip_add.add(output, res)
|
331 |
+
# output += res
|
332 |
+
|
333 |
+
output = self.resConfUnit2(output)
|
334 |
+
|
335 |
+
output = nn.functional.interpolate(
|
336 |
+
output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
|
337 |
+
)
|
338 |
+
|
339 |
+
output = self.out_conv(output)
|
340 |
+
|
341 |
+
return output
|
342 |
+
|
text2tex/models/ControlNet/annotator/midas/midas/dpt_depth.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from .base_model import BaseModel
|
6 |
+
from .blocks import (
|
7 |
+
FeatureFusionBlock,
|
8 |
+
FeatureFusionBlock_custom,
|
9 |
+
Interpolate,
|
10 |
+
_make_encoder,
|
11 |
+
forward_vit,
|
12 |
+
)
|
13 |
+
|
14 |
+
|
15 |
+
def _make_fusion_block(features, use_bn):
|
16 |
+
return FeatureFusionBlock_custom(
|
17 |
+
features,
|
18 |
+
nn.ReLU(False),
|
19 |
+
deconv=False,
|
20 |
+
bn=use_bn,
|
21 |
+
expand=False,
|
22 |
+
align_corners=True,
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
class DPT(BaseModel):
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
head,
|
30 |
+
features=256,
|
31 |
+
backbone="vitb_rn50_384",
|
32 |
+
readout="project",
|
33 |
+
channels_last=False,
|
34 |
+
use_bn=False,
|
35 |
+
):
|
36 |
+
|
37 |
+
super(DPT, self).__init__()
|
38 |
+
|
39 |
+
self.channels_last = channels_last
|
40 |
+
|
41 |
+
hooks = {
|
42 |
+
"vitb_rn50_384": [0, 1, 8, 11],
|
43 |
+
"vitb16_384": [2, 5, 8, 11],
|
44 |
+
"vitl16_384": [5, 11, 17, 23],
|
45 |
+
}
|
46 |
+
|
47 |
+
# Instantiate backbone and reassemble blocks
|
48 |
+
self.pretrained, self.scratch = _make_encoder(
|
49 |
+
backbone,
|
50 |
+
features,
|
51 |
+
False, # Set to true of you want to train from scratch, uses ImageNet weights
|
52 |
+
groups=1,
|
53 |
+
expand=False,
|
54 |
+
exportable=False,
|
55 |
+
hooks=hooks[backbone],
|
56 |
+
use_readout=readout,
|
57 |
+
)
|
58 |
+
|
59 |
+
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
|
60 |
+
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
|
61 |
+
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
|
62 |
+
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
|
63 |
+
|
64 |
+
self.scratch.output_conv = head
|
65 |
+
|
66 |
+
|
67 |
+
def forward(self, x):
|
68 |
+
if self.channels_last == True:
|
69 |
+
x.contiguous(memory_format=torch.channels_last)
|
70 |
+
|
71 |
+
layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
|
72 |
+
|
73 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
74 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
75 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
76 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
77 |
+
|
78 |
+
path_4 = self.scratch.refinenet4(layer_4_rn)
|
79 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
80 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
81 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
82 |
+
|
83 |
+
out = self.scratch.output_conv(path_1)
|
84 |
+
|
85 |
+
return out
|
86 |
+
|
87 |
+
|
88 |
+
class DPTDepthModel(DPT):
|
89 |
+
def __init__(self, path=None, non_negative=True, **kwargs):
|
90 |
+
features = kwargs["features"] if "features" in kwargs else 256
|
91 |
+
|
92 |
+
head = nn.Sequential(
|
93 |
+
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
|
94 |
+
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
|
95 |
+
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
|
96 |
+
nn.ReLU(True),
|
97 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
98 |
+
nn.ReLU(True) if non_negative else nn.Identity(),
|
99 |
+
nn.Identity(),
|
100 |
+
)
|
101 |
+
|
102 |
+
super().__init__(head, **kwargs)
|
103 |
+
|
104 |
+
if path is not None:
|
105 |
+
self.load(path)
|
106 |
+
|
107 |
+
def forward(self, x):
|
108 |
+
return super().forward(x).squeeze(dim=1)
|
109 |
+
|
text2tex/models/ControlNet/annotator/midas/midas/midas_net.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
|
2 |
+
This file contains code that is adapted from
|
3 |
+
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
4 |
+
"""
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from .base_model import BaseModel
|
9 |
+
from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
|
10 |
+
|
11 |
+
|
12 |
+
class MidasNet(BaseModel):
|
13 |
+
"""Network for monocular depth estimation.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, path=None, features=256, non_negative=True):
|
17 |
+
"""Init.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
path (str, optional): Path to saved model. Defaults to None.
|
21 |
+
features (int, optional): Number of features. Defaults to 256.
|
22 |
+
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
|
23 |
+
"""
|
24 |
+
print("Loading weights: ", path)
|
25 |
+
|
26 |
+
super(MidasNet, self).__init__()
|
27 |
+
|
28 |
+
use_pretrained = False if path is None else True
|
29 |
+
|
30 |
+
self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
|
31 |
+
|
32 |
+
self.scratch.refinenet4 = FeatureFusionBlock(features)
|
33 |
+
self.scratch.refinenet3 = FeatureFusionBlock(features)
|
34 |
+
self.scratch.refinenet2 = FeatureFusionBlock(features)
|
35 |
+
self.scratch.refinenet1 = FeatureFusionBlock(features)
|
36 |
+
|
37 |
+
self.scratch.output_conv = nn.Sequential(
|
38 |
+
nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
|
39 |
+
Interpolate(scale_factor=2, mode="bilinear"),
|
40 |
+
nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
|
41 |
+
nn.ReLU(True),
|
42 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
43 |
+
nn.ReLU(True) if non_negative else nn.Identity(),
|
44 |
+
)
|
45 |
+
|
46 |
+
if path:
|
47 |
+
self.load(path)
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
"""Forward pass.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
x (tensor): input data (image)
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
tensor: depth
|
57 |
+
"""
|
58 |
+
|
59 |
+
layer_1 = self.pretrained.layer1(x)
|
60 |
+
layer_2 = self.pretrained.layer2(layer_1)
|
61 |
+
layer_3 = self.pretrained.layer3(layer_2)
|
62 |
+
layer_4 = self.pretrained.layer4(layer_3)
|
63 |
+
|
64 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
65 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
66 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
67 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
68 |
+
|
69 |
+
path_4 = self.scratch.refinenet4(layer_4_rn)
|
70 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
71 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
72 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
73 |
+
|
74 |
+
out = self.scratch.output_conv(path_1)
|
75 |
+
|
76 |
+
return torch.squeeze(out, dim=1)
|
text2tex/models/ControlNet/annotator/midas/midas/midas_net_custom.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
|
2 |
+
This file contains code that is adapted from
|
3 |
+
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
4 |
+
"""
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from .base_model import BaseModel
|
9 |
+
from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
|
10 |
+
|
11 |
+
|
12 |
+
class MidasNet_small(BaseModel):
|
13 |
+
"""Network for monocular depth estimation.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
|
17 |
+
blocks={'expand': True}):
|
18 |
+
"""Init.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
path (str, optional): Path to saved model. Defaults to None.
|
22 |
+
features (int, optional): Number of features. Defaults to 256.
|
23 |
+
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
|
24 |
+
"""
|
25 |
+
print("Loading weights: ", path)
|
26 |
+
|
27 |
+
super(MidasNet_small, self).__init__()
|
28 |
+
|
29 |
+
use_pretrained = False if path else True
|
30 |
+
|
31 |
+
self.channels_last = channels_last
|
32 |
+
self.blocks = blocks
|
33 |
+
self.backbone = backbone
|
34 |
+
|
35 |
+
self.groups = 1
|
36 |
+
|
37 |
+
features1=features
|
38 |
+
features2=features
|
39 |
+
features3=features
|
40 |
+
features4=features
|
41 |
+
self.expand = False
|
42 |
+
if "expand" in self.blocks and self.blocks['expand'] == True:
|
43 |
+
self.expand = True
|
44 |
+
features1=features
|
45 |
+
features2=features*2
|
46 |
+
features3=features*4
|
47 |
+
features4=features*8
|
48 |
+
|
49 |
+
self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
|
50 |
+
|
51 |
+
self.scratch.activation = nn.ReLU(False)
|
52 |
+
|
53 |
+
self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
|
54 |
+
self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
|
55 |
+
self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
|
56 |
+
self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
|
57 |
+
|
58 |
+
|
59 |
+
self.scratch.output_conv = nn.Sequential(
|
60 |
+
nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
|
61 |
+
Interpolate(scale_factor=2, mode="bilinear"),
|
62 |
+
nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
|
63 |
+
self.scratch.activation,
|
64 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
65 |
+
nn.ReLU(True) if non_negative else nn.Identity(),
|
66 |
+
nn.Identity(),
|
67 |
+
)
|
68 |
+
|
69 |
+
if path:
|
70 |
+
self.load(path)
|
71 |
+
|
72 |
+
|
73 |
+
def forward(self, x):
|
74 |
+
"""Forward pass.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
x (tensor): input data (image)
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
tensor: depth
|
81 |
+
"""
|
82 |
+
if self.channels_last==True:
|
83 |
+
print("self.channels_last = ", self.channels_last)
|
84 |
+
x.contiguous(memory_format=torch.channels_last)
|
85 |
+
|
86 |
+
|
87 |
+
layer_1 = self.pretrained.layer1(x)
|
88 |
+
layer_2 = self.pretrained.layer2(layer_1)
|
89 |
+
layer_3 = self.pretrained.layer3(layer_2)
|
90 |
+
layer_4 = self.pretrained.layer4(layer_3)
|
91 |
+
|
92 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
93 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
94 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
95 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
96 |
+
|
97 |
+
|
98 |
+
path_4 = self.scratch.refinenet4(layer_4_rn)
|
99 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
100 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
101 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
102 |
+
|
103 |
+
out = self.scratch.output_conv(path_1)
|
104 |
+
|
105 |
+
return torch.squeeze(out, dim=1)
|
106 |
+
|
107 |
+
|
108 |
+
|
109 |
+
def fuse_model(m):
|
110 |
+
prev_previous_type = nn.Identity()
|
111 |
+
prev_previous_name = ''
|
112 |
+
previous_type = nn.Identity()
|
113 |
+
previous_name = ''
|
114 |
+
for name, module in m.named_modules():
|
115 |
+
if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
|
116 |
+
# print("FUSED ", prev_previous_name, previous_name, name)
|
117 |
+
torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
|
118 |
+
elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
|
119 |
+
# print("FUSED ", prev_previous_name, previous_name)
|
120 |
+
torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
|
121 |
+
# elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
|
122 |
+
# print("FUSED ", previous_name, name)
|
123 |
+
# torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
|
124 |
+
|
125 |
+
prev_previous_type = previous_type
|
126 |
+
prev_previous_name = previous_name
|
127 |
+
previous_type = type(module)
|
128 |
+
previous_name = name
|