thesis-demo / app.py
jrheiner's picture
Fix typo
ae79cd2 verified
import numpy as np
import gradio as gr
from transformers import CLIPProcessor, CLIPModel
import torch
import itertools
import os
import plotly.graph_objects as go
import hashlib
from PIL import Image
import json
import random
os.environ["PYTHONHASHSEED"] = "42"
CUDA_AVAILABLE = torch.cuda.is_available()
print(f"CUDA={CUDA_AVAILABLE}")
device = "cuda" if CUDA_AVAILABLE else "cpu"
if CUDA_AVAILABLE:
print(f"count={torch.cuda.device_count()}")
print(f"current={torch.cuda.get_device_name(torch.cuda.current_device())}")
continent_model = CLIPModel.from_pretrained(
"jrheiner/thesis-clip-geoloc-continent",
token=os.getenv("token"),
)
country_model = CLIPModel.from_pretrained(
"jrheiner/thesis-clip-geoloc-country",
token=os.getenv("token"),
)
processor = CLIPProcessor.from_pretrained(
"jrheiner/thesis-clip-geoloc-continent",
token=os.getenv("token"),
)
continent_model = continent_model.to(device)
country_model = country_model.to(device)
continents = ["Africa", "Asia", "Europe", "North America", "Oceania", "South America"]
countries_per_continent = {
"Africa": [
"Botswana",
"Eswatini",
"Ghana",
"Kenya",
"Lesotho",
"Nigeria",
"Senegal",
"South Africa",
"Rwanda",
"Uganda",
"Tanzania",
"Madagascar",
"Djibouti",
"Mali",
"Libya",
"Morocco",
"Somalia",
"Tunisia",
"Egypt",
"Réunion",
],
"Asia": [
"Bangladesh",
"Bhutan",
"Cambodia",
"China",
"India",
"Indonesia",
"Israel",
"Japan",
"Jordan",
"Kyrgyzstan",
"Laos",
"Malaysia",
"Mongolia",
"Nepal",
"Palestine",
"Philippines",
"Singapore",
"South Korea",
"Sri Lanka",
"Taiwan",
"Thailand",
"United Arab Emirates",
"Vietnam",
"Afghanistan",
"Azerbaijan",
"Cyprus",
"Iran",
"Syria",
"Tajikistan",
"Turkey",
"Russia",
"Pakistan",
"Hong Kong",
],
"Europe": [
"Albania",
"Andorra",
"Austria",
"Belgium",
"Bulgaria",
"Croatia",
"Czechia",
"Denmark",
"Estonia",
"Finland",
"France",
"Germany",
"Greece",
"Hungary",
"Iceland",
"Ireland",
"Italy",
"Latvia",
"Lithuania",
"Luxembourg",
"Montenegro",
"Netherlands",
"North Macedonia",
"Norway",
"Poland",
"Portugal",
"Romania",
"Russia",
"Serbia",
"Slovakia",
"Slovenia",
"Spain",
"Sweden",
"Switzerland",
"Ukraine",
"United Kingdom",
"Bosnia and Herzegovina",
"Cyprus",
"Turkey",
"Greenland",
"Faroe Islands",
],
"North America": [
"Canada",
"Dominican Republic",
"Guatemala",
"Mexico",
"United States",
"Bahamas",
"Cuba",
"Panama",
"Puerto Rico",
"Bermuda",
"Greenland",
],
"Oceania": [
"Australia",
"New Zealand",
"Fiji",
"Papua New Guinea",
"Solomon Islands",
"Vanuatu",
],
"South America": [
"Argentina",
"Bolivia",
"Brazil",
"Chile",
"Colombia",
"Ecuador",
"Paraguay",
"Peru",
"Uruguay",
],
}
countries = list(set(itertools.chain.from_iterable(countries_per_continent.values())))
country_to_center_coords = {
"Indonesia": (-2.4833826, 117.8902853),
"Egypt": (26.2540493, 29.2675469),
"Dominican Republic": (19.0974031, -70.3028026),
"Russia": (64.6863136, 97.7453061),
"Denmark": (55.670249, 10.3333283),
"Latvia": (56.8406494, 24.7537645),
"Hong Kong": (22.350627, 114.1849161),
"Brazil": (-10.3333333, -53.2),
"Turkey": (38.9597594, 34.9249653),
"Paraguay": (-23.3165935, -58.1693445),
"Nigeria": (9.6000359, 7.9999721),
"United Kingdom": (54.7023545, -3.2765753),
"Argentina": (-34.9964963, -64.9672817),
"United Arab Emirates": (24.0002488, 53.9994829),
"Estonia": (58.7523778, 25.3319078),
"Greenland": (69.6354163, -42.1736914),
"Canada": (61.0666922, -107.991707),
"Andorra": (42.5407167, 1.5732033),
"Czechia": (49.7439047, 15.3381061),
"Australia": (-24.7761086, 134.755),
"Azerbaijan": (40.3936294, 47.7872508),
"Cambodia": (12.5433216, 104.8144914),
"Peru": (-6.8699697, -75.0458515),
"Slovakia": (48.7411522, 19.4528646),
"Réunion": (-21.130737949999997, 55.536480112992315),
"France": (46.603354, 1.8883335),
"Israel": (30.8124247, 34.8594762),
"China": (35.000074, 104.999927),
"Ecuador": (-1.3397668, -79.3666965),
"Poland": (52.215933, 19.134422),
"Switzerland": (46.7985624, 8.2319736),
"Singapore": (1.357107, 103.8194992),
"Kenya": (1.4419683, 38.4313975),
"Bhutan": (27.549511, 90.5119273),
"Laos": (20.0171109, 103.378253),
"Vietnam": (15.9266657, 107.9650855),
"Puerto Rico": (18.2247706, -66.4858295),
"Germany": (51.1638175, 10.4478313),
"Tanzania": (-6.5247123, 35.7878438),
"Colombia": (4.099917, -72.9088133),
"Italy": (42.6384261, 12.674297),
"Bahamas": (24.7736546, -78.0000547),
"Panama": (8.559559, -81.1308434),
"Bulgaria": (42.6073975, 25.4856617),
"Solomon Islands": (-8.7053941, 159.1070693851845),
"Afghanistan": (33.7680065, 66.2385139),
"Tajikistan": (38.6281733, 70.8156541),
"Portugal": (39.6621648, -8.1353519),
"Tunisia": (36.8002068, 10.1857757),
"Bolivia": (-17.0568696, -64.9912286),
"Malaysia": (4.5693754, 102.2656823),
"Lithuania": (55.3500003, 23.7499997),
"Sweden": (59.6749712, 14.5208584),
"Belgium": (50.6402809, 4.6667145),
"Libya": (26.8234472, 18.1236723),
"Guatemala": (15.5855545, -90.345759),
"India": (22.3511148, 78.6677428),
"Sri Lanka": (7.5554942, 80.7137847),
"New Zealand": (-41.5000831, 172.8344077),
"Iceland": (64.9841821, -18.1059013),
"Somalia": (8.3676771, 49.083416),
"Croatia": (45.3658443, 15.6575209),
"Bosnia and Herzegovina": (44.3053476, 17.5961467),
"Greece": (38.9953683, 21.9877132),
"Rwanda": (-1.9646631, 30.0644358),
"Hungary": (47.1817585, 19.5060937),
"Eswatini": (-26.5624806, 31.3991317),
"Kyrgyzstan": (41.5089324, 74.724091),
"Bangladesh": (23.6943117, 90.344352),
"Morocco": (28.3347722, -10.371337908392647),
"Finland": (63.2467777, 25.9209164),
"Luxembourg": (49.6112768, 6.129799),
"North Macedonia": (41.6171214, 21.7168387),
"Uruguay": (-32.8755548, -56.0201525),
"Chile": (-31.7613365, -71.3187697),
"Spain": (39.3260685, -4.8379791),
"South Korea": (36.638392, 127.6961188),
"Botswana": (-23.1681782, 24.5928742),
"Uganda": (1.5333554, 32.2166578),
"Papua New Guinea": (-5.6816069, 144.2489081),
"Mali": (16.3700359, -2.2900239),
"Philippines": (12.7503486, 122.7312101),
"Norway": (64.5731537, 11.52803643954819),
"Thailand": (14.8971921, 100.83273),
"Mongolia": (46.8651082, 103.8347844),
"Japan": (36.5748441, 139.2394179),
"Montenegro": (42.7044223, 19.3957785),
"Austria": (47.59397, 14.12456),
"Taiwan": (23.6978, 120.9605),
"Netherlands": (52.2434979, 5.6343227),
"Ukraine": (49.4871968, 31.2718321),
"Fiji": (-18.1239696, 179.0122737),
"Ghana": (8.0300284, -1.0800271),
"Cuba": (23.0131338, -80.8328748),
"Nepal": (28.3780464, 83.9999901),
"Faroe Islands": (62.0448724, -7.0322972),
"Slovenia": (46.1199444, 14.8153333),
"Cyprus": (34.9174159, 32.889902651331866),
"Serbia": (44.024322850000004, 21.07657433209902),
"Madagascar": (-18.9249604, 46.4416422),
"Pakistan": (30.3308401, 71.247499),
"Syria": (34.6401861, 39.0494106),
"Iran": (32.6475314, 54.5643516),
"Ireland": (52.865196, -7.9794599),
"South Africa": (-28.8166236, 24.991639),
"Albania": (41.1529058, 20.1605717),
"Lesotho": (-29.6039267, 28.3350193),
"Romania": (45.9852129, 24.6859225),
"Palestine": (31.947351, 35.227163),
"Vanuatu": (-16.5255069, 168.1069154),
"Mexico": (19.4326296, -99.1331785),
"Jordan": (31.279862, 37.1297454),
"Djibouti": (11.8145966, 42.8453061),
"Senegal": (14.4750607, -14.4529612),
"Bermuda": (32.3040273, -64.7563086),
"United States": (39.7837304, -100.445882),
}
def predict(input_img):
inputs = processor(
text=[f"A photo from {geo}." for geo in continents],
images=input_img,
return_tensors="pt",
padding=True,
)
inputs = inputs.to(device)
with torch.no_grad():
outputs = continent_model(**inputs)
logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=-1)
pred_id = probs.argmax().cpu().item()
continent_probs = {
label: prob for label, prob in zip(continents, probs.tolist()[0])
}
model_continent = continents[pred_id]
predicted_continent_countries = countries_per_continent[model_continent]
inputs = processor(
text=[f"A photo from {geo}." for geo in predicted_continent_countries],
images=input_img,
return_tensors="pt",
padding=True,
)
inputs = inputs.to(device)
with torch.no_grad():
outputs = country_model(**inputs)
logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=-1)
pred_id = probs.argmax().cpu().item()
model_country = predicted_continent_countries[pred_id]
country_probs = {
label: prob for label, prob in zip(predicted_continent_countries, probs.tolist()[0])
}
hash = hashlib.sha1(np.asarray(input_img).data.tobytes()).hexdigest()
metadata_block = gr.Accordion(visible=False)
metadata_map = None
if hash in EXAMPLE_METADATA.keys():
model_result = ""
if (
model_continent == EXAMPLE_METADATA[hash]["continent"]
and model_country == EXAMPLE_METADATA[hash]["country"]
):
model_result = "The AI 🤖 correctly guessed continent and country ✅ ✅."
elif model_continent == EXAMPLE_METADATA[hash]["continent"]:
model_result = "The AI 🤖 only guessed the correct continent ❌ ✅."
elif (
model_country == EXAMPLE_METADATA[hash]["country"]
and model_continent != EXAMPLE_METADATA[hash]["continent"]
):
model_result = "The AI 🤖 only guessed the correct country ✅ ❌."
else:
model_result = "The AI 🤖 failed to guess country and continent ❌ ❌."
metadata_block = gr.Accordion(
visible=True,
label=f"This photo was taken in {EXAMPLE_METADATA[hash]['country']}, {EXAMPLE_METADATA[hash]['continent']}.\n{model_result}",
)
metadata_map = make_versus_map(None, model_country, EXAMPLE_METADATA[hash])
return continent_probs, country_probs, metadata_block, metadata_map
def make_versus_map(human_country, model_country, versus_state):
if human_country:
human_coordinates = country_to_center_coords[human_country]
else:
human_coordinates = (None, None)
model_coordinates = country_to_center_coords[model_country]
fig = go.Figure()
fig.add_trace(
go.Scattermapbox(
lon=[versus_state["lon"]],
lat=[versus_state["lat"]],
text=[f"📷 Photo taken in {versus_state['country']}, {versus_state['continent']}"],
mode="markers",
hoverinfo="text",
marker=dict(size=14, color="#0C5DA5"),
showlegend=True,
name="📷 Photo Location",
)
)
if human_country == model_country:
fig.add_trace(
go.Scattermapbox(
lat=[human_coordinates[0], model_coordinates[0]],
lon=[human_coordinates[1], model_coordinates[1]],
text=f"🧑 🤖 Human & AI guess {human_country}",
mode="markers",
hoverinfo="text",
marker=dict(size=14, color="#FF9500"),
showlegend=True,
name="🧑 🤖 Human & AI Guess",
)
)
else:
if human_country:
fig.add_trace(
go.Scattermapbox(
lat=[human_coordinates[0]],
lon=[human_coordinates[1]],
text=[f"🧑 Human guesses {human_country}"],
mode="markers",
hoverinfo="text",
marker=dict(size=14, color="#FF9500"),
showlegend=True,
name="🧑 Human Guess",
)
)
fig.add_trace(
go.Scattermapbox(
lat=[model_coordinates[0]],
lon=[model_coordinates[1]],
text=[f"🤖 AI guesses {model_country}"],
mode="markers",
hoverinfo="text",
marker=dict(size=14, color="#474747"),
showlegend=True,
name="🤖 AI Guess",
)
)
fig.update_layout(
mapbox=dict(
style="carto-positron",
center=dict(lat=float(versus_state["lat"]), lon=float(versus_state["lon"])),
zoom=2,
),
margin={"r": 0, "t": 0, "l": 0, "b": 0},
legend=dict(yanchor="bottom", y=0.01, xanchor="left", x=0.01),
)
return fig
def versus_mode_inputs(input_img, human_continent, human_country, versus_state):
human_points = 0
model_points = 0
if human_country == versus_state["country"]:
country_result = "✅"
human_points += 2
else:
country_result = "❌"
if human_continent == versus_state["continent"]:
continent_result = "✅"
human_points += 1
else:
continent_result = "❌"
human_result = f"The photo is from **{versus_state['country']}** {country_result} in **{versus_state['continent']}** {continent_result}"
human_score_update = (
f"+{human_points} points" if human_points > 0 else "0 Points..."
)
versus_state["score"]["HUMAN"] += human_points
continent_probs, country_probs, _, _ = predict(input_img)
model_country = max(country_probs, key=country_probs.get)
model_continent = max(continent_probs, key=continent_probs.get)
if model_country == versus_state["country"]:
model_country_result = "✅"
model_points += 2
else:
model_country_result = "❌"
if model_continent == versus_state["continent"]:
model_continent_result = "✅"
model_points += 1
else:
model_continent_result = "❌"
model_score_update = (
f"+{model_points} points"
if model_points > 0
else "0 Points... The model was completely wrong, it seems the world is not doomed yet."
)
versus_state["score"]["AI"] += model_points
map = make_versus_map(human_country, model_country, versus_state)
return (
f"""
## {human_result}
### The AI 🤖 thinks this photo is from **{model_country}** {model_country_result} in **{model_continent}** {model_continent_result}
🧑 {human_score_update}
🤖 {model_score_update}
### Score 🧑 {versus_state['score']['HUMAN']} : {versus_state['score']['AI']} 🤖
""",
continent_probs,
country_probs,
map,
versus_state,
)
def get_example_images(dir):
image_extensions = (".jpg", ".jpeg", ".png")
image_files = []
for root, dirs, files in os.walk(dir):
for file in files:
if file.lower().endswith(image_extensions):
image_files.append(os.path.join(root, file))
return image_files
def next_versus_image(versus_state):
versus_image = random.sample(versus_state["images"], 1)[0]
versus_state["continent"] = versus_image.split("/")[-1].split("_")[0]
versus_state["country"] = versus_image.split("/")[-1].split("_")[1]
versus_state["lat"] = versus_image.split("/")[-1].split("_")[2]
versus_state["lon"] = versus_image.split("/")[-1].split("_")[3]
versus_state["image"] = versus_image
return versus_image, versus_state, None, None
example_images = get_example_images("kerger-test-images")
EXAMPLE_METADATA = {}
for img_path in example_images:
hash = hashlib.sha1(np.asarray(Image.open(img_path)).data.tobytes()).hexdigest()
EXAMPLE_METADATA[hash] = {
"continent": img_path.split("/")[-1].split("_")[0],
"country": img_path.split("/")[-1].split("_")[1],
"lat": img_path.split("/")[-1].split("_")[2],
"lon": img_path.split("/")[-1].split("_")[3],
}
def set_up_intial_state():
INTIAL_VERSUS_IMAGE = "versus_images/Europe_Germany_49.069183_10.319444_im2gps3k.jpg"
INITAL_VERSUS_STATE = {
"image": INTIAL_VERSUS_IMAGE,
"continent": INTIAL_VERSUS_IMAGE.split("/")[-1].split("_")[0],
"country": INTIAL_VERSUS_IMAGE.split("/")[-1].split("_")[1],
"lat": INTIAL_VERSUS_IMAGE.split("/")[-1].split("_")[2],
"lon": INTIAL_VERSUS_IMAGE.split("/")[-1].split("_")[3],
"score": {"HUMAN": 0, "AI": 0},
"images": get_example_images("versus_images")
}
return INITAL_VERSUS_STATE
demo = gr.Blocks(title="Thesis Demo")
with demo:
gr.HTML(
"""
<h1 style="text-align: center; margin-bottom: 1rem">Image Geolocation Thesis Demo</h1>
<h3> This Demo showcases the developed models and allows interacting with the optimized prototype.</h3>
<p>Try the <b>"Image Geolocation Demo"</b> tab with your own images or with one of the examples. For all example image the ground truth is available and will be displayed together with the model predictions.</p>
<p>In the <b>"Versus Mode"</b> tab you can play against the AI, guessing the country and continent where images where taken. Images in the versus mode are from the <a href="http://graphics.cs.cmu.edu/projects/im2gps/" target="_blank" rel="noopener noreferrer"><code>Im2GPS</code></a> and <a href="https://arxiv.org/abs/1705.04838" target="_blank" rel="noopener noreferrer"><code>Im2GPS3k</code></a> geolocation literature benchmarks. Can you beat the AI?
<div style="font-style: italic; font-size: smaller;">Note that inference in this publicly hosted version is very slow due to the limited and shared hardware. This demo runs on the <a href="https://huggingface.co/pricing#spaces" style="color: inherit;" target="_blank" rel="noopener noreferrer">Hugging Face free tier</a> without GPU acceleration. Running the demo with a GPU allows for inference times between 0.5-2 seconds per image.</div>
"""
)
with gr.Accordion(
label="The demo currently encompasses 116 countries from 6 continents 🌍",
open=False,
):
gr.Code(
json.dumps(countries_per_continent, indent=2, ensure_ascii=False),
label="countries_per_continent.json",
language="json",
interactive=False,
)
with gr.Tab("Image Geolocation Demo"):
with gr.Row():
with gr.Column():
image = gr.Image(
label="Image", type="pil", sources=["upload", "clipboard"], show_fullscreen_button=True
)
predict_btn = gr.Button("Predict")
example_images = get_example_images("kerger-test-images")
# example_images.extend(get_example_images("versus_images"))
gr.Examples(examples=example_images, inputs=image, examples_per_page=24)
with gr.Column():
with gr.Accordion(visible=False) as metadata_block:
map = gr.Plot(label="Locations")
with gr.Group():
continents_label = gr.Label(label="Continents")
country_label = gr.Label(num_top_classes=5, label="Top countries")
predict_btn.click(
predict,
inputs=image,
outputs=[continents_label, country_label, metadata_block, map],
)
with gr.Tab("Versus Mode"):
versus_state = gr.State(value=set_up_intial_state())
with gr.Row():
with gr.Column():
versus_image = gr.Image(versus_state.value["image"], interactive=False, show_download_button=False, show_share_button=False, show_fullscreen_button=True)
continent_selection = gr.Radio(
continents,
label="Continents",
info="Where was this image taken? (1 Point)",
)
country_selection = (
gr.Dropdown(
countries,
label="Countries",
info="Can you guess the exact country? (2 Points)",
),
)
with gr.Row():
next_img_btn = gr.Button("Try new image")
versus_btn = gr.Button("Submit guess")
with gr.Column():
versus_output = gr.Markdown()
# with gr.Accordion("View Map", open=False):
map = gr.Plot(label="Locations")
with gr.Accordion("Full Model Output", open=False):
with gr.Group():
continents_label = gr.Label(label="Continents")
country_label = gr.Label(
num_top_classes=5, label="Top countries"
)
next_img_btn.click(
next_versus_image,
inputs=[versus_state],
outputs=[
versus_image,
versus_state,
continent_selection,
country_selection[0],
],
)
versus_btn.click(
versus_mode_inputs,
inputs=[
versus_image,
continent_selection,
country_selection[0],
versus_state,
],
outputs=[versus_output, continents_label, country_label, map, versus_state],
)
if __name__ == "__main__":
demo.launch(show_api=False)