File size: 3,579 Bytes
c6c9d33
98f4a4b
cbcafc2
 
 
 
 
 
 
 
 
 
e9d9b92
 
 
cbcafc2
 
e9d9b92
 
cbcafc2
 
 
 
 
e9d9b92
 
cbcafc2
 
 
 
e9d9b92
 
cbcafc2
 
e9d9b92
cbcafc2
 
e9d9b92
cbcafc2
 
 
 
 
 
 
e9d9b92
cbcafc2
 
e9d9b92
 
cbcafc2
e9d9b92
cbcafc2
e9d9b92
a23ce4f
 
7035650
 
 
 
 
 
 
 
 
 
 
 
 
a23ce4f
 
 
 
 
 
 
 
 
7035650
 
a23ce4f
 
 
 
 
cbcafc2
a23ce4f
 
 
 
 
 
 
 
 
 
e9d9b92
 
5078e2a
e9d9b92
5078e2a
cffe600
e9d9b92
 
5078e2a
e9d9b92
 
bcfbcce
e9d9b92
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import numpy as np
import plotly.graph_objects as go

import torch
from tqdm.auto import tqdm

from point_e.diffusion.configs import DIFFUSION_CONFIGS, diffusion_from_config
from point_e.diffusion.sampler import PointCloudSampler
from point_e.models.download import load_checkpoint
from point_e.models.configs import MODEL_CONFIGS, model_from_config
from point_e.util.plotting import plot_point_cloud

import gradio as gr

# Select device (CUDA if available, otherwise CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize base model
print('Creating base model...')
base_name = 'base40M-textvec'
base_model = model_from_config(MODEL_CONFIGS[base_name], device)
base_model.eval()
base_diffusion = diffusion_from_config(DIFFUSION_CONFIGS[base_name])

# Initialize upsample model
print('Creating upsample model...')
upsampler_model = model_from_config(MODEL_CONFIGS['upsample'], device)
upsampler_model.eval()
upsampler_diffusion = diffusion_from_config(DIFFUSION_CONFIGS['upsample'])

# Load checkpoints
print('Downloading base checkpoint...')
base_model.load_state_dict(load_checkpoint(base_name, device))

print('Downloading upsampler checkpoint...')
upsampler_model.load_state_dict(load_checkpoint('upsample', device))

# Initialize sampler
sampler = PointCloudSampler(
    device=device,
    models=[base_model, upsampler_model],
    diffusions=[base_diffusion, upsampler_diffusion],
    num_points=[1024, 4096 - 1024],
    aux_channels=['R', 'G', 'B'],
    guidance_scale=[3.0, 0.0],
    model_kwargs_key_filter=('texts', ''),  # Do not condition the upsampler at all
)

# Function to create point clouds
def create_point_cloud(inp):
    samples = None
    for x in tqdm(sampler.sample_batch_progressive(batch_size=1, model_kwargs=dict(texts=[inp]))):
        samples = x

    pc = sampler.output_to_point_clouds(samples)[0]  # Get the point cloud
    
    # Check if auxiliary channels (e.g., RGB) are available
    if 'R' in pc.channels and 'G' in pc.channels and 'B' in pc.channels:
        # Combine R, G, B channels into a single color value
        colors = (
            pc.channels['R'] / 255.0,  # Normalize to [0, 1]
            pc.channels['G'] / 255.0,
            pc.channels['B'] / 255.0
        )
    else:
        # Fall back to a single color if no RGB data is available
        colors = 'blue'  # Choose a default color

    # Create a Plotly 3D scatter plot
    fig = go.Figure(
        data=[
            go.Scatter3d(
                x=pc.coords[:, 0],  # X coordinates
                y=pc.coords[:, 1],  # Y coordinates
                z=pc.coords[:, 2],  # Z coordinates
                mode='markers',
                marker=dict(
                    size=2,
                    color=colors,
                    colorscale='Viridis' if isinstance(colors, tuple) else None,  # Use Viridis if RGB
                    opacity=0.8
                )
            )
        ]
    )

    fig.update_layout(
        scene=dict(
            xaxis_title="X",
            yaxis_title="Y",
            zaxis_title="Z",
        ),
        margin=dict(r=0, l=0, b=0, t=0)
    )
    
    return fig

# Create Gradio interface
demo = gr.Interface(
    fn=create_point_cloud,
    inputs="text",
    outputs=gr.Plot(),  # Gradio expects HTML for Plotly visualizations
    title="Point-E Demo - Convert Text to 3D Point Clouds",
    description="Generate and visualize 3D point clouds from textual descriptions using OpenAI's Point-E framework."
)

# Enable queuing and launch Gradio app
demo.queue(max_size=30)
demo.launch(debug=True)