Spaces:
Runtime error
Runtime error
Jonas Rheiner
commited on
Commit
·
4b77aea
1
Parent(s):
710c658
update
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +206 -64
- dataset_examples/africa/3962011747224020_1024.jpg +0 -0
- dataset_examples/africa/3973126792769679_1024.jpg +0 -0
- dataset_examples/africa/4471109009586514_1024.jpg +0 -0
- dataset_examples/asia/106261888221766_1024.jpg +0 -0
- dataset_examples/asia/138321512570044_1024.jpg +0 -0
- dataset_examples/asia/147206360658971_1024.jpg +0 -0
- dataset_examples/europe/1423684677989158_1024.jpg +0 -0
- dataset_examples/europe/1483506425323136_1024.jpg +0 -0
- dataset_examples/europe/150428493699453_1024.jpg +0 -0
- dataset_examples/north america/1000276020376482_1024.jpg +0 -0
- dataset_examples/north america/757125001639938_1024.jpg +0 -0
- dataset_examples/north america/843371313196684_1024.jpg +0 -0
- dataset_examples/oceania/100141636075718_1024.jpg +0 -0
- dataset_examples/oceania/1899604010244250_1024.jpg +0 -0
- dataset_examples/oceania/821397982117703_1024.jpg +0 -0
- dataset_examples/south america/103386512798674_1024.jpg +0 -0
- dataset_examples/south america/1677652415776082_1024.jpg +0 -0
- dataset_examples/south america/327973242016483_1024.jpg +0 -0
- examples/1000276020376482_1024.jpg +0 -0
- examples/100141636075718_1024.jpg +0 -0
- examples/103386512798674_1024.jpg +0 -0
- examples/106261888221766_1024.jpg +0 -0
- examples/138321512570044_1024.jpg +0 -0
- examples/1423684677989158_1024.jpg +0 -0
- examples/147206360658971_1024.jpg +0 -0
- examples/1483506425323136_1024.jpg +0 -0
- examples/150428493699453_1024.jpg +0 -0
- examples/1677652415776082_1024.jpg +0 -0
- examples/1899604010244250_1024.jpg +0 -0
- examples/327973242016483_1024.jpg +0 -0
- examples/3962011747224020_1024.jpg +0 -0
- examples/3973126792769679_1024.jpg +0 -0
- examples/4471109009586514_1024.jpg +0 -0
- examples/757125001639938_1024.jpg +0 -0
- examples/821397982117703_1024.jpg +0 -0
- examples/843371313196684_1024.jpg +0 -0
- kerger-test-images/Africa_Botswana_-24.358520377382_23.5184910801.jpg +0 -0
- kerger-test-images/Africa_Kenya_-0.21870999999999_37.023791.jpg +0 -0
- kerger-test-images/Africa_Madagascar_-16.078452454738_46.73369803641.jpg +0 -0
- kerger-test-images/Africa_South Africa_-23.590135077274_28.785944164821.jpg +0 -0
- kerger-test-images/Africa_Tanzania_-3.3676537025657_36.716512872377.jpg +0 -0
- kerger-test-images/Africa_Uganda_1.1212866787272_33.915204986261.jpg +0 -0
- kerger-test-images/Asia_Israel_31.708865303742_34.94966916063.jpg +0 -0
- kerger-test-images/Asia_Japan_35.381304970616_134.65860211972.jpg +0 -0
- kerger-test-images/Asia_Pakistan_24.910493840503_69.506229024537.jpg +0 -0
- kerger-test-images/Asia_Russia_54.597757883015_48.163689656865.jpg +0 -0
- kerger-test-images/Asia_Russia_56.018311493214_38.359778952407.jpg +0 -0
- kerger-test-images/Asia_Russia_60.27835356798_29.754665851696.jpg +0 -0
- kerger-test-images/Asia_Thailand_19.824843951089_99.694080339609.jpg +0 -0
app.py
CHANGED
@@ -4,92 +4,208 @@ from transformers import CLIPProcessor, CLIPModel
|
|
4 |
import torch
|
5 |
import itertools
|
6 |
import os
|
|
|
7 |
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14-336")
|
|
|
|
|
|
|
10 |
|
11 |
-
continents = ["Africa", "Asia", "Europe",
|
|
|
12 |
countries_per_continent = {
|
13 |
"Africa": [
|
14 |
-
"Algeria", "Angola", "Benin", "Botswana", "Burkina Faso", "Burundi", "Cabo Verde", "Cameroon",
|
15 |
-
"Central African Republic", "
|
16 |
-
"Djibouti", "Egypt", "Equatorial Guinea", "Eritrea", "Eswatini", "Ethiopia", "Gabon",
|
17 |
-
"Gambia", "Ghana", "Guinea", "Guinea-Bissau", "Ivory Coast", "Kenya", "Lesotho", "Liberia",
|
18 |
-
"Libya", "Madagascar", "Malawi", "Mali", "Mauritania", "Mauritius", "Morocco", "Mozambique",
|
19 |
-
"Namibia", "Niger", "Nigeria", "Rwanda", "Sao Tome and Principe", "Senegal", "Seychelles",
|
20 |
-
"Sierra Leone", "Somalia", "South Africa", "
|
21 |
"Tunisia", "Uganda", "Zambia", "Zimbabwe"
|
22 |
],
|
23 |
"Asia": [
|
24 |
-
"Afghanistan", "Armenia", "Azerbaijan", "Bahrain", "Bangladesh", "Bhutan", "Brunei",
|
25 |
-
"Cambodia", "China", "Cyprus", "Georgia", "India", "Indonesia", "Iran", "Iraq",
|
26 |
-
"Israel", "Japan", "Jordan", "Kazakhstan", "Kuwait", "Kyrgyzstan", "Laos", "Lebanon",
|
27 |
-
"Malaysia", "Maldives", "Mongolia", "Myanmar", "Nepal", "North Korea", "Oman", "Pakistan",
|
28 |
-
"Palestine", "Philippines", "Qatar", "Russia", "Saudi Arabia", "Singapore", "South Korea",
|
29 |
-
"Sri Lanka", "Syria", "Taiwan", "Tajikistan", "Thailand", "Timor-Leste", "Turkey",
|
30 |
"Turkmenistan", "United Arab Emirates", "Uzbekistan", "Vietnam", "Yemen"
|
31 |
],
|
32 |
"Europe": [
|
33 |
-
"Albania", "
|
34 |
-
"Bulgaria", "Croatia", "Cyprus", "Czech Republic", "Denmark", "Estonia", "Finland", "France",
|
35 |
-
"Georgia", "Germany", "Greece", "Hungary", "Iceland", "Ireland", "Italy", "Kazakhstan",
|
36 |
-
"Kosovo", "Latvia", "Liechtenstein", "Lithuania", "Luxembourg", "Malta", "Moldova", "Monaco",
|
37 |
-
"Montenegro", "Netherlands", "North Macedonia", "Norway", "Poland", "Portugal", "Romania",
|
38 |
-
"Russia", "San Marino", "Serbia", "Slovakia", "Slovenia", "Spain", "Sweden", "Switzerland",
|
39 |
-
"Turkey", "Ukraine", "United Kingdom"
|
40 |
],
|
41 |
"North America": [
|
42 |
-
"Antigua and Barbuda", "Bahamas", "Barbados", "Belize", "Canada", "Costa Rica", "Cuba",
|
43 |
-
"Dominica", "Dominican Republic", "El Salvador", "Grenada", "Guatemala", "Haiti", "Honduras",
|
44 |
-
"Jamaica", "Mexico", "Nicaragua", "Panama", "Saint Kitts and Nevis", "Saint Lucia",
|
45 |
"Saint Vincent and the Grenadines", "Trinidad and Tobago", "United States"
|
46 |
],
|
47 |
"Oceania": [
|
48 |
-
"Australia", "Fiji", "Kiribati", "Marshall Islands", "Micronesia", "Nauru", "New Zealand",
|
49 |
"Palau", "Papua New Guinea", "Samoa", "Solomon Islands", "Tonga", "Tuvalu", "Vanuatu"
|
50 |
],
|
51 |
"South America": [
|
52 |
-
"Argentina", "Bolivia", "Brazil", "Chile", "Colombia", "Ecuador", "Guyana", "Paraguay",
|
53 |
"Peru", "Suriname", "Uruguay", "Venezuela"
|
54 |
]
|
55 |
}
|
56 |
-
countries = list(set(itertools.chain.from_iterable(
|
|
|
57 |
|
58 |
-
|
59 |
-
|
60 |
-
"
|
61 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
}
|
63 |
|
|
|
64 |
def predict(input_img):
|
65 |
-
inputs = processor(text=[f"A photo from {
|
|
|
|
|
66 |
with torch.no_grad():
|
67 |
-
outputs =
|
68 |
logits_per_image = outputs.logits_per_image
|
69 |
probs = logits_per_image.softmax(dim=-1)
|
70 |
pred_id = probs.argmax().cpu().item()
|
71 |
-
continent_probs = {label: prob for label,
|
72 |
-
|
73 |
-
|
74 |
predicted_continent_countries = countries_per_continent[continents[pred_id]]
|
75 |
-
inputs = processor(text=[f"A photo from {
|
|
|
|
|
76 |
with torch.no_grad():
|
77 |
-
outputs =
|
78 |
logits_per_image = outputs.logits_per_image
|
79 |
probs = logits_per_image.softmax(dim=-1)
|
80 |
-
country_probs = {label: prob for label, prob in zip(
|
81 |
-
|
82 |
return continent_probs, country_probs
|
83 |
|
84 |
-
|
85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
continent_probs, country_probs = predict(input_img)
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
-
def next_versus_image():
|
90 |
-
# VERSUS_GT["continent"] = "test"
|
91 |
-
# VERSUS_GT["country"] = "test"
|
92 |
-
return "versus_images/[email protected]"
|
93 |
|
94 |
def get_example_images(dir):
|
95 |
image_extensions = (".jpg", ".jpeg", ".png")
|
@@ -100,39 +216,65 @@ def get_example_images(dir):
|
|
100 |
image_files.append(os.path.join(root, file))
|
101 |
return image_files
|
102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
demo = gr.Blocks()
|
104 |
with demo:
|
105 |
with gr.Tab("Image Geolocation Demo"):
|
106 |
with gr.Row():
|
107 |
with gr.Column():
|
108 |
-
image = gr.Image(label="Image", type="pil",
|
|
|
109 |
predict_btn = gr.Button("Predict")
|
110 |
-
example_images = get_example_images("
|
111 |
-
example_images.extend(get_example_images("versus_images"))
|
112 |
-
gr.Examples(examples=example_images,
|
|
|
113 |
with gr.Column():
|
114 |
continents_label = gr.Label(label="Continents")
|
115 |
-
country_label = gr.Label(
|
|
|
116 |
# continents_label.select(predict_country, inputs=[image, continents_label], outputs=country_label)
|
117 |
-
predict_btn.click(predict, inputs=image, outputs=[
|
|
|
118 |
|
119 |
with gr.Tab("Versus Mode"):
|
|
|
120 |
with gr.Row():
|
121 |
with gr.Column():
|
122 |
-
versus_image = gr.Image(
|
123 |
-
|
124 |
-
|
125 |
-
|
|
|
|
|
126 |
with gr.Row():
|
127 |
next_img_btn = gr.Button("Try new image")
|
128 |
versus_btn = gr.Button("Submit guess")
|
129 |
with gr.Column():
|
130 |
-
versus_output = gr.
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
|
|
|
|
|
|
|
|
|
|
135 |
|
136 |
|
137 |
if __name__ == "__main__":
|
138 |
-
demo.launch()
|
|
|
4 |
import torch
|
5 |
import itertools
|
6 |
import os
|
7 |
+
import plotly.graph_objects as go
|
8 |
|
9 |
+
|
10 |
+
CUDA_AVAILABLE = torch.cuda.is_available()
|
11 |
+
print(f"CUDA={CUDA_AVAILABLE}")
|
12 |
+
device = "cuda" if CUDA_AVAILABLE else "cpu"
|
13 |
+
print(f"count={torch.cuda.device_count()}")
|
14 |
+
print(f"current={torch.cuda.get_device_name(torch.cuda.current_device())}")
|
15 |
+
|
16 |
+
continent_model = CLIPModel.from_pretrained("model-checkpoints/continent")
|
17 |
+
country_model = CLIPModel.from_pretrained("model-checkpoints/country")
|
18 |
processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14-336")
|
19 |
+
continent_model = continent_model.to(device)
|
20 |
+
country_model = country_model.to(device)
|
21 |
+
|
22 |
|
23 |
+
continents = ["Africa", "Asia", "Europe",
|
24 |
+
"North America", "Oceania", "South America"]
|
25 |
countries_per_continent = {
|
26 |
"Africa": [
|
27 |
+
"Algeria", "Angola", "Benin", "Botswana", "Burkina Faso", "Burundi", "Cabo Verde", "Cameroon",
|
28 |
+
"Central African Republic", "Congo", "Democratic Republic of the Congo",
|
29 |
+
"Djibouti", "Egypt", "Equatorial Guinea", "Eritrea", "Eswatini", "Ethiopia", "Gabon",
|
30 |
+
"Gambia", "Ghana", "Guinea", "Guinea-Bissau", "Ivory Coast", "Kenya", "Lesotho", "Liberia",
|
31 |
+
"Libya", "Madagascar", "Malawi", "Mali", "Mauritania", "Mauritius", "Morocco", "Mozambique",
|
32 |
+
"Namibia", "Niger", "Nigeria", "Rwanda", "Sao Tome and Principe", "Senegal", "Seychelles",
|
33 |
+
"Sierra Leone", "Somalia", "South Africa", "Sudan", "Tanzania", "Togo",
|
34 |
"Tunisia", "Uganda", "Zambia", "Zimbabwe"
|
35 |
],
|
36 |
"Asia": [
|
37 |
+
"Afghanistan", "Armenia", "Azerbaijan", "Bahrain", "Bangladesh", "Bhutan", "Brunei",
|
38 |
+
"Cambodia", "China", "Cyprus", "Georgia", "India", "Indonesia", "Iran", "Iraq",
|
39 |
+
"Israel", "Japan", "Jordan", "Kazakhstan", "Kuwait", "Kyrgyzstan", "Laos", "Lebanon",
|
40 |
+
"Malaysia", "Maldives", "Mongolia", "Myanmar", "Nepal", "North Korea", "Oman", "Pakistan",
|
41 |
+
"Palestine", "Philippines", "Qatar", "Russia", "Saudi Arabia", "Singapore", "South Korea",
|
42 |
+
"Sri Lanka", "Syria", "Taiwan", "Tajikistan", "Thailand", "Timor-Leste", "Turkey",
|
43 |
"Turkmenistan", "United Arab Emirates", "Uzbekistan", "Vietnam", "Yemen"
|
44 |
],
|
45 |
"Europe": [
|
46 |
+
"Albania", "Armenia", "Austria", "Azerbaijan", "Belarus", "Belgium", "Bosnia and Herzegovina",
|
47 |
+
"Bulgaria", "Croatia", "Cyprus", "Czech Republic", "Denmark", "Estonia", "Finland", "France",
|
48 |
+
"Georgia", "Germany", "Greece", "Hungary", "Iceland", "Ireland", "Italy", "Kazakhstan",
|
49 |
+
"Kosovo", "Latvia", "Liechtenstein", "Lithuania", "Luxembourg", "Malta", "Moldova", "Monaco",
|
50 |
+
"Montenegro", "Netherlands", "North Macedonia", "Norway", "Poland", "Portugal", "Romania",
|
51 |
+
"Russia", "San Marino", "Serbia", "Slovakia", "Slovenia", "Spain", "Sweden", "Switzerland",
|
52 |
+
"Turkey", "Ukraine", "United Kingdom"
|
53 |
],
|
54 |
"North America": [
|
55 |
+
"Antigua and Barbuda", "Bahamas", "Barbados", "Belize", "Canada", "Costa Rica", "Cuba",
|
56 |
+
"Dominica", "Dominican Republic", "El Salvador", "Grenada", "Guatemala", "Haiti", "Honduras",
|
57 |
+
"Jamaica", "Mexico", "Nicaragua", "Panama", "Saint Kitts and Nevis", "Saint Lucia",
|
58 |
"Saint Vincent and the Grenadines", "Trinidad and Tobago", "United States"
|
59 |
],
|
60 |
"Oceania": [
|
61 |
+
"Australia", "Fiji", "Kiribati", "Marshall Islands", "Micronesia", "Nauru", "New Zealand",
|
62 |
"Palau", "Papua New Guinea", "Samoa", "Solomon Islands", "Tonga", "Tuvalu", "Vanuatu"
|
63 |
],
|
64 |
"South America": [
|
65 |
+
"Argentina", "Bolivia", "Brazil", "Chile", "Colombia", "Ecuador", "Guyana", "Paraguay",
|
66 |
"Peru", "Suriname", "Uruguay", "Venezuela"
|
67 |
]
|
68 |
}
|
69 |
+
countries = list(set(itertools.chain.from_iterable(
|
70 |
+
countries_per_continent.values())))
|
71 |
|
72 |
+
INTIAL_VERSUS_IMAGE = "versus_images/Europe_Germany_49.069183_10.319444_im2gps3k.jpg"
|
73 |
+
INITAL_VERSUS_STATE = {
|
74 |
+
"image": INTIAL_VERSUS_IMAGE,
|
75 |
+
"continent": INTIAL_VERSUS_IMAGE.split("/")[-1].split("_")[0],
|
76 |
+
"country": INTIAL_VERSUS_IMAGE.split("/")[-1].split("_")[1],
|
77 |
+
"lat": INTIAL_VERSUS_IMAGE.split("/")[-1].split("_")[2],
|
78 |
+
"lon": INTIAL_VERSUS_IMAGE.split("/")[-1].split("_")[3],
|
79 |
+
"score": {
|
80 |
+
"HUMAN": 0,
|
81 |
+
"AI": 0
|
82 |
+
},
|
83 |
+
"idx": 0
|
84 |
}
|
85 |
|
86 |
+
|
87 |
def predict(input_img):
|
88 |
+
inputs = processor(text=[f"A photo from {
|
89 |
+
geo}." for geo in continents], images=input_img, return_tensors="pt", padding=True)
|
90 |
+
inputs = inputs.to(device)
|
91 |
with torch.no_grad():
|
92 |
+
outputs = continent_model(**inputs)
|
93 |
logits_per_image = outputs.logits_per_image
|
94 |
probs = logits_per_image.softmax(dim=-1)
|
95 |
pred_id = probs.argmax().cpu().item()
|
96 |
+
continent_probs = {label: prob for label,
|
97 |
+
prob in zip(continents, probs.tolist()[0])}
|
98 |
+
|
99 |
predicted_continent_countries = countries_per_continent[continents[pred_id]]
|
100 |
+
inputs = processor(text=[f"A photo from {
|
101 |
+
geo}." for geo in predicted_continent_countries], images=input_img, return_tensors="pt", padding=True)
|
102 |
+
inputs = inputs.to(device)
|
103 |
with torch.no_grad():
|
104 |
+
outputs = country_model(**inputs)
|
105 |
logits_per_image = outputs.logits_per_image
|
106 |
probs = logits_per_image.softmax(dim=-1)
|
107 |
+
country_probs = {label: prob for label, prob in zip(
|
108 |
+
predicted_continent_countries, probs.tolist()[0])}
|
109 |
return continent_probs, country_probs
|
110 |
|
111 |
+
|
112 |
+
def make_versus_map(human_country, model_country, versus_state):
|
113 |
+
fig = go.Figure()
|
114 |
+
fig.add_trace(go.Scattergeo(
|
115 |
+
lon=[versus_state["lon"]],
|
116 |
+
lat=[versus_state["lat"]],
|
117 |
+
text=["📷"],
|
118 |
+
mode='text+markers',
|
119 |
+
hoverinfo='text',
|
120 |
+
hovertext=f"Photo taken in {versus_state['country']}, {
|
121 |
+
versus_state['continent']}",
|
122 |
+
marker=dict(size=14, color='#00B945'),
|
123 |
+
showlegend=False
|
124 |
+
|
125 |
+
))
|
126 |
+
if human_country == model_country:
|
127 |
+
fig.add_trace(go.Scattergeo(
|
128 |
+
locations=[human_country],
|
129 |
+
locationmode='country names',
|
130 |
+
text=["🧑🤖"],
|
131 |
+
mode='text',
|
132 |
+
hoverinfo='location',
|
133 |
+
showlegend=False
|
134 |
+
|
135 |
+
))
|
136 |
+
else:
|
137 |
+
fig.add_trace(go.Scattergeo(
|
138 |
+
locations=[human_country],
|
139 |
+
locationmode='country names',
|
140 |
+
text=["🧑"],
|
141 |
+
mode='text',
|
142 |
+
hoverinfo='location',
|
143 |
+
showlegend=False
|
144 |
+
|
145 |
+
))
|
146 |
+
fig.add_trace(go.Scattergeo(
|
147 |
+
locations=[model_country],
|
148 |
+
locationmode='country names',
|
149 |
+
text=["🤖"],
|
150 |
+
mode='text',
|
151 |
+
hoverinfo='location',
|
152 |
+
showlegend=False
|
153 |
+
|
154 |
+
))
|
155 |
+
fig.update_geos(
|
156 |
+
visible=True, resolution=110,
|
157 |
+
showcountries=True, countrycolor="grey", fitbounds="locations", projection_type="natural earth",
|
158 |
+
)
|
159 |
+
return fig
|
160 |
+
|
161 |
+
|
162 |
+
def versus_mode_inputs(input_img, human_continent, human_country, versus_state):
|
163 |
+
human_points = 0
|
164 |
+
model_points = 0
|
165 |
+
if human_country == versus_state["country"]:
|
166 |
+
country_result = "✅"
|
167 |
+
human_points += 2
|
168 |
+
else:
|
169 |
+
country_result = "❌"
|
170 |
+
if human_continent == versus_state["continent"]:
|
171 |
+
continent_result = "✅"
|
172 |
+
human_points += 1
|
173 |
+
else:
|
174 |
+
continent_result = "❌"
|
175 |
+
human_result = f"The photo is from **{versus_state['country']}** {
|
176 |
+
country_result} in **{versus_state['continent']}** {continent_result}"
|
177 |
+
human_score_update = f"+{
|
178 |
+
human_points} points" if human_points > 0 else "Wrong guess, try a new image."
|
179 |
+
versus_state['score']['HUMAN'] += human_points
|
180 |
+
|
181 |
continent_probs, country_probs = predict(input_img)
|
182 |
+
model_country = max(country_probs, key=country_probs.get)
|
183 |
+
model_continent = max(continent_probs, key=continent_probs.get)
|
184 |
+
if model_country == versus_state["country"]:
|
185 |
+
model_country_result = "✅"
|
186 |
+
model_points += 2
|
187 |
+
else:
|
188 |
+
model_country_result = "❌"
|
189 |
+
if model_continent == versus_state["continent"]:
|
190 |
+
model_continent_result = "✅"
|
191 |
+
model_points += 1
|
192 |
+
else:
|
193 |
+
model_continent_result = "❌"
|
194 |
+
model_score_update = f"+{
|
195 |
+
model_points} points" if model_points > 0 else "The model was wrong, seems the world is not yet doomed."
|
196 |
+
versus_state['score']['AI'] += model_points
|
197 |
+
|
198 |
+
map = make_versus_map(human_country, model_country, versus_state)
|
199 |
+
return f"""
|
200 |
+
## {human_result}
|
201 |
+
### The AI 🤖 thinks this photo is from **{model_country}** {model_country_result} in **{model_continent}** {model_continent_result}
|
202 |
+
|
203 |
+
🧑 {human_score_update}
|
204 |
+
🤖 {model_score_update}
|
205 |
+
|
206 |
+
### Score 🧑 {versus_state['score']['HUMAN']} : {versus_state['score']['AI']} 🤖
|
207 |
+
""", continent_probs, country_probs, map, versus_state
|
208 |
|
|
|
|
|
|
|
|
|
209 |
|
210 |
def get_example_images(dir):
|
211 |
image_extensions = (".jpg", ".jpeg", ".png")
|
|
|
216 |
image_files.append(os.path.join(root, file))
|
217 |
return image_files
|
218 |
|
219 |
+
|
220 |
+
def next_versus_image(versus_state):
|
221 |
+
images = get_example_images("versus_images")
|
222 |
+
versus_state["idx"] += 1
|
223 |
+
if versus_state["idx"] > len(images):
|
224 |
+
versus_state["idx"] = 0
|
225 |
+
versus_image = images[versus_state["idx"]]
|
226 |
+
versus_state["continent"] = versus_image.split("/")[-1].split("_")[0]
|
227 |
+
versus_state["country"] = versus_image.split("/")[-1].split("_")[1]
|
228 |
+
versus_state["lat"] = versus_image.split("/")[-1].split("_")[2]
|
229 |
+
versus_state["lon"] = versus_image.split("/")[-1].split("_")[3]
|
230 |
+
versus_state["image"] = versus_image
|
231 |
+
return versus_image, versus_state, None, None
|
232 |
+
|
233 |
demo = gr.Blocks()
|
234 |
with demo:
|
235 |
with gr.Tab("Image Geolocation Demo"):
|
236 |
with gr.Row():
|
237 |
with gr.Column():
|
238 |
+
image = gr.Image(label="Image", type="pil",
|
239 |
+
sources=["upload", "clipboard"])
|
240 |
predict_btn = gr.Button("Predict")
|
241 |
+
example_images = get_example_images("kerger-test-images")
|
242 |
+
# example_images.extend(get_example_images("versus_images"))
|
243 |
+
gr.Examples(examples=example_images,
|
244 |
+
inputs=image, examples_per_page=24)
|
245 |
with gr.Column():
|
246 |
continents_label = gr.Label(label="Continents")
|
247 |
+
country_label = gr.Label(
|
248 |
+
num_top_classes=5, label="Top countries")
|
249 |
# continents_label.select(predict_country, inputs=[image, continents_label], outputs=country_label)
|
250 |
+
predict_btn.click(predict, inputs=image, outputs=[
|
251 |
+
continents_label, country_label])
|
252 |
|
253 |
with gr.Tab("Versus Mode"):
|
254 |
+
versus_state = gr.State(value=INITAL_VERSUS_STATE)
|
255 |
with gr.Row():
|
256 |
with gr.Column():
|
257 |
+
versus_image = gr.Image(
|
258 |
+
INITAL_VERSUS_STATE["image"], interactive=False)
|
259 |
+
continent_selection = gr.Radio(
|
260 |
+
continents, label="Continents", info="Where was this image taken? (1 Point)")
|
261 |
+
country_selection = gr.Dropdown(countries, label="Countries", info="Can you guess the exact country? (2 Points)"
|
262 |
+
),
|
263 |
with gr.Row():
|
264 |
next_img_btn = gr.Button("Try new image")
|
265 |
versus_btn = gr.Button("Submit guess")
|
266 |
with gr.Column():
|
267 |
+
versus_output = gr.Markdown()
|
268 |
+
# with gr.Accordion("View Map", open=False):
|
269 |
+
map = gr.Plot(label="Locations")
|
270 |
+
with gr.Accordion("Full Model Output", open=False):
|
271 |
+
continents_label = gr.Label(label="Continents")
|
272 |
+
country_label = gr.Label(
|
273 |
+
num_top_classes=5, label="Top countries")
|
274 |
+
next_img_btn.click(next_versus_image, inputs=[versus_state], outputs=[versus_image, versus_state, continent_selection, country_selection[0]])
|
275 |
+
versus_btn.click(versus_mode_inputs, inputs=[versus_image, continent_selection, country_selection[0], versus_state], outputs=[
|
276 |
+
versus_output, continents_label, country_label, map, versus_state])
|
277 |
|
278 |
|
279 |
if __name__ == "__main__":
|
280 |
+
demo.launch()
|
dataset_examples/africa/3962011747224020_1024.jpg
DELETED
Binary file (145 kB)
|
|
dataset_examples/africa/3973126792769679_1024.jpg
DELETED
Binary file (69.5 kB)
|
|
dataset_examples/africa/4471109009586514_1024.jpg
DELETED
Binary file (147 kB)
|
|
dataset_examples/asia/106261888221766_1024.jpg
DELETED
Binary file (131 kB)
|
|
dataset_examples/asia/138321512570044_1024.jpg
DELETED
Binary file (78.1 kB)
|
|
dataset_examples/asia/147206360658971_1024.jpg
DELETED
Binary file (96.2 kB)
|
|
dataset_examples/europe/1423684677989158_1024.jpg
DELETED
Binary file (84.3 kB)
|
|
dataset_examples/europe/1483506425323136_1024.jpg
DELETED
Binary file (46.8 kB)
|
|
dataset_examples/europe/150428493699453_1024.jpg
DELETED
Binary file (140 kB)
|
|
dataset_examples/north america/1000276020376482_1024.jpg
DELETED
Binary file (66 kB)
|
|
dataset_examples/north america/757125001639938_1024.jpg
DELETED
Binary file (162 kB)
|
|
dataset_examples/north america/843371313196684_1024.jpg
DELETED
Binary file (257 kB)
|
|
dataset_examples/oceania/100141636075718_1024.jpg
DELETED
Binary file (84.2 kB)
|
|
dataset_examples/oceania/1899604010244250_1024.jpg
DELETED
Binary file (46.2 kB)
|
|
dataset_examples/oceania/821397982117703_1024.jpg
DELETED
Binary file (75.7 kB)
|
|
dataset_examples/south america/103386512798674_1024.jpg
DELETED
Binary file (135 kB)
|
|
dataset_examples/south america/1677652415776082_1024.jpg
DELETED
Binary file (94.5 kB)
|
|
dataset_examples/south america/327973242016483_1024.jpg
DELETED
Binary file (92.1 kB)
|
|
examples/1000276020376482_1024.jpg
DELETED
Binary file (66 kB)
|
|
examples/100141636075718_1024.jpg
DELETED
Binary file (84.2 kB)
|
|
examples/103386512798674_1024.jpg
DELETED
Binary file (135 kB)
|
|
examples/106261888221766_1024.jpg
DELETED
Binary file (131 kB)
|
|
examples/138321512570044_1024.jpg
DELETED
Binary file (78.1 kB)
|
|
examples/1423684677989158_1024.jpg
DELETED
Binary file (84.3 kB)
|
|
examples/147206360658971_1024.jpg
DELETED
Binary file (96.2 kB)
|
|
examples/1483506425323136_1024.jpg
DELETED
Binary file (46.8 kB)
|
|
examples/150428493699453_1024.jpg
DELETED
Binary file (140 kB)
|
|
examples/1677652415776082_1024.jpg
DELETED
Binary file (94.5 kB)
|
|
examples/1899604010244250_1024.jpg
DELETED
Binary file (46.2 kB)
|
|
examples/327973242016483_1024.jpg
DELETED
Binary file (92.1 kB)
|
|
examples/3962011747224020_1024.jpg
DELETED
Binary file (145 kB)
|
|
examples/3973126792769679_1024.jpg
DELETED
Binary file (69.5 kB)
|
|
examples/4471109009586514_1024.jpg
DELETED
Binary file (147 kB)
|
|
examples/757125001639938_1024.jpg
DELETED
Binary file (162 kB)
|
|
examples/821397982117703_1024.jpg
DELETED
Binary file (75.7 kB)
|
|
examples/843371313196684_1024.jpg
DELETED
Binary file (257 kB)
|
|
kerger-test-images/Africa_Botswana_-24.358520377382_23.5184910801.jpg
ADDED
![]() |
kerger-test-images/Africa_Kenya_-0.21870999999999_37.023791.jpg
ADDED
![]() |
kerger-test-images/Africa_Madagascar_-16.078452454738_46.73369803641.jpg
ADDED
![]() |
kerger-test-images/Africa_South Africa_-23.590135077274_28.785944164821.jpg
ADDED
![]() |
kerger-test-images/Africa_Tanzania_-3.3676537025657_36.716512872377.jpg
ADDED
![]() |
kerger-test-images/Africa_Uganda_1.1212866787272_33.915204986261.jpg
ADDED
![]() |
kerger-test-images/Asia_Israel_31.708865303742_34.94966916063.jpg
ADDED
![]() |
kerger-test-images/Asia_Japan_35.381304970616_134.65860211972.jpg
ADDED
![]() |
kerger-test-images/Asia_Pakistan_24.910493840503_69.506229024537.jpg
ADDED
![]() |
kerger-test-images/Asia_Russia_54.597757883015_48.163689656865.jpg
ADDED
![]() |
kerger-test-images/Asia_Russia_56.018311493214_38.359778952407.jpg
ADDED
![]() |
kerger-test-images/Asia_Russia_60.27835356798_29.754665851696.jpg
ADDED
![]() |
kerger-test-images/Asia_Thailand_19.824843951089_99.694080339609.jpg
ADDED
![]() |