RADARPICKv3 / app.py
BenK0y's picture
Update app.py
54c509a verified
import os
from dotenv import load_dotenv
import google.generativeai as genai
from pathlib import Path
import gradio as gr
from transformers import DetrImageProcessor, DetrForObjectDetection
import torch
from PIL import Image, ImageDraw
import requests
# Load environment variables from .env file
load_dotenv()
# Get the API key from the environment
API_KEY = os.getenv("GOOGLE_API_KEY")
# Set up the generative AI model with the API key
genai.configure(api_key=API_KEY)
# Set up the generative model
generation_config = {
"temperature": 0.7,
"top_p": 0.9,
"top_k": 40,
"max_output_tokens": 4000,
}
safety_settings = [
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_MEDIUM_AND_ABOVE"
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_MEDIUM_AND_ABOVE"
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_MEDIUM_AND_ABOVE"
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_MEDIUM_AND_ABOVE"
}
]
model = genai.GenerativeModel(model_name="gemini-1.5-flash-latest",
generation_config=generation_config,
safety_settings=safety_settings)
input_prompt_template = """give me the info of the car/truck (if an info is not available juste write "introuvable"):
- plate:
- model:
- color: """
def input_image_setup(file_loc):
if not (img := Path(file_loc)).exists():
raise FileNotFoundError(f"Could not find image: {img}")
image_parts = [
{
"mime_type": "image/jpeg",
"data": Path(file_loc).read_bytes()
}
]
return image_parts
def generate_gemini_response(input_prompt, image):
image_parts = [
{
"mime_type": "image/jpeg",
"data": image
}
]
prompt_parts = [input_prompt, image_parts[0]]
response = model.generate_content(prompt_parts)
return response.text
# Object detection part
def detect_objects(image):
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
target_sizes = torch.tensor([image.size[::-1]])
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
detected_cars = []
draw = ImageDraw.Draw(image)
# Loop through detections and filter only "car" class (ID 3 for COCO dataset)
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
if (model.config.id2label[label.item()] == 'car' or model.config.id2label[label.item()] == 'truck' ) and score.item() > 0.9:
box = [round(i, 2) for i in box.tolist()]
# Crop the detected car
cropped_car = image.crop(box)
# Convert the cropped image to bytes
cropped_car_bytes = image_to_bytes(cropped_car)
detected_cars.append((cropped_car_bytes, box))
# Draw bounding box around the car
draw.rectangle(box, outline="red", width=3)
draw.text((box[0], box[1]), f"véhicule: {round(score.item(), 2)}", fill="red")
return image, detected_cars
def image_to_bytes(img):
# Convert a PIL image to bytes
from io import BytesIO
img_bytes = BytesIO()
img.save(img_bytes, format="JPEG")
img_bytes = img_bytes.getvalue()
return img_bytes
def upload_file(files):
if not files:
return None, "Image not uploaded"
file_paths = [file.name for file in files]
return file_paths[0]
def process_generate(files):
if not files:
return None, "Image not uploaded"
# Load the image
file_path = files[0].name
image = Image.open(file_path)
# Detect cars and return cropped car images
detected_image, detected_cars = detect_objects(image)
# Generate responses for each car
car_info_list = []
for car_bytes, box in detected_cars:
car_info = generate_gemini_response(input_prompt_template, car_bytes)
car_info_list.append(f"véhicule aux coordonnées {box}:\n{car_info}\n")
return detected_image, "\n".join(car_info_list)
with gr.Blocks() as demo:
header = gr.Label("RADARPICK: Vous avez pris en flag!")
image_output = gr.Image()
upload_button = gr.UploadButton("Click to upload an image", file_types=["image"], file_count="multiple")
generate_button = gr.Button("Generate")
file_output = gr.Textbox(label="Generated Content")
upload_button.upload(fn=lambda files: files[0].name if files else None, inputs=[upload_button], outputs=image_output)
generate_button.click(fn=process_generate, inputs=[upload_button], outputs=[image_output, file_output])
demo.launch()