File size: 5,023 Bytes
79b13b9 b886d74 79b13b9 b886d74 79b13b9 b886d74 79b13b9 b886d74 79b13b9 b886d74 79b13b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import json
import math
import sqlite3
import streamlit as st
import torch
import torchvision
from PIL import Image
from huggingface_hub import hf_hub_download
from torchvision import transforms
from transformers import AutoModelForImageClassification, AutoConfig
# Set the page title
st.title("Global Bird Classification App")
# Input latitude and longitude (optional)
latitude = st.number_input("Enter latitude (optional)", value=None, format="%f")
longitude = st.number_input("Enter longitude (optional)", value=None, format="%f")
st.text('Please fill the coordinates before upload image.')
# Upload an image
uploaded_file = st.file_uploader("Please select an image", type=["jpg", "jpeg", "png"])
lang = st.selectbox(
"Result Language",
options=[2, 1, 0],
format_func=lambda x: {
2: "Latina (Nomen Scientificum)",
1: "English (IOC 10.1)",
0: "中文 (中国大陆)",
}[x]
)
classify_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# crop and classification
def classify_objects(classification_model, image, species_list):
input_tensor = classify_transforms(image).unsqueeze(0).to(device)
with torch.no_grad():
logits = classification_model(input_tensor)[0]
filtered = get_filtered_predictions(logits, species_list)
return softmax(filtered)
def softmax(tuples):
# `torch.nn.functional.softmax` requires the input to be `Tensor`, so I implemented it myself
values = [t[1] for t in tuples]
exp_values = [math.exp(v) for v in values]
sum_exp_values = sum(exp_values)
softmax_values = [ev / sum_exp_values for ev in exp_values]
updated_tuples = [(t[0], softmax_values[i]) for i, t in enumerate(tuples)]
updated_tuples.sort(key=lambda t: t[1], reverse=True)
return updated_tuples
def get_filtered_predictions(predictions: list[float], species_list: list[int]) -> list[tuple[int, float]]:
original = {index: value for index, value in enumerate(predictions)}
if species_list:
filtered_predictions = [(key, value) for key, value in original.items() if key in species_list]
else:
filtered_predictions = [(key, value) for key, value in original.items()]
return filtered_predictions
class DistributionDB:
def __init__(self, db_path):
self.con = sqlite3.connect(db_path)
self.cur = self.con.cursor()
def get_list(self, lat, lng) -> list:
self.cur.execute(f'''
SELECT m.cls
FROM distributions AS d
LEFT OUTER JOIN places AS p
ON p.worldid = d.worldid
LEFT OUTER JOIN sp_cls_map AS m
ON d.species = m.species
WHERE p.south <= {lat}
AND p.north >= {lat}
AND p.east >= {lng}
AND p.west <= {lng}
GROUP BY d.species, m.cls;
''')
return [row[0] for row in self.cur]
def close(self):
self.cur.close()
self.con.close()
# If the user uploads an image
if uploaded_file is not None:
try:
label_map_path = hf_hub_download(repo_id='sunjiao/osea', filename='bird_info.json')
st.success(f"Successfully downloaded labels from Hugging Face Hub!")
except Exception as e:
st.error(f"Failed to download the file: {e}")
st.stop()
with open(label_map_path, 'r') as f:
data = f.read()
bird_info = json.loads(data)
species_list = None
if latitude and longitude:
try:
sqlite_path = hf_hub_download(repo_id='sunjiao/osea', filename='avonet.db')
st.success(f"Successfully downloaded distribution database from Hugging Face Hub!")
except Exception as e:
st.error(f"Failed to download the file: {e}")
st.stop()
db = DistributionDB(sqlite_path)
species_list = db.get_list(latitude, longitude)
db.close()
# Open the image
image = Image.open(uploaded_file)
# Display the uploaded image
st.image(image, caption="Uploaded Image", use_container_width=True)
try:
weight_dict = hf_hub_download(repo_id='sunjiao/osea', filename='pytorch_model.bin')
st.success(f"Successfully downloaded weight dict from Hugging Face Hub!")
except Exception as e:
st.error(f"Failed to download the file: {e}")
st.stop()
model = torchvision.models.resnet34(num_classes=11000)
model.load_state_dict(torch.load(weight_dict, map_location=device))
model.eval()
results = classify_objects(model, image, species_list)
top3_results = results[:3]
# Display the top 3 results and their probabilities
st.subheader("Classification Results (Top 3):")
for result in top3_results:
st.write(f"{bird_info[result[0]][lang]}: {result[1]:.4f}")
# Display latitude and longitude if provided
if latitude is not None and longitude is not None:
st.write(f"Entered Latitude: {latitude}, Longitude: {longitude}") |