Spaces:
Runtime error
Runtime error
File size: 3,939 Bytes
c29ac1a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import gradio as gr
import PIL.Image
import transformers
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
import torch
import os
import string
import functools
import re
import numpy as np
import spaces
model_id = "agentsea/paligemma-3b-ft-widgetcap-waveui-448"
processor_id = "google/paligemma-3b-pt-448"
COLORS = ['#4285f4', '#db4437', '#f4b400', '#0f9d58', '#e48ef1']
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).eval().to(device)
processor = PaliGemmaProcessor.from_pretrained(processor_id)
###### Transformers Inference
@spaces.GPU
def infer(
image: PIL.Image.Image,
text: str,
max_new_tokens: int
) -> str:
inputs = processor(text=text, images=image, return_tensors="pt").to(device)
with torch.inference_mode():
generated_ids = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False
)
result = processor.batch_decode(generated_ids, skip_special_tokens=True)
return result[0][len(text):].lstrip("\n")
def parse_segmentation(input_image, input_text):
out = infer(input_image, input_text, max_new_tokens=100)
objs = extract_objs(out.lstrip("\n"), input_image.size[0], input_image.size[1], unique_labels=True)
labels = set(obj.get('name') for obj in objs if obj.get('name'))
color_map = {l: COLORS[i % len(COLORS)] for i, l in enumerate(labels)}
highlighted_text = [(obj['content'], obj.get('name')) for obj in objs]
annotated_img = (
input_image,
[
(
obj['mask'] if obj.get('mask') is not None else obj['xyxy'],
obj['name'] or '',
)
for obj in objs
if 'mask' in obj or 'xyxy' in obj
],
)
has_annotations = bool(annotated_img[1])
return annotated_img
######## Demo
INTRO_TEXT = """## PaliGemma WaveUI\n\n
Bla bla
"""
with gr.Blocks(css="style.css") as demo:
gr.Markdown(INTRO_TEXT)
with gr.Tab("Detection"):
image = gr.Image(type="pil")
seg_input = gr.Text(label="Entities to Detect")
seg_btn = gr.Button("Submit")
annotated_image = gr.AnnotatedImage(label="Output")
examples = [["./airbnb.jpg", "detect 'Amazing pools' button"]]
gr.Markdown("Example images are licensed CC0 by [akolesnikoff@](https://github.com/akolesnikoff), [mbosnjak@](https://github.com/mbosnjak), [maximneumann@](https://github.com/maximneumann) and [merve](https://huggingface.co/merve).")
gr.Examples(
examples=examples,
inputs=[image, seg_input],
)
seg_inputs = [
image,
seg_input
]
seg_outputs = [
annotated_image
]
seg_btn.click(
fn=parse_segmentation,
inputs=seg_inputs,
outputs=seg_outputs,
)
_SEGMENT_DETECT_RE = re.compile(
r'(.*?)' +
r'<loc(\d{4})>' * 4 + r'\s*' +
'(?:%s)?' % (r'<seg(\d{3})>' * 16) +
r'\s*([^;<>]+)? ?(?:; )?',
)
def extract_objs(text, width, height, unique_labels=False):
"""Returns objs for a string with "<loc>" and "<seg>" tokens."""
objs = []
seen = set()
while text:
m = _SEGMENT_DETECT_RE.match(text)
if not m:
break
print("m", m)
gs = list(m.groups())
before = gs.pop(0)
name = gs.pop()
y1, x1, y2, x2 = [int(x) / 1024 for x in gs[:4]]
y1, x1, y2, x2 = map(round, (y1*height, x1*width, y2*height, x2*width))
mask = None
content = m.group()
if before:
objs.append(dict(content=before))
content = content[len(before):]
while unique_labels and name in seen:
name = (name or '') + "'"
seen.add(name)
objs.append(dict(
content=content, xyxy=(x1, y1, x2, y2), mask=mask, name=name))
text = text[len(before) + len(content):]
if text:
objs.append(dict(content=text))
return objs
#########
if __name__ == "__main__":
demo.queue(max_size=10).launch(debug=True) |