Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# backend.py
|
2 |
+
from fastapi import FastAPI, UploadFile, Form
|
3 |
+
from pydantic import BaseModel
|
4 |
+
from typing import List, Literal
|
5 |
+
import os
|
6 |
+
import shutil
|
7 |
+
import torch
|
8 |
+
import numpy as np
|
9 |
+
from PIL import Image
|
10 |
+
from trellis.pipelines import TrellisImageTo3DPipeline
|
11 |
+
from trellis.utils import render_utils, postprocessing_utils
|
12 |
+
|
13 |
+
app = FastAPI()
|
14 |
+
|
15 |
+
MAX_SEED = np.iinfo(np.int32).max
|
16 |
+
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
|
17 |
+
os.makedirs(TMP_DIR, exist_ok=True)
|
18 |
+
|
19 |
+
pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
|
20 |
+
pipeline.cuda()
|
21 |
+
|
22 |
+
|
23 |
+
@app.on_event("startup")
|
24 |
+
def preload_model():
|
25 |
+
try:
|
26 |
+
pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
|
27 |
+
except:
|
28 |
+
pass
|
29 |
+
|
30 |
+
|
31 |
+
@app.post("/image-to-3d")
|
32 |
+
async def image_to_3d(
|
33 |
+
image: UploadFile,
|
34 |
+
seed: int = Form(...),
|
35 |
+
ss_guidance_strength: float = Form(...),
|
36 |
+
ss_sampling_steps: int = Form(...),
|
37 |
+
slat_guidance_strength: float = Form(...),
|
38 |
+
slat_sampling_steps: int = Form(...),
|
39 |
+
req_session: str = Form(...)
|
40 |
+
):
|
41 |
+
user_dir = os.path.join(TMP_DIR, req_session)
|
42 |
+
os.makedirs(user_dir, exist_ok=True)
|
43 |
+
image_data = Image.open(image.file)
|
44 |
+
|
45 |
+
outputs = pipeline.run(
|
46 |
+
image_data,
|
47 |
+
seed=seed,
|
48 |
+
formats=["gaussian", "mesh"],
|
49 |
+
sparse_structure_sampler_params={
|
50 |
+
"steps": ss_sampling_steps,
|
51 |
+
"cfg_strength": ss_guidance_strength,
|
52 |
+
},
|
53 |
+
slat_sampler_params={
|
54 |
+
"steps": slat_sampling_steps,
|
55 |
+
"cfg_strength": slat_guidance_strength,
|
56 |
+
},
|
57 |
+
)
|
58 |
+
|
59 |
+
video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
|
60 |
+
video_path = os.path.join(user_dir, 'sample.mp4')
|
61 |
+
render_utils.save_video(video, video_path)
|
62 |
+
|
63 |
+
torch.cuda.empty_cache()
|
64 |
+
return {"video_path": video_path}
|
65 |
+
|
66 |
+
|
67 |
+
@app.post("/extract-glb")
|
68 |
+
async def extract_glb(
|
69 |
+
mesh_simplify: float = Form(...),
|
70 |
+
texture_size: int = Form(...),
|
71 |
+
req_session: str = Form(...),
|
72 |
+
):
|
73 |
+
user_dir = os.path.join(TMP_DIR, req_session)
|
74 |
+
glb_path = os.path.join(user_dir, 'sample.glb')
|
75 |
+
postprocessing_utils.export_glb(glb_path, simplify=mesh_simplify, texture_size=texture_size)
|
76 |
+
|
77 |
+
torch.cuda.empty_cache()
|
78 |
+
return {"glb_path": glb_path}
|