Spaces:
Runtime error
Runtime error
# DO NOT HOST PUBLICLY - SECURITY RISKS! | |
# (the API will only be on with --api starting option) | |
# Currently no API stability guarantees are provided - API may break on any new commit (but hopefully won't). | |
import os | |
import numpy as np | |
from fastapi import FastAPI, Body | |
from fastapi.exceptions import HTTPException | |
from PIL import Image | |
import gradio as gr | |
from typing import Dict, List | |
from modules.api import api | |
from src.core import core_generation_funnel, run_makevideo | |
from src.misc import SCRIPT_VERSION | |
from src import backbone | |
from src.common_constants import GenerationOptions as go | |
def encode_to_base64(image): | |
if type(image) is str: | |
return image | |
elif type(image) is Image.Image: | |
return api.encode_pil_to_base64(image) | |
elif type(image) is np.ndarray: | |
return encode_np_to_base64(image) | |
else: | |
return "" | |
def encode_np_to_base64(image): | |
pil = Image.fromarray(image) | |
return api.encode_pil_to_base64(pil) | |
def to_base64_PIL(encoding: str): | |
return Image.fromarray(np.array(api.decode_base64_to_image(encoding)).astype('uint8')) | |
def depth_api(_: gr.Blocks, app: FastAPI): | |
async def version(): | |
return {"version": SCRIPT_VERSION} | |
async def get_options(): | |
return {"options": sorted([x.name.lower() for x in go])} | |
# TODO: some potential inputs not supported (like custom depthmaps) | |
async def process( | |
depth_input_images: List[str] = Body([], title='Input Images'), | |
options: Dict[str, object] = Body("options", title='Generation options'), | |
): | |
# TODO: restrict mesh options | |
if len(depth_input_images) == 0: | |
raise HTTPException(status_code=422, detail="No images supplied") | |
print(f"Processing {str(len(depth_input_images))} images trough the API") | |
pil_images = [] | |
for input_image in depth_input_images: | |
pil_images.append(to_base64_PIL(input_image)) | |
outpath = backbone.get_outpath() | |
gen_obj = core_generation_funnel(outpath, pil_images, None, None, options) | |
results_based = [] | |
for count, type, result in gen_obj: | |
if not isinstance(result, Image.Image): | |
continue | |
results_based += [encode_to_base64(result)] | |
return {"images": results_based, "info": "Success"} | |
async def process_video( | |
depth_input_images: List[str] = Body([], title='Input Images'), | |
options: Dict[str, object] = Body("options", title='Generation options'), | |
): | |
if len(depth_input_images) == 0: | |
raise HTTPException(status_code=422, detail="No images supplied") | |
print(f"Processing {str(len(depth_input_images))} images trough the API") | |
# You can use either these strings, or integers | |
available_models = { | |
'res101': 0, | |
'dpt_beit_large_512': 1, #midas 3.1 | |
'dpt_beit_large_384': 2, #midas 3.1 | |
'dpt_large_384': 3, #midas 3.0 | |
'dpt_hybrid_384': 4, #midas 3.0 | |
'midas_v21': 5, | |
'midas_v21_small': 6, | |
'zoedepth_n': 7, #indoor | |
'zoedepth_k': 8, #outdoor | |
'zoedepth_nk': 9, | |
'marigold_v1': 10, | |
'depth_anything': 11, | |
'depth_anything_v2_small': 12, | |
'depth_anything_v2_base': 13, | |
'depth_anything_v2_large': 14 | |
} | |
model_type = options["model_type"] | |
model_id = None | |
if isinstance(model_type, str): | |
# Check if the string is in the available_models dictionary | |
if model_type in available_models: | |
model_id = available_models[model_type] | |
else: | |
available_strings = list(available_models.keys()) | |
raise HTTPException(status_code=400, detail={'error': 'Invalid model string', 'available_models': available_strings}) | |
elif isinstance(model_type, int): | |
model_id = model_type | |
else: | |
raise HTTPException(status_code=400, detail={'error': 'Invalid model parameter type'}) | |
options["model_type"] = model_id | |
video_parameters = options["video_parameters"] | |
required_params = ["vid_numframes", "vid_fps", "vid_traj", "vid_shift", "vid_border", "dolly", "vid_format", "vid_ssaa", "output_filename"] | |
missing_params = [param for param in required_params if param not in video_parameters] | |
if missing_params: | |
raise HTTPException(status_code=400, detail={'error': f"Missing required parameter(s): {', '.join(missing_params)}"}) | |
vid_numframes = video_parameters["vid_numframes"] | |
vid_fps = video_parameters["vid_fps"] | |
vid_traj = video_parameters["vid_traj"] | |
vid_shift = video_parameters["vid_shift"] | |
vid_border = video_parameters["vid_border"] | |
dolly = video_parameters["dolly"] | |
vid_format = video_parameters["vid_format"] | |
vid_ssaa = int(video_parameters["vid_ssaa"]) | |
output_filename = video_parameters["output_filename"] | |
output_path = os.path.dirname(output_filename) | |
basename, extension = os.path.splitext(os.path.basename(output_filename)) | |
# Comparing video_format with the extension | |
if vid_format != extension[1:]: | |
raise HTTPException(status_code=400, detail={'error': f"Video format '{vid_format}' does not match with the extension '{extension}'."}) | |
pil_images = [] | |
for input_image in depth_input_images: | |
pil_images.append(to_base64_PIL(input_image)) | |
outpath = backbone.get_outpath() | |
mesh_fi_filename = video_parameters.get('mesh_fi_filename', None) | |
if mesh_fi_filename and os.path.exists(mesh_fi_filename): | |
mesh_fi = mesh_fi_filename | |
print("Loaded existing mesh from: ", mesh_fi) | |
else: | |
# If there is no mesh file generate it. | |
options["GEN_INPAINTED_MESH"] = True | |
gen_obj = core_generation_funnel(outpath, pil_images, None, None, options) | |
mesh_fi = None | |
for count, type, result in gen_obj: | |
if type == 'inpainted_mesh': | |
mesh_fi = result | |
break | |
if mesh_fi: | |
print("Created mesh in: ", mesh_fi) | |
else: | |
raise HTTPException(status_code=400, detail={'error': "The mesh has not been created"}) | |
run_makevideo(mesh_fi, vid_numframes, vid_fps, vid_traj, vid_shift, vid_border, dolly, vid_format, vid_ssaa, output_path, basename) | |
return {"info": "Success"} | |
try: | |
import modules.script_callbacks as script_callbacks | |
if backbone.get_cmd_opt('api', False): | |
script_callbacks.on_app_started(depth_api) | |
print("Started the depthmap API. DO NOT HOST PUBLICLY - SECURITY RISKS!") | |
except: | |
print('DepthMap API could not start') | |