File size: 3,579 Bytes
c6c9d33 98f4a4b cbcafc2 e9d9b92 cbcafc2 e9d9b92 cbcafc2 e9d9b92 cbcafc2 e9d9b92 cbcafc2 e9d9b92 cbcafc2 e9d9b92 cbcafc2 e9d9b92 cbcafc2 e9d9b92 cbcafc2 e9d9b92 cbcafc2 e9d9b92 a23ce4f 7035650 a23ce4f 7035650 a23ce4f cbcafc2 a23ce4f e9d9b92 5078e2a e9d9b92 5078e2a cffe600 e9d9b92 5078e2a e9d9b92 bcfbcce e9d9b92 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import numpy as np
import plotly.graph_objects as go
import torch
from tqdm.auto import tqdm
from point_e.diffusion.configs import DIFFUSION_CONFIGS, diffusion_from_config
from point_e.diffusion.sampler import PointCloudSampler
from point_e.models.download import load_checkpoint
from point_e.models.configs import MODEL_CONFIGS, model_from_config
from point_e.util.plotting import plot_point_cloud
import gradio as gr
# Select device (CUDA if available, otherwise CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Initialize base model
print('Creating base model...')
base_name = 'base40M-textvec'
base_model = model_from_config(MODEL_CONFIGS[base_name], device)
base_model.eval()
base_diffusion = diffusion_from_config(DIFFUSION_CONFIGS[base_name])
# Initialize upsample model
print('Creating upsample model...')
upsampler_model = model_from_config(MODEL_CONFIGS['upsample'], device)
upsampler_model.eval()
upsampler_diffusion = diffusion_from_config(DIFFUSION_CONFIGS['upsample'])
# Load checkpoints
print('Downloading base checkpoint...')
base_model.load_state_dict(load_checkpoint(base_name, device))
print('Downloading upsampler checkpoint...')
upsampler_model.load_state_dict(load_checkpoint('upsample', device))
# Initialize sampler
sampler = PointCloudSampler(
device=device,
models=[base_model, upsampler_model],
diffusions=[base_diffusion, upsampler_diffusion],
num_points=[1024, 4096 - 1024],
aux_channels=['R', 'G', 'B'],
guidance_scale=[3.0, 0.0],
model_kwargs_key_filter=('texts', ''), # Do not condition the upsampler at all
)
# Function to create point clouds
def create_point_cloud(inp):
samples = None
for x in tqdm(sampler.sample_batch_progressive(batch_size=1, model_kwargs=dict(texts=[inp]))):
samples = x
pc = sampler.output_to_point_clouds(samples)[0] # Get the point cloud
# Check if auxiliary channels (e.g., RGB) are available
if 'R' in pc.channels and 'G' in pc.channels and 'B' in pc.channels:
# Combine R, G, B channels into a single color value
colors = (
pc.channels['R'] / 255.0, # Normalize to [0, 1]
pc.channels['G'] / 255.0,
pc.channels['B'] / 255.0
)
else:
# Fall back to a single color if no RGB data is available
colors = 'blue' # Choose a default color
# Create a Plotly 3D scatter plot
fig = go.Figure(
data=[
go.Scatter3d(
x=pc.coords[:, 0], # X coordinates
y=pc.coords[:, 1], # Y coordinates
z=pc.coords[:, 2], # Z coordinates
mode='markers',
marker=dict(
size=2,
color=colors,
colorscale='Viridis' if isinstance(colors, tuple) else None, # Use Viridis if RGB
opacity=0.8
)
)
]
)
fig.update_layout(
scene=dict(
xaxis_title="X",
yaxis_title="Y",
zaxis_title="Z",
),
margin=dict(r=0, l=0, b=0, t=0)
)
return fig
# Create Gradio interface
demo = gr.Interface(
fn=create_point_cloud,
inputs="text",
outputs=gr.Plot(), # Gradio expects HTML for Plotly visualizations
title="Point-E Demo - Convert Text to 3D Point Clouds",
description="Generate and visualize 3D point clouds from textual descriptions using OpenAI's Point-E framework."
)
# Enable queuing and launch Gradio app
demo.queue(max_size=30)
demo.launch(debug=True) |