File size: 5,023 Bytes
79b13b9
 
 
 
 
 
b886d74
79b13b9
 
 
b886d74
79b13b9
 
 
 
 
 
 
 
b886d74
 
 
 
 
79b13b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b886d74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79b13b9
 
 
 
 
 
 
b886d74
 
 
 
 
 
 
 
 
 
79b13b9
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import json
import math
import sqlite3

import streamlit as st
import torch
import torchvision
from PIL import Image
from huggingface_hub import hf_hub_download
from torchvision import transforms
from transformers import AutoModelForImageClassification, AutoConfig

# Set the page title
st.title("Global Bird Classification App")

# Input latitude and longitude (optional)
latitude = st.number_input("Enter latitude (optional)", value=None, format="%f")
longitude = st.number_input("Enter longitude (optional)", value=None, format="%f")

st.text('Please fill the coordinates before upload image.')

# Upload an image
uploaded_file = st.file_uploader("Please select an image", type=["jpg", "jpeg", "png"])

lang = st.selectbox(
    "Result Language",
    options=[2, 1, 0],
    format_func=lambda x: {
        2: "Latina (Nomen Scientificum)",
        1: "English (IOC 10.1)",
        0: "中文 (中国大陆)",
    }[x]
)


classify_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


# crop and classification
def classify_objects(classification_model, image, species_list):
    input_tensor = classify_transforms(image).unsqueeze(0).to(device)
    with torch.no_grad():
        logits = classification_model(input_tensor)[0]
        filtered = get_filtered_predictions(logits, species_list)
        return softmax(filtered)


def softmax(tuples):
    # `torch.nn.functional.softmax` requires the input to be `Tensor`, so I implemented it myself
    values = [t[1] for t in tuples]
    exp_values = [math.exp(v) for v in values]
    sum_exp_values = sum(exp_values)
    softmax_values = [ev / sum_exp_values for ev in exp_values]
    updated_tuples = [(t[0], softmax_values[i]) for i, t in enumerate(tuples)]
    updated_tuples.sort(key=lambda t: t[1], reverse=True)

    return updated_tuples


def get_filtered_predictions(predictions: list[float], species_list: list[int]) -> list[tuple[int, float]]:
    original = {index: value for index, value in enumerate(predictions)}

    if species_list:
        filtered_predictions = [(key, value) for key, value in original.items() if key in species_list]
    else:
        filtered_predictions = [(key, value) for key, value in original.items()]

    return filtered_predictions


class DistributionDB:
    def __init__(self, db_path):
        self.con = sqlite3.connect(db_path)
        self.cur = self.con.cursor()


    def get_list(self, lat, lng) -> list:
        self.cur.execute(f'''
SELECT m.cls
FROM distributions AS d
LEFT OUTER JOIN places AS p
  ON p.worldid = d.worldid
LEFT OUTER JOIN sp_cls_map AS m
  ON d.species = m.species
WHERE p.south <= {lat}
  AND p.north >= {lat}
  AND p.east >= {lng}
  AND p.west <= {lng}
GROUP BY d.species, m.cls;
''')

        return [row[0] for row in self.cur]

    def close(self):
        self.cur.close()
        self.con.close()


# If the user uploads an image
if uploaded_file is not None:
    try:
        label_map_path = hf_hub_download(repo_id='sunjiao/osea', filename='bird_info.json')
        st.success(f"Successfully downloaded labels from Hugging Face Hub!")
    except Exception as e:
        st.error(f"Failed to download the file: {e}")
        st.stop()

    with open(label_map_path, 'r') as f:
        data = f.read()

    bird_info = json.loads(data)

    species_list = None
    if latitude and longitude:
        try:
            sqlite_path = hf_hub_download(repo_id='sunjiao/osea', filename='avonet.db')
            st.success(f"Successfully downloaded distribution database from Hugging Face Hub!")
        except Exception as e:
            st.error(f"Failed to download the file: {e}")
            st.stop()

        db = DistributionDB(sqlite_path)
        species_list = db.get_list(latitude, longitude)
        db.close()

    # Open the image
    image = Image.open(uploaded_file)

    # Display the uploaded image
    st.image(image, caption="Uploaded Image", use_container_width=True)

    try:
        weight_dict = hf_hub_download(repo_id='sunjiao/osea', filename='pytorch_model.bin')
        st.success(f"Successfully downloaded weight dict from Hugging Face Hub!")
    except Exception as e:
        st.error(f"Failed to download the file: {e}")
        st.stop()

    model = torchvision.models.resnet34(num_classes=11000)
    model.load_state_dict(torch.load(weight_dict, map_location=device))
    model.eval()

    results = classify_objects(model, image, species_list)

    top3_results = results[:3]

    # Display the top 3 results and their probabilities
    st.subheader("Classification Results (Top 3):")
    for result in top3_results:
        st.write(f"{bird_info[result[0]][lang]}: {result[1]:.4f}")

# Display latitude and longitude if provided
if latitude is not None and longitude is not None:
    st.write(f"Entered Latitude: {latitude}, Longitude: {longitude}")