mervenoyan commited on
Commit
8f43ecc
·
0 Parent(s):

clear commit history

Browse files
.gitattributes ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ transformers-4.47.0.dev0-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
37
+ examples/bee.jpg filter=lfs diff=lfs merge=lfs -text
38
+ examples/cats.png filter=lfs diff=lfs merge=lfs -text
39
+ examples/emu.jpg filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Paligemma2
3
+ emoji: 🌖
4
+ colorFrom: pink
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 5.6.0
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import PIL.Image
4
+ import transformers
5
+ from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
6
+ import torch
7
+ import string
8
+ import functools
9
+ import re
10
+ import flax.linen as nn
11
+ import jax
12
+ import jax.numpy as jnp
13
+ import numpy as np
14
+ import spaces
15
+
16
+
17
+ model_id = "gv-hf/paligemma2-10b-mix-448"
18
+ COLORS = ['#4285f4', '#db4437', '#f4b400', '#0f9d58', '#e48ef1']
19
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+ model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).eval().to(device)
21
+ processor = PaliGemmaProcessor.from_pretrained(model_id)
22
+
23
+ ###### Transformers Inference
24
+ @spaces.GPU
25
+ def infer(
26
+ image: PIL.Image.Image,
27
+ text: str,
28
+ max_new_tokens: int
29
+ ) -> str:
30
+ inputs = processor(text=text, images=image, return_tensors="pt").to(device)
31
+ with torch.inference_mode():
32
+ generated_ids = model.generate(
33
+ **inputs,
34
+ max_new_tokens=max_new_tokens,
35
+ do_sample=False
36
+ )
37
+ result = processor.batch_decode(generated_ids, skip_special_tokens=True)
38
+ return result[0][len(text):].lstrip("\n")
39
+
40
+ ##### Parse segmentation output tokens into masks
41
+ ##### Also returns bounding boxes with their labels
42
+
43
+ def parse_segmentation(input_image, input_text):
44
+ out = infer(input_image, input_text, max_new_tokens=100)
45
+ objs = extract_objs(out.lstrip("\n"), input_image.size[0], input_image.size[1], unique_labels=True)
46
+ labels = set(obj.get('name') for obj in objs if obj.get('name'))
47
+ color_map = {l: COLORS[i % len(COLORS)] for i, l in enumerate(labels)}
48
+ highlighted_text = [(obj['content'], obj.get('name')) for obj in objs]
49
+ annotated_img = (
50
+ input_image,
51
+ [
52
+ (
53
+ obj['mask'] if obj.get('mask') is not None else obj['xyxy'],
54
+ obj['name'] or '',
55
+ )
56
+ for obj in objs
57
+ if 'mask' in obj or 'xyxy' in obj
58
+ ],
59
+ )
60
+ has_annotations = bool(annotated_img[1])
61
+ return annotated_img
62
+
63
+
64
+
65
+ ######## Demo
66
+
67
+ INTRO_TEXT = """## PaliGemma 2 demo\n\n
68
+ | [Github](https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md)
69
+ | [Blogpost](https://huggingface.co/blog/paligemma)
70
+ |\n\n
71
+ PaliGemma 2 is an open vision-language model by Google, inspired by [PaLI-3](https://arxiv.org/abs/2310.09199) and
72
+ built with open components such as the [SigLIP](https://arxiv.org/abs/2303.15343)
73
+ vision model and the [Gemma 2](https://arxiv.org/abs/2408.00118) language model. PaliGemma 2 is designed as a versatile
74
+ model for transfer to a wide range of vision-language tasks such as image and short video caption, visual question
75
+ answering, text reading, object detection and object segmentation.
76
+ \n\n
77
+ This space includes models fine-tuned on a mix of downstream tasks, **inferred via 🤗 transformers**.
78
+ See the [Blogpost](https://huggingface.co/blog/paligemma2) and
79
+ [README](https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md)
80
+ for detailed information how to use and fine-tune PaliGemma models.
81
+ \n\n
82
+ **This is an experimental research model.** Make sure to add appropriate guardrails when using the model for applications.
83
+ """
84
+
85
+
86
+ with gr.Blocks(css="style.css") as demo:
87
+ gr.Markdown(INTRO_TEXT)
88
+ with gr.Tab("Text Generation"):
89
+ with gr.Column():
90
+ image = gr.Image(type="pil")
91
+ text_input = gr.Text(label="Input Text")
92
+
93
+ text_output = gr.Text(label="Text Output")
94
+ chat_btn = gr.Button()
95
+ tokens = gr.Slider(
96
+ label="Max New Tokens",
97
+ info="Set to larger for longer generation.",
98
+ minimum=10,
99
+ maximum=100,
100
+ value=20,
101
+ step=10,
102
+ )
103
+
104
+ chat_inputs = [
105
+ image,
106
+ text_input,
107
+ tokens
108
+ ]
109
+ chat_outputs = [
110
+ text_output
111
+ ]
112
+ chat_btn.click(
113
+ fn=infer,
114
+ inputs=chat_inputs,
115
+ outputs=chat_outputs,
116
+ )
117
+
118
+ examples = [["./examples/password.jpg", "What is the password?"],
119
+ ["./examples/howto.jpg", "What does this image show?"],
120
+ ["./examples/billard.jpg", "How many red balls are there?"],
121
+ ["./examples/bowie.jpg", "Who is this?"],
122
+ ["./examples/emu.jpg", "What animal is this?"],
123
+ ["./examples/bee.jpg", "What is on the flower?"],
124
+ ["./examples/ulges.jpg", "Who is the author of this book?"]]
125
+ gr.Markdown("Example images are licensed CC0 by [akolesnikoff@](https://github.com/akolesnikoff), [mbosnjak@](https://github.com/mbosnjak), [maximneumann@](https://github.com/maximneumann) and [merve](https://huggingface.co/merve).")
126
+
127
+ gr.Examples(
128
+ examples=examples,
129
+ inputs=chat_inputs,
130
+ )
131
+ with gr.Tab("Segment/Detect"):
132
+ image = gr.Image(type="pil")
133
+ seg_input = gr.Text(label="Entities to Segment/Detect")
134
+ seg_btn = gr.Button("Submit")
135
+ annotated_image = gr.AnnotatedImage(label="Output")
136
+
137
+ examples = [["./examples/cats.png", "segment cats"],
138
+ ["./examples/bee.jpg", "detect bee"],
139
+ ["./examples/barsik.jpg", "segment cat"],
140
+ ["./examples/bird.jpg", "segment bird ; bird ; plant"]]
141
+ gr.Markdown("Example images are licensed CC0 by [akolesnikoff@](https://github.com/akolesnikoff), [mbosnjak@](https://github.com/mbosnjak), [maximneumann@](https://github.com/maximneumann) and [merve](https://huggingface.co/merve).")
142
+ gr.Examples(
143
+ examples=examples,
144
+ inputs=[image, seg_input],
145
+ )
146
+
147
+ seg_inputs = [
148
+ image,
149
+ seg_input
150
+ ]
151
+ seg_outputs = [
152
+ annotated_image
153
+ ]
154
+ seg_btn.click(
155
+ fn=parse_segmentation,
156
+ inputs=seg_inputs,
157
+ outputs=seg_outputs,
158
+ )
159
+
160
+
161
+
162
+
163
+
164
+ ### Postprocessing Utils for Segmentation Tokens
165
+ ### Segmentation tokens are passed to another VAE which decodes them to a mask
166
+
167
+ _MODEL_PATH = 'vae-oid.npz'
168
+
169
+ _SEGMENT_DETECT_RE = re.compile(
170
+ r'(.*?)' +
171
+ r'<loc(\d{4})>' * 4 + r'\s*' +
172
+ '(?:%s)?' % (r'<seg(\d{3})>' * 16) +
173
+ r'\s*([^;<>]+)? ?(?:; )?',
174
+ )
175
+
176
+
177
+ def _get_params(checkpoint):
178
+ """Converts PyTorch checkpoint to Flax params."""
179
+
180
+ def transp(kernel):
181
+ return np.transpose(kernel, (2, 3, 1, 0))
182
+
183
+ def conv(name):
184
+ return {
185
+ 'bias': checkpoint[name + '.bias'],
186
+ 'kernel': transp(checkpoint[name + '.weight']),
187
+ }
188
+
189
+ def resblock(name):
190
+ return {
191
+ 'Conv_0': conv(name + '.0'),
192
+ 'Conv_1': conv(name + '.2'),
193
+ 'Conv_2': conv(name + '.4'),
194
+ }
195
+
196
+ return {
197
+ '_embeddings': checkpoint['_vq_vae._embedding'],
198
+ 'Conv_0': conv('decoder.0'),
199
+ 'ResBlock_0': resblock('decoder.2.net'),
200
+ 'ResBlock_1': resblock('decoder.3.net'),
201
+ 'ConvTranspose_0': conv('decoder.4'),
202
+ 'ConvTranspose_1': conv('decoder.6'),
203
+ 'ConvTranspose_2': conv('decoder.8'),
204
+ 'ConvTranspose_3': conv('decoder.10'),
205
+ 'Conv_1': conv('decoder.12'),
206
+ }
207
+
208
+
209
+ def _quantized_values_from_codebook_indices(codebook_indices, embeddings):
210
+ batch_size, num_tokens = codebook_indices.shape
211
+ assert num_tokens == 16, codebook_indices.shape
212
+ unused_num_embeddings, embedding_dim = embeddings.shape
213
+
214
+ encodings = jnp.take(embeddings, codebook_indices.reshape((-1)), axis=0)
215
+ encodings = encodings.reshape((batch_size, 4, 4, embedding_dim))
216
+ return encodings
217
+
218
+
219
+ @functools.cache
220
+ def _get_reconstruct_masks():
221
+ """Reconstructs masks from codebook indices.
222
+ Returns:
223
+ A function that expects indices shaped `[B, 16]` of dtype int32, each
224
+ ranging from 0 to 127 (inclusive), and that returns a decoded masks sized
225
+ `[B, 64, 64, 1]`, of dtype float32, in range [-1, 1].
226
+ """
227
+
228
+ class ResBlock(nn.Module):
229
+ features: int
230
+
231
+ @nn.compact
232
+ def __call__(self, x):
233
+ original_x = x
234
+ x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x)
235
+ x = nn.relu(x)
236
+ x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x)
237
+ x = nn.relu(x)
238
+ x = nn.Conv(features=self.features, kernel_size=(1, 1), padding=0)(x)
239
+ return x + original_x
240
+
241
+ class Decoder(nn.Module):
242
+ """Upscales quantized vectors to mask."""
243
+
244
+ @nn.compact
245
+ def __call__(self, x):
246
+ num_res_blocks = 2
247
+ dim = 128
248
+ num_upsample_layers = 4
249
+
250
+ x = nn.Conv(features=dim, kernel_size=(1, 1), padding=0)(x)
251
+ x = nn.relu(x)
252
+
253
+ for _ in range(num_res_blocks):
254
+ x = ResBlock(features=dim)(x)
255
+
256
+ for _ in range(num_upsample_layers):
257
+ x = nn.ConvTranspose(
258
+ features=dim,
259
+ kernel_size=(4, 4),
260
+ strides=(2, 2),
261
+ padding=2,
262
+ transpose_kernel=True,
263
+ )(x)
264
+ x = nn.relu(x)
265
+ dim //= 2
266
+
267
+ x = nn.Conv(features=1, kernel_size=(1, 1), padding=0)(x)
268
+
269
+ return x
270
+
271
+ def reconstruct_masks(codebook_indices):
272
+ quantized = _quantized_values_from_codebook_indices(
273
+ codebook_indices, params['_embeddings']
274
+ )
275
+ return Decoder().apply({'params': params}, quantized)
276
+
277
+ with open(_MODEL_PATH, 'rb') as f:
278
+ params = _get_params(dict(np.load(f)))
279
+
280
+ return jax.jit(reconstruct_masks, backend='cpu')
281
+ def extract_objs(text, width, height, unique_labels=False):
282
+ """Returns objs for a string with "<loc>" and "<seg>" tokens."""
283
+ objs = []
284
+ seen = set()
285
+ while text:
286
+ m = _SEGMENT_DETECT_RE.match(text)
287
+ if not m:
288
+ break
289
+ print("m", m)
290
+ gs = list(m.groups())
291
+ before = gs.pop(0)
292
+ name = gs.pop()
293
+ y1, x1, y2, x2 = [int(x) / 1024 for x in gs[:4]]
294
+
295
+ y1, x1, y2, x2 = map(round, (y1*height, x1*width, y2*height, x2*width))
296
+ seg_indices = gs[4:20]
297
+ if seg_indices[0] is None:
298
+ mask = None
299
+ else:
300
+ seg_indices = np.array([int(x) for x in seg_indices], dtype=np.int32)
301
+ m64, = _get_reconstruct_masks()(seg_indices[None])[..., 0]
302
+ m64 = np.clip(np.array(m64) * 0.5 + 0.5, 0, 1)
303
+ m64 = PIL.Image.fromarray((m64 * 255).astype('uint8'))
304
+ mask = np.zeros([height, width])
305
+ if y2 > y1 and x2 > x1:
306
+ mask[y1:y2, x1:x2] = np.array(m64.resize([x2 - x1, y2 - y1])) / 255.0
307
+
308
+ content = m.group()
309
+ if before:
310
+ objs.append(dict(content=before))
311
+ content = content[len(before):]
312
+ while unique_labels and name in seen:
313
+ name = (name or '') + "'"
314
+ seen.add(name)
315
+ objs.append(dict(
316
+ content=content, xyxy=(x1, y1, x2, y2), mask=mask, name=name))
317
+ text = text[len(before) + len(content):]
318
+
319
+ if text:
320
+ objs.append(dict(content=text))
321
+
322
+ return objs
323
+
324
+ #########
325
+
326
+ if __name__ == "__main__":
327
+ demo.queue(max_size=10).launch(debug=True)
examples/barsik.jpg ADDED
examples/bee.jpg ADDED

Git LFS Details

  • SHA256: 8b21ba78250f852ca5990063866b1ace6432521d0251bde7f8de783b22c99a6d
  • Pointer size: 132 Bytes
  • Size of remote file: 5.37 MB
examples/billard1.jpg ADDED
examples/billard2.jpg ADDED
examples/bowie.jpg ADDED
examples/cats.png ADDED

Git LFS Details

  • SHA256: 33e53e0656ed8336c87a0e6f6441e3d75c4d64c3af225b11fca44fe95a8bc487
  • Pointer size: 131 Bytes
  • Size of remote file: 678 kB
examples/emu.jpg ADDED

Git LFS Details

  • SHA256: 82909d85d16cf1f4b2080d7962545b05611b7b80de3d26633ad7931185bf2cee
  • Pointer size: 131 Bytes
  • Size of remote file: 212 kB
examples/givt.jpg ADDED
examples/howto.jpg ADDED
examples/password.jpg ADDED
examples/ulges.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ jax
3
+ flax
4
+ spaces
5
+ transformers
vae-oid.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5586010257b8536dddefab65e7755077f21d5672d5674dacf911f73ae95a4447
3
+ size 8479556