Spaces:
Runtime error
Runtime error
File size: 7,256 Bytes
e04dce3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
# DO NOT HOST PUBLICLY - SECURITY RISKS!
# (the API will only be on with --api starting option)
# Currently no API stability guarantees are provided - API may break on any new commit (but hopefully won't).
import os
import numpy as np
from fastapi import FastAPI, Body
from fastapi.exceptions import HTTPException
from PIL import Image
import gradio as gr
from typing import Dict, List
from modules.api import api
from src.core import core_generation_funnel, run_makevideo
from src.misc import SCRIPT_VERSION
from src import backbone
from src.common_constants import GenerationOptions as go
def encode_to_base64(image):
if type(image) is str:
return image
elif type(image) is Image.Image:
return api.encode_pil_to_base64(image)
elif type(image) is np.ndarray:
return encode_np_to_base64(image)
else:
return ""
def encode_np_to_base64(image):
pil = Image.fromarray(image)
return api.encode_pil_to_base64(pil)
def to_base64_PIL(encoding: str):
return Image.fromarray(np.array(api.decode_base64_to_image(encoding)).astype('uint8'))
def depth_api(_: gr.Blocks, app: FastAPI):
@app.get("/depth/version")
async def version():
return {"version": SCRIPT_VERSION}
@app.get("/depth/get_options")
async def get_options():
return {"options": sorted([x.name.lower() for x in go])}
# TODO: some potential inputs not supported (like custom depthmaps)
@app.post("/depth/generate")
async def process(
depth_input_images: List[str] = Body([], title='Input Images'),
options: Dict[str, object] = Body("options", title='Generation options'),
):
# TODO: restrict mesh options
if len(depth_input_images) == 0:
raise HTTPException(status_code=422, detail="No images supplied")
print(f"Processing {str(len(depth_input_images))} images trough the API")
pil_images = []
for input_image in depth_input_images:
pil_images.append(to_base64_PIL(input_image))
outpath = backbone.get_outpath()
gen_obj = core_generation_funnel(outpath, pil_images, None, None, options)
results_based = []
for count, type, result in gen_obj:
if not isinstance(result, Image.Image):
continue
results_based += [encode_to_base64(result)]
return {"images": results_based, "info": "Success"}
@app.post("/depth/generate/video")
async def process_video(
depth_input_images: List[str] = Body([], title='Input Images'),
options: Dict[str, object] = Body("options", title='Generation options'),
):
if len(depth_input_images) == 0:
raise HTTPException(status_code=422, detail="No images supplied")
print(f"Processing {str(len(depth_input_images))} images trough the API")
# You can use either these strings, or integers
available_models = {
'res101': 0,
'dpt_beit_large_512': 1, #midas 3.1
'dpt_beit_large_384': 2, #midas 3.1
'dpt_large_384': 3, #midas 3.0
'dpt_hybrid_384': 4, #midas 3.0
'midas_v21': 5,
'midas_v21_small': 6,
'zoedepth_n': 7, #indoor
'zoedepth_k': 8, #outdoor
'zoedepth_nk': 9,
'marigold_v1': 10,
'depth_anything': 11,
'depth_anything_v2_small': 12,
'depth_anything_v2_base': 13,
'depth_anything_v2_large': 14
}
model_type = options["model_type"]
model_id = None
if isinstance(model_type, str):
# Check if the string is in the available_models dictionary
if model_type in available_models:
model_id = available_models[model_type]
else:
available_strings = list(available_models.keys())
raise HTTPException(status_code=400, detail={'error': 'Invalid model string', 'available_models': available_strings})
elif isinstance(model_type, int):
model_id = model_type
else:
raise HTTPException(status_code=400, detail={'error': 'Invalid model parameter type'})
options["model_type"] = model_id
video_parameters = options["video_parameters"]
required_params = ["vid_numframes", "vid_fps", "vid_traj", "vid_shift", "vid_border", "dolly", "vid_format", "vid_ssaa", "output_filename"]
missing_params = [param for param in required_params if param not in video_parameters]
if missing_params:
raise HTTPException(status_code=400, detail={'error': f"Missing required parameter(s): {', '.join(missing_params)}"})
vid_numframes = video_parameters["vid_numframes"]
vid_fps = video_parameters["vid_fps"]
vid_traj = video_parameters["vid_traj"]
vid_shift = video_parameters["vid_shift"]
vid_border = video_parameters["vid_border"]
dolly = video_parameters["dolly"]
vid_format = video_parameters["vid_format"]
vid_ssaa = int(video_parameters["vid_ssaa"])
output_filename = video_parameters["output_filename"]
output_path = os.path.dirname(output_filename)
basename, extension = os.path.splitext(os.path.basename(output_filename))
# Comparing video_format with the extension
if vid_format != extension[1:]:
raise HTTPException(status_code=400, detail={'error': f"Video format '{vid_format}' does not match with the extension '{extension}'."})
pil_images = []
for input_image in depth_input_images:
pil_images.append(to_base64_PIL(input_image))
outpath = backbone.get_outpath()
mesh_fi_filename = video_parameters.get('mesh_fi_filename', None)
if mesh_fi_filename and os.path.exists(mesh_fi_filename):
mesh_fi = mesh_fi_filename
print("Loaded existing mesh from: ", mesh_fi)
else:
# If there is no mesh file generate it.
options["GEN_INPAINTED_MESH"] = True
gen_obj = core_generation_funnel(outpath, pil_images, None, None, options)
mesh_fi = None
for count, type, result in gen_obj:
if type == 'inpainted_mesh':
mesh_fi = result
break
if mesh_fi:
print("Created mesh in: ", mesh_fi)
else:
raise HTTPException(status_code=400, detail={'error': "The mesh has not been created"})
run_makevideo(mesh_fi, vid_numframes, vid_fps, vid_traj, vid_shift, vid_border, dolly, vid_format, vid_ssaa, output_path, basename)
return {"info": "Success"}
try:
import modules.script_callbacks as script_callbacks
if backbone.get_cmd_opt('api', False):
script_callbacks.on_app_started(depth_api)
print("Started the depthmap API. DO NOT HOST PUBLICLY - SECURITY RISKS!")
except:
print('DepthMap API could not start')
|