Mehyaar commited on
Commit
31216ea
·
verified ·
1 Parent(s): dcff7c5
Files changed (1) hide show
  1. app.py +103 -0
app.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import pickle
4
+ from torchvision import transforms
5
+ from torchvision.transforms import ToTensor, RandomErasing
6
+ from PIL import Image
7
+ import numpy as np
8
+ from ultralytics import YOLO
9
+ import io
10
+
11
+
12
+ # Load the YOLO model for bird detection
13
+ yolo_model = YOLO('yolov5su.pt')
14
+
15
+ class CPU_Unpickler(pickle.Unpickler):
16
+ def find_class(self, module, name):
17
+ if module == 'torch.storage' and name == '_load_from_bytes':
18
+ return lambda b: torch.load(io.BytesIO(b), map_location='cpu') # Ensure loading on CPU
19
+ else:
20
+ return super().find_class(module, name)
21
+
22
+ # Load your model using the custom unpickler
23
+ with open(".\model\model_resultsconvnext_large.pkl", "rb") as file:
24
+ model = CPU_Unpickler(file).load()
25
+ model = model['convnext_large']['model']
26
+ model.eval()
27
+ # Function to detect bird region
28
+ def detect_bird_region(image):
29
+ results = yolo_model(image, verbose=False)
30
+ bird_boxes = results[0].boxes[results[0].boxes.cls == 14]
31
+ if len(bird_boxes) > 0:
32
+ return bird_boxes[0].xyxy[0].cpu().numpy() # Coordinates of the first detected bird
33
+ return None
34
+
35
+ # Preprocessing function for inference
36
+ def preprocess_image(image):
37
+ bird_box = detect_bird_region(image)
38
+ if bird_box is not None:
39
+ image = image.crop(bird_box) # Crop to bird region
40
+
41
+ # Apply validation transformations
42
+ val_transform = transforms.Compose([
43
+ transforms.Resize((229, 229)),
44
+ ToTensor(),
45
+ transforms.ConvertImageDtype(torch.float32),
46
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
47
+ ])
48
+ return val_transform(image).unsqueeze(0) # Add batch dimension
49
+
50
+ # Prediction function
51
+ def predict(image):
52
+ # Preprocess the image
53
+ image = preprocess_image(image)
54
+
55
+ # Perform prediction
56
+ with torch.no_grad():
57
+ outputs = model(image)
58
+ predicted_class = torch.argmax(outputs, dim=1).item()
59
+
60
+ # Map the predicted class to bird names using bird_folders
61
+ bird_folders = {
62
+ 0: "019.Gray_Catbird",
63
+ 1: "025.Pelagic_Cormorant",
64
+ 2: "026.Bronzed_Cowbird",
65
+ 3: "029.American_Crow",
66
+ 4: "039.Least_Flycatcher",
67
+ 5: "073.Blue_Jay",
68
+ 6: "085.Horned_Lark",
69
+ 7: "099.Ovenbird",
70
+ 8: "104.American_Pipit",
71
+ 9: "119.Field_Sparrow",
72
+ 10: "127.Savannah_Sparrow",
73
+ 11: "129.Song_Sparrow",
74
+ 12: "135.Bank_Swallow",
75
+ 13: "137.Cliff_Swallow",
76
+ 14: "138.Tree_Swallow",
77
+ 15: "142.Black_Tern",
78
+ 16: "143.Caspian_Tern",
79
+ 17: "144.Common_Tern",
80
+ 18: "167.Hooded_Warbler",
81
+ 19: "176.Prairie_Warbler",
82
+ 20: "177.Prothonotary_Warbler",
83
+ 21: "179.Tennessee_Warbler",
84
+ 22: "182.Yellow_Warbler",
85
+ 23: "183.Northern_Waterthrush",
86
+ 24: "185.Bohemian_Waxwing",
87
+ 25: "186.Cedar_Waxwing",
88
+ 26: "188.Pileated_Woodpecker",
89
+ 27: "192.Downy_Woodpecker",
90
+ 28: "195.Carolina_Wren",
91
+ 29: "199.Winter_Wren"
92
+ }
93
+ return bird_folders[predicted_class] # Return bird name as output
94
+
95
+ # Gradio Interface
96
+ interface = gr.Interface(
97
+ fn=predict,
98
+ inputs=gr.Image(type="pil"),
99
+ outputs="label" # Display class label as output
100
+ )
101
+
102
+ # Launch Gradio App
103
+ interface.launch()