Spaces:
Runtime error
Runtime error
import numpy as np | |
import gradio as gr | |
from transformers import CLIPProcessor, CLIPModel | |
import torch | |
import itertools | |
import os | |
import plotly.graph_objects as go | |
import hashlib | |
from PIL import Image | |
import json | |
os.environ["PYTHONHASHSEED"] = "42" | |
CUDA_AVAILABLE = torch.cuda.is_available() | |
print(f"CUDA={CUDA_AVAILABLE}") | |
device = "cuda" if CUDA_AVAILABLE else "cpu" | |
print(f"count={torch.cuda.device_count()}") | |
print(f"current={torch.cuda.get_device_name(torch.cuda.current_device())}") | |
continent_model = CLIPModel.from_pretrained("jrheiner/thesis-clip-geoloc-continent", token=os.getenv("token")) | |
country_model = CLIPModel.from_pretrained("jrheiner/thesis-clip-geoloc-country", token=os.getenv("token")) | |
processor = CLIPProcessor.from_pretrained("jrheiner/thesis-clip-geoloc-continent", token=os.getenv("token")) | |
continent_model = continent_model.to(device) | |
country_model = country_model.to(device) | |
continents = ["Africa", "Asia", "Europe", | |
"North America", "Oceania", "South America"] | |
countries_per_continent = { | |
"Africa": [ | |
"Botswana", "Eswatini", "Ghana", "Kenya", "Lesotho", "Nigeria", "Senegal", | |
"South Africa", "Rwanda", "Uganda", "Tanzania", "Madagascar", "Djibouti", | |
"Mali", "Libya", "Morocco", "Somalia", "Tunisia", "Egypt", "RΓ©union" | |
], | |
"Asia": [ | |
"Bangladesh", "Bhutan", "Cambodia", "China", "India", "Indonesia", "Israel", | |
"Japan", "Jordan", "Kyrgyzstan", "Laos", "Malaysia", "Mongolia", "Nepal", | |
"Palestine", "Philippines", "Singapore", "South Korea", "Sri Lanka", | |
"Taiwan", "Thailand", "United Arab Emirates", "Vietnam", "Afghanistan", | |
"Azerbaijan", "Cyprus", "Iran", "Syria", "Tajikistan", "Turkey", "Russia", | |
"Pakistan", "Hong Kong" | |
], | |
"Europe": [ | |
"Albania", "Andorra", "Austria", "Belgium", "Bulgaria", "Croatia", "Czechia", | |
"Denmark", "Estonia", "Finland", "France", "Germany", "Greece", "Hungary", | |
"Iceland", "Ireland", "Italy", "Latvia", "Lithuania", "Luxembourg", | |
"Montenegro", "Netherlands", "North Macedonia", "Norway", "Poland", | |
"Portugal", "Romania", "Russia", "Serbia", "Slovakia", "Slovenia", "Spain", | |
"Sweden", "Switzerland", "Ukraine", "United Kingdom", "Bosnia and Herzegovina", | |
"Cyprus", "Turkey", "Greenland", "Faroe Islands" | |
], | |
"North America": [ | |
"Canada", "Dominican Republic", "Guatemala", "Mexico", "United States", | |
"Bahamas", "Cuba", "Panama", "Puerto Rico", "Bermuda", "Greenland" | |
], | |
"Oceania": [ | |
"Australia", "New Zealand", "Fiji", "Papua New Guinea", "Solomon Islands", "Vanuatu" | |
], | |
"South America": [ | |
"Argentina", "Bolivia", "Brazil", "Chile", "Colombia", "Ecuador", "Paraguay", | |
"Peru", "Uruguay" | |
] | |
} | |
countries = list(set(itertools.chain.from_iterable( | |
countries_per_continent.values()))) | |
country_to_center_coords = { | |
"Indonesia": (-2.4833826, 117.8902853), | |
"Egypt": (26.2540493, 29.2675469), | |
"Dominican Republic": (19.0974031, -70.3028026), | |
"Russia": (64.6863136, 97.7453061), | |
"Denmark": (55.670249, 10.3333283), | |
"Latvia": (56.8406494, 24.7537645), | |
"Hong Kong": (22.350627, 114.1849161), | |
"Brazil": (-10.3333333, -53.2), | |
"Turkey": (38.9597594, 34.9249653), | |
"Paraguay": (-23.3165935, -58.1693445), | |
"Nigeria": (9.6000359, 7.9999721), | |
"United Kingdom": (54.7023545, -3.2765753), | |
"Argentina": (-34.9964963, -64.9672817), | |
"United Arab Emirates": (24.0002488, 53.9994829), | |
"Estonia": (58.7523778, 25.3319078), | |
"Greenland": (69.6354163, -42.1736914), | |
"Canada": (61.0666922, -107.991707), | |
"Andorra": (42.5407167, 1.5732033), | |
"Czechia": (49.7439047, 15.3381061), | |
"Australia": (-24.7761086, 134.755), | |
"Azerbaijan": (40.3936294, 47.7872508), | |
"Cambodia": (12.5433216, 104.8144914), | |
"Peru": (-6.8699697, -75.0458515), | |
"Slovakia": (48.7411522, 19.4528646), | |
"RΓ©union": (-21.130737949999997, 55.536480112992315), | |
"France": (46.603354, 1.8883335), | |
"Israel": (30.8124247, 34.8594762), | |
"China": (35.000074, 104.999927), | |
"Ecuador": (-1.3397668, -79.3666965), | |
"Poland": (52.215933, 19.134422), | |
"Switzerland": (46.7985624, 8.2319736), | |
"Singapore": (1.357107, 103.8194992), | |
"Kenya": (1.4419683, 38.4313975), | |
"Bhutan": (27.549511, 90.5119273), | |
"Laos": (20.0171109, 103.378253), | |
"Vietnam": (15.9266657, 107.9650855), | |
"Puerto Rico": (18.2247706, -66.4858295), | |
"Germany": (51.1638175, 10.4478313), | |
"Tanzania": (-6.5247123, 35.7878438), | |
"Colombia": (4.099917, -72.9088133), | |
"Italy": (42.6384261, 12.674297), | |
"Bahamas": (24.7736546, -78.0000547), | |
"Panama": (8.559559, -81.1308434), | |
"Bulgaria": (42.6073975, 25.4856617), | |
"Solomon Islands": (-8.7053941, 159.1070693851845), | |
"Afghanistan": (33.7680065, 66.2385139), | |
"Tajikistan": (38.6281733, 70.8156541), | |
"Portugal": (39.6621648, -8.1353519), | |
"Tunisia": (36.8002068, 10.1857757), | |
"Bolivia": (-17.0568696, -64.9912286), | |
"Malaysia": (4.5693754, 102.2656823), | |
"Lithuania": (55.3500003, 23.7499997), | |
"Sweden": (59.6749712, 14.5208584), | |
"Belgium": (50.6402809, 4.6667145), | |
"Libya": (26.8234472, 18.1236723), | |
"Guatemala": (15.5855545, -90.345759), | |
"India": (22.3511148, 78.6677428), | |
"Sri Lanka": (7.5554942, 80.7137847), | |
"New Zealand": (-41.5000831, 172.8344077), | |
"Iceland": (64.9841821, -18.1059013), | |
"Somalia": (8.3676771, 49.083416), | |
"Croatia": (45.3658443, 15.6575209), | |
"Bosnia and Herzegovina": (44.3053476, 17.5961467), | |
"Greece": (38.9953683, 21.9877132), | |
"Rwanda": (-1.9646631, 30.0644358), | |
"Hungary": (47.1817585, 19.5060937), | |
"Eswatini": (-26.5624806, 31.3991317), | |
"Kyrgyzstan": (41.5089324, 74.724091), | |
"Bangladesh": (23.6943117, 90.344352), | |
"Morocco": (28.3347722, -10.371337908392647), | |
"Finland": (63.2467777, 25.9209164), | |
"Luxembourg": (49.6112768, 6.129799), | |
"North Macedonia": (41.6171214, 21.7168387), | |
"Uruguay": (-32.8755548, -56.0201525), | |
"Chile": (-31.7613365, -71.3187697), | |
"Spain": (39.3260685, -4.8379791), | |
"South Korea": (36.638392, 127.6961188), | |
"Botswana": (-23.1681782, 24.5928742), | |
"Uganda": (1.5333554, 32.2166578), | |
"Papua New Guinea": (-5.6816069, 144.2489081), | |
"Mali": (16.3700359, -2.2900239), | |
"Philippines": (12.7503486, 122.7312101), | |
"Norway": (64.5731537, 11.52803643954819), | |
"Thailand": (14.8971921, 100.83273), | |
"Mongolia": (46.8651082, 103.8347844), | |
"Japan": (36.5748441, 139.2394179), | |
"Montenegro": (42.7044223, 19.3957785), | |
"Austria": (47.59397, 14.12456), | |
"Taiwan": (23.6978, 120.9605), | |
"Netherlands": (52.2434979, 5.6343227), | |
"Ukraine": (49.4871968, 31.2718321), | |
"Fiji": (-18.1239696, 179.0122737), | |
"Ghana": (8.0300284, -1.0800271), | |
"Cuba": (23.0131338, -80.8328748), | |
"Nepal": (28.3780464, 83.9999901), | |
"Faroe Islands": (62.0448724, -7.0322972), | |
"Slovenia": (46.1199444, 14.8153333), | |
"Cyprus": (34.9174159, 32.889902651331866), | |
"Serbia": (44.024322850000004, 21.07657433209902), | |
"Madagascar": (-18.9249604, 46.4416422), | |
"Pakistan": (30.3308401, 71.247499), | |
"Syria": (34.6401861, 39.0494106), | |
"Iran": (32.6475314, 54.5643516), | |
"Ireland": (52.865196, -7.9794599), | |
"South Africa": (-28.8166236, 24.991639), | |
"Albania": (41.1529058, 20.1605717), | |
"Lesotho": (-29.6039267, 28.3350193), | |
"Romania": (45.9852129, 24.6859225), | |
"Palestine": (31.947351, 35.227163), | |
"Vanuatu": (-16.5255069, 168.1069154), | |
"Mexico": (19.4326296, -99.1331785), | |
"Jordan": (31.279862, 37.1297454), | |
"Djibouti": (11.8145966, 42.8453061), | |
"Senegal": (14.4750607, -14.4529612), | |
"Bermuda": (32.3040273, -64.7563086), | |
"United States": (39.7837304, -100.445882) | |
} | |
INTIAL_VERSUS_IMAGE = "versus_images/Europe_Germany_49.069183_10.319444_im2gps3k.jpg" | |
INITAL_VERSUS_STATE = { | |
"image": INTIAL_VERSUS_IMAGE, | |
"continent": INTIAL_VERSUS_IMAGE.split("/")[-1].split("_")[0], | |
"country": INTIAL_VERSUS_IMAGE.split("/")[-1].split("_")[1], | |
"lat": INTIAL_VERSUS_IMAGE.split("/")[-1].split("_")[2], | |
"lon": INTIAL_VERSUS_IMAGE.split("/")[-1].split("_")[3], | |
"score": { | |
"HUMAN": 0, | |
"AI": 0 | |
}, | |
"idx": 0 | |
} | |
def predict(input_img): | |
inputs = processor(text=[f"A photo from { | |
geo}." for geo in continents], images=input_img, return_tensors="pt", padding=True) | |
inputs = inputs.to(device) | |
with torch.no_grad(): | |
outputs = continent_model(**inputs) | |
logits_per_image = outputs.logits_per_image | |
probs = logits_per_image.softmax(dim=-1) | |
pred_id = probs.argmax().cpu().item() | |
continent_probs = {label: prob for label, | |
prob in zip(continents, probs.tolist()[0])} | |
model_continent = continents[pred_id] | |
predicted_continent_countries = countries_per_continent[model_continent] | |
inputs = processor(text=[f"A photo from { | |
geo}." for geo in predicted_continent_countries], images=input_img, return_tensors="pt", padding=True) | |
inputs = inputs.to(device) | |
with torch.no_grad(): | |
outputs = country_model(**inputs) | |
logits_per_image = outputs.logits_per_image | |
probs = logits_per_image.softmax(dim=-1) | |
pred_id = probs.argmax().cpu().item() | |
model_country = predicted_continent_countries[pred_id] | |
country_probs = {label: prob for label, prob in zip( | |
predicted_continent_countries, probs.tolist()[0])} | |
hash = hashlib.sha1(np.asarray(input_img).data.tobytes()).hexdigest() | |
metadata_block = gr.Accordion(visible=False) | |
metadata_map = None | |
if hash in EXAMPLE_METADATA.keys(): | |
model_result = "" | |
if model_continent == EXAMPLE_METADATA[hash]['continent'] and model_country == EXAMPLE_METADATA[hash]['country']: | |
model_result = "The AI π€ correctly guessed continent and country β β ." | |
elif model_continent == EXAMPLE_METADATA[hash]['continent']: | |
model_result = "The AI π€ only guessed the correct continent β β ." | |
elif model_country == EXAMPLE_METADATA[hash]['country'] and model_continent != EXAMPLE_METADATA[hash]['continent']: | |
model_result = "The AI π€ only guessed the correct country β β." | |
else: | |
model_result = "The AI π€ failed to guess country and continent β β." | |
metadata_block = gr.Accordion(visible=True, label=f"This photo was taken in {EXAMPLE_METADATA[hash]['country']}, {EXAMPLE_METADATA[hash]['continent']}.\n{model_result}") | |
metadata_map = make_versus_map(None, model_country, EXAMPLE_METADATA[hash]) | |
return continent_probs, country_probs, metadata_block, metadata_map | |
def make_versus_map(human_country, model_country, versus_state): | |
if human_country: | |
human_coordinates = country_to_center_coords[human_country] | |
else: | |
human_coordinates = (None, None) | |
model_coordinates = country_to_center_coords[model_country] | |
fig = go.Figure() | |
fig.add_trace(go.Scattermapbox( | |
lon=[versus_state["lon"]], | |
lat=[versus_state["lat"]], | |
text=[f"π· Photo taken in {versus_state['country']}, { | |
versus_state['continent']}"], | |
mode='markers', | |
hoverinfo='text', | |
marker=dict(size=14, color='#0C5DA5'), | |
showlegend=True, | |
name="π· Photo Location" | |
)) | |
if human_country == model_country: | |
fig.add_trace(go.Scattermapbox( | |
lat=[human_coordinates[0], model_coordinates[0]], | |
lon=[human_coordinates[1], model_coordinates[1]], | |
text=f"π§ π€ Human & AI guess {human_country}", | |
mode='markers', | |
hoverinfo='text', | |
marker=dict(size=14, color='#FF9500'), | |
showlegend=True, | |
name="π§ π€ Human & AI Guess" | |
)) | |
else: | |
if human_country: | |
fig.add_trace(go.Scattermapbox( | |
lat=[human_coordinates[0]], | |
lon=[human_coordinates[1]], | |
text=[f"π§ Human guesses {human_country}"], | |
mode='markers', | |
hoverinfo='text', | |
marker=dict(size=14, color='#FF9500'), | |
showlegend=True, | |
name="π§ Human Guess" | |
)) | |
fig.add_trace(go.Scattermapbox( | |
lat=[model_coordinates[0]], | |
lon=[model_coordinates[1]], | |
text=[f"π€ AI guesses {model_country}"], | |
mode='markers', | |
hoverinfo='text', | |
marker=dict(size=14, color='#474747'), | |
showlegend=True, | |
name="π€ AI Guess" | |
)) | |
fig.update_layout( | |
mapbox=dict( | |
style="carto-positron", | |
center=dict(lat=float(versus_state["lat"]), lon=float(versus_state["lon"])), | |
zoom=2 | |
), | |
margin={"r": 0, "t": 0, "l": 0, "b": 0}, | |
legend=dict( | |
yanchor="bottom", | |
y=0.01, | |
xanchor="left", | |
x=0.01 | |
) | |
) | |
return fig | |
def versus_mode_inputs(input_img, human_continent, human_country, versus_state): | |
human_points = 0 | |
model_points = 0 | |
if human_country == versus_state["country"]: | |
country_result = "β " | |
human_points += 2 | |
else: | |
country_result = "β" | |
if human_continent == versus_state["continent"]: | |
continent_result = "β " | |
human_points += 1 | |
else: | |
continent_result = "β" | |
human_result = f"The photo is from **{versus_state['country']}** { | |
country_result} in **{versus_state['continent']}** {continent_result}" | |
human_score_update = f"+{human_points} points" if human_points > 0 else "0 Points..." | |
versus_state['score']['HUMAN'] += human_points | |
continent_probs, country_probs, _,_ = predict(input_img) | |
model_country = max(country_probs, key=country_probs.get) | |
model_continent = max(continent_probs, key=continent_probs.get) | |
if model_country == versus_state["country"]: | |
model_country_result = "β " | |
model_points += 2 | |
else: | |
model_country_result = "β" | |
if model_continent == versus_state["continent"]: | |
model_continent_result = "β " | |
model_points += 1 | |
else: | |
model_continent_result = "β" | |
model_score_update = f"+{model_points} points" if model_points > 0 else "0 Points... The model was completely wrong, it seems the world is not doomed yet." | |
versus_state['score']['AI'] += model_points | |
map = make_versus_map(human_country, model_country, versus_state) | |
return f""" | |
## {human_result} | |
### The AI π€ thinks this photo is from **{model_country}** {model_country_result} in **{model_continent}** {model_continent_result} | |
π§ {human_score_update} | |
π€ {model_score_update} | |
### Score π§ {versus_state['score']['HUMAN']} : {versus_state['score']['AI']} π€ | |
""", continent_probs, country_probs, map, versus_state | |
def get_example_images(dir): | |
image_extensions = (".jpg", ".jpeg", ".png") | |
image_files = [] | |
for root, dirs, files in os.walk(dir): | |
for file in files: | |
if file.lower().endswith(image_extensions): | |
image_files.append(os.path.join(root, file)) | |
return image_files | |
def next_versus_image(versus_state): | |
images = get_example_images("versus_images") | |
versus_state["idx"] += 1 | |
if versus_state["idx"] > len(images): | |
versus_state["idx"] = 0 | |
versus_image = images[versus_state["idx"]] | |
versus_state["continent"] = versus_image.split("/")[-1].split("_")[0] | |
versus_state["country"] = versus_image.split("/")[-1].split("_")[1] | |
versus_state["lat"] = versus_image.split("/")[-1].split("_")[2] | |
versus_state["lon"] = versus_image.split("/")[-1].split("_")[3] | |
versus_state["image"] = versus_image | |
return versus_image, versus_state, None, None | |
example_images = get_example_images("kerger-test-images") | |
EXAMPLE_METADATA = {} | |
for img_path in example_images: | |
hash = hashlib.sha1(np.asarray(Image.open(img_path)).data.tobytes()).hexdigest() | |
EXAMPLE_METADATA[hash] = { | |
"continent": img_path.split("/")[-1].split("_")[0], | |
"country": img_path.split("/")[-1].split("_")[1], | |
"lat": img_path.split("/")[-1].split("_")[2], | |
"lon": img_path.split("/")[-1].split("_")[3], | |
} | |
demo = gr.Blocks(title="Thesis Demo") | |
with demo: | |
gr.HTML(""" | |
<h1 style="text-align: center; margin-bottom: 1rem">Image Geolocation Thesis Demo</h1> | |
<h3> This Demo showcases the developed models and allows interacting with the optimized prototype.</h3> | |
<p>Try the <b>"Image Geolocation Demo"</b> tab with your own images or with one of the examples. For all example image the ground truth is available and will be displayed together with the model predictions.</p> | |
<p>In the <b>"Versus Mode"</b> tab to play against the AI, guessing the country and continent where images where taken. Images in the versus mode are from the <a href="http://graphics.cs.cmu.edu/projects/im2gps/"><code>Im2GPS</code></a> and <a href="https://arxiv.org/abs/1705.04838"><code>Im2GPS3k</code></a> geolocation literature benchmarks. Can you beat the AI? | |
""") | |
with gr.Accordion(label="The demo currently encompasses 116 countries from 6 continents π", open=False): | |
gr.Code(json.dumps(countries_per_continent, indent=2, ensure_ascii=False), label="countries_per_continent.json", language="json", interactive=False) | |
with gr.Tab("Image Geolocation Demo"): | |
with gr.Row(): | |
with gr.Column(): | |
image = gr.Image(label="Image", type="pil", | |
sources=["upload", "clipboard"]) | |
predict_btn = gr.Button("Predict") | |
example_images = get_example_images("kerger-test-images") | |
# example_images.extend(get_example_images("versus_images")) | |
gr.Examples(examples=example_images, | |
inputs=image, examples_per_page=24) | |
with gr.Column(): | |
with gr.Accordion(visible=False) as metadata_block: | |
map = gr.Plot(label="Locations") | |
with gr.Group(): | |
continents_label = gr.Label(label="Continents") | |
country_label = gr.Label( | |
num_top_classes=5, label="Top countries") | |
predict_btn.click(predict, inputs=image, outputs=[ | |
continents_label, country_label, metadata_block, map]) | |
with gr.Tab("Versus Mode"): | |
versus_state = gr.State(value=INITAL_VERSUS_STATE) | |
with gr.Row(): | |
with gr.Column(): | |
versus_image = gr.Image( | |
INITAL_VERSUS_STATE["image"], interactive=False) | |
continent_selection = gr.Radio( | |
continents, label="Continents", info="Where was this image taken? (1 Point)") | |
country_selection = gr.Dropdown(countries, label="Countries", info="Can you guess the exact country? (2 Points)"), | |
with gr.Row(): | |
next_img_btn = gr.Button("Try new image") | |
versus_btn = gr.Button("Submit guess") | |
with gr.Column(): | |
versus_output = gr.Markdown() | |
# with gr.Accordion("View Map", open=False): | |
map = gr.Plot(label="Locations") | |
with gr.Accordion("Full Model Output", open=False): | |
with gr.Group(): | |
continents_label = gr.Label(label="Continents") | |
country_label = gr.Label( | |
num_top_classes=5, label="Top countries") | |
next_img_btn.click(next_versus_image, inputs=[versus_state], outputs=[ | |
versus_image, versus_state, continent_selection, country_selection[0]]) | |
versus_btn.click(versus_mode_inputs, inputs=[versus_image, continent_selection, country_selection[0], versus_state], outputs=[ | |
versus_output, continents_label, country_label, map, versus_state]) | |
if __name__ == "__main__": | |
demo.launch(show_api=False) | |