Spaces:
Runtime error
Runtime error
Jonas Rheiner
commited on
Commit
Β·
e20beac
1
Parent(s):
8b18a0c
Reformat
Browse files
app.py
CHANGED
@@ -18,52 +18,158 @@ device = "cuda" if CUDA_AVAILABLE else "cpu"
|
|
18 |
print(f"count={torch.cuda.device_count()}")
|
19 |
print(f"current={torch.cuda.get_device_name(torch.cuda.current_device())}")
|
20 |
|
21 |
-
continent_model = CLIPModel.from_pretrained(
|
22 |
-
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
continent_model = continent_model.to(device)
|
25 |
country_model = country_model.to(device)
|
26 |
|
27 |
|
28 |
-
continents = ["Africa", "Asia", "Europe",
|
29 |
-
"North America", "Oceania", "South America"]
|
30 |
countries_per_continent = {
|
31 |
"Africa": [
|
32 |
-
"Botswana",
|
33 |
-
"
|
34 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
],
|
36 |
"Asia": [
|
37 |
-
"Bangladesh",
|
38 |
-
"
|
39 |
-
"
|
40 |
-
"
|
41 |
-
"
|
42 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
],
|
44 |
"Europe": [
|
45 |
-
"Albania",
|
46 |
-
"
|
47 |
-
"
|
48 |
-
"
|
49 |
-
"
|
50 |
-
"
|
51 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
],
|
53 |
"North America": [
|
54 |
-
"Canada",
|
55 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
],
|
57 |
"Oceania": [
|
58 |
-
"Australia",
|
|
|
|
|
|
|
|
|
|
|
59 |
],
|
60 |
"South America": [
|
61 |
-
"Argentina",
|
62 |
-
"
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
}
|
65 |
-
countries = list(set(itertools.chain.from_iterable(
|
66 |
-
countries_per_continent.values())))
|
67 |
|
68 |
country_to_center_coords = {
|
69 |
"Indonesia": (-2.4833826, 117.8902853),
|
@@ -181,7 +287,7 @@ country_to_center_coords = {
|
|
181 |
"Djibouti": (11.8145966, 42.8453061),
|
182 |
"Senegal": (14.4750607, -14.4529612),
|
183 |
"Bermuda": (32.3040273, -64.7563086),
|
184 |
-
"United States": (39.7837304, -100.445882)
|
185 |
}
|
186 |
|
187 |
INTIAL_VERSUS_IMAGE = "versus_images/Europe_Germany_49.069183_10.319444_im2gps3k.jpg"
|
@@ -191,29 +297,35 @@ INITAL_VERSUS_STATE = {
|
|
191 |
"country": INTIAL_VERSUS_IMAGE.split("/")[-1].split("_")[1],
|
192 |
"lat": INTIAL_VERSUS_IMAGE.split("/")[-1].split("_")[2],
|
193 |
"lon": INTIAL_VERSUS_IMAGE.split("/")[-1].split("_")[3],
|
194 |
-
"score": {
|
195 |
-
|
196 |
-
"AI": 0
|
197 |
-
},
|
198 |
-
"idx": 0
|
199 |
}
|
200 |
|
201 |
|
202 |
def predict(input_img):
|
203 |
-
inputs = processor(
|
204 |
-
|
|
|
|
|
|
|
|
|
205 |
inputs = inputs.to(device)
|
206 |
with torch.no_grad():
|
207 |
outputs = continent_model(**inputs)
|
208 |
logits_per_image = outputs.logits_per_image
|
209 |
probs = logits_per_image.softmax(dim=-1)
|
210 |
pred_id = probs.argmax().cpu().item()
|
211 |
-
continent_probs = {
|
212 |
-
|
|
|
213 |
model_continent = continents[pred_id]
|
214 |
predicted_continent_countries = countries_per_continent[model_continent]
|
215 |
-
inputs = processor(
|
216 |
-
|
|
|
|
|
|
|
|
|
217 |
inputs = inputs.to(device)
|
218 |
with torch.no_grad():
|
219 |
outputs = country_model(**inputs)
|
@@ -221,26 +333,37 @@ def predict(input_img):
|
|
221 |
probs = logits_per_image.softmax(dim=-1)
|
222 |
pred_id = probs.argmax().cpu().item()
|
223 |
model_country = predicted_continent_countries[pred_id]
|
224 |
-
country_probs = {
|
225 |
-
predicted_continent_countries, probs.tolist()[0])
|
226 |
-
|
|
|
227 |
hash = hashlib.sha1(np.asarray(input_img).data.tobytes()).hexdigest()
|
228 |
metadata_block = gr.Accordion(visible=False)
|
229 |
metadata_map = None
|
230 |
if hash in EXAMPLE_METADATA.keys():
|
231 |
model_result = ""
|
232 |
-
if
|
|
|
|
|
|
|
233 |
model_result = "The AI π€ correctly guessed continent and country β
β
."
|
234 |
-
elif model_continent == EXAMPLE_METADATA[hash][
|
235 |
model_result = "The AI π€ only guessed the correct continent β β
."
|
236 |
-
elif
|
|
|
|
|
|
|
237 |
model_result = "The AI π€ only guessed the correct country β
β."
|
238 |
else:
|
239 |
model_result = "The AI π€ failed to guess country and continent β β."
|
240 |
-
metadata_block = gr.Accordion(
|
|
|
|
|
|
|
241 |
metadata_map = make_versus_map(None, model_country, EXAMPLE_METADATA[hash])
|
242 |
return continent_probs, country_probs, metadata_block, metadata_map
|
243 |
|
|
|
244 |
def make_versus_map(human_country, model_country, versus_state):
|
245 |
if human_country:
|
246 |
human_coordinates = country_to_center_coords[human_country]
|
@@ -248,64 +371,66 @@ def make_versus_map(human_country, model_country, versus_state):
|
|
248 |
human_coordinates = (None, None)
|
249 |
model_coordinates = country_to_center_coords[model_country]
|
250 |
fig = go.Figure()
|
251 |
-
fig.add_trace(
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
versus_state['continent']}"],
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
showlegend=True,
|
260 |
-
name="π· Photo Location"
|
261 |
-
))
|
262 |
-
if human_country == model_country:
|
263 |
-
fig.add_trace(go.Scattermapbox(
|
264 |
-
lat=[human_coordinates[0], model_coordinates[0]],
|
265 |
-
lon=[human_coordinates[1], model_coordinates[1]],
|
266 |
-
text=f"π§ π€ Human & AI guess {human_country}",
|
267 |
-
mode='markers',
|
268 |
-
hoverinfo='text',
|
269 |
-
marker=dict(size=14, color='#FF9500'),
|
270 |
showlegend=True,
|
271 |
-
name="
|
272 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
273 |
else:
|
274 |
if human_country:
|
275 |
-
fig.add_trace(
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
282 |
showlegend=True,
|
283 |
-
name="
|
284 |
-
)
|
285 |
-
|
286 |
-
|
287 |
-
lon=[model_coordinates[1]],
|
288 |
-
text=[f"π€ AI guesses {model_country}"],
|
289 |
-
mode='markers',
|
290 |
-
hoverinfo='text',
|
291 |
-
marker=dict(size=14, color='#474747'),
|
292 |
-
showlegend=True,
|
293 |
-
name="π€ AI Guess"
|
294 |
-
))
|
295 |
-
|
296 |
fig.update_layout(
|
297 |
mapbox=dict(
|
298 |
style="carto-positron",
|
299 |
center=dict(lat=float(versus_state["lat"]), lon=float(versus_state["lon"])),
|
300 |
-
zoom=2
|
301 |
),
|
302 |
margin={"r": 0, "t": 0, "l": 0, "b": 0},
|
303 |
-
legend=dict(
|
304 |
-
yanchor="bottom",
|
305 |
-
y=0.01,
|
306 |
-
xanchor="left",
|
307 |
-
x=0.01
|
308 |
-
)
|
309 |
)
|
310 |
return fig
|
311 |
|
@@ -323,12 +448,13 @@ def versus_mode_inputs(input_img, human_continent, human_country, versus_state):
|
|
323 |
human_points += 1
|
324 |
else:
|
325 |
continent_result = "β"
|
326 |
-
human_result = f"The photo is from **{versus_state['country']}** {
|
327 |
-
|
328 |
-
|
329 |
-
|
|
|
330 |
|
331 |
-
continent_probs, country_probs, _,_ = predict(input_img)
|
332 |
model_country = max(country_probs, key=country_probs.get)
|
333 |
model_continent = max(continent_probs, key=continent_probs.get)
|
334 |
if model_country == versus_state["country"]:
|
@@ -341,11 +467,16 @@ def versus_mode_inputs(input_img, human_continent, human_country, versus_state):
|
|
341 |
model_points += 1
|
342 |
else:
|
343 |
model_continent_result = "β"
|
344 |
-
model_score_update =
|
345 |
-
|
|
|
|
|
|
|
|
|
346 |
|
347 |
map = make_versus_map(human_country, model_country, versus_state)
|
348 |
-
return
|
|
|
349 |
## {human_result}
|
350 |
### The AI π€ thinks this photo is from **{model_country}** {model_country_result} in **{model_continent}** {model_continent_result}
|
351 |
|
@@ -353,7 +484,12 @@ def versus_mode_inputs(input_img, human_continent, human_country, versus_state):
|
|
353 |
π€ {model_score_update}
|
354 |
|
355 |
### Score π§ {versus_state['score']['HUMAN']} : {versus_state['score']['AI']} π€
|
356 |
-
""",
|
|
|
|
|
|
|
|
|
|
|
357 |
|
358 |
|
359 |
def get_example_images(dir):
|
@@ -393,45 +529,65 @@ for img_path in example_images:
|
|
393 |
|
394 |
demo = gr.Blocks(title="Thesis Demo")
|
395 |
with demo:
|
396 |
-
gr.HTML(
|
|
|
397 |
<h1 style="text-align: center; margin-bottom: 1rem">Image Geolocation Thesis Demo</h1>
|
398 |
|
399 |
<h3> This Demo showcases the developed models and allows interacting with the optimized prototype.</h3>
|
400 |
<p>Try the <b>"Image Geolocation Demo"</b> tab with your own images or with one of the examples. For all example image the ground truth is available and will be displayed together with the model predictions.</p>
|
401 |
<p>In the <b>"Versus Mode"</b> tab to play against the AI, guessing the country and continent where images where taken. Images in the versus mode are from the <a href="http://graphics.cs.cmu.edu/projects/im2gps/"><code>Im2GPS</code></a> and <a href="https://arxiv.org/abs/1705.04838"><code>Im2GPS3k</code></a> geolocation literature benchmarks. Can you beat the AI?
|
402 |
|
403 |
-
"""
|
404 |
-
|
405 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
406 |
with gr.Tab("Image Geolocation Demo"):
|
407 |
with gr.Row():
|
408 |
with gr.Column():
|
409 |
-
image = gr.Image(
|
410 |
-
|
|
|
411 |
predict_btn = gr.Button("Predict")
|
412 |
example_images = get_example_images("kerger-test-images")
|
413 |
# example_images.extend(get_example_images("versus_images"))
|
414 |
-
gr.Examples(examples=example_images,
|
415 |
-
inputs=image, examples_per_page=24)
|
416 |
with gr.Column():
|
417 |
with gr.Accordion(visible=False) as metadata_block:
|
418 |
map = gr.Plot(label="Locations")
|
419 |
with gr.Group():
|
420 |
continents_label = gr.Label(label="Continents")
|
421 |
-
country_label = gr.Label(
|
422 |
-
|
423 |
-
|
424 |
-
|
|
|
|
|
425 |
|
426 |
with gr.Tab("Versus Mode"):
|
427 |
versus_state = gr.State(value=INITAL_VERSUS_STATE)
|
428 |
with gr.Row():
|
429 |
with gr.Column():
|
430 |
-
versus_image = gr.Image(
|
431 |
-
INITAL_VERSUS_STATE["image"], interactive=False)
|
432 |
continent_selection = gr.Radio(
|
433 |
-
continents,
|
434 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
435 |
with gr.Row():
|
436 |
next_img_btn = gr.Button("Try new image")
|
437 |
versus_btn = gr.Button("Submit guess")
|
@@ -443,11 +599,28 @@ with demo:
|
|
443 |
with gr.Group():
|
444 |
continents_label = gr.Label(label="Continents")
|
445 |
country_label = gr.Label(
|
446 |
-
num_top_classes=5, label="Top countries"
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
451 |
|
452 |
|
453 |
if __name__ == "__main__":
|
|
|
18 |
print(f"count={torch.cuda.device_count()}")
|
19 |
print(f"current={torch.cuda.get_device_name(torch.cuda.current_device())}")
|
20 |
|
21 |
+
continent_model = CLIPModel.from_pretrained(
|
22 |
+
"jrheiner/thesis-clip-geoloc-continent",
|
23 |
+
token=os.getenv("token"),
|
24 |
+
)
|
25 |
+
country_model = CLIPModel.from_pretrained(
|
26 |
+
"jrheiner/thesis-clip-geoloc-country",
|
27 |
+
token=os.getenv("token"),
|
28 |
+
)
|
29 |
+
processor = CLIPProcessor.from_pretrained(
|
30 |
+
"jrheiner/thesis-clip-geoloc-continent",
|
31 |
+
token=os.getenv("token"),
|
32 |
+
)
|
33 |
continent_model = continent_model.to(device)
|
34 |
country_model = country_model.to(device)
|
35 |
|
36 |
|
37 |
+
continents = ["Africa", "Asia", "Europe", "North America", "Oceania", "South America"]
|
|
|
38 |
countries_per_continent = {
|
39 |
"Africa": [
|
40 |
+
"Botswana",
|
41 |
+
"Eswatini",
|
42 |
+
"Ghana",
|
43 |
+
"Kenya",
|
44 |
+
"Lesotho",
|
45 |
+
"Nigeria",
|
46 |
+
"Senegal",
|
47 |
+
"South Africa",
|
48 |
+
"Rwanda",
|
49 |
+
"Uganda",
|
50 |
+
"Tanzania",
|
51 |
+
"Madagascar",
|
52 |
+
"Djibouti",
|
53 |
+
"Mali",
|
54 |
+
"Libya",
|
55 |
+
"Morocco",
|
56 |
+
"Somalia",
|
57 |
+
"Tunisia",
|
58 |
+
"Egypt",
|
59 |
+
"RΓ©union",
|
60 |
],
|
61 |
"Asia": [
|
62 |
+
"Bangladesh",
|
63 |
+
"Bhutan",
|
64 |
+
"Cambodia",
|
65 |
+
"China",
|
66 |
+
"India",
|
67 |
+
"Indonesia",
|
68 |
+
"Israel",
|
69 |
+
"Japan",
|
70 |
+
"Jordan",
|
71 |
+
"Kyrgyzstan",
|
72 |
+
"Laos",
|
73 |
+
"Malaysia",
|
74 |
+
"Mongolia",
|
75 |
+
"Nepal",
|
76 |
+
"Palestine",
|
77 |
+
"Philippines",
|
78 |
+
"Singapore",
|
79 |
+
"South Korea",
|
80 |
+
"Sri Lanka",
|
81 |
+
"Taiwan",
|
82 |
+
"Thailand",
|
83 |
+
"United Arab Emirates",
|
84 |
+
"Vietnam",
|
85 |
+
"Afghanistan",
|
86 |
+
"Azerbaijan",
|
87 |
+
"Cyprus",
|
88 |
+
"Iran",
|
89 |
+
"Syria",
|
90 |
+
"Tajikistan",
|
91 |
+
"Turkey",
|
92 |
+
"Russia",
|
93 |
+
"Pakistan",
|
94 |
+
"Hong Kong",
|
95 |
],
|
96 |
"Europe": [
|
97 |
+
"Albania",
|
98 |
+
"Andorra",
|
99 |
+
"Austria",
|
100 |
+
"Belgium",
|
101 |
+
"Bulgaria",
|
102 |
+
"Croatia",
|
103 |
+
"Czechia",
|
104 |
+
"Denmark",
|
105 |
+
"Estonia",
|
106 |
+
"Finland",
|
107 |
+
"France",
|
108 |
+
"Germany",
|
109 |
+
"Greece",
|
110 |
+
"Hungary",
|
111 |
+
"Iceland",
|
112 |
+
"Ireland",
|
113 |
+
"Italy",
|
114 |
+
"Latvia",
|
115 |
+
"Lithuania",
|
116 |
+
"Luxembourg",
|
117 |
+
"Montenegro",
|
118 |
+
"Netherlands",
|
119 |
+
"North Macedonia",
|
120 |
+
"Norway",
|
121 |
+
"Poland",
|
122 |
+
"Portugal",
|
123 |
+
"Romania",
|
124 |
+
"Russia",
|
125 |
+
"Serbia",
|
126 |
+
"Slovakia",
|
127 |
+
"Slovenia",
|
128 |
+
"Spain",
|
129 |
+
"Sweden",
|
130 |
+
"Switzerland",
|
131 |
+
"Ukraine",
|
132 |
+
"United Kingdom",
|
133 |
+
"Bosnia and Herzegovina",
|
134 |
+
"Cyprus",
|
135 |
+
"Turkey",
|
136 |
+
"Greenland",
|
137 |
+
"Faroe Islands",
|
138 |
],
|
139 |
"North America": [
|
140 |
+
"Canada",
|
141 |
+
"Dominican Republic",
|
142 |
+
"Guatemala",
|
143 |
+
"Mexico",
|
144 |
+
"United States",
|
145 |
+
"Bahamas",
|
146 |
+
"Cuba",
|
147 |
+
"Panama",
|
148 |
+
"Puerto Rico",
|
149 |
+
"Bermuda",
|
150 |
+
"Greenland",
|
151 |
],
|
152 |
"Oceania": [
|
153 |
+
"Australia",
|
154 |
+
"New Zealand",
|
155 |
+
"Fiji",
|
156 |
+
"Papua New Guinea",
|
157 |
+
"Solomon Islands",
|
158 |
+
"Vanuatu",
|
159 |
],
|
160 |
"South America": [
|
161 |
+
"Argentina",
|
162 |
+
"Bolivia",
|
163 |
+
"Brazil",
|
164 |
+
"Chile",
|
165 |
+
"Colombia",
|
166 |
+
"Ecuador",
|
167 |
+
"Paraguay",
|
168 |
+
"Peru",
|
169 |
+
"Uruguay",
|
170 |
+
],
|
171 |
}
|
172 |
+
countries = list(set(itertools.chain.from_iterable(countries_per_continent.values())))
|
|
|
173 |
|
174 |
country_to_center_coords = {
|
175 |
"Indonesia": (-2.4833826, 117.8902853),
|
|
|
287 |
"Djibouti": (11.8145966, 42.8453061),
|
288 |
"Senegal": (14.4750607, -14.4529612),
|
289 |
"Bermuda": (32.3040273, -64.7563086),
|
290 |
+
"United States": (39.7837304, -100.445882),
|
291 |
}
|
292 |
|
293 |
INTIAL_VERSUS_IMAGE = "versus_images/Europe_Germany_49.069183_10.319444_im2gps3k.jpg"
|
|
|
297 |
"country": INTIAL_VERSUS_IMAGE.split("/")[-1].split("_")[1],
|
298 |
"lat": INTIAL_VERSUS_IMAGE.split("/")[-1].split("_")[2],
|
299 |
"lon": INTIAL_VERSUS_IMAGE.split("/")[-1].split("_")[3],
|
300 |
+
"score": {"HUMAN": 0, "AI": 0},
|
301 |
+
"idx": 0,
|
|
|
|
|
|
|
302 |
}
|
303 |
|
304 |
|
305 |
def predict(input_img):
|
306 |
+
inputs = processor(
|
307 |
+
text=[f"A photo from {geo}." for geo in continents],
|
308 |
+
images=input_img,
|
309 |
+
return_tensors="pt",
|
310 |
+
padding=True,
|
311 |
+
)
|
312 |
inputs = inputs.to(device)
|
313 |
with torch.no_grad():
|
314 |
outputs = continent_model(**inputs)
|
315 |
logits_per_image = outputs.logits_per_image
|
316 |
probs = logits_per_image.softmax(dim=-1)
|
317 |
pred_id = probs.argmax().cpu().item()
|
318 |
+
continent_probs = {
|
319 |
+
label: prob for label, prob in zip(continents, probs.tolist()[0])
|
320 |
+
}
|
321 |
model_continent = continents[pred_id]
|
322 |
predicted_continent_countries = countries_per_continent[model_continent]
|
323 |
+
inputs = processor(
|
324 |
+
text=[f"A photo from {geo}." for geo in predicted_continent_countries],
|
325 |
+
images=input_img,
|
326 |
+
return_tensors="pt",
|
327 |
+
padding=True,
|
328 |
+
)
|
329 |
inputs = inputs.to(device)
|
330 |
with torch.no_grad():
|
331 |
outputs = country_model(**inputs)
|
|
|
333 |
probs = logits_per_image.softmax(dim=-1)
|
334 |
pred_id = probs.argmax().cpu().item()
|
335 |
model_country = predicted_continent_countries[pred_id]
|
336 |
+
country_probs = {
|
337 |
+
label: prob for label, prob in zip(predicted_continent_countries, probs.tolist()[0])
|
338 |
+
}
|
339 |
+
|
340 |
hash = hashlib.sha1(np.asarray(input_img).data.tobytes()).hexdigest()
|
341 |
metadata_block = gr.Accordion(visible=False)
|
342 |
metadata_map = None
|
343 |
if hash in EXAMPLE_METADATA.keys():
|
344 |
model_result = ""
|
345 |
+
if (
|
346 |
+
model_continent == EXAMPLE_METADATA[hash]["continent"]
|
347 |
+
and model_country == EXAMPLE_METADATA[hash]["country"]
|
348 |
+
):
|
349 |
model_result = "The AI π€ correctly guessed continent and country β
β
."
|
350 |
+
elif model_continent == EXAMPLE_METADATA[hash]["continent"]:
|
351 |
model_result = "The AI π€ only guessed the correct continent β β
."
|
352 |
+
elif (
|
353 |
+
model_country == EXAMPLE_METADATA[hash]["country"]
|
354 |
+
and model_continent != EXAMPLE_METADATA[hash]["continent"]
|
355 |
+
):
|
356 |
model_result = "The AI π€ only guessed the correct country β
β."
|
357 |
else:
|
358 |
model_result = "The AI π€ failed to guess country and continent β β."
|
359 |
+
metadata_block = gr.Accordion(
|
360 |
+
visible=True,
|
361 |
+
label=f"This photo was taken in {EXAMPLE_METADATA[hash]['country']}, {EXAMPLE_METADATA[hash]['continent']}.\n{model_result}",
|
362 |
+
)
|
363 |
metadata_map = make_versus_map(None, model_country, EXAMPLE_METADATA[hash])
|
364 |
return continent_probs, country_probs, metadata_block, metadata_map
|
365 |
|
366 |
+
|
367 |
def make_versus_map(human_country, model_country, versus_state):
|
368 |
if human_country:
|
369 |
human_coordinates = country_to_center_coords[human_country]
|
|
|
371 |
human_coordinates = (None, None)
|
372 |
model_coordinates = country_to_center_coords[model_country]
|
373 |
fig = go.Figure()
|
374 |
+
fig.add_trace(
|
375 |
+
go.Scattermapbox(
|
376 |
+
lon=[versus_state["lon"]],
|
377 |
+
lat=[versus_state["lat"]],
|
378 |
+
text=[f"π· Photo taken in {versus_state['country']}, {versus_state['continent']}"],
|
379 |
+
mode="markers",
|
380 |
+
hoverinfo="text",
|
381 |
+
marker=dict(size=14, color="#0C5DA5"),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
382 |
showlegend=True,
|
383 |
+
name="π· Photo Location",
|
384 |
+
)
|
385 |
+
)
|
386 |
+
if human_country == model_country:
|
387 |
+
fig.add_trace(
|
388 |
+
go.Scattermapbox(
|
389 |
+
lat=[human_coordinates[0], model_coordinates[0]],
|
390 |
+
lon=[human_coordinates[1], model_coordinates[1]],
|
391 |
+
text=f"π§ π€ Human & AI guess {human_country}",
|
392 |
+
mode="markers",
|
393 |
+
hoverinfo="text",
|
394 |
+
marker=dict(size=14, color="#FF9500"),
|
395 |
+
showlegend=True,
|
396 |
+
name="π§ π€ Human & AI Guess",
|
397 |
+
)
|
398 |
+
)
|
399 |
else:
|
400 |
if human_country:
|
401 |
+
fig.add_trace(
|
402 |
+
go.Scattermapbox(
|
403 |
+
lat=[human_coordinates[0]],
|
404 |
+
lon=[human_coordinates[1]],
|
405 |
+
text=[f"π§ Human guesses {human_country}"],
|
406 |
+
mode="markers",
|
407 |
+
hoverinfo="text",
|
408 |
+
marker=dict(size=14, color="#FF9500"),
|
409 |
+
showlegend=True,
|
410 |
+
name="π§ Human Guess",
|
411 |
+
)
|
412 |
+
)
|
413 |
+
fig.add_trace(
|
414 |
+
go.Scattermapbox(
|
415 |
+
lat=[model_coordinates[0]],
|
416 |
+
lon=[model_coordinates[1]],
|
417 |
+
text=[f"π€ AI guesses {model_country}"],
|
418 |
+
mode="markers",
|
419 |
+
hoverinfo="text",
|
420 |
+
marker=dict(size=14, color="#474747"),
|
421 |
showlegend=True,
|
422 |
+
name="π€ AI Guess",
|
423 |
+
)
|
424 |
+
)
|
425 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
426 |
fig.update_layout(
|
427 |
mapbox=dict(
|
428 |
style="carto-positron",
|
429 |
center=dict(lat=float(versus_state["lat"]), lon=float(versus_state["lon"])),
|
430 |
+
zoom=2,
|
431 |
),
|
432 |
margin={"r": 0, "t": 0, "l": 0, "b": 0},
|
433 |
+
legend=dict(yanchor="bottom", y=0.01, xanchor="left", x=0.01),
|
|
|
|
|
|
|
|
|
|
|
434 |
)
|
435 |
return fig
|
436 |
|
|
|
448 |
human_points += 1
|
449 |
else:
|
450 |
continent_result = "β"
|
451 |
+
human_result = f"The photo is from **{versus_state['country']}** {country_result} in **{versus_state['continent']}** {continent_result}"
|
452 |
+
human_score_update = (
|
453 |
+
f"+{human_points} points" if human_points > 0 else "0 Points..."
|
454 |
+
)
|
455 |
+
versus_state["score"]["HUMAN"] += human_points
|
456 |
|
457 |
+
continent_probs, country_probs, _, _ = predict(input_img)
|
458 |
model_country = max(country_probs, key=country_probs.get)
|
459 |
model_continent = max(continent_probs, key=continent_probs.get)
|
460 |
if model_country == versus_state["country"]:
|
|
|
467 |
model_points += 1
|
468 |
else:
|
469 |
model_continent_result = "β"
|
470 |
+
model_score_update = (
|
471 |
+
f"+{model_points} points"
|
472 |
+
if model_points > 0
|
473 |
+
else "0 Points... The model was completely wrong, it seems the world is not doomed yet."
|
474 |
+
)
|
475 |
+
versus_state["score"]["AI"] += model_points
|
476 |
|
477 |
map = make_versus_map(human_country, model_country, versus_state)
|
478 |
+
return (
|
479 |
+
f"""
|
480 |
## {human_result}
|
481 |
### The AI π€ thinks this photo is from **{model_country}** {model_country_result} in **{model_continent}** {model_continent_result}
|
482 |
|
|
|
484 |
π€ {model_score_update}
|
485 |
|
486 |
### Score π§ {versus_state['score']['HUMAN']} : {versus_state['score']['AI']} π€
|
487 |
+
""",
|
488 |
+
continent_probs,
|
489 |
+
country_probs,
|
490 |
+
map,
|
491 |
+
versus_state,
|
492 |
+
)
|
493 |
|
494 |
|
495 |
def get_example_images(dir):
|
|
|
529 |
|
530 |
demo = gr.Blocks(title="Thesis Demo")
|
531 |
with demo:
|
532 |
+
gr.HTML(
|
533 |
+
"""
|
534 |
<h1 style="text-align: center; margin-bottom: 1rem">Image Geolocation Thesis Demo</h1>
|
535 |
|
536 |
<h3> This Demo showcases the developed models and allows interacting with the optimized prototype.</h3>
|
537 |
<p>Try the <b>"Image Geolocation Demo"</b> tab with your own images or with one of the examples. For all example image the ground truth is available and will be displayed together with the model predictions.</p>
|
538 |
<p>In the <b>"Versus Mode"</b> tab to play against the AI, guessing the country and continent where images where taken. Images in the versus mode are from the <a href="http://graphics.cs.cmu.edu/projects/im2gps/"><code>Im2GPS</code></a> and <a href="https://arxiv.org/abs/1705.04838"><code>Im2GPS3k</code></a> geolocation literature benchmarks. Can you beat the AI?
|
539 |
|
540 |
+
"""
|
541 |
+
)
|
542 |
+
with gr.Accordion(
|
543 |
+
label="The demo currently encompasses 116 countries from 6 continents π",
|
544 |
+
open=False,
|
545 |
+
):
|
546 |
+
gr.Code(
|
547 |
+
json.dumps(countries_per_continent, indent=2, ensure_ascii=False),
|
548 |
+
label="countries_per_continent.json",
|
549 |
+
language="json",
|
550 |
+
interactive=False,
|
551 |
+
)
|
552 |
with gr.Tab("Image Geolocation Demo"):
|
553 |
with gr.Row():
|
554 |
with gr.Column():
|
555 |
+
image = gr.Image(
|
556 |
+
label="Image", type="pil", sources=["upload", "clipboard"]
|
557 |
+
)
|
558 |
predict_btn = gr.Button("Predict")
|
559 |
example_images = get_example_images("kerger-test-images")
|
560 |
# example_images.extend(get_example_images("versus_images"))
|
561 |
+
gr.Examples(examples=example_images, inputs=image, examples_per_page=24)
|
|
|
562 |
with gr.Column():
|
563 |
with gr.Accordion(visible=False) as metadata_block:
|
564 |
map = gr.Plot(label="Locations")
|
565 |
with gr.Group():
|
566 |
continents_label = gr.Label(label="Continents")
|
567 |
+
country_label = gr.Label(num_top_classes=5, label="Top countries")
|
568 |
+
predict_btn.click(
|
569 |
+
predict,
|
570 |
+
inputs=image,
|
571 |
+
outputs=[continents_label, country_label, metadata_block, map],
|
572 |
+
)
|
573 |
|
574 |
with gr.Tab("Versus Mode"):
|
575 |
versus_state = gr.State(value=INITAL_VERSUS_STATE)
|
576 |
with gr.Row():
|
577 |
with gr.Column():
|
578 |
+
versus_image = gr.Image(INITAL_VERSUS_STATE["image"], interactive=False)
|
|
|
579 |
continent_selection = gr.Radio(
|
580 |
+
continents,
|
581 |
+
label="Continents",
|
582 |
+
info="Where was this image taken? (1 Point)",
|
583 |
+
)
|
584 |
+
country_selection = (
|
585 |
+
gr.Dropdown(
|
586 |
+
countries,
|
587 |
+
label="Countries",
|
588 |
+
info="Can you guess the exact country? (2 Points)",
|
589 |
+
),
|
590 |
+
)
|
591 |
with gr.Row():
|
592 |
next_img_btn = gr.Button("Try new image")
|
593 |
versus_btn = gr.Button("Submit guess")
|
|
|
599 |
with gr.Group():
|
600 |
continents_label = gr.Label(label="Continents")
|
601 |
country_label = gr.Label(
|
602 |
+
num_top_classes=5, label="Top countries"
|
603 |
+
)
|
604 |
+
next_img_btn.click(
|
605 |
+
next_versus_image,
|
606 |
+
inputs=[versus_state],
|
607 |
+
outputs=[
|
608 |
+
versus_image,
|
609 |
+
versus_state,
|
610 |
+
continent_selection,
|
611 |
+
country_selection[0],
|
612 |
+
],
|
613 |
+
)
|
614 |
+
versus_btn.click(
|
615 |
+
versus_mode_inputs,
|
616 |
+
inputs=[
|
617 |
+
versus_image,
|
618 |
+
continent_selection,
|
619 |
+
country_selection[0],
|
620 |
+
versus_state,
|
621 |
+
],
|
622 |
+
outputs=[versus_output, continents_label, country_label, map, versus_state],
|
623 |
+
)
|
624 |
|
625 |
|
626 |
if __name__ == "__main__":
|