BenK0y commited on
Commit
18862e8
·
verified ·
1 Parent(s): 5da7355

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -30
app.py CHANGED
@@ -14,10 +14,10 @@ load_dotenv()
14
  # Get the API key from the environment
15
  API_KEY = os.getenv("GOOGLE_API_KEY")
16
 
17
- # Set up the model with the API key
18
  genai.configure(api_key=API_KEY)
19
 
20
- # Set up the model
21
  generation_config = {
22
  "temperature": 0.7,
23
  "top_p": 0.9,
@@ -48,6 +48,11 @@ model = genai.GenerativeModel(model_name="gemini-1.5-flash-latest",
48
  generation_config=generation_config,
49
  safety_settings=safety_settings)
50
 
 
 
 
 
 
51
  def input_image_setup(file_loc):
52
  if not (img := Path(file_loc)).exists():
53
  raise FileNotFoundError(f"Could not find image: {img}")
@@ -60,24 +65,17 @@ def input_image_setup(file_loc):
60
  ]
61
  return image_parts
62
 
63
- def generate_gemini_response(input_prompt, image_loc):
64
- image_prompt = input_image_setup(image_loc)
65
- prompt_parts = [input_prompt, image_prompt[0]]
 
 
 
 
 
66
  response = model.generate_content(prompt_parts)
67
  return response.text
68
 
69
- input_prompt = """ give me the info of the car:
70
- - plate:
71
- - model:
72
- - color: """
73
-
74
- def upload_file(files):
75
- if not files:
76
- return None, "Image not uploaded"
77
- file_paths = [file.name for file in files]
78
- response = generate_gemini_response(input_prompt, file_paths[0])
79
- return file_paths[0], response
80
-
81
  # Object detection part
82
  def detect_objects(image):
83
  processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
@@ -89,13 +87,57 @@ def detect_objects(image):
89
  target_sizes = torch.tensor([image.size[::-1]])
90
  results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
91
 
 
92
  draw = ImageDraw.Draw(image)
 
 
93
  for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
94
- box = [round(i, 2) for i in box.tolist()]
95
- draw.rectangle(box, outline="red", width=3)
96
- draw.text((box[0], box[1]), f"{model.config.id2label[label.item()]}: {round(score.item(), 2)}", fill="red")
 
 
 
 
 
 
 
 
97
 
98
- return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  with gr.Blocks() as demo:
101
  header = gr.Label("RADARPICK: Vous avez été radarisé!")
@@ -105,15 +147,7 @@ with gr.Blocks() as demo:
105
 
106
  file_output = gr.Textbox(label="Generated Caption/Post Content")
107
 
108
- def process_generate(files):
109
- if not files:
110
- return None, "Image not uploaded"
111
- file_path = files[0].name
112
- image = Image.open(file_path)
113
- detected_image = detect_objects(image)
114
- return detected_image, upload_file(files)[1]
115
-
116
  upload_button.upload(fn=lambda files: files[0].name if files else None, inputs=[upload_button], outputs=image_output)
117
  generate_button.click(fn=process_generate, inputs=[upload_button], outputs=[image_output, file_output])
118
 
119
- demo.launch()
 
14
  # Get the API key from the environment
15
  API_KEY = os.getenv("GOOGLE_API_KEY")
16
 
17
+ # Set up the generative AI model with the API key
18
  genai.configure(api_key=API_KEY)
19
 
20
+ # Set up the generative model
21
  generation_config = {
22
  "temperature": 0.7,
23
  "top_p": 0.9,
 
48
  generation_config=generation_config,
49
  safety_settings=safety_settings)
50
 
51
+ input_prompt_template = """give me the info of the car:
52
+ - plate:
53
+ - model:
54
+ - color: """
55
+
56
  def input_image_setup(file_loc):
57
  if not (img := Path(file_loc)).exists():
58
  raise FileNotFoundError(f"Could not find image: {img}")
 
65
  ]
66
  return image_parts
67
 
68
+ def generate_gemini_response(input_prompt, image):
69
+ image_parts = [
70
+ {
71
+ "mime_type": "image/jpeg",
72
+ "data": image
73
+ }
74
+ ]
75
+ prompt_parts = [input_prompt, image_parts[0]]
76
  response = model.generate_content(prompt_parts)
77
  return response.text
78
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  # Object detection part
80
  def detect_objects(image):
81
  processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
 
87
  target_sizes = torch.tensor([image.size[::-1]])
88
  results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
89
 
90
+ detected_cars = []
91
  draw = ImageDraw.Draw(image)
92
+
93
+ # Loop through detections and filter only "car" class (ID 3 for COCO dataset)
94
  for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
95
+ if model.config.id2label[label.item()] == 'car' and score.item() > 0.9:
96
+ box = [round(i, 2) for i in box.tolist()]
97
+ # Crop the detected car
98
+ cropped_car = image.crop(box)
99
+ # Convert the cropped image to bytes
100
+ cropped_car_bytes = image_to_bytes(cropped_car)
101
+ detected_cars.append((cropped_car_bytes, box))
102
+
103
+ # Draw bounding box around the car
104
+ draw.rectangle(box, outline="red", width=3)
105
+ draw.text((box[0], box[1]), f"Car: {round(score.item(), 2)}", fill="red")
106
 
107
+ return image, detected_cars
108
+
109
+ def image_to_bytes(img):
110
+ # Convert a PIL image to bytes
111
+ from io import BytesIO
112
+ img_bytes = BytesIO()
113
+ img.save(img_bytes, format="JPEG")
114
+ img_bytes = img_bytes.getvalue()
115
+ return img_bytes
116
+
117
+ def upload_file(files):
118
+ if not files:
119
+ return None, "Image not uploaded"
120
+ file_paths = [file.name for file in files]
121
+ return file_paths[0]
122
+
123
+ def process_generate(files):
124
+ if not files:
125
+ return None, "Image not uploaded"
126
+
127
+ # Load the image
128
+ file_path = files[0].name
129
+ image = Image.open(file_path)
130
+
131
+ # Detect cars and return cropped car images
132
+ detected_image, detected_cars = detect_objects(image)
133
+
134
+ # Generate responses for each car
135
+ car_info_list = []
136
+ for car_bytes, box in detected_cars:
137
+ car_info = generate_gemini_response(input_prompt_template, car_bytes)
138
+ car_info_list.append(f"Car at {box}:\n{car_info}\n")
139
+
140
+ return detected_image, "\n".join(car_info_list)
141
 
142
  with gr.Blocks() as demo:
143
  header = gr.Label("RADARPICK: Vous avez été radarisé!")
 
147
 
148
  file_output = gr.Textbox(label="Generated Caption/Post Content")
149
 
 
 
 
 
 
 
 
 
150
  upload_button.upload(fn=lambda files: files[0].name if files else None, inputs=[upload_button], outputs=image_output)
151
  generate_button.click(fn=process_generate, inputs=[upload_button], outputs=[image_output, file_output])
152
 
153
+ demo.launch()