Jonas Rheiner commited on
Commit
4b77aea
·
1 Parent(s): 710c658
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +206 -64
  2. dataset_examples/africa/3962011747224020_1024.jpg +0 -0
  3. dataset_examples/africa/3973126792769679_1024.jpg +0 -0
  4. dataset_examples/africa/4471109009586514_1024.jpg +0 -0
  5. dataset_examples/asia/106261888221766_1024.jpg +0 -0
  6. dataset_examples/asia/138321512570044_1024.jpg +0 -0
  7. dataset_examples/asia/147206360658971_1024.jpg +0 -0
  8. dataset_examples/europe/1423684677989158_1024.jpg +0 -0
  9. dataset_examples/europe/1483506425323136_1024.jpg +0 -0
  10. dataset_examples/europe/150428493699453_1024.jpg +0 -0
  11. dataset_examples/north america/1000276020376482_1024.jpg +0 -0
  12. dataset_examples/north america/757125001639938_1024.jpg +0 -0
  13. dataset_examples/north america/843371313196684_1024.jpg +0 -0
  14. dataset_examples/oceania/100141636075718_1024.jpg +0 -0
  15. dataset_examples/oceania/1899604010244250_1024.jpg +0 -0
  16. dataset_examples/oceania/821397982117703_1024.jpg +0 -0
  17. dataset_examples/south america/103386512798674_1024.jpg +0 -0
  18. dataset_examples/south america/1677652415776082_1024.jpg +0 -0
  19. dataset_examples/south america/327973242016483_1024.jpg +0 -0
  20. examples/1000276020376482_1024.jpg +0 -0
  21. examples/100141636075718_1024.jpg +0 -0
  22. examples/103386512798674_1024.jpg +0 -0
  23. examples/106261888221766_1024.jpg +0 -0
  24. examples/138321512570044_1024.jpg +0 -0
  25. examples/1423684677989158_1024.jpg +0 -0
  26. examples/147206360658971_1024.jpg +0 -0
  27. examples/1483506425323136_1024.jpg +0 -0
  28. examples/150428493699453_1024.jpg +0 -0
  29. examples/1677652415776082_1024.jpg +0 -0
  30. examples/1899604010244250_1024.jpg +0 -0
  31. examples/327973242016483_1024.jpg +0 -0
  32. examples/3962011747224020_1024.jpg +0 -0
  33. examples/3973126792769679_1024.jpg +0 -0
  34. examples/4471109009586514_1024.jpg +0 -0
  35. examples/757125001639938_1024.jpg +0 -0
  36. examples/821397982117703_1024.jpg +0 -0
  37. examples/843371313196684_1024.jpg +0 -0
  38. kerger-test-images/Africa_Botswana_-24.358520377382_23.5184910801.jpg +0 -0
  39. kerger-test-images/Africa_Kenya_-0.21870999999999_37.023791.jpg +0 -0
  40. kerger-test-images/Africa_Madagascar_-16.078452454738_46.73369803641.jpg +0 -0
  41. kerger-test-images/Africa_South Africa_-23.590135077274_28.785944164821.jpg +0 -0
  42. kerger-test-images/Africa_Tanzania_-3.3676537025657_36.716512872377.jpg +0 -0
  43. kerger-test-images/Africa_Uganda_1.1212866787272_33.915204986261.jpg +0 -0
  44. kerger-test-images/Asia_Israel_31.708865303742_34.94966916063.jpg +0 -0
  45. kerger-test-images/Asia_Japan_35.381304970616_134.65860211972.jpg +0 -0
  46. kerger-test-images/Asia_Pakistan_24.910493840503_69.506229024537.jpg +0 -0
  47. kerger-test-images/Asia_Russia_54.597757883015_48.163689656865.jpg +0 -0
  48. kerger-test-images/Asia_Russia_56.018311493214_38.359778952407.jpg +0 -0
  49. kerger-test-images/Asia_Russia_60.27835356798_29.754665851696.jpg +0 -0
  50. kerger-test-images/Asia_Thailand_19.824843951089_99.694080339609.jpg +0 -0
