Pierre Chapuis
update to use the official Finegrain API directly
e619418 unverified
raw
history blame
10.7 kB
import dataclasses as dc
import io
from typing import Any
import gradio as gr
import pillow_heif
from environs import Env
from gradio_image_annotation import image_annotator
from gradio_imageslider import ImageSlider
from PIL import Image
from fg import EditorAPIContext
pillow_heif.register_heif_opener()
pillow_heif.register_avif_opener()
env = Env()
env.read_env()
with env.prefixed("ERASER_"):
API_URL: str = str(env.str("API_URL", "https://api.finegrain.ai/editor"))
API_USER: str | None = env.str("API_USER")
API_PASSWORD: str | None = env.str("API_PASSWORD")
CA_BUNDLE: str | None = env.str("CA_BUNDLE", None)
assert API_USER is not None
assert API_PASSWORD is not None
CTX = EditorAPIContext(uri=API_URL, user=API_USER, password=API_PASSWORD)
if CA_BUNDLE:
CTX.verify = CA_BUNDLE
def resize(image: Image.Image, shortest_side: int = 768) -> Image.Image:
if image.width <= shortest_side and image.height <= shortest_side:
return image
if image.width < image.height:
return image.resize(size=(shortest_side, int(shortest_side * image.height / image.width)))
return image.resize(size=(int(shortest_side * image.width / image.height), shortest_side))
@dc.dataclass(kw_only=True)
class ProcessParams:
image: Image.Image
prompt: str | None = None
bbox: tuple[int, int, int, int] | None = None
async def _process(ctx: EditorAPIContext, params: ProcessParams) -> Image.Image:
with io.BytesIO() as f:
params.image.save(f, format="JPEG")
async with ctx as client:
response = await client.post(
f"{ctx.uri}/state/upload",
files={"file": f},
headers=ctx.auth_headers,
)
response.raise_for_status()
st_input = response.json()["state"]
if params.bbox:
segment_input_st = st_input
segment_params = {"bbox": list(params.bbox)}
else:
assert params.prompt
async with ctx as client:
response = await client.post(
f"{ctx.uri}/skills/infer-bbox/{st_input}",
json={"product_name": params.prompt},
headers=ctx.auth_headers,
)
response.raise_for_status()
st_bbox = response.json()["state"]
await ctx.sse_await(st_bbox)
segment_input_st = st_bbox
segment_params = {}
async with ctx as client:
response = await client.post(
f"{ctx.uri}/skills/segment/{segment_input_st}",
json=segment_params,
headers=ctx.auth_headers,
)
response.raise_for_status()
st_mask = response.json()["state"]
await ctx.sse_await(st_mask)
erase_params: dict[str, str | bool] = {
"mode": "free", # new API
"restore_original_resolution": False, # legacy API
}
async with ctx as client:
response = await client.post(
f"{ctx.uri}/skills/erase/{st_input}/{st_mask}",
json=erase_params,
headers=ctx.auth_headers,
)
response.raise_for_status()
st_erased = response.json()["state"]
await ctx.sse_await(st_erased)
async with ctx as client:
response = await client.get(
f"{ctx.uri}/state/image/{st_erased}",
params={"format": "JPEG", "resolution": "DISPLAY"},
headers=ctx.auth_headers,
)
response.raise_for_status()
f = io.BytesIO()
f.write(response.content)
f.seek(0)
return Image.open(f)
def process_bbox(prompts: dict[str, Any]) -> tuple[Image.Image, Image.Image]:
assert isinstance(img := prompts["image"], Image.Image)
assert isinstance(boxes := prompts["boxes"], list)
assert len(boxes) == 1
assert isinstance(box := boxes[0], dict)
resized_img = resize(img)
bbox = [box[k] for k in ["xmin", "ymin", "xmax", "ymax"]]
if resized_img.width != img.width:
bbox = [int(v * resized_img.width / img.width) for v in bbox]
output_image = CTX.run_one_sync(
_process,
ProcessParams(
image=resized_img,
bbox=(bbox[0], bbox[1], bbox[2], bbox[3]),
),
)
return (img, output_image)
def on_change_bbox(prompts: dict[str, Any] | None):
return gr.update(interactive=prompts is not None and len(prompts["boxes"]) > 0)
def process_prompt(img: Image.Image, prompt: str) -> tuple[Image.Image, Image.Image]:
resized_img = resize(img)
output_image = CTX.run_one_sync(
_process,
ProcessParams(image=resized_img, prompt=prompt),
)
return (img, output_image)
def on_change_prompt(img: Image.Image | None, prompt: str | None):
return gr.update(interactive=bool(img and prompt))
TITLE = """
<center>
<div style="
background-color: #ff9100;
color: #1f2937;
padding: 0.5rem 1rem;
font-size: 1.25rem;
">
πŸš€ For an optimized version of this space, try out the
<a href="https://finegrain.ai/editor?utm_source=hf&utm_campaign=object-eraser" target="_blank">Finegrain Editor</a>!
You'll find there all our AI tools made available in a nice UI. πŸš€
</div>
<h1 style="font-size: 1.5rem; margin-bottom: 0.5rem;">
Object Eraser Powered By Refiners
</h1>
<p>
Erase any object from your image just by naming it β€” no manual work required!
Not only will the object disappear, but so will its effects on the scene, like shadows or reflections.
</p>
<p>
This space is powered by Refiners, our open source micro-framework for simple foundation model adaptation.
If you enjoyed it, please consider starring Refiners on GitHub!
</p>
<a href="https://github.com/finegrain-ai/refiners" target="_blank">
<img src="https://img.shields.io/github/stars/finegrain-ai/refiners?style=social" />
</a>
</center>
"""
with gr.Blocks() as demo:
gr.HTML(TITLE)
with gr.Tab("By prompt", id="tab_prompt"):
with gr.Row():
with gr.Column():
iimg = gr.Image(type="pil", label="Input")
prompt = gr.Textbox(label="What should we erase?")
with gr.Column():
oimg = ImageSlider(label="Output")
with gr.Row():
btn = gr.ClearButton(components=[oimg], value="Erase Object", interactive=False)
for inp in [iimg, prompt]:
inp.change(
fn=on_change_prompt,
inputs=[iimg, prompt],
outputs=[btn],
)
btn.click(
fn=process_prompt,
inputs=[iimg, prompt],
outputs=[oimg],
api_name=False,
)
examples = [
[
"examples/white-towels-rattan-basket-white-table-with-bright-room-background.jpg",
"soap",
],
[
"examples/interior-decor-with-mirror-potted-plant.jpg",
"potted plant",
],
[
"examples/detail-ball-basketball-court-sunset.jpg",
"basketball",
],
[
"examples/still-life-device-table_23-2150994394.jpg",
"glass of water",
],
[
"examples/knife-fork-green-checkered-napkin_140725-63576.jpg",
"knife and fork",
],
[
"examples/city-night-with-architecture-vibrant-lights_23-2149836930.jpg",
"frontmost black car on right lane",
],
[
"examples/close-up-coffee-latte-wooden-table_23-2147893063.jpg",
"coffee cup on plate",
],
[
"examples/empty-chair-with-vase-plant_74190-2078.jpg",
"chair",
],
]
ex = gr.Examples(
examples=examples,
inputs=[iimg, prompt],
outputs=[oimg],
fn=process_prompt,
cache_examples=True,
)
with gr.Tab("By bounding box", id="tab_bb"):
with gr.Row():
with gr.Column():
annotator = image_annotator(
image_type="pil",
disable_edit_boxes=True,
show_download_button=False,
show_share_button=False,
single_box=True,
label="Input",
)
with gr.Column():
oimg = ImageSlider(label="Output")
with gr.Row():
btn = gr.ClearButton(components=[oimg], value="Erase Object", interactive=False)
annotator.change(
fn=on_change_bbox,
inputs=[annotator],
outputs=[btn],
)
btn.click(
fn=process_bbox,
inputs=[annotator],
outputs=[oimg],
api_name=False,
)
examples = [
{
"image": "examples/white-towels-rattan-basket-white-table-with-bright-room-background.jpg",
"boxes": [{"xmin": 836, "ymin": 475, "xmax": 1125, "ymax": 1013}],
},
{
"image": "examples/interior-decor-with-mirror-potted-plant.jpg",
"boxes": [{"xmin": 47, "ymin": 907, "xmax": 397, "ymax": 1633}],
},
{
"image": "examples/detail-ball-basketball-court-sunset.jpg",
"boxes": [{"xmin": 673, "ymin": 954, "xmax": 911, "ymax": 1186}],
},
{
"image": "examples/still-life-device-table_23-2150994394.jpg",
"boxes": [{"xmin": 429, "ymin": 586, "xmax": 571, "ymax": 834}],
},
{
"image": "examples/knife-fork-green-checkered-napkin_140725-63576.jpg",
"boxes": [{"xmin": 972, "ymin": 226, "xmax": 1092, "ymax": 1023}],
},
{
"image": "examples/city-night-with-architecture-vibrant-lights_23-2149836930.jpg",
"boxes": [{"xmin": 215, "ymin": 637, "xmax": 411, "ymax": 855}],
},
{
"image": "examples/close-up-coffee-latte-wooden-table_23-2147893063.jpg",
"boxes": [{"xmin": 255, "ymin": 456, "xmax": 1080, "ymax": 1064}],
},
{
"image": "examples/empty-chair-with-vase-plant_74190-2078.jpg",
"boxes": [{"xmin": 35, "ymin": 320, "xmax": 383, "ymax": 983}],
},
]
ex = gr.Examples(
examples=examples,
inputs=[annotator],
outputs=[oimg],
fn=process_bbox,
cache_examples=True,
)
demo.queue(max_size=30, api_open=False)
demo.launch(show_api=False)