Sun Jiao commited on
Commit
b886d74
·
1 Parent(s): 4413319

fix errors.

Browse files
Files changed (2) hide show
  1. app.py +34 -16
  2. requirements.txt +4 -3
app.py CHANGED
@@ -4,21 +4,24 @@ import sqlite3
4
 
5
  import streamlit as st
6
  import torch
 
7
  from PIL import Image
8
  from huggingface_hub import hf_hub_download
9
  from torchvision import transforms
10
- from transformers import AutoModelForImageClassification
11
 
12
  # Set the page title
13
  st.title("Global Bird Classification App")
14
 
15
- # Upload an image
16
- uploaded_file = st.file_uploader("Please select an image", type=["jpg", "jpeg", "png"])
17
-
18
  # Input latitude and longitude (optional)
19
  latitude = st.number_input("Enter latitude (optional)", value=None, format="%f")
20
  longitude = st.number_input("Enter longitude (optional)", value=None, format="%f")
21
 
 
 
 
 
 
22
  lang = st.selectbox(
23
  "Result Language",
24
  options=[2, 1, 0],
@@ -102,18 +105,29 @@ GROUP BY d.species, m.cls;
102
  # If the user uploads an image
103
  if uploaded_file is not None:
104
  try:
105
- sqlite_path = hf_hub_download(repo_id='sunjiao/osea', filename='avonet.db')
106
- st.success(f"Successfully downloaded distribution database from Hugging Face Hub!")
107
-
108
  label_map_path = hf_hub_download(repo_id='sunjiao/osea', filename='bird_info.json')
109
  st.success(f"Successfully downloaded labels from Hugging Face Hub!")
110
  except Exception as e:
111
  st.error(f"Failed to download the file: {e}")
112
  st.stop()
113
 
114
- db = DistributionDB(sqlite_path)
115
- species_list = db.get_list(latitude, longitude)
116
- db.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  # Open the image
119
  image = Image.open(uploaded_file)
@@ -121,17 +135,21 @@ if uploaded_file is not None:
121
  # Display the uploaded image
122
  st.image(image, caption="Uploaded Image", use_container_width=True)
123
 
124
- model = AutoModelForImageClassification.from_pretrained('sunjiao/osea')
 
 
 
 
 
 
 
 
 
125
 
126
  results = classify_objects(model, image, species_list)
127
 
128
  top3_results = results[:3]
129
 
130
- with open(label_map_path, 'r') as f:
131
- data = f.read()
132
-
133
- bird_info = json.loads(data)
134
-
135
  # Display the top 3 results and their probabilities
136
  st.subheader("Classification Results (Top 3):")
137
  for result in top3_results:
 
4
 
5
  import streamlit as st
6
  import torch
7
+ import torchvision
8
  from PIL import Image
9
  from huggingface_hub import hf_hub_download
10
  from torchvision import transforms
11
+ from transformers import AutoModelForImageClassification, AutoConfig
12
 
13
  # Set the page title
14
  st.title("Global Bird Classification App")
15
 
 
 
 
16
  # Input latitude and longitude (optional)
17
  latitude = st.number_input("Enter latitude (optional)", value=None, format="%f")
18
  longitude = st.number_input("Enter longitude (optional)", value=None, format="%f")
19
 
20
+ st.text('Please fill the coordinates before upload image.')
21
+
22
+ # Upload an image
23
+ uploaded_file = st.file_uploader("Please select an image", type=["jpg", "jpeg", "png"])
24
+
25
  lang = st.selectbox(
26
  "Result Language",
27
  options=[2, 1, 0],
 
105
  # If the user uploads an image
106
  if uploaded_file is not None:
107
  try:
 
 
 
108
  label_map_path = hf_hub_download(repo_id='sunjiao/osea', filename='bird_info.json')
109
  st.success(f"Successfully downloaded labels from Hugging Face Hub!")
110
  except Exception as e:
111
  st.error(f"Failed to download the file: {e}")
112
  st.stop()
113
 
114
+ with open(label_map_path, 'r') as f:
115
+ data = f.read()
116
+
117
+ bird_info = json.loads(data)
118
+
119
+ species_list = None
120
+ if latitude and longitude:
121
+ try:
122
+ sqlite_path = hf_hub_download(repo_id='sunjiao/osea', filename='avonet.db')
123
+ st.success(f"Successfully downloaded distribution database from Hugging Face Hub!")
124
+ except Exception as e:
125
+ st.error(f"Failed to download the file: {e}")
126
+ st.stop()
127
+
128
+ db = DistributionDB(sqlite_path)
129
+ species_list = db.get_list(latitude, longitude)
130
+ db.close()
131
 
132
  # Open the image
133
  image = Image.open(uploaded_file)
 
135
  # Display the uploaded image
136
  st.image(image, caption="Uploaded Image", use_container_width=True)
137
 
138
+ try:
139
+ weight_dict = hf_hub_download(repo_id='sunjiao/osea', filename='pytorch_model.bin')
140
+ st.success(f"Successfully downloaded weight dict from Hugging Face Hub!")
141
+ except Exception as e:
142
+ st.error(f"Failed to download the file: {e}")
143
+ st.stop()
144
+
145
+ model = torchvision.models.resnet34(num_classes=11000)
146
+ model.load_state_dict(torch.load(weight_dict, map_location=device))
147
+ model.eval()
148
 
149
  results = classify_objects(model, image, species_list)
150
 
151
  top3_results = results[:3]
152
 
 
 
 
 
 
153
  # Display the top 3 results and their probabilities
154
  st.subheader("Classification Results (Top 3):")
155
  for result in top3_results:
requirements.txt CHANGED
@@ -2,8 +2,9 @@ huggingface_hub
2
  transformers
3
  ImageHash==4.3.1
4
  openpyxl==3.1.5
5
- Pillow==11.0.0
6
  pyshp==2.3.1
7
- torch==2.5.0
8
- torchvision==0.20.0
 
9
  tqdm==4.67.1
 
2
  transformers
3
  ImageHash==4.3.1
4
  openpyxl==3.1.5
5
+ Pillow
6
  pyshp==2.3.1
7
+ streamlit
8
+ torch
9
+ torchvision
10
  tqdm==4.67.1