import dataclasses as dc import io from functools import cache from typing import Any import gradio as gr import pillow_heif from environs import Env from finegrain import EditorAPIContext from gradio_image_annotation import image_annotator from gradio_imageslider import ImageSlider from PIL import Image pillow_heif.register_heif_opener() pillow_heif.register_avif_opener() env = Env() env.read_env() with env.prefixed("ERASER_"): API_USER: str | None = env.str("API_USER") API_PASSWORD: str | None = env.str("API_PASSWORD") API_URL: str | None = env.str("API_URL", None) CA_BUNDLE: str | None = env.str("CA_BUNDLE", None) @cache def _ctx() -> EditorAPIContext: assert API_USER is not None assert API_PASSWORD is not None ctx = EditorAPIContext(user=API_USER, password=API_PASSWORD, priority="low") if CA_BUNDLE: ctx.verify = CA_BUNDLE if API_URL: ctx.base_url = API_URL return ctx def resize(image: Image.Image, shortest_side: int = 768) -> Image.Image: if image.width <= shortest_side and image.height <= shortest_side: return image if image.width < image.height: return image.resize(size=(shortest_side, int(shortest_side * image.height / image.width))) return image.resize(size=(int(shortest_side * image.width / image.height), shortest_side)) @dc.dataclass(kw_only=True) class ProcessParams: image: Image.Image prompt: str | None = None bbox: tuple[int, int, int, int] | None = None async def _process(ctx: EditorAPIContext, params: ProcessParams) -> Image.Image: with io.BytesIO() as f: params.image.save(f, format="JPEG") response = await ctx.request("POST", "state/upload", files={"file": f}) st_input = response.json()["state"] if params.bbox: segment_input_st = st_input segment_params = {"bbox": list(params.bbox)} else: assert params.prompt segment_input_st = await ctx.ensure_skill( f"infer-bbox/{st_input}", {"product_name": params.prompt}, ) segment_params = {} st_mask = await ctx.ensure_skill(f"segment/{segment_input_st}", segment_params) st_erased = await ctx.ensure_skill(f"erase/{st_input}/{st_mask}", {"mode": "free"}) response = await ctx.request( "GET", f"state/image/{st_erased}", params={"format": "JPEG", "resolution": "DISPLAY"}, ) f = io.BytesIO() f.write(response.content) f.seek(0) return Image.open(f) def process_bbox(prompts: dict[str, Any]) -> tuple[Image.Image, Image.Image]: assert isinstance(img := prompts["image"], Image.Image) assert isinstance(boxes := prompts["boxes"], list) assert len(boxes) == 1 assert isinstance(box := boxes[0], dict) resized_img = resize(img) bbox = [box[k] for k in ["xmin", "ymin", "xmax", "ymax"]] if resized_img.width != img.width: bbox = [int(v * resized_img.width / img.width) for v in bbox] output_image = _ctx().run_one_sync( _process, ProcessParams( image=resized_img, bbox=(bbox[0], bbox[1], bbox[2], bbox[3]), ), ) return (img, output_image) def on_change_bbox(prompts: dict[str, Any] | None): return gr.update(interactive=prompts is not None and len(prompts["boxes"]) > 0) def process_prompt(img: Image.Image, prompt: str) -> tuple[Image.Image, Image.Image]: resized_img = resize(img) output_image = _ctx().run_one_sync( _process, ProcessParams(image=resized_img, prompt=prompt), ) return (img, output_image) def on_change_prompt(img: Image.Image | None, prompt: str | None): return gr.update(interactive=bool(img and prompt)) TITLE = """
Erase any object from your image just by naming it — no manual work required! Not only will the object disappear, but so will its effects on the scene, like shadows or reflections.
This space is powered by Refiners, our open source micro-framework for simple foundation model adaptation. If you enjoyed it, please consider starring Refiners on GitHub!