app.py CHANGED
@@ -4,92 +4,208 @@ from transformers import CLIPProcessor, CLIPModel
4
  import torch
5
  import itertools
6
  import os
 
7
 
8
- model = CLIPModel.from_pretrained("model-checkpoint")
 
 
 
 
 
 
 
 
9
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14-336")
 
 
 
10
 
11
- continents = ["Africa", "Asia", "Europe", "North America", "Oceania", "South America"]
 
12
  countries_per_continent = {
13
  "Africa": [
14
- "Algeria", "Angola", "Benin", "Botswana", "Burkina Faso", "Burundi", "Cabo Verde", "Cameroon",
15
- "Central African Republic", "Chad", "Comoros", "Congo", "Democratic Republic of the Congo",
16
- "Djibouti", "Egypt", "Equatorial Guinea", "Eritrea", "Eswatini", "Ethiopia", "Gabon",
17
- "Gambia", "Ghana", "Guinea", "Guinea-Bissau", "Ivory Coast", "Kenya", "Lesotho", "Liberia",
18
- "Libya", "Madagascar", "Malawi", "Mali", "Mauritania", "Mauritius", "Morocco", "Mozambique",
19
- "Namibia", "Niger", "Nigeria", "Rwanda", "Sao Tome and Principe", "Senegal", "Seychelles",
20
- "Sierra Leone", "Somalia", "South Africa", "South Sudan", "Sudan", "Tanzania", "Togo",
21
  "Tunisia", "Uganda", "Zambia", "Zimbabwe"
22
  ],
23
  "Asia": [
24
- "Afghanistan", "Armenia", "Azerbaijan", "Bahrain", "Bangladesh", "Bhutan", "Brunei",
25
- "Cambodia", "China", "Cyprus", "Georgia", "India", "Indonesia", "Iran", "Iraq",
26
- "Israel", "Japan", "Jordan", "Kazakhstan", "Kuwait", "Kyrgyzstan", "Laos", "Lebanon",
27
- "Malaysia", "Maldives", "Mongolia", "Myanmar", "Nepal", "North Korea", "Oman", "Pakistan",
28
- "Palestine", "Philippines", "Qatar", "Russia", "Saudi Arabia", "Singapore", "South Korea",
29
- "Sri Lanka", "Syria", "Taiwan", "Tajikistan", "Thailand", "Timor-Leste", "Turkey",
30
  "Turkmenistan", "United Arab Emirates", "Uzbekistan", "Vietnam", "Yemen"
31
  ],
32
  "Europe": [
33
- "Albania", "Andorra", "Armenia", "Austria", "Azerbaijan", "Belarus", "Belgium", "Bosnia and Herzegovina",
34
- "Bulgaria", "Croatia", "Cyprus", "Czech Republic", "Denmark", "Estonia", "Finland", "France",
35
- "Georgia", "Germany", "Greece", "Hungary", "Iceland", "Ireland", "Italy", "Kazakhstan",
36
- "Kosovo", "Latvia", "Liechtenstein", "Lithuania", "Luxembourg", "Malta", "Moldova", "Monaco",
37
- "Montenegro", "Netherlands", "North Macedonia", "Norway", "Poland", "Portugal", "Romania",
38
- "Russia", "San Marino", "Serbia", "Slovakia", "Slovenia", "Spain", "Sweden", "Switzerland",
39
- "Turkey", "Ukraine", "United Kingdom", "Vatican City"
40
  ],
41
  "North America": [
42
- "Antigua and Barbuda", "Bahamas", "Barbados", "Belize", "Canada", "Costa Rica", "Cuba",
43
- "Dominica", "Dominican Republic", "El Salvador", "Grenada", "Guatemala", "Haiti", "Honduras",
44
- "Jamaica", "Mexico", "Nicaragua", "Panama", "Saint Kitts and Nevis", "Saint Lucia",
45
  "Saint Vincent and the Grenadines", "Trinidad and Tobago", "United States"
46
  ],
47
  "Oceania": [
48
- "Australia", "Fiji", "Kiribati", "Marshall Islands", "Micronesia", "Nauru", "New Zealand",
49
  "Palau", "Papua New Guinea", "Samoa", "Solomon Islands", "Tonga", "Tuvalu", "Vanuatu"
50
  ],
51
  "South America": [
52
- "Argentina", "Bolivia", "Brazil", "Chile", "Colombia", "Ecuador", "Guyana", "Paraguay",
53
  "Peru", "Suriname", "Uruguay", "Venezuela"
54
  ]
55
  }
