cosmo3769 commited on
Commit
06e0fe9
·
verified ·
1 Parent(s): 617aebf

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -0
app.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ import re
4
+ import cv2
5
+ from PIL import ImageDraw, Image
6
+
7
+ from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
8
+
9
+ mix_model_id = "google/paligemma-3b-mix-224"
10
+ mix_model = PaliGemmaForConditionalGeneration.from_pretrained(mix_model_id)
11
+ mix_processor = AutoProcessor.from_pretrained(mix_model_id)
12
+
13
+ # Helper function to parse multiple <loc> tags and return a list of coordinate sets and labels
14
+ def parse_multiple_locations(decoded_output):
15
+ # Regex pattern to match four <locxxxx> tags and the label at the end (e.g., 'cat')
16
+ loc_pattern = r"<loc(\d{4})><loc(\d{4})><loc(\d{4})><loc(\d{4})>\s+(\w+)"
17
+
18
+ matches = re.findall(loc_pattern, decoded_output)
19
+ coords_and_labels = []
20
+
21
+ for match in matches:
22
+ # Extract the coordinates and label
23
+ y1 = int(match[0]) / 1000
24
+ x1 = int(match[1]) / 1000
25
+ y2 = int(match[2]) / 1000
26
+ x2 = int(match[3]) / 1000
27
+ label = match[4]
28
+
29
+ coords_and_labels.append({
30
+ 'label': label,
31
+ 'bbox': [y1, x1, y2, x2]
32
+ })
33
+
34
+ return coords_and_labels
35
+
36
+ # Helper function to draw bounding boxes and labels for all objects on the image
37
+ def draw_multiple_bounding_boxes(image, coords_and_labels):
38
+ draw = ImageDraw.Draw(image)
39
+ width, height = image.size
40
+
41
+ for obj in coords_and_labels:
42
+ # Extract the bounding box coordinates
43
+ y1, x1, y2, x2 = obj['bbox'][0] * height, obj['bbox'][1] * width, obj['bbox'][2] * height, obj['bbox'][3] * width
44
+
45
+ # Draw bounding box and label
46
+ draw.rectangle([x1, y1, x2, y2], outline="red", width=3)
47
+ draw.text((x1, y1), obj['label'], fill="red")
48
+
49
+ return image
50
+
51
+ # Define inference function
52
+ def process_image(image, prompt):
53
+ # Process the image and prompt using the processor
54
+ inputs = mix_processor(image.convert("RGB"), prompt, return_tensors="pt")
55
+
56
+ try:
57
+ # Generate output from the model
58
+ output = mix_model.generate(**inputs, max_new_tokens=100)
59
+
60
+ # Decode the output from the model
61
+ decoded_output = mix_processor.decode(output[0], skip_special_tokens=True)
62
+
63
+ # Extract bounding box coordinates and labels
64
+ coords_and_labels = parse_multiple_locations(decoded_output)
65
+
66
+ if coords_and_labels:
67
+ # Draw bounding boxes and labels on the image
68
+ image_with_boxes = draw_multiple_bounding_boxes(image, coords_and_labels)
69
+
70
+ # Prepare the coordinates and labels for the UI
71
+ labels_and_coords = "\n".join([f"Label: {obj['label']}, Coordinates: {obj['bbox']}" for obj in coords_and_labels])
72
+
73
+ # Return the modified image and the list of coordinates+labels
74
+ return image_with_boxes, labels_and_coords
75
+ else:
76
+ return "No bounding boxes detected."
77
+
78
+ except IndexError as e:
79
+ print(f"IndexError: {e}")
80
+ return "An error occurred during processing."
81
+
82
+ # Define the Gradio interface
83
+ inputs = [
84
+ gr.Image(type="pil"),
85
+ gr.Textbox(label="Prompt", placeholder="Enter your question")
86
+ ]
87
+ outputs = [
88
+ gr.Image(label="Output Image with Bounding Boxes"),
89
+ gr.Textbox(label="Bounding Box Coordinates and Labels")
90
+ ]
91
+
92
+ # Create the Gradio app
93
+ demo = gr.Interface(fn=process_image, inputs=inputs, outputs=outputs, title="Object Detection with Mix PaliGemma Model",
94
+ description="Upload an image and get object detections with bounding boxes and labels.")
95
+
96
+ # Launch the app
97
+ demo.launch(debug=True)