Spaces:
Runtime error
Runtime error
Jonas Rheiner
commited on
Commit
·
0a2704b
1
Parent(s):
69171b5
initial commit
Browse files- .gitignore +3 -0
- app.py +138 -0
- dataset_examples/africa/3962011747224020_1024.jpg +0 -0
- dataset_examples/africa/3973126792769679_1024.jpg +0 -0
- dataset_examples/africa/4471109009586514_1024.jpg +0 -0
- dataset_examples/asia/106261888221766_1024.jpg +0 -0
- dataset_examples/asia/138321512570044_1024.jpg +0 -0
- dataset_examples/asia/147206360658971_1024.jpg +0 -0
- dataset_examples/europe/1423684677989158_1024.jpg +0 -0
- dataset_examples/europe/1483506425323136_1024.jpg +0 -0
- dataset_examples/europe/150428493699453_1024.jpg +0 -0
- dataset_examples/north america/1000276020376482_1024.jpg +0 -0
- dataset_examples/north america/757125001639938_1024.jpg +0 -0
- dataset_examples/north america/843371313196684_1024.jpg +0 -0
- dataset_examples/oceania/100141636075718_1024.jpg +0 -0
- dataset_examples/oceania/1899604010244250_1024.jpg +0 -0
- dataset_examples/oceania/821397982117703_1024.jpg +0 -0
- dataset_examples/south america/103386512798674_1024.jpg +0 -0
- dataset_examples/south america/1677652415776082_1024.jpg +0 -0
- dataset_examples/south america/327973242016483_1024.jpg +0 -0
- examples/1000276020376482_1024.jpg +0 -0
- examples/100141636075718_1024.jpg +0 -0
- examples/103386512798674_1024.jpg +0 -0
- examples/106261888221766_1024.jpg +0 -0
- examples/138321512570044_1024.jpg +0 -0
- examples/1423684677989158_1024.jpg +0 -0
- examples/147206360658971_1024.jpg +0 -0
- examples/1483506425323136_1024.jpg +0 -0
- examples/150428493699453_1024.jpg +0 -0
- examples/1677652415776082_1024.jpg +0 -0
- examples/1899604010244250_1024.jpg +0 -0
- examples/327973242016483_1024.jpg +0 -0
- examples/3962011747224020_1024.jpg +0 -0
- examples/3973126792769679_1024.jpg +0 -0
- examples/4471109009586514_1024.jpg +0 -0
- examples/757125001639938_1024.jpg +0 -0
- examples/821397982117703_1024.jpg +0 -0
- examples/843371313196684_1024.jpg +0 -0
- versus_images/[email protected] +0 -0
- versus_images/[email protected] +0 -0
- versus_images/[email protected] +0 -0
- versus_images/[email protected] +0 -0
- versus_images/[email protected] +0 -0
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
.venv/
|
2 |
+
model-checkpoint/
|
3 |
+
__pycache__/
|
app.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import gradio as gr
|
3 |
+
from transformers import CLIPProcessor, CLIPModel
|
4 |
+
import torch
|
5 |
+
import itertools
|
6 |
+
import os
|
7 |
+
|
8 |
+
model = CLIPModel.from_pretrained("model-checkpoint")
|
9 |
+
processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14-336")
|
10 |
+
|
11 |
+
continents = ["Africa", "Asia", "Europe", "North America", "Oceania", "South America"]
|
12 |
+
countries_per_continent = {
|
13 |
+
"Africa": [
|
14 |
+
"Algeria", "Angola", "Benin", "Botswana", "Burkina Faso", "Burundi", "Cabo Verde", "Cameroon",
|
15 |
+
"Central African Republic", "Chad", "Comoros", "Congo", "Democratic Republic of the Congo",
|
16 |
+
"Djibouti", "Egypt", "Equatorial Guinea", "Eritrea", "Eswatini", "Ethiopia", "Gabon",
|
17 |
+
"Gambia", "Ghana", "Guinea", "Guinea-Bissau", "Ivory Coast", "Kenya", "Lesotho", "Liberia",
|
18 |
+
"Libya", "Madagascar", "Malawi", "Mali", "Mauritania", "Mauritius", "Morocco", "Mozambique",
|
19 |
+
"Namibia", "Niger", "Nigeria", "Rwanda", "Sao Tome and Principe", "Senegal", "Seychelles",
|
20 |
+
"Sierra Leone", "Somalia", "South Africa", "South Sudan", "Sudan", "Tanzania", "Togo",
|
21 |
+
"Tunisia", "Uganda", "Zambia", "Zimbabwe"
|
22 |
+
],
|
23 |
+
"Asia": [
|
24 |
+
"Afghanistan", "Armenia", "Azerbaijan", "Bahrain", "Bangladesh", "Bhutan", "Brunei",
|
25 |
+
"Cambodia", "China", "Cyprus", "Georgia", "India", "Indonesia", "Iran", "Iraq",
|
26 |
+
"Israel", "Japan", "Jordan", "Kazakhstan", "Kuwait", "Kyrgyzstan", "Laos", "Lebanon",
|
27 |
+
"Malaysia", "Maldives", "Mongolia", "Myanmar", "Nepal", "North Korea", "Oman", "Pakistan",
|
28 |
+
"Palestine", "Philippines", "Qatar", "Russia", "Saudi Arabia", "Singapore", "South Korea",
|
29 |
+
"Sri Lanka", "Syria", "Taiwan", "Tajikistan", "Thailand", "Timor-Leste", "Turkey",
|
30 |
+
"Turkmenistan", "United Arab Emirates", "Uzbekistan", "Vietnam", "Yemen"
|
31 |
+
],
|
32 |
+
"Europe": [
|
33 |
+
"Albania", "Andorra", "Armenia", "Austria", "Azerbaijan", "Belarus", "Belgium", "Bosnia and Herzegovina",
|
34 |
+
"Bulgaria", "Croatia", "Cyprus", "Czech Republic", "Denmark", "Estonia", "Finland", "France",
|
35 |
+
"Georgia", "Germany", "Greece", "Hungary", "Iceland", "Ireland", "Italy", "Kazakhstan",
|
36 |
+
"Kosovo", "Latvia", "Liechtenstein", "Lithuania", "Luxembourg", "Malta", "Moldova", "Monaco",
|
37 |
+
"Montenegro", "Netherlands", "North Macedonia", "Norway", "Poland", "Portugal", "Romania",
|
38 |
+
"Russia", "San Marino", "Serbia", "Slovakia", "Slovenia", "Spain", "Sweden", "Switzerland",
|
39 |
+
"Turkey", "Ukraine", "United Kingdom", "Vatican City"
|
40 |
+
],
|
41 |
+
"North America": [
|
42 |
+
"Antigua and Barbuda", "Bahamas", "Barbados", "Belize", "Canada", "Costa Rica", "Cuba",
|
43 |
+
"Dominica", "Dominican Republic", "El Salvador", "Grenada", "Guatemala", "Haiti", "Honduras",
|
44 |
+
"Jamaica", "Mexico", "Nicaragua", "Panama", "Saint Kitts and Nevis", "Saint Lucia",
|
45 |
+
"Saint Vincent and the Grenadines", "Trinidad and Tobago", "United States"
|
46 |
+
],
|
47 |
+
"Oceania": [
|
48 |
+
"Australia", "Fiji", "Kiribati", "Marshall Islands", "Micronesia", "Nauru", "New Zealand",
|
49 |
+
"Palau", "Papua New Guinea", "Samoa", "Solomon Islands", "Tonga", "Tuvalu", "Vanuatu"
|
50 |
+
],
|
51 |
+
"South America": [
|
52 |
+
"Argentina", "Bolivia", "Brazil", "Chile", "Colombia", "Ecuador", "Guyana", "Paraguay",
|
53 |
+
"Peru", "Suriname", "Uruguay", "Venezuela"
|
54 |
+
]
|
55 |
+
}
|
56 |
+
countries = list(set(itertools.chain.from_iterable(countries_per_continent.values())))
|
57 |
+
|
58 |
+
VERSUS_IMAGE = "versus_images/[email protected]"
|
59 |
+
VERSUS_GT = {
|
60 |
+
"continent": "test",
|
61 |
+
"country": "test"
|
62 |
+
}
|
63 |
+
|
64 |
+
def predict(input_img):
|
65 |
+
inputs = processor(text=[f"A photo from {geo}." for geo in continents], images=input_img, return_tensors="pt", padding=True)
|
66 |
+
with torch.no_grad():
|
67 |
+
outputs = model(**inputs)
|
68 |
+
logits_per_image = outputs.logits_per_image
|
69 |
+
probs = logits_per_image.softmax(dim=-1)
|
70 |
+
pred_id = probs.argmax().cpu().item()
|
71 |
+
continent_probs = {label: prob for label, prob in zip(continents, probs.tolist()[0])}
|
72 |
+
print(continent_probs)
|
73 |
+
|
74 |
+
predicted_continent_countries = countries_per_continent[continents[pred_id]]
|
75 |
+
inputs = processor(text=[f"A photo from {geo}." for geo in predicted_continent_countries], images=input_img, return_tensors="pt", padding=True)
|
76 |
+
with torch.no_grad():
|
77 |
+
outputs = model(**inputs)
|
78 |
+
logits_per_image = outputs.logits_per_image
|
79 |
+
probs = logits_per_image.softmax(dim=-1)
|
80 |
+
country_probs = {label: prob for label, prob in zip(predicted_continent_countries, probs.tolist()[0])}
|
81 |
+
print(country_probs)
|
82 |
+
return continent_probs, country_probs
|
83 |
+
|
84 |
+
def versus_mode_inputs(input_img, human_continent, human_country):
|
85 |
+
print(human_continent, human_country)
|
86 |
+
continent_probs, country_probs = predict(input_img)
|
87 |
+
return f"human guessed {human_continent} {human_country}\ngroung truth {VERSUS_GT}", continent_probs, country_probs
|
88 |
+
|
89 |
+
def next_versus_image():
|
90 |
+
# VERSUS_GT["continent"] = "test"
|
91 |
+
# VERSUS_GT["country"] = "test"
|
92 |
+
return "versus_images/[email protected]"
|
93 |
+
|
94 |
+
def get_example_images(dir):
|
95 |
+
image_extensions = (".jpg", ".jpeg", ".png")
|
96 |
+
image_files = []
|
97 |
+
for root, dirs, files in os.walk(dir):
|
98 |
+
for file in files:
|
99 |
+
if file.lower().endswith(image_extensions):
|
100 |
+
image_files.append(os.path.join(root, file))
|
101 |
+
return image_files
|
102 |
+
|
103 |
+
demo = gr.Blocks()
|
104 |
+
with demo:
|
105 |
+
with gr.Tab("Image Geolocation Demo"):
|
106 |
+
with gr.Row():
|
107 |
+
with gr.Column():
|
108 |
+
image = gr.Image(label="Image", type="pil", sources=["upload","clipboard"])
|
109 |
+
predict_btn = gr.Button("Predict")
|
110 |
+
example_images = get_example_images("examples")
|
111 |
+
example_images.extend(get_example_images("versus_images"))
|
112 |
+
gr.Examples(examples=example_images, inputs=image, examples_per_page=24)
|
113 |
+
with gr.Column():
|
114 |
+
continents_label = gr.Label(label="Continents")
|
115 |
+
country_label = gr.Label(num_top_classes=5, label="Top countries")
|
116 |
+
# continents_label.select(predict_country, inputs=[image, continents_label], outputs=country_label)
|
117 |
+
predict_btn.click(predict, inputs=image, outputs=[continents_label, country_label])
|
118 |
+
|
119 |
+
with gr.Tab("Versus Mode"):
|
120 |
+
with gr.Row():
|
121 |
+
with gr.Column():
|
122 |
+
versus_image = gr.Image(VERSUS_IMAGE, interactive=False)
|
123 |
+
continent_selection = gr.Radio(continents, label="Continents", info="Where was this image taken?")
|
124 |
+
country_selection = gr.Dropdown(countries, label="Countries", info="Can you guess the exact country?"
|
125 |
+
),
|
126 |
+
with gr.Row():
|
127 |
+
next_img_btn = gr.Button("Try new image")
|
128 |
+
versus_btn = gr.Button("Submit guess")
|
129 |
+
with gr.Column():
|
130 |
+
versus_output = gr.Text()
|
131 |
+
continents_label = gr.Label(label="Continents")
|
132 |
+
country_label = gr.Label(num_top_classes=5, label="Top countries")
|
133 |
+
next_img_btn.click(next_versus_image, outputs=versus_image)
|
134 |
+
versus_btn.click(versus_mode_inputs, inputs=[versus_image, continent_selection, country_selection[0]], outputs=[versus_output, continents_label, country_label])
|
135 |
+
|
136 |
+
|
137 |
+
if __name__ == "__main__":
|
138 |
+
demo.launch()
|
dataset_examples/africa/3962011747224020_1024.jpg
ADDED
![]() |
dataset_examples/africa/3973126792769679_1024.jpg
ADDED
![]() |
dataset_examples/africa/4471109009586514_1024.jpg
ADDED
![]() |
dataset_examples/asia/106261888221766_1024.jpg
ADDED
![]() |
dataset_examples/asia/138321512570044_1024.jpg
ADDED
![]() |
dataset_examples/asia/147206360658971_1024.jpg
ADDED
![]() |
dataset_examples/europe/1423684677989158_1024.jpg
ADDED
![]() |
dataset_examples/europe/1483506425323136_1024.jpg
ADDED
![]() |
dataset_examples/europe/150428493699453_1024.jpg
ADDED
![]() |
dataset_examples/north america/1000276020376482_1024.jpg
ADDED
![]() |
dataset_examples/north america/757125001639938_1024.jpg
ADDED
![]() |
dataset_examples/north america/843371313196684_1024.jpg
ADDED
![]() |
dataset_examples/oceania/100141636075718_1024.jpg
ADDED
![]() |
dataset_examples/oceania/1899604010244250_1024.jpg
ADDED
![]() |
dataset_examples/oceania/821397982117703_1024.jpg
ADDED
![]() |
dataset_examples/south america/103386512798674_1024.jpg
ADDED
![]() |
dataset_examples/south america/1677652415776082_1024.jpg
ADDED
![]() |
dataset_examples/south america/327973242016483_1024.jpg
ADDED
![]() |
examples/1000276020376482_1024.jpg
ADDED
![]() |
examples/100141636075718_1024.jpg
ADDED
![]() |
examples/103386512798674_1024.jpg
ADDED
![]() |
examples/106261888221766_1024.jpg
ADDED
![]() |
examples/138321512570044_1024.jpg
ADDED
![]() |
examples/1423684677989158_1024.jpg
ADDED
![]() |
examples/147206360658971_1024.jpg
ADDED
![]() |
examples/1483506425323136_1024.jpg
ADDED
![]() |
examples/150428493699453_1024.jpg
ADDED
![]() |
examples/1677652415776082_1024.jpg
ADDED
![]() |
examples/1899604010244250_1024.jpg
ADDED
![]() |
examples/327973242016483_1024.jpg
ADDED
![]() |
examples/3962011747224020_1024.jpg
ADDED
![]() |
examples/3973126792769679_1024.jpg
ADDED
![]() |
examples/4471109009586514_1024.jpg
ADDED
![]() |
examples/757125001639938_1024.jpg
ADDED
![]() |
examples/821397982117703_1024.jpg
ADDED
![]() |
examples/843371313196684_1024.jpg
ADDED
![]() |
versus_images/[email protected]
ADDED
![]() |
versus_images/[email protected]
ADDED
![]() |
versus_images/[email protected]
ADDED
![]() |
versus_images/[email protected]
ADDED
![]() |
versus_images/[email protected]
ADDED
![]() |