File size: 6,591 Bytes
0912f0e
 
 
 
 
 
 
900613f
0912f0e
 
 
 
 
900613f
0912f0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57c929f
 
 
0912f0e
57c929f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0912f0e
57c929f
 
 
 
0912f0e
57c929f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
900613f
0912f0e
 
 
 
 
 
 
 
 
57c929f
900613f
0912f0e
900613f
57c929f
0912f0e
57c929f
 
0912f0e
 
 
57c929f
 
0912f0e
57c929f
 
0912f0e
 
 
57c929f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import os
import gradio as gr
import PIL.Image
import torch
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor

# Model and Processor Setup
model_id = "gv-hf/paligemma2-3b-mix-448"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
HF_KEY = os.getenv("HF_KEY")
if not HF_KEY:
    raise ValueError("Please set the HF_KEY environment variable with your Hugging Face API token")

# Load model and processor
model = PaliGemmaForConditionalGeneration.from_pretrained(
    model_id,
    token=HF_KEY,
    trust_remote_code=True
).eval().to(device)

processor = PaliGemmaProcessor.from_pretrained(
    model_id,
    token=HF_KEY,
    trust_remote_code=True
)

# Inference Function
def infer(image: PIL.Image.Image, text: str, max_new_tokens: int) -> str:
    inputs = processor(text=text, images=image, return_tensors="pt").to(device)
    with torch.inference_mode():
        generated_ids = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False
        )
    result = processor.batch_decode(generated_ids, skip_special_tokens=True)
    return result[0][len(text):].lstrip("\n")

# Image Captioning (with user input for improvement)
def generate_caption(image: PIL.Image.Image, caption_improvement: str) -> str:
    return infer(image, f"caption: {caption_improvement}", max_new_tokens=50)

# Object Detection/Segmentation
def parse_segmentation(input_image, input_text):
    out = infer(input_image, input_text, max_new_tokens=200)
    objs = extract_objs(out.lstrip("\n"), input_image.size[0], input_image.size[1], unique_labels=True)
    labels = set(obj.get('name') for obj in objs if obj.get('name'))
    color_map = {l: COLORS[i % len(COLORS)] for i, l in enumerate(labels)}
    highlighted_text = [(obj['content'], obj.get('name')) for obj in objs]
    annotated_img = (
        input_image,
        [
            (
                obj['mask'] if obj.get('mask') is not None else obj['xyxy'],
                obj['name'] or '',
            )
            for obj in objs
            if 'mask' in obj or 'xyxy' in obj
        ],
    )
    has_annotations = bool(annotated_img[1])
    return annotated_img

# Helper functions for object detection/segmentation
def _get_params(checkpoint):
    def transp(kernel):
        return np.transpose(kernel, (2, 3, 1, 0))

    def conv(name):
        return {
            'bias': checkpoint[name + '.bias'],
            'kernel': transp(checkpoint[name + '.weight']),
        }

    def resblock(name):
        return {
            'Conv_0': conv(name + '.0'),
            'Conv_1': conv(name + '.2'),
            'Conv_2': conv(name + '.4'),
        }

    return {
        '_embeddings': checkpoint['_vq_vae._embedding'],
        'Conv_0': conv('decoder.0'),
        'ResBlock_0': resblock('decoder.2.net'),
        'ResBlock_1': resblock('decoder.3.net'),
        'ConvTranspose_0': conv('decoder.4'),
        'ConvTranspose_1': conv('decoder.6'),
        'ConvTranspose_2': conv('decoder.8'),
        'ConvTranspose_3': conv('decoder.10'),
        'Conv_1': conv('decoder.12'),
    }

def _quantized_values_from_codebook_indices(codebook_indices, embeddings):
    batch_size, num_tokens = codebook_indices.shape
    assert num_tokens == 16, codebook_indices.shape
    unused_num_embeddings, embedding_dim = embeddings.shape
    
    encodings = jnp.take(embeddings, codebook_indices.reshape((-1)), axis=0)
    encodings = encodings.reshape((batch_size, 4, 4, embedding_dim))
    return encodings

def extract_objs(text, width, height, unique_labels=False):
    objs = []
    seen = set()
    while text:
        m = _SEGMENT_DETECT_RE.match(text)
        if not m:
            break
            
        gs = list(m.groups())
        before = gs.pop(0)
        name = gs.pop()
        y1, x1, y2, x2 = [int(x) / 1024 for x in gs[:4]]
        
        y1, x1, y2, x2 = map(round, (y1*height, x1*width, y2*height, x2*width))
        seg_indices = gs[4:20]
        if seg_indices[0] is None:
            mask = None
        else:
            seg_indices = np.array([int(x) for x in seg_indices], dtype=np.int32)
            m64, = _get_reconstruct_masks()(seg_indices[None])[..., 0]
            m64 = np.clip(np.array(m64) * 0.5 + 0.5, 0, 1)
            m64 = PIL.Image.fromarray((m64 * 255).astype('uint8'))
            mask = np.zeros([height, width])
            if y2 > y1 and x2 > x1:
                mask[y1:y2, x1:x2] = np.array(m64.resize([x2 - x1, y2 - y1])) / 255.0

        content = m.group()
        if before:
            objs.append(dict(content=before))
            content = content[len(before):]
        while unique_labels and name in seen:
            name = (name or '') + "'"
        seen.add(name)
        objs.append(dict(
            content=content, xyxy=(x1, y1, x2, y2), mask=mask, name=name))
        text = text[len(before) + len(content):]

    if text:
        objs.append(dict(content=text))

    return objs

# Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("# PaliGemma Multi-Modal App")
    gr.Markdown("Upload an image and explore its features using the PaliGemma model!")

    with gr.Tabs():
        # Tab 1: Image Captioning
        with gr.Tab("Image Captioning"):
            with gr.Row():
                with gr.Column():
                    caption_image = gr.Image(type="pil", label="Upload Image", width=512, height=512)
                    caption_improvement_input = gr.Textbox(label="Improvement Input", placeholder="Enter description to improve caption")
                    caption_btn = gr.Button("Generate Caption")
                with gr.Column():
                    caption_output = gr.Text(label="Generated Caption")
            caption_btn.click(fn=generate_caption, inputs=[caption_image, caption_improvement_input], outputs=[caption_output])

        # Tab 2: Segment/Detect
        with gr.Tab("Segment/Detect"):
            with gr.Row():
                with gr.Column():
                    detect_image = gr.Image(type="pil", label="Upload Image", width=512, height=512)
                    detect_text = gr.Textbox(label="Entities to Detect", placeholder="List entities to segment/detect")
                    detect_btn = gr.Button("Detect/Segment")
                with gr.Column():
                    detect_output = gr.AnnotatedImage(label="Annotated Image")
            detect_btn.click(fn=parse_segmentation, inputs=[detect_image, detect_text], outputs=[detect_output])

# Launch the App
if __name__ == "__main__":
    demo.queue(max_size=10).launch(debug=True)