Jonas Rheiner commited on
Commit
e20beac
Β·
1 Parent(s): 8b18a0c
Files changed (1) hide show
  1. app.py +299 -126
app.py CHANGED
@@ -18,52 +18,158 @@ device = "cuda" if CUDA_AVAILABLE else "cpu"
18
  print(f"count={torch.cuda.device_count()}")
19
  print(f"current={torch.cuda.get_device_name(torch.cuda.current_device())}")
20
 
21
- continent_model = CLIPModel.from_pretrained("jrheiner/thesis-clip-geoloc-continent", token=os.getenv("token"))
22
- country_model = CLIPModel.from_pretrained("jrheiner/thesis-clip-geoloc-country", token=os.getenv("token"))
23
- processor = CLIPProcessor.from_pretrained("jrheiner/thesis-clip-geoloc-continent", token=os.getenv("token"))
 
 
 
 
 
 
 
 
 
24
  continent_model = continent_model.to(device)
25
  country_model = country_model.to(device)
26
 
27
 
28
- continents = ["Africa", "Asia", "Europe",
29
- "North America", "Oceania", "South America"]
30
  countries_per_continent = {
31
  "Africa": [
32
- "Botswana", "Eswatini", "Ghana", "Kenya", "Lesotho", "Nigeria", "Senegal",
33
- "South Africa", "Rwanda", "Uganda", "Tanzania", "Madagascar", "Djibouti",
34
- "Mali", "Libya", "Morocco", "Somalia", "Tunisia", "Egypt", "RΓ©union"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  ],
36
  "Asia": [
37
- "Bangladesh", "Bhutan", "Cambodia", "China", "India", "Indonesia", "Israel",
38
- "Japan", "Jordan", "Kyrgyzstan", "Laos", "Malaysia", "Mongolia", "Nepal",
39
- "Palestine", "Philippines", "Singapore", "South Korea", "Sri Lanka",
40
- "Taiwan", "Thailand", "United Arab Emirates", "Vietnam", "Afghanistan",
41
- "Azerbaijan", "Cyprus", "Iran", "Syria", "Tajikistan", "Turkey", "Russia",
42
- "Pakistan", "Hong Kong"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  ],
44
  "Europe": [
45
- "Albania", "Andorra", "Austria", "Belgium", "Bulgaria", "Croatia", "Czechia",
46
- "Denmark", "Estonia", "Finland", "France", "Germany", "Greece", "Hungary",
47
- "Iceland", "Ireland", "Italy", "Latvia", "Lithuania", "Luxembourg",
48
- "Montenegro", "Netherlands", "North Macedonia", "Norway", "Poland",
49
- "Portugal", "Romania", "Russia", "Serbia", "Slovakia", "Slovenia", "Spain",
50
- "Sweden", "Switzerland", "Ukraine", "United Kingdom", "Bosnia and Herzegovina",
51
- "Cyprus", "Turkey", "Greenland", "Faroe Islands"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  ],
53
  "North America": [
54
- "Canada", "Dominican Republic", "Guatemala", "Mexico", "United States",
55
- "Bahamas", "Cuba", "Panama", "Puerto Rico", "Bermuda", "Greenland"
 
 
 
 
 
 
 
 
 
56
  ],
57
  "Oceania": [
58
- "Australia", "New Zealand", "Fiji", "Papua New Guinea", "Solomon Islands", "Vanuatu"
 
 
 
 
 
59
  ],
60
  "South America": [
61
- "Argentina", "Bolivia", "Brazil", "Chile", "Colombia", "Ecuador", "Paraguay",
62
- "Peru", "Uruguay"
63
- ]
 
 
 
 
 
 
 
64
  }
65
- countries = list(set(itertools.chain.from_iterable(
66
- countries_per_continent.values())))
67
 