56
- countries = list(set(itertools.chain.from_iterable(countries_per_continent.values())))
 
57
 
58
- VERSUS_IMAGE = "versus_images/254762082_1a1b6d27d1_121_79267922@N00.jpg"
59
- VERSUS_GT = {
60
- "continent": "test",
61
- "country": "test"
 
 
 
 
 
 
 
 
62
  }
63
 
 
64
  def predict(input_img):
65
- inputs = processor(text=[f"A photo from {geo}." for geo in continents], images=input_img, return_tensors="pt", padding=True)
 
 
66
  with torch.no_grad():
67
- outputs = model(**inputs)
68
  logits_per_image = outputs.logits_per_image
69
  probs = logits_per_image.softmax(dim=-1)
70
  pred_id = probs.argmax().cpu().item()
71
- continent_probs = {label: prob for label, prob in zip(continents, probs.tolist()[0])}
72
- print(continent_probs)
73
-
74
  predicted_continent_countries = countries_per_continent[continents[pred_id]]
75
- inputs = processor(text=[f"A photo from {geo}." for geo in predicted_continent_countries], images=input_img, return_tensors="pt", padding=True)
 
 
76
  with torch.no_grad():
77
- outputs = model(**inputs)
78
  logits_per_image = outputs.logits_per_image
79
  probs = logits_per_image.softmax(dim=-1)
80
- country_probs = {label: prob for label, prob in zip(predicted_continent_countries, probs.tolist()[0])}
81
- print(country_probs)
82
  return continent_probs, country_probs
83
 
84
- def versus_mode_inputs(input_img, human_continent, human_country):
85
- print(human_continent, human_country)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  continent_probs, country_probs = predict(input_img)
87
- return f"human guessed {human_continent} {human_country}\ngroung truth {VERSUS_GT}", continent_probs, country_probs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
- def next_versus_image():
90
- # VERSUS_GT["continent"] = "test"
91
- # VERSUS_GT["country"] = "test"
92
- return "versus_images/[email protected]"
93
 
94
  def get_example_images(dir):
95
  image_extensions = (".jpg", ".jpeg", ".png")
@@ -100,39 +216,65 @@ def get_example_images(dir):
100
  image_files.append(os.path.join(root, file))
101
  return image_files
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  demo = gr.Blocks()
104
  with demo:
105
  with gr.Tab("Image Geolocation Demo"):
106
  with gr.Row():
107
  with gr.Column():
108
- image = gr.Image(label="Image", type="pil", sources=["upload","clipboard"])
 
109
  predict_btn = gr.Button("Predict")
110
- example_images = get_example_images("examples")
111
- example_images.extend(get_example_images("versus_images"))
112
- gr.Examples(examples=example_images, inputs=image, examples_per_page=24)
 
113
  with gr.Column():
114
  continents_label = gr.Label(label="Continents")
115
- country_label = gr.Label(num_top_classes=5, label="Top countries")
 
116
  # continents_label.select(predict_country, inputs=[image, continents_label], outputs=country_label)
117
- predict_btn.click(predict, inputs=image, outputs=[continents_label, country_label])
 
118
 
119
  with gr.Tab("Versus Mode"):
 
120
  with gr.Row():
121
  with gr.Column():
