Mehyaar's picture
app.py
31216ea verified
raw
history blame
3.3 kB
import gradio as gr
import torch
import pickle
from torchvision import transforms
from torchvision.transforms import ToTensor, RandomErasing
from PIL import Image
import numpy as np
from ultralytics import YOLO
import io
# Load the YOLO model for bird detection
yolo_model = YOLO('yolov5su.pt')
class CPU_Unpickler(pickle.Unpickler):
def find_class(self, module, name):
if module == 'torch.storage' and name == '_load_from_bytes':
return lambda b: torch.load(io.BytesIO(b), map_location='cpu') # Ensure loading on CPU
else:
return super().find_class(module, name)
# Load your model using the custom unpickler
with open(".\model\model_resultsconvnext_large.pkl", "rb") as file:
model = CPU_Unpickler(file).load()
model = model['convnext_large']['model']
model.eval()
# Function to detect bird region
def detect_bird_region(image):
results = yolo_model(image, verbose=False)
bird_boxes = results[0].boxes[results[0].boxes.cls == 14]
if len(bird_boxes) > 0:
return bird_boxes[0].xyxy[0].cpu().numpy() # Coordinates of the first detected bird
return None
# Preprocessing function for inference
def preprocess_image(image):
bird_box = detect_bird_region(image)
if bird_box is not None:
image = image.crop(bird_box) # Crop to bird region
# Apply validation transformations
val_transform = transforms.Compose([
transforms.Resize((229, 229)),
ToTensor(),
transforms.ConvertImageDtype(torch.float32),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
return val_transform(image).unsqueeze(0) # Add batch dimension
# Prediction function
def predict(image):
# Preprocess the image
image = preprocess_image(image)
# Perform prediction
with torch.no_grad():
outputs = model(image)
predicted_class = torch.argmax(outputs, dim=1).item()
# Map the predicted class to bird names using bird_folders
bird_folders = {
0: "019.Gray_Catbird",
1: "025.Pelagic_Cormorant",
2: "026.Bronzed_Cowbird",
3: "029.American_Crow",
4: "039.Least_Flycatcher",
5: "073.Blue_Jay",
6: "085.Horned_Lark",
7: "099.Ovenbird",
8: "104.American_Pipit",
9: "119.Field_Sparrow",
10: "127.Savannah_Sparrow",
11: "129.Song_Sparrow",
12: "135.Bank_Swallow",
13: "137.Cliff_Swallow",
14: "138.Tree_Swallow",
15: "142.Black_Tern",
16: "143.Caspian_Tern",
17: "144.Common_Tern",
18: "167.Hooded_Warbler",
19: "176.Prairie_Warbler",
20: "177.Prothonotary_Warbler",
21: "179.Tennessee_Warbler",
22: "182.Yellow_Warbler",
23: "183.Northern_Waterthrush",
24: "185.Bohemian_Waxwing",
25: "186.Cedar_Waxwing",
26: "188.Pileated_Woodpecker",
27: "192.Downy_Woodpecker",
28: "195.Carolina_Wren",
29: "199.Winter_Wren"
}
return bird_folders[predicted_class] # Return bird name as output
# Gradio Interface
interface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs="label" # Display class label as output
)
# Launch Gradio App
interface.launch()