Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -14,10 +14,10 @@ load_dotenv()
|
|
14 |
# Get the API key from the environment
|
15 |
API_KEY = os.getenv("GOOGLE_API_KEY")
|
16 |
|
17 |
-
# Set up the model with the API key
|
18 |
genai.configure(api_key=API_KEY)
|
19 |
|
20 |
-
# Set up the model
|
21 |
generation_config = {
|
22 |
"temperature": 0.7,
|
23 |
"top_p": 0.9,
|
@@ -48,6 +48,11 @@ model = genai.GenerativeModel(model_name="gemini-1.5-flash-latest",
|
|
48 |
generation_config=generation_config,
|
49 |
safety_settings=safety_settings)
|
50 |
|
|
|
|
|
|
|
|
|
|
|
51 |
def input_image_setup(file_loc):
|
52 |
if not (img := Path(file_loc)).exists():
|
53 |
raise FileNotFoundError(f"Could not find image: {img}")
|
@@ -60,24 +65,17 @@ def input_image_setup(file_loc):
|
|
60 |
]
|
61 |
return image_parts
|
62 |
|
63 |
-
def generate_gemini_response(input_prompt,
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
66 |
response = model.generate_content(prompt_parts)
|
67 |
return response.text
|
68 |
|
69 |
-
input_prompt = """ give me the info of the car:
|
70 |
-
- plate:
|
71 |
-
- model:
|
72 |
-
- color: """
|
73 |
-
|
74 |
-
def upload_file(files):
|
75 |
-
if not files:
|
76 |
-
return None, "Image not uploaded"
|
77 |
-
file_paths = [file.name for file in files]
|
78 |
-
response = generate_gemini_response(input_prompt, file_paths[0])
|
79 |
-
return file_paths[0], response
|
80 |
-
|
81 |
# Object detection part
|
82 |
def detect_objects(image):
|
83 |
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
|
@@ -89,13 +87,57 @@ def detect_objects(image):
|
|
89 |
target_sizes = torch.tensor([image.size[::-1]])
|
90 |
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
|
91 |
|
|
|
92 |
draw = ImageDraw.Draw(image)
|
|
|
|
|
93 |
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
|
94 |
-
|
95 |
-
|
96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
-
return image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
|
100 |
with gr.Blocks() as demo:
|
101 |
header = gr.Label("RADARPICK: Vous avez été radarisé!")
|
@@ -105,15 +147,7 @@ with gr.Blocks() as demo:
|
|
105 |
|
106 |
file_output = gr.Textbox(label="Generated Caption/Post Content")
|
107 |
|
108 |
-
def process_generate(files):
|
109 |
-
if not files:
|
110 |
-
return None, "Image not uploaded"
|
111 |
-
file_path = files[0].name
|
112 |
-
image = Image.open(file_path)
|
113 |
-
detected_image = detect_objects(image)
|
114 |
-
return detected_image, upload_file(files)[1]
|
115 |
-
|
116 |
upload_button.upload(fn=lambda files: files[0].name if files else None, inputs=[upload_button], outputs=image_output)
|
117 |
generate_button.click(fn=process_generate, inputs=[upload_button], outputs=[image_output, file_output])
|
118 |
|
119 |
-
demo.launch()
|
|
|
14 |
# Get the API key from the environment
|
15 |
API_KEY = os.getenv("GOOGLE_API_KEY")
|
16 |
|
17 |
+
# Set up the generative AI model with the API key
|
18 |
genai.configure(api_key=API_KEY)
|
19 |
|
20 |
+
# Set up the generative model
|
21 |
generation_config = {
|
22 |
"temperature": 0.7,
|
23 |
"top_p": 0.9,
|
|
|
48 |
generation_config=generation_config,
|
49 |
safety_settings=safety_settings)
|
50 |
|
51 |
+
input_prompt_template = """give me the info of the car:
|
52 |
+
- plate:
|
53 |
+
- model:
|
54 |
+
- color: """
|
55 |
+
|
56 |
def input_image_setup(file_loc):
|
57 |
if not (img := Path(file_loc)).exists():
|
58 |
raise FileNotFoundError(f"Could not find image: {img}")
|
|
|
65 |
]
|
66 |
return image_parts
|
67 |
|
68 |
+
def generate_gemini_response(input_prompt, image):
|
69 |
+
image_parts = [
|
70 |
+
{
|
71 |
+
"mime_type": "image/jpeg",
|
72 |
+
"data": image
|
73 |
+
}
|
74 |
+
]
|
75 |
+
prompt_parts = [input_prompt, image_parts[0]]
|
76 |
response = model.generate_content(prompt_parts)
|
77 |
return response.text
|
78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
# Object detection part
|
80 |
def detect_objects(image):
|
81 |
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
|
|
|
87 |
target_sizes = torch.tensor([image.size[::-1]])
|
88 |
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
|
89 |
|
90 |
+
detected_cars = []
|
91 |
draw = ImageDraw.Draw(image)
|
92 |
+
|
93 |
+
# Loop through detections and filter only "car" class (ID 3 for COCO dataset)
|
94 |
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
|
95 |
+
if model.config.id2label[label.item()] == 'car' and score.item() > 0.9:
|
96 |
+
box = [round(i, 2) for i in box.tolist()]
|
97 |
+
# Crop the detected car
|
98 |
+
cropped_car = image.crop(box)
|
99 |
+
# Convert the cropped image to bytes
|
100 |
+
cropped_car_bytes = image_to_bytes(cropped_car)
|
101 |
+
detected_cars.append((cropped_car_bytes, box))
|
102 |
+
|
103 |
+
# Draw bounding box around the car
|
104 |
+
draw.rectangle(box, outline="red", width=3)
|
105 |
+
draw.text((box[0], box[1]), f"Car: {round(score.item(), 2)}", fill="red")
|
106 |
|
107 |
+
return image, detected_cars
|
108 |
+
|
109 |
+
def image_to_bytes(img):
|
110 |
+
# Convert a PIL image to bytes
|
111 |
+
from io import BytesIO
|
112 |
+
img_bytes = BytesIO()
|
113 |
+
img.save(img_bytes, format="JPEG")
|
114 |
+
img_bytes = img_bytes.getvalue()
|
115 |
+
return img_bytes
|
116 |
+
|
117 |
+
def upload_file(files):
|
118 |
+
if not files:
|
119 |
+
return None, "Image not uploaded"
|
120 |
+
file_paths = [file.name for file in files]
|
121 |
+
return file_paths[0]
|
122 |
+
|
123 |
+
def process_generate(files):
|
124 |
+
if not files:
|
125 |
+
return None, "Image not uploaded"
|
126 |
+
|
127 |
+
# Load the image
|
128 |
+
file_path = files[0].name
|
129 |
+
image = Image.open(file_path)
|
130 |
+
|
131 |
+
# Detect cars and return cropped car images
|
132 |
+
detected_image, detected_cars = detect_objects(image)
|
133 |
+
|
134 |
+
# Generate responses for each car
|
135 |
+
car_info_list = []
|
136 |
+
for car_bytes, box in detected_cars:
|
137 |
+
car_info = generate_gemini_response(input_prompt_template, car_bytes)
|
138 |
+
car_info_list.append(f"Car at {box}:\n{car_info}\n")
|
139 |
+
|
140 |
+
return detected_image, "\n".join(car_info_list)
|
141 |
|
142 |
with gr.Blocks() as demo:
|
143 |
header = gr.Label("RADARPICK: Vous avez été radarisé!")
|
|
|
147 |
|
148 |
file_output = gr.Textbox(label="Generated Caption/Post Content")
|
149 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
upload_button.upload(fn=lambda files: files[0].name if files else None, inputs=[upload_button], outputs=image_output)
|
151 |
generate_button.click(fn=process_generate, inputs=[upload_button], outputs=[image_output, file_output])
|
152 |
|
153 |
+
demo.launch()
|