File size: 3,121 Bytes
c6c9d33
98f4a4b
cbcafc2
 
 
 
 
 
 
 
 
 
e9d9b92
 
 
cbcafc2
 
e9d9b92
 
cbcafc2
 
 
 
 
e9d9b92
 
cbcafc2
 
 
 
e9d9b92
 
cbcafc2
 
e9d9b92
cbcafc2
 
e9d9b92
cbcafc2
 
 
 
 
 
 
e9d9b92
cbcafc2
 
e9d9b92
 
cbcafc2
e9d9b92
cbcafc2
e9d9b92
a23ce4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cbcafc2
a23ce4f
 
 
 
 
 
 
 
 
 
e9d9b92
 
5078e2a
e9d9b92
5078e2a
cffe600
e9d9b92
 
5078e2a
e9d9b92
 
bcfbcce
e9d9b92
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import numpy as np
import plotly.graph_objects as go

import torch
from tqdm.auto import tqdm

from point_e.diffusion.configs import DIFFUSION_CONFIGS, diffusion_from_config
from point_e.diffusion.sampler import PointCloudSampler
from point_e.models.download import load_checkpoint
from point_e.models.configs import MODEL_CONFIGS, model_from_config
from point_e.util.plotting import plot_point_cloud

import gradio as gr

# Select device (CUDA if available, otherwise CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize base model
print('Creating base model...')
base_name = 'base40M-textvec'
base_model = model_from_config(MODEL_CONFIGS[base_name], device)
base_model.eval()
base_diffusion = diffusion_from_config(DIFFUSION_CONFIGS[base_name])

# Initialize upsample model
print('Creating upsample model...')
upsampler_model = model_from_config(MODEL_CONFIGS['upsample'], device)
upsampler_model.eval()
upsampler_diffusion = diffusion_from_config(DIFFUSION_CONFIGS['upsample'])

# Load checkpoints
print('Downloading base checkpoint...')
base_model.load_state_dict(load_checkpoint(base_name, device))

print('Downloading upsampler checkpoint...')
upsampler_model.load_state_dict(load_checkpoint('upsample', device))

# Initialize sampler
sampler = PointCloudSampler(
    device=device,
    models=[base_model, upsampler_model],
    diffusions=[base_diffusion, upsampler_diffusion],
    num_points=[1024, 4096 - 1024],
    aux_channels=['R', 'G', 'B'],
    guidance_scale=[3.0, 0.0],
    model_kwargs_key_filter=('texts', ''),  # Do not condition the upsampler at all
)

# Function to create point clouds
def create_point_cloud(inp):
    samples = None
    for x in tqdm(sampler.sample_batch_progressive(batch_size=1, model_kwargs=dict(texts=[inp]))):
        samples = x

    pc = sampler.output_to_point_clouds(samples)[0]  # Get the point cloud
    
    # Convert the point cloud to a Plotly figure
    fig = go.Figure(
        data=[
            go.Scatter3d(
                x=pc.coords[:, 0],  # X coordinates
                y=pc.coords[:, 1],  # Y coordinates
                z=pc.coords[:, 2],  # Z coordinates
                mode='markers',
                marker=dict(
                    size=2,
                    color=pc.channels[0],  # Assuming the first channel contains RGB values
                    colorscale='Viridis',  # Adjust as needed
                    opacity=0.8
                )
            )
        ]
    )

    fig.update_layout(
        scene=dict(
            xaxis_title="X",
            yaxis_title="Y",
            zaxis_title="Z",
        ),
        margin=dict(r=0, l=0, b=0, t=0)
    )
    
    return fig

# Create Gradio interface
demo = gr.Interface(
    fn=create_point_cloud,
    inputs="text",
    outputs=gr.Plot(),  # Gradio expects HTML for Plotly visualizations
    title="Point-E Demo - Convert Text to 3D Point Clouds",
    description="Generate and visualize 3D point clouds from textual descriptions using OpenAI's Point-E framework."
)

# Enable queuing and launch Gradio app
demo.queue(max_size=30)
demo.launch(debug=True)