Spaces:
Running
on
Zero
Running
on
Zero
from __future__ import annotations | |
import functools | |
import os | |
import tempfile | |
import torch | |
import spaces | |
import gradio as gr | |
from PIL import Image | |
from gradio_imageslider import ImageSlider | |
from pathlib import Path | |
from gradio.utils import get_cache_folder | |
class Examples(gr.helpers.Examples): | |
def __init__(self, *args, directory_name=None, **kwargs): | |
super().__init__(*args, **kwargs, _initiated_directly=False) | |
if directory_name is not None: | |
self.cached_folder = get_cache_folder() / directory_name | |
self.cached_file = Path(self.cached_folder) / "log.csv" | |
self.create() | |
# Global variable to store loaded predictors | |
predictors = {} | |
# Available model versions | |
MODEL_VERSIONS = { | |
"v0.3: Camera Ready Version": "yoso-normal-v0-3", | |
"v1.0: NormalAnything Version": "yoso-normal-v1-0", | |
"v1.5: Best Balance": "yoso-normal-v1-5", | |
"v1.8.1: Best Sharpness": "yoso-normal-v1-8-1" | |
} | |
def load_predictor(version: str = "v1.8.1: Best Sharpness"): | |
"""Load model predictor using torch.hub with specified version""" | |
if version not in predictors: | |
yoso_version = MODEL_VERSIONS[version] | |
print(f"Loading StableNormal with {yoso_version}...") | |
predictor = torch.hub.load("Stable-X/StableNormal", "StableNormal_turbo", | |
trust_repo=True, yoso_version=yoso_version) | |
predictors[version] = predictor | |
print(f"Successfully loaded {version}") | |
return predictors[version] | |
def precache_all_predictors(): | |
"""Precache all model predictors at startup""" | |
print("Precaching all StableNormal predictors...") | |
for version in MODEL_VERSIONS.keys(): | |
print(f"Precaching {version}...") | |
try: | |
load_predictor(version) | |
print(f"✓ Successfully precached {version}") | |
except Exception as e: | |
print(f"✗ Failed to precache {version}: {e}") | |
print("Finished precaching all predictors.") | |
def process_image( | |
path_input: str, | |
version: str = "v1.8.1: Best Sharpness", | |
data_type: str = "object" | |
) -> tuple: | |
"""Process single image with specified model version""" | |
if path_input is None: | |
raise gr.Error("Please upload an image or select one from the gallery.") | |
# Load the predictor for the specified version | |
predictor = load_predictor(version) | |
name_base = os.path.splitext(os.path.basename(path_input))[0] | |
out_path = os.path.join(tempfile.mkdtemp(), f"{name_base}_normal_{version.replace('.', '_')}.png") | |
# Load and process image | |
input_image = Image.open(path_input) | |
normal_image = predictor(input_image, match_input_resolution=False, data_type=data_type) | |
normal_image.save(out_path) | |
yield [input_image, out_path] | |
def create_demo(): | |
# Precache all predictors before creating the demo | |
precache_all_predictors() | |
# Create processing function | |
process_object = spaces.GPU(process_image) | |
# Define markdown content | |
HEADER_MD = """ | |
# 🎪 StableNormal Turbo | |
<p align="center"> | |
<a title="Website" href="https://stable-x.github.io/StableNormal/" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
<img src="https://www.obukhov.ai/img/badges/badge-website.svg"> | |
</a> | |
<a title="arXiv" href="https://arxiv.org/abs/2406.16864" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
<img src="https://www.obukhov.ai/img/badges/badge-pdf.svg"> | |
</a> | |
<a title="Github" href="https://github.com/Stable-X/StableNormal" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
<img src="https://img.shields.io/github/stars/Stable-X/StableNormal?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars"> | |
</a> | |
<a title="Social" href="https://x.com/ychngji6" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
<img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social"> | |
</a> | |
</p> | |
Select between different YOSO Normal model versions. Each version may have different performance characteristics and quality trade-offs. | |
""" | |
# Create interface | |
demo = gr.Blocks( | |
title="Stable Normal Estimation", | |
css=""" | |
.slider .inner { width: 5px; background: #FFF; } | |
.viewport { aspect-ratio: 4/3; } | |
.tabs button.selected { font-size: 20px !important; color: crimson !important; } | |
h1, h2, h3 { text-align: center; display: block; } | |
.md_feedback li { margin-bottom: 0px !important; } | |
""" | |
) | |
with demo: | |
gr.Markdown(HEADER_MD) | |
with gr.Tabs() as tabs: | |
# Object Tab | |
with gr.Tab("Object"): | |
with gr.Row(): | |
with gr.Column(): | |
object_input = gr.Image(label="Input Object Image", type="filepath") | |
# Model version selector | |
version_dropdown = gr.Dropdown( | |
choices=list(MODEL_VERSIONS.keys()), | |
value="v1.8.1: Best Sharpness", | |
label="Model Version", | |
info="Select YOSO Normal model version" | |
) | |
with gr.Row(): | |
object_submit_btn = gr.Button("Compute Normal", variant="primary") | |
object_reset_btn = gr.Button("Reset") | |
with gr.Column(): | |
object_output_slider = ImageSlider( | |
label="Normal outputs", | |
type="filepath", | |
show_download_button=True, | |
show_share_button=True, | |
interactive=False, | |
elem_classes="slider", | |
position=0.25, | |
) | |
# Examples section | |
if os.path.exists(os.path.join("files", "object")): | |
Examples( | |
fn=lambda img, ver: process_object(img, ver), | |
examples=sorted([ | |
os.path.join("files", "object", name) | |
for name in os.listdir(os.path.join("files", "object")) | |
]), | |
inputs=[object_input], | |
outputs=[object_output_slider], | |
cache_examples=False, | |
directory_name="examples_object", | |
examples_per_page=50, | |
) | |
# Event Handlers for Object Tab | |
object_submit_btn.click( | |
fn=lambda x, v: None if x else gr.Error("Please upload an image"), | |
inputs=[object_input, version_dropdown], | |
outputs=None, | |
queue=False, | |
).success( | |
fn=process_object, | |
inputs=[object_input, version_dropdown], | |
outputs=[object_output_slider], | |
) | |
object_reset_btn.click( | |
fn=lambda: (None, "v1.8.1: Best Sharpness", None), | |
inputs=[], | |
outputs=[object_input, version_dropdown, object_output_slider], | |
queue=False, | |
) | |
return demo | |
def main(): | |
demo = create_demo() | |
demo.queue(api_open=False).launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
) | |
if __name__ == "__main__": | |
main() |