|
import torch |
|
import gradio as gr |
|
from diffusers import ShapEPipeline, ShapEImg2ImgPipeline |
|
from diffusers.utils import export_to_gif |
|
import os |
|
from huggingface_hub import HfApi, login |
|
from PIL import Image |
|
import numpy as np |
|
import gc |
|
|
|
|
|
device = "cpu" |
|
torch.set_num_threads(4) |
|
print(f"Using device: {device}") |
|
|
|
def validate_token(token): |
|
try: |
|
login(token=token) |
|
return True |
|
except Exception as e: |
|
print(f"Token validation error: {str(e)}") |
|
return False |
|
|
|
def generate_3d_from_text(prompt, token, guidance_scale=7.0, export_format="obj", progress=gr.Progress()): |
|
try: |
|
if not validate_token(token): |
|
return gr.update(value="Invalid Hugging Face token"), None, None |
|
|
|
print(f"Starting generation: {prompt}") |
|
progress(0.1, desc="Loading model...") |
|
|
|
pipe = ShapEPipeline.from_pretrained( |
|
"openai/shap-e", |
|
torch_dtype=torch.float32, |
|
token=token, |
|
revision="main", |
|
low_cpu_mem_usage=True |
|
) |
|
|
|
os.makedirs("outputs", exist_ok=True) |
|
safe_prompt = "".join(x for x in prompt if x.isalnum() or x in (" ", "-", "_")) |
|
base_filename = f"outputs/{safe_prompt}" |
|
|
|
try: |
|
progress(0.3, desc="Creating 3D model...") |
|
with torch.no_grad(): |
|
output = pipe( |
|
prompt, |
|
guidance_scale=min(guidance_scale, 10.0), |
|
num_inference_steps=16 |
|
) |
|
|
|
progress(0.5, desc="Creating GIF...") |
|
gif_path = export_to_gif(output.images, f"{base_filename}.gif") |
|
|
|
progress(0.7, desc="Creating 3D mesh...") |
|
mesh_output = pipe( |
|
prompt, |
|
guidance_scale=min(guidance_scale, 10.0), |
|
num_inference_steps=16, |
|
output_type="mesh" |
|
) |
|
|
|
progress(0.9, desc="Saving files...") |
|
output_path = f"{base_filename}.{export_format}" |
|
mesh_output.meshes[0].export(output_path) |
|
|
|
del pipe |
|
del output |
|
del mesh_output |
|
gc.collect() |
|
|
|
print(f"Generation completed: {output_path}") |
|
progress(1.0, desc="Completed!") |
|
return gr.update(value="Generation successful!"), gr.update(value=gif_path), gr.update(value=output_path) |
|
|
|
except Exception as model_error: |
|
error_msg = f"Model execution error: {str(model_error)}" |
|
print(error_msg) |
|
return gr.update(value=error_msg), None, None |
|
|
|
except Exception as e: |
|
error_msg = f"General error: {str(e)}" |
|
print(error_msg) |
|
return gr.update(value=error_msg), None, None |
|
|
|
def generate_3d_from_image(image, token, guidance_scale=7.0, export_format="obj", progress=gr.Progress()): |
|
try: |
|
if not validate_token(token): |
|
return gr.update(value="Invalid Hugging Face token"), None, None |
|
|
|
print("Starting image to 3D generation") |
|
progress(0.1, desc="Loading model...") |
|
|
|
pipe = ShapEImg2ImgPipeline.from_pretrained( |
|
"openai/shap-e-img2img", |
|
torch_dtype=torch.float32, |
|
token=token, |
|
revision="main", |
|
low_cpu_mem_usage=True |
|
) |
|
|
|
os.makedirs("outputs", exist_ok=True) |
|
|
|
import time |
|
timestamp = int(time.time()) |
|
base_filename = f"outputs/image_to_3d_{timestamp}" |
|
|
|
try: |
|
progress(0.3, desc="Preparing image...") |
|
if isinstance(image, str): |
|
image = Image.open(image) |
|
elif isinstance(image, np.ndarray): |
|
image = Image.fromarray(image) |
|
|
|
image = image.resize((128, 128)) |
|
|
|
progress(0.5, desc="Creating 3D model...") |
|
with torch.no_grad(): |
|
output = pipe( |
|
image=image, |
|
guidance_scale=min(guidance_scale, 10.0), |
|
num_inference_steps=16 |
|
) |
|
|
|
progress(0.7, desc="Creating GIF...") |
|
gif_path = export_to_gif(output.images, f"{base_filename}.gif") |
|
|
|
progress(0.8, desc="Creating 3D mesh...") |
|
mesh_output = pipe( |
|
image=image, |
|
guidance_scale=min(guidance_scale, 10.0), |
|
num_inference_steps=16, |
|
output_type="mesh" |
|
) |
|
|
|
progress(0.9, desc="Saving files...") |
|
output_path = f"{base_filename}.{export_format}" |
|
mesh_output.meshes[0].export(output_path) |
|
|
|
del pipe |
|
del output |
|
del mesh_output |
|
gc.collect() |
|
|
|
print(f"Generation completed: {output_path}") |
|
progress(1.0, desc="Completed!") |
|
return gr.update(value="Generation successful!"), gr.update(value=gif_path), gr.update(value=output_path) |
|
|
|
except Exception as model_error: |
|
error_msg = f"Model execution error: {str(model_error)}" |
|
print(error_msg) |
|
return gr.update(value=error_msg), None, None |
|
|
|
except Exception as e: |
|
error_msg = f"General error: {str(e)}" |
|
print(error_msg) |
|
return gr.update(value=error_msg), None, None |
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as interface: |
|
gr.Markdown("# SORA-3D - Text/Image to 3D Model Generator") |
|
gr.Markdown("Create 3D models from text or image input. You need a Hugging Face token to use this app.") |
|
gr.Markdown(""" |
|
> **Important Notes**: |
|
> - Processing time may be longer on CPU |
|
> - Keep guidance scale under 10 for faster results |
|
> - Number of steps is fixed at 16 |
|
> - Image size is optimized for quality/speed |
|
""") |
|
|
|
with gr.Tab("Text → 3D"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
text_input = gr.Textbox(label="Enter description for 3D model", scale=2) |
|
text_token = gr.Textbox(label="Hugging Face Token", type="password", scale=2) |
|
with gr.Row(): |
|
text_guidance = gr.Slider(minimum=1, maximum=10, value=7, label="Guidance Scale", scale=1) |
|
text_format = gr.Radio(["obj", "glb"], label="Export Format", value="obj", scale=1) |
|
text_button = gr.Button("Generate", variant="primary") |
|
|
|
with gr.Column(): |
|
text_status = gr.Textbox(label="Status", interactive=False) |
|
text_preview = gr.Image(label="3D Preview (GIF)", interactive=False) |
|
text_file = gr.File(label="3D Model File") |
|
|
|
with gr.Tab("Image → 3D"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
image_input = gr.Image(label="Image to convert to 3D", type="pil", scale=2) |
|
image_token = gr.Textbox(label="Hugging Face Token", type="password", scale=2) |
|
with gr.Row(): |
|
image_guidance = gr.Slider(minimum=1, maximum=10, value=7, label="Guidance Scale", scale=1) |
|
image_format = gr.Radio(["obj", "glb"], label="Export Format", value="obj", scale=1) |
|
image_button = gr.Button("Generate", variant="primary") |
|
|
|
with gr.Column(): |
|
image_status = gr.Textbox(label="Status", interactive=False) |
|
image_preview = gr.Image(label="3D Preview (GIF)", interactive=False) |
|
image_file = gr.File(label="3D Model File") |
|
|
|
text_button.click( |
|
generate_3d_from_text, |
|
inputs=[text_input, text_token, text_guidance, text_format], |
|
outputs=[text_status, text_preview, text_file] |
|
) |
|
|
|
image_button.click( |
|
generate_3d_from_image, |
|
inputs=[image_input, image_token, image_guidance, image_format], |
|
outputs=[image_status, image_preview, image_file] |
|
) |
|
|
|
if __name__ == "__main__": |
|
interface.launch() |