from __future__ import annotations import functools import os import tempfile import torch import spaces import gradio as gr from PIL import Image from gradio_imageslider import ImageSlider from pathlib import Path from gradio.utils import get_cache_folder # Constants DEFAULT_SHARPNESS = 2 class Examples(gr.helpers.Examples): def __init__(self, *args, directory_name=None, **kwargs): super().__init__(*args, **kwargs, _initiated_directly=False) if directory_name is not None: self.cached_folder = get_cache_folder() / directory_name self.cached_file = Path(self.cached_folder) / "log.csv" self.create() def load_predictor(): """Load model predictor using torch.hub""" predictor = torch.hub.load("hugoycj/StableNormal", "StableNormal", trust_repo=True) return predictor def process_image( predictor, path_input: str, sharpness: int = DEFAULT_SHARPNESS, data_type: str = "object" ) -> tuple: """Process single image""" if path_input is None: raise gr.Error("Please upload an image or select one from the gallery.") name_base = os.path.splitext(os.path.basename(path_input))[0] out_path = os.path.join(tempfile.mkdtemp(), f"{name_base}_normal.png") # Load and process image input_image = Image.open(path_input) normal_image = predictor(input_image, num_inference_steps=sharpness, match_input_resolution=False, data_type=data_type) normal_image.save(out_path) yield [input_image, out_path] def create_demo(): # Load model predictor = load_predictor() # Create processing functions for each data type process_object = spaces.GPU(functools.partial(process_image, predictor, data_type="object")) process_scene = spaces.GPU(functools.partial(process_image, predictor, data_type="indoor")) process_human = spaces.GPU(functools.partial(process_image, predictor, data_type="object")) # Define markdown content HEADER_MD = """ # StableNormal: Reducing Diffusion Variance for Stable and Sharp Normal
"""
# Create interface
demo = gr.Blocks(
title="Stable Normal Estimation",
css="""
.slider .inner { width: 5px; background: #FFF; }
.viewport { aspect-ratio: 4/3; }
.tabs button.selected { font-size: 20px !important; color: crimson !important; }
h1, h2, h3 { text-align: center; display: block; }
.md_feedback li { margin-bottom: 0px !important; }
"""
)
with demo:
gr.Markdown(HEADER_MD)
with gr.Tabs() as tabs:
# Object Tab
with gr.Tab("Object"):
with gr.Row():
with gr.Column():
object_input = gr.Image(label="Input Object Image", type="filepath")
object_sharpness = gr.Slider(
minimum=1,
maximum=10,
value=DEFAULT_SHARPNESS,
step=1,
label="Sharpness (inference steps)",
info="Higher values produce sharper results but take longer"
)
with gr.Row():
object_submit_btn = gr.Button("Compute Normal", variant="primary")
object_reset_btn = gr.Button("Reset")
with gr.Column():
object_output_slider = ImageSlider(
label="Normal outputs",
type="filepath",
show_download_button=True,
show_share_button=True,
interactive=False,
elem_classes="slider",
position=0.25,
)
Examples(
fn=process_object,
examples=sorted([
os.path.join("files", "object", name)
for name in os.listdir(os.path.join("files", "object"))
if os.path.exists(os.path.join("files", "object"))
]),
inputs=[object_input],
outputs=[object_output_slider],
cache_examples=True,
directory_name="examples_object",
examples_per_page=50,
)
# Scene Tab
with gr.Tab("Scene"):
with gr.Row():
with gr.Column():
scene_input = gr.Image(label="Input Scene Image", type="filepath")
scene_sharpness = gr.Slider(
minimum=1,
maximum=10,
value=DEFAULT_SHARPNESS,
step=1,
label="Sharpness (inference steps)",
info="Higher values produce sharper results but take longer"
)
with gr.Row():
scene_submit_btn = gr.Button("Compute Normal", variant="primary")
scene_reset_btn = gr.Button("Reset")
with gr.Column():
scene_output_slider = ImageSlider(
label="Normal outputs",
type="filepath",
show_download_button=True,
show_share_button=True,
interactive=False,
elem_classes="slider",
position=0.25,
)
Examples(
fn=process_scene,
examples=sorted([
os.path.join("files", "scene", name)
for name in os.listdir(os.path.join("files", "scene"))
if os.path.exists(os.path.join("files", "scene"))
]),
inputs=[scene_input],
outputs=[scene_output_slider],
cache_examples=True,
directory_name="examples_scene",
examples_per_page=50,
)
# Human Tab
with gr.Tab("Human"):
with gr.Row():
with gr.Column():
human_input = gr.Image(label="Input Human Image", type="filepath")
human_sharpness = gr.Slider(
minimum=1,
maximum=10,
value=DEFAULT_SHARPNESS,
step=1,
label="Sharpness (inference steps)",
info="Higher values produce sharper results but take longer"
)
with gr.Row():
human_submit_btn = gr.Button("Compute Normal", variant="primary")
human_reset_btn = gr.Button("Reset")
with gr.Column():
human_output_slider = ImageSlider(
label="Normal outputs",
type="filepath",
show_download_button=True,
show_share_button=True,
interactive=False,
elem_classes="slider",
position=0.25,
)
Examples(
fn=process_human,
examples=sorted([
os.path.join("files", "human", name)
for name in os.listdir(os.path.join("files", "human"))
if os.path.exists(os.path.join("files", "human"))
]),
inputs=[human_input],
outputs=[human_output_slider],
cache_examples=True,
directory_name="examples_human",
examples_per_page=50,
)
# Event Handlers for Object Tab
object_submit_btn.click(
fn=lambda x, _: None if x else gr.Error("Please upload an image"),
inputs=[object_input, object_sharpness],
outputs=None,
queue=False,
).success(
fn=process_object,
inputs=[object_input, object_sharpness],
outputs=[object_output_slider],
)
object_reset_btn.click(
fn=lambda: (None, DEFAULT_SHARPNESS, None),
inputs=[],
outputs=[object_input, object_sharpness, object_output_slider],
queue=False,
)
# Event Handlers for Scene Tab
scene_submit_btn.click(
fn=lambda x, _: None if x else gr.Error("Please upload an image"),
inputs=[scene_input, scene_sharpness],
outputs=None,
queue=False,
).success(
fn=process_scene,
inputs=[scene_input, scene_sharpness],
outputs=[scene_output_slider],
)
scene_reset_btn.click(
fn=lambda: (None, DEFAULT_SHARPNESS, None),
inputs=[],
outputs=[scene_input, scene_sharpness, scene_output_slider],
queue=False,
)
# Event Handlers for Human Tab
human_submit_btn.click(
fn=lambda x, _: None if x else gr.Error("Please upload an image"),
inputs=[human_input, human_sharpness],
outputs=None,
queue=False,
).success(
fn=process_human,
inputs=[human_input, human_sharpness],
outputs=[human_output_slider],
)
human_reset_btn.click(
fn=lambda: (None, DEFAULT_SHARPNESS, None),
inputs=[],
outputs=[human_input, human_sharpness, human_output_slider],
queue=False,
)
return demo
def main():
demo = create_demo()
demo.queue(api_open=False).launch(
server_name="0.0.0.0",
server_port=7860,
)
if __name__ == "__main__":
main()