File size: 2,608 Bytes
28bead6 98f4a4b cbcafc2 e9d9b92 cbcafc2 e9d9b92 cbcafc2 e9d9b92 cbcafc2 e9d9b92 cbcafc2 e9d9b92 cbcafc2 e9d9b92 cbcafc2 e9d9b92 cbcafc2 e9d9b92 cbcafc2 e9d9b92 cbcafc2 e9d9b92 cbcafc2 e9d9b92 5078e2a e9d9b92 5078e2a e9d9b92 5078e2a e9d9b92 bcfbcce e9d9b92 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import numpy
import plotly.graph_objects as go
import torch
from tqdm.auto import tqdm
from point_e.diffusion.configs import DIFFUSION_CONFIGS, diffusion_from_config
from point_e.diffusion.sampler import PointCloudSampler
from point_e.models.download import load_checkpoint
from point_e.models.configs import MODEL_CONFIGS, model_from_config
from point_e.util.plotting import plot_point_cloud
import gradio as gr
# Select device (CUDA if available, otherwise CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Initialize base model
print('Creating base model...')
base_name = 'base40M-textvec'
base_model = model_from_config(MODEL_CONFIGS[base_name], device)
base_model.eval()
base_diffusion = diffusion_from_config(DIFFUSION_CONFIGS[base_name])
# Initialize upsample model
print('Creating upsample model...')
upsampler_model = model_from_config(MODEL_CONFIGS['upsample'], device)
upsampler_model.eval()
upsampler_diffusion = diffusion_from_config(DIFFUSION_CONFIGS['upsample'])
# Load checkpoints
print('Downloading base checkpoint...')
base_model.load_state_dict(load_checkpoint(base_name, device))
print('Downloading upsampler checkpoint...')
upsampler_model.load_state_dict(load_checkpoint('upsample', device))
# Initialize sampler
sampler = PointCloudSampler(
device=device,
models=[base_model, upsampler_model],
diffusions=[base_diffusion, upsampler_diffusion],
num_points=[1024, 4096 - 1024],
aux_channels=['R', 'G', 'B'],
guidance_scale=[3.0, 0.0],
model_kwargs_key_filter=('texts', ''), # Do not condition the upsampler at all
)
# Function to create point clouds
def create_point_cloud(inp):
# Generate progressive samples
samples = None
for x in tqdm(sampler.sample_batch_progressive(batch_size=1, model_kwargs=dict(texts=[inp]))):
samples = x
# Extract the point cloud
pc = sampler.output_to_point_clouds(samples)[0]
# Generate a Plotly figure for visualization
fig = plot_point_cloud(pc, grid_size=3, fixed_bounds=((-0.75, -0.75, -0.75), (0.75, 0.75, 0.75)))
# Convert Plotly figure to HTML for Gradio compatibility
return fig.to_html(full_html=False)
# Create Gradio interface
demo = gr.Interface(
fn=create_point_cloud,
inputs="text",
outputs=gr.HTML(), # Gradio expects HTML for Plotly visualizations
title="Point-E Demo - Convert Text to 3D Point Clouds",
description="Generate and visualize 3D point clouds from textual descriptions using OpenAI's Point-E framework."
)
# Enable queuing and launch Gradio app
demo.queue(max_size=30)
demo.launch(debug=True) |