68
  country_to_center_coords = {
69
  "Indonesia": (-2.4833826, 117.8902853),
@@ -181,7 +287,7 @@ country_to_center_coords = {
181
  "Djibouti": (11.8145966, 42.8453061),
182
  "Senegal": (14.4750607, -14.4529612),
183
  "Bermuda": (32.3040273, -64.7563086),
184
- "United States": (39.7837304, -100.445882)
185
  }
186
 
187
  INTIAL_VERSUS_IMAGE = "versus_images/Europe_Germany_49.069183_10.319444_im2gps3k.jpg"
@@ -191,29 +297,35 @@ INITAL_VERSUS_STATE = {
191
  "country": INTIAL_VERSUS_IMAGE.split("/")[-1].split("_")[1],
192
  "lat": INTIAL_VERSUS_IMAGE.split("/")[-1].split("_")[2],
193
  "lon": INTIAL_VERSUS_IMAGE.split("/")[-1].split("_")[3],
194
- "score": {
195
- "HUMAN": 0,
196
- "AI": 0
197
- },
198
- "idx": 0
199
  }
200
 
201
 
202
  def predict(input_img):
203
- inputs = processor(text=[f"A photo from {
204
- geo}." for geo in continents], images=input_img, return_tensors="pt", padding=True)
 
 
 
 
205
  inputs = inputs.to(device)
206
  with torch.no_grad():
207
  outputs = continent_model(**inputs)
208
  logits_per_image = outputs.logits_per_image
209
  probs = logits_per_image.softmax(dim=-1)
210
  pred_id = probs.argmax().cpu().item()
211
- continent_probs = {label: prob for label,
212
- prob in zip(continents, probs.tolist()[0])}
 
213
  model_continent = continents[pred_id]
214
  predicted_continent_countries = countries_per_continent[model_continent]
215
- inputs = processor(text=[f"A photo from {
216
- geo}." for geo in predicted_continent_countries], images=input_img, return_tensors="pt", padding=True)
 
 
 
 
217
  inputs = inputs.to(device)
218
  with torch.no_grad():
219
  outputs = country_model(**inputs)
@@ -221,26 +333,37 @@ def predict(input_img):
221
  probs = logits_per_image.softmax(dim=-1)
222
  pred_id = probs.argmax().cpu().item()
223
  model_country = predicted_continent_countries[pred_id]
224
- country_probs = {label: prob for label, prob in zip(
225
- predicted_continent_countries, probs.tolist()[0])}
226
-
 
227
  hash = hashlib.sha1(np.asarray(input_img).data.tobytes()).hexdigest()
228
  metadata_block = gr.Accordion(visible=False)
229
  metadata_map = None
230
  if hash in EXAMPLE_METADATA.keys():
231
  model_result = ""
232
- if model_continent == EXAMPLE_METADATA[hash]['continent'] and model_country == EXAMPLE_METADATA[hash]['country']:
 
 
 
233
  model_result = "The AI πŸ€– correctly guessed continent and country βœ… βœ…."
234
- elif model_continent == EXAMPLE_METADATA[hash]['continent']:
235
  model_result = "The AI πŸ€– only guessed the correct continent ❌ βœ…."
236
- elif model_country == EXAMPLE_METADATA[hash]['country'] and model_continent != EXAMPLE_METADATA[hash]['continent']:
 
 
 
237
  model_result = "The AI πŸ€– only guessed the correct country βœ… ❌."
238
  else:
239
  model_result = "The AI πŸ€– failed to guess country and continent ❌ ❌."
240
- metadata_block = gr.Accordion(visible=True, label=f"This photo was taken in {EXAMPLE_METADATA[hash]['country']}, {EXAMPLE_METADATA[hash]['continent']}.\n{model_result}")
 
 
 
241
  metadata_map = make_versus_map(None, model_country, EXAMPLE_METADATA[hash])
242
  return continent_probs, country_probs, metadata_block, metadata_map
243
 
 
244
  def make_versus_map(human_country, model_country, versus_state):
245
  if human_country:
246
  human_coordinates = country_to_center_coords[human_country]
@@ -248,64 +371,66 @@ def make_versus_map(human_country, model_country, versus_state):
248
  human_coordinates = (None, None)
249
  model_coordinates = country_to_center_coords[model_country]
250
  fig = go.Figure()
251
- fig.add_trace(go.Scattermapbox(
252
- lon=[versus_state["lon"]],
253
- lat=[versus_state["lat"]],
254
- text=[f"πŸ“· Photo taken in {versus_state['country']}, {
255
- versus_state['continent']}"],
256
- mode='markers',
257
- hoverinfo='text',
258
- marker=dict(size=14, color='#0C5DA5'),
259
- showlegend=True,
260
- name="πŸ“· Photo Location"
261
- ))
262
- if human_country == model_country:
263
- fig.add_trace(go.Scattermapbox(
264
- lat=[human_coordinates[0], model_coordinates[0]],
265
- lon=[human_coordinates[1], model_coordinates[1]],
266
- text=f"πŸ§‘ πŸ€– Human & AI guess {human_country}",
267
- mode='markers',
268
- hoverinfo='text',
269
- marker=dict(size=14, color='#FF9500'),
270
  showlegend=True,
271
- name="πŸ§‘ πŸ€– Human & AI Guess"
272
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
  else:
274
  if human_country:
275
- fig.add_trace(go.Scattermapbox(
276
- lat=[human_coordinates[0]],
277
- lon=[human_coordinates[1]],
278
- text=[f"πŸ§‘ Human guesses {human_country}"],
279
- mode='markers',
280
- hoverinfo='text',
281
- marker=dict(size=14, color='#FF9500'),
 
 
 
 
 
 
 
 
 
 
 
 
 
282
  showlegend=True,
283
- name="πŸ§‘ Human Guess"
284
- ))
285
- fig.add_trace(go.Scattermapbox(
286
- lat=[model_coordinates[0]],
287
- lon=[model_coordinates[1]],
288
- text=[f"πŸ€– AI guesses {model_country}"],
289
- mode='markers',
290
- hoverinfo='text',
291
- marker=dict(size=14, color='#474747'),
292
- showlegend=True,
293
- name="πŸ€– AI Guess"
294
- ))
295
-
296
  fig.update_layout(
297
  mapbox=dict(
298
  style="carto-positron",
299
  center=dict(lat=float(versus_state["lat"]), lon=float(versus_state["lon"])),
300
- zoom=2
301
  ),
302
  margin={"r": 0, "t": 0, "l": 0, "b": 0},
303
- legend=dict(
304
- yanchor="bottom",
305
- y=0.01,
306
- xanchor="left",
307
- x=0.01
308
- )
309
  )
310
  return fig
311
 
@@ -323,12 +448,13 @@ def versus_mode_inputs(input_img, human_continent, human_country, versus_state):
323
  human_points += 1
324
  else:
325
  continent_result = "❌"
326
- human_result = f"The photo is from **{versus_state['country']}** {
327
- country_result} in **{versus_state['continent']}** {continent_result}"
328
- human_score_update = f"+{human_points} points" if human_points > 0 else "0 Points..."
329
- versus_state['score']['HUMAN'] += human_points
 
330
 
331
- continent_probs, country_probs, _,_ = predict(input_img)
332
  model_country = max(country_probs, key=country_probs.get)
333
  model_continent = max(continent_probs, key=continent_probs.get)
334
  if model_country == versus_state["country"]:
@@ -341,11 +467,16 @@ def versus_mode_inputs(input_img, human_continent, human_country, versus_state):
341
  model_points += 1
342
  else:
343
  model_continent_result = "❌"
344
- model_score_update = f"+{model_points} points" if model_points > 0 else "0 Points... The model was completely wrong, it seems the world is not doomed yet."
345
- versus_state['score']['AI'] += model_points
 
 
 
 
346
 
347
  map = make_versus_map(human_country, model_country, versus_state)
348
- return f"""
 
349
  ## {human_result}
350
  ### The AI πŸ€– thinks this photo is from **{model_country}** {model_country_result} in **{model_continent}** {model_continent_result}
351
 
@@ -353,7 +484,12 @@ def versus_mode_inputs(input_img, human_continent, human_country, versus_state):
353
  πŸ€– {model_score_update}
354
 
355
  ### Score πŸ§‘ {versus_state['score']['HUMAN']} : {versus_state['score']['AI']} πŸ€–
356
- """, continent_probs, country_probs, map, versus_state
 
 
 
 
 
357
 
358
 
359
  def get_example_images(dir):
@@ -393,45 +529,65 @@ for img_path in example_images:
393
 
394
  demo = gr.Blocks(title="Thesis Demo")
395
  with demo:
396
- gr.HTML("""
 
397
  <h1 style="text-align: center; margin-bottom: 1rem">Image Geolocation Thesis Demo</h1>
398
 
399
  <h3> This Demo showcases the developed models and allows interacting with the optimized prototype.</h3>
400
  <p>Try the <b>"Image Geolocation Demo"</b> tab with your own images or with one of the examples. For all example image the ground truth is available and will be displayed together with the model predictions.</p>
401
  <p>In the <b>"Versus Mode"</b> tab to play against the AI, guessing the country and continent where images where taken. Images in the versus mode are from the <a href="http://graphics.cs.cmu.edu/projects/im2gps/"><code>Im2GPS</code></a> and <a href="https://arxiv.org/abs/1705.04838"><code>Im2GPS3k</code></a> geolocation literature benchmarks. Can you beat the AI?
402
 
403
- """)
404
- with gr.Accordion(label="The demo currently encompasses 116 countries from 6 continents 🌍", open=False):
405
- gr.Code(json.dumps(countries_per_continent, indent=2, ensure_ascii=False), label="countries_per_continent.json", language="json", interactive=False)
 
 
 
 
 
 
 
 
 
406
  with gr.Tab("Image Geolocation Demo"):
407
  with gr.Row():
408
  with gr.Column():
409
- image = gr.Image(label="Image", type="pil",
410
- sources=["upload", "clipboard"])
 
411
  predict_btn = gr.Button("Predict")
412
  example_images = get_example_images("kerger-test-images")
413
  # example_images.extend(get_example_images("versus_images"))
414
- gr.Examples(examples=example_images,
415
- inputs=image, examples_per_page=24)
416
  with gr.Column():
417
  with gr.Accordion(visible=False) as metadata_block:
418
  map = gr.Plot(label="Locations")
419
  with gr.Group():
420
  continents_label = gr.Label(label="Continents")
421
- country_label = gr.Label(
422
- num_top_classes=5, label="Top countries")
423
- predict_btn.click(predict, inputs=image, outputs=[
424
- continents_label, country_label, metadata_block, map])
 
 
425
 
426
  with gr.Tab("Versus Mode"):
427
  versus_state = gr.State(value=INITAL_VERSUS_STATE)
428
  with gr.Row():
429
  with gr.Column():
430
- versus_image = gr.Image(
431
- INITAL_VERSUS_STATE["image"], interactive=False)
432
  continent_selection = gr.Radio(
433
- continents, label="Continents", info="Where was this image taken? (1 Point)")
434
- country_selection = gr.Dropdown(countries, label="Countries", info="Can you guess the exact country? (2 Points)"),
 
 
 
 
 
 
 
 
 
435
  with gr.Row():
436
  next_img_btn = gr.Button("Try new image")
437
  versus_btn = gr.Button("Submit guess")
@@ -443,11 +599,28 @@ with demo:
443
  with gr.Group():
444
  continents_label = gr.Label(label="Continents")
445
  country_label = gr.Label(
446
- num_top_classes=5, label="Top countries")
447
- next_img_btn.click(next_versus_image, inputs=[versus_state], outputs=[
448
- versus_image, versus_state, continent_selection, country_selection[0]])
449
- versus_btn.click(versus_mode_inputs, inputs=[versus_image, continent_selection, country_selection[0], versus_state], outputs=[
450
- versus_output, continents_label, country_label, map, versus_state])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
451
 
452
 
453
  if __name__ == "__main__":
 
18
  print(f"count={torch.cuda.device_count()}")
19
  print(f"current={torch.cuda.get_device_name(torch.cuda.current_device())}")
20
 
21
+ continent_model = CLIPModel.from_pretrained(
22
+ "jrheiner/thesis-clip-geoloc-continent",
23
+ token=os.getenv("token"),
24
+ )
25
+ country_model = CLIPModel.from_pretrained(
26
+ "jrheiner/thesis-clip-geoloc-country",
27
+ token=os.getenv("token"),
28
+ )
29
+ processor = CLIPProcessor.from_pretrained(
30
+ "jrheiner/thesis-clip-geoloc-continent",
31
+ token=os.getenv("token"),
32
+ )
33
  continent_model = continent_model.to(device)
34
  country_model = country_model.to(device)
35
 
36
 
37
+ continents = ["Africa", "Asia", "Europe", "North America", "Oceania", "South America"]
 
38
  countries_per_continent = {
39
  "Africa": [
40
+ "Botswana",
41
+ "Eswatini",
42
+ "Ghana",
43
+ "Kenya",
44
+ "Lesotho",
45
+ "Nigeria",
46
+ "Senegal",
47
+ "South Africa",
48
+ "Rwanda",
49
+ "Uganda",
50
+ "Tanzania",
51
+ "Madagascar",
52
+ "Djibouti",
53
+ "Mali",
54
+ "Libya",
55
+ "Morocco",
56
+ "Somalia",
57
+ "Tunisia",
58
+ "Egypt",
59
+ "RΓ©union",
60
  ],
61
  "Asia": [
62
+ "Bangladesh",
63
+ "Bhutan",
64
+ "Cambodia",
65
+ "China",
66
+ "India",
67
+ "Indonesia",
68
+ "Israel",
69
+ "Japan",
70
+ "Jordan",
71
+ "Kyrgyzstan",
72
+ "Laos",
73
+ "Malaysia",
74
+ "Mongolia",
75
+ "Nepal",
76
+ "Palestine",
77
+ "Philippines",
78
+ "Singapore",
79
+ "South Korea",
80
+ "Sri Lanka",
81
+ "Taiwan",
82
+ "Thailand",
83
+ "United Arab Emirates",
84
+ "Vietnam",
85
+ "Afghanistan",
86
+ "Azerbaijan",
87
+ "Cyprus",
88
+ "Iran",
89
+ "Syria",
90
+ "Tajikistan",
91
+ "Turkey",
92
+ "Russia",
93
+ "Pakistan",
94
+ "Hong Kong",
95
  ],
96
  "Europe": [
97
+ "Albania",
98
+ "Andorra",
99
+ "Austria",
100
+ "Belgium",
101
+ "Bulgaria",
102
+ "Croatia",
103
+ "Czechia",
104
+ "Denmark",
105
+ "Estonia",
106
+ "Finland",
107
+ "France",
108
+ "Germany",
109
+ "Greece",
110
+ "Hungary",
111
+ "Iceland",
112
+ "Ireland",
113
+ "Italy",
114
+ "Latvia",
115
+ "Lithuania",
116
+ "Luxembourg",
117
+ "Montenegro",
118
+ "Netherlands",
119
+ "North Macedonia",
120
+ "Norway",
121
+ "Poland",
122
+ "Portugal",
123
+ "Romania",
124
+ "Russia",
125
+ "Serbia",
126
+ "Slovakia",
127
+ "Slovenia",
128
+ "Spain",
129
+ "Sweden",
130
+ "Switzerland",
131
+ "Ukraine",
132
+ "United Kingdom",
133
+ "Bosnia and Herzegovina",
134
+ "Cyprus",
135
+ "Turkey",
136
+ "Greenland",
137
+ "Faroe Islands",
138
  ],
139
  "North America": [
140
+ "Canada",
141
+ "Dominican Republic",
142
+ "Guatemala",
143
+ "Mexico",
144
+ "United States",
145
+ "Bahamas",
146
+ "Cuba",
147
+ "Panama",
148
+ "Puerto Rico",
149
+ "Bermuda",
150
+ "Greenland",
151
  ],
152
  "Oceania": [
153
+ "Australia",
154
+ "New Zealand",
155
+ "Fiji",
156
+ "Papua New Guinea",
157
+ "Solomon Islands",
158
+ "Vanuatu",
159
  ],
160
  "South America": [
161
+ "Argentina",
162
+ "Bolivia",
163
+ "Brazil",
164
+ "Chile",
165
+ "Colombia",
166
+ "Ecuador",
167
+ "Paraguay",
168
+ "Peru",
169
+ "Uruguay",
170
+ ],
171
  }
172
+ countries = list(set(itertools.chain.from_iterable(countries_per_continent.values())))
 
173
 
174
  country_to_center_coords = {
175
  "Indonesia": (-2.4833826, 117.8902853),
 
287
  "Djibouti": (11.8145966, 42.8453061),
288
  "Senegal": (14.4750607, -14.4529612),
289
  "Bermuda": (32.3040273, -64.7563086),
290
+ "United States": (39.7837304, -100.445882),
291
  }
292
 
293
  INTIAL_VERSUS_IMAGE = "versus_images/Europe_Germany_49.069183_10.319444_im2gps3k.jpg"
 
297
  "country": INTIAL_VERSUS_IMAGE.split("/")[-1].split("_")[1],
298
  "lat": INTIAL_VERSUS_IMAGE.split("/")[-1].split("_")[2],
299
  "lon": INTIAL_VERSUS_IMAGE.split("/")[-1].split("_")[3],
300
+ "score": {"HUMAN": 0, "AI": 0},
301
+ "idx": 0,
 
 
 
302
  }
303
 
304
 
305
  def predict(input_img):
306
+ inputs = processor(
307
+ text=[f"A photo from {geo}." for geo in continents],
308
+ images=input_img,
309
+ return_tensors="pt",
310
+ padding=True,
311
+ )
312
  inputs = inputs.to(device)
313
  with torch.no_grad():
314
  outputs = continent_model(**inputs)
315
  logits_per_image = outputs.logits_per_image
316
  probs = logits_per_image.softmax(dim=-1)
317
  pred_id = probs.argmax().cpu().item()
318
+ continent_probs = {
319
+ label: prob for label, prob in zip(continents, probs.tolist()[0])
320
+ }
321
  model_continent = continents[pred_id]
322
  predicted_continent_countries = countries_per_continent[model_continent]
323
+ inputs = processor(
324
+ text=[f"A photo from {geo}." for geo in predicted_continent_countries],
325
+ images=input_img,
326
+ return_tensors="pt",
327
+ padding=True,
328
+ )
329
  inputs = inputs.to(device)
330
  with torch.no_grad():
331
  outputs = country_model(**inputs)
 
333
  probs = logits_per_image.softmax(dim=-1)
334
  pred_id = probs.argmax().cpu().item()
335
  model_country = predicted_continent_countries[pred_id]
336
+ country_probs = {
337
+ label: prob for label, prob in zip(predicted_continent_countries, probs.tolist()[0])
338
+ }
339
+
340
  hash = hashlib.sha1(np.asarray(input_img).data.tobytes()).hexdigest()
341
  metadata_block = gr.Accordion(visible=False)
342
  metadata_map = None
343
  if hash in EXAMPLE_METADATA.keys():
344
  model_result = ""
345
+ if (
346
+ model_continent == EXAMPLE_METADATA[hash]["continent"]
347
+ and model_country == EXAMPLE_METADATA[hash]["country"]
348
+ ):
349
  model_result = "The AI πŸ€– correctly guessed continent and country βœ… βœ…."
350
+ elif model_continent == EXAMPLE_METADATA[hash]["continent"]:
351
  model_result = "The AI πŸ€– only guessed the correct continent ❌ βœ…."
352
+ elif (
353
+ model_country == EXAMPLE_METADATA[hash]["country"]
354
+ and model_continent != EXAMPLE_METADATA[hash]["continent"]
355
+ ):
356
  model_result = "The AI πŸ€– only guessed the correct country βœ… ❌."
357
  else:
358
  model_result = "The AI πŸ€– failed to guess country and continent ❌ ❌."
359
+ metadata_block = gr.Accordion(
360
+ visible=True,
361
+ label=f"This photo was taken in {EXAMPLE_METADATA[hash]['country']}, {EXAMPLE_METADATA[hash]['continent']}.\n{model_result}",
362
+ )
363
  metadata_map = make_versus_map(None, model_country, EXAMPLE_METADATA[hash])
364
  return continent_probs, country_probs, metadata_block, metadata_map
365
 
366
+
367
  def make_versus_map(human_country, model_country, versus_state):
368
  if human_country:
369
  human_coordinates = country_to_center_coords[human_country]
 
371
  human_coordinates = (None, None)
372
  model_coordinates = country_to_center_coords[model_country]
373
  fig = go.Figure()
374
+ fig.add_trace(
375
+ go.Scattermapbox(
376
+ lon=[versus_state["lon"]],
377
+ lat=[versus_state["lat"]],
378
+ text=[f"πŸ“· Photo taken in {versus_state['country']}, {versus_state['continent']}"],
379
+ mode="markers",
380
+ hoverinfo="text",
381
+ marker=dict(size=14, color="#0C5DA5"),
 
 
 
 
 
 
 
 
 
 
 
382
  showlegend=True,
383
+ name="πŸ“· Photo Location",
384
+ )
385
+ )
386
+ if human_country == model_country:
387
+ fig.add_trace(
388
+ go.Scattermapbox(
389
+ lat=[human_coordinates[0], model_coordinates[0]],
390
+ lon=[human_coordinates[1], model_coordinates[1]],
391
+ text=f"πŸ§‘ πŸ€– Human & AI guess {human_country}",
392
+ mode="markers",
393
+ hoverinfo="text",
394
+ marker=dict(size=14, color="#FF9500"),
395
+ showlegend=True,
396
+ name="πŸ§‘ πŸ€– Human & AI Guess",
397
+ )
398
+ )
399
  else:
400
  if human_country:
401
+ fig.add_trace(
402
+ go.Scattermapbox(
403
+ lat=[human_coordinates[0]],
404
+ lon=[human_coordinates[1]],
405
+ text=[f"πŸ§‘ Human guesses {human_country}"],
406
+ mode="markers",
407
+ hoverinfo="text",
408
+ marker=dict(size=14, color="#FF9500"),
409
+ showlegend=True,
410
+ name="πŸ§‘ Human Guess",
411
+ )
412
+ )
413
+ fig.add_trace(
414
+ go.Scattermapbox(
415
+ lat=[model_coordinates[0]],
416
+ lon=[model_coordinates[1]],
417
+ text=[f"πŸ€– AI guesses {model_country}"],
418
+ mode="markers",
419
+ hoverinfo="text",
420
+ marker=dict(size=14, color="#474747"),
421
  showlegend=True,
422
+ name="πŸ€– AI Guess",
423
+ )
424
+ )
425
+
 
 
 
 
 
 
 
 
 
426
  fig.update_layout(
427
  mapbox=dict(
428
  style="carto-positron",
429
  center=dict(lat=float(versus_state["lat"]), lon=float(versus_state["lon"])),
430
+ zoom=2,
431
  ),
432
  margin={"r": 0, "t": 0, "l": 0, "b": 0},
433
+ legend=dict(yanchor="bottom", y=0.01, xanchor="left", x=0.01),
 
 
 
 
 
434
  )
435
  return fig
436
 
 
448
  human_points += 1
449
  else:
450
  continent_result = "❌"
451
+ human_result = f"The photo is from **{versus_state['country']}** {country_result} in **{versus_state['continent']}** {continent_result}"
452
+ human_score_update = (
453
+ f"+{human_points} points" if human_points > 0 else "0 Points..."
454
+ )
455
+ versus_state["score"]["HUMAN"] += human_points
456
 
457
+ continent_probs, country_probs, _, _ = predict(input_img)
458
  model_country = max(country_probs, key=country_probs.get)
459
  model_continent = max(continent_probs, key=continent_probs.get)
460
  if model_country == versus_state["country"]:
 
467
  model_points += 1
468
  else:
469
  model_continent_result = "❌"
470
+ model_score_update = (
471
+ f"+{model_points} points"
472
+ if model_points > 0
473
+ else "0 Points... The model was completely wrong, it seems the world is not doomed yet."
474
+ )
475
+ versus_state["score"]["AI"] += model_points
476
 
477
  map = make_versus_map(human_country, model_country, versus_state)
478
+ return (
479
+ f"""
480
  ## {human_result}
481
  ### The AI πŸ€– thinks this photo is from **{model_country}** {model_country_result} in **{model_continent}** {model_continent_result}
482
 
 
484
  πŸ€– {model_score_update}
485
 
486
  ### Score πŸ§‘ {versus_state['score']['HUMAN']} : {versus_state['score']['AI']} πŸ€–
487
+ """,
488
+ continent_probs,
489
+ country_probs,
490
+ map,
491
+ versus_state,
492
+ )
493
 
494
 
495
  def get_example_images(dir):
 
529
 
530
  demo = gr.Blocks(title="Thesis Demo")
531
  with demo:
532
+ gr.HTML(
533
+ """
534
  <h1 style="text-align: center; margin-bottom: 1rem">Image Geolocation Thesis Demo</h1>
535
 
536
  <h3> This Demo showcases the developed models and allows interacting with the optimized prototype.</h3>
537
  <p>Try the <b>"Image Geolocation Demo"</b> tab with your own images or with one of the examples. For all example image the ground truth is available and will be displayed together with the model predictions.</p>
538
  <p>In the <b>"Versus Mode"</b> tab to play against the AI, guessing the country and continent where images where taken. Images in the versus mode are from the <a href="http://graphics.cs.cmu.edu/projects/im2gps/"><code>Im2GPS</code></a> and <a href="https://arxiv.org/abs/1705.04838"><code>Im2GPS3k</code></a> geolocation literature benchmarks. Can you beat the AI?
539
 
540
+ """
541
+ )
542
+ with gr.Accordion(
543
+ label="The demo currently encompasses 116 countries from 6 continents 🌍",
544
+ open=False,
545
+ ):
546
+ gr.Code(
547
+ json.dumps(countries_per_continent, indent=2, ensure_ascii=False),
548
+ label="countries_per_continent.json",
549
+ language="json",
550
+ interactive=False,
551
+ )
552
  with gr.Tab("Image Geolocation Demo"):
553
  with gr.Row():
554
  with gr.Column():
555
+ image = gr.Image(
556
+ label="Image", type="pil", sources=["upload", "clipboard"]
557
+ )
558
  predict_btn = gr.Button("Predict")
559
  example_images = get_example_images("kerger-test-images")
560
  # example_images.extend(get_example_images("versus_images"))
561
+ gr.Examples(examples=example_images, inputs=image, examples_per_page=24)
 
562
  with gr.Column():
563
  with gr.Accordion(visible=False) as metadata_block:
564
  map = gr.Plot(label="Locations")
565
  with gr.Group():
566
  continents_label = gr.Label(label="Continents")
567
+ country_label = gr.Label(num_top_classes=5, label="Top countries")
568
+ predict_btn.click(
569
+ predict,
570
+ inputs=image,
571
+ outputs=[continents_label, country_label, metadata_block, map],
572
+ )
573
 
574
  with gr.Tab("Versus Mode"):
575
  versus_state = gr.State(value=INITAL_VERSUS_STATE)
576
  with gr.Row():
577
  with gr.Column():
578
+ versus_image = gr.Image(INITAL_VERSUS_STATE["image"], interactive=False)
 
579
  continent_selection = gr.Radio(
580
+ continents,
581
+ label="Continents",
582
+ info="Where was this image taken? (1 Point)",
583
+ )
584
+ country_selection = (
585
+ gr.Dropdown(
586
+ countries,
587
+ label="Countries",
588
+ info="Can you guess the exact country? (2 Points)",
589
+ ),
590
+ )
591
  with gr.Row():
592
  next_img_btn = gr.Button("Try new image")
593
  versus_btn = gr.Button("Submit guess")
 
599
  with gr.Group():
600
  continents_label = gr.Label(label="Continents")
601
  country_label = gr.Label(
602
+ num_top_classes=5, label="Top countries"
603
+ )
604
+ next_img_btn.click(
605
+ next_versus_image,
606
+ inputs=[versus_state],
607
+ outputs=[
608
+ versus_image,
609
+ versus_state,
610
+ continent_selection,
611
+ country_selection[0],
612
+ ],
613
+ )
614
+ versus_btn.click(
615
+ versus_mode_inputs,
616
+ inputs=[
617
+ versus_image,
618
+ continent_selection,
619
+ country_selection[0],
620
+ versus_state,
621
+ ],
622
+ outputs=[versus_output, continents_label, country_label, map, versus_state],
623
+ )
624
 
625
 
626
  if __name__ == "__main__":