Stable-X's picture
Improve app
7caafe1
from __future__ import annotations
import functools
import os
import tempfile
import torch
import spaces
import gradio as gr
from PIL import Image
from gradio_imageslider import ImageSlider
from pathlib import Path
from gradio.utils import get_cache_folder
class Examples(gr.helpers.Examples):
def __init__(self, *args, directory_name=None, **kwargs):
super().__init__(*args, **kwargs, _initiated_directly=False)
if directory_name is not None:
self.cached_folder = get_cache_folder() / directory_name
self.cached_file = Path(self.cached_folder) / "log.csv"
self.create()
# Global variable to store loaded predictors
predictors = {}
# Available model versions
MODEL_VERSIONS = {
"v0.3: Camera Ready Version": "yoso-normal-v0-3",
"v1.0: NormalAnything Version": "yoso-normal-v1-0",
"v1.5: Best Balance": "yoso-normal-v1-5",
"v1.8.1: Best Sharpness": "yoso-normal-v1-8-1"
}
def load_predictor(version: str = "v1.8.1: Best Sharpness"):
"""Load model predictor using torch.hub with specified version"""
if version not in predictors:
yoso_version = MODEL_VERSIONS[version]
print(f"Loading StableNormal with {yoso_version}...")
predictor = torch.hub.load("Stable-X/StableNormal", "StableNormal_turbo",
trust_repo=True, yoso_version=yoso_version)
predictors[version] = predictor
print(f"Successfully loaded {version}")
return predictors[version]
def precache_all_predictors():
"""Precache all model predictors at startup"""
print("Precaching all StableNormal predictors...")
for version in MODEL_VERSIONS.keys():
print(f"Precaching {version}...")
try:
load_predictor(version)
print(f"✓ Successfully precached {version}")
except Exception as e:
print(f"✗ Failed to precache {version}: {e}")
print("Finished precaching all predictors.")
def process_image(
path_input: str,
version: str = "v1.8.1: Best Sharpness",
data_type: str = "object"
) -> tuple:
"""Process single image with specified model version"""
if path_input is None:
raise gr.Error("Please upload an image or select one from the gallery.")
# Load the predictor for the specified version
predictor = load_predictor(version)
name_base = os.path.splitext(os.path.basename(path_input))[0]
out_path = os.path.join(tempfile.mkdtemp(), f"{name_base}_normal_{version.replace('.', '_')}.png")
# Load and process image
input_image = Image.open(path_input)
normal_image = predictor(input_image, match_input_resolution=False, data_type=data_type)
normal_image.save(out_path)
yield [input_image, out_path]
def create_demo():
# Precache all predictors before creating the demo
precache_all_predictors()
# Create processing function
process_object = spaces.GPU(process_image)
# Define markdown content
HEADER_MD = """
# 🎪 StableNormal Turbo
<p align="center">
<a title="Website" href="https://stable-x.github.io/StableNormal/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://www.obukhov.ai/img/badges/badge-website.svg">
</a>
<a title="arXiv" href="https://arxiv.org/abs/2406.16864" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://www.obukhov.ai/img/badges/badge-pdf.svg">
</a>
<a title="Github" href="https://github.com/Stable-X/StableNormal" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://img.shields.io/github/stars/Stable-X/StableNormal?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
</a>
<a title="Social" href="https://x.com/ychngji6" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social">
</a>
</p>
Select between different YOSO Normal model versions. Each version may have different performance characteristics and quality trade-offs.
"""
# Create interface
demo = gr.Blocks(
title="Stable Normal Estimation",
css="""
.slider .inner { width: 5px; background: #FFF; }
.viewport { aspect-ratio: 4/3; }
.tabs button.selected { font-size: 20px !important; color: crimson !important; }
h1, h2, h3 { text-align: center; display: block; }
.md_feedback li { margin-bottom: 0px !important; }
"""
)
with demo:
gr.Markdown(HEADER_MD)
with gr.Tabs() as tabs:
# Object Tab
with gr.Tab("Object"):
with gr.Row():
with gr.Column():
object_input = gr.Image(label="Input Object Image", type="filepath")
# Model version selector
version_dropdown = gr.Dropdown(
choices=list(MODEL_VERSIONS.keys()),
value="v1.8.1: Best Sharpness",
label="Model Version",
info="Select YOSO Normal model version"
)
with gr.Row():
object_submit_btn = gr.Button("Compute Normal", variant="primary")
object_reset_btn = gr.Button("Reset")
with gr.Column():
object_output_slider = ImageSlider(
label="Normal outputs",
type="filepath",
show_download_button=True,
show_share_button=True,
interactive=False,
elem_classes="slider",
position=0.25,
)
# Examples section
if os.path.exists(os.path.join("files", "object")):
Examples(
fn=lambda img, ver: process_object(img, ver),
examples=sorted([
os.path.join("files", "object", name)
for name in os.listdir(os.path.join("files", "object"))
]),
inputs=[object_input],
outputs=[object_output_slider],
cache_examples=False,
directory_name="examples_object",
examples_per_page=50,
)
# Event Handlers for Object Tab
object_submit_btn.click(
fn=lambda x, v: None if x else gr.Error("Please upload an image"),
inputs=[object_input, version_dropdown],
outputs=None,
queue=False,
).success(
fn=process_object,
inputs=[object_input, version_dropdown],
outputs=[object_output_slider],
)
object_reset_btn.click(
fn=lambda: (None, "v1.8.1: Best Sharpness", None),
inputs=[],
outputs=[object_input, version_dropdown, object_output_slider],
queue=False,
)
return demo
def main():
demo = create_demo()
demo.queue(api_open=False).launch(
server_name="0.0.0.0",
server_port=7860,
)
if __name__ == "__main__":
main()