File size: 6,591 Bytes
0912f0e 900613f 0912f0e 900613f 0912f0e 57c929f 0912f0e 57c929f 0912f0e 57c929f 0912f0e 57c929f 900613f 0912f0e 57c929f 900613f 0912f0e 900613f 57c929f 0912f0e 57c929f 0912f0e 57c929f 0912f0e 57c929f 0912f0e 57c929f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
import os
import gradio as gr
import PIL.Image
import torch
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
# Model and Processor Setup
model_id = "gv-hf/paligemma2-3b-mix-448"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
HF_KEY = os.getenv("HF_KEY")
if not HF_KEY:
raise ValueError("Please set the HF_KEY environment variable with your Hugging Face API token")
# Load model and processor
model = PaliGemmaForConditionalGeneration.from_pretrained(
model_id,
token=HF_KEY,
trust_remote_code=True
).eval().to(device)
processor = PaliGemmaProcessor.from_pretrained(
model_id,
token=HF_KEY,
trust_remote_code=True
)
# Inference Function
def infer(image: PIL.Image.Image, text: str, max_new_tokens: int) -> str:
inputs = processor(text=text, images=image, return_tensors="pt").to(device)
with torch.inference_mode():
generated_ids = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False
)
result = processor.batch_decode(generated_ids, skip_special_tokens=True)
return result[0][len(text):].lstrip("\n")
# Image Captioning (with user input for improvement)
def generate_caption(image: PIL.Image.Image, caption_improvement: str) -> str:
return infer(image, f"caption: {caption_improvement}", max_new_tokens=50)
# Object Detection/Segmentation
def parse_segmentation(input_image, input_text):
out = infer(input_image, input_text, max_new_tokens=200)
objs = extract_objs(out.lstrip("\n"), input_image.size[0], input_image.size[1], unique_labels=True)
labels = set(obj.get('name') for obj in objs if obj.get('name'))
color_map = {l: COLORS[i % len(COLORS)] for i, l in enumerate(labels)}
highlighted_text = [(obj['content'], obj.get('name')) for obj in objs]
annotated_img = (
input_image,
[
(
obj['mask'] if obj.get('mask') is not None else obj['xyxy'],
obj['name'] or '',
)
for obj in objs
if 'mask' in obj or 'xyxy' in obj
],
)
has_annotations = bool(annotated_img[1])
return annotated_img
# Helper functions for object detection/segmentation
def _get_params(checkpoint):
def transp(kernel):
return np.transpose(kernel, (2, 3, 1, 0))
def conv(name):
return {
'bias': checkpoint[name + '.bias'],
'kernel': transp(checkpoint[name + '.weight']),
}
def resblock(name):
return {
'Conv_0': conv(name + '.0'),
'Conv_1': conv(name + '.2'),
'Conv_2': conv(name + '.4'),
}
return {
'_embeddings': checkpoint['_vq_vae._embedding'],
'Conv_0': conv('decoder.0'),
'ResBlock_0': resblock('decoder.2.net'),
'ResBlock_1': resblock('decoder.3.net'),
'ConvTranspose_0': conv('decoder.4'),
'ConvTranspose_1': conv('decoder.6'),
'ConvTranspose_2': conv('decoder.8'),
'ConvTranspose_3': conv('decoder.10'),
'Conv_1': conv('decoder.12'),
}
def _quantized_values_from_codebook_indices(codebook_indices, embeddings):
batch_size, num_tokens = codebook_indices.shape
assert num_tokens == 16, codebook_indices.shape
unused_num_embeddings, embedding_dim = embeddings.shape
encodings = jnp.take(embeddings, codebook_indices.reshape((-1)), axis=0)
encodings = encodings.reshape((batch_size, 4, 4, embedding_dim))
return encodings
def extract_objs(text, width, height, unique_labels=False):
objs = []
seen = set()
while text:
m = _SEGMENT_DETECT_RE.match(text)
if not m:
break
gs = list(m.groups())
before = gs.pop(0)
name = gs.pop()
y1, x1, y2, x2 = [int(x) / 1024 for x in gs[:4]]
y1, x1, y2, x2 = map(round, (y1*height, x1*width, y2*height, x2*width))
seg_indices = gs[4:20]
if seg_indices[0] is None:
mask = None
else:
seg_indices = np.array([int(x) for x in seg_indices], dtype=np.int32)
m64, = _get_reconstruct_masks()(seg_indices[None])[..., 0]
m64 = np.clip(np.array(m64) * 0.5 + 0.5, 0, 1)
m64 = PIL.Image.fromarray((m64 * 255).astype('uint8'))
mask = np.zeros([height, width])
if y2 > y1 and x2 > x1:
mask[y1:y2, x1:x2] = np.array(m64.resize([x2 - x1, y2 - y1])) / 255.0
content = m.group()
if before:
objs.append(dict(content=before))
content = content[len(before):]
while unique_labels and name in seen:
name = (name or '') + "'"
seen.add(name)
objs.append(dict(
content=content, xyxy=(x1, y1, x2, y2), mask=mask, name=name))
text = text[len(before) + len(content):]
if text:
objs.append(dict(content=text))
return objs
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("# PaliGemma Multi-Modal App")
gr.Markdown("Upload an image and explore its features using the PaliGemma model!")
with gr.Tabs():
# Tab 1: Image Captioning
with gr.Tab("Image Captioning"):
with gr.Row():
with gr.Column():
caption_image = gr.Image(type="pil", label="Upload Image", width=512, height=512)
caption_improvement_input = gr.Textbox(label="Improvement Input", placeholder="Enter description to improve caption")
caption_btn = gr.Button("Generate Caption")
with gr.Column():
caption_output = gr.Text(label="Generated Caption")
caption_btn.click(fn=generate_caption, inputs=[caption_image, caption_improvement_input], outputs=[caption_output])
# Tab 2: Segment/Detect
with gr.Tab("Segment/Detect"):
with gr.Row():
with gr.Column():
detect_image = gr.Image(type="pil", label="Upload Image", width=512, height=512)
detect_text = gr.Textbox(label="Entities to Detect", placeholder="List entities to segment/detect")
detect_btn = gr.Button("Detect/Segment")
with gr.Column():
detect_output = gr.AnnotatedImage(label="Annotated Image")
detect_btn.click(fn=parse_segmentation, inputs=[detect_image, detect_text], outputs=[detect_output])
# Launch the App
if __name__ == "__main__":
demo.queue(max_size=10).launch(debug=True)
|