User-2468 commited on
Commit
a23ce4f
·
verified ·
1 Parent(s): cffe600

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -8
app.py CHANGED
@@ -48,19 +48,40 @@ sampler = PointCloudSampler(
48
 
49
  # Function to create point clouds
50
  def create_point_cloud(inp):
51
- # Generate progressive samples
52
  samples = None
53
  for x in tqdm(sampler.sample_batch_progressive(batch_size=1, model_kwargs=dict(texts=[inp]))):
54
  samples = x
55
 
56
- # Extract the point cloud
57
- pc = sampler.output_to_point_clouds(samples)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
- # Generate a Plotly figure for visualization
60
- fig = plot_point_cloud(pc, grid_size=3, fixed_bounds=((-0.75, -0.75, -0.75), (0.75, 0.75, 0.75)))
61
-
62
- # Convert Plotly figure to HTML for Gradio compatibility
63
- return pc
 
 
 
 
 
64
 
65
  # Create Gradio interface
66
  demo = gr.Interface(
 
48
 
49
  # Function to create point clouds
50
  def create_point_cloud(inp):
 
51
  samples = None
52
  for x in tqdm(sampler.sample_batch_progressive(batch_size=1, model_kwargs=dict(texts=[inp]))):
53
  samples = x
54
 
55
+ pc = sampler.output_to_point_clouds(samples)[0] # Get the point cloud
56
+
57
+ # Convert the point cloud to a Plotly figure
58
+ fig = go.Figure(
59
+ data=[
60
+ go.Scatter3d(
61
+ x=pc.coords[:, 0], # X coordinates
62
+ y=pc.coords[:, 1], # Y coordinates
63
+ z=pc.coords[:, 2], # Z coordinates
64
+ mode='markers',
65
+ marker=dict(
66
+ size=2,
67
+ color=pc.channels[0], # Assuming the first channel contains RGB values
68
+ colorscale='Viridis', # Adjust as needed
69
+ opacity=0.8
70
+ )
71
+ )
72
+ ]
73
+ )
74
 
75
+ fig.update_layout(
76
+ scene=dict(
77
+ xaxis_title="X",
78
+ yaxis_title="Y",
79
+ zaxis_title="Z",
80
+ ),
81
+ margin=dict(r=0, l=0, b=0, t=0)
82
+ )
83
+
84
+ return fig
85
 
86
  # Create Gradio interface
87
  demo = gr.Interface(