Update app.py
Browse files
app.py
CHANGED
@@ -48,19 +48,40 @@ sampler = PointCloudSampler(
|
|
48 |
|
49 |
# Function to create point clouds
|
50 |
def create_point_cloud(inp):
|
51 |
-
# Generate progressive samples
|
52 |
samples = None
|
53 |
for x in tqdm(sampler.sample_batch_progressive(batch_size=1, model_kwargs=dict(texts=[inp]))):
|
54 |
samples = x
|
55 |
|
56 |
-
#
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
# Create Gradio interface
|
66 |
demo = gr.Interface(
|
|
|
48 |
|
49 |
# Function to create point clouds
|
50 |
def create_point_cloud(inp):
|
|
|
51 |
samples = None
|
52 |
for x in tqdm(sampler.sample_batch_progressive(batch_size=1, model_kwargs=dict(texts=[inp]))):
|
53 |
samples = x
|
54 |
|
55 |
+
pc = sampler.output_to_point_clouds(samples)[0] # Get the point cloud
|
56 |
+
|
57 |
+
# Convert the point cloud to a Plotly figure
|
58 |
+
fig = go.Figure(
|
59 |
+
data=[
|
60 |
+
go.Scatter3d(
|
61 |
+
x=pc.coords[:, 0], # X coordinates
|
62 |
+
y=pc.coords[:, 1], # Y coordinates
|
63 |
+
z=pc.coords[:, 2], # Z coordinates
|
64 |
+
mode='markers',
|
65 |
+
marker=dict(
|
66 |
+
size=2,
|
67 |
+
color=pc.channels[0], # Assuming the first channel contains RGB values
|
68 |
+
colorscale='Viridis', # Adjust as needed
|
69 |
+
opacity=0.8
|
70 |
+
)
|
71 |
+
)
|
72 |
+
]
|
73 |
+
)
|
74 |
|
75 |
+
fig.update_layout(
|
76 |
+
scene=dict(
|
77 |
+
xaxis_title="X",
|
78 |
+
yaxis_title="Y",
|
79 |
+
zaxis_title="Z",
|
80 |
+
),
|
81 |
+
margin=dict(r=0, l=0, b=0, t=0)
|
82 |
+
)
|
83 |
+
|
84 |
+
return fig
|
85 |
|
86 |
# Create Gradio interface
|
87 |
demo = gr.Interface(
|