File size: 4,377 Bytes
79b13b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import json
import math
import sqlite3
import streamlit as st
import torch
from PIL import Image
from huggingface_hub import hf_hub_download
from torchvision import transforms
from transformers import AutoModelForImageClassification
# Set the page title
st.title("Global Bird Classification App")
# Upload an image
uploaded_file = st.file_uploader("Please select an image", type=["jpg", "jpeg", "png"])
# Input latitude and longitude (optional)
latitude = st.number_input("Enter latitude (optional)", value=None, format="%f")
longitude = st.number_input("Enter longitude (optional)", value=None, format="%f")
lang = st.selectbox(
"Result Language",
options=[2, 1, 0],
format_func=lambda x: {
2: "Latina (Nomen Scientificum)",
1: "English (IOC 10.1)",
0: "中文 (中国大陆)",
}[x]
)
classify_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# crop and classification
def classify_objects(classification_model, image, species_list):
input_tensor = classify_transforms(image).unsqueeze(0).to(device)
with torch.no_grad():
logits = classification_model(input_tensor)[0]
filtered = get_filtered_predictions(logits, species_list)
return softmax(filtered)
def softmax(tuples):
# `torch.nn.functional.softmax` requires the input to be `Tensor`, so I implemented it myself
values = [t[1] for t in tuples]
exp_values = [math.exp(v) for v in values]
sum_exp_values = sum(exp_values)
softmax_values = [ev / sum_exp_values for ev in exp_values]
updated_tuples = [(t[0], softmax_values[i]) for i, t in enumerate(tuples)]
updated_tuples.sort(key=lambda t: t[1], reverse=True)
return updated_tuples
def get_filtered_predictions(predictions: list[float], species_list: list[int]) -> list[tuple[int, float]]:
original = {index: value for index, value in enumerate(predictions)}
if species_list:
filtered_predictions = [(key, value) for key, value in original.items() if key in species_list]
else:
filtered_predictions = [(key, value) for key, value in original.items()]
return filtered_predictions
class DistributionDB:
def __init__(self, db_path):
self.con = sqlite3.connect(db_path)
self.cur = self.con.cursor()
def get_list(self, lat, lng) -> list:
self.cur.execute(f'''
SELECT m.cls
FROM distributions AS d
LEFT OUTER JOIN places AS p
ON p.worldid = d.worldid
LEFT OUTER JOIN sp_cls_map AS m
ON d.species = m.species
WHERE p.south <= {lat}
AND p.north >= {lat}
AND p.east >= {lng}
AND p.west <= {lng}
GROUP BY d.species, m.cls;
''')
return [row[0] for row in self.cur]
def close(self):
self.cur.close()
self.con.close()
# If the user uploads an image
if uploaded_file is not None:
try:
sqlite_path = hf_hub_download(repo_id='sunjiao/osea', filename='avonet.db')
st.success(f"Successfully downloaded distribution database from Hugging Face Hub!")
label_map_path = hf_hub_download(repo_id='sunjiao/osea', filename='bird_info.json')
st.success(f"Successfully downloaded labels from Hugging Face Hub!")
except Exception as e:
st.error(f"Failed to download the file: {e}")
st.stop()
db = DistributionDB(sqlite_path)
species_list = db.get_list(latitude, longitude)
db.close()
# Open the image
image = Image.open(uploaded_file)
# Display the uploaded image
st.image(image, caption="Uploaded Image", use_container_width=True)
model = AutoModelForImageClassification.from_pretrained('sunjiao/osea')
results = classify_objects(model, image, species_list)
top3_results = results[:3]
with open(label_map_path, 'r') as f:
data = f.read()
bird_info = json.loads(data)
# Display the top 3 results and their probabilities
st.subheader("Classification Results (Top 3):")
for result in top3_results:
st.write(f"{bird_info[result[0]][lang]}: {result[1]:.4f}")
# Display latitude and longitude if provided
if latitude is not None and longitude is not None:
st.write(f"Entered Latitude: {latitude}, Longitude: {longitude}") |