122
- versus_image = gr.Image(VERSUS_IMAGE, interactive=False)
123
- continent_selection = gr.Radio(continents, label="Continents", info="Where was this image taken?")
124
- country_selection = gr.Dropdown(countries, label="Countries", info="Can you guess the exact country?"
125
- ),
 
 
126
  with gr.Row():
127
  next_img_btn = gr.Button("Try new image")
128
  versus_btn = gr.Button("Submit guess")
129
  with gr.Column():
130
- versus_output = gr.Text()
131
- continents_label = gr.Label(label="Continents")
132
- country_label = gr.Label(num_top_classes=5, label="Top countries")
133
- next_img_btn.click(next_versus_image, outputs=versus_image)
134
- versus_btn.click(versus_mode_inputs, inputs=[versus_image, continent_selection, country_selection[0]], outputs=[versus_output, continents_label, country_label])
 
 
 
 
 
135
 
136
 
137
  if __name__ == "__main__":
138
- demo.launch()
 
4
  import torch
5
  import itertools
6
  import os
7
+ import plotly.graph_objects as go
8
 
9
+
10
+ CUDA_AVAILABLE = torch.cuda.is_available()
11
+ print(f"CUDA={CUDA_AVAILABLE}")
12
+ device = "cuda" if CUDA_AVAILABLE else "cpu"
13
+ print(f"count={torch.cuda.device_count()}")
14
+ print(f"current={torch.cuda.get_device_name(torch.cuda.current_device())}")
15
+
16
+ continent_model = CLIPModel.from_pretrained("model-checkpoints/continent")
17
+ country_model = CLIPModel.from_pretrained("model-checkpoints/country")
18
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14-336")
19
+ continent_model = continent_model.to(device)
20
+ country_model = country_model.to(device)
21
+
22
 
23
+ continents = ["Africa", "Asia", "Europe",
24
+ "North America", "Oceania", "South America"]
25
  countries_per_continent = {
26
  "Africa": [
27
+ "Algeria", "Angola", "Benin", "Botswana", "Burkina Faso", "Burundi", "Cabo Verde", "Cameroon",
28
+ "Central African Republic", "Congo", "Democratic Republic of the Congo",
29
+ "Djibouti", "Egypt", "Equatorial Guinea", "Eritrea", "Eswatini", "Ethiopia", "Gabon",
30
+ "Gambia", "Ghana", "Guinea", "Guinea-Bissau", "Ivory Coast", "Kenya", "Lesotho", "Liberia",
31
+ "Libya", "Madagascar", "Malawi", "Mali", "Mauritania", "Mauritius", "Morocco", "Mozambique",
32
+ "Namibia", "Niger", "Nigeria", "Rwanda", "Sao Tome and Principe", "Senegal", "Seychelles",
33
+ "Sierra Leone", "Somalia", "South Africa", "Sudan", "Tanzania", "Togo",
34
  "Tunisia", "Uganda", "Zambia", "Zimbabwe"
35
  ],
36
  "Asia": [
37
+ "Afghanistan", "Armenia", "Azerbaijan", "Bahrain", "Bangladesh", "Bhutan", "Brunei",
38
+ "Cambodia", "China", "Cyprus", "Georgia", "India", "Indonesia", "Iran", "Iraq",
39
+ "Israel", "Japan", "Jordan", "Kazakhstan", "Kuwait", "Kyrgyzstan", "Laos", "Lebanon",
40
+ "Malaysia", "Maldives", "Mongolia", "Myanmar", "Nepal", "North Korea", "Oman", "Pakistan",
41
+ "Palestine", "Philippines", "Qatar", "Russia", "Saudi Arabia", "Singapore", "South Korea",
42
+ "Sri Lanka", "Syria", "Taiwan", "Tajikistan", "Thailand", "Timor-Leste", "Turkey",
43
  "Turkmenistan", "United Arab Emirates", "Uzbekistan", "Vietnam", "Yemen"
44
  ],
45
  "Europe": [
46
+ "Albania", "Armenia", "Austria", "Azerbaijan", "Belarus", "Belgium", "Bosnia and Herzegovina",
47
+ "Bulgaria", "Croatia", "Cyprus", "Czech Republic", "Denmark", "Estonia", "Finland", "France",
48
+ "Georgia", "Germany", "Greece", "Hungary", "Iceland", "Ireland", "Italy", "Kazakhstan",
49
+ "Kosovo", "Latvia", "Liechtenstein", "Lithuania", "Luxembourg", "Malta", "Moldova", "Monaco",
50
+ "Montenegro", "Netherlands", "North Macedonia", "Norway", "Poland", "Portugal", "Romania",
51
+ "Russia", "San Marino", "Serbia", "Slovakia", "Slovenia", "Spain", "Sweden", "Switzerland",
52
+ "Turkey", "Ukraine", "United Kingdom"
53
  ],
54
  "North America": [
55
+ "Antigua and Barbuda", "Bahamas", "Barbados", "Belize", "Canada", "Costa Rica", "Cuba",
56
+ "Dominica", "Dominican Republic", "El Salvador", "Grenada", "Guatemala", "Haiti", "Honduras",
57
+ "Jamaica", "Mexico", "Nicaragua", "Panama", "Saint Kitts and Nevis", "Saint Lucia",
58
  "Saint Vincent and the Grenadines", "Trinidad and Tobago", "United States"
59
  ],
60
  "Oceania": [
61
+ "Australia", "Fiji", "Kiribati", "Marshall Islands", "Micronesia", "Nauru", "New Zealand",
62
  "Palau", "Papua New Guinea", "Samoa", "Solomon Islands", "Tonga", "Tuvalu", "Vanuatu"
63
  ],
64
  "South America": [
65
+ "Argentina", "Bolivia", "Brazil", "Chile", "Colombia", "Ecuador", "Guyana", "Paraguay",
66
  "Peru", "Suriname", "Uruguay", "Venezuela"
67
  ]
68
  }
