theekshana commited on
Commit
9a1a4ec
Β·
verified Β·
1 Parent(s): 5c7e93b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -0
app.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # backend.py
2
+ from fastapi import FastAPI, UploadFile, Form
3
+ from pydantic import BaseModel
4
+ from typing import List, Literal
5
+ import os
6
+ import shutil
7
+ import torch
8
+ import numpy as np
9
+ from PIL import Image
10
+ from trellis.pipelines import TrellisImageTo3DPipeline
11
+ from trellis.utils import render_utils, postprocessing_utils
12
+
13
+ app = FastAPI()
14
+
15
+ MAX_SEED = np.iinfo(np.int32).max
16
+ TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
17
+ os.makedirs(TMP_DIR, exist_ok=True)
18
+
19
+ pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
20
+ pipeline.cuda()
21
+
22
+
23
+ @app.on_event("startup")
24
+ def preload_model():
25
+ try:
26
+ pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
27
+ except:
28
+ pass
29
+
30
+
31
+ @app.post("/image-to-3d")
32
+ async def image_to_3d(
33
+ image: UploadFile,
34
+ seed: int = Form(...),
35
+ ss_guidance_strength: float = Form(...),
36
+ ss_sampling_steps: int = Form(...),
37
+ slat_guidance_strength: float = Form(...),
38
+ slat_sampling_steps: int = Form(...),
39
+ req_session: str = Form(...)
40
+ ):
41
+ user_dir = os.path.join(TMP_DIR, req_session)
42
+ os.makedirs(user_dir, exist_ok=True)
43
+ image_data = Image.open(image.file)
44
+
45
+ outputs = pipeline.run(
46
+ image_data,
47
+ seed=seed,
48
+ formats=["gaussian", "mesh"],
49
+ sparse_structure_sampler_params={
50
+ "steps": ss_sampling_steps,
51
+ "cfg_strength": ss_guidance_strength,
52
+ },
53
+ slat_sampler_params={
54
+ "steps": slat_sampling_steps,
55
+ "cfg_strength": slat_guidance_strength,
56
+ },
57
+ )
58
+
59
+ video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
60
+ video_path = os.path.join(user_dir, 'sample.mp4')
61
+ render_utils.save_video(video, video_path)
62
+
63
+ torch.cuda.empty_cache()
64
+ return {"video_path": video_path}
65
+
66
+
67
+ @app.post("/extract-glb")
68
+ async def extract_glb(
69
+ mesh_simplify: float = Form(...),
70
+ texture_size: int = Form(...),
71
+ req_session: str = Form(...),
72
+ ):
73
+ user_dir = os.path.join(TMP_DIR, req_session)
74
+ glb_path = os.path.join(user_dir, 'sample.glb')
75
+ postprocessing_utils.export_glb(glb_path, simplify=mesh_simplify, texture_size=texture_size)
76
+
77
+ torch.cuda.empty_cache()
78
+ return {"glb_path": glb_path}