Spaces:
Sleeping
Sleeping
app.py
Browse files
app.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import pickle
|
4 |
+
from torchvision import transforms
|
5 |
+
from torchvision.transforms import ToTensor, RandomErasing
|
6 |
+
from PIL import Image
|
7 |
+
import numpy as np
|
8 |
+
from ultralytics import YOLO
|
9 |
+
import io
|
10 |
+
|
11 |
+
|
12 |
+
# Load the YOLO model for bird detection
|
13 |
+
yolo_model = YOLO('yolov5su.pt')
|
14 |
+
|
15 |
+
class CPU_Unpickler(pickle.Unpickler):
|
16 |
+
def find_class(self, module, name):
|
17 |
+
if module == 'torch.storage' and name == '_load_from_bytes':
|
18 |
+
return lambda b: torch.load(io.BytesIO(b), map_location='cpu') # Ensure loading on CPU
|
19 |
+
else:
|
20 |
+
return super().find_class(module, name)
|
21 |
+
|
22 |
+
# Load your model using the custom unpickler
|
23 |
+
with open(".\model\model_resultsconvnext_large.pkl", "rb") as file:
|
24 |
+
model = CPU_Unpickler(file).load()
|
25 |
+
model = model['convnext_large']['model']
|
26 |
+
model.eval()
|
27 |
+
# Function to detect bird region
|
28 |
+
def detect_bird_region(image):
|
29 |
+
results = yolo_model(image, verbose=False)
|
30 |
+
bird_boxes = results[0].boxes[results[0].boxes.cls == 14]
|
31 |
+
if len(bird_boxes) > 0:
|
32 |
+
return bird_boxes[0].xyxy[0].cpu().numpy() # Coordinates of the first detected bird
|
33 |
+
return None
|
34 |
+
|
35 |
+
# Preprocessing function for inference
|
36 |
+
def preprocess_image(image):
|
37 |
+
bird_box = detect_bird_region(image)
|
38 |
+
if bird_box is not None:
|
39 |
+
image = image.crop(bird_box) # Crop to bird region
|
40 |
+
|
41 |
+
# Apply validation transformations
|
42 |
+
val_transform = transforms.Compose([
|
43 |
+
transforms.Resize((229, 229)),
|
44 |
+
ToTensor(),
|
45 |
+
transforms.ConvertImageDtype(torch.float32),
|
46 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
47 |
+
])
|
48 |
+
return val_transform(image).unsqueeze(0) # Add batch dimension
|
49 |
+
|
50 |
+
# Prediction function
|
51 |
+
def predict(image):
|
52 |
+
# Preprocess the image
|
53 |
+
image = preprocess_image(image)
|
54 |
+
|
55 |
+
# Perform prediction
|
56 |
+
with torch.no_grad():
|
57 |
+
outputs = model(image)
|
58 |
+
predicted_class = torch.argmax(outputs, dim=1).item()
|
59 |
+
|
60 |
+
# Map the predicted class to bird names using bird_folders
|
61 |
+
bird_folders = {
|
62 |
+
0: "019.Gray_Catbird",
|
63 |
+
1: "025.Pelagic_Cormorant",
|
64 |
+
2: "026.Bronzed_Cowbird",
|
65 |
+
3: "029.American_Crow",
|
66 |
+
4: "039.Least_Flycatcher",
|
67 |
+
5: "073.Blue_Jay",
|
68 |
+
6: "085.Horned_Lark",
|
69 |
+
7: "099.Ovenbird",
|
70 |
+
8: "104.American_Pipit",
|
71 |
+
9: "119.Field_Sparrow",
|
72 |
+
10: "127.Savannah_Sparrow",
|
73 |
+
11: "129.Song_Sparrow",
|
74 |
+
12: "135.Bank_Swallow",
|
75 |
+
13: "137.Cliff_Swallow",
|
76 |
+
14: "138.Tree_Swallow",
|
77 |
+
15: "142.Black_Tern",
|
78 |
+
16: "143.Caspian_Tern",
|
79 |
+
17: "144.Common_Tern",
|
80 |
+
18: "167.Hooded_Warbler",
|
81 |
+
19: "176.Prairie_Warbler",
|
82 |
+
20: "177.Prothonotary_Warbler",
|
83 |
+
21: "179.Tennessee_Warbler",
|
84 |
+
22: "182.Yellow_Warbler",
|
85 |
+
23: "183.Northern_Waterthrush",
|
86 |
+
24: "185.Bohemian_Waxwing",
|
87 |
+
25: "186.Cedar_Waxwing",
|
88 |
+
26: "188.Pileated_Woodpecker",
|
89 |
+
27: "192.Downy_Woodpecker",
|
90 |
+
28: "195.Carolina_Wren",
|
91 |
+
29: "199.Winter_Wren"
|
92 |
+
}
|
93 |
+
return bird_folders[predicted_class] # Return bird name as output
|
94 |
+
|
95 |
+
# Gradio Interface
|
96 |
+
interface = gr.Interface(
|
97 |
+
fn=predict,
|
98 |
+
inputs=gr.Image(type="pil"),
|
99 |
+
outputs="label" # Display class label as output
|
100 |
+
)
|
101 |
+
|
102 |
+
# Launch Gradio App
|
103 |
+
interface.launch()
|