Spaces:
Runtime error
Runtime error
import numpy as np | |
import gradio as gr | |
from transformers import CLIPProcessor, CLIPModel | |
import torch | |
import itertools | |
import os | |
model = CLIPModel.from_pretrained("model-checkpoint") | |
processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14-336") | |
continents = ["Africa", "Asia", "Europe", "North America", "Oceania", "South America"] | |
countries_per_continent = { | |
"Africa": [ | |
"Algeria", "Angola", "Benin", "Botswana", "Burkina Faso", "Burundi", "Cabo Verde", "Cameroon", | |
"Central African Republic", "Chad", "Comoros", "Congo", "Democratic Republic of the Congo", | |
"Djibouti", "Egypt", "Equatorial Guinea", "Eritrea", "Eswatini", "Ethiopia", "Gabon", | |
"Gambia", "Ghana", "Guinea", "Guinea-Bissau", "Ivory Coast", "Kenya", "Lesotho", "Liberia", | |
"Libya", "Madagascar", "Malawi", "Mali", "Mauritania", "Mauritius", "Morocco", "Mozambique", | |
"Namibia", "Niger", "Nigeria", "Rwanda", "Sao Tome and Principe", "Senegal", "Seychelles", | |
"Sierra Leone", "Somalia", "South Africa", "South Sudan", "Sudan", "Tanzania", "Togo", | |
"Tunisia", "Uganda", "Zambia", "Zimbabwe" | |
], | |
"Asia": [ | |
"Afghanistan", "Armenia", "Azerbaijan", "Bahrain", "Bangladesh", "Bhutan", "Brunei", | |
"Cambodia", "China", "Cyprus", "Georgia", "India", "Indonesia", "Iran", "Iraq", | |
"Israel", "Japan", "Jordan", "Kazakhstan", "Kuwait", "Kyrgyzstan", "Laos", "Lebanon", | |
"Malaysia", "Maldives", "Mongolia", "Myanmar", "Nepal", "North Korea", "Oman", "Pakistan", | |
"Palestine", "Philippines", "Qatar", "Russia", "Saudi Arabia", "Singapore", "South Korea", | |
"Sri Lanka", "Syria", "Taiwan", "Tajikistan", "Thailand", "Timor-Leste", "Turkey", | |
"Turkmenistan", "United Arab Emirates", "Uzbekistan", "Vietnam", "Yemen" | |
], | |
"Europe": [ | |
"Albania", "Andorra", "Armenia", "Austria", "Azerbaijan", "Belarus", "Belgium", "Bosnia and Herzegovina", | |
"Bulgaria", "Croatia", "Cyprus", "Czech Republic", "Denmark", "Estonia", "Finland", "France", | |
"Georgia", "Germany", "Greece", "Hungary", "Iceland", "Ireland", "Italy", "Kazakhstan", | |
"Kosovo", "Latvia", "Liechtenstein", "Lithuania", "Luxembourg", "Malta", "Moldova", "Monaco", | |
"Montenegro", "Netherlands", "North Macedonia", "Norway", "Poland", "Portugal", "Romania", | |
"Russia", "San Marino", "Serbia", "Slovakia", "Slovenia", "Spain", "Sweden", "Switzerland", | |
"Turkey", "Ukraine", "United Kingdom", "Vatican City" | |
], | |
"North America": [ | |
"Antigua and Barbuda", "Bahamas", "Barbados", "Belize", "Canada", "Costa Rica", "Cuba", | |
"Dominica", "Dominican Republic", "El Salvador", "Grenada", "Guatemala", "Haiti", "Honduras", | |
"Jamaica", "Mexico", "Nicaragua", "Panama", "Saint Kitts and Nevis", "Saint Lucia", | |
"Saint Vincent and the Grenadines", "Trinidad and Tobago", "United States" | |
], | |
"Oceania": [ | |
"Australia", "Fiji", "Kiribati", "Marshall Islands", "Micronesia", "Nauru", "New Zealand", | |
"Palau", "Papua New Guinea", "Samoa", "Solomon Islands", "Tonga", "Tuvalu", "Vanuatu" | |
], | |
"South America": [ | |
"Argentina", "Bolivia", "Brazil", "Chile", "Colombia", "Ecuador", "Guyana", "Paraguay", | |
"Peru", "Suriname", "Uruguay", "Venezuela" | |
] | |
} | |
countries = list(set(itertools.chain.from_iterable(countries_per_continent.values()))) | |
VERSUS_IMAGE = "versus_images/254762082_1a1b6d27d1_121_79267922@N00.jpg" | |
VERSUS_GT = { | |
"continent": "test", | |
"country": "test" | |
} | |
def predict(input_img): | |
inputs = processor(text=[f"A photo from {geo}." for geo in continents], images=input_img, return_tensors="pt", padding=True) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits_per_image = outputs.logits_per_image | |
probs = logits_per_image.softmax(dim=-1) | |
pred_id = probs.argmax().cpu().item() | |
continent_probs = {label: prob for label, prob in zip(continents, probs.tolist()[0])} | |
print(continent_probs) | |
predicted_continent_countries = countries_per_continent[continents[pred_id]] | |
inputs = processor(text=[f"A photo from {geo}." for geo in predicted_continent_countries], images=input_img, return_tensors="pt", padding=True) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits_per_image = outputs.logits_per_image | |
probs = logits_per_image.softmax(dim=-1) | |
country_probs = {label: prob for label, prob in zip(predicted_continent_countries, probs.tolist()[0])} | |
print(country_probs) | |
return continent_probs, country_probs | |
def versus_mode_inputs(input_img, human_continent, human_country): | |
print(human_continent, human_country) | |
continent_probs, country_probs = predict(input_img) | |
return f"human guessed {human_continent} {human_country}\ngroung truth {VERSUS_GT}", continent_probs, country_probs | |
def next_versus_image(): | |
# VERSUS_GT["continent"] = "test" | |
# VERSUS_GT["country"] = "test" | |
return "versus_images/780091415_c803d82672_1332_88971695@N00.jpg" | |
def get_example_images(dir): | |
image_extensions = (".jpg", ".jpeg", ".png") | |
image_files = [] | |
for root, dirs, files in os.walk(dir): | |
for file in files: | |
if file.lower().endswith(image_extensions): | |
image_files.append(os.path.join(root, file)) | |
return image_files | |
demo = gr.Blocks() | |
with demo: | |
with gr.Tab("Image Geolocation Demo"): | |
with gr.Row(): | |
with gr.Column(): | |
image = gr.Image(label="Image", type="pil", sources=["upload","clipboard"]) | |
predict_btn = gr.Button("Predict") | |
example_images = get_example_images("examples") | |
example_images.extend(get_example_images("versus_images")) | |
gr.Examples(examples=example_images, inputs=image, examples_per_page=24) | |
with gr.Column(): | |
continents_label = gr.Label(label="Continents") | |
country_label = gr.Label(num_top_classes=5, label="Top countries") | |
# continents_label.select(predict_country, inputs=[image, continents_label], outputs=country_label) | |
predict_btn.click(predict, inputs=image, outputs=[continents_label, country_label]) | |
with gr.Tab("Versus Mode"): | |
with gr.Row(): | |
with gr.Column(): | |
versus_image = gr.Image(VERSUS_IMAGE, interactive=False) | |
continent_selection = gr.Radio(continents, label="Continents", info="Where was this image taken?") | |
country_selection = gr.Dropdown(countries, label="Countries", info="Can you guess the exact country?" | |
), | |
with gr.Row(): | |
next_img_btn = gr.Button("Try new image") | |
versus_btn = gr.Button("Submit guess") | |
with gr.Column(): | |
versus_output = gr.Text() | |
continents_label = gr.Label(label="Continents") | |
country_label = gr.Label(num_top_classes=5, label="Top countries") | |
next_img_btn.click(next_versus_image, outputs=versus_image) | |
versus_btn.click(versus_mode_inputs, inputs=[versus_image, continent_selection, country_selection[0]], outputs=[versus_output, continents_label, country_label]) | |
if __name__ == "__main__": | |
demo.launch() |