Jonas Rheiner commited on
Commit
0a2704b
·
1 Parent(s): 69171b5

initial commit

Browse files
Files changed (43) hide show
  1. .gitignore +3 -0
  2. app.py +138 -0
  3. dataset_examples/africa/3962011747224020_1024.jpg +0 -0
  4. dataset_examples/africa/3973126792769679_1024.jpg +0 -0
  5. dataset_examples/africa/4471109009586514_1024.jpg +0 -0
  6. dataset_examples/asia/106261888221766_1024.jpg +0 -0
  7. dataset_examples/asia/138321512570044_1024.jpg +0 -0
  8. dataset_examples/asia/147206360658971_1024.jpg +0 -0
  9. dataset_examples/europe/1423684677989158_1024.jpg +0 -0
  10. dataset_examples/europe/1483506425323136_1024.jpg +0 -0
  11. dataset_examples/europe/150428493699453_1024.jpg +0 -0
  12. dataset_examples/north america/1000276020376482_1024.jpg +0 -0
  13. dataset_examples/north america/757125001639938_1024.jpg +0 -0
  14. dataset_examples/north america/843371313196684_1024.jpg +0 -0
  15. dataset_examples/oceania/100141636075718_1024.jpg +0 -0
  16. dataset_examples/oceania/1899604010244250_1024.jpg +0 -0
  17. dataset_examples/oceania/821397982117703_1024.jpg +0 -0
  18. dataset_examples/south america/103386512798674_1024.jpg +0 -0
  19. dataset_examples/south america/1677652415776082_1024.jpg +0 -0
  20. dataset_examples/south america/327973242016483_1024.jpg +0 -0
  21. examples/1000276020376482_1024.jpg +0 -0
  22. examples/100141636075718_1024.jpg +0 -0
  23. examples/103386512798674_1024.jpg +0 -0
  24. examples/106261888221766_1024.jpg +0 -0
  25. examples/138321512570044_1024.jpg +0 -0
  26. examples/1423684677989158_1024.jpg +0 -0
  27. examples/147206360658971_1024.jpg +0 -0
  28. examples/1483506425323136_1024.jpg +0 -0
  29. examples/150428493699453_1024.jpg +0 -0
  30. examples/1677652415776082_1024.jpg +0 -0
  31. examples/1899604010244250_1024.jpg +0 -0
  32. examples/327973242016483_1024.jpg +0 -0
  33. examples/3962011747224020_1024.jpg +0 -0
  34. examples/3973126792769679_1024.jpg +0 -0
  35. examples/4471109009586514_1024.jpg +0 -0
  36. examples/757125001639938_1024.jpg +0 -0
  37. examples/821397982117703_1024.jpg +0 -0
  38. examples/843371313196684_1024.jpg +0 -0
  39. versus_images/[email protected] +0 -0
  40. versus_images/[email protected] +0 -0
  41. versus_images/[email protected] +0 -0
  42. versus_images/[email protected] +0 -0
  43. versus_images/[email protected] +0 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .venv/
