Spaces:
Runtime error
Runtime error
import numpy as np | |
import gradio as gr | |
from transformers import CLIPProcessor, CLIPModel | |
import torch | |
import itertools | |
import os | |
import plotly.graph_objects as go | |
import hashlib | |
from PIL import Image | |
import json | |
import random | |
os.environ["PYTHONHASHSEED"] = "42" | |
CUDA_AVAILABLE = torch.cuda.is_available() | |
print(f"CUDA={CUDA_AVAILABLE}") | |
device = "cuda" if CUDA_AVAILABLE else "cpu" | |
if CUDA_AVAILABLE: | |
print(f"count={torch.cuda.device_count()}") | |
print(f"current={torch.cuda.get_device_name(torch.cuda.current_device())}") | |
continent_model = CLIPModel.from_pretrained( | |
"jrheiner/thesis-clip-geoloc-continent", | |
token=os.getenv("token"), | |
) | |
country_model = CLIPModel.from_pretrained( | |
"jrheiner/thesis-clip-geoloc-country", | |
token=os.getenv("token"), | |
) | |
processor = CLIPProcessor.from_pretrained( | |
"jrheiner/thesis-clip-geoloc-continent", | |
token=os.getenv("token"), | |
) | |
continent_model = continent_model.to(device) | |
country_model = country_model.to(device) | |
continents = ["Africa", "Asia", "Europe", "North America", "Oceania", "South America"] | |
countries_per_continent = { | |
"Africa": [ | |
"Botswana", | |
"Eswatini", | |
"Ghana", | |
"Kenya", | |
"Lesotho", | |
"Nigeria", | |
"Senegal", | |
"South Africa", | |
"Rwanda", | |
"Uganda", | |
"Tanzania", | |
"Madagascar", | |
"Djibouti", | |
"Mali", | |
"Libya", | |
"Morocco", | |
"Somalia", | |
"Tunisia", | |
"Egypt", | |
"Réunion", | |
], | |
"Asia": [ | |
"Bangladesh", | |
"Bhutan", | |
"Cambodia", | |
"China", | |
"India", | |
"Indonesia", | |
"Israel", | |
"Japan", | |
"Jordan", | |
"Kyrgyzstan", | |
"Laos", | |
"Malaysia", | |
"Mongolia", | |
"Nepal", | |
"Palestine", | |
"Philippines", | |
"Singapore", | |
"South Korea", | |
"Sri Lanka", | |
"Taiwan", | |
"Thailand", | |
"United Arab Emirates", | |
"Vietnam", | |
"Afghanistan", | |
"Azerbaijan", | |
"Cyprus", | |
"Iran", | |
"Syria", | |
"Tajikistan", | |
"Turkey", | |
"Russia", | |
"Pakistan", | |
"Hong Kong", | |
], | |
"Europe": [ | |
"Albania", | |
"Andorra", | |
"Austria", | |
"Belgium", | |
"Bulgaria", | |
"Croatia", | |
"Czechia", | |
"Denmark", | |
"Estonia", | |
"Finland", | |
"France", | |
"Germany", | |
"Greece", | |
"Hungary", | |
"Iceland", | |
"Ireland", | |
"Italy", | |
"Latvia", | |
"Lithuania", | |
"Luxembourg", | |
"Montenegro", | |
"Netherlands", | |
"North Macedonia", | |
"Norway", | |
"Poland", | |
"Portugal", | |
"Romania", | |
"Russia", | |
"Serbia", | |
"Slovakia", | |
"Slovenia", | |
"Spain", | |
"Sweden", | |
"Switzerland", | |
"Ukraine", | |
"United Kingdom", | |
"Bosnia and Herzegovina", | |
"Cyprus", | |
"Turkey", | |
"Greenland", | |
"Faroe Islands", | |
], | |
"North America": [ | |
"Canada", | |
"Dominican Republic", | |
"Guatemala", | |
"Mexico", | |
"United States", | |
"Bahamas", | |
"Cuba", | |
"Panama", | |
"Puerto Rico", | |
"Bermuda", | |
"Greenland", | |
], | |
"Oceania": [ | |
"Australia", | |
"New Zealand", | |
"Fiji", | |
"Papua New Guinea", | |
"Solomon Islands", | |
"Vanuatu", | |
], | |
"South America": [ | |
"Argentina", | |
"Bolivia", | |
"Brazil", | |
"Chile", | |
"Colombia", | |
"Ecuador", | |
"Paraguay", | |
"Peru", | |
"Uruguay", | |
], | |
} | |
countries = list(set(itertools.chain.from_iterable(countries_per_continent.values()))) | |
country_to_center_coords = { | |
"Indonesia": (-2.4833826, 117.8902853), | |
"Egypt": (26.2540493, 29.2675469), | |
"Dominican Republic": (19.0974031, -70.3028026), | |
"Russia": (64.6863136, 97.7453061), | |
"Denmark": (55.670249, 10.3333283), | |
"Latvia": (56.8406494, 24.7537645), | |
"Hong Kong": (22.350627, 114.1849161), | |
"Brazil": (-10.3333333, -53.2), | |
"Turkey": (38.9597594, 34.9249653), | |
"Paraguay": (-23.3165935, -58.1693445), | |
"Nigeria": (9.6000359, 7.9999721), | |
"United Kingdom": (54.7023545, -3.2765753), | |
"Argentina": (-34.9964963, -64.9672817), | |
"United Arab Emirates": (24.0002488, 53.9994829), | |
"Estonia": (58.7523778, 25.3319078), | |
"Greenland": (69.6354163, -42.1736914), | |
"Canada": (61.0666922, -107.991707), | |
"Andorra": (42.5407167, 1.5732033), | |
"Czechia": (49.7439047, 15.3381061), | |
"Australia": (-24.7761086, 134.755), | |
"Azerbaijan": (40.3936294, 47.7872508), | |
"Cambodia": (12.5433216, 104.8144914), | |
"Peru": (-6.8699697, -75.0458515), | |
"Slovakia": (48.7411522, 19.4528646), | |
"Réunion": (-21.130737949999997, 55.536480112992315), | |
"France": (46.603354, 1.8883335), | |
"Israel": (30.8124247, 34.8594762), | |
"China": (35.000074, 104.999927), | |
"Ecuador": (-1.3397668, -79.3666965), | |
"Poland": (52.215933, 19.134422), | |
"Switzerland": (46.7985624, 8.2319736), | |
"Singapore": (1.357107, 103.8194992), | |
"Kenya": (1.4419683, 38.4313975), | |
"Bhutan": (27.549511, 90.5119273), | |
"Laos": (20.0171109, 103.378253), | |
"Vietnam": (15.9266657, 107.9650855), | |
"Puerto Rico": (18.2247706, -66.4858295), | |
"Germany": (51.1638175, 10.4478313), | |
"Tanzania": (-6.5247123, 35.7878438), | |
"Colombia": (4.099917, -72.9088133), | |
"Italy": (42.6384261, 12.674297), | |
"Bahamas": (24.7736546, -78.0000547), | |
"Panama": (8.559559, -81.1308434), | |
"Bulgaria": (42.6073975, 25.4856617), | |
"Solomon Islands": (-8.7053941, 159.1070693851845), | |
"Afghanistan": (33.7680065, 66.2385139), | |
"Tajikistan": (38.6281733, 70.8156541), | |
"Portugal": (39.6621648, -8.1353519), | |
"Tunisia": (36.8002068, 10.1857757), | |
"Bolivia": (-17.0568696, -64.9912286), | |
"Malaysia": (4.5693754, 102.2656823), | |
"Lithuania": (55.3500003, 23.7499997), | |
"Sweden": (59.6749712, 14.5208584), | |
"Belgium": (50.6402809, 4.6667145), | |
"Libya": (26.8234472, 18.1236723), | |
"Guatemala": (15.5855545, -90.345759), | |
"India": (22.3511148, 78.6677428), | |
"Sri Lanka": (7.5554942, 80.7137847), | |
"New Zealand": (-41.5000831, 172.8344077), | |
"Iceland": (64.9841821, -18.1059013), | |
"Somalia": (8.3676771, 49.083416), | |
"Croatia": (45.3658443, 15.6575209), | |
"Bosnia and Herzegovina": (44.3053476, 17.5961467), | |
"Greece": (38.9953683, 21.9877132), | |
"Rwanda": (-1.9646631, 30.0644358), | |
"Hungary": (47.1817585, 19.5060937), | |
"Eswatini": (-26.5624806, 31.3991317), | |
"Kyrgyzstan": (41.5089324, 74.724091), | |
"Bangladesh": (23.6943117, 90.344352), | |
"Morocco": (28.3347722, -10.371337908392647), | |
"Finland": (63.2467777, 25.9209164), | |
"Luxembourg": (49.6112768, 6.129799), | |
"North Macedonia": (41.6171214, 21.7168387), | |
"Uruguay": (-32.8755548, -56.0201525), | |
"Chile": (-31.7613365, -71.3187697), | |
"Spain": (39.3260685, -4.8379791), | |
"South Korea": (36.638392, 127.6961188), | |
"Botswana": (-23.1681782, 24.5928742), | |
"Uganda": (1.5333554, 32.2166578), | |
"Papua New Guinea": (-5.6816069, 144.2489081), | |
"Mali": (16.3700359, -2.2900239), | |
"Philippines": (12.7503486, 122.7312101), | |
"Norway": (64.5731537, 11.52803643954819), | |
"Thailand": (14.8971921, 100.83273), | |
"Mongolia": (46.8651082, 103.8347844), | |
"Japan": (36.5748441, 139.2394179), | |
"Montenegro": (42.7044223, 19.3957785), | |
"Austria": (47.59397, 14.12456), | |
"Taiwan": (23.6978, 120.9605), | |
"Netherlands": (52.2434979, 5.6343227), | |
"Ukraine": (49.4871968, 31.2718321), | |
"Fiji": (-18.1239696, 179.0122737), | |
"Ghana": (8.0300284, -1.0800271), | |
"Cuba": (23.0131338, -80.8328748), | |
"Nepal": (28.3780464, 83.9999901), | |
"Faroe Islands": (62.0448724, -7.0322972), | |
"Slovenia": (46.1199444, 14.8153333), | |
"Cyprus": (34.9174159, 32.889902651331866), | |
"Serbia": (44.024322850000004, 21.07657433209902), | |
"Madagascar": (-18.9249604, 46.4416422), | |
"Pakistan": (30.3308401, 71.247499), | |
"Syria": (34.6401861, 39.0494106), | |
"Iran": (32.6475314, 54.5643516), | |
"Ireland": (52.865196, -7.9794599), | |
"South Africa": (-28.8166236, 24.991639), | |
"Albania": (41.1529058, 20.1605717), | |
"Lesotho": (-29.6039267, 28.3350193), | |
"Romania": (45.9852129, 24.6859225), | |
"Palestine": (31.947351, 35.227163), | |
"Vanuatu": (-16.5255069, 168.1069154), | |
"Mexico": (19.4326296, -99.1331785), | |
"Jordan": (31.279862, 37.1297454), | |
"Djibouti": (11.8145966, 42.8453061), | |
"Senegal": (14.4750607, -14.4529612), | |
"Bermuda": (32.3040273, -64.7563086), | |
"United States": (39.7837304, -100.445882), | |
} | |
def predict(input_img): | |
inputs = processor( | |
text=[f"A photo from {geo}." for geo in continents], | |
images=input_img, | |
return_tensors="pt", | |
padding=True, | |
) | |
inputs = inputs.to(device) | |
with torch.no_grad(): | |
outputs = continent_model(**inputs) | |
logits_per_image = outputs.logits_per_image | |
probs = logits_per_image.softmax(dim=-1) | |
pred_id = probs.argmax().cpu().item() | |
continent_probs = { | |
label: prob for label, prob in zip(continents, probs.tolist()[0]) | |
} | |
model_continent = continents[pred_id] | |
predicted_continent_countries = countries_per_continent[model_continent] | |
inputs = processor( | |
text=[f"A photo from {geo}." for geo in predicted_continent_countries], | |
images=input_img, | |
return_tensors="pt", | |
padding=True, | |
) | |
inputs = inputs.to(device) | |
with torch.no_grad(): | |
outputs = country_model(**inputs) | |
logits_per_image = outputs.logits_per_image | |
probs = logits_per_image.softmax(dim=-1) | |
pred_id = probs.argmax().cpu().item() | |
model_country = predicted_continent_countries[pred_id] | |
country_probs = { | |
label: prob for label, prob in zip(predicted_continent_countries, probs.tolist()[0]) | |
} | |
hash = hashlib.sha1(np.asarray(input_img).data.tobytes()).hexdigest() | |
metadata_block = gr.Accordion(visible=False) | |
metadata_map = None | |
if hash in EXAMPLE_METADATA.keys(): | |
model_result = "" | |
if ( | |
model_continent == EXAMPLE_METADATA[hash]["continent"] | |
and model_country == EXAMPLE_METADATA[hash]["country"] | |
): | |
model_result = "The AI 🤖 correctly guessed continent and country ✅ ✅." | |
elif model_continent == EXAMPLE_METADATA[hash]["continent"]: | |
model_result = "The AI 🤖 only guessed the correct continent ❌ ✅." | |
elif ( | |
model_country == EXAMPLE_METADATA[hash]["country"] | |
and model_continent != EXAMPLE_METADATA[hash]["continent"] | |
): | |
model_result = "The AI 🤖 only guessed the correct country ✅ ❌." | |
else: | |
model_result = "The AI 🤖 failed to guess country and continent ❌ ❌." | |
metadata_block = gr.Accordion( | |
visible=True, | |
label=f"This photo was taken in {EXAMPLE_METADATA[hash]['country']}, {EXAMPLE_METADATA[hash]['continent']}.\n{model_result}", | |
) | |
metadata_map = make_versus_map(None, model_country, EXAMPLE_METADATA[hash]) | |
return continent_probs, country_probs, metadata_block, metadata_map | |
def make_versus_map(human_country, model_country, versus_state): | |
if human_country: | |
human_coordinates = country_to_center_coords[human_country] | |
else: | |
human_coordinates = (None, None) | |
model_coordinates = country_to_center_coords[model_country] | |
fig = go.Figure() | |
fig.add_trace( | |
go.Scattermapbox( | |
lon=[versus_state["lon"]], | |
lat=[versus_state["lat"]], | |
text=[f"📷 Photo taken in {versus_state['country']}, {versus_state['continent']}"], | |
mode="markers", | |
hoverinfo="text", | |
marker=dict(size=14, color="#0C5DA5"), | |
showlegend=True, | |
name="📷 Photo Location", | |
) | |
) | |
if human_country == model_country: | |
fig.add_trace( | |
go.Scattermapbox( | |
lat=[human_coordinates[0], model_coordinates[0]], | |
lon=[human_coordinates[1], model_coordinates[1]], | |
text=f"🧑 🤖 Human & AI guess {human_country}", | |
mode="markers", | |
hoverinfo="text", | |
marker=dict(size=14, color="#FF9500"), | |
showlegend=True, | |
name="🧑 🤖 Human & AI Guess", | |
) | |
) | |
else: | |
if human_country: | |
fig.add_trace( | |
go.Scattermapbox( | |
lat=[human_coordinates[0]], | |
lon=[human_coordinates[1]], | |
text=[f"🧑 Human guesses {human_country}"], | |
mode="markers", | |
hoverinfo="text", | |
marker=dict(size=14, color="#FF9500"), | |
showlegend=True, | |
name="🧑 Human Guess", | |
) | |
) | |
fig.add_trace( | |
go.Scattermapbox( | |
lat=[model_coordinates[0]], | |
lon=[model_coordinates[1]], | |
text=[f"🤖 AI guesses {model_country}"], | |
mode="markers", | |
hoverinfo="text", | |
marker=dict(size=14, color="#474747"), | |
showlegend=True, | |
name="🤖 AI Guess", | |
) | |
) | |
fig.update_layout( | |
mapbox=dict( | |
style="carto-positron", | |
center=dict(lat=float(versus_state["lat"]), lon=float(versus_state["lon"])), | |
zoom=2, | |
), | |
margin={"r": 0, "t": 0, "l": 0, "b": 0}, | |
legend=dict(yanchor="bottom", y=0.01, xanchor="left", x=0.01), | |
) | |
return fig | |
def versus_mode_inputs(input_img, human_continent, human_country, versus_state): | |
human_points = 0 | |
model_points = 0 | |
if human_country == versus_state["country"]: | |
country_result = "✅" | |
human_points += 2 | |
else: | |
country_result = "❌" | |
if human_continent == versus_state["continent"]: | |
continent_result = "✅" | |
human_points += 1 | |
else: | |
continent_result = "❌" | |
human_result = f"The photo is from **{versus_state['country']}** {country_result} in **{versus_state['continent']}** {continent_result}" | |
human_score_update = ( | |
f"+{human_points} points" if human_points > 0 else "0 Points..." | |
) | |
versus_state["score"]["HUMAN"] += human_points | |
continent_probs, country_probs, _, _ = predict(input_img) | |
model_country = max(country_probs, key=country_probs.get) | |
model_continent = max(continent_probs, key=continent_probs.get) | |
if model_country == versus_state["country"]: | |
model_country_result = "✅" | |
model_points += 2 | |
else: | |
model_country_result = "❌" | |
if model_continent == versus_state["continent"]: | |
model_continent_result = "✅" | |
model_points += 1 | |
else: | |
model_continent_result = "❌" | |
model_score_update = ( | |
f"+{model_points} points" | |
if model_points > 0 | |
else "0 Points... The model was completely wrong, it seems the world is not doomed yet." | |
) | |
versus_state["score"]["AI"] += model_points | |
map = make_versus_map(human_country, model_country, versus_state) | |
return ( | |
f""" | |
## {human_result} | |
### The AI 🤖 thinks this photo is from **{model_country}** {model_country_result} in **{model_continent}** {model_continent_result} | |
🧑 {human_score_update} | |
🤖 {model_score_update} | |
### Score 🧑 {versus_state['score']['HUMAN']} : {versus_state['score']['AI']} 🤖 | |
""", | |
continent_probs, | |
country_probs, | |
map, | |
versus_state, | |
) | |
def get_example_images(dir): | |
image_extensions = (".jpg", ".jpeg", ".png") | |
image_files = [] | |
for root, dirs, files in os.walk(dir): | |
for file in files: | |
if file.lower().endswith(image_extensions): | |
image_files.append(os.path.join(root, file)) | |
return image_files | |
def next_versus_image(versus_state): | |
versus_image = random.sample(versus_state["images"], 1)[0] | |
versus_state["continent"] = versus_image.split("/")[-1].split("_")[0] | |
versus_state["country"] = versus_image.split("/")[-1].split("_")[1] | |
versus_state["lat"] = versus_image.split("/")[-1].split("_")[2] | |
versus_state["lon"] = versus_image.split("/")[-1].split("_")[3] | |
versus_state["image"] = versus_image | |
return versus_image, versus_state, None, None | |
example_images = get_example_images("kerger-test-images") | |
EXAMPLE_METADATA = {} | |
for img_path in example_images: | |
hash = hashlib.sha1(np.asarray(Image.open(img_path)).data.tobytes()).hexdigest() | |
EXAMPLE_METADATA[hash] = { | |
"continent": img_path.split("/")[-1].split("_")[0], | |
"country": img_path.split("/")[-1].split("_")[1], | |
"lat": img_path.split("/")[-1].split("_")[2], | |
"lon": img_path.split("/")[-1].split("_")[3], | |
} | |
def set_up_intial_state(): | |
INTIAL_VERSUS_IMAGE = "versus_images/Europe_Germany_49.069183_10.319444_im2gps3k.jpg" | |
INITAL_VERSUS_STATE = { | |
"image": INTIAL_VERSUS_IMAGE, | |
"continent": INTIAL_VERSUS_IMAGE.split("/")[-1].split("_")[0], | |
"country": INTIAL_VERSUS_IMAGE.split("/")[-1].split("_")[1], | |
"lat": INTIAL_VERSUS_IMAGE.split("/")[-1].split("_")[2], | |
"lon": INTIAL_VERSUS_IMAGE.split("/")[-1].split("_")[3], | |
"score": {"HUMAN": 0, "AI": 0}, | |
"images": get_example_images("versus_images") | |
} | |
return INITAL_VERSUS_STATE | |
demo = gr.Blocks(title="Thesis Demo") | |
with demo: | |
gr.HTML( | |
""" | |
<h1 style="text-align: center; margin-bottom: 1rem">Image Geolocation Thesis Demo</h1> | |
<h3> This Demo showcases the developed models and allows interacting with the optimized prototype.</h3> | |
<p>Try the <b>"Image Geolocation Demo"</b> tab with your own images or with one of the examples. For all example image the ground truth is available and will be displayed together with the model predictions.</p> | |
<p>In the <b>"Versus Mode"</b> tab you can play against the AI, guessing the country and continent where images where taken. Images in the versus mode are from the <a href="http://graphics.cs.cmu.edu/projects/im2gps/" target="_blank" rel="noopener noreferrer"><code>Im2GPS</code></a> and <a href="https://arxiv.org/abs/1705.04838" target="_blank" rel="noopener noreferrer"><code>Im2GPS3k</code></a> geolocation literature benchmarks. Can you beat the AI? | |
<div style="font-style: italic; font-size: smaller;">Note that inference in this publicly hosted version is very slow due to the limited and shared hardware. This demo runs on the <a href="https://huggingface.co/pricing#spaces" style="color: inherit;" target="_blank" rel="noopener noreferrer">Hugging Face free tier</a> without GPU acceleration. Running the demo with a GPU allows for inference times between 0.5-2 seconds per image.</div> | |
""" | |
) | |
with gr.Accordion( | |
label="The demo currently encompasses 116 countries from 6 continents 🌍", | |
open=False, | |
): | |
gr.Code( | |
json.dumps(countries_per_continent, indent=2, ensure_ascii=False), | |
label="countries_per_continent.json", | |
language="json", | |
interactive=False, | |
) | |
with gr.Tab("Image Geolocation Demo"): | |
with gr.Row(): | |
with gr.Column(): | |
image = gr.Image( | |
label="Image", type="pil", sources=["upload", "clipboard"], show_fullscreen_button=True | |
) | |
predict_btn = gr.Button("Predict") | |
example_images = get_example_images("kerger-test-images") | |
# example_images.extend(get_example_images("versus_images")) | |
gr.Examples(examples=example_images, inputs=image, examples_per_page=24) | |
with gr.Column(): | |
with gr.Accordion(visible=False) as metadata_block: | |
map = gr.Plot(label="Locations") | |
with gr.Group(): | |
continents_label = gr.Label(label="Continents") | |
country_label = gr.Label(num_top_classes=5, label="Top countries") | |
predict_btn.click( | |
predict, | |
inputs=image, | |
outputs=[continents_label, country_label, metadata_block, map], | |
) | |
with gr.Tab("Versus Mode"): | |
versus_state = gr.State(value=set_up_intial_state()) | |
with gr.Row(): | |
with gr.Column(): | |
versus_image = gr.Image(versus_state.value["image"], interactive=False, show_download_button=False, show_share_button=False, show_fullscreen_button=True) | |
continent_selection = gr.Radio( | |
continents, | |
label="Continents", | |
info="Where was this image taken? (1 Point)", | |
) | |
country_selection = ( | |
gr.Dropdown( | |
countries, | |
label="Countries", | |
info="Can you guess the exact country? (2 Points)", | |
), | |
) | |
with gr.Row(): | |
next_img_btn = gr.Button("Try new image") | |
versus_btn = gr.Button("Submit guess") | |
with gr.Column(): | |
versus_output = gr.Markdown() | |
# with gr.Accordion("View Map", open=False): | |
map = gr.Plot(label="Locations") | |
with gr.Accordion("Full Model Output", open=False): | |
with gr.Group(): | |
continents_label = gr.Label(label="Continents") | |
country_label = gr.Label( | |
num_top_classes=5, label="Top countries" | |
) | |
next_img_btn.click( | |
next_versus_image, | |
inputs=[versus_state], | |
outputs=[ | |
versus_image, | |
versus_state, | |
continent_selection, | |
country_selection[0], | |
], | |
) | |
versus_btn.click( | |
versus_mode_inputs, | |
inputs=[ | |
versus_image, | |
continent_selection, | |
country_selection[0], | |
versus_state, | |
], | |
outputs=[versus_output, continents_label, country_label, map, versus_state], | |
) | |
if __name__ == "__main__": | |
demo.launch(show_api=False) | |