File size: 2,608 Bytes
28bead6
98f4a4b
cbcafc2
 
 
 
 
 
 
 
 
 
e9d9b92
 
 
cbcafc2
 
e9d9b92
 
cbcafc2
 
 
 
 
e9d9b92
 
cbcafc2
 
 
 
e9d9b92
 
cbcafc2
 
e9d9b92
cbcafc2
 
e9d9b92
cbcafc2
 
 
 
 
 
 
e9d9b92
cbcafc2
 
e9d9b92
 
 
cbcafc2
e9d9b92
cbcafc2
e9d9b92
 
cbcafc2
 
e9d9b92
 
 
 
 
 
 
5078e2a
e9d9b92
5078e2a
e9d9b92
 
 
5078e2a
e9d9b92
 
bcfbcce
e9d9b92
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import numpy
import plotly.graph_objects as go

import torch
from tqdm.auto import tqdm

from point_e.diffusion.configs import DIFFUSION_CONFIGS, diffusion_from_config
from point_e.diffusion.sampler import PointCloudSampler
from point_e.models.download import load_checkpoint
from point_e.models.configs import MODEL_CONFIGS, model_from_config
from point_e.util.plotting import plot_point_cloud

import gradio as gr

# Select device (CUDA if available, otherwise CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize base model
print('Creating base model...')
base_name = 'base40M-textvec'
base_model = model_from_config(MODEL_CONFIGS[base_name], device)
base_model.eval()
base_diffusion = diffusion_from_config(DIFFUSION_CONFIGS[base_name])

# Initialize upsample model
print('Creating upsample model...')
upsampler_model = model_from_config(MODEL_CONFIGS['upsample'], device)
upsampler_model.eval()
upsampler_diffusion = diffusion_from_config(DIFFUSION_CONFIGS['upsample'])

# Load checkpoints
print('Downloading base checkpoint...')
base_model.load_state_dict(load_checkpoint(base_name, device))

print('Downloading upsampler checkpoint...')
upsampler_model.load_state_dict(load_checkpoint('upsample', device))

# Initialize sampler
sampler = PointCloudSampler(
    device=device,
    models=[base_model, upsampler_model],
    diffusions=[base_diffusion, upsampler_diffusion],
    num_points=[1024, 4096 - 1024],
    aux_channels=['R', 'G', 'B'],
    guidance_scale=[3.0, 0.0],
    model_kwargs_key_filter=('texts', ''),  # Do not condition the upsampler at all
)

# Function to create point clouds
def create_point_cloud(inp):
    # Generate progressive samples
    samples = None
    for x in tqdm(sampler.sample_batch_progressive(batch_size=1, model_kwargs=dict(texts=[inp]))):
        samples = x

    # Extract the point cloud
    pc = sampler.output_to_point_clouds(samples)[0]

    # Generate a Plotly figure for visualization
    fig = plot_point_cloud(pc, grid_size=3, fixed_bounds=((-0.75, -0.75, -0.75), (0.75, 0.75, 0.75)))

    # Convert Plotly figure to HTML for Gradio compatibility
    return fig.to_html(full_html=False)

# Create Gradio interface
demo = gr.Interface(
    fn=create_point_cloud,
    inputs="text",
    outputs=gr.HTML(),  # Gradio expects HTML for Plotly visualizations
    title="Point-E Demo - Convert Text to 3D Point Clouds",
    description="Generate and visualize 3D point clouds from textual descriptions using OpenAI's Point-E framework."
)

# Enable queuing and launch Gradio app
demo.queue(max_size=30)
demo.launch(debug=True)