Spaces:
Sleeping
Sleeping
Mohammed Abdeldayem
commited on
Create code.py
Browse files
code.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoTokenizer, VisionEncoderDecoderModel, AutoImageProcessor
|
3 |
+
from PIL import Image
|
4 |
+
from torchvision.transforms.functional import crop
|
5 |
+
import gradio as gr
|
6 |
+
|
7 |
+
# Load models during initialization
|
8 |
+
def init():
|
9 |
+
global object_detection_model, captioning_model, tokenizer, captioning_processor
|
10 |
+
|
11 |
+
# Step 1: Load the YOLOv5 model from Hugging Face
|
12 |
+
try:
|
13 |
+
object_detection_model = torch.hub.load('ultralytics/yolov5', 'custom', path='yolov5/weights/best14.pt') # Assuming model is locally available
|
14 |
+
print("YOLOv5 model loaded successfully.")
|
15 |
+
except Exception as e:
|
16 |
+
print(f"Error loading YOLOv5 model: {e}")
|
17 |
+
|
18 |
+
# Step 2: Load the ViT-GPT2 captioning model from Hugging Face
|
19 |
+
try:
|
20 |
+
captioning_model = VisionEncoderDecoderModel.from_pretrained("motheecreator/ViT-GPT2-Image-Captioning")
|
21 |
+
tokenizer = AutoTokenizer.from_pretrained("motheecreator/ViT-GPT2-Image-Captioning")
|
22 |
+
captioning_processor = AutoImageProcessor.from_pretrained("motheecreator/ViT-GPT2-Image-Captioning")
|
23 |
+
print("ViT-GPT2 model loaded successfully.")
|
24 |
+
except Exception as e:
|
25 |
+
print(f"Error loading captioning model: {e}")
|
26 |
+
|
27 |
+
# Utility function to crop objects from the image based on bounding boxes
|
28 |
+
def crop_objects(image, boxes):
|
29 |
+
cropped_images = []
|
30 |
+
for box in boxes:
|
31 |
+
cropped_image = crop(image, int(box[1]), int(box[0]), int(box[3] - box[1]), int(box[2] - box[0]))
|
32 |
+
cropped_images.append(cropped_image)
|
33 |
+
return cropped_images
|
34 |
+
|
35 |
+
# Gradio interface function
|
36 |
+
def process_image(image):
|
37 |
+
try:
|
38 |
+
# Step 1: Perform object detection with YOLOv5
|
39 |
+
results = object_detection_model(image)
|
40 |
+
boxes = results.xyxy[0][:, :4].cpu().numpy() # Bounding boxes
|
41 |
+
labels = [results.names[int(class_id)] for class_id in results.xyxy[0][:, 5].cpu().numpy().astype(int)] # Class names
|
42 |
+
scores = results.xyxy[0][:, 4].cpu().numpy() # Confidence scores
|
43 |
+
|
44 |
+
# Step 2: Generate caption for the whole image
|
45 |
+
original_inputs = captioning_processor(images=image, return_tensors="pt")
|
46 |
+
with torch.no_grad():
|
47 |
+
original_caption_ids = captioning_model.generate(**original_inputs)
|
48 |
+
original_caption = tokenizer.decode(original_caption_ids[0], skip_special_tokens=True)
|
49 |
+
|
50 |
+
# Step 3: Crop detected objects and generate captions for each object
|
51 |
+
cropped_images = crop_objects(image, boxes)
|
52 |
+
captions = []
|
53 |
+
for cropped_image in cropped_images:
|
54 |
+
inputs = captioning_processor(images=cropped_image, return_tensors="pt")
|
55 |
+
with torch.no_grad():
|
56 |
+
caption_ids = captioning_model.generate(**inputs)
|
57 |
+
caption = tokenizer.decode(caption_ids[0], skip_special_tokens=True)
|
58 |
+
captions.append(caption)
|
59 |
+
|
60 |
+
# Prepare the result for visualization
|
61 |
+
detection_results = []
|
62 |
+
for i, (label, box, score, caption) in enumerate(zip(labels, boxes, scores, captions)):
|
63 |
+
detection_results.append({
|
64 |
+
"label": label,
|
65 |
+
"caption": caption,
|
66 |
+
"bounding_box": [float(coord) for coord in box], # Convert to float
|
67 |
+
"confidence_score": float(score) # Convert to float
|
68 |
+
})
|
69 |
+
|
70 |
+
# Return the image with detections and the caption
|
71 |
+
return results.render()[0], detection_results, original_caption
|
72 |
+
|
73 |
+
except Exception as e:
|
74 |
+
return None, {"error": str(e)}, None
|
75 |
+
|
76 |
+
# Initialize models
|
77 |
+
init()
|
78 |
+
|
79 |
+
# Gradio Interface
|
80 |
+
interface = gr.Interface(
|
81 |
+
fn=process_image, # Function to run
|
82 |
+
inputs=gr.inputs.Image(type="pil"), # Input: Image upload
|
83 |
+
outputs=[gr.outputs.Image(type="pil", label="Detected Objects"), # Output 1: Image with bounding boxes
|
84 |
+
gr.outputs.JSON(label="Object Captions & Bounding Boxes"), # Output 2: JSON results for each object
|
85 |
+
gr.outputs.Textbox(label="Whole Image Caption")], # Output 3: Caption for the whole image
|
86 |
+
live=True
|
87 |
+
)
|
88 |
+
|
89 |
+
# Launch the Gradio app
|
90 |
+
interface.launch()
|