GenMM / app.py
radames's picture
add header with title
fa0bdf4
raw
history blame
2.8 kB
import json
import time
import uvicorn
from pathlib import Path
from fastapi import FastAPI
from fastapi.responses import RedirectResponse
from fastapi.staticfiles import StaticFiles
from dataset.tracks_motion import TracksMotion
from GPS import GPS
import gradio as gr
def _synthesis(synthesis_setting, motion_data):
model = GPS(
init_mode=f"random_synthesis/{synthesis_setting['frames']}",
noise_sigma=synthesis_setting["noise_sigma"],
coarse_ratio=0.2,
pyr_factor=synthesis_setting["pyr_factor"],
num_stages_limit=-1,
silent=True,
device="cpu",
)
synthesized_motion = model.run(
motion_data,
mode="match_and_blend",
ext={
"criteria": {
"type": "PatchCoherentLoss",
"patch_size": synthesis_setting["patch_size"],
"stride": synthesis_setting["stride"]
if "stride" in synthesis_setting.keys()
else 1,
"loop": synthesis_setting["loop"],
"coherent_alpha": synthesis_setting["alpha"]
if synthesis_setting["completeness"]
else None,
},
"optimizer": "match_and_blend",
"num_itrs": synthesis_setting["num_steps"],
},
)
return synthesized_motion
def synthesis(data):
data = json.loads(data)
# create track object
data["setting"]["coarse_ratio"] = -1
motion_data = TracksMotion(data["tracks"], scale=data["scale"])
start = time.time()
synthesized_motion = _synthesis(data["setting"], [motion_data])
end = time.time()
data["time"] = end - start
data["tracks"] = motion_data.parse(synthesized_motion)
return data
intro = """
<h1 style="text-align: center;">
Example-based Motion Synthesis via Generative Motion Matching
</h1>
<h3 style="text-align: center; margin-bottom: 7px;">
<a href="http://weiyuli.xyz/GenMM" target="_blank">Project Page</a> | <a href="https://huggingface.co/papers/2306.00378" target="_blank">Paper</a> | <a href="https://github.com/wyysf-98/GenMM" target="_blank">Code</a>
</h3>
"""
with gr.Blocks() as demo:
gr.HTML(intro)
gr.HTML(
"""<iframe src="/GenMM_demo/" width="100%" height="700px" style="border:none;">"""
)
json_in = gr.JSON(visible=False)
json_out = gr.JSON(visible=False)
btn = gr.Button("Synthesize", visible=False)
btn.click(synthesis, inputs=[json_in], outputs=[json_out], api_name="predict")
app = FastAPI()
static_dir = Path("./GenMM_demo")
app.mount("/GenMM_demo", StaticFiles(directory=static_dir, html=True), name="static")
app = gr.mount_gradio_app(app, demo, path="/")
# serve the app
if __name__ == "__main__":
uvicorn.run("app:app", host="0.0.0.0", port=7860)