69
+ countries = list(set(itertools.chain.from_iterable(
70
+ countries_per_continent.values())))
71
 
72
+ INTIAL_VERSUS_IMAGE = "versus_images/Europe_Germany_49.069183_10.319444_im2gps3k.jpg"
73
+ INITAL_VERSUS_STATE = {
74
+ "image": INTIAL_VERSUS_IMAGE,
75
+ "continent": INTIAL_VERSUS_IMAGE.split("/")[-1].split("_")[0],
76
+ "country": INTIAL_VERSUS_IMAGE.split("/")[-1].split("_")[1],
77
+ "lat": INTIAL_VERSUS_IMAGE.split("/")[-1].split("_")[2],
78
+ "lon": INTIAL_VERSUS_IMAGE.split("/")[-1].split("_")[3],
79
+ "score": {
80
+ "HUMAN": 0,
81
+ "AI": 0
82
+ },
83
+ "idx": 0
84
  }
85
 
86
+
87
  def predict(input_img):
88
+ inputs = processor(text=[f"A photo from {
89
+ geo}." for geo in continents], images=input_img, return_tensors="pt", padding=True)
90
+ inputs = inputs.to(device)
91
  with torch.no_grad():
92
+ outputs = continent_model(**inputs)
93
  logits_per_image = outputs.logits_per_image
94
  probs = logits_per_image.softmax(dim=-1)
95
  pred_id = probs.argmax().cpu().item()
96
+ continent_probs = {label: prob for label,
97
+ prob in zip(continents, probs.tolist()[0])}
98
+
99
  predicted_continent_countries = countries_per_continent[continents[pred_id]]
100
+ inputs = processor(text=[f"A photo from {
101
+ geo}." for geo in predicted_continent_countries], images=input_img, return_tensors="pt", padding=True)
102
+ inputs = inputs.to(device)
103
  with torch.no_grad():
104
+ outputs = country_model(**inputs)
105
  logits_per_image = outputs.logits_per_image
106
  probs = logits_per_image.softmax(dim=-1)
107
+ country_probs = {label: prob for label, prob in zip(
108
+ predicted_continent_countries, probs.tolist()[0])}
109
  return continent_probs, country_probs
110
 
111
+
112
+ def make_versus_map(human_country, model_country, versus_state):
113
+ fig = go.Figure()
114
+ fig.add_trace(go.Scattergeo(
115
+ lon=[versus_state["lon"]],
116
+ lat=[versus_state["lat"]],
117
+ text=["📷"],
118
+ mode='text+markers',
119
+ hoverinfo='text',
120
+ hovertext=f"Photo taken in {versus_state['country']}, {
121
+ versus_state['continent']}",
122
+ marker=dict(size=14, color='#00B945'),
123
+ showlegend=False
124
+
125
+ ))
126
+ if human_country == model_country:
127
+ fig.add_trace(go.Scattergeo(
128
+ locations=[human_country],
129
+ locationmode='country names',
130
+ text=["🧑🤖"],
131
+ mode='text',
132
+ hoverinfo='location',
133
+ showlegend=False
134
+
135
+ ))
136
+ else:
137
+ fig.add_trace(go.Scattergeo(
138
+ locations=[human_country],
139
+ locationmode='country names',
140
+ text=["🧑"],
141
+ mode='text',
142
+ hoverinfo='location',
143
+ showlegend=False
144
+
145
+ ))
146
+ fig.add_trace(go.Scattergeo(
147
+ locations=[model_country],
148
+ locationmode='country names',
149
+ text=["🤖"],
150
+ mode='text',
151
+ hoverinfo='location',
152
+ showlegend=False
153
+
154
+ ))
155
+ fig.update_geos(
156
+ visible=True, resolution=110,
157
+ showcountries=True, countrycolor="grey", fitbounds="locations", projection_type="natural earth",
158
+ )
159
+ return fig
160
+
161
+
162
+ def versus_mode_inputs(input_img, human_continent, human_country, versus_state):
163
+ human_points = 0
164
+ model_points = 0
165
+ if human_country == versus_state["country"]:
166
+ country_result = "✅"
167
+ human_points += 2
168
+ else:
169
+ country_result = "❌"
170
+ if human_continent == versus_state["continent"]:
171
+ continent_result = "✅"
172
+ human_points += 1
173
+ else:
174
+ continent_result = "❌"
175
+ human_result = f"The photo is from **{versus_state['country']}** {
176
+ country_result} in **{versus_state['continent']}** {continent_result}"
177
+ human_score_update = f"+{
178
+ human_points} points" if human_points > 0 else "Wrong guess, try a new image."
179
+ versus_state['score']['HUMAN'] += human_points
180
+
181
  continent_probs, country_probs = predict(input_img)
