jingyangcarl commited on
Commit
05d00b7
·
1 Parent(s): ac36933
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. DockerFile → Dockerfile +2 -0
  2. README.md +4 -22
  3. app.py +10 -74
  4. app_3d.py +0 -21
  5. app_canny.py +0 -83
  6. app_matnet.py +0 -83
  7. app_sd.py +0 -154
  8. app_texnet.py +0 -259
  9. cv_utils.py +0 -17
  10. depth_estimator.py +0 -25
  11. examples/bunny/frame_0001.png +0 -3
  12. examples/bunny/mesh.obj +0 -0
  13. examples/bunny/uv_normal.png +0 -3
  14. examples/fighter/frame_0001.png +0 -3
  15. examples/fighter/mesh.obj +0 -0
  16. examples/fighter/uv_normal.png +0 -3
  17. examples/highheel/frame_0001.png +0 -3
  18. examples/highheel/mesh.obj +0 -0
  19. examples/highheel/uv_normal.png +0 -3
  20. examples/monkey/frame_0001.png +0 -3
  21. examples/monkey/mesh.obj +0 -0
  22. examples/monkey/uv_normal.png +0 -3
  23. examples/tank/frame_0001.png +0 -3
  24. examples/tank/mesh.obj +0 -3
  25. examples/tank/uv_normal.png +0 -3
  26. examples/tshirt/frame_0001.png +0 -3
  27. examples/tshirt/mesh.obj +0 -3
  28. examples/tshirt/uv_normal.png +0 -3
  29. image_segmentor.py +0 -33
  30. install.sh +0 -18
  31. model.py +0 -959
  32. pre-requirements.txt +0 -9
  33. preprocessor.py +0 -120
  34. push_dataset.py +0 -9
  35. requirements.txt +0 -9
  36. rgb2x/generate_blend.py +0 -142
  37. rgb2x/gradio_demo_rgb2x.py +0 -157
  38. rgb2x/load_image.py +0 -119
  39. rgb2x/pipeline_rgb2x.py +0 -821
  40. run.sh +5 -0
  41. settings.py +0 -23
  42. text2tex/lib/__init__.py +0 -0
  43. text2tex/lib/camera_helper.py +0 -231
  44. text2tex/lib/constants.py +0 -648
  45. text2tex/lib/diffusion_helper.py +0 -189
  46. text2tex/lib/io_helper.py +0 -78
  47. text2tex/lib/mesh_helper.py +0 -148
  48. text2tex/lib/projection_helper.py +0 -464
  49. text2tex/lib/render_helper.py +0 -108
  50. text2tex/lib/shading_helper.py +0 -45
DockerFile → Dockerfile RENAMED
@@ -8,8 +8,10 @@ RUN conda env create -f /code/environment.yml
8
 
9
  # Set up a new user named "user" with user ID 1000
10
  RUN useradd -m -u 1000 user
 
11
  # Switch to the "user" user
12
  USER user
 
13
  # Set home to the user's home directory
14
  ENV HOME=/home/user \
15
  PYTHONPATH=$HOME/app \
 
8
 
9
  # Set up a new user named "user" with user ID 1000
10
  RUN useradd -m -u 1000 user
11
+
12
  # Switch to the "user" user
13
  USER user
14
+
15
  # Set home to the user's home directory
16
  ENV HOME=/home/user \
17
  PYTHONPATH=$HOME/app \
README.md CHANGED
@@ -1,28 +1,10 @@
1
  ---
2
- title: Matgen
3
- emoji: 🖼
4
- colorFrom: purple
5
- colorTo: red
6
  sdk: docker
7
  pinned: false
8
- license: mit
9
  ---
10
 
11
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
12
-
13
-
14
- ## setup locally
15
- conda create -n matgen python=3.11
16
- conda activate matgen
17
- pip install diffusers["torch"] transformers accelerate xformers
18
- pip install gradio
19
- pip install controlnet-aux
20
-
21
- ## local authen
22
- huggingface-cli login
23
-
24
- ## on using Huggingface ZeroGPU
25
- need to import spaces and the corresponding decorator
26
- https://huggingface.co/docs/hub/spaces-zerogpu
27
-
28
- also, check the usage of controlnet over zerogpu here: https://huggingface.co/spaces/radames/Enhance-This-HiDiffusion-SDXL/blob/main/app.py
 
1
  ---
2
+ title: Gradio Conda Template
3
+ emoji: 🐨
4
+ colorFrom: blue
5
+ colorTo: indigo
6
  sdk: docker
7
  pinned: false
 
8
  ---
9
 
10
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -1,80 +1,16 @@
1
- #!/usr/bin/env python
2
-
3
  import gradio as gr
4
- import torch
5
-
6
- import sys
7
- pyt_version_str=torch.__version__.split("+")[0].replace(".", "")
8
- version_str="".join([
9
- f"py3{sys.version_info.minor}_cu",
10
- torch.version.cuda.replace(".",""),
11
- f"_pyt{pyt_version_str}"
12
- ])
13
- print(f"Using version: {version_str}") # used to locate pytorch3d version in the requirements.txt for huggingface
14
-
15
-
16
- from app_canny import create_demo as create_demo_canny
17
- from app_texnet import create_demo as create_demo_texnet
18
-
19
- from model import Model
20
- from settings import ALLOW_CHANGING_BASE_MODEL, DEFAULT_MODEL_ID, SHOW_DUPLICATE_BUTTON
21
 
22
- DESCRIPTION = "# Material Authoring Demo v0.3"
23
 
24
- if not torch.cuda.is_available():
25
- DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p> Check if the 'CUDA_VISIBLE_DEVICES' are set incorrectly in settings.py"
26
 
27
- # model = Model(base_model_id=DEFAULT_MODEL_ID, task_name="Canny")
28
- model = Model(base_model_id=DEFAULT_MODEL_ID, task_name="texnet")
29
 
30
  with gr.Blocks() as demo:
31
- gr.Markdown(DESCRIPTION)
32
- gr.DuplicateButton(
33
- value="Duplicate Space for private use",
34
- elem_id="duplicate-button",
35
- visible=SHOW_DUPLICATE_BUTTON,
36
- )
37
-
38
- with gr.Tabs():
39
- with gr.Tab("Texnet+Matnet"):
40
- create_demo_texnet(model.process_texnet)
41
-
42
- with gr.Accordion(label="Base model", open=False):
43
- with gr.Row():
44
- with gr.Column(scale=5):
45
- current_base_model = gr.Text(label="Current base model")
46
- with gr.Column(scale=1):
47
- check_base_model_button = gr.Button("Check current base model")
48
- with gr.Row():
49
- with gr.Column(scale=5):
50
- new_base_model_id = gr.Text(
51
- label="New base model",
52
- max_lines=1,
53
- placeholder="stable-diffusion-v1-5/stable-diffusion-v1-5",
54
- info="The base model must be compatible with Stable Diffusion v1.5.",
55
- interactive=ALLOW_CHANGING_BASE_MODEL,
56
- )
57
- with gr.Column(scale=1):
58
- change_base_model_button = gr.Button("Change base model", interactive=ALLOW_CHANGING_BASE_MODEL)
59
- if not ALLOW_CHANGING_BASE_MODEL:
60
- gr.Markdown(
61
- """The base model is not allowed to be changed in this Space so as not to slow down the demo, but it can be changed if you duplicate the Space."""
62
- )
63
-
64
- check_base_model_button.click(
65
- fn=lambda: model.base_model_id,
66
- outputs=current_base_model,
67
- queue=False,
68
- api_name="check_base_model",
69
- )
70
- gr.on(
71
- triggers=[new_base_model_id.submit, change_base_model_button.click],
72
- fn=model.set_base_model,
73
- inputs=new_base_model_id,
74
- outputs=current_base_model,
75
- api_name=False,
76
- concurrency_id="main",
77
- )
78
-
79
- if __name__ == "__main__":
80
- demo.queue(max_size=20).launch()
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
 
3
 
4
+ def update(name):
5
+ return f"Welcome to Gradio, {name}!"
6
 
 
 
7
 
8
  with gr.Blocks() as demo:
9
+ gr.Markdown("Start typing below and then click **Run** to see the output.")
10
+ with gr.Row():
11
+ inp = gr.Textbox(placeholder="What is your name?")
12
+ out = gr.Textbox()
13
+ btn = gr.Button("Run")
14
+ btn.click(fn=update, inputs=inp, outputs=out)
15
+
16
+ demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app_3d.py DELETED
@@ -1,21 +0,0 @@
1
- import gradio as gr
2
- import os
3
-
4
- def load_mesh(mesh_file_name):
5
- return mesh_file_name
6
-
7
- demo = gr.Interface(
8
- fn=load_mesh,
9
- inputs=gr.Model3D(),
10
- outputs=gr.Model3D(
11
- clear_color=(255.0, 0.0, 0.0, 0.0), label="3D Model", display_mode="wireframe"),
12
- examples=[
13
- [os.path.join(os.path.dirname(__file__), "examples/bunny/mesh.obj")],
14
- [os.path.join(os.path.dirname(__file__), "examples/monkey/mesh.obj")],
15
- [os.path.join(os.path.dirname(__file__), "examples/Bunny.obj")],
16
- ],
17
- cache_examples=True
18
- )
19
-
20
- if __name__ == "__main__":
21
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app_canny.py DELETED
@@ -1,83 +0,0 @@
1
- #!/usr/bin/env python
2
-
3
- import gradio as gr
4
-
5
- from settings import (
6
- DEFAULT_IMAGE_RESOLUTION,
7
- DEFAULT_NUM_IMAGES,
8
- MAX_IMAGE_RESOLUTION,
9
- MAX_NUM_IMAGES,
10
- MAX_SEED,
11
- )
12
- from utils import randomize_seed_fn
13
-
14
-
15
- def create_demo(process):
16
- with gr.Blocks() as demo:
17
- with gr.Row():
18
- with gr.Column():
19
- image = gr.Image()
20
- prompt = gr.Textbox(label="Prompt", submit_btn=True)
21
- with gr.Accordion("Advanced options", open=False):
22
- num_samples = gr.Slider(
23
- label="Number of images", minimum=1, maximum=MAX_NUM_IMAGES, value=DEFAULT_NUM_IMAGES, step=1
24
- )
25
- image_resolution = gr.Slider(
26
- label="Image resolution",
27
- minimum=256,
28
- maximum=MAX_IMAGE_RESOLUTION,
29
- value=DEFAULT_IMAGE_RESOLUTION,
30
- step=256,
31
- )
32
- canny_low_threshold = gr.Slider(
33
- label="Canny low threshold", minimum=1, maximum=255, value=100, step=1
34
- )
35
- canny_high_threshold = gr.Slider(
36
- label="Canny high threshold", minimum=1, maximum=255, value=200, step=1
37
- )
38
- num_steps = gr.Slider(label="Number of steps", minimum=1, maximum=100, value=20, step=1)
39
- guidance_scale = gr.Slider(label="Guidance scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
40
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
41
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
42
- a_prompt = gr.Textbox(label="Additional prompt", value="best quality, extremely detailed")
43
- n_prompt = gr.Textbox(
44
- label="Negative prompt",
45
- value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
46
- )
47
- with gr.Column():
48
- result = gr.Gallery(label="Output", show_label=False, columns=2, object_fit="scale-down")
49
- inputs = [
50
- image,
51
- prompt,
52
- a_prompt,
53
- n_prompt,
54
- num_samples,
55
- image_resolution,
56
- num_steps,
57
- guidance_scale,
58
- seed,
59
- canny_low_threshold,
60
- canny_high_threshold,
61
- ]
62
- prompt.submit(
63
- fn=randomize_seed_fn,
64
- inputs=[seed, randomize_seed],
65
- outputs=seed,
66
- queue=False,
67
- api_name=False,
68
- ).then(
69
- fn=process,
70
- inputs=inputs,
71
- outputs=result,
72
- api_name="canny",
73
- concurrency_id="main",
74
- )
75
- return demo
76
-
77
-
78
- if __name__ == "__main__":
79
- from model import Model
80
-
81
- model = Model(task_name="Canny")
82
- demo = create_demo(model.process_canny)
83
- demo.queue().launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app_matnet.py DELETED
@@ -1,83 +0,0 @@
1
- #!/usr/bin/env python
2
-
3
- import gradio as gr
4
-
5
- from settings import (
6
- DEFAULT_IMAGE_RESOLUTION,
7
- DEFAULT_NUM_IMAGES,
8
- MAX_IMAGE_RESOLUTION,
9
- MAX_NUM_IMAGES,
10
- MAX_SEED,
11
- )
12
- from utils import randomize_seed_fn
13
-
14
-
15
- def create_demo(process):
16
- with gr.Blocks() as demo:
17
- with gr.Row():
18
- with gr.Column():
19
- image = gr.Image()
20
- prompt = gr.Textbox(label="Prompt", submit_btn=True)
21
- with gr.Accordion("Advanced options", open=False):
22
- num_samples = gr.Slider(
23
- label="Number of images", minimum=1, maximum=MAX_NUM_IMAGES, value=DEFAULT_NUM_IMAGES, step=1
24
- )
25
- image_resolution = gr.Slider(
26
- label="Image resolution",
27
- minimum=256,
28
- maximum=MAX_IMAGE_RESOLUTION,
29
- value=DEFAULT_IMAGE_RESOLUTION,
30
- step=256,
31
- )
32
- canny_low_threshold = gr.Slider(
33
- label="Canny low threshold", minimum=1, maximum=255, value=100, step=1
34
- )
35
- canny_high_threshold = gr.Slider(
36
- label="Canny high threshold", minimum=1, maximum=255, value=200, step=1
37
- )
38
- num_steps = gr.Slider(label="Number of steps", minimum=1, maximum=100, value=20, step=1)
39
- guidance_scale = gr.Slider(label="Guidance scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
40
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
41
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
42
- a_prompt = gr.Textbox(label="Additional prompt", value="best quality, extremely detailed")
43
- n_prompt = gr.Textbox(
44
- label="Negative prompt",
45
- value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
46
- )
47
- with gr.Column():
48
- result = gr.Gallery(label="Output", show_label=False, columns=2, object_fit="scale-down")
49
- inputs = [
50
- image,
51
- prompt,
52
- a_prompt,
53
- n_prompt,
54
- num_samples,
55
- image_resolution,
56
- num_steps,
57
- guidance_scale,
58
- seed,
59
- canny_low_threshold,
60
- canny_high_threshold,
61
- ]
62
- prompt.submit(
63
- fn=randomize_seed_fn,
64
- inputs=[seed, randomize_seed],
65
- outputs=seed,
66
- queue=False,
67
- api_name=False,
68
- ).then(
69
- fn=process,
70
- inputs=inputs,
71
- outputs=result,
72
- api_name="canny",
73
- concurrency_id="main",
74
- )
75
- return demo
76
-
77
-
78
- if __name__ == "__main__":
79
- from model import Model
80
-
81
- model = Model(task_name="Canny")
82
- demo = create_demo(model.process_canny)
83
- demo.queue().launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app_sd.py DELETED
@@ -1,154 +0,0 @@
1
- import gradio as gr
2
- import numpy as np
3
- import random
4
-
5
- import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
7
- import torch
8
-
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
11
-
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
16
-
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
19
-
20
- MAX_SEED = np.iinfo(np.int32).max
21
- MAX_IMAGE_SIZE = 1024
22
-
23
-
24
- @spaces.GPU #[uncomment to use ZeroGPU]
25
- def infer(
26
- prompt,
27
- negative_prompt,
28
- seed,
29
- randomize_seed,
30
- width,
31
- height,
32
- guidance_scale,
33
- num_inference_steps,
34
- progress=gr.Progress(track_tqdm=True),
35
- ):
36
- if randomize_seed:
37
- seed = random.randint(0, MAX_SEED)
38
-
39
- generator = torch.Generator().manual_seed(seed)
40
-
41
- image = pipe(
42
- prompt=prompt,
43
- negative_prompt=negative_prompt,
44
- guidance_scale=guidance_scale,
45
- num_inference_steps=num_inference_steps,
46
- width=width,
47
- height=height,
48
- generator=generator,
49
- ).images[0]
50
-
51
- return image, seed
52
-
53
-
54
- examples = [
55
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
- "An astronaut riding a green horse",
57
- "A delicious ceviche cheesecake slice",
58
- ]
59
-
60
- css = """
61
- #col-container {
62
- margin: 0 auto;
63
- max-width: 640px;
64
- }
65
- """
66
-
67
- with gr.Blocks(css=css) as demo:
68
- with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
70
-
71
- with gr.Row():
72
- prompt = gr.Text(
73
- label="Prompt",
74
- show_label=False,
75
- max_lines=1,
76
- placeholder="Enter your prompt",
77
- container=False,
78
- )
79
-
80
- run_button = gr.Button("Run", scale=0, variant="primary")
81
-
82
- result = gr.Image(label="Result", show_label=False)
83
-
84
- with gr.Accordion("Advanced Settings", open=False):
85
- negative_prompt = gr.Text(
86
- label="Negative prompt",
87
- max_lines=1,
88
- placeholder="Enter a negative prompt",
89
- visible=False,
90
- )
91
-
92
- seed = gr.Slider(
93
- label="Seed",
94
- minimum=0,
95
- maximum=MAX_SEED,
96
- step=1,
97
- value=0,
98
- )
99
-
100
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
101
-
102
- with gr.Row():
103
- width = gr.Slider(
104
- label="Width",
105
- minimum=256,
106
- maximum=MAX_IMAGE_SIZE,
107
- step=32,
108
- value=1024, # Replace with defaults that work for your model
109
- )
110
-
111
- height = gr.Slider(
112
- label="Height",
113
- minimum=256,
114
- maximum=MAX_IMAGE_SIZE,
115
- step=32,
116
- value=1024, # Replace with defaults that work for your model
117
- )
118
-
119
- with gr.Row():
120
- guidance_scale = gr.Slider(
121
- label="Guidance scale",
122
- minimum=0.0,
123
- maximum=10.0,
124
- step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
126
- )
127
-
128
- num_inference_steps = gr.Slider(
129
- label="Number of inference steps",
130
- minimum=1,
131
- maximum=50,
132
- step=1,
133
- value=2, # Replace with defaults that work for your model
134
- )
135
-
136
- gr.Examples(examples=examples, inputs=[prompt])
137
- gr.on(
138
- triggers=[run_button.click, prompt.submit],
139
- fn=infer,
140
- inputs=[
141
- prompt,
142
- negative_prompt,
143
- seed,
144
- randomize_seed,
145
- width,
146
- height,
147
- guidance_scale,
148
- num_inference_steps,
149
- ],
150
- outputs=[result, seed],
151
- )
152
-
153
- if __name__ == "__main__":
154
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app_texnet.py DELETED
@@ -1,259 +0,0 @@
1
- #!/usr/bin/env python
2
-
3
- import os
4
- import shutil
5
- import tempfile
6
- import gradio as gr
7
- from PIL import Image
8
- import numpy as np
9
-
10
- from settings import (
11
- DEFAULT_IMAGE_RESOLUTION,
12
- DEFAULT_NUM_IMAGES,
13
- MAX_IMAGE_RESOLUTION,
14
- MAX_NUM_IMAGES,
15
- MAX_SEED,
16
- )
17
- from utils import randomize_seed_fn
18
-
19
- # ---- helper to build a quick textured copy of the mesh ---------------
20
- def apply_texture(src_mesh:str, texture:str, tag:str)->str:
21
- """
22
- Writes a copy of `src_mesh` and tiny .mtl that points to `texture`.
23
- Returns the new OBJ/GLB path for viewing.
24
- """
25
- tmp_dir = tempfile.mkdtemp()
26
- mesh_copy = os.path.join(tmp_dir, f"{tag}.obj")
27
- mtl_name = f"{tag}.mtl"
28
-
29
- # copy geometry
30
- shutil.copy(src_mesh, mesh_copy)
31
-
32
- # write minimal MTL
33
- with open(os.path.join(tmp_dir, mtl_name), "w") as f:
34
- f.write(f"newmtl material_0\nmap_Kd {os.path.basename(texture)}\n")
35
-
36
- # ensure texture lives next to OBJ
37
- shutil.copy(texture, os.path.join(tmp_dir, os.path.basename(texture)))
38
-
39
- # patch OBJ to reference our new MTL
40
- with open(mesh_copy, "r+") as f:
41
- lines = f.readlines()
42
- if not lines[0].startswith("mtllib"):
43
- lines.insert(0, f"mtllib {mtl_name}\n")
44
- f.seek(0); f.writelines(lines)
45
-
46
- return mesh_copy
47
-
48
- def image_to_temp_path(img_like, tag, out_dir=None):
49
- """
50
- Convert various image-like objects (str, PIL.Image, list, tuple) to temp PNG path.
51
- Returns the path to the saved image file.
52
- """
53
- # Handle tuple or list input
54
- if isinstance(img_like, (list, tuple)):
55
- if len(img_like) == 0:
56
- raise ValueError("Empty image list/tuple.")
57
- img_like = img_like[0]
58
-
59
- # If it's already a file path
60
- if isinstance(img_like, str):
61
- return img_like
62
-
63
- # If it's a PIL Image
64
- if isinstance(img_like, Image.Image):
65
- temp_path = os.path.join(tempfile.mkdtemp() if out_dir is None else out_dir, f"{tag}.png")
66
- os.makedirs(os.path.dirname(temp_path), exist_ok=True)
67
- img_like.save(temp_path)
68
- return temp_path
69
-
70
- # if it's numpy array
71
- if isinstance(img_like, np.ndarray):
72
- temp_path = os.path.join(tempfile.mkdtemp() if out_dir is None else out_dir, f"{tag}.png")
73
- os.makedirs(os.path.dirname(temp_path), exist_ok=True)
74
- img_like = Image.fromarray(img_like)
75
- img_like.save(temp_path)
76
- return temp_path
77
-
78
- raise ValueError(f"Expected PIL.Image, str, list, or tuple — got {type(img_like)}")
79
-
80
- def show_mesh(which, mesh, inp, coarse, fine):
81
- """Switch the displayed texture based on dropdown change."""
82
- print()
83
- tex_map = {
84
- "Input": image_to_temp_path(inp, "input"),
85
- "Coarse": coarse[0] if isinstance(coarse, tuple) else coarse,
86
- "Fine": fine[0] if isinstance(fine, tuple) else fine,
87
- }
88
- texture_path = tex_map[which]
89
- return apply_texture(mesh, texture_path, which.lower())
90
- # ----------------------------------------------------------------------
91
-
92
-
93
- def create_demo(process):
94
- with gr.Blocks() as demo:
95
- with gr.Row():
96
- with gr.Column():
97
- gr.Markdown("## Select preset from the example list, and modify the prompt accordingly")
98
- with gr.Row():
99
- name = gr.Textbox(label="Name", interactive=False, visible=False)
100
- representative = gr.Image(label="Geometry", interactive=False)
101
- image = gr.Image(label="UV Normal", interactive=False)
102
- prompt = gr.Textbox(label="Prompt", submit_btn=True)
103
- with gr.Accordion("Advanced options", open=False):
104
- num_samples = gr.Slider(
105
- label="Number of images", minimum=1, maximum=MAX_NUM_IMAGES, value=DEFAULT_NUM_IMAGES, step=1
106
- )
107
- image_resolution = gr.Slider(
108
- label="Image resolution",
109
- minimum=256,
110
- maximum=MAX_IMAGE_RESOLUTION,
111
- value=DEFAULT_IMAGE_RESOLUTION,
112
- step=256,
113
- )
114
- num_steps = gr.Slider(label="Number of steps", minimum=1, maximum=100, value=10, step=1)
115
- guidance_scale = gr.Slider(label="Guidance scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
116
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
117
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
118
- a_prompt = gr.Textbox(label="Additional prompt", value="best quality, extremely detailed")
119
- n_prompt = gr.Textbox(
120
- label="Negative prompt",
121
- value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
122
- )
123
- with gr.Column():
124
- # 2x2 grid of images for the output textures
125
- gr.Markdown("### Output BRDF")
126
- with gr.Row():
127
- base_color = gr.Gallery(label="Base Color", show_label=True, columns=1, object_fit="scale-down")
128
- normal = gr.Gallery(label="Displacement Map", show_label=True, columns=1, object_fit="scale-down")
129
- with gr.Row():
130
- roughness = gr.Gallery(label="Roughness Map", show_label=True, columns=1, object_fit="scale-down")
131
- metallic = gr.Gallery(label="Metallic Map", show_label=True, columns=1, object_fit="scale-down")
132
-
133
- gr.Markdown("### Download Packed Blender Files for 3D Visualization")
134
- out_blender_path = gr.File(label="Generated Blender File", file_types=[".blend"])
135
-
136
- inputs = [
137
- name, # Name of the object
138
- representative, # Geometry mesh
139
- image,
140
- prompt,
141
- a_prompt,
142
- n_prompt,
143
- num_samples,
144
- image_resolution,
145
- num_steps,
146
- guidance_scale,
147
- seed,
148
- ]
149
-
150
- # first call → run diffusion / texture network
151
- prompt.submit(
152
- fn=randomize_seed_fn,
153
- inputs=[seed, randomize_seed],
154
- outputs=seed,
155
- queue=False,
156
- api_name=False,
157
- ).then(
158
- fn=process,
159
- inputs=inputs,
160
- outputs=[base_color, normal, roughness, metallic, out_blender_path],
161
- api_name="canny",
162
- concurrency_id="main",
163
- )
164
-
165
- gr.Examples(
166
- fn=process,
167
- inputs=inputs,
168
- outputs=[base_color, normal, roughness, metallic],
169
- examples=[
170
- [
171
- "bunny",
172
- "examples/bunny/frame_0001.png", # /dgxusers/Users/jyang/project/ObjectReal/data/control/preprocess/bunny/uv_normal/fused.png
173
- "examples/bunny/uv_normal.png", # /dgxusers/Users/jyang/project/ObjectReal/data/control/preprocess/bunny/uv_normal/fused.png
174
- "feather",
175
- a_prompt.value,
176
- n_prompt.value,
177
- num_samples.value,
178
- image_resolution.value,
179
- num_steps.value,
180
- guidance_scale.value,
181
- seed.value,
182
- ],
183
- [
184
- "monkey",
185
- "examples/monkey/frame_0001.png", # /dgxusers/Users/jyang/project/ObjectReal/data/control/preprocess/monkey/uv_normal/fused.png
186
- "examples/monkey/uv_normal.png", # /dgxusers/Users/jyang/project/ObjectReal/data/control/preprocess/monkey/uv_normal/fused.png
187
- "wood",
188
- a_prompt.value,
189
- n_prompt.value,
190
- num_samples.value,
191
- image_resolution.value,
192
- num_steps.value,
193
- guidance_scale.value,
194
- seed.value,
195
- ],
196
- [
197
- "tshirt",
198
- "examples/tshirt/frame_0001.png", # /dgxusers/Users/jyang/project/ObjectReal/data/control/preprocess/monkey/uv_normal/fused.png
199
- "examples/tshirt/uv_normal.png", # /dgxusers/Users/jyang/project/ObjectReal/data/control/preprocess/monkey/uv_normal/fused.png
200
- "wood",
201
- a_prompt.value,
202
- n_prompt.value,
203
- num_samples.value,
204
- image_resolution.value,
205
- num_steps.value,
206
- guidance_scale.value,
207
- seed.value,
208
- ],
209
- # [
210
- # "highheel",
211
- # "examples/highheel/frame_0001.png", # /dgxusers/Users/jyang/project/ObjectReal/data/control/preprocess/monkey/uv_normal/fused.png
212
- # "examples/highheel/uv_normal.png", # /dgxusers/Users/jyang/project/ObjectReal/data/control/preprocess/monkey/uv_normal/fused.png
213
- # "wood",
214
- # a_prompt.value,
215
- # n_prompt.value,
216
- # num_samples.value,
217
- # image_resolution.value,
218
- # num_steps.value,
219
- # guidance_scale.value,
220
- # seed.value,
221
- # ],
222
- [
223
- "tank",
224
- "examples/tank/frame_0001.png", # /dgxusers/Users/jyang/project/ObjectReal/data/control/preprocess/monkey/uv_normal/fused.png
225
- "examples/tank/uv_normal.png", # /dgxusers/Users/jyang/project/ObjectReal/data/control/preprocess/monkey/uv_normal/fused.png
226
- "wood",
227
- a_prompt.value,
228
- n_prompt.value,
229
- num_samples.value,
230
- image_resolution.value,
231
- num_steps.value,
232
- guidance_scale.value,
233
- seed.value,
234
- ],
235
- [
236
- "fighter",
237
- "examples/fighter/frame_0001.png", # /dgxusers/Users/jyang/project/ObjectReal/data/control/preprocess/monkey/uv_normal/fused.png
238
- "examples/fighter/uv_normal.png", # /dgxusers/Users/jyang/project/ObjectReal/data/control/preprocess/monkey/uv_normal/fused.png
239
- "wood",
240
- a_prompt.value,
241
- n_prompt.value,
242
- num_samples.value,
243
- image_resolution.value,
244
- num_steps.value,
245
- guidance_scale.value,
246
- seed.value,
247
- ],
248
- ],
249
- )
250
-
251
- return demo
252
-
253
-
254
- if __name__ == "__main__":
255
- from model import Model
256
-
257
- model = Model(task_name="Texnet")
258
- demo = create_demo(model.process_texnet)
259
- demo.queue().launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cv_utils.py DELETED
@@ -1,17 +0,0 @@
1
- import cv2
2
- import numpy as np
3
-
4
-
5
- def resize_image(input_image, resolution, interpolation=None):
6
- H, W, C = input_image.shape
7
- H = float(H)
8
- W = float(W)
9
- k = float(resolution) / max(H, W)
10
- H *= k
11
- W *= k
12
- H = int(np.round(H / 64.0)) * 64
13
- W = int(np.round(W / 64.0)) * 64
14
- if interpolation is None:
15
- interpolation = cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA
16
- img = cv2.resize(input_image, (W, H), interpolation=interpolation)
17
- return img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
depth_estimator.py DELETED
@@ -1,25 +0,0 @@
1
- import numpy as np
2
- import PIL.Image
3
- from controlnet_aux.util import HWC3
4
- from transformers import pipeline
5
-
6
- from cv_utils import resize_image
7
-
8
-
9
- class DepthEstimator:
10
- def __init__(self):
11
- self.model = pipeline("depth-estimation")
12
-
13
- def __call__(self, image: np.ndarray, **kwargs) -> PIL.Image.Image:
14
- detect_resolution = kwargs.pop("detect_resolution", 512)
15
- image_resolution = kwargs.pop("image_resolution", 512)
16
- image = np.array(image)
17
- image = HWC3(image)
18
- image = resize_image(image, resolution=detect_resolution)
19
- image = PIL.Image.fromarray(image)
20
- image = self.model(image)
21
- image = image["depth"]
22
- image = np.array(image)
23
- image = HWC3(image)
24
- image = resize_image(image, resolution=image_resolution)
25
- return PIL.Image.fromarray(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/bunny/frame_0001.png DELETED

Git LFS Details

  • SHA256: 7fedab3e148faac233d7eba7b2ab92f02998b8b3ba6a3ab1e3b823f1fdedf51b
  • Pointer size: 131 Bytes
  • Size of remote file: 468 kB
examples/bunny/mesh.obj DELETED
The diff for this file is too large to render. See raw diff
 
examples/bunny/uv_normal.png DELETED

Git LFS Details

  • SHA256: 03e7c7aa7f14454b3b179aa4c5c30863e5c74c67fd858f8ea6c28e93630ecec0
  • Pointer size: 132 Bytes
  • Size of remote file: 2.3 MB
examples/fighter/frame_0001.png DELETED

Git LFS Details

  • SHA256: 2ffaa00d5cd340167e7b13d0dd986dc6a6680c4b91595eec8d27d384f6670df7
  • Pointer size: 131 Bytes
  • Size of remote file: 423 kB
examples/fighter/mesh.obj DELETED
The diff for this file is too large to render. See raw diff
 
examples/fighter/uv_normal.png DELETED

Git LFS Details

  • SHA256: 46d4c010107c4fa030ead5ca1b8ca66ade255bcb4194dc30e9f1195bba2da672
  • Pointer size: 131 Bytes
  • Size of remote file: 753 kB
examples/highheel/frame_0001.png DELETED

Git LFS Details

  • SHA256: 9b9b91e5f99c06dd11372aa2ccb44cb996bc103ea19eacb94ee5611478c831b8
  • Pointer size: 131 Bytes
  • Size of remote file: 490 kB
examples/highheel/mesh.obj DELETED
The diff for this file is too large to render. See raw diff
 
examples/highheel/uv_normal.png DELETED

Git LFS Details

  • SHA256: 52cf0dd687067109f160ba5015078077a9a6187c09305af5c57eec3c3d05c885
  • Pointer size: 132 Bytes
  • Size of remote file: 1.57 MB
examples/monkey/frame_0001.png DELETED

Git LFS Details

  • SHA256: da466a3467871077c00e6ab2a6d105afb837f6372cf48867a31dd86ca6fdd157
  • Pointer size: 131 Bytes
  • Size of remote file: 491 kB
examples/monkey/mesh.obj DELETED
The diff for this file is too large to render. See raw diff
 
examples/monkey/uv_normal.png DELETED

Git LFS Details

  • SHA256: f3710880d4777042bb2838fa01988e653ad5d932ac8ba5a817eb13869902ba03
  • Pointer size: 132 Bytes
  • Size of remote file: 2.03 MB
examples/tank/frame_0001.png DELETED

Git LFS Details

  • SHA256: 58a1cd1df7b94ad52568952a46f6b9cf57d62c81290cca1c967250af7a15316b
  • Pointer size: 131 Bytes
  • Size of remote file: 512 kB
examples/tank/mesh.obj DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:301633de1a7757f78a6f67abb6e61bcc8e6a01f5a54a8582d1943ad0ad943211
3
- size 6942253
 
 
 
 
examples/tank/uv_normal.png DELETED

Git LFS Details

  • SHA256: 9a7d1d168addc29d7953ea222cfb125b2d802188747e800b8e63dd686bcf9c06
  • Pointer size: 132 Bytes
  • Size of remote file: 6.13 MB
examples/tshirt/frame_0001.png DELETED

Git LFS Details

  • SHA256: 31f2e5239afc351695d176aeeefee23359f43b4c4b4fe40a1793bb9ccb80464b
  • Pointer size: 131 Bytes
  • Size of remote file: 496 kB
examples/tshirt/mesh.obj DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:b7c6c9bdec8d646a1980e5b987a1182c92af84cc945ef49c1735d4337185d3e5
3
- size 39275876
 
 
 
 
examples/tshirt/uv_normal.png DELETED

Git LFS Details

  • SHA256: ee7f1df0f853fab91acdaf0240a6bf1444d8db56c310fe30efc6c98cc18c36c9
  • Pointer size: 132 Bytes
  • Size of remote file: 2.17 MB
image_segmentor.py DELETED
@@ -1,33 +0,0 @@
1
- import cv2
2
- import numpy as np
3
- import PIL.Image
4
- import torch
5
- from controlnet_aux.util import HWC3, ade_palette
6
- from transformers import AutoImageProcessor, UperNetForSemanticSegmentation
7
-
8
- from cv_utils import resize_image
9
-
10
-
11
- class ImageSegmentor:
12
- def __init__(self):
13
- self.image_processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-small")
14
- self.image_segmentor = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-convnext-small")
15
-
16
- @torch.inference_mode()
17
- def __call__(self, image: np.ndarray, **kwargs) -> PIL.Image.Image:
18
- detect_resolution = kwargs.pop("detect_resolution", 512)
19
- image_resolution = kwargs.pop("image_resolution", 512)
20
- image = HWC3(image)
21
- image = resize_image(image, resolution=detect_resolution)
22
- image = PIL.Image.fromarray(image)
23
-
24
- pixel_values = self.image_processor(image, return_tensors="pt").pixel_values
25
- outputs = self.image_segmentor(pixel_values)
26
- seg = self.image_processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
27
- color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
28
- for label, color in enumerate(ade_palette()):
29
- color_seg[seg == label, :] = color
30
- color_seg = color_seg.astype(np.uint8)
31
-
32
- color_seg = resize_image(color_seg, resolution=image_resolution, interpolation=cv2.INTER_NEAREST)
33
- return PIL.Image.fromarray(color_seg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
install.sh DELETED
@@ -1,18 +0,0 @@
1
- #!/bin/bash
2
- eval "$(conda shell.bash hook)"
3
- # conda activate base
4
- # conda remove -n matgen-plus --all
5
-
6
- conda create -n matgen-plus python=3.11
7
- conda activate matgen-plus
8
-
9
- pip install diffusers["torch"] transformers accelerate xformers
10
- pip install gradio
11
- pip install controlnet-aux
12
-
13
- # text2tex
14
- conda install pytorch3d -c pytorch -c conda-forge
15
- conda install -c conda-forge open-clip-torch pytorch-lightning
16
- pip install trimesh xatlas scikit-learn opencv-python omegaconf
17
-
18
- python app.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model.py DELETED
@@ -1,959 +0,0 @@
1
- import gc
2
-
3
- # get socket and check if the name is vgldgx01
4
- import socket
5
- if socket.gethostname() != "vgldgx01":
6
- import spaces #[uncomment to use ZeroGPU]
7
-
8
- import numpy as np
9
- import PIL.Image
10
- import torch
11
- from controlnet_aux.util import HWC3
12
- from diffusers import (
13
- ControlNetModel,
14
- DiffusionPipeline,
15
- StableDiffusionControlNetPipeline,
16
- StableDiffusionImg2ImgPipeline,
17
- UniPCMultistepScheduler,
18
- DDIMScheduler, #rgb2x
19
- )
20
- import torchvision
21
- from torchvision import transforms
22
- from cv_utils import resize_image
23
- from preprocessor import Preprocessor
24
- from settings import MAX_IMAGE_RESOLUTION, MAX_NUM_IMAGES
25
- from tqdm.auto import tqdm
26
- import subprocess
27
-
28
- from rgb2x.pipeline_rgb2x import StableDiffusionAOVMatEstPipeline
29
- from app_texnet import image_to_temp_path
30
- import os
31
- import time
32
- import tempfile
33
- from text2tex.scripts.generate_texture import text2tex_call, init_args
34
- from glob import glob
35
-
36
- CONTROLNET_MODEL_IDS = {
37
- # "Openpose": "lllyasviel/control_v11p_sd15_openpose",
38
- # "Canny": "lllyasviel/control_v11p_sd15_canny",
39
- # "MLSD": "lllyasviel/control_v11p_sd15_mlsd",
40
- # "scribble": "lllyasviel/control_v11p_sd15_scribble",
41
- # "softedge": "lllyasviel/control_v11p_sd15_softedge",
42
- # "segmentation": "lllyasviel/control_v11p_sd15_seg",
43
- # "depth": "lllyasviel/control_v11f1p_sd15_depth",
44
- # "NormalBae": "lllyasviel/control_v11p_sd15_normalbae",
45
- # "lineart": "lllyasviel/control_v11p_sd15_lineart",
46
- # "lineart_anime": "lllyasviel/control_v11p_sd15s2_lineart_anime",
47
- # "shuffle": "lllyasviel/control_v11e_sd15_shuffle",
48
- # "ip2p": "lllyasviel/control_v11e_sd15_ip2p",
49
- # "inpaint": "lllyasviel/control_v11e_sd15_inpaint",
50
- # "texnet": "/home/jyang/projects/ObjectReal/logs/train_texnet_deploy/checkpoint-55000/controlnet" # load and call
51
- "texnet": "jingyangcarl/texnet",
52
- }
53
-
54
-
55
- def download_all_controlnet_weights() -> None:
56
- for model_id in CONTROLNET_MODEL_IDS.values():
57
- ControlNetModel.from_pretrained(model_id)
58
-
59
-
60
- class Model:
61
- def __init__(
62
- self, base_model_id: str = "stable-diffusion-v1-5/stable-diffusion-v1-5", task_name: str = "Canny"
63
- ) -> None:
64
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
65
- self.base_model_id = ""
66
- self.task_name = ""
67
- self.pipe = self.load_pipe(base_model_id, task_name)
68
- self.pipe_base = StableDiffusionImg2ImgPipeline.from_pretrained(
69
- 'runwayml/stable-diffusion-v1-5', safety_checker=None, torch_dtype=torch.float16
70
- ).to(self.device)
71
- self.preprocessor = Preprocessor()
72
-
73
- # set up pipe_rgb2x
74
- self.pipe_rgb2x = StableDiffusionAOVMatEstPipeline.from_pretrained(
75
- "zheng95z/rgb-to-x",
76
- torch_dtype=torch.float16,
77
- ).to(self.device)
78
- self.pipe_rgb2x.scheduler = DDIMScheduler.from_config(
79
- self.pipe_rgb2x.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing"
80
- )
81
- self.pipe_rgb2x.set_progress_bar_config(disable=True)
82
-
83
- # setup blender
84
- self.blender_path = '/tmp/blender-3.2.2-linux-x64/blender'
85
- if not os.path.exists(self.blender_path):
86
- print("Downloading Blender...")
87
- subprocess.run(["wget", "https://download.blender.org/release/Blender3.2/blender-3.2.2-linux-x64.tar.xz", "-O", "/tmp/blender-3.2.2-linux-x64.tar.xz"], check=True)
88
- subprocess.run(["tar", "-xf", "/tmp/blender-3.2.2-linux-x64.tar.xz", "-C", "/tmp"], check=True)
89
- print("Blender downloaded and extracted.")
90
-
91
- def load_pipe(self, base_model_id: str, task_name: str) -> DiffusionPipeline:
92
- if (
93
- base_model_id == self.base_model_id
94
- and task_name == self.task_name
95
- and hasattr(self, "pipe")
96
- and self.pipe is not None
97
- ):
98
- return self.pipe
99
- model_id = CONTROLNET_MODEL_IDS[task_name]
100
- controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float16)
101
- to_upload = False
102
- if to_upload:
103
- # confirm before uploading
104
- confirm = input(f"Do you want to upload {model_id} to the hub? (y/n): ")
105
- if confirm.lower() == "y":
106
- controlnet.push_to_hub("jingyangcarl/texnet")
107
- else:
108
- print("Upload cancelled.")
109
- pipe = StableDiffusionControlNetPipeline.from_pretrained(
110
- base_model_id, safety_checker=None, controlnet=controlnet, torch_dtype=torch.float16
111
- )
112
- pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
113
- pipe.to(self.device)
114
- if self.device.type == "cuda":
115
- import os
116
- if os.environ.get("SPACES_ZERO_GPU", "0") == "1":
117
- # when running on ZeroGPU, enable CPU offload
118
- # pipe.enable_xformers_memory_efficient_attention() doens't work
119
- # pipe.enable_model_cpu_offload()
120
- pass
121
- else:
122
- pipe.enable_xformers_memory_efficient_attention()
123
- torch.cuda.empty_cache()
124
- gc.collect()
125
- self.base_model_id = base_model_id
126
- self.task_name = task_name
127
- return pipe
128
-
129
- def set_base_model(self, base_model_id: str) -> str:
130
- if not base_model_id or base_model_id == self.base_model_id:
131
- return self.base_model_id
132
- del self.pipe
133
- torch.cuda.empty_cache()
134
- gc.collect()
135
- try:
136
- self.pipe = self.load_pipe(base_model_id, self.task_name)
137
- except Exception: # noqa: BLE001
138
- self.pipe = self.load_pipe(self.base_model_id, self.task_name)
139
- return self.base_model_id
140
-
141
- def load_controlnet_weight(self, task_name: str) -> None:
142
- if task_name == self.task_name:
143
- return
144
- if self.pipe is not None and hasattr(self.pipe, "controlnet"):
145
- del self.pipe.controlnet
146
- torch.cuda.empty_cache()
147
- gc.collect()
148
- model_id = CONTROLNET_MODEL_IDS[task_name]
149
- controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float16)
150
- controlnet.to(self.device)
151
- torch.cuda.empty_cache()
152
- gc.collect()
153
- self.pipe.controlnet = controlnet
154
- self.task_name = task_name
155
-
156
- def get_prompt(self, prompt: str, additional_prompt: str) -> str:
157
- return additional_prompt if not prompt else f"{prompt}, {additional_prompt}"
158
-
159
- # @spaces.GPU #[uncomment to use ZeroGPU]
160
- @torch.autocast("cuda")
161
- def run_pipe(
162
- self,
163
- prompt: str,
164
- negative_prompt: str,
165
- control_image: PIL.Image.Image,
166
- num_images: int,
167
- num_steps: int,
168
- guidance_scale: float,
169
- seed: int,
170
- ) -> list[PIL.Image.Image]:
171
- generator = torch.Generator().manual_seed(seed)
172
- # self.pipe.to(self.device)
173
- return self.pipe(
174
- prompt=prompt,
175
- negative_prompt=negative_prompt,
176
- guidance_scale=guidance_scale,
177
- num_images_per_prompt=num_images,
178
- num_inference_steps=num_steps,
179
- generator=generator,
180
- image=control_image,
181
- ).images
182
-
183
- # @spaces.GPU #[uncomment to use ZeroGPU]
184
- @torch.inference_mode()
185
- def process_texnet(
186
- self,
187
- obj_name: str,
188
- represented_image: np.ndarray | None, # not used
189
- image: np.ndarray,
190
- prompt: str,
191
- additional_prompt: str,
192
- negative_prompt: str,
193
- num_images: int,
194
- image_resolution: int,
195
- num_steps: int,
196
- guidance_scale: float,
197
- seed: int,
198
- low_threshold: int,
199
- high_threshold: int,
200
- ) -> list[PIL.Image.Image]:
201
- if image is None:
202
- raise ValueError
203
- if image_resolution > MAX_IMAGE_RESOLUTION:
204
- raise ValueError
205
- if num_images > MAX_NUM_IMAGES:
206
- raise ValueError
207
-
208
- prompt_nospace = prompt.replace(' ', '_')
209
-
210
- # self.preprocessor.load("texnet")
211
- # control_image = self.preprocessor(
212
- # image=image, low_threshold=low_threshold, high_threshold=high_threshold, image_resolution=image_resolution, output_type="pil"
213
- # )
214
-
215
- # self.load_controlnet_weight("texnet")
216
- # tex_coarse = self.run_pipe(
217
- # prompt=self.get_prompt(prompt, additional_prompt),
218
- # negative_prompt=negative_prompt,
219
- # control_image=control_image,
220
- # num_images=num_images,
221
- # num_steps=num_steps,
222
- # guidance_scale=guidance_scale,
223
- # seed=seed,
224
- # )
225
-
226
- # # use img2img pipeline
227
- # self.pipe_backup = self.pipe
228
- # self.pipe = self.pipe_base
229
-
230
- # # refine
231
- tex_fine = []
232
- mesh_fine = []
233
- # for result_coarse in tex_coarse:
234
- # # clean up GPU cache
235
- # torch.cuda.empty_cache()
236
- # gc.collect()
237
-
238
- # # masking
239
- # mask = (np.array(control_image).sum(axis=-1) == 0)[...,None]
240
- # image_masked = PIL.Image.fromarray(np.where(mask, control_image, result_coarse))
241
- # image_blurry = transforms.GaussianBlur(kernel_size=5, sigma=1)(image_masked)
242
- # result_fine = self.run_pipe(
243
- # # prompt=prompt,
244
- # prompt=self.get_prompt(prompt, additional_prompt),
245
- # negative_prompt=negative_prompt,
246
- # control_image=image_blurry,
247
- # num_images=1,
248
- # num_steps=num_steps,
249
- # guidance_scale=guidance_scale,
250
- # seed=seed,
251
- # )[0]
252
- # result_fine = PIL.Image.fromarray(np.where(mask, control_image, result_fine))
253
- # tex_fine.append(result_fine)
254
-
255
- temp_out_path = tempfile.mkdtemp()
256
- temp_out_path = 'output'
257
-
258
- # put text2tex here,
259
- args = init_args()
260
- args.input_dir = f'examples/{obj_name}/'
261
- args.output_dir = os.path.join(temp_out_path, f'{obj_name}/{prompt_nospace}')
262
- args.obj_name = obj_name
263
- args.obj_file = 'mesh.obj'
264
- args.prompt = f'{prompt} {obj_name}'
265
- args.add_view_to_prompt = True
266
- args.ddim_steps = 5
267
- # args.ddim_steps = 50
268
- args.new_strength = 1.0
269
- args.update_strength = 0.3
270
- args.view_threshold = 0.1
271
- args.blend = 0
272
- args.dist = 1
273
- args.num_viewpoints = 2
274
- # args.num_viewpoints = 36
275
- args.viewpoint_mode = 'predefined'
276
- args.use_principle = True
277
- args.update_steps = 2
278
- # args.update_steps = 20
279
- args.update_mode = 'heuristic'
280
- args.seed = 42
281
- args.post_process = True
282
- args.device = '2080'
283
- args.uv_size = 1000
284
- args.image_size = 512
285
- # args.image_size = 768
286
- args.use_objaverse = True # assume the mesh is normalized with y-axis as up
287
- output_dir = text2tex_call(args)
288
-
289
- # get the texture and mesh with underscore '_post', which is the id of the last mesh, should be good for the visual
290
- post_idx = glob(os.path.join(output_dir, 'update', 'mesh', "*_post.png"))[0].split('/')[-1].split('_')[0]
291
-
292
- tex_fine.append(PIL.Image.open(os.path.join(output_dir, 'update', 'mesh', f"{post_idx}.png")).convert("RGB"))
293
- mesh_fine.append(os.path.join(output_dir, 'update', 'mesh', f"{post_idx}.obj"))
294
- torch.cuda.empty_cache()
295
-
296
- # restore the original pipe
297
- # self.pipe = self.pipe_backup
298
-
299
- # use rgb2x for now for generating the texture
300
- def rgb2x(
301
- pipeline,
302
- photo,
303
- inference_step = 50,
304
- num_samples = 1,
305
- ):
306
- generator = torch.Generator(device="cuda").manual_seed(seed)
307
-
308
- # Check if the width and height are multiples of 8. If not, crop it using torchvision.transforms.CenterCrop
309
- old_height = photo.shape[1]
310
- old_width = photo.shape[2]
311
- new_height = old_height
312
- new_width = old_width
313
- radio = old_height / old_width
314
- max_side = 1000
315
- if old_height > old_width:
316
- new_height = max_side
317
- new_width = int(new_height / radio)
318
- else:
319
- new_width = max_side
320
- new_height = int(new_width * radio)
321
-
322
- if new_width % 8 != 0 or new_height % 8 != 0:
323
- new_width = new_width // 8 * 8
324
- new_height = new_height // 8 * 8
325
-
326
- photo = torchvision.transforms.Resize((new_height, new_width))(photo)
327
-
328
- required_aovs = ["albedo", "normal", "roughness", "metallic", "irradiance"]
329
- prompts = {
330
- "albedo": "Albedo (diffuse basecolor)",
331
- "normal": "Camera-space Normal",
332
- "roughness": "Roughness",
333
- "metallic": "Metallicness",
334
- "irradiance": "Irradiance (diffuse lighting)",
335
- }
336
-
337
- return_list = []
338
- for i in tqdm(range(num_samples), desc="Running Pipeline", leave=False):
339
- for aov_name in required_aovs:
340
- prompt = prompts[aov_name]
341
- generated_image = pipeline(
342
- prompt=prompt,
343
- photo=photo,
344
- num_inference_steps=inference_step,
345
- height=new_height,
346
- width=new_width,
347
- generator=generator,
348
- required_aovs=[aov_name],
349
- ).images[0][0]
350
-
351
- generated_image = torchvision.transforms.Resize(
352
- (old_height, old_width)
353
- )(generated_image)
354
-
355
- # generated_image = (generated_image, f"Generated {aov_name} {i}")
356
- # generated_image = (generated_image, f"{aov_name}")
357
- return_list.append(generated_image)
358
-
359
- return photo, return_list, prompts
360
-
361
- # Load rgb2x pipeline
362
- _, preds, prompts = rgb2x(self.pipe_rgb2x, torchvision.transforms.PILToTensor()(tex_fine[0]).to(self.pipe.device), inference_step=num_steps, num_samples=num_images)
363
-
364
- intrinsic_dir = os.path.join(output_dir, 'intrinsic')
365
- use_text2tex = True
366
- if use_text2tex:
367
- base_color_path = image_to_temp_path(tex_fine[0], "base_color", out_dir=intrinsic_dir)
368
- normal_map_path = image_to_temp_path(preds[0], "normal_map", out_dir=intrinsic_dir)
369
- roughness_path = image_to_temp_path(preds[1], "roughness", out_dir=intrinsic_dir)
370
- metallic_path = image_to_temp_path(preds[2], "metallic", out_dir=intrinsic_dir)
371
- else:
372
- base_color_path = image_to_temp_path(tex_fine[0].rotate(90), "base_color", out_dir=intrinsic_dir)
373
- normal_map_path = image_to_temp_path(preds[0].rotate(90), "normal_map", out_dir=intrinsic_dir)
374
- roughness_path = image_to_temp_path(preds[1].rotate(90), "roughness", out_dir=intrinsic_dir)
375
- metallic_path = image_to_temp_path(preds[2].rotate(90), "metallic", out_dir=intrinsic_dir)
376
- current_timecode = time.strftime("%Y%m%d_%H%M%S")
377
- # output_blend_path = os.path.join(os.getcwd(), "output", f"{obj_name}_{prompt_nospace}_{current_timecode}.blend") # replace with desired output path
378
- output_blend_path = os.path.join(tempfile.mkdtemp(), f"{obj_name}_{prompt_nospace}_{current_timecode}.blend") # replace with desired output path
379
- os.makedirs(os.path.dirname(output_blend_path), exist_ok=True)
380
-
381
- def run_blend_generation(
382
- blender_path,
383
- generate_script_path,
384
- obj_path,
385
- base_color_path,
386
- normal_map_path,
387
- roughness_path,
388
- metallic_path,
389
- output_blend
390
- ):
391
- cmd = [
392
- blender_path, "--background", "--python", generate_script_path, "--",
393
- obj_path, base_color_path, normal_map_path, roughness_path, metallic_path, output_blend
394
- ]
395
- subprocess.run(cmd, check=True)
396
-
397
- # check if the blender_path exists, if not download
398
- run_blend_generation(
399
- blender_path=self.blender_path,
400
- generate_script_path="rgb2x/generate_blend.py",
401
- # obj_path=f"examples/{obj_name}/mesh.obj", # replace with actual mesh path
402
- obj_path=mesh_fine[0], # replace with actual mesh path
403
- base_color_path=base_color_path,
404
- normal_map_path=normal_map_path,
405
- roughness_path=roughness_path,
406
- metallic_path=metallic_path,
407
- output_blend=output_blend_path # replace with desired output path
408
- )
409
-
410
- # gallary
411
- return [*tex_fine], [preds[1]], [preds[2]], [preds[3]], [output_blend_path]
412
-
413
- # @spaces.GPU #[uncomment to use ZeroGPU]
414
- @torch.inference_mode()
415
- def process_canny(
416
- self,
417
- image: np.ndarray,
418
- prompt: str,
419
- additional_prompt: str,
420
- negative_prompt: str,
421
- num_images: int,
422
- image_resolution: int,
423
- num_steps: int,
424
- guidance_scale: float,
425
- seed: int,
426
- low_threshold: int,
427
- high_threshold: int,
428
- ) -> list[PIL.Image.Image]:
429
- if image is None:
430
- raise ValueError
431
- if image_resolution > MAX_IMAGE_RESOLUTION:
432
- raise ValueError
433
- if num_images > MAX_NUM_IMAGES:
434
- raise ValueError
435
-
436
- self.preprocessor.load("Canny")
437
- control_image = self.preprocessor(
438
- image=image, low_threshold=low_threshold, high_threshold=high_threshold, detect_resolution=image_resolution
439
- )
440
-
441
- self.load_controlnet_weight("Canny")
442
- results = self.run_pipe(
443
- prompt=self.get_prompt(prompt, additional_prompt),
444
- negative_prompt=negative_prompt,
445
- control_image=control_image,
446
- num_images=num_images,
447
- num_steps=num_steps,
448
- guidance_scale=guidance_scale,
449
- seed=seed,
450
- )
451
- return [control_image, *results]
452
-
453
- @torch.inference_mode()
454
- def process_mlsd(
455
- self,
456
- image: np.ndarray,
457
- prompt: str,
458
- additional_prompt: str,
459
- negative_prompt: str,
460
- num_images: int,
461
- image_resolution: int,
462
- preprocess_resolution: int,
463
- num_steps: int,
464
- guidance_scale: float,
465
- seed: int,
466
- value_threshold: float,
467
- distance_threshold: float,
468
- ) -> list[PIL.Image.Image]:
469
- if image is None:
470
- raise ValueError
471
- if image_resolution > MAX_IMAGE_RESOLUTION:
472
- raise ValueError
473
- if num_images > MAX_NUM_IMAGES:
474
- raise ValueError
475
-
476
- self.preprocessor.load("MLSD")
477
- control_image = self.preprocessor(
478
- image=image,
479
- image_resolution=image_resolution,
480
- detect_resolution=preprocess_resolution,
481
- thr_v=value_threshold,
482
- thr_d=distance_threshold,
483
- )
484
- self.load_controlnet_weight("MLSD")
485
- results = self.run_pipe(
486
- prompt=self.get_prompt(prompt, additional_prompt),
487
- negative_prompt=negative_prompt,
488
- control_image=control_image,
489
- num_images=num_images,
490
- num_steps=num_steps,
491
- guidance_scale=guidance_scale,
492
- seed=seed,
493
- )
494
- return [control_image, *results]
495
-
496
- @torch.inference_mode()
497
- def process_scribble(
498
- self,
499
- image: np.ndarray,
500
- prompt: str,
501
- additional_prompt: str,
502
- negative_prompt: str,
503
- num_images: int,
504
- image_resolution: int,
505
- preprocess_resolution: int,
506
- num_steps: int,
507
- guidance_scale: float,
508
- seed: int,
509
- preprocessor_name: str,
510
- ) -> list[PIL.Image.Image]:
511
- if image is None:
512
- raise ValueError
513
- if image_resolution > MAX_IMAGE_RESOLUTION:
514
- raise ValueError
515
- if num_images > MAX_NUM_IMAGES:
516
- raise ValueError
517
-
518
- if preprocessor_name == "None":
519
- image = HWC3(image)
520
- image = resize_image(image, resolution=image_resolution)
521
- control_image = PIL.Image.fromarray(image)
522
- elif preprocessor_name == "HED":
523
- self.preprocessor.load(preprocessor_name)
524
- control_image = self.preprocessor(
525
- image=image,
526
- image_resolution=image_resolution,
527
- detect_resolution=preprocess_resolution,
528
- scribble=False,
529
- )
530
- elif preprocessor_name == "PidiNet":
531
- self.preprocessor.load(preprocessor_name)
532
- control_image = self.preprocessor(
533
- image=image,
534
- image_resolution=image_resolution,
535
- detect_resolution=preprocess_resolution,
536
- safe=False,
537
- )
538
- self.load_controlnet_weight("scribble")
539
- results = self.run_pipe(
540
- prompt=self.get_prompt(prompt, additional_prompt),
541
- negative_prompt=negative_prompt,
542
- control_image=control_image,
543
- num_images=num_images,
544
- num_steps=num_steps,
545
- guidance_scale=guidance_scale,
546
- seed=seed,
547
- )
548
- return [control_image, *results]
549
-
550
- @torch.inference_mode()
551
- def process_scribble_interactive(
552
- self,
553
- image_and_mask: dict[str, np.ndarray | list[np.ndarray]] | None,
554
- prompt: str,
555
- additional_prompt: str,
556
- negative_prompt: str,
557
- num_images: int,
558
- image_resolution: int,
559
- num_steps: int,
560
- guidance_scale: float,
561
- seed: int,
562
- ) -> list[PIL.Image.Image]:
563
- if image_and_mask is None:
564
- raise ValueError
565
- if image_resolution > MAX_IMAGE_RESOLUTION:
566
- raise ValueError
567
- if num_images > MAX_NUM_IMAGES:
568
- raise ValueError
569
-
570
- image = 255 - image_and_mask["composite"] # type: ignore
571
- image = HWC3(image)
572
- image = resize_image(image, resolution=image_resolution)
573
- control_image = PIL.Image.fromarray(image)
574
-
575
- self.load_controlnet_weight("scribble")
576
- results = self.run_pipe(
577
- prompt=self.get_prompt(prompt, additional_prompt),
578
- negative_prompt=negative_prompt,
579
- control_image=control_image,
580
- num_images=num_images,
581
- num_steps=num_steps,
582
- guidance_scale=guidance_scale,
583
- seed=seed,
584
- )
585
- return [control_image, *results]
586
-
587
- @torch.inference_mode()
588
- def process_softedge(
589
- self,
590
- image: np.ndarray,
591
- prompt: str,
592
- additional_prompt: str,
593
- negative_prompt: str,
594
- num_images: int,
595
- image_resolution: int,
596
- preprocess_resolution: int,
597
- num_steps: int,
598
- guidance_scale: float,
599
- seed: int,
600
- preprocessor_name: str,
601
- ) -> list[PIL.Image.Image]:
602
- if image is None:
603
- raise ValueError
604
- if image_resolution > MAX_IMAGE_RESOLUTION:
605
- raise ValueError
606
- if num_images > MAX_NUM_IMAGES:
607
- raise ValueError
608
-
609
- if preprocessor_name == "None":
610
- image = HWC3(image)
611
- image = resize_image(image, resolution=image_resolution)
612
- control_image = PIL.Image.fromarray(image)
613
- elif preprocessor_name in ["HED", "HED safe"]:
614
- safe = "safe" in preprocessor_name
615
- self.preprocessor.load("HED")
616
- control_image = self.preprocessor(
617
- image=image,
618
- image_resolution=image_resolution,
619
- detect_resolution=preprocess_resolution,
620
- scribble=safe,
621
- )
622
- elif preprocessor_name in ["PidiNet", "PidiNet safe"]:
623
- safe = "safe" in preprocessor_name
624
- self.preprocessor.load("PidiNet")
625
- control_image = self.preprocessor(
626
- image=image,
627
- image_resolution=image_resolution,
628
- detect_resolution=preprocess_resolution,
629
- safe=safe,
630
- )
631
- else:
632
- raise ValueError
633
- self.load_controlnet_weight("softedge")
634
- results = self.run_pipe(
635
- prompt=self.get_prompt(prompt, additional_prompt),
636
- negative_prompt=negative_prompt,
637
- control_image=control_image,
638
- num_images=num_images,
639
- num_steps=num_steps,
640
- guidance_scale=guidance_scale,
641
- seed=seed,
642
- )
643
- return [control_image, *results]
644
-
645
- @torch.inference_mode()
646
- def process_openpose(
647
- self,
648
- image: np.ndarray,
649
- prompt: str,
650
- additional_prompt: str,
651
- negative_prompt: str,
652
- num_images: int,
653
- image_resolution: int,
654
- preprocess_resolution: int,
655
- num_steps: int,
656
- guidance_scale: float,
657
- seed: int,
658
- preprocessor_name: str,
659
- ) -> list[PIL.Image.Image]:
660
- if image is None:
661
- raise ValueError
662
- if image_resolution > MAX_IMAGE_RESOLUTION:
663
- raise ValueError
664
- if num_images > MAX_NUM_IMAGES:
665
- raise ValueError
666
-
667
- if preprocessor_name == "None":
668
- image = HWC3(image)
669
- image = resize_image(image, resolution=image_resolution)
670
- control_image = PIL.Image.fromarray(image)
671
- else:
672
- self.preprocessor.load("Openpose")
673
- control_image = self.preprocessor(
674
- image=image,
675
- image_resolution=image_resolution,
676
- detect_resolution=preprocess_resolution,
677
- hand_and_face=True,
678
- )
679
- self.load_controlnet_weight("Openpose")
680
- results = self.run_pipe(
681
- prompt=self.get_prompt(prompt, additional_prompt),
682
- negative_prompt=negative_prompt,
683
- control_image=control_image,
684
- num_images=num_images,
685
- num_steps=num_steps,
686
- guidance_scale=guidance_scale,
687
- seed=seed,
688
- )
689
- return [control_image, *results]
690
-
691
- @torch.inference_mode()
692
- def process_segmentation(
693
- self,
694
- image: np.ndarray,
695
- prompt: str,
696
- additional_prompt: str,
697
- negative_prompt: str,
698
- num_images: int,
699
- image_resolution: int,
700
- preprocess_resolution: int,
701
- num_steps: int,
702
- guidance_scale: float,
703
- seed: int,
704
- preprocessor_name: str,
705
- ) -> list[PIL.Image.Image]:
706
- if image is None:
707
- raise ValueError
708
- if image_resolution > MAX_IMAGE_RESOLUTION:
709
- raise ValueError
710
- if num_images > MAX_NUM_IMAGES:
711
- raise ValueError
712
-
713
- if preprocessor_name == "None":
714
- image = HWC3(image)
715
- image = resize_image(image, resolution=image_resolution)
716
- control_image = PIL.Image.fromarray(image)
717
- else:
718
- self.preprocessor.load(preprocessor_name)
719
- control_image = self.preprocessor(
720
- image=image,
721
- image_resolution=image_resolution,
722
- detect_resolution=preprocess_resolution,
723
- )
724
- self.load_controlnet_weight("segmentation")
725
- results = self.run_pipe(
726
- prompt=self.get_prompt(prompt, additional_prompt),
727
- negative_prompt=negative_prompt,
728
- control_image=control_image,
729
- num_images=num_images,
730
- num_steps=num_steps,
731
- guidance_scale=guidance_scale,
732
- seed=seed,
733
- )
734
- return [control_image, *results]
735
-
736
- @torch.inference_mode()
737
- def process_depth(
738
- self,
739
- image: np.ndarray,
740
- prompt: str,
741
- additional_prompt: str,
742
- negative_prompt: str,
743
- num_images: int,
744
- image_resolution: int,
745
- preprocess_resolution: int,
746
- num_steps: int,
747
- guidance_scale: float,
748
- seed: int,
749
- preprocessor_name: str,
750
- ) -> list[PIL.Image.Image]:
751
- if image is None:
752
- raise ValueError
753
- if image_resolution > MAX_IMAGE_RESOLUTION:
754
- raise ValueError
755
- if num_images > MAX_NUM_IMAGES:
756
- raise ValueError
757
-
758
- if preprocessor_name == "None":
759
- image = HWC3(image)
760
- image = resize_image(image, resolution=image_resolution)
761
- control_image = PIL.Image.fromarray(image)
762
- else:
763
- self.preprocessor.load(preprocessor_name)
764
- control_image = self.preprocessor(
765
- image=image,
766
- image_resolution=image_resolution,
767
- detect_resolution=preprocess_resolution,
768
- )
769
- self.load_controlnet_weight("depth")
770
- results = self.run_pipe(
771
- prompt=self.get_prompt(prompt, additional_prompt),
772
- negative_prompt=negative_prompt,
773
- control_image=control_image,
774
- num_images=num_images,
775
- num_steps=num_steps,
776
- guidance_scale=guidance_scale,
777
- seed=seed,
778
- )
779
- return [control_image, *results]
780
-
781
- @torch.inference_mode()
782
- def process_normal(
783
- self,
784
- image: np.ndarray,
785
- prompt: str,
786
- additional_prompt: str,
787
- negative_prompt: str,
788
- num_images: int,
789
- image_resolution: int,
790
- preprocess_resolution: int,
791
- num_steps: int,
792
- guidance_scale: float,
793
- seed: int,
794
- preprocessor_name: str,
795
- ) -> list[PIL.Image.Image]:
796
- if image is None:
797
- raise ValueError
798
- if image_resolution > MAX_IMAGE_RESOLUTION:
799
- raise ValueError
800
- if num_images > MAX_NUM_IMAGES:
801
- raise ValueError
802
-
803
- if preprocessor_name == "None":
804
- image = HWC3(image)
805
- image = resize_image(image, resolution=image_resolution)
806
- control_image = PIL.Image.fromarray(image)
807
- else:
808
- self.preprocessor.load("NormalBae")
809
- control_image = self.preprocessor(
810
- image=image,
811
- image_resolution=image_resolution,
812
- detect_resolution=preprocess_resolution,
813
- )
814
- self.load_controlnet_weight("NormalBae")
815
- results = self.run_pipe(
816
- prompt=self.get_prompt(prompt, additional_prompt),
817
- negative_prompt=negative_prompt,
818
- control_image=control_image,
819
- num_images=num_images,
820
- num_steps=num_steps,
821
- guidance_scale=guidance_scale,
822
- seed=seed,
823
- )
824
- return [control_image, *results]
825
-
826
- @torch.inference_mode()
827
- def process_lineart(
828
- self,
829
- image: np.ndarray,
830
- prompt: str,
831
- additional_prompt: str,
832
- negative_prompt: str,
833
- num_images: int,
834
- image_resolution: int,
835
- preprocess_resolution: int,
836
- num_steps: int,
837
- guidance_scale: float,
838
- seed: int,
839
- preprocessor_name: str,
840
- ) -> list[PIL.Image.Image]:
841
- if image is None:
842
- raise ValueError
843
- if image_resolution > MAX_IMAGE_RESOLUTION:
844
- raise ValueError
845
- if num_images > MAX_NUM_IMAGES:
846
- raise ValueError
847
-
848
- if preprocessor_name in ["None", "None (anime)"]:
849
- image = HWC3(image)
850
- image = resize_image(image, resolution=image_resolution)
851
- control_image = PIL.Image.fromarray(image)
852
- elif preprocessor_name in ["Lineart", "Lineart coarse"]:
853
- coarse = "coarse" in preprocessor_name
854
- self.preprocessor.load("Lineart")
855
- control_image = self.preprocessor(
856
- image=image,
857
- image_resolution=image_resolution,
858
- detect_resolution=preprocess_resolution,
859
- coarse=coarse,
860
- )
861
- elif preprocessor_name == "Lineart (anime)":
862
- self.preprocessor.load("LineartAnime")
863
- control_image = self.preprocessor(
864
- image=image,
865
- image_resolution=image_resolution,
866
- detect_resolution=preprocess_resolution,
867
- )
868
- if "anime" in preprocessor_name:
869
- self.load_controlnet_weight("lineart_anime")
870
- else:
871
- self.load_controlnet_weight("lineart")
872
- results = self.run_pipe(
873
- prompt=self.get_prompt(prompt, additional_prompt),
874
- negative_prompt=negative_prompt,
875
- control_image=control_image,
876
- num_images=num_images,
877
- num_steps=num_steps,
878
- guidance_scale=guidance_scale,
879
- seed=seed,
880
- )
881
- return [control_image, *results]
882
-
883
- @torch.inference_mode()
884
- def process_shuffle(
885
- self,
886
- image: np.ndarray,
887
- prompt: str,
888
- additional_prompt: str,
889
- negative_prompt: str,
890
- num_images: int,
891
- image_resolution: int,
892
- num_steps: int,
893
- guidance_scale: float,
894
- seed: int,
895
- preprocessor_name: str,
896
- ) -> list[PIL.Image.Image]:
897
- if image is None:
898
- raise ValueError
899
- if image_resolution > MAX_IMAGE_RESOLUTION:
900
- raise ValueError
901
- if num_images > MAX_NUM_IMAGES:
902
- raise ValueError
903
-
904
- if preprocessor_name == "None":
905
- image = HWC3(image)
906
- image = resize_image(image, resolution=image_resolution)
907
- control_image = PIL.Image.fromarray(image)
908
- else:
909
- self.preprocessor.load(preprocessor_name)
910
- control_image = self.preprocessor(
911
- image=image,
912
- image_resolution=image_resolution,
913
- )
914
- self.load_controlnet_weight("shuffle")
915
- results = self.run_pipe(
916
- prompt=self.get_prompt(prompt, additional_prompt),
917
- negative_prompt=negative_prompt,
918
- control_image=control_image,
919
- num_images=num_images,
920
- num_steps=num_steps,
921
- guidance_scale=guidance_scale,
922
- seed=seed,
923
- )
924
- return [control_image, *results]
925
-
926
- @torch.inference_mode()
927
- def process_ip2p(
928
- self,
929
- image: np.ndarray,
930
- prompt: str,
931
- additional_prompt: str,
932
- negative_prompt: str,
933
- num_images: int,
934
- image_resolution: int,
935
- num_steps: int,
936
- guidance_scale: float,
937
- seed: int,
938
- ) -> list[PIL.Image.Image]:
939
- if image is None:
940
- raise ValueError
941
- if image_resolution > MAX_IMAGE_RESOLUTION:
942
- raise ValueError
943
- if num_images > MAX_NUM_IMAGES:
944
- raise ValueError
945
-
946
- image = HWC3(image)
947
- image = resize_image(image, resolution=image_resolution)
948
- control_image = PIL.Image.fromarray(image)
949
- self.load_controlnet_weight("ip2p")
950
- results = self.run_pipe(
951
- prompt=self.get_prompt(prompt, additional_prompt),
952
- negative_prompt=negative_prompt,
953
- control_image=control_image,
954
- num_images=num_images,
955
- num_steps=num_steps,
956
- guidance_scale=guidance_scale,
957
- seed=seed,
958
- )
959
- return [control_image, *results]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pre-requirements.txt DELETED
@@ -1,9 +0,0 @@
1
- accelerate
2
- diffusers
3
- invisible_watermark
4
- torch
5
- torchvision
6
- transformers
7
- xformers
8
- controlnet-aux # for controlnet
9
- spaces # no need to specify here
 
 
 
 
 
 
 
 
 
 
preprocessor.py DELETED
@@ -1,120 +0,0 @@
1
- import gc
2
- from typing import TYPE_CHECKING
3
-
4
- if TYPE_CHECKING:
5
- from collections.abc import Callable
6
-
7
- import numpy as np
8
- import PIL.Image
9
- import torch
10
- from controlnet_aux import (
11
- CannyDetector,
12
- ContentShuffleDetector,
13
- HEDdetector,
14
- LineartAnimeDetector,
15
- LineartDetector,
16
- MidasDetector,
17
- MLSDdetector,
18
- NormalBaeDetector,
19
- OpenposeDetector,
20
- PidiNetDetector,
21
- )
22
- from controlnet_aux.util import HWC3
23
-
24
- from cv_utils import resize_image
25
- from depth_estimator import DepthEstimator
26
- from image_segmentor import ImageSegmentor
27
-
28
-
29
- class Preprocessor:
30
- MODEL_ID = "lllyasviel/Annotators"
31
-
32
- def __init__(self) -> None:
33
- self.model: Callable = None # type: ignore
34
- self.name = ""
35
-
36
- def load(self, name: str) -> None: # noqa: C901, PLR0912
37
- if name == self.name:
38
- return
39
- if name == "HED":
40
- self.model = HEDdetector.from_pretrained(self.MODEL_ID)
41
- elif name == "Midas":
42
- self.model = MidasDetector.from_pretrained(self.MODEL_ID)
43
- elif name == "MLSD":
44
- self.model = MLSDdetector.from_pretrained(self.MODEL_ID)
45
- elif name == "Openpose":
46
- self.model = OpenposeDetector.from_pretrained(self.MODEL_ID)
47
- elif name == "PidiNet":
48
- self.model = PidiNetDetector.from_pretrained(self.MODEL_ID)
49
- elif name == "NormalBae":
50
- self.model = NormalBaeDetector.from_pretrained(self.MODEL_ID)
51
- elif name == "Lineart":
52
- self.model = LineartDetector.from_pretrained(self.MODEL_ID)
53
- elif name == "LineartAnime":
54
- self.model = LineartAnimeDetector.from_pretrained(self.MODEL_ID)
55
- elif name == "Canny":
56
- self.model = CannyDetector()
57
- elif name == "ContentShuffle":
58
- self.model = ContentShuffleDetector()
59
- elif name == "DPT":
60
- self.model = DepthEstimator()
61
- elif name == "UPerNet":
62
- self.model = ImageSegmentor()
63
- elif name == 'texnet':
64
- self.model = TexnetPreprocessor()
65
- else:
66
- raise ValueError
67
- torch.cuda.empty_cache()
68
- gc.collect()
69
- self.name = name
70
-
71
- def __call__(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image: # noqa: ANN003
72
- if self.name == "Canny":
73
- if "detect_resolution" in kwargs:
74
- detect_resolution = kwargs.pop("detect_resolution")
75
- image = np.array(image)
76
- image = HWC3(image)
77
- image = resize_image(image, resolution=detect_resolution)
78
- image = self.model(image, **kwargs)
79
- return PIL.Image.fromarray(image)
80
- if self.name == "Midas":
81
- detect_resolution = kwargs.pop("detect_resolution", 512)
82
- image_resolution = kwargs.pop("image_resolution", 512)
83
- image = np.array(image)
84
- image = HWC3(image)
85
- image = resize_image(image, resolution=detect_resolution)
86
- image = self.model(image, **kwargs)
87
- image = HWC3(image)
88
- image = resize_image(image, resolution=image_resolution)
89
- return PIL.Image.fromarray(image)
90
- return self.model(image, **kwargs)
91
-
92
-
93
- # https://github.com/huggingface/controlnet_aux/blob/master/src/controlnet_aux/canny/__init__.py
94
- class TexnetPreprocessor:
95
- def __call__(self, input_image=None, low_threshold=100, high_threshold=200, image_resolution=512, output_type=None, **kwargs):
96
- if "img" in kwargs:
97
- warnings.warn("img is deprecated, please use `input_image=...` instead.", DeprecationWarning)
98
- input_image = kwargs.pop("img")
99
-
100
- if input_image is None:
101
- raise ValueError("input_image must be defined.")
102
-
103
- if not isinstance(input_image, np.ndarray):
104
- input_image = np.array(input_image, dtype=np.uint8)
105
- output_type = output_type or "pil"
106
- else:
107
- output_type = output_type or "np"
108
-
109
- input_image = HWC3(input_image)
110
- input_image = resize_image(input_image, image_resolution)
111
- H, W, C = input_image.shape
112
-
113
- # detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
114
- output_image = input_image.copy()
115
-
116
- if output_type == "pil":
117
- # detected_map = Image.fromarray(detected_map)
118
- output_image = PIL.Image.fromarray(output_image)
119
-
120
- return output_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
push_dataset.py DELETED
@@ -1,9 +0,0 @@
1
- from huggingface_hub import HfApi
2
- api = HfApi()
3
-
4
- api.upload_folder(
5
- folder_path="./examples",
6
- repo_id="jingyangcarl/matgen",
7
- repo_type="space",
8
- path_in_repo="examples", # Upload to a specific folder
9
- )
 
 
 
 
 
 
 
 
 
 
requirements.txt DELETED
@@ -1,9 +0,0 @@
1
- torch
2
- torchvision
3
- pytorch3d @ git+https://github.com/facebookresearch/pytorch3d.git@stable
4
- trimesh
5
- xatlas
6
- scikit-learn
7
- opencv-python
8
- matplotlib
9
- omegaconf
 
 
 
 
 
 
 
 
 
 
rgb2x/generate_blend.py DELETED
@@ -1,142 +0,0 @@
1
- import bpy
2
- import sys
3
- import os
4
-
5
- def create_tex_node(nodes, img_path, label, color_space, location):
6
- img = bpy.data.images.load(img_path)
7
- tex = nodes.new(type='ShaderNodeTexImage')
8
- tex.image = img
9
- tex.label = label
10
- tex.location = location
11
- tex.image.colorspace_settings.name = color_space
12
- return tex
13
-
14
- def setup_environment_lighting(hdri_path):
15
- if not bpy.data.worlds:
16
- bpy.data.worlds.new(name="World")
17
- if bpy.context.scene.world is None:
18
- bpy.context.scene.world = bpy.data.worlds[0]
19
- world = bpy.context.scene.world
20
-
21
- world.use_nodes = True
22
- nodes = world.node_tree.nodes
23
- links = world.node_tree.links
24
- nodes.clear()
25
-
26
- env_tex = nodes.new(type="ShaderNodeTexEnvironment")
27
- env_tex.image = bpy.data.images.load(hdri_path)
28
- env_tex.location = (-300, 0)
29
-
30
- bg = nodes.new(type="ShaderNodeBackground")
31
- bg.location = (0, 0)
32
-
33
- output = nodes.new(type="ShaderNodeOutputWorld")
34
- output.location = (300, 0)
35
-
36
- links.new(env_tex.outputs["Color"], bg.inputs["Color"])
37
- links.new(bg.outputs["Background"], output.inputs["Surface"])
38
-
39
- def setup_gpu_rendering():
40
- bpy.context.scene.render.engine = 'CYCLES'
41
- prefs = bpy.context.preferences
42
- cprefs = prefs.addons['cycles'].preferences
43
-
44
- # Choose backend depending on GPU type: 'CUDA', 'OPTIX', 'HIP', 'METAL'
45
- cprefs.compute_device_type = 'CUDA'
46
- bpy.context.scene.cycles.device = 'GPU'
47
-
48
- def generate_blend(obj_path, base_color_path, normal_map_path, roughness_path, metallic_path, output_blend):
49
- # Reset scene
50
- bpy.ops.wm.read_factory_settings(use_empty=True)
51
-
52
- # Import OBJ
53
- bpy.ops.import_scene.obj(filepath=obj_path)
54
- obj = bpy.context.selected_objects[0]
55
-
56
- # Create material
57
- mat = bpy.data.materials.new(name="BRDF_Material")
58
- mat.use_nodes = True
59
- nodes = mat.node_tree.nodes
60
- links = mat.node_tree.links
61
- nodes.clear()
62
-
63
- output = nodes.new(type='ShaderNodeOutputMaterial')
64
- output.location = (400, 0)
65
-
66
- principled = nodes.new(type='ShaderNodeBsdfPrincipled')
67
- principled.location = (100, 0)
68
- links.new(principled.outputs['BSDF'], output.inputs['Surface'])
69
-
70
- # Base Color
71
- base_color = create_tex_node(nodes, base_color_path, "Base Color", 'sRGB', (-600, 200))
72
- links.new(base_color.outputs['Color'], principled.inputs['Base Color'])
73
-
74
- # Roughness
75
- rough = create_tex_node(nodes, roughness_path, "Roughness", 'Non-Color', (-600, 0))
76
- links.new(rough.outputs['Color'], principled.inputs['Roughness'])
77
-
78
- # Metallic
79
- metal = create_tex_node(nodes, metallic_path, "Metallic", 'Non-Color', (-600, -200))
80
- links.new(metal.outputs['Color'], principled.inputs['Metallic'])
81
-
82
- # Normal Map
83
- normal_tex = create_tex_node(nodes, normal_map_path, "Normal Map", 'Non-Color', (-800, -400))
84
- normal_map = nodes.new(type='ShaderNodeNormalMap')
85
- normal_map.location = (-400, -400)
86
- links.new(normal_tex.outputs['Color'], normal_map.inputs['Color'])
87
- links.new(normal_map.outputs['Normal'], principled.inputs['Normal'])
88
-
89
- # Assign material
90
- if obj.data.materials:
91
- obj.data.materials[0] = mat
92
- else:
93
- obj.data.materials.append(mat)
94
-
95
- # Global Illumination using Blender's default forest HDRI
96
- blender_data_path = bpy.utils.resource_path('LOCAL')
97
- forest_hdri_path = os.path.join(blender_data_path, "datafiles", "studiolights", "world", "forest.exr")
98
- print(f"Using HDRI: {forest_hdri_path}")
99
- setup_environment_lighting(forest_hdri_path)
100
-
101
- # GPU rendering setup
102
- setup_gpu_rendering()
103
-
104
- # Pack textures into .blend
105
- bpy.ops.file.pack_all()
106
-
107
- # Set the 3D View to Rendered mode and focus on object
108
- for area in bpy.context.screen.areas:
109
- if area.type == 'VIEW_3D':
110
- for space in area.spaces:
111
- if space.type == 'VIEW_3D':
112
- space.shading.type = 'RENDERED' # Set viewport shading to Rendered
113
- for region in area.regions:
114
- if region.type == 'WINDOW':
115
- override = {'area': area, 'region': region, 'scene': bpy.context.scene}
116
- bpy.ops.view3d.view_all(override, center=True)
117
-
118
- elif area.type == 'NODE_EDITOR':
119
- for space in area.spaces:
120
- if space.type == 'NODE_EDITOR':
121
- space.tree_type = 'ShaderNodeTree' # Switch to Shader Editor
122
- space.shader_type = 'OBJECT'
123
-
124
- # Optional: Switch active workspace to Shading (if it exists)
125
- for workspace in bpy.data.workspaces:
126
- if workspace.name == 'Shading':
127
- bpy.context.window.workspace = workspace
128
- break
129
-
130
- # Save the .blend file
131
- bpy.ops.wm.save_as_mainfile(filepath=output_blend)
132
- print(f"✅ Saved .blend file with BRDF, HDRI, GPU: {output_blend}")
133
-
134
- if __name__ == "__main__":
135
- argv = sys.argv
136
- argv = argv[argv.index("--") + 1:] # Only use args after "--"
137
-
138
- if len(argv) != 6:
139
- print("Usage:\n blender --background --python generate_blend.py -- obj base_color normal roughness metallic output.blend")
140
- sys.exit(1)
141
-
142
- generate_blend(*argv)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rgb2x/gradio_demo_rgb2x.py DELETED
@@ -1,157 +0,0 @@
1
- import os
2
-
3
- os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
4
-
5
- import gradio as gr
6
- import torch
7
- import torchvision
8
- from diffusers import DDIMScheduler
9
- from load_image import load_exr_image, load_ldr_image
10
- from pipeline_rgb2x import StableDiffusionAOVMatEstPipeline
11
-
12
- current_directory = os.path.dirname(os.path.abspath(__file__))
13
-
14
-
15
- def get_rgb2x_demo():
16
- # Load pipeline
17
- pipe = StableDiffusionAOVMatEstPipeline.from_pretrained(
18
- "zheng95z/rgb-to-x",
19
- torch_dtype=torch.float16,
20
- cache_dir=os.path.join(current_directory, "model_cache"),
21
- ).to("cuda")
22
- pipe.scheduler = DDIMScheduler.from_config(
23
- pipe.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing"
24
- )
25
- pipe.set_progress_bar_config(disable=True)
26
- pipe.to("cuda")
27
-
28
- # Augmentation
29
- def callback(
30
- photo,
31
- seed,
32
- inference_step,
33
- num_samples,
34
- ):
35
- generator = torch.Generator(device="cuda").manual_seed(seed)
36
-
37
- if photo.name.endswith(".exr"):
38
- photo = load_exr_image(photo.name, tonemaping=True, clamp=True).to("cuda")
39
- elif (
40
- photo.name.endswith(".png")
41
- or photo.name.endswith(".jpg")
42
- or photo.name.endswith(".jpeg")
43
- ):
44
- photo = load_ldr_image(photo.name, from_srgb=True).to("cuda")
45
-
46
- # Check if the width and height are multiples of 8. If not, crop it using torchvision.transforms.CenterCrop
47
- old_height = photo.shape[1]
48
- old_width = photo.shape[2]
49
- new_height = old_height
50
- new_width = old_width
51
- radio = old_height / old_width
52
- max_side = 1000
53
- if old_height > old_width:
54
- new_height = max_side
55
- new_width = int(new_height / radio)
56
- else:
57
- new_width = max_side
58
- new_height = int(new_width * radio)
59
-
60
- if new_width % 8 != 0 or new_height % 8 != 0:
61
- new_width = new_width // 8 * 8
62
- new_height = new_height // 8 * 8
63
-
64
- photo = torchvision.transforms.Resize((new_height, new_width))(photo)
65
-
66
- required_aovs = ["albedo", "normal", "roughness", "metallic", "irradiance"]
67
- prompts = {
68
- "albedo": "Albedo (diffuse basecolor)",
69
- "normal": "Camera-space Normal",
70
- "roughness": "Roughness",
71
- "metallic": "Metallicness",
72
- "irradiance": "Irradiance (diffuse lighting)",
73
- }
74
-
75
- return_list = []
76
- for i in range(num_samples):
77
- for aov_name in required_aovs:
78
- prompt = prompts[aov_name]
79
- generated_image = pipe(
80
- prompt=prompt,
81
- photo=photo,
82
- num_inference_steps=inference_step,
83
- height=new_height,
84
- width=new_width,
85
- generator=generator,
86
- required_aovs=[aov_name],
87
- ).images[0][0]
88
-
89
- generated_image = torchvision.transforms.Resize(
90
- (old_height, old_width)
91
- )(generated_image)
92
-
93
- generated_image = (generated_image, f"Generated {aov_name} {i}")
94
- return_list.append(generated_image)
95
-
96
- return return_list
97
-
98
- block = gr.Blocks()
99
- with block:
100
- with gr.Row():
101
- gr.Markdown("## Model RGB -> X (Realistic image -> Intrinsic channels)")
102
- with gr.Row():
103
- # Input side
104
- with gr.Column():
105
- gr.Markdown("### Given Image")
106
- photo = gr.File(label="Photo", file_types=[".exr", ".png", ".jpg"])
107
-
108
- gr.Markdown("### Parameters")
109
- run_button = gr.Button(value="Run")
110
- with gr.Accordion("Advanced options", open=False):
111
- seed = gr.Slider(
112
- label="Seed",
113
- minimum=-1,
114
- maximum=2147483647,
115
- step=1,
116
- randomize=True,
117
- )
118
- inference_step = gr.Slider(
119
- label="Inference Step",
120
- minimum=1,
121
- maximum=100,
122
- step=1,
123
- value=50,
124
- )
125
- num_samples = gr.Slider(
126
- label="Samples",
127
- minimum=1,
128
- maximum=100,
129
- step=1,
130
- value=1,
131
- )
132
-
133
- # Output side
134
- with gr.Column():
135
- gr.Markdown("### Output Gallery")
136
- result_gallery = gr.Gallery(
137
- label="Output",
138
- show_label=False,
139
- elem_id="gallery",
140
- columns=2,
141
- )
142
-
143
- inputs = [
144
- photo,
145
- seed,
146
- inference_step,
147
- num_samples,
148
- ]
149
- run_button.click(fn=callback, inputs=inputs, outputs=result_gallery, queue=True)
150
-
151
- return block
152
-
153
-
154
- if __name__ == "__main__":
155
- demo = get_rgb2x_demo()
156
- demo.queue(max_size=1)
157
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rgb2x/load_image.py DELETED
@@ -1,119 +0,0 @@
1
- import os
2
-
3
- import cv2
4
- import torch
5
-
6
- os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
7
- import numpy as np
8
-
9
-
10
- def convert_rgb_2_XYZ(rgb):
11
- # Reference: https://web.archive.org/web/20191027010220/http://www.brucelindbloom.com/index.html?Eqn_RGB_XYZ_Matrix.html
12
- # rgb: (h, w, 3)
13
- # XYZ: (h, w, 3)
14
- XYZ = torch.ones_like(rgb)
15
- XYZ[:, :, 0] = (
16
- 0.4124564 * rgb[:, :, 0] + 0.3575761 * rgb[:, :, 1] + 0.1804375 * rgb[:, :, 2]
17
- )
18
- XYZ[:, :, 1] = (
19
- 0.2126729 * rgb[:, :, 0] + 0.7151522 * rgb[:, :, 1] + 0.0721750 * rgb[:, :, 2]
20
- )
21
- XYZ[:, :, 2] = (
22
- 0.0193339 * rgb[:, :, 0] + 0.1191920 * rgb[:, :, 1] + 0.9503041 * rgb[:, :, 2]
23
- )
24
- return XYZ
25
-
26
-
27
- def convert_XYZ_2_Yxy(XYZ):
28
- # XYZ: (h, w, 3)
29
- # Yxy: (h, w, 3)
30
- Yxy = torch.ones_like(XYZ)
31
- Yxy[:, :, 0] = XYZ[:, :, 1]
32
- sum = torch.sum(XYZ, dim=2)
33
- inv_sum = 1.0 / torch.clamp(sum, min=1e-4)
34
- Yxy[:, :, 1] = XYZ[:, :, 0] * inv_sum
35
- Yxy[:, :, 2] = XYZ[:, :, 1] * inv_sum
36
- return Yxy
37
-
38
-
39
- def convert_rgb_2_Yxy(rgb):
40
- # rgb: (h, w, 3)
41
- # Yxy: (h, w, 3)
42
- return convert_XYZ_2_Yxy(convert_rgb_2_XYZ(rgb))
43
-
44
-
45
- def convert_XYZ_2_rgb(XYZ):
46
- # XYZ: (h, w, 3)
47
- # rgb: (h, w, 3)
48
- rgb = torch.ones_like(XYZ)
49
- rgb[:, :, 0] = (
50
- 3.2404542 * XYZ[:, :, 0] - 1.5371385 * XYZ[:, :, 1] - 0.4985314 * XYZ[:, :, 2]
51
- )
52
- rgb[:, :, 1] = (
53
- -0.9692660 * XYZ[:, :, 0] + 1.8760108 * XYZ[:, :, 1] + 0.0415560 * XYZ[:, :, 2]
54
- )
55
- rgb[:, :, 2] = (
56
- 0.0556434 * XYZ[:, :, 0] - 0.2040259 * XYZ[:, :, 1] + 1.0572252 * XYZ[:, :, 2]
57
- )
58
- return rgb
59
-
60
-
61
- def convert_Yxy_2_XYZ(Yxy):
62
- # Yxy: (h, w, 3)
63
- # XYZ: (h, w, 3)
64
- XYZ = torch.ones_like(Yxy)
65
- XYZ[:, :, 0] = Yxy[:, :, 1] / torch.clamp(Yxy[:, :, 2], min=1e-6) * Yxy[:, :, 0]
66
- XYZ[:, :, 1] = Yxy[:, :, 0]
67
- XYZ[:, :, 2] = (
68
- (1.0 - Yxy[:, :, 1] - Yxy[:, :, 2])
69
- / torch.clamp(Yxy[:, :, 2], min=1e-4)
70
- * Yxy[:, :, 0]
71
- )
72
- return XYZ
73
-
74
-
75
- def convert_Yxy_2_rgb(Yxy):
76
- # Yxy: (h, w, 3)
77
- # rgb: (h, w, 3)
78
- return convert_XYZ_2_rgb(convert_Yxy_2_XYZ(Yxy))
79
-
80
-
81
- def load_ldr_image(image_path, from_srgb=False, clamp=False, normalize=False):
82
- # Load png or jpg image
83
- image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
84
- image = torch.from_numpy(image.astype(np.float32) / 255.0) # (h, w, c)
85
- image[~torch.isfinite(image)] = 0
86
- if from_srgb:
87
- # Convert from sRGB to linear RGB
88
- image = image**2.2
89
- if clamp:
90
- image = torch.clamp(image, min=0.0, max=1.0)
91
- if normalize:
92
- # Normalize to [-1, 1]
93
- image = image * 2.0 - 1.0
94
- image = torch.nn.functional.normalize(image, dim=-1, eps=1e-6)
95
- return image.permute(2, 0, 1) # returns (c, h, w)
96
-
97
-
98
- def load_exr_image(image_path, tonemaping=False, clamp=False, normalize=False):
99
- image = cv2.cvtColor(cv2.imread(image_path, -1), cv2.COLOR_BGR2RGB)
100
- image = torch.from_numpy(image.astype("float32")) # (h, w, c)
101
- image[~torch.isfinite(image)] = 0
102
- if tonemaping:
103
- # Exposure adjuestment
104
- image_Yxy = convert_rgb_2_Yxy(image)
105
- lum = (
106
- image[:, :, 0:1] * 0.2125
107
- + image[:, :, 1:2] * 0.7154
108
- + image[:, :, 2:3] * 0.0721
109
- )
110
- lum = torch.log(torch.clamp(lum, min=1e-6))
111
- lum_mean = torch.exp(torch.mean(lum))
112
- lp = image_Yxy[:, :, 0:1] * 0.18 / torch.clamp(lum_mean, min=1e-6)
113
- image_Yxy[:, :, 0:1] = lp
114
- image = convert_Yxy_2_rgb(image_Yxy)
115
- if clamp:
116
- image = torch.clamp(image, min=0.0, max=1.0)
117
- if normalize:
118
- image = torch.nn.functional.normalize(image, dim=-1, eps=1e-6)
119
- return image.permute(2, 0, 1) # returns (c, h, w)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rgb2x/pipeline_rgb2x.py DELETED
@@ -1,821 +0,0 @@
1
- import inspect
2
- from dataclasses import dataclass
3
- from typing import Callable, List, Optional, Union
4
-
5
- import numpy as np
6
- import PIL
7
- import torch
8
- from diffusers.configuration_utils import register_to_config
9
- from diffusers.image_processor import VaeImageProcessor
10
- from diffusers.loaders import (
11
- LoraLoaderMixin,
12
- TextualInversionLoaderMixin,
13
- )
14
- from diffusers.models import AutoencoderKL, UNet2DConditionModel
15
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline
16
- from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
17
- rescale_noise_cfg,
18
- )
19
- from diffusers.schedulers import KarrasDiffusionSchedulers
20
- from diffusers.utils import (
21
- CONFIG_NAME,
22
- BaseOutput,
23
- deprecate,
24
- logging,
25
- )
26
- from diffusers.utils.torch_utils import randn_tensor
27
- from transformers import CLIPTextModel, CLIPTokenizer
28
-
29
- logger = logging.get_logger(__name__)
30
-
31
-
32
- class VaeImageProcrssorAOV(VaeImageProcessor):
33
- """
34
- Image processor for VAE AOV.
35
-
36
- Args:
37
- do_resize (`bool`, *optional*, defaults to `True`):
38
- Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
39
- vae_scale_factor (`int`, *optional*, defaults to `8`):
40
- VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
41
- resample (`str`, *optional*, defaults to `lanczos`):
42
- Resampling filter to use when resizing the image.
43
- do_normalize (`bool`, *optional*, defaults to `True`):
44
- Whether to normalize the image to [-1,1].
45
- """
46
-
47
- config_name = CONFIG_NAME
48
-
49
- @register_to_config
50
- def __init__(
51
- self,
52
- do_resize: bool = True,
53
- vae_scale_factor: int = 8,
54
- resample: str = "lanczos",
55
- do_normalize: bool = True,
56
- ):
57
- super().__init__()
58
-
59
- def postprocess(
60
- self,
61
- image: torch.FloatTensor,
62
- output_type: str = "pil",
63
- do_denormalize: Optional[List[bool]] = None,
64
- do_gamma_correction: bool = True,
65
- ):
66
- if not isinstance(image, torch.Tensor):
67
- raise ValueError(
68
- f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
69
- )
70
- if output_type not in ["latent", "pt", "np", "pil"]:
71
- deprecation_message = (
72
- f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
73
- "`pil`, `np`, `pt`, `latent`"
74
- )
75
- deprecate(
76
- "Unsupported output_type",
77
- "1.0.0",
78
- deprecation_message,
79
- standard_warn=False,
80
- )
81
- output_type = "np"
82
-
83
- if output_type == "latent":
84
- return image
85
-
86
- if do_denormalize is None:
87
- do_denormalize = [self.config.do_normalize] * image.shape[0]
88
-
89
- image = torch.stack(
90
- [
91
- self.denormalize(image[i]) if do_denormalize[i] else image[i]
92
- for i in range(image.shape[0])
93
- ]
94
- )
95
-
96
- # Gamma correction
97
- if do_gamma_correction:
98
- image = torch.pow(image, 1.0 / 2.2)
99
-
100
- if output_type == "pt":
101
- return image
102
-
103
- image = self.pt_to_numpy(image)
104
-
105
- if output_type == "np":
106
- return image
107
-
108
- if output_type == "pil":
109
- return self.numpy_to_pil(image)
110
-
111
- def preprocess_normal(
112
- self,
113
- image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
114
- height: Optional[int] = None,
115
- width: Optional[int] = None,
116
- ) -> torch.Tensor:
117
- image = torch.stack([image], axis=0)
118
- return image
119
-
120
-
121
- @dataclass
122
- class StableDiffusionAOVPipelineOutput(BaseOutput):
123
- """
124
- Output class for Stable Diffusion AOV pipelines.
125
-
126
- Args:
127
- images (`List[PIL.Image.Image]` or `np.ndarray`)
128
- List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
129
- num_channels)`.
130
- nsfw_content_detected (`List[bool]`)
131
- List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or
132
- `None` if safety checking could not be performed.
133
- """
134
-
135
- images: Union[List[PIL.Image.Image], np.ndarray]
136
-
137
-
138
- class StableDiffusionAOVMatEstPipeline(
139
- DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin
140
- ):
141
- r"""
142
- Pipeline for AOVs.
143
-
144
- This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
145
- implemented for all pipelines (downloading, saving, running on a particular device, etc.).
146
-
147
- The pipeline also inherits the following loading methods:
148
- - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
149
- - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
150
- - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
151
-
152
- Args:
153
- vae ([`AutoencoderKL`]):
154
- Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
155
- text_encoder ([`~transformers.CLIPTextModel`]):
156
- Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
157
- tokenizer ([`~transformers.CLIPTokenizer`]):
158
- A `CLIPTokenizer` to tokenize text.
159
- unet ([`UNet2DConditionModel`]):
160
- A `UNet2DConditionModel` to denoise the encoded image latents.
161
- scheduler ([`SchedulerMixin`]):
162
- A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
163
- [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
164
- """
165
-
166
- def __init__(
167
- self,
168
- vae: AutoencoderKL,
169
- text_encoder: CLIPTextModel,
170
- tokenizer: CLIPTokenizer,
171
- unet: UNet2DConditionModel,
172
- scheduler: KarrasDiffusionSchedulers,
173
- ):
174
- super().__init__()
175
-
176
- self.register_modules(
177
- vae=vae,
178
- text_encoder=text_encoder,
179
- tokenizer=tokenizer,
180
- unet=unet,
181
- scheduler=scheduler,
182
- )
183
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
184
- self.image_processor = VaeImageProcrssorAOV(
185
- vae_scale_factor=self.vae_scale_factor
186
- )
187
- self.register_to_config()
188
-
189
- def _encode_prompt(
190
- self,
191
- prompt,
192
- device,
193
- num_images_per_prompt,
194
- do_classifier_free_guidance,
195
- negative_prompt=None,
196
- prompt_embeds: Optional[torch.FloatTensor] = None,
197
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
198
- ):
199
- r"""
200
- Encodes the prompt into text encoder hidden states.
201
-
202
- Args:
203
- prompt (`str` or `List[str]`, *optional*):
204
- prompt to be encoded
205
- device: (`torch.device`):
206
- torch device
207
- num_images_per_prompt (`int`):
208
- number of images that should be generated per prompt
209
- do_classifier_free_guidance (`bool`):
210
- whether to use classifier free guidance or not
211
- negative_ prompt (`str` or `List[str]`, *optional*):
212
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
213
- `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
214
- less than `1`).
215
- prompt_embeds (`torch.FloatTensor`, *optional*):
216
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
217
- provided, text embeddings will be generated from `prompt` input argument.
218
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
219
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
220
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
221
- argument.
222
- """
223
- if prompt is not None and isinstance(prompt, str):
224
- batch_size = 1
225
- elif prompt is not None and isinstance(prompt, list):
226
- batch_size = len(prompt)
227
- else:
228
- batch_size = prompt_embeds.shape[0]
229
-
230
- if prompt_embeds is None:
231
- # textual inversion: procecss multi-vector tokens if necessary
232
- if isinstance(self, TextualInversionLoaderMixin):
233
- prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
234
-
235
- text_inputs = self.tokenizer(
236
- prompt,
237
- padding="max_length",
238
- max_length=self.tokenizer.model_max_length,
239
- truncation=True,
240
- return_tensors="pt",
241
- )
242
- text_input_ids = text_inputs.input_ids
243
- untruncated_ids = self.tokenizer(
244
- prompt, padding="longest", return_tensors="pt"
245
- ).input_ids
246
-
247
- if untruncated_ids.shape[-1] >= text_input_ids.shape[
248
- -1
249
- ] and not torch.equal(text_input_ids, untruncated_ids):
250
- removed_text = self.tokenizer.batch_decode(
251
- untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
252
- )
253
- logger.warning(
254
- "The following part of your input was truncated because CLIP can only handle sequences up to"
255
- f" {self.tokenizer.model_max_length} tokens: {removed_text}"
256
- )
257
-
258
- if (
259
- hasattr(self.text_encoder.config, "use_attention_mask")
260
- and self.text_encoder.config.use_attention_mask
261
- ):
262
- attention_mask = text_inputs.attention_mask.to(device)
263
- else:
264
- attention_mask = None
265
-
266
- prompt_embeds = self.text_encoder(
267
- text_input_ids.to(device),
268
- attention_mask=attention_mask,
269
- )
270
- prompt_embeds = prompt_embeds[0]
271
-
272
- prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
273
-
274
- bs_embed, seq_len, _ = prompt_embeds.shape
275
- # duplicate text embeddings for each generation per prompt, using mps friendly method
276
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
277
- prompt_embeds = prompt_embeds.view(
278
- bs_embed * num_images_per_prompt, seq_len, -1
279
- )
280
-
281
- # get unconditional embeddings for classifier free guidance
282
- if do_classifier_free_guidance and negative_prompt_embeds is None:
283
- uncond_tokens: List[str]
284
- if negative_prompt is None:
285
- uncond_tokens = [""] * batch_size
286
- elif type(prompt) is not type(negative_prompt):
287
- raise TypeError(
288
- f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
289
- f" {type(prompt)}."
290
- )
291
- elif isinstance(negative_prompt, str):
292
- uncond_tokens = [negative_prompt]
293
- elif batch_size != len(negative_prompt):
294
- raise ValueError(
295
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
296
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
297
- " the batch size of `prompt`."
298
- )
299
- else:
300
- uncond_tokens = negative_prompt
301
-
302
- # textual inversion: procecss multi-vector tokens if necessary
303
- if isinstance(self, TextualInversionLoaderMixin):
304
- uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
305
-
306
- max_length = prompt_embeds.shape[1]
307
- uncond_input = self.tokenizer(
308
- uncond_tokens,
309
- padding="max_length",
310
- max_length=max_length,
311
- truncation=True,
312
- return_tensors="pt",
313
- )
314
-
315
- if (
316
- hasattr(self.text_encoder.config, "use_attention_mask")
317
- and self.text_encoder.config.use_attention_mask
318
- ):
319
- attention_mask = uncond_input.attention_mask.to(device)
320
- else:
321
- attention_mask = None
322
-
323
- negative_prompt_embeds = self.text_encoder(
324
- uncond_input.input_ids.to(device),
325
- attention_mask=attention_mask,
326
- )
327
- negative_prompt_embeds = negative_prompt_embeds[0]
328
-
329
- if do_classifier_free_guidance:
330
- # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
331
- seq_len = negative_prompt_embeds.shape[1]
332
-
333
- negative_prompt_embeds = negative_prompt_embeds.to(
334
- dtype=self.text_encoder.dtype, device=device
335
- )
336
-
337
- negative_prompt_embeds = negative_prompt_embeds.repeat(
338
- 1, num_images_per_prompt, 1
339
- )
340
- negative_prompt_embeds = negative_prompt_embeds.view(
341
- batch_size * num_images_per_prompt, seq_len, -1
342
- )
343
-
344
- # For classifier free guidance, we need to do two forward passes.
345
- # Here we concatenate the unconditional and text embeddings into a single batch
346
- # to avoid doing two forward passes
347
- # pix2pix has two negative embeddings, and unlike in other pipelines latents are ordered [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds]
348
- prompt_embeds = torch.cat(
349
- [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds]
350
- )
351
-
352
- return prompt_embeds
353
-
354
- def prepare_extra_step_kwargs(self, generator, eta):
355
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
356
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
357
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
358
- # and should be between [0, 1]
359
-
360
- accepts_eta = "eta" in set(
361
- inspect.signature(self.scheduler.step).parameters.keys()
362
- )
363
- extra_step_kwargs = {}
364
- if accepts_eta:
365
- extra_step_kwargs["eta"] = eta
366
-
367
- # check if the scheduler accepts generator
368
- accepts_generator = "generator" in set(
369
- inspect.signature(self.scheduler.step).parameters.keys()
370
- )
371
- if accepts_generator:
372
- extra_step_kwargs["generator"] = generator
373
- return extra_step_kwargs
374
-
375
- def check_inputs(
376
- self,
377
- prompt,
378
- callback_steps,
379
- negative_prompt=None,
380
- prompt_embeds=None,
381
- negative_prompt_embeds=None,
382
- ):
383
- if (callback_steps is None) or (
384
- callback_steps is not None
385
- and (not isinstance(callback_steps, int) or callback_steps <= 0)
386
- ):
387
- raise ValueError(
388
- f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
389
- f" {type(callback_steps)}."
390
- )
391
-
392
- if prompt is not None and prompt_embeds is not None:
393
- raise ValueError(
394
- f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
395
- " only forward one of the two."
396
- )
397
- elif prompt is None and prompt_embeds is None:
398
- raise ValueError(
399
- "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
400
- )
401
- elif prompt is not None and (
402
- not isinstance(prompt, str) and not isinstance(prompt, list)
403
- ):
404
- raise ValueError(
405
- f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
406
- )
407
-
408
- if negative_prompt is not None and negative_prompt_embeds is not None:
409
- raise ValueError(
410
- f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
411
- f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
412
- )
413
-
414
- if prompt_embeds is not None and negative_prompt_embeds is not None:
415
- if prompt_embeds.shape != negative_prompt_embeds.shape:
416
- raise ValueError(
417
- "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
418
- f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
419
- f" {negative_prompt_embeds.shape}."
420
- )
421
-
422
- def prepare_latents(
423
- self,
424
- batch_size,
425
- num_channels_latents,
426
- height,
427
- width,
428
- dtype,
429
- device,
430
- generator,
431
- latents=None,
432
- ):
433
- shape = (
434
- batch_size,
435
- num_channels_latents,
436
- height // self.vae_scale_factor,
437
- width // self.vae_scale_factor,
438
- )
439
- if isinstance(generator, list) and len(generator) != batch_size:
440
- raise ValueError(
441
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
442
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
443
- )
444
-
445
- if latents is None:
446
- latents = randn_tensor(
447
- shape, generator=generator, device=device, dtype=dtype
448
- )
449
- else:
450
- latents = latents.to(device)
451
-
452
- # scale the initial noise by the standard deviation required by the scheduler
453
- latents = latents * self.scheduler.init_noise_sigma
454
- return latents
455
-
456
- def prepare_image_latents(
457
- self,
458
- image,
459
- batch_size,
460
- num_images_per_prompt,
461
- dtype,
462
- device,
463
- do_classifier_free_guidance,
464
- generator=None,
465
- ):
466
- if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
467
- raise ValueError(
468
- f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
469
- )
470
-
471
- image = image.to(device=device, dtype=dtype)
472
-
473
- batch_size = batch_size * num_images_per_prompt
474
-
475
- if image.shape[1] == 4:
476
- image_latents = image
477
- else:
478
- if isinstance(generator, list) and len(generator) != batch_size:
479
- raise ValueError(
480
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
481
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
482
- )
483
-
484
- if isinstance(generator, list):
485
- image_latents = [
486
- self.vae.encode(image[i : i + 1]).latent_dist.mode()
487
- for i in range(batch_size)
488
- ]
489
- image_latents = torch.cat(image_latents, dim=0)
490
- else:
491
- image_latents = self.vae.encode(image).latent_dist.mode()
492
-
493
- if (
494
- batch_size > image_latents.shape[0]
495
- and batch_size % image_latents.shape[0] == 0
496
- ):
497
- # expand image_latents for batch_size
498
- deprecation_message = (
499
- f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial"
500
- " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
501
- " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
502
- " your script to pass as many initial images as text prompts to suppress this warning."
503
- )
504
- deprecate(
505
- "len(prompt) != len(image)",
506
- "1.0.0",
507
- deprecation_message,
508
- standard_warn=False,
509
- )
510
- additional_image_per_prompt = batch_size // image_latents.shape[0]
511
- image_latents = torch.cat(
512
- [image_latents] * additional_image_per_prompt, dim=0
513
- )
514
- elif (
515
- batch_size > image_latents.shape[0]
516
- and batch_size % image_latents.shape[0] != 0
517
- ):
518
- raise ValueError(
519
- f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
520
- )
521
- else:
522
- image_latents = torch.cat([image_latents], dim=0)
523
-
524
- if do_classifier_free_guidance:
525
- uncond_image_latents = torch.zeros_like(image_latents)
526
- image_latents = torch.cat(
527
- [image_latents, image_latents, uncond_image_latents], dim=0
528
- )
529
-
530
- return image_latents
531
-
532
- @torch.no_grad()
533
- def __call__(
534
- self,
535
- prompt: Union[str, List[str]] = None,
536
- photo: Union[
537
- torch.FloatTensor,
538
- PIL.Image.Image,
539
- np.ndarray,
540
- List[torch.FloatTensor],
541
- List[PIL.Image.Image],
542
- List[np.ndarray],
543
- ] = None,
544
- height: Optional[int] = None,
545
- width: Optional[int] = None,
546
- num_inference_steps: int = 100,
547
- required_aovs: List[str] = ["albedo"],
548
- negative_prompt: Optional[Union[str, List[str]]] = None,
549
- num_images_per_prompt: Optional[int] = 1,
550
- use_default_scaling_factor: Optional[bool] = False,
551
- guidance_scale: float = 0.0,
552
- image_guidance_scale: float = 0.0,
553
- guidance_rescale: float = 0.0,
554
- eta: float = 0.0,
555
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
556
- latents: Optional[torch.FloatTensor] = None,
557
- prompt_embeds: Optional[torch.FloatTensor] = None,
558
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
559
- output_type: Optional[str] = "pil",
560
- return_dict: bool = True,
561
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
562
- callback_steps: int = 1,
563
- ):
564
- r"""
565
- The call function to the pipeline for generation.
566
-
567
- Args:
568
- prompt (`str` or `List[str]`, *optional*):
569
- The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
570
- image (`torch.FloatTensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
571
- `Image` or tensor representing an image batch to be repainted according to `prompt`. Can also accept
572
- image latents as `image`, but if passing latents directly it is not encoded again.
573
- num_inference_steps (`int`, *optional*, defaults to 100):
574
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
575
- expense of slower inference.
576
- guidance_scale (`float`, *optional*, defaults to 7.5):
577
- A higher guidance scale value encourages the model to generate images closely linked to the text
578
- `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
579
- image_guidance_scale (`float`, *optional*, defaults to 1.5):
580
- Push the generated image towards the inital `image`. Image guidance scale is enabled by setting
581
- `image_guidance_scale > 1`. Higher image guidance scale encourages generated images that are closely
582
- linked to the source `image`, usually at the expense of lower image quality. This pipeline requires a
583
- value of at least `1`.
584
- negative_prompt (`str` or `List[str]`, *optional*):
585
- The prompt or prompts to guide what to not include in image generation. If not defined, you need to
586
- pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
587
- num_images_per_prompt (`int`, *optional*, defaults to 1):
588
- The number of images to generate per prompt.
589
- eta (`float`, *optional*, defaults to 0.0):
590
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
591
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
592
- generator (`torch.Generator`, *optional*):
593
- A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
594
- generation deterministic.
595
- latents (`torch.FloatTensor`, *optional*):
596
- Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
597
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
598
- tensor is generated by sampling using the supplied random `generator`.
599
- prompt_embeds (`torch.FloatTensor`, *optional*):
600
- Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
601
- provided, text embeddings are generated from the `prompt` input argument.
602
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
603
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
604
- not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
605
- output_type (`str`, *optional*, defaults to `"pil"`):
606
- The output format of the generated image. Choose between `PIL.Image` or `np.array`.
607
- return_dict (`bool`, *optional*, defaults to `True`):
608
- Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
609
- plain tuple.
610
- callback (`Callable`, *optional*):
611
- A function that calls every `callback_steps` steps during inference. The function is called with the
612
- following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
613
- callback_steps (`int`, *optional*, defaults to 1):
614
- The frequency at which the `callback` function is called. If not specified, the callback is called at
615
- every step.
616
-
617
- Examples:
618
-
619
- ```py
620
- >>> import PIL
621
- >>> import requests
622
- >>> import torch
623
- >>> from io import BytesIO
624
-
625
- >>> from diffusers import StableDiffusionInstructPix2PixPipeline
626
-
627
-
628
- >>> def download_image(url):
629
- ... response = requests.get(url)
630
- ... return PIL.Image.open(BytesIO(response.content)).convert("RGB")
631
-
632
-
633
- >>> img_url = "https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png"
634
-
635
- >>> image = download_image(img_url).resize((512, 512))
636
-
637
- >>> pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
638
- ... "timbrooks/instruct-pix2pix", torch_dtype=torch.float16
639
- ... )
640
- >>> pipe = pipe.to("cuda")
641
-
642
- >>> prompt = "make the mountains snowy"
643
- >>> image = pipe(prompt=prompt, image=image).images[0]
644
- ```
645
-
646
- Returns:
647
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
648
- If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
649
- otherwise a `tuple` is returned where the first element is a list with the generated images and the
650
- second element is a list of `bool`s indicating whether the corresponding generated image contains
651
- "not-safe-for-work" (nsfw) content.
652
- """
653
- # 0. Check inputs
654
- self.check_inputs(
655
- prompt,
656
- callback_steps,
657
- negative_prompt,
658
- prompt_embeds,
659
- negative_prompt_embeds,
660
- )
661
-
662
- # 1. Define call parameters
663
- if prompt is not None and isinstance(prompt, str):
664
- batch_size = 1
665
- elif prompt is not None and isinstance(prompt, list):
666
- batch_size = len(prompt)
667
- else:
668
- batch_size = prompt_embeds.shape[0]
669
-
670
- device = self._execution_device
671
- do_classifier_free_guidance = (
672
- guidance_scale > 1.0 and image_guidance_scale >= 1.0
673
- )
674
- # check if scheduler is in sigmas space
675
- scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas")
676
-
677
- # 2. Encode input prompt
678
- prompt_embeds = self._encode_prompt(
679
- prompt,
680
- device,
681
- num_images_per_prompt,
682
- do_classifier_free_guidance,
683
- negative_prompt,
684
- prompt_embeds=prompt_embeds,
685
- negative_prompt_embeds=negative_prompt_embeds,
686
- )
687
-
688
- # 3. Preprocess image
689
- # Normalize image to [-1,1]
690
- preprocessed_photo = self.image_processor.preprocess(photo)
691
-
692
- # 4. set timesteps
693
- self.scheduler.set_timesteps(num_inference_steps, device=device)
694
- timesteps = self.scheduler.timesteps
695
-
696
- # 5. Prepare Image latents
697
- image_latents = self.prepare_image_latents(
698
- preprocessed_photo,
699
- batch_size,
700
- num_images_per_prompt,
701
- prompt_embeds.dtype,
702
- device,
703
- do_classifier_free_guidance,
704
- generator,
705
- )
706
- image_latents = image_latents * self.vae.config.scaling_factor
707
-
708
- height, width = image_latents.shape[-2:]
709
- height = height * self.vae_scale_factor
710
- width = width * self.vae_scale_factor
711
-
712
- # 6. Prepare latent variables
713
- num_channels_latents = self.unet.config.out_channels
714
- latents = self.prepare_latents(
715
- batch_size * num_images_per_prompt,
716
- num_channels_latents,
717
- height,
718
- width,
719
- prompt_embeds.dtype,
720
- device,
721
- generator,
722
- latents,
723
- )
724
-
725
- # 7. Check that shapes of latents and image match the UNet channels
726
- num_channels_image = image_latents.shape[1]
727
- if num_channels_latents + num_channels_image != self.unet.config.in_channels:
728
- raise ValueError(
729
- f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
730
- f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
731
- f" `num_channels_image`: {num_channels_image} "
732
- f" = {num_channels_latents+num_channels_image}. Please verify the config of"
733
- " `pipeline.unet` or your `image` input."
734
- )
735
-
736
- # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
737
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
738
-
739
- # 9. Denoising loop
740
- num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
741
- with self.progress_bar(total=num_inference_steps) as progress_bar:
742
- for i, t in enumerate(timesteps):
743
- # Expand the latents if we are doing classifier free guidance.
744
- # The latents are expanded 3 times because for pix2pix the guidance\
745
- # is applied for both the text and the input image.
746
- latent_model_input = (
747
- torch.cat([latents] * 3) if do_classifier_free_guidance else latents
748
- )
749
-
750
- # concat latents, image_latents in the channel dimension
751
- scaled_latent_model_input = self.scheduler.scale_model_input(
752
- latent_model_input, t
753
- )
754
- scaled_latent_model_input = torch.cat(
755
- [scaled_latent_model_input, image_latents], dim=1
756
- )
757
-
758
- # predict the noise residual
759
- noise_pred = self.unet(
760
- scaled_latent_model_input,
761
- t,
762
- encoder_hidden_states=prompt_embeds,
763
- return_dict=False,
764
- )[0]
765
-
766
- # perform guidance
767
- if do_classifier_free_guidance:
768
- (
769
- noise_pred_text,
770
- noise_pred_image,
771
- noise_pred_uncond,
772
- ) = noise_pred.chunk(3)
773
- noise_pred = (
774
- noise_pred_uncond
775
- + guidance_scale * (noise_pred_text - noise_pred_image)
776
- + image_guidance_scale * (noise_pred_image - noise_pred_uncond)
777
- )
778
-
779
- if do_classifier_free_guidance and guidance_rescale > 0.0:
780
- # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
781
- noise_pred = rescale_noise_cfg(
782
- noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
783
- )
784
-
785
- # compute the previous noisy sample x_t -> x_t-1
786
- latents = self.scheduler.step(
787
- noise_pred, t, latents, **extra_step_kwargs, return_dict=False
788
- )[0]
789
-
790
- # call the callback, if provided
791
- if i == len(timesteps) - 1 or (
792
- (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
793
- ):
794
- progress_bar.update()
795
- if callback is not None and i % callback_steps == 0:
796
- callback(i, t, latents)
797
-
798
- aov_latents = latents / self.vae.config.scaling_factor
799
- aov = self.vae.decode(aov_latents, return_dict=False)[0]
800
- do_denormalize = [True] * aov.shape[0]
801
- aov_name = required_aovs[0]
802
- if aov_name == "albedo" or aov_name == "irradiance":
803
- do_gamma_correction = True
804
- else:
805
- do_gamma_correction = False
806
-
807
- if aov_name == "roughness" or aov_name == "metallic":
808
- aov = aov[:, 0:1].repeat(1, 3, 1, 1)
809
-
810
- aov = self.image_processor.postprocess(
811
- aov,
812
- output_type=output_type,
813
- do_denormalize=do_denormalize,
814
- do_gamma_correction=do_gamma_correction,
815
- )
816
- aovs = [aov]
817
-
818
- # Offload last model to CPU
819
- if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
820
- self.final_offload_hook.offload()
821
- return StableDiffusionAOVPipelineOutput(images=aovs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
run.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ CONDA_ENV=$(head -1 /code/environment.yml | cut -d" " -f2)
3
+ eval "$(conda shell.bash hook)"
4
+ conda activate $CONDA_ENV
5
+ python app.py
settings.py DELETED
@@ -1,23 +0,0 @@
1
- import os
2
-
3
- import numpy as np
4
-
5
- DEFAULT_MODEL_ID = os.getenv("DEFAULT_MODEL_ID", "stable-diffusion-v1-5/stable-diffusion-v1-5")
6
-
7
- MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "3"))
8
- DEFAULT_NUM_IMAGES = min(MAX_NUM_IMAGES, int(os.getenv("DEFAULT_NUM_IMAGES", "1")))
9
- MAX_IMAGE_RESOLUTION = int(os.getenv("MAX_IMAGE_RESOLUTION", "2048"))
10
- DEFAULT_IMAGE_RESOLUTION = min(MAX_IMAGE_RESOLUTION, int(os.getenv("DEFAULT_IMAGE_RESOLUTION", "1024")))
11
-
12
- ALLOW_CHANGING_BASE_MODEL = os.getenv("SPACE_ID") != "hysts/ControlNet-v1-1"
13
- SHOW_DUPLICATE_BUTTON = os.getenv("SHOW_DUPLICATE_BUTTON") == "1"
14
-
15
- MAX_SEED = np.iinfo(np.int32).max
16
-
17
- # Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
18
-
19
- # setup CUDA
20
- # disable the following when deployting to hugging face
21
- # if os.getenv("CUDA_VISIBLE_DEVICES") is None:
22
- # os.environ["CUDA_VISIBLE_DEVICES"] = "7"
23
- # os.environ["GRADIO_SERVER_PORT"] = "7864"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
text2tex/lib/__init__.py DELETED
File without changes
text2tex/lib/camera_helper.py DELETED
@@ -1,231 +0,0 @@
1
- import torch
2
-
3
- import numpy as np
4
-
5
- from sklearn.metrics.pairwise import cosine_similarity
6
-
7
- from pytorch3d.renderer import (
8
- PerspectiveCameras,
9
- look_at_view_transform
10
- )
11
-
12
- # customized
13
- import sys
14
- sys.path.append(".")
15
-
16
- from lib.constants import VIEWPOINTS
17
-
18
- # ---------------- UTILS ----------------------
19
-
20
- def degree_to_radian(d):
21
- return d * np.pi / 180
22
-
23
- def radian_to_degree(r):
24
- return 180 * r / np.pi
25
-
26
- def xyz_to_polar(xyz):
27
- """ assume y-axis is the up axis """
28
-
29
- x, y, z = xyz
30
-
31
- theta = 180 * np.arccos(z) / np.pi
32
- phi = 180 * np.arccos(y) / np.pi
33
-
34
- return theta, phi
35
-
36
- def polar_to_xyz(theta, phi, dist):
37
- """ assume y-axis is the up axis """
38
-
39
- theta = degree_to_radian(theta)
40
- phi = degree_to_radian(phi)
41
-
42
- x = np.sin(phi) * np.sin(theta) * dist
43
- y = np.cos(phi) * dist
44
- z = np.sin(phi) * np.cos(theta) * dist
45
-
46
- return [x, y, z]
47
-
48
-
49
- # ---------------- VIEWPOINTS ----------------------
50
-
51
-
52
- def filter_viewpoints(pre_viewpoints: dict, viewpoints: dict):
53
- """ return the binary mask of viewpoints to be filtered """
54
-
55
- filter_mask = [0 for _ in viewpoints.keys()]
56
- for i, v in viewpoints.items():
57
- x_v, y_v, z_v = polar_to_xyz(v["azim"], 90 - v["elev"], v["dist"])
58
-
59
- for _, pv in pre_viewpoints.items():
60
- x_pv, y_pv, z_pv = polar_to_xyz(pv["azim"], 90 - pv["elev"], pv["dist"])
61
- sim = cosine_similarity(
62
- np.array([[x_v, y_v, z_v]]),
63
- np.array([[x_pv, y_pv, z_pv]])
64
- )[0, 0]
65
-
66
- if sim > 0.9:
67
- filter_mask[i] = 1
68
-
69
- return filter_mask
70
-
71
-
72
- def init_viewpoints(mode, sample_space, init_dist, init_elev, principle_directions,
73
- use_principle=True, use_shapenet=False, use_objaverse=False):
74
-
75
- if mode == "predefined":
76
-
77
- (
78
- dist_list,
79
- elev_list,
80
- azim_list,
81
- sector_list
82
- ) = init_predefined_viewpoints(sample_space, init_dist, init_elev)
83
-
84
- elif mode == "hemisphere":
85
-
86
- (
87
- dist_list,
88
- elev_list,
89
- azim_list,
90
- sector_list
91
- ) = init_hemisphere_viewpoints(sample_space, init_dist)
92
-
93
- else:
94
- raise NotImplementedError()
95
-
96
- # punishments for views -> in case always selecting the same view
97
- view_punishments = [1 for _ in range(len(dist_list))]
98
-
99
- if use_principle:
100
-
101
- (
102
- dist_list,
103
- elev_list,
104
- azim_list,
105
- sector_list,
106
- view_punishments
107
- ) = init_principle_viewpoints(
108
- principle_directions,
109
- dist_list,
110
- elev_list,
111
- azim_list,
112
- sector_list,
113
- view_punishments,
114
- use_shapenet,
115
- use_objaverse
116
- )
117
-
118
- return dist_list, elev_list, azim_list, sector_list, view_punishments
119
-
120
-
121
- def init_principle_viewpoints(
122
- principle_directions,
123
- dist_list,
124
- elev_list,
125
- azim_list,
126
- sector_list,
127
- view_punishments,
128
- use_shapenet=False,
129
- use_objaverse=False
130
- ):
131
-
132
- if use_shapenet:
133
- key = "shapenet"
134
-
135
- pre_elev_list = [v for v in VIEWPOINTS[key]["elev"]]
136
- pre_azim_list = [v for v in VIEWPOINTS[key]["azim"]]
137
- pre_sector_list = [v for v in VIEWPOINTS[key]["sector"]]
138
-
139
- num_principle = 10
140
- pre_dist_list = [dist_list[0] for _ in range(num_principle)]
141
- pre_view_punishments = [0 for _ in range(num_principle)]
142
-
143
- elif use_objaverse:
144
- key = "objaverse"
145
-
146
- pre_elev_list = [v for v in VIEWPOINTS[key]["elev"]]
147
- pre_azim_list = [v for v in VIEWPOINTS[key]["azim"]]
148
- pre_sector_list = [v for v in VIEWPOINTS[key]["sector"]]
149
-
150
- num_principle = 10
151
- pre_dist_list = [dist_list[0] for _ in range(num_principle)]
152
- pre_view_punishments = [0 for _ in range(num_principle)]
153
- else:
154
- num_principle = 6
155
- pre_elev_list = [v for v in VIEWPOINTS[num_principle]["elev"]]
156
- pre_azim_list = [v for v in VIEWPOINTS[num_principle]["azim"]]
157
- pre_sector_list = [v for v in VIEWPOINTS[num_principle]["sector"]]
158
- pre_dist_list = [dist_list[0] for _ in range(num_principle)]
159
- pre_view_punishments = [0 for _ in range(num_principle)]
160
-
161
- dist_list = pre_dist_list + dist_list
162
- elev_list = pre_elev_list + elev_list
163
- azim_list = pre_azim_list + azim_list
164
- sector_list = pre_sector_list + sector_list
165
- view_punishments = pre_view_punishments + view_punishments
166
-
167
- return dist_list, elev_list, azim_list, sector_list, view_punishments
168
-
169
-
170
- def init_predefined_viewpoints(sample_space, init_dist, init_elev):
171
-
172
- viewpoints = VIEWPOINTS[sample_space]
173
-
174
- assert sample_space == len(viewpoints["sector"])
175
-
176
- dist_list = [init_dist for _ in range(sample_space)] # always the same dist
177
- elev_list = [viewpoints["elev"][i] for i in range(sample_space)]
178
- azim_list = [viewpoints["azim"][i] for i in range(sample_space)]
179
- sector_list = [viewpoints["sector"][i] for i in range(sample_space)]
180
-
181
- return dist_list, elev_list, azim_list, sector_list
182
-
183
-
184
- def init_hemisphere_viewpoints(sample_space, init_dist):
185
- """
186
- y is up-axis
187
- """
188
-
189
- num_points = 2 * sample_space
190
- ga = np.pi * (3. - np.sqrt(5.)) # golden angle in radians
191
-
192
- flags = []
193
- elev_list = [] # degree
194
- azim_list = [] # degree
195
-
196
- for i in range(num_points):
197
- y = 1 - (i / float(num_points - 1)) * 2 # y goes from 1 to -1
198
-
199
- # only take the north hemisphere
200
- if y >= 0:
201
- flags.append(True)
202
- else:
203
- flags.append(False)
204
-
205
- theta = ga * i # golden angle increment
206
-
207
- elev_list.append(radian_to_degree(np.arcsin(y)))
208
- azim_list.append(radian_to_degree(theta))
209
-
210
- radius = np.sqrt(1 - y * y) # radius at y
211
- x = np.cos(theta) * radius
212
- z = np.sin(theta) * radius
213
-
214
- elev_list = [elev_list[i] for i in range(len(elev_list)) if flags[i]]
215
- azim_list = [azim_list[i] for i in range(len(azim_list)) if flags[i]]
216
-
217
- dist_list = [init_dist for _ in elev_list]
218
- sector_list = ["good" for _ in elev_list] # HACK don't define sector names for now
219
-
220
- return dist_list, elev_list, azim_list, sector_list
221
-
222
-
223
- # ---------------- CAMERAS ----------------------
224
-
225
-
226
- def init_camera(dist, elev, azim, image_size, device):
227
- R, T = look_at_view_transform(dist, elev, azim)
228
- image_size = torch.tensor([image_size, image_size]).unsqueeze(0)
229
- cameras = PerspectiveCameras(R=R, T=T, device=device, image_size=image_size)
230
-
231
- return cameras
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
text2tex/lib/constants.py DELETED
@@ -1,648 +0,0 @@
1
- PALETTE = {
2
- 0: [255, 255, 255], # white - background
3
- 1: [204, 50, 50], # red - old
4
- 2: [231, 180, 22], # yellow - update
5
- 3: [45, 201, 55] # green - new
6
- }
7
-
8
- QUAD_WEIGHTS = {
9
- 0: 0, # background
10
- 1: 0.1, # old
11
- 2: 0.5, # update
12
- 3: 1 # new
13
- }
14
-
15
- VIEWPOINTS = {
16
- 1: {
17
- "azim": [
18
- 0
19
- ],
20
- "elev": [
21
- 0
22
- ],
23
- "sector": [
24
- "front"
25
- ]
26
- },
27
- 2: {
28
- "azim": [
29
- 0,
30
- 30
31
- ],
32
- "elev": [
33
- 0,
34
- 0
35
- ],
36
- "sector": [
37
- "front",
38
- "front"
39
- ]
40
- },
41
- 4: {
42
- "azim": [
43
- 45,
44
- 315,
45
- 135,
46
- 225,
47
- ],
48
- "elev": [
49
- 0,
50
- 0,
51
- 0,
52
- 0,
53
- ],
54
- "sector": [
55
- "front right",
56
- "front left",
57
- "back right",
58
- "back left",
59
- ]
60
- },
61
- 6: {
62
- "azim": [
63
- 0,
64
- 90,
65
- 270,
66
- 0,
67
- 180,
68
- 0
69
- ],
70
- "elev": [
71
- 0,
72
- 0,
73
- 0,
74
- 90,
75
- 0,
76
- -90
77
- ],
78
- "sector": [
79
- "front",
80
- "right",
81
- "left",
82
- "top",
83
- "back",
84
- "bottom",
85
- ]
86
- },
87
- "shapenet": {
88
- "azim": [
89
- 270,
90
- 315,
91
- 225,
92
- 0,
93
- 180,
94
- 45,
95
- 135,
96
- 90,
97
- 270,
98
- 270
99
- ],
100
- "elev": [
101
- 15,
102
- 15,
103
- 15,
104
- 15,
105
- 15,
106
- 15,
107
- 15,
108
- 15,
109
- 90,
110
- -90
111
- ],
112
- "sector": [
113
- "front",
114
- "front right",
115
- "front left",
116
- "right",
117
- "left",
118
- "back right",
119
- "back left",
120
- "back",
121
- "top",
122
- "bottom",
123
- ]
124
- },
125
- "objaverse": {
126
- "azim": [
127
- 0,
128
- 45,
129
- 315,
130
- 90,
131
- 270,
132
- 135,
133
- 225,
134
- 180,
135
- 0,
136
- 0
137
- ],
138
- "elev": [
139
- 15,
140
- 15,
141
- 15,
142
- 15,
143
- 15,
144
- 15,
145
- 15,
146
- 15,
147
- 90,
148
- -90
149
- ],
150
- "sector": [
151
- "front",
152
- "front right",
153
- "front left",
154
- "right",
155
- "left",
156
- "back right",
157
- "back left",
158
- "back",
159
- "top",
160
- "bottom",
161
- ]
162
- },
163
- 12: {
164
- "azim": [
165
- 45,
166
- 315,
167
- 135,
168
- 225,
169
-
170
- 0,
171
- 45,
172
- 315,
173
- 90,
174
- 270,
175
- 135,
176
- 225,
177
- 180,
178
- ],
179
- "elev": [
180
- 0,
181
- 0,
182
- 0,
183
- 0,
184
-
185
- 45,
186
- 45,
187
- 45,
188
- 45,
189
- 45,
190
- 45,
191
- 45,
192
- 45,
193
- ],
194
- "sector": [
195
- "front right",
196
- "front left",
197
- "back right",
198
- "back left",
199
-
200
- "front",
201
- "front right",
202
- "front left",
203
- "right",
204
- "left",
205
- "back right",
206
- "back left",
207
- "back",
208
- ]
209
- },
210
- 20: {
211
- "azim": [
212
- 45,
213
- 315,
214
- 135,
215
- 225,
216
-
217
- 0,
218
- 45,
219
- 315,
220
- 90,
221
- 270,
222
- 135,
223
- 225,
224
- 180,
225
-
226
- 0,
227
- 45,
228
- 315,
229
- 90,
230
- 270,
231
- 135,
232
- 225,
233
- 180,
234
- ],
235
- "elev": [
236
- 0,
237
- 0,
238
- 0,
239
- 0,
240
-
241
- 30,
242
- 30,
243
- 30,
244
- 30,
245
- 30,
246
- 30,
247
- 30,
248
- 30,
249
-
250
- 60,
251
- 60,
252
- 60,
253
- 60,
254
- 60,
255
- 60,
256
- 60,
257
- 60,
258
- ],
259
- "sector": [
260
- "front right",
261
- "front left",
262
- "back right",
263
- "back left",
264
-
265
- "front",
266
- "front right",
267
- "front left",
268
- "right",
269
- "left",
270
- "back right",
271
- "back left",
272
- "back",
273
-
274
- "front",
275
- "front right",
276
- "front left",
277
- "right",
278
- "left",
279
- "back right",
280
- "back left",
281
- "back",
282
- ]
283
- },
284
- 36: {
285
- "azim": [
286
- 45,
287
- 315,
288
- 135,
289
- 225,
290
-
291
- 0,
292
- 45,
293
- 315,
294
- 90,
295
- 270,
296
- 135,
297
- 225,
298
- 180,
299
-
300
- 0,
301
- 45,
302
- 315,
303
- 90,
304
- 270,
305
- 135,
306
- 225,
307
- 180,
308
-
309
- 22.5,
310
- 337.5,
311
- 67.5,
312
- 292.5,
313
- 112.5,
314
- 247.5,
315
- 157.5,
316
- 202.5,
317
-
318
- 22.5,
319
- 337.5,
320
- 67.5,
321
- 292.5,
322
- 112.5,
323
- 247.5,
324
- 157.5,
325
- 202.5,
326
- ],
327
- "elev": [
328
- 0,
329
- 0,
330
- 0,
331
- 0,
332
-
333
- 30,
334
- 30,
335
- 30,
336
- 30,
337
- 30,
338
- 30,
339
- 30,
340
- 30,
341
-
342
- 60,
343
- 60,
344
- 60,
345
- 60,
346
- 60,
347
- 60,
348
- 60,
349
- 60,
350
-
351
- 15,
352
- 15,
353
- 15,
354
- 15,
355
- 15,
356
- 15,
357
- 15,
358
- 15,
359
-
360
- 45,
361
- 45,
362
- 45,
363
- 45,
364
- 45,
365
- 45,
366
- 45,
367
- 45,
368
- ],
369
- "sector": [
370
- "front right",
371
- "front left",
372
- "back right",
373
- "back left",
374
-
375
- "front",
376
- "front right",
377
- "front left",
378
- "right",
379
- "left",
380
- "back right",
381
- "back left",
382
- "back",
383
-
384
- "top front",
385
- "top right",
386
- "top left",
387
- "top right",
388
- "top left",
389
- "top right",
390
- "top left",
391
- "top back",
392
-
393
- "front right",
394
- "front left",
395
- "front right",
396
- "front left",
397
- "back right",
398
- "back left",
399
- "back right",
400
- "back left",
401
-
402
- "front right",
403
- "front left",
404
- "front right",
405
- "front left",
406
- "back right",
407
- "back left",
408
- "back right",
409
- "back left",
410
- ]
411
- },
412
- 68: {
413
- "azim": [
414
- 45,
415
- 315,
416
- 135,
417
- 225,
418
-
419
- 0,
420
- 45,
421
- 315,
422
- 90,
423
- 270,
424
- 135,
425
- 225,
426
- 180,
427
-
428
- 0,
429
- 45,
430
- 315,
431
- 90,
432
- 270,
433
- 135,
434
- 225,
435
- 180,
436
-
437
- 22.5,
438
- 337.5,
439
- 67.5,
440
- 292.5,
441
- 112.5,
442
- 247.5,
443
- 157.5,
444
- 202.5,
445
-
446
- 22.5,
447
- 337.5,
448
- 67.5,
449
- 292.5,
450
- 112.5,
451
- 247.5,
452
- 157.5,
453
- 202.5,
454
-
455
- 0,
456
- 45,
457
- 315,
458
- 90,
459
- 270,
460
- 135,
461
- 225,
462
- 180,
463
-
464
- 0,
465
- 45,
466
- 315,
467
- 90,
468
- 270,
469
- 135,
470
- 225,
471
- 180,
472
-
473
- 22.5,
474
- 337.5,
475
- 67.5,
476
- 292.5,
477
- 112.5,
478
- 247.5,
479
- 157.5,
480
- 202.5,
481
-
482
- 22.5,
483
- 337.5,
484
- 67.5,
485
- 292.5,
486
- 112.5,
487
- 247.5,
488
- 157.5,
489
- 202.5
490
- ],
491
- "elev": [
492
- 0,
493
- 0,
494
- 0,
495
- 0,
496
-
497
- 30,
498
- 30,
499
- 30,
500
- 30,
501
- 30,
502
- 30,
503
- 30,
504
- 30,
505
-
506
- 60,
507
- 60,
508
- 60,
509
- 60,
510
- 60,
511
- 60,
512
- 60,
513
- 60,
514
-
515
- 15,
516
- 15,
517
- 15,
518
- 15,
519
- 15,
520
- 15,
521
- 15,
522
- 15,
523
-
524
- 45,
525
- 45,
526
- 45,
527
- 45,
528
- 45,
529
- 45,
530
- 45,
531
- 45,
532
-
533
- -30,
534
- -30,
535
- -30,
536
- -30,
537
- -30,
538
- -30,
539
- -30,
540
- -30,
541
-
542
- -60,
543
- -60,
544
- -60,
545
- -60,
546
- -60,
547
- -60,
548
- -60,
549
- -60,
550
-
551
- -15,
552
- -15,
553
- -15,
554
- -15,
555
- -15,
556
- -15,
557
- -15,
558
- -15,
559
-
560
- -45,
561
- -45,
562
- -45,
563
- -45,
564
- -45,
565
- -45,
566
- -45,
567
- -45,
568
- ],
569
- "sector": [
570
- "front right",
571
- "front left",
572
- "back right",
573
- "back left",
574
-
575
- "front",
576
- "front right",
577
- "front left",
578
- "right",
579
- "left",
580
- "back right",
581
- "back left",
582
- "back",
583
-
584
- "top front",
585
- "top right",
586
- "top left",
587
- "top right",
588
- "top left",
589
- "top right",
590
- "top left",
591
- "top back",
592
-
593
- "front right",
594
- "front left",
595
- "front right",
596
- "front left",
597
- "back right",
598
- "back left",
599
- "back right",
600
- "back left",
601
-
602
- "front right",
603
- "front left",
604
- "front right",
605
- "front left",
606
- "back right",
607
- "back left",
608
- "back right",
609
- "back left",
610
-
611
- "front",
612
- "front right",
613
- "front left",
614
- "right",
615
- "left",
616
- "back right",
617
- "back left",
618
- "back",
619
-
620
- "bottom front",
621
- "bottom right",
622
- "bottom left",
623
- "bottom right",
624
- "bottom left",
625
- "bottom right",
626
- "bottom left",
627
- "bottom back",
628
-
629
- "bottom front right",
630
- "bottom front left",
631
- "bottom front right",
632
- "bottom front left",
633
- "bottom back right",
634
- "bottom back left",
635
- "bottom back right",
636
- "bottom back left",
637
-
638
- "bottom front right",
639
- "bottom front left",
640
- "bottom front right",
641
- "bottom front left",
642
- "bottom back right",
643
- "bottom back left",
644
- "bottom back right",
645
- "bottom back left",
646
- ]
647
- }
648
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
text2tex/lib/diffusion_helper.py DELETED
@@ -1,189 +0,0 @@
1
- import torch
2
-
3
- import cv2
4
- import numpy as np
5
-
6
- from PIL import Image
7
- from torchvision import transforms
8
-
9
- # Stable Diffusion 2
10
- from diffusers import (
11
- StableDiffusionInpaintPipeline,
12
- StableDiffusionPipeline,
13
- EulerDiscreteScheduler
14
- )
15
-
16
- # customized
17
- import sys
18
- sys.path.append(".")
19
-
20
- from models.ControlNet.gradio_depth2image import init_model, process
21
-
22
-
23
- def get_controlnet_depth():
24
- print("=> initializing ControlNet Depth...")
25
- model, ddim_sampler = init_model()
26
-
27
- return model, ddim_sampler
28
-
29
-
30
- def get_inpainting(device):
31
- print("=> initializing Inpainting...")
32
-
33
- model = StableDiffusionInpaintPipeline.from_pretrained(
34
- "stabilityai/stable-diffusion-2-inpainting",
35
- torch_dtype=torch.float16,
36
- ).to(device)
37
-
38
- return model
39
-
40
- def get_text2image(device):
41
- print("=> initializing Inpainting...")
42
-
43
- model_id = "stabilityai/stable-diffusion-2"
44
- scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
45
- model = StableDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler, torch_dtype=torch.float16).to(device)
46
-
47
- return model
48
-
49
-
50
- @torch.no_grad()
51
- def apply_controlnet_depth(model, ddim_sampler,
52
- init_image, prompt, strength, ddim_steps,
53
- generate_mask_image, keep_mask_image, depth_map_np,
54
- a_prompt, n_prompt, guidance_scale, seed, eta, num_samples,
55
- device, blend=0, save_memory=False):
56
- """
57
- Use Stable Diffusion 2 to generate image
58
-
59
- Arguments:
60
- args: input arguments
61
- model: Stable Diffusion 2 model
62
- init_image_tensor: input image, torch.FloatTensor of shape (1, H, W, 3)
63
- mask_tensor: depth map of the input image, torch.FloatTensor of shape (1, H, W, 1)
64
- depth_map_np: depth map of the input image, torch.FloatTensor of shape (1, H, W)
65
- """
66
-
67
- print("=> generating ControlNet Depth RePaint image...")
68
-
69
-
70
- # Stable Diffusion 2 receives PIL.Image
71
- # NOTE Stable Diffusion 2 returns a PIL.Image object
72
- # image and mask_image should be PIL images.
73
- # The mask structure is white for inpainting and black for keeping as is
74
- diffused_image_np = process(
75
- model, ddim_sampler,
76
- np.array(init_image), prompt, a_prompt, n_prompt, num_samples,
77
- ddim_steps, guidance_scale, seed, eta,
78
- strength=strength, detected_map=depth_map_np, unknown_mask=np.array(generate_mask_image), save_memory=save_memory
79
- )[0]
80
-
81
- init_image = init_image.convert("RGB")
82
- diffused_image = Image.fromarray(diffused_image_np).convert("RGB")
83
-
84
- if blend > 0 and transforms.ToTensor()(keep_mask_image).sum() > 0:
85
- print("=> blending the generated region...")
86
- kernel_size = 3
87
- kernel = np.ones((kernel_size, kernel_size), np.uint8)
88
-
89
- keep_image_np = np.array(init_image).astype(np.uint8)
90
- keep_image_np_dilate = cv2.dilate(keep_image_np, kernel, iterations=1)
91
-
92
- keep_mask_np = np.array(keep_mask_image).astype(np.uint8)
93
- keep_mask_np_dilate = cv2.dilate(keep_mask_np, kernel, iterations=1)
94
-
95
- generate_image_np = np.array(diffused_image).astype(np.uint8)
96
-
97
- overlap_mask_np = np.array(generate_mask_image).astype(np.uint8)
98
- overlap_mask_np *= keep_mask_np_dilate
99
- print("=> blending {} pixels...".format(np.sum(overlap_mask_np)))
100
-
101
- overlap_keep = keep_image_np_dilate[overlap_mask_np == 1]
102
- overlap_generate = generate_image_np[overlap_mask_np == 1]
103
-
104
- overlap_np = overlap_keep * blend + overlap_generate * (1 - blend)
105
-
106
- generate_image_np[overlap_mask_np == 1] = overlap_np
107
-
108
- diffused_image = Image.fromarray(generate_image_np.astype(np.uint8)).convert("RGB")
109
-
110
- init_image_masked = init_image
111
- diffused_image_masked = diffused_image
112
-
113
- return diffused_image, init_image_masked, diffused_image_masked
114
-
115
-
116
- @torch.no_grad()
117
- def apply_inpainting(model,
118
- init_image, mask_image_tensor, prompt, height, width, device):
119
- """
120
- Use Stable Diffusion 2 to generate image
121
-
122
- Arguments:
123
- args: input arguments
124
- model: Stable Diffusion 2 model
125
- init_image_tensor: input image, torch.FloatTensor of shape (1, H, W, 3)
126
- mask_tensor: depth map of the input image, torch.FloatTensor of shape (1, H, W, 1)
127
- depth_map_tensor: depth map of the input image, torch.FloatTensor of shape (1, H, W)
128
- """
129
-
130
- print("=> generating Inpainting image...")
131
-
132
- mask_image = mask_image_tensor[0].cpu()
133
- mask_image = mask_image.permute(2, 0, 1)
134
- mask_image = transforms.ToPILImage()(mask_image).convert("L")
135
-
136
- # NOTE Stable Diffusion 2 returns a PIL.Image object
137
- # image and mask_image should be PIL images.
138
- # The mask structure is white for inpainting and black for keeping as is
139
- diffused_image = model(
140
- prompt=prompt,
141
- image=init_image.resize((512, 512)),
142
- mask_image=mask_image.resize((512, 512)),
143
- height=512,
144
- width=512
145
- ).images[0].resize((height, width))
146
-
147
- return diffused_image
148
-
149
-
150
- @torch.no_grad()
151
- def apply_inpainting_postprocess(model,
152
- init_image, mask_image_tensor, prompt, height, width, device):
153
- """
154
- Use Stable Diffusion 2 to generate image
155
-
156
- Arguments:
157
- args: input arguments
158
- model: Stable Diffusion 2 model
159
- init_image_tensor: input image, torch.FloatTensor of shape (1, H, W, 3)
160
- mask_tensor: depth map of the input image, torch.FloatTensor of shape (1, H, W, 1)
161
- depth_map_tensor: depth map of the input image, torch.FloatTensor of shape (1, H, W)
162
- """
163
-
164
- print("=> generating Inpainting image...")
165
-
166
- mask_image = mask_image_tensor[0].cpu()
167
- mask_image = mask_image.permute(2, 0, 1)
168
- mask_image = transforms.ToPILImage()(mask_image).convert("L")
169
-
170
- # NOTE Stable Diffusion 2 returns a PIL.Image object
171
- # image and mask_image should be PIL images.
172
- # The mask structure is white for inpainting and black for keeping as is
173
- diffused_image = model(
174
- prompt=prompt,
175
- image=init_image.resize((512, 512)),
176
- mask_image=mask_image.resize((512, 512)),
177
- height=512,
178
- width=512
179
- ).images[0].resize((height, width))
180
-
181
- diffused_image_tensor = torch.from_numpy(np.array(diffused_image)).to(device)
182
-
183
- init_images_tensor = torch.from_numpy(np.array(init_image)).to(device)
184
-
185
- init_images_tensor = diffused_image_tensor * mask_image_tensor[0] + init_images_tensor * (1 - mask_image_tensor[0])
186
- init_image = Image.fromarray(init_images_tensor.cpu().numpy().astype(np.uint8)).convert("RGB")
187
-
188
- return init_image
189
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
text2tex/lib/io_helper.py DELETED
@@ -1,78 +0,0 @@
1
- # common utils
2
- import os
3
- import json
4
-
5
- # numpy
6
- import numpy as np
7
-
8
- # visualization
9
- import matplotlib
10
- import matplotlib.cm as cm
11
- import matplotlib.pyplot as plt
12
-
13
- matplotlib.use("Agg")
14
-
15
- from pytorch3d.io import save_obj
16
-
17
- from torchvision import transforms
18
-
19
-
20
- def save_depth(fragments, output_dir, init_image, view_idx):
21
- print("=> saving depth...")
22
- width, height = init_image.size
23
- dpi = 100
24
- figsize = width / float(dpi), height / float(dpi)
25
-
26
- depth_np = fragments.zbuf[0].cpu().numpy()
27
-
28
- fig = plt.figure(figsize=figsize)
29
- ax = fig.add_axes([0, 0, 1, 1])
30
- # Hide spines, ticks, etc.
31
- ax.axis('off')
32
- # Display the image.
33
- ax.imshow(depth_np, cmap='gray')
34
-
35
- plt.savefig(os.path.join(output_dir, "{}.png".format(view_idx)), bbox_inches='tight', pad_inches=0)
36
- np.save(os.path.join(output_dir, "{}.npy".format(view_idx)), depth_np[..., 0])
37
-
38
-
39
- def save_backproject_obj(output_dir, obj_name,
40
- verts, faces, verts_uvs, faces_uvs, projected_texture,
41
- device):
42
- print("=> saving OBJ file...")
43
- texture_map = transforms.ToTensor()(projected_texture).to(device)
44
- texture_map = texture_map.permute(1, 2, 0)
45
- obj_path = os.path.join(output_dir, obj_name)
46
-
47
- save_obj(
48
- obj_path,
49
- verts=verts,
50
- faces=faces,
51
- decimal_places=5,
52
- verts_uvs=verts_uvs,
53
- faces_uvs=faces_uvs,
54
- texture_map=texture_map
55
- )
56
-
57
-
58
- def save_args(args, output_dir):
59
- with open(os.path.join(output_dir, "args.json"), "w") as f:
60
- json.dump(
61
- {k: v for k, v in vars(args).items()},
62
- f,
63
- indent=4
64
- )
65
-
66
-
67
- def save_viewpoints(args, output_dir, dist_list, elev_list, azim_list, view_list):
68
- with open(os.path.join(output_dir, "viewpoints.json"), "w") as f:
69
- json.dump(
70
- {
71
- "dist": dist_list,
72
- "elev": elev_list,
73
- "azim": azim_list,
74
- "view": view_list
75
- },
76
- f,
77
- indent=4
78
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
text2tex/lib/mesh_helper.py DELETED
@@ -1,148 +0,0 @@
1
- import os
2
- import torch
3
- import trimesh
4
- import xatlas
5
-
6
- import numpy as np
7
-
8
- from sklearn.decomposition import PCA
9
-
10
- from torchvision import transforms
11
-
12
- from tqdm import tqdm
13
-
14
- from pytorch3d.io import (
15
- load_obj,
16
- load_objs_as_meshes
17
- )
18
-
19
-
20
- def compute_principle_directions(model_path, num_points=20000):
21
- mesh = trimesh.load_mesh(model_path, force="mesh")
22
- pc, _ = trimesh.sample.sample_surface_even(mesh, num_points)
23
-
24
- pc -= np.mean(pc, axis=0, keepdims=True)
25
-
26
- principle_directions = PCA(n_components=3).fit(pc).components_
27
-
28
- return principle_directions
29
-
30
-
31
- def init_mesh(input_path, cache_path, device):
32
- print("=> parameterizing target mesh...")
33
-
34
- mesh = trimesh.load_mesh(input_path, force='mesh')
35
- try:
36
- vertices, faces = mesh.vertices, mesh.faces
37
- except AttributeError:
38
- print("multiple materials in {} are not supported".format(input_path))
39
- exit()
40
-
41
- vmapping, indices, uvs = xatlas.parametrize(vertices, faces)
42
- xatlas.export(str(cache_path), vertices[vmapping], indices, uvs)
43
-
44
- print("=> loading target mesh...")
45
-
46
- # principle_directions = compute_principle_directions(cache_path)
47
- principle_directions = None
48
-
49
- _, faces, aux = load_obj(cache_path, device=device)
50
- mesh = load_objs_as_meshes([cache_path], device=device)
51
-
52
- num_verts = mesh.verts_packed().shape[0]
53
-
54
- # make sure mesh center is at origin
55
- bbox = mesh.get_bounding_boxes()
56
- mesh_center = bbox.mean(dim=2).repeat(num_verts, 1)
57
- mesh = apply_offsets_to_mesh(mesh, -mesh_center)
58
-
59
- # make sure mesh size is normalized
60
- box_size = bbox[..., 1] - bbox[..., 0]
61
- box_max = box_size.max(dim=1, keepdim=True)[0].repeat(num_verts, 3)
62
- mesh = apply_scale_to_mesh(mesh, 1 / box_max)
63
-
64
- return mesh, mesh.verts_packed(), faces, aux, principle_directions, mesh_center, box_max
65
-
66
-
67
- def apply_offsets_to_mesh(mesh, offsets):
68
- new_mesh = mesh.offset_verts(offsets)
69
-
70
- return new_mesh
71
-
72
- def apply_scale_to_mesh(mesh, scale):
73
- new_mesh = mesh.scale_verts(scale)
74
-
75
- return new_mesh
76
-
77
-
78
- def adjust_uv_map(faces, aux, init_texture, uv_size):
79
- """
80
- adjust UV map to be compatiable with multiple textures.
81
- UVs for different materials will be decomposed and placed horizontally
82
-
83
- +-----+-----+-----+--
84
- | 1 | 2 | 3 |
85
- +-----+-----+-----+--
86
-
87
- """
88
-
89
- textures_ids = faces.textures_idx
90
- materials_idx = faces.materials_idx
91
- verts_uvs = aux.verts_uvs
92
-
93
- num_materials = torch.unique(materials_idx).shape[0]
94
-
95
- new_verts_uvs = verts_uvs.clone()
96
- for material_id in range(num_materials):
97
- # apply offsets to horizontal axis
98
- faces_ids = textures_ids[materials_idx == material_id].unique()
99
- new_verts_uvs[faces_ids, 0] += material_id
100
-
101
- new_verts_uvs[:, 0] /= num_materials
102
-
103
- init_texture_tensor = transforms.ToTensor()(init_texture)
104
- init_texture_tensor = torch.cat([init_texture_tensor for _ in range(num_materials)], dim=-1)
105
- init_texture = transforms.ToPILImage()(init_texture_tensor).resize((uv_size, uv_size))
106
-
107
- return new_verts_uvs, init_texture
108
-
109
-
110
- @torch.no_grad()
111
- def update_face_angles(mesh, cameras, fragments):
112
- def get_angle(x, y):
113
- x = torch.nn.functional.normalize(x)
114
- y = torch.nn.functional.normalize(y)
115
- inner_product = (x * y).sum(dim=1)
116
- x_norm = x.pow(2).sum(dim=1).pow(0.5)
117
- y_norm = y.pow(2).sum(dim=1).pow(0.5)
118
- cos = inner_product / (x_norm * y_norm)
119
- angle = torch.acos(cos)
120
- angle = angle * 180 / 3.14159
121
-
122
- return angle
123
-
124
- # face normals
125
- face_normals = mesh.faces_normals_padded()[0]
126
-
127
- # view vector (object center -> camera center)
128
- camera_center = cameras.get_camera_center()
129
-
130
- face_angles = get_angle(
131
- face_normals,
132
- camera_center.repeat(face_normals.shape[0], 1)
133
- ) # (F)
134
-
135
- face_angles_rev = get_angle(
136
- face_normals,
137
- -camera_center.repeat(face_normals.shape[0], 1)
138
- ) # (F)
139
-
140
- face_angles = torch.minimum(face_angles, face_angles_rev)
141
-
142
- # Indices of unique visible faces
143
- visible_map = fragments.pix_to_face.unique() # (num_visible_faces)
144
- invisible_mask = torch.ones_like(face_angles)
145
- invisible_mask[visible_map] = 0
146
- face_angles[invisible_mask == 1] = 10000. # angles of invisible faces are ignored
147
-
148
- return face_angles
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
text2tex/lib/projection_helper.py DELETED
@@ -1,464 +0,0 @@
1
- import os
2
- import torch
3
-
4
- import cv2
5
- import random
6
-
7
- import numpy as np
8
-
9
- from torchvision import transforms
10
-
11
- from pytorch3d.renderer import TexturesUV
12
- from pytorch3d.ops import interpolate_face_attributes
13
-
14
- from PIL import Image
15
-
16
- from tqdm import tqdm
17
-
18
- # customized
19
- import sys
20
- sys.path.append(".")
21
-
22
- from lib.camera_helper import init_camera
23
- from lib.render_helper import init_renderer, render
24
- from lib.shading_helper import (
25
- BlendParams,
26
- init_soft_phong_shader,
27
- init_flat_texel_shader,
28
- )
29
- from lib.vis_helper import visualize_outputs, visualize_quad_mask
30
- from lib.constants import *
31
-
32
-
33
- def get_all_4_locations(values_y, values_x):
34
- y_0 = torch.floor(values_y)
35
- y_1 = torch.ceil(values_y)
36
- x_0 = torch.floor(values_x)
37
- x_1 = torch.ceil(values_x)
38
-
39
- return torch.cat([y_0, y_0, y_1, y_1], 0).long(), torch.cat([x_0, x_1, x_0, x_1], 0).long()
40
-
41
-
42
- def compose_quad_mask(new_mask_image, update_mask_image, old_mask_image, device):
43
- """
44
- compose quad mask:
45
- -> 0: background
46
- -> 1: old
47
- -> 2: update
48
- -> 3: new
49
- """
50
-
51
- new_mask_tensor = transforms.ToTensor()(new_mask_image).to(device)
52
- update_mask_tensor = transforms.ToTensor()(update_mask_image).to(device)
53
- old_mask_tensor = transforms.ToTensor()(old_mask_image).to(device)
54
-
55
- all_mask_tensor = new_mask_tensor + update_mask_tensor + old_mask_tensor
56
-
57
- quad_mask_tensor = torch.zeros_like(all_mask_tensor)
58
- quad_mask_tensor[old_mask_tensor == 1] = 1
59
- quad_mask_tensor[update_mask_tensor == 1] = 2
60
- quad_mask_tensor[new_mask_tensor == 1] = 3
61
-
62
- return old_mask_tensor, update_mask_tensor, new_mask_tensor, all_mask_tensor, quad_mask_tensor
63
-
64
-
65
- def compute_view_heat(similarity_tensor, quad_mask_tensor):
66
- num_total_pixels = quad_mask_tensor.reshape(-1).shape[0]
67
- heat = 0
68
- for idx in QUAD_WEIGHTS:
69
- heat += (quad_mask_tensor == idx).sum() * QUAD_WEIGHTS[idx] / num_total_pixels
70
-
71
- return heat
72
-
73
-
74
- def select_viewpoint(selected_view_ids, view_punishments,
75
- mode, dist_list, elev_list, azim_list, sector_list, view_idx,
76
- similarity_texture_cache, exist_texture,
77
- mesh, faces, verts_uvs,
78
- image_size, faces_per_pixel,
79
- init_image_dir, mask_image_dir, normal_map_dir, depth_map_dir, similarity_map_dir,
80
- device, use_principle=False
81
- ):
82
- if mode == "sequential":
83
-
84
- num_views = len(dist_list)
85
-
86
- dist = dist_list[view_idx % num_views]
87
- elev = elev_list[view_idx % num_views]
88
- azim = azim_list[view_idx % num_views]
89
- sector = sector_list[view_idx % num_views]
90
-
91
- selected_view_ids.append(view_idx % num_views)
92
-
93
- elif mode == "heuristic":
94
-
95
- if use_principle and view_idx < 6:
96
-
97
- selected_view_idx = view_idx
98
-
99
- else:
100
-
101
- selected_view_idx = None
102
- max_heat = 0
103
-
104
- print("=> selecting next view...")
105
- view_heat_list = []
106
- for sample_idx in tqdm(range(len(dist_list))):
107
-
108
- view_heat, *_ = render_one_view_and_build_masks(dist_list[sample_idx], elev_list[sample_idx], azim_list[sample_idx],
109
- sample_idx, sample_idx, view_punishments,
110
- similarity_texture_cache, exist_texture,
111
- mesh, faces, verts_uvs,
112
- image_size, faces_per_pixel,
113
- init_image_dir, mask_image_dir, normal_map_dir, depth_map_dir, similarity_map_dir,
114
- device)
115
-
116
- if view_heat > max_heat:
117
- selected_view_idx = sample_idx
118
- max_heat = view_heat
119
-
120
- view_heat_list.append(view_heat.item())
121
-
122
- print(view_heat_list)
123
- print("select view {} with heat {}".format(selected_view_idx, max_heat))
124
-
125
-
126
- dist = dist_list[selected_view_idx]
127
- elev = elev_list[selected_view_idx]
128
- azim = azim_list[selected_view_idx]
129
- sector = sector_list[selected_view_idx]
130
-
131
- selected_view_ids.append(selected_view_idx)
132
-
133
- view_punishments[selected_view_idx] *= 0.01
134
-
135
- elif mode == "random":
136
-
137
- selected_view_idx = random.choice(range(len(dist_list)))
138
-
139
- dist = dist_list[selected_view_idx]
140
- elev = elev_list[selected_view_idx]
141
- azim = azim_list[selected_view_idx]
142
- sector = sector_list[selected_view_idx]
143
-
144
- selected_view_ids.append(selected_view_idx)
145
-
146
- else:
147
- raise NotImplementedError()
148
-
149
- return dist, elev, azim, sector, selected_view_ids, view_punishments
150
-
151
-
152
- @torch.no_grad()
153
- def build_backproject_mask(mesh, faces, verts_uvs,
154
- cameras, reference_image, faces_per_pixel,
155
- image_size, uv_size, device):
156
- # construct pixel UVs
157
- renderer_scaled = init_renderer(cameras,
158
- shader=init_soft_phong_shader(
159
- camera=cameras,
160
- blend_params=BlendParams(),
161
- device=device),
162
- image_size=image_size,
163
- faces_per_pixel=faces_per_pixel
164
- )
165
- fragments_scaled = renderer_scaled.rasterizer(mesh)
166
-
167
- # get UV coordinates for each pixel
168
- faces_verts_uvs = verts_uvs[faces.textures_idx]
169
-
170
- pixel_uvs = interpolate_face_attributes(
171
- fragments_scaled.pix_to_face, fragments_scaled.bary_coords, faces_verts_uvs
172
- ) # NxHsxWsxKx2
173
- pixel_uvs = pixel_uvs.permute(0, 3, 1, 2, 4).reshape(-1, 2)
174
-
175
- texture_locations_y, texture_locations_x = get_all_4_locations(
176
- (1 - pixel_uvs[:, 1]).reshape(-1) * (uv_size - 1),
177
- pixel_uvs[:, 0].reshape(-1) * (uv_size - 1)
178
- )
179
-
180
- K = faces_per_pixel
181
-
182
- texture_values = torch.from_numpy(np.array(reference_image.resize((image_size, image_size)))).float() / 255.
183
- texture_values = texture_values.to(device).unsqueeze(0).expand([4, -1, -1, -1]).unsqueeze(0).expand([K, -1, -1, -1, -1])
184
-
185
- # texture
186
- texture_tensor = torch.zeros(uv_size, uv_size, 3).to(device)
187
- texture_tensor[texture_locations_y, texture_locations_x, :] = texture_values.reshape(-1, 3)
188
-
189
- return texture_tensor[:, :, 0]
190
-
191
-
192
- @torch.no_grad()
193
- def build_diffusion_mask(mesh_stuff,
194
- renderer, exist_texture, similarity_texture_cache, target_value, device, image_size,
195
- smooth_mask=False, view_threshold=0.01):
196
-
197
- mesh, faces, verts_uvs = mesh_stuff
198
- mask_mesh = mesh.clone() # NOTE in-place operation - DANGER!!!
199
-
200
- # visible mask => the whole region
201
- exist_texture_expand = exist_texture.unsqueeze(0).unsqueeze(-1).expand(-1, -1, -1, 3).to(device)
202
- mask_mesh.textures = TexturesUV(
203
- maps=torch.ones_like(exist_texture_expand),
204
- faces_uvs=faces.textures_idx[None, ...],
205
- verts_uvs=verts_uvs[None, ...],
206
- sampling_mode="nearest"
207
- )
208
- # visible_mask_tensor, *_ = render(mask_mesh, renderer)
209
- visible_mask_tensor, _, similarity_map_tensor, *_ = render(mask_mesh, renderer)
210
- # faces that are too rotated away from the viewpoint will be treated as invisible
211
- valid_mask_tensor = (similarity_map_tensor >= view_threshold).float()
212
- visible_mask_tensor *= valid_mask_tensor
213
-
214
- # nonexist mask <=> new mask
215
- exist_texture_expand = exist_texture.unsqueeze(0).unsqueeze(-1).expand(-1, -1, -1, 3).to(device)
216
- mask_mesh.textures = TexturesUV(
217
- maps=1 - exist_texture_expand,
218
- faces_uvs=faces.textures_idx[None, ...],
219
- verts_uvs=verts_uvs[None, ...],
220
- sampling_mode="nearest"
221
- )
222
- new_mask_tensor, *_ = render(mask_mesh, renderer)
223
- new_mask_tensor *= valid_mask_tensor
224
-
225
- # exist mask => visible mask - new mask
226
- exist_mask_tensor = visible_mask_tensor - new_mask_tensor
227
- exist_mask_tensor[exist_mask_tensor < 0] = 0 # NOTE dilate can lead to overflow
228
-
229
- # all update mask
230
- mask_mesh.textures = TexturesUV(
231
- maps=(
232
- similarity_texture_cache.argmax(0) == target_value
233
- # # only consider the views that have already appeared before
234
- # similarity_texture_cache[0:target_value+1].argmax(0) == target_value
235
- ).float().unsqueeze(0).unsqueeze(-1).expand(-1, -1, -1, 3).to(device),
236
- faces_uvs=faces.textures_idx[None, ...],
237
- verts_uvs=verts_uvs[None, ...],
238
- sampling_mode="nearest"
239
- )
240
- all_update_mask_tensor, *_ = render(mask_mesh, renderer)
241
-
242
- # current update mask => intersection between all update mask and exist mask
243
- update_mask_tensor = exist_mask_tensor * all_update_mask_tensor
244
-
245
- # keep mask => exist mask - update mask
246
- old_mask_tensor = exist_mask_tensor - update_mask_tensor
247
-
248
- # convert
249
- new_mask = new_mask_tensor[0].cpu().float().permute(2, 0, 1)
250
- new_mask = transforms.ToPILImage()(new_mask).convert("L")
251
-
252
- update_mask = update_mask_tensor[0].cpu().float().permute(2, 0, 1)
253
- update_mask = transforms.ToPILImage()(update_mask).convert("L")
254
-
255
- old_mask = old_mask_tensor[0].cpu().float().permute(2, 0, 1)
256
- old_mask = transforms.ToPILImage()(old_mask).convert("L")
257
-
258
- exist_mask = exist_mask_tensor[0].cpu().float().permute(2, 0, 1)
259
- exist_mask = transforms.ToPILImage()(exist_mask).convert("L")
260
-
261
- return new_mask, update_mask, old_mask, exist_mask
262
-
263
-
264
- @torch.no_grad()
265
- def render_one_view(mesh,
266
- dist, elev, azim,
267
- image_size, faces_per_pixel,
268
- device):
269
-
270
- # render the view
271
- cameras = init_camera(
272
- dist, elev, azim,
273
- image_size, device
274
- )
275
- renderer = init_renderer(cameras,
276
- shader=init_soft_phong_shader(
277
- camera=cameras,
278
- blend_params=BlendParams(),
279
- device=device),
280
- image_size=image_size,
281
- faces_per_pixel=faces_per_pixel
282
- )
283
-
284
- init_images_tensor, normal_maps_tensor, similarity_tensor, depth_maps_tensor, fragments = render(mesh, renderer)
285
-
286
- return (
287
- cameras, renderer,
288
- init_images_tensor, normal_maps_tensor, similarity_tensor, depth_maps_tensor, fragments
289
- )
290
-
291
-
292
- @torch.no_grad()
293
- def build_similarity_texture_cache_for_all_views(mesh, faces, verts_uvs,
294
- dist_list, elev_list, azim_list,
295
- image_size, image_size_scaled, uv_size, faces_per_pixel,
296
- device):
297
-
298
- num_candidate_views = len(dist_list)
299
- similarity_texture_cache = torch.zeros(num_candidate_views, uv_size, uv_size).to(device)
300
-
301
- print("=> building similarity texture cache for all views...")
302
- for i in tqdm(range(num_candidate_views)):
303
- cameras, _, _, _, similarity_tensor, _, _ = render_one_view(mesh,
304
- dist_list[i], elev_list[i], azim_list[i],
305
- image_size, faces_per_pixel, device)
306
-
307
- similarity_texture_cache[i] = build_backproject_mask(mesh, faces, verts_uvs,
308
- cameras, transforms.ToPILImage()(similarity_tensor[0, :, :, 0]).convert("RGB"), faces_per_pixel,
309
- image_size_scaled, uv_size, device)
310
-
311
- return similarity_texture_cache
312
-
313
-
314
- @torch.no_grad()
315
- def render_one_view_and_build_masks(dist, elev, azim,
316
- selected_view_idx, view_idx, view_punishments,
317
- similarity_texture_cache, exist_texture,
318
- mesh, faces, verts_uvs,
319
- image_size, faces_per_pixel,
320
- init_image_dir, mask_image_dir, normal_map_dir, depth_map_dir, similarity_map_dir,
321
- device, save_intermediate=False, smooth_mask=False, view_threshold=0.01):
322
-
323
- # render the view
324
- (
325
- cameras, renderer,
326
- init_images_tensor, normal_maps_tensor, similarity_tensor, depth_maps_tensor, fragments
327
- ) = render_one_view(mesh,
328
- dist, elev, azim,
329
- image_size, faces_per_pixel,
330
- device
331
- )
332
-
333
- init_image = init_images_tensor[0].cpu()
334
- init_image = init_image.permute(2, 0, 1)
335
- init_image = transforms.ToPILImage()(init_image).convert("RGB")
336
-
337
- normal_map = normal_maps_tensor[0].cpu()
338
- normal_map = normal_map.permute(2, 0, 1)
339
- normal_map = transforms.ToPILImage()(normal_map).convert("RGB")
340
-
341
- depth_map = depth_maps_tensor[0].cpu().numpy()
342
- depth_map = Image.fromarray(depth_map).convert("L")
343
-
344
- similarity_map = similarity_tensor[0, :, :, 0].cpu()
345
- similarity_map = transforms.ToPILImage()(similarity_map).convert("L")
346
-
347
-
348
- flat_renderer = init_renderer(cameras,
349
- shader=init_flat_texel_shader(
350
- camera=cameras,
351
- device=device),
352
- image_size=image_size,
353
- faces_per_pixel=faces_per_pixel
354
- )
355
- new_mask_image, update_mask_image, old_mask_image, exist_mask_image = build_diffusion_mask(
356
- (mesh, faces, verts_uvs),
357
- flat_renderer, exist_texture, similarity_texture_cache, selected_view_idx, device, image_size,
358
- smooth_mask=smooth_mask, view_threshold=view_threshold
359
- )
360
- # NOTE the view idx is the absolute idx in the sample space (i.e. `selected_view_idx`)
361
- # it should match with `similarity_texture_cache`
362
-
363
- (
364
- old_mask_tensor,
365
- update_mask_tensor,
366
- new_mask_tensor,
367
- all_mask_tensor,
368
- quad_mask_tensor
369
- ) = compose_quad_mask(new_mask_image, update_mask_image, old_mask_image, device)
370
-
371
- view_heat = compute_view_heat(similarity_tensor, quad_mask_tensor)
372
- view_heat *= view_punishments[selected_view_idx]
373
-
374
- # save intermediate results
375
- if save_intermediate:
376
- init_image.save(os.path.join(init_image_dir, "{}.png".format(view_idx)))
377
- normal_map.save(os.path.join(normal_map_dir, "{}.png".format(view_idx)))
378
- depth_map.save(os.path.join(depth_map_dir, "{}.png".format(view_idx)))
379
- similarity_map.save(os.path.join(similarity_map_dir, "{}.png".format(view_idx)))
380
-
381
- new_mask_image.save(os.path.join(mask_image_dir, "{}_new.png".format(view_idx)))
382
- update_mask_image.save(os.path.join(mask_image_dir, "{}_update.png".format(view_idx)))
383
- old_mask_image.save(os.path.join(mask_image_dir, "{}_old.png".format(view_idx)))
384
- exist_mask_image.save(os.path.join(mask_image_dir, "{}_exist.png".format(view_idx)))
385
-
386
- visualize_quad_mask(mask_image_dir, quad_mask_tensor, view_idx, view_heat, device)
387
-
388
- return (
389
- view_heat,
390
- renderer, cameras, fragments,
391
- init_image, normal_map, depth_map,
392
- init_images_tensor, normal_maps_tensor, depth_maps_tensor, similarity_tensor,
393
- old_mask_image, update_mask_image, new_mask_image,
394
- old_mask_tensor, update_mask_tensor, new_mask_tensor, all_mask_tensor, quad_mask_tensor
395
- )
396
-
397
-
398
-
399
- @torch.no_grad()
400
- def backproject_from_image(mesh, faces, verts_uvs, cameras,
401
- reference_image, new_mask_image, update_mask_image,
402
- init_texture, exist_texture,
403
- image_size, uv_size, faces_per_pixel,
404
- device):
405
-
406
- # construct pixel UVs
407
- renderer_scaled = init_renderer(cameras,
408
- shader=init_soft_phong_shader(
409
- camera=cameras,
410
- blend_params=BlendParams(),
411
- device=device),
412
- image_size=image_size,
413
- faces_per_pixel=faces_per_pixel
414
- )
415
- fragments_scaled = renderer_scaled.rasterizer(mesh)
416
-
417
- # get UV coordinates for each pixel
418
- faces_verts_uvs = verts_uvs[faces.textures_idx]
419
-
420
- pixel_uvs = interpolate_face_attributes(
421
- fragments_scaled.pix_to_face, fragments_scaled.bary_coords, faces_verts_uvs
422
- ) # NxHsxWsxKx2
423
- pixel_uvs = pixel_uvs.permute(0, 3, 1, 2, 4).reshape(pixel_uvs.shape[-2], pixel_uvs.shape[1], pixel_uvs.shape[2], 2)
424
-
425
- # the update mask has to be on top of the diffusion mask
426
- new_mask_image_tensor = transforms.ToTensor()(new_mask_image).to(device).unsqueeze(-1)
427
- update_mask_image_tensor = transforms.ToTensor()(update_mask_image).to(device).unsqueeze(-1)
428
-
429
- project_mask_image_tensor = torch.logical_or(update_mask_image_tensor, new_mask_image_tensor).float()
430
- project_mask_image = project_mask_image_tensor * 255.
431
- project_mask_image = Image.fromarray(project_mask_image[0, :, :, 0].cpu().numpy().astype(np.uint8))
432
-
433
- project_mask_image_scaled = project_mask_image.resize(
434
- (image_size, image_size),
435
- Image.Resampling.NEAREST
436
- )
437
- project_mask_image_tensor_scaled = transforms.ToTensor()(project_mask_image_scaled).to(device)
438
-
439
- pixel_uvs_masked = pixel_uvs[project_mask_image_tensor_scaled == 1]
440
-
441
- texture_locations_y, texture_locations_x = get_all_4_locations(
442
- (1 - pixel_uvs_masked[:, 1]).reshape(-1) * (uv_size - 1),
443
- pixel_uvs_masked[:, 0].reshape(-1) * (uv_size - 1)
444
- )
445
-
446
- K = pixel_uvs.shape[0]
447
- project_mask_image_tensor_scaled = project_mask_image_tensor_scaled[:, None, :, :, None].repeat(1, 4, 1, 1, 3)
448
-
449
- texture_values = torch.from_numpy(np.array(reference_image.resize((image_size, image_size))))
450
- texture_values = texture_values.to(device).unsqueeze(0).expand([4, -1, -1, -1]).unsqueeze(0).expand([K, -1, -1, -1, -1])
451
-
452
- texture_values_masked = texture_values.reshape(-1, 3)[project_mask_image_tensor_scaled.reshape(-1, 3) == 1].reshape(-1, 3)
453
-
454
- # texture
455
- texture_tensor = torch.from_numpy(np.array(init_texture)).to(device)
456
- texture_tensor[texture_locations_y, texture_locations_x, :] = texture_values_masked
457
-
458
- init_texture = Image.fromarray(texture_tensor.cpu().numpy().astype(np.uint8))
459
-
460
- # update texture cache
461
- exist_texture[texture_locations_y, texture_locations_x] = 1
462
-
463
- return init_texture, project_mask_image, exist_texture
464
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
text2tex/lib/render_helper.py DELETED
@@ -1,108 +0,0 @@
1
- import os
2
- import torch
3
-
4
- import cv2
5
-
6
- import numpy as np
7
-
8
- from PIL import Image
9
-
10
- from torchvision import transforms
11
- from pytorch3d.ops import interpolate_face_attributes
12
- from pytorch3d.renderer import (
13
- RasterizationSettings,
14
- MeshRendererWithFragments,
15
- MeshRasterizer,
16
- )
17
-
18
- # customized
19
- import sys
20
- sys.path.append(".")
21
-
22
-
23
- def init_renderer(camera, shader, image_size, faces_per_pixel):
24
- raster_settings = RasterizationSettings(image_size=image_size, faces_per_pixel=faces_per_pixel)
25
- renderer = MeshRendererWithFragments(
26
- rasterizer=MeshRasterizer(
27
- cameras=camera,
28
- raster_settings=raster_settings
29
- ),
30
- shader=shader
31
- )
32
-
33
- return renderer
34
-
35
-
36
- @torch.no_grad()
37
- def render(mesh, renderer, pad_value=10):
38
- def phong_normal_shading(meshes, fragments) -> torch.Tensor:
39
- faces = meshes.faces_packed() # (F, 3)
40
- vertex_normals = meshes.verts_normals_packed() # (V, 3)
41
- faces_normals = vertex_normals[faces]
42
- pixel_normals = interpolate_face_attributes(
43
- fragments.pix_to_face, fragments.bary_coords, faces_normals
44
- )
45
-
46
- return pixel_normals
47
-
48
- def similarity_shading(meshes, fragments):
49
- faces = meshes.faces_packed() # (F, 3)
50
- vertex_normals = meshes.verts_normals_packed() # (V, 3)
51
- faces_normals = vertex_normals[faces]
52
- vertices = meshes.verts_packed() # (V, 3)
53
- face_positions = vertices[faces]
54
- view_directions = torch.nn.functional.normalize((renderer.shader.cameras.get_camera_center().reshape(1, 1, 3) - face_positions), p=2, dim=2)
55
- cosine_similarity = torch.nn.CosineSimilarity(dim=2)(faces_normals, view_directions)
56
- pixel_similarity = interpolate_face_attributes(
57
- fragments.pix_to_face, fragments.bary_coords, cosine_similarity.unsqueeze(-1)
58
- )
59
-
60
- return pixel_similarity
61
-
62
- def get_relative_depth_map(fragments, pad_value=pad_value):
63
- absolute_depth = fragments.zbuf[..., 0] # B, H, W
64
- no_depth = -1
65
-
66
- depth_min, depth_max = absolute_depth[absolute_depth != no_depth].min(), absolute_depth[absolute_depth != no_depth].max()
67
- target_min, target_max = 50, 255
68
-
69
- depth_value = absolute_depth[absolute_depth != no_depth]
70
- depth_value = depth_max - depth_value # reverse values
71
-
72
- depth_value /= (depth_max - depth_min)
73
- depth_value = depth_value * (target_max - target_min) + target_min
74
-
75
- relative_depth = absolute_depth.clone()
76
- relative_depth[absolute_depth != no_depth] = depth_value
77
- relative_depth[absolute_depth == no_depth] = pad_value # not completely black
78
-
79
- return relative_depth
80
-
81
-
82
- images, fragments = renderer(mesh)
83
- normal_maps = phong_normal_shading(mesh, fragments).squeeze(-2)
84
- similarity_maps = similarity_shading(mesh, fragments).squeeze(-2) # -1 - 1
85
- depth_maps = get_relative_depth_map(fragments)
86
-
87
- # normalize similarity mask to 0 - 1
88
- similarity_maps = torch.abs(similarity_maps) # 0 - 1
89
-
90
- # HACK erode, eliminate isolated dots
91
- non_zero_similarity = (similarity_maps > 0).float()
92
- non_zero_similarity = (non_zero_similarity * 255.).cpu().numpy().astype(np.uint8)[0]
93
- non_zero_similarity = cv2.erode(non_zero_similarity, kernel=np.ones((3, 3), np.uint8), iterations=2)
94
- non_zero_similarity = torch.from_numpy(non_zero_similarity).to(similarity_maps.device).unsqueeze(0) / 255.
95
- similarity_maps = non_zero_similarity.unsqueeze(-1) * similarity_maps
96
-
97
- return images, normal_maps, similarity_maps, depth_maps, fragments
98
-
99
-
100
- @torch.no_grad()
101
- def check_visible_faces(mesh, fragments):
102
- pix_to_face = fragments.pix_to_face
103
-
104
- # Indices of unique visible faces
105
- visible_map = pix_to_face.unique() # (num_visible_faces)
106
-
107
- return visible_map
108
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
text2tex/lib/shading_helper.py DELETED
@@ -1,45 +0,0 @@
1
- from typing import NamedTuple, Sequence
2
-
3
- from pytorch3d.renderer.mesh.shader import ShaderBase
4
- from pytorch3d.renderer import (
5
- AmbientLights,
6
- SoftPhongShader
7
- )
8
-
9
-
10
- class BlendParams(NamedTuple):
11
- sigma: float = 1e-4
12
- gamma: float = 1e-4
13
- background_color: Sequence = (1, 1, 1)
14
-
15
-
16
- class FlatTexelShader(ShaderBase):
17
-
18
- def __init__(self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None):
19
- super().__init__(device, cameras, lights, materials, blend_params)
20
-
21
- def forward(self, fragments, meshes, **_kwargs):
22
- texels = meshes.sample_textures(fragments)
23
- texels[(fragments.pix_to_face == -1), :] = 0
24
- return texels.squeeze(-2)
25
-
26
-
27
- def init_soft_phong_shader(camera, blend_params, device):
28
- lights = AmbientLights(device=device)
29
- shader = SoftPhongShader(
30
- cameras=camera,
31
- lights=lights,
32
- device=device,
33
- blend_params=blend_params
34
- )
35
-
36
- return shader
37
-
38
-
39
- def init_flat_texel_shader(camera, device):
40
- shader=FlatTexelShader(
41
- cameras=camera,
42
- device=device
43
- )
44
-
45
- return shader