SerdarHelli commited on
Commit
a7808a2
·
1 Parent(s): 42629a8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +135 -0
app.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import plotly.graph_objects as go
4
+ import sys
5
+ import torch
6
+ from huggingface_hub import hf_hub_download
7
+ import numpy as np
8
+ import random
9
+
10
+ os.system("https://github.com/Zhengxinyang/SDF-StyleGAN.git")
11
+ sys.path.append("SDF-StyleGAN")
12
+
13
+ #Codes reference : https://github.com/Zhengxinyang/SDF-StyleGAN
14
+
15
+ from utils.utils import noise, evaluate_in_chunks, scale_to_unit_sphere, volume_noise, process_sdf, linear_slerp
16
+ from network.model import StyleGAN2_3D
17
+
18
+
19
+ cars=hf_hub_download("SerdarHelli/SDF-StyleGAN-3D", filename="cars.ckpt",revision="main")
20
+
21
+
22
+ ["Car","Airplane","Chair","Rifle","Table"]
23
+
24
+ #default model
25
+ device='cuda' if torch.cuda.is_available() else 'cpu'
26
+ if device=="cuda":
27
+ model = StyleGAN2_3D.load_from_checkpoint(cars).cuda(0)
28
+ else:
29
+ model = StyleGAN2_3D.load_from_checkpoint(cars)
30
+ model.eval()
31
+
32
+
33
+ models={"Car":cars,
34
+ "Airplane":"./planes.ckpt"
35
+ "Chair":"./chairs.ckpt",
36
+ "Rifle":"./rifles.ckpt",
37
+ "Table":"./tables.ckpt"
38
+ }
39
+
40
+
41
+ def seed_all(seed):
42
+
43
+ torch.manual_seed(seed)
44
+ np.random.seed(seed)
45
+ random.seed(seed)
46
+
47
+
48
+ def change_model(ckpt_path):
49
+ global model
50
+ if device=="cuda":
51
+ model = StyleGAN2_3D.load_from_checkpoint(cars).cuda(0)
52
+ else:
53
+ model = StyleGAN2_3D.load_from_checkpoint(cars)
54
+ model.eval()
55
+
56
+
57
+ def predict(seed,trunc_psi):
58
+ if seed==None:
59
+ seed=777
60
+ seed_all(seed)
61
+ if trunc_psi==None:
62
+ trunc_psi=1
63
+
64
+ z = noise(100000, model.latent_dim, device=model.device)
65
+ samples = evaluate_in_chunks(1000, model.SE, z)
66
+ model.av = torch.mean(samples, dim=0, keepdim=True)
67
+
68
+ mesh = model.generate_mesh(
69
+ ema=True, mc_vol_size=64, level=-0.015, trunc_psi=trunc_psi)
70
+ mesh = scale_to_unit_sphere(mesh)
71
+ mesh.export("/content/asdads.obj")
72
+ x=np.asarray(mesh.vertices).T[0]
73
+ y=np.asarray(mesh.vertices).T[1]
74
+ z=np.asarray(mesh.vertices).T[2]
75
+
76
+ i=np.asarray(mesh.faces).T[0]
77
+ j=np.asarray(mesh.faces).T[1]
78
+ k=np.asarray(mesh.faces).T[2]
79
+
80
+ return x,y,z,i,j,k
81
+
82
+ def generate(seed,model,trunc_psi):
83
+
84
+ global model
85
+ change_model(models[model])
86
+ x,y,z,i,j,k=predict(seed,trunc_psi)
87
+
88
+
89
+ fig = go.Figure(go.Mesh3d(x=x, y=y, z=z,
90
+ i=i, j=j, k=k,
91
+ colorscale="Viridis",
92
+ colorbar_len=0.75,
93
+ flatshading=True,
94
+ lighting=dict(ambient=0.5,
95
+ diffuse=1,
96
+ fresnel=4,
97
+ specular=0.5,
98
+ roughness=0.05,
99
+ facenormalsepsilon=0,
100
+ vertexnormalsepsilon=0),
101
+ lightposition=dict(x=100,
102
+ y=100,
103
+ z=1000)))
104
+ return fig
105
+
106
+ markdown=f'''
107
+ # SDF-StyleGAN: Implicit SDF-Based StyleGAN for 3D Shape Generation
108
+
109
+
110
+ [The space demo for the SGP 2022 paper "SDF-StyleGAN: Implicit SDF-Based StyleGAN for 3D Shape Generation".](https://arxiv.org/abs/2206.12055)
111
+
112
+ [For the official implementation.](https://github.com/Zhengxinyang/SDF-StyleGAN)
113
+ ### Future Work based on interest
114
+ - Adding new models for new type objects
115
+ - New Customization
116
+
117
+
118
+ It is running on {device}
119
+
120
+ '''
121
+ with gr.Blocks() as demo:
122
+ with gr.Column():
123
+ with gr.Row():
124
+ gr.Markdown(markdown)
125
+ with gr.Row():
126
+ seed = gr.Slider( minimum=0, maximum=2**16,label='Seed')
127
+ model=gr.Dropdown(choices=["Car","Airplane","Chair","Rifle","Table"],label="Choose Model Type")
128
+ trunc_psi = gr.Slider( minimum=0, maximum=2,label='Truncate PSI')
129
+
130
+ btn = gr.Button(value="Generate")
131
+ mesh = gr.Plot()
132
+ demo.load(generate, [seed,model,trunc_psi], mesh)
133
+ btn.click(generate, [seed,model,trunc_psi], mesh)
134
+
135
+ demo.launch(debug=True)