182
+ model_country = max(country_probs, key=country_probs.get)
183
+ model_continent = max(continent_probs, key=continent_probs.get)
184
+ if model_country == versus_state["country"]:
185
+ model_country_result = "✅"
186
+ model_points += 2
187
+ else:
188
+ model_country_result = "❌"
189
+ if model_continent == versus_state["continent"]:
190
+ model_continent_result = "✅"
191
+ model_points += 1
192
+ else:
193
+ model_continent_result = "❌"
194
+ model_score_update = f"+{
195
+ model_points} points" if model_points > 0 else "The model was wrong, seems the world is not yet doomed."
196
+ versus_state['score']['AI'] += model_points
197
+
198
+ map = make_versus_map(human_country, model_country, versus_state)
199
+ return f"""
200
+ ## {human_result}
201
+ ### The AI 🤖 thinks this photo is from **{model_country}** {model_country_result} in **{model_continent}** {model_continent_result}
202
+
203
+ 🧑 {human_score_update}
204
+ 🤖 {model_score_update}
205
+
206
+ ### Score 🧑 {versus_state['score']['HUMAN']} : {versus_state['score']['AI']} 🤖
207
+ """, continent_probs, country_probs, map, versus_state
208
 
 
 
 
 
209
 
210
  def get_example_images(dir):
211
  image_extensions = (".jpg", ".jpeg", ".png")
 
216
  image_files.append(os.path.join(root, file))
