from __future__ import annotations import functools import os import tempfile import torch import spaces import gradio as gr from PIL import Image from gradio_imageslider import ImageSlider from pathlib import Path from gradio.utils import get_cache_folder class Examples(gr.helpers.Examples): def __init__(self, *args, directory_name=None, **kwargs): super().__init__(*args, **kwargs, _initiated_directly=False) if directory_name is not None: self.cached_folder = get_cache_folder() / directory_name self.cached_file = Path(self.cached_folder) / "log.csv" self.create() # Global variable to store loaded predictors predictors = {} # Available model versions MODEL_VERSIONS = { "v0.3: Camera Ready Version": "yoso-normal-v0-3", "v1.0: NormalAnything Version": "yoso-normal-v1-0", "v1.5: Best Balance": "yoso-normal-v1-5", "v1.8.1: Best Sharpness": "yoso-normal-v1-8-1" } def load_predictor(version: str = "v1.8.1: Best Sharpness"): """Load model predictor using torch.hub with specified version""" if version not in predictors: yoso_version = MODEL_VERSIONS[version] print(f"Loading StableNormal with {yoso_version}...") predictor = torch.hub.load("Stable-X/StableNormal", "StableNormal_turbo", trust_repo=True, yoso_version=yoso_version) predictors[version] = predictor print(f"Successfully loaded {version}") return predictors[version] def precache_all_predictors(): """Precache all model predictors at startup""" print("Precaching all StableNormal predictors...") for version in MODEL_VERSIONS.keys(): print(f"Precaching {version}...") try: load_predictor(version) print(f"✓ Successfully precached {version}") except Exception as e: print(f"✗ Failed to precache {version}: {e}") print("Finished precaching all predictors.") def process_image( path_input: str, version: str = "v1.8.1: Best Sharpness", data_type: str = "object" ) -> tuple: """Process single image with specified model version""" if path_input is None: raise gr.Error("Please upload an image or select one from the gallery.") # Load the predictor for the specified version predictor = load_predictor(version) name_base = os.path.splitext(os.path.basename(path_input))[0] out_path = os.path.join(tempfile.mkdtemp(), f"{name_base}_normal_{version.replace('.', '_')}.png") # Load and process image input_image = Image.open(path_input) normal_image = predictor(input_image, match_input_resolution=False, data_type=data_type) normal_image.save(out_path) yield [input_image, out_path] def create_demo(): # Precache all predictors before creating the demo precache_all_predictors() # Create processing function process_object = spaces.GPU(process_image) # Define markdown content HEADER_MD = """ # 🎪 StableNormal Turbo

badge-github-stars social

Select between different YOSO Normal model versions. Each version may have different performance characteristics and quality trade-offs. """ # Create interface demo = gr.Blocks( title="Stable Normal Estimation", css=""" .slider .inner { width: 5px; background: #FFF; } .viewport { aspect-ratio: 4/3; } .tabs button.selected { font-size: 20px !important; color: crimson !important; } h1, h2, h3 { text-align: center; display: block; } .md_feedback li { margin-bottom: 0px !important; } """ ) with demo: gr.Markdown(HEADER_MD) with gr.Tabs() as tabs: # Object Tab with gr.Tab("Object"): with gr.Row(): with gr.Column(): object_input = gr.Image(label="Input Object Image", type="filepath") # Model version selector version_dropdown = gr.Dropdown( choices=list(MODEL_VERSIONS.keys()), value="v1.8.1: Best Sharpness", label="Model Version", info="Select YOSO Normal model version" ) with gr.Row(): object_submit_btn = gr.Button("Compute Normal", variant="primary") object_reset_btn = gr.Button("Reset") with gr.Column(): object_output_slider = ImageSlider( label="Normal outputs", type="filepath", show_download_button=True, show_share_button=True, interactive=False, elem_classes="slider", position=0.25, ) # Examples section if os.path.exists(os.path.join("files", "object")): Examples( fn=lambda img, ver: process_object(img, ver), examples=sorted([ os.path.join("files", "object", name) for name in os.listdir(os.path.join("files", "object")) ]), inputs=[object_input], outputs=[object_output_slider], cache_examples=False, directory_name="examples_object", examples_per_page=50, ) # Event Handlers for Object Tab object_submit_btn.click( fn=lambda x, v: None if x else gr.Error("Please upload an image"), inputs=[object_input, version_dropdown], outputs=None, queue=False, ).success( fn=process_object, inputs=[object_input, version_dropdown], outputs=[object_output_slider], ) object_reset_btn.click( fn=lambda: (None, "v1.8.1: Best Sharpness", None), inputs=[], outputs=[object_input, version_dropdown, object_output_slider], queue=False, ) return demo def main(): demo = create_demo() demo.queue(api_open=False).launch( server_name="0.0.0.0", server_port=7860, ) if __name__ == "__main__": main()