import os from dotenv import load_dotenv import google.generativeai as genai from pathlib import Path import gradio as gr from transformers import DetrImageProcessor, DetrForObjectDetection import torch from PIL import Image, ImageDraw import requests # Load environment variables from .env file load_dotenv() # Get the API key from the environment API_KEY = os.getenv("GOOGLE_API_KEY") # Set up the generative AI model with the API key genai.configure(api_key=API_KEY) # Set up the generative model generation_config = { "temperature": 0.7, "top_p": 0.9, "top_k": 40, "max_output_tokens": 4000, } safety_settings = [ { "category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE" }, { "category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_MEDIUM_AND_ABOVE" }, { "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_MEDIUM_AND_ABOVE" }, { "category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE" } ] model = genai.GenerativeModel(model_name="gemini-1.5-flash-latest", generation_config=generation_config, safety_settings=safety_settings) input_prompt_template = """give me the info of the car/truck (if an info is not available juste write "introuvable"): - plate: - model: - color: """ def input_image_setup(file_loc): if not (img := Path(file_loc)).exists(): raise FileNotFoundError(f"Could not find image: {img}") image_parts = [ { "mime_type": "image/jpeg", "data": Path(file_loc).read_bytes() } ] return image_parts def generate_gemini_response(input_prompt, image): image_parts = [ { "mime_type": "image/jpeg", "data": image } ] prompt_parts = [input_prompt, image_parts[0]] response = model.generate_content(prompt_parts) return response.text # Object detection part def detect_objects(image): processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm") model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm") inputs = processor(images=image, return_tensors="pt") outputs = model(**inputs) target_sizes = torch.tensor([image.size[::-1]]) results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0] detected_cars = [] draw = ImageDraw.Draw(image) # Loop through detections and filter only "car" class (ID 3 for COCO dataset) for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): if (model.config.id2label[label.item()] == 'car' or model.config.id2label[label.item()] == 'truck' ) and score.item() > 0.9: box = [round(i, 2) for i in box.tolist()] # Crop the detected car cropped_car = image.crop(box) # Convert the cropped image to bytes cropped_car_bytes = image_to_bytes(cropped_car) detected_cars.append((cropped_car_bytes, box)) # Draw bounding box around the car draw.rectangle(box, outline="red", width=3) draw.text((box[0], box[1]), f"véhicule: {round(score.item(), 2)}", fill="red") return image, detected_cars def image_to_bytes(img): # Convert a PIL image to bytes from io import BytesIO img_bytes = BytesIO() img.save(img_bytes, format="JPEG") img_bytes = img_bytes.getvalue() return img_bytes def upload_file(files): if not files: return None, "Image not uploaded" file_paths = [file.name for file in files] return file_paths[0] def process_generate(files): if not files: return None, "Image not uploaded" # Load the image file_path = files[0].name image = Image.open(file_path) # Detect cars and return cropped car images detected_image, detected_cars = detect_objects(image) # Generate responses for each car car_info_list = [] for car_bytes, box in detected_cars: car_info = generate_gemini_response(input_prompt_template, car_bytes) car_info_list.append(f"véhicule aux coordonnées {box}:\n{car_info}\n") return detected_image, "\n".join(car_info_list) with gr.Blocks() as demo: header = gr.Label("RADARPICK: Vous avez pris en flag!") image_output = gr.Image() upload_button = gr.UploadButton("Click to upload an image", file_types=["image"], file_count="multiple") generate_button = gr.Button("Generate") file_output = gr.Textbox(label="Generated Content") upload_button.upload(fn=lambda files: files[0].name if files else None, inputs=[upload_button], outputs=image_output) generate_button.click(fn=process_generate, inputs=[upload_button], outputs=[image_output, file_output]) demo.launch()