File size: 4,377 Bytes
79b13b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import json
import math
import sqlite3

import streamlit as st
import torch
from PIL import Image
from huggingface_hub import hf_hub_download
from torchvision import transforms
from transformers import AutoModelForImageClassification

# Set the page title
st.title("Global Bird Classification App")

# Upload an image
uploaded_file = st.file_uploader("Please select an image", type=["jpg", "jpeg", "png"])

# Input latitude and longitude (optional)
latitude = st.number_input("Enter latitude (optional)", value=None, format="%f")
longitude = st.number_input("Enter longitude (optional)", value=None, format="%f")

lang = st.selectbox(
    "Result Language",
    options=[2, 1, 0],
    format_func=lambda x: {
        2: "Latina (Nomen Scientificum)",
        1: "English (IOC 10.1)",
        0: "中文 (中国大陆)",
    }[x]
)


classify_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


# crop and classification
def classify_objects(classification_model, image, species_list):
    input_tensor = classify_transforms(image).unsqueeze(0).to(device)
    with torch.no_grad():
        logits = classification_model(input_tensor)[0]
        filtered = get_filtered_predictions(logits, species_list)
        return softmax(filtered)


def softmax(tuples):
    # `torch.nn.functional.softmax` requires the input to be `Tensor`, so I implemented it myself
    values = [t[1] for t in tuples]
    exp_values = [math.exp(v) for v in values]
    sum_exp_values = sum(exp_values)
    softmax_values = [ev / sum_exp_values for ev in exp_values]
    updated_tuples = [(t[0], softmax_values[i]) for i, t in enumerate(tuples)]
    updated_tuples.sort(key=lambda t: t[1], reverse=True)

    return updated_tuples


def get_filtered_predictions(predictions: list[float], species_list: list[int]) -> list[tuple[int, float]]:
    original = {index: value for index, value in enumerate(predictions)}

    if species_list:
        filtered_predictions = [(key, value) for key, value in original.items() if key in species_list]
    else:
        filtered_predictions = [(key, value) for key, value in original.items()]

    return filtered_predictions


class DistributionDB:
    def __init__(self, db_path):
        self.con = sqlite3.connect(db_path)
        self.cur = self.con.cursor()


    def get_list(self, lat, lng) -> list:
        self.cur.execute(f'''
SELECT m.cls
FROM distributions AS d
LEFT OUTER JOIN places AS p
  ON p.worldid = d.worldid
LEFT OUTER JOIN sp_cls_map AS m
  ON d.species = m.species
WHERE p.south <= {lat}
  AND p.north >= {lat}
  AND p.east >= {lng}
  AND p.west <= {lng}
GROUP BY d.species, m.cls;
''')

        return [row[0] for row in self.cur]

    def close(self):
        self.cur.close()
        self.con.close()


# If the user uploads an image
if uploaded_file is not None:
    try:
        sqlite_path = hf_hub_download(repo_id='sunjiao/osea', filename='avonet.db')
        st.success(f"Successfully downloaded distribution database from Hugging Face Hub!")

        label_map_path = hf_hub_download(repo_id='sunjiao/osea', filename='bird_info.json')
        st.success(f"Successfully downloaded labels from Hugging Face Hub!")
    except Exception as e:
        st.error(f"Failed to download the file: {e}")
        st.stop()

    db = DistributionDB(sqlite_path)
    species_list = db.get_list(latitude, longitude)
    db.close()

    # Open the image
    image = Image.open(uploaded_file)

    # Display the uploaded image
    st.image(image, caption="Uploaded Image", use_container_width=True)

    model = AutoModelForImageClassification.from_pretrained('sunjiao/osea')

    results = classify_objects(model, image, species_list)

    top3_results = results[:3]

    with open(label_map_path, 'r') as f:
        data = f.read()

    bird_info = json.loads(data)

    # Display the top 3 results and their probabilities
    st.subheader("Classification Results (Top 3):")
    for result in top3_results:
        st.write(f"{bird_info[result[0]][lang]}: {result[1]:.4f}")

# Display latitude and longitude if provided
if latitude is not None and longitude is not None:
    st.write(f"Entered Latitude: {latitude}, Longitude: {longitude}")