Mohammed Abdeldayem commited on
Commit
3ce1aa7
·
verified ·
1 Parent(s): 5ffaa01

Create code.py

Browse files
Files changed (1) hide show
  1. code.py +90 -0
code.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, VisionEncoderDecoderModel, AutoImageProcessor
3
+ from PIL import Image
4
+ from torchvision.transforms.functional import crop
5
+ import gradio as gr
6
+
7
+ # Load models during initialization
8
+ def init():
9
+ global object_detection_model, captioning_model, tokenizer, captioning_processor
10
+
11
+ # Step 1: Load the YOLOv5 model from Hugging Face
12
+ try:
13
+ object_detection_model = torch.hub.load('ultralytics/yolov5', 'custom', path='yolov5/weights/best14.pt') # Assuming model is locally available
14
+ print("YOLOv5 model loaded successfully.")
15
+ except Exception as e:
16
+ print(f"Error loading YOLOv5 model: {e}")
17
+
18
+ # Step 2: Load the ViT-GPT2 captioning model from Hugging Face
19
+ try:
20
+ captioning_model = VisionEncoderDecoderModel.from_pretrained("motheecreator/ViT-GPT2-Image-Captioning")
21
+ tokenizer = AutoTokenizer.from_pretrained("motheecreator/ViT-GPT2-Image-Captioning")
22
+ captioning_processor = AutoImageProcessor.from_pretrained("motheecreator/ViT-GPT2-Image-Captioning")
23
+ print("ViT-GPT2 model loaded successfully.")
24
+ except Exception as e:
25
+ print(f"Error loading captioning model: {e}")
26
+
27
+ # Utility function to crop objects from the image based on bounding boxes
28
+ def crop_objects(image, boxes):
29
+ cropped_images = []
30
+ for box in boxes:
31
+ cropped_image = crop(image, int(box[1]), int(box[0]), int(box[3] - box[1]), int(box[2] - box[0]))
32
+ cropped_images.append(cropped_image)
33
+ return cropped_images
34
+
35
+ # Gradio interface function
36
+ def process_image(image):
37
+ try:
38
+ # Step 1: Perform object detection with YOLOv5
39
+ results = object_detection_model(image)
40
+ boxes = results.xyxy[0][:, :4].cpu().numpy() # Bounding boxes
41
+ labels = [results.names[int(class_id)] for class_id in results.xyxy[0][:, 5].cpu().numpy().astype(int)] # Class names
42
+ scores = results.xyxy[0][:, 4].cpu().numpy() # Confidence scores
43
+
44
+ # Step 2: Generate caption for the whole image
45
+ original_inputs = captioning_processor(images=image, return_tensors="pt")
46
+ with torch.no_grad():
47
+ original_caption_ids = captioning_model.generate(**original_inputs)
48
+ original_caption = tokenizer.decode(original_caption_ids[0], skip_special_tokens=True)
49
+
50
+ # Step 3: Crop detected objects and generate captions for each object
51
+ cropped_images = crop_objects(image, boxes)
52
+ captions = []
53
+ for cropped_image in cropped_images:
54
+ inputs = captioning_processor(images=cropped_image, return_tensors="pt")
55
+ with torch.no_grad():
56
+ caption_ids = captioning_model.generate(**inputs)
57
+ caption = tokenizer.decode(caption_ids[0], skip_special_tokens=True)
58
+ captions.append(caption)
59
+
60
+ # Prepare the result for visualization
61
+ detection_results = []
62
+ for i, (label, box, score, caption) in enumerate(zip(labels, boxes, scores, captions)):
63
+ detection_results.append({
64
+ "label": label,
65
+ "caption": caption,
66
+ "bounding_box": [float(coord) for coord in box], # Convert to float
67
+ "confidence_score": float(score) # Convert to float
68
+ })
69
+
70
+ # Return the image with detections and the caption
71
+ return results.render()[0], detection_results, original_caption
72
+
73
+ except Exception as e:
74
+ return None, {"error": str(e)}, None
75
+
76
+ # Initialize models
77
+ init()
78
+
79
+ # Gradio Interface
80
+ interface = gr.Interface(
81
+ fn=process_image, # Function to run
82
+ inputs=gr.inputs.Image(type="pil"), # Input: Image upload
83
+ outputs=[gr.outputs.Image(type="pil", label="Detected Objects"), # Output 1: Image with bounding boxes
84
+ gr.outputs.JSON(label="Object Captions & Bounding Boxes"), # Output 2: JSON results for each object
85
+ gr.outputs.Textbox(label="Whole Image Caption")], # Output 3: Caption for the whole image
86
+ live=True
87
+ )
88
+
89
+ # Launch the Gradio app
90
+ interface.launch()