2
+ model-checkpoint/
3
+ __pycache__/
app.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import gradio as gr
3
+ from transformers import CLIPProcessor, CLIPModel
4
+ import torch
5
+ import itertools
6
+ import os
7
+
8
+ model = CLIPModel.from_pretrained("model-checkpoint")
9
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14-336")
10
+
11
+ continents = ["Africa", "Asia", "Europe", "North America", "Oceania", "South America"]
12
+ countries_per_continent = {
13
+ "Africa": [
14
+ "Algeria", "Angola", "Benin", "Botswana", "Burkina Faso", "Burundi", "Cabo Verde", "Cameroon",
15
+ "Central African Republic", "Chad", "Comoros", "Congo", "Democratic Republic of the Congo",
16
+ "Djibouti", "Egypt", "Equatorial Guinea", "Eritrea", "Eswatini", "Ethiopia", "Gabon",
17
+ "Gambia", "Ghana", "Guinea", "Guinea-Bissau", "Ivory Coast", "Kenya", "Lesotho", "Liberia",
18
+ "Libya", "Madagascar", "Malawi", "Mali", "Mauritania", "Mauritius", "Morocco", "Mozambique",
19
+ "Namibia", "Niger", "Nigeria", "Rwanda", "Sao Tome and Principe", "Senegal", "Seychelles",
20
+ "Sierra Leone", "Somalia", "South Africa", "South Sudan", "Sudan", "Tanzania", "Togo",
21
+ "Tunisia", "Uganda", "Zambia", "Zimbabwe"
22
+ ],
23
+ "Asia": [
24
+ "Afghanistan", "Armenia", "Azerbaijan", "Bahrain", "Bangladesh", "Bhutan", "Brunei",
25
+ "Cambodia", "China", "Cyprus", "Georgia", "India", "Indonesia", "Iran", "Iraq",
26
+ "Israel", "Japan", "Jordan", "Kazakhstan", "Kuwait", "Kyrgyzstan", "Laos", "Lebanon",
27
+ "Malaysia", "Maldives", "Mongolia", "Myanmar", "Nepal", "North Korea", "Oman", "Pakistan",
28
+ "Palestine", "Philippines", "Qatar", "Russia", "Saudi Arabia", "Singapore", "South Korea",
29
+ "Sri Lanka", "Syria", "Taiwan", "Tajikistan", "Thailand", "Timor-Leste", "Turkey",
30
+ "Turkmenistan", "United Arab Emirates", "Uzbekistan", "Vietnam", "Yemen"
31
+ ],
32
+ "Europe": [
33
+ "Albania", "Andorra", "Armenia", "Austria", "Azerbaijan", "Belarus", "Belgium", "Bosnia and Herzegovina",
34
+ "Bulgaria", "Croatia", "Cyprus", "Czech Republic", "Denmark", "Estonia", "Finland", "France",
35
+ "Georgia", "Germany", "Greece", "Hungary", "Iceland", "Ireland", "Italy", "Kazakhstan",
36
+ "Kosovo", "Latvia", "Liechtenstein", "Lithuania", "Luxembourg", "Malta", "Moldova", "Monaco",
37
+ "Montenegro", "Netherlands", "North Macedonia", "Norway", "Poland", "Portugal", "Romania",
38
+ "Russia", "San Marino", "Serbia", "Slovakia", "Slovenia", "Spain", "Sweden", "Switzerland",
39
+ "Turkey", "Ukraine", "United Kingdom", "Vatican City"
40
+ ],
41
+ "North America": [
42
+ "Antigua and Barbuda", "Bahamas", "Barbados", "Belize", "Canada", "Costa Rica", "Cuba",
43
+ "Dominica", "Dominican Republic", "El Salvador", "Grenada", "Guatemala", "Haiti", "Honduras",
44
+ "Jamaica", "Mexico", "Nicaragua", "Panama", "Saint Kitts and Nevis", "Saint Lucia",
45
+ "Saint Vincent and the Grenadines", "Trinidad and Tobago", "United States"
46
+ ],
47
+ "Oceania": [
48
+ "Australia", "Fiji", "Kiribati", "Marshall Islands", "Micronesia", "Nauru", "New Zealand",
49
+ "Palau", "Papua New Guinea", "Samoa", "Solomon Islands", "Tonga", "Tuvalu", "Vanuatu"
50
+ ],
51
+ "South America": [
52
+ "Argentina", "Bolivia", "Brazil", "Chile", "Colombia", "Ecuador", "Guyana", "Paraguay",
53
+ "Peru", "Suriname", "Uruguay", "Venezuela"
54
+ ]
55
+ }
56
+ countries = list(set(itertools.chain.from_iterable(countries_per_continent.values())))
57
+
58
+ VERSUS_IMAGE = "versus_images/[email protected]"
59
+ VERSUS_GT = {
60
+ "continent": "test",
61
+ "country": "test"
62
+ }
63
+
64
+ def predict(input_img):
65
+ inputs = processor(text=[f"A photo from {geo}." for geo in continents], images=input_img, return_tensors="pt", padding=True)
66
+ with torch.no_grad():
67
+ outputs = model(**inputs)
68
+ logits_per_image = outputs.logits_per_image
69
+ probs = logits_per_image.softmax(dim=-1)
70
+ pred_id = probs.argmax().cpu().item()
71
+ continent_probs = {label: prob for label, prob in zip(continents, probs.tolist()[0])}
72
+ print(continent_probs)
73
+
74
+ predicted_continent_countries = countries_per_continent[continents[pred_id]]
75
+ inputs = processor(text=[f"A photo from {geo}." for geo in predicted_continent_countries], images=input_img, return_tensors="pt", padding=True)
76
+ with torch.no_grad():
77
+ outputs = model(**inputs)
78
+ logits_per_image = outputs.logits_per_image
79
+ probs = logits_per_image.softmax(dim=-1)
80
+ country_probs = {label: prob for label, prob in zip(predicted_continent_countries, probs.tolist()[0])}
81
+ print(country_probs)
82
+ return continent_probs, country_probs
83
+
84
+ def versus_mode_inputs(input_img, human_continent, human_country):
85
+ print(human_continent, human_country)
86
+ continent_probs, country_probs = predict(input_img)
87
+ return f"human guessed {human_continent} {human_country}\ngroung truth {VERSUS_GT}", continent_probs, country_probs
88
+
89
+ def next_versus_image():
90
+ # VERSUS_GT["continent"] = "test"
91
+ # VERSUS_GT["country"] = "test"
92
+ return "versus_images/[email protected]"
93
+
94
+ def get_example_images(dir):
95
+ image_extensions = (".jpg", ".jpeg", ".png")
96
+ image_files = []
97
+ for root, dirs, files in os.walk(dir):
98
+ for file in files:
99
+ if file.lower().endswith(image_extensions):
100
+ image_files.append(os.path.join(root, file))
101
+ return image_files
102
+
103
+ demo = gr.Blocks()
104
+ with demo:
105
+ with gr.Tab("Image Geolocation Demo"):
106
+ with gr.Row():
107
+ with gr.Column():
108
+ image = gr.Image(label="Image", type="pil", sources=["upload","clipboard"])
109
+ predict_btn = gr.Button("Predict")
110
+ example_images = get_example_images("examples")
111
+ example_images.extend(get_example_images("versus_images"))
112
+ gr.Examples(examples=example_images, inputs=image, examples_per_page=24)
113
+ with gr.Column():
114
+ continents_label = gr.Label(label="Continents")
115
+ country_label = gr.Label(num_top_classes=5, label="Top countries")
116
+ # continents_label.select(predict_country, inputs=[image, continents_label], outputs=country_label)
117
+ predict_btn.click(predict, inputs=image, outputs=[continents_label, country_label])
118
+
119
+ with gr.Tab("Versus Mode"):
120
+ with gr.Row():
121
+ with gr.Column():
122
+ versus_image = gr.Image(VERSUS_IMAGE, interactive=False)
123
+ continent_selection = gr.Radio(continents, label="Continents", info="Where was this image taken?")
124
+ country_selection = gr.Dropdown(countries, label="Countries", info="Can you guess the exact country?"
125
+ ),
126
+ with gr.Row():
127
+ next_img_btn = gr.Button("Try new image")
128
+ versus_btn = gr.Button("Submit guess")
129
+ with gr.Column():
130
+ versus_output = gr.Text()
131
+ continents_label = gr.Label(label="Continents")
132
+ country_label = gr.Label(num_top_classes=5, label="Top countries")
133
+ next_img_btn.click(next_versus_image, outputs=versus_image)
134
+ versus_btn.click(versus_mode_inputs, inputs=[versus_image, continent_selection, country_selection[0]], outputs=[versus_output, continents_label, country_label])
135
+
136
+
137
+ if __name__ == "__main__":
138
+ demo.launch()
dataset_examples/africa/3962011747224020_1024.jpg ADDED
dataset_examples/africa/3973126792769679_1024.jpg ADDED
dataset_examples/africa/4471109009586514_1024.jpg ADDED
dataset_examples/asia/106261888221766_1024.jpg ADDED
dataset_examples/asia/138321512570044_1024.jpg ADDED
dataset_examples/asia/147206360658971_1024.jpg ADDED
dataset_examples/europe/1423684677989158_1024.jpg ADDED
dataset_examples/europe/1483506425323136_1024.jpg ADDED
dataset_examples/europe/150428493699453_1024.jpg ADDED
dataset_examples/north america/1000276020376482_1024.jpg ADDED
dataset_examples/north america/757125001639938_1024.jpg ADDED
dataset_examples/north america/843371313196684_1024.jpg ADDED
dataset_examples/oceania/100141636075718_1024.jpg ADDED
dataset_examples/oceania/1899604010244250_1024.jpg ADDED
dataset_examples/oceania/821397982117703_1024.jpg ADDED
dataset_examples/south america/103386512798674_1024.jpg ADDED
dataset_examples/south america/1677652415776082_1024.jpg ADDED
dataset_examples/south america/327973242016483_1024.jpg ADDED
examples/1000276020376482_1024.jpg ADDED
examples/100141636075718_1024.jpg ADDED
examples/103386512798674_1024.jpg ADDED
examples/106261888221766_1024.jpg ADDED
examples/138321512570044_1024.jpg ADDED
examples/1423684677989158_1024.jpg ADDED
examples/147206360658971_1024.jpg ADDED
examples/1483506425323136_1024.jpg ADDED
examples/150428493699453_1024.jpg ADDED
examples/1677652415776082_1024.jpg ADDED
examples/1899604010244250_1024.jpg ADDED
examples/327973242016483_1024.jpg ADDED
examples/3962011747224020_1024.jpg ADDED
examples/3973126792769679_1024.jpg ADDED
examples/4471109009586514_1024.jpg ADDED
examples/757125001639938_1024.jpg ADDED
examples/821397982117703_1024.jpg ADDED
examples/843371313196684_1024.jpg ADDED
versus_images/[email protected] ADDED
versus_images/[email protected] ADDED
versus_images/[email protected] ADDED
versus_images/[email protected] ADDED
versus_images/[email protected] ADDED