File size: 3,850 Bytes
9dbce78
778cfce
 
 
 
7eaf7dd
 
5da7355
7eaf7dd
778cfce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa31549
778cfce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4111fe7
778cfce
dbf11a5
778cfce
 
 
133dd82
 
 
 
778cfce
87627b2
778cfce
 
 
4111fe7
778cfce
f8fe675
5da7355
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd2c4ba
0178eee
778cfce
 
0178eee
778cfce
 
 
e68acf0
778cfce
 
5da7355
 
 
 
778cfce
 
e68acf0
f8fe675
7eaf7dd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
from dotenv import load_dotenv
import google.generativeai as genai
from pathlib import Path
import gradio as gr
from transformers import DetrImageProcessor, DetrForObjectDetection
import torch
from PIL import Image, ImageDraw
import requests

# Load environment variables from .env file
load_dotenv()

# Get the API key from the environment
API_KEY = os.getenv("GOOGLE_API_KEY")

# Set up the model with the API key
genai.configure(api_key=API_KEY)

# Set up the model
generation_config = {
    "temperature": 0.7,
    "top_p": 0.9,
    "top_k": 40,
    "max_output_tokens": 4000,
}

safety_settings = [
    {
        "category": "HARM_CATEGORY_HARASSMENT",
        "threshold": "BLOCK_MEDIUM_AND_ABOVE"
    },
    {
        "category": "HARM_CATEGORY_HATE_SPEECH",
        "threshold": "BLOCK_MEDIUM_AND_ABOVE"
    },
    {
        "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
        "threshold": "BLOCK_MEDIUM_AND_ABOVE"
    },
    {
        "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
        "threshold": "BLOCK_MEDIUM_AND_ABOVE"
    }
]

model = genai.GenerativeModel(model_name="gemini-1.5-flash-latest",
                              generation_config=generation_config,
                              safety_settings=safety_settings)

def input_image_setup(file_loc):
    if not (img := Path(file_loc)).exists():
        raise FileNotFoundError(f"Could not find image: {img}")

    image_parts = [
        {
            "mime_type": "image/jpeg",
            "data": Path(file_loc).read_bytes()
            }
        ]
    return image_parts

def generate_gemini_response(input_prompt, image_loc):
    image_prompt = input_image_setup(image_loc)
    prompt_parts = [input_prompt, image_prompt[0]]
    response = model.generate_content(prompt_parts)
    return response.text

input_prompt = """ give me the info of the car: 
- plate:
- model:
- color: """

def upload_file(files):
    if not files:
        return None, "Image not uploaded"
    file_paths = [file.name for file in files]
    response = generate_gemini_response(input_prompt, file_paths[0])
    return file_paths[0], response

# Object detection part
def detect_objects(image):
    processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
    model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm")

    inputs = processor(images=image, return_tensors="pt")
    outputs = model(**inputs)

    target_sizes = torch.tensor([image.size[::-1]])
    results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]

    draw = ImageDraw.Draw(image)
    for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
        box = [round(i, 2) for i in box.tolist()]
        draw.rectangle(box, outline="red", width=3)
        draw.text((box[0], box[1]), f"{model.config.id2label[label.item()]}: {round(score.item(), 2)}", fill="red")
    
    return image

with gr.Blocks() as demo:
    header = gr.Label("RADARPICK: Vous avez été radarisé!")
    image_output = gr.Image()
    upload_button = gr.UploadButton("Click to upload an image", file_types=["image"], file_count="multiple")
    generate_button = gr.Button("Generer")
    
    file_output = gr.Textbox(label="Generated Caption/Post Content")
    
    def process_generate(files):
        if not files:
            return None, "Image not uploaded"
        file_path = files[0].name
        image = Image.open(file_path)
        detected_image = detect_objects(image)
        return detected_image, upload_file(files)[1]
    
    upload_button.upload(fn=lambda files: files[0].name if files else None, inputs=[upload_button], outputs=image_output)
    generate_button.click(fn=process_generate, inputs=[upload_button], outputs=[image_output, file_output])

demo.launch()