217
  return image_files
218
 
219
+
220
+ def next_versus_image(versus_state):
221
+ images = get_example_images("versus_images")
222
+ versus_state["idx"] += 1
223
+ if versus_state["idx"] > len(images):
224
+ versus_state["idx"] = 0
225
+ versus_image = images[versus_state["idx"]]
226
+ versus_state["continent"] = versus_image.split("/")[-1].split("_")[0]
227
+ versus_state["country"] = versus_image.split("/")[-1].split("_")[1]
228
+ versus_state["lat"] = versus_image.split("/")[-1].split("_")[2]
229
+ versus_state["lon"] = versus_image.split("/")[-1].split("_")[3]
230
+ versus_state["image"] = versus_image
231
+ return versus_image, versus_state, None, None
232
+
233
  demo = gr.Blocks()
234
  with demo:
235
  with gr.Tab("Image Geolocation Demo"):
236
  with gr.Row():
237
  with gr.Column():
238
+ image = gr.Image(label="Image", type="pil",
239
+ sources=["upload", "clipboard"])
240
  predict_btn = gr.Button("Predict")
241
+ example_images = get_example_images("kerger-test-images")
242
+ # example_images.extend(get_example_images("versus_images"))
243
+ gr.Examples(examples=example_images,
244
+ inputs=image, examples_per_page=24)
245
  with gr.Column():
246
  continents_label = gr.Label(label="Continents")
247
+ country_label = gr.Label(
248
+ num_top_classes=5, label="Top countries")
249
  # continents_label.select(predict_country, inputs=[image, continents_label], outputs=country_label)
250
+ predict_btn.click(predict, inputs=image, outputs=[
251
+ continents_label, country_label])
252
 
253
  with gr.Tab("Versus Mode"):
254
+ versus_state = gr.State(value=INITAL_VERSUS_STATE)
255
  with gr.Row():
256
  with gr.Column():
257
+ versus_image = gr.Image(
258
+ INITAL_VERSUS_STATE["image"], interactive=False)
259
+ continent_selection = gr.Radio(
260
+ continents, label="Continents", info="Where was this image taken? (1 Point)")
261
+ country_selection = gr.Dropdown(countries, label="Countries", info="Can you guess the exact country? (2 Points)"
262
+ ),
263
  with gr.Row():
264
  next_img_btn = gr.Button("Try new image")
265
  versus_btn = gr.Button("Submit guess")
266
  with gr.Column():
267
+ versus_output = gr.Markdown()
268
+ # with gr.Accordion("View Map", open=False):
269
+ map = gr.Plot(label="Locations")
270
+ with gr.Accordion("Full Model Output", open=False):
271
+ continents_label = gr.Label(label="Continents")
272
+ country_label = gr.Label(
273
+ num_top_classes=5, label="Top countries")
274
+ next_img_btn.click(next_versus_image, inputs=[versus_state], outputs=[versus_image, versus_state, continent_selection, country_selection[0]])
275
+ versus_btn.click(versus_mode_inputs, inputs=[versus_image, continent_selection, country_selection[0], versus_state], outputs=[
276
+ versus_output, continents_label, country_label, map, versus_state])
277
 
278
 
279
  if __name__ == "__main__":
280
+ demo.launch()
dataset_examples/africa/3962011747224020_1024.jpg DELETED
Binary file (145 kB)
 
dataset_examples/africa/3973126792769679_1024.jpg DELETED
Binary file (69.5 kB)
 
dataset_examples/africa/4471109009586514_1024.jpg DELETED
Binary file (147 kB)
 
dataset_examples/asia/106261888221766_1024.jpg DELETED
Binary file (131 kB)
 
dataset_examples/asia/138321512570044_1024.jpg DELETED
Binary file (78.1 kB)
 
dataset_examples/asia/147206360658971_1024.jpg DELETED
Binary file (96.2 kB)
 
dataset_examples/europe/1423684677989158_1024.jpg DELETED
Binary file (84.3 kB)
 
dataset_examples/europe/1483506425323136_1024.jpg DELETED
Binary file (46.8 kB)
 
dataset_examples/europe/150428493699453_1024.jpg DELETED
Binary file (140 kB)
 
dataset_examples/north america/1000276020376482_1024.jpg DELETED
Binary file (66 kB)
 
dataset_examples/north america/757125001639938_1024.jpg DELETED
Binary file (162 kB)
 
dataset_examples/north america/843371313196684_1024.jpg DELETED
Binary file (257 kB)
 
dataset_examples/oceania/100141636075718_1024.jpg DELETED
Binary file (84.2 kB)
 
dataset_examples/oceania/1899604010244250_1024.jpg DELETED
Binary file (46.2 kB)
 
dataset_examples/oceania/821397982117703_1024.jpg DELETED
Binary file (75.7 kB)
 
dataset_examples/south america/103386512798674_1024.jpg DELETED
Binary file (135 kB)
 
dataset_examples/south america/1677652415776082_1024.jpg DELETED
Binary file (94.5 kB)
 
dataset_examples/south america/327973242016483_1024.jpg DELETED
Binary file (92.1 kB)
 
examples/1000276020376482_1024.jpg DELETED
Binary file (66 kB)
 
examples/100141636075718_1024.jpg DELETED
Binary file (84.2 kB)
 
examples/103386512798674_1024.jpg DELETED
Binary file (135 kB)
 
examples/106261888221766_1024.jpg DELETED
Binary file (131 kB)
 
examples/138321512570044_1024.jpg DELETED
Binary file (78.1 kB)
 
examples/1423684677989158_1024.jpg DELETED
Binary file (84.3 kB)
 
examples/147206360658971_1024.jpg DELETED
Binary file (96.2 kB)
 
examples/1483506425323136_1024.jpg DELETED
Binary file (46.8 kB)
 
examples/150428493699453_1024.jpg DELETED
Binary file (140 kB)
 
examples/1677652415776082_1024.jpg DELETED
Binary file (94.5 kB)
 
examples/1899604010244250_1024.jpg DELETED
Binary file (46.2 kB)
 
examples/327973242016483_1024.jpg DELETED
Binary file (92.1 kB)
 
examples/3962011747224020_1024.jpg DELETED
Binary file (145 kB)
 
examples/3973126792769679_1024.jpg DELETED
Binary file (69.5 kB)
 
examples/4471109009586514_1024.jpg DELETED
Binary file (147 kB)
 
examples/757125001639938_1024.jpg DELETED
Binary file (162 kB)
 
examples/821397982117703_1024.jpg DELETED
Binary file (75.7 kB)
 
examples/843371313196684_1024.jpg DELETED
Binary file (257 kB)
 
kerger-test-images/Africa_Botswana_-24.358520377382_23.5184910801.jpg ADDED
kerger-test-images/Africa_Kenya_-0.21870999999999_37.023791.jpg ADDED
kerger-test-images/Africa_Madagascar_-16.078452454738_46.73369803641.jpg ADDED
kerger-test-images/Africa_South Africa_-23.590135077274_28.785944164821.jpg ADDED
kerger-test-images/Africa_Tanzania_-3.3676537025657_36.716512872377.jpg ADDED
kerger-test-images/Africa_Uganda_1.1212866787272_33.915204986261.jpg ADDED
kerger-test-images/Asia_Israel_31.708865303742_34.94966916063.jpg ADDED
kerger-test-images/Asia_Japan_35.381304970616_134.65860211972.jpg ADDED
kerger-test-images/Asia_Pakistan_24.910493840503_69.506229024537.jpg ADDED
kerger-test-images/Asia_Russia_54.597757883015_48.163689656865.jpg ADDED
kerger-test-images/Asia_Russia_56.018311493214_38.359778952407.jpg ADDED
kerger-test-images/Asia_Russia_60.27835356798_29.754665851696.jpg ADDED
kerger-test-images/Asia_Thailand_19.824843951089_99.694080339609.jpg ADDED