Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| import numpy as np | |
| import gradio as gr | |
| from transformers import CLIPProcessor, CLIPModel | |
| import torch | |
| import itertools | |
| import os | |
| import plotly.graph_objects as go | |
| CUDA_AVAILABLE = torch.cuda.is_available() | |
| print(f"CUDA={CUDA_AVAILABLE}") | |
| device = "cuda" if CUDA_AVAILABLE else "cpu" | |
| print(f"count={torch.cuda.device_count()}") | |
| print(f"current={torch.cuda.get_device_name(torch.cuda.current_device())}") | |
| continent_model = CLIPModel.from_pretrained("model-checkpoints/continent") | |
| country_model = CLIPModel.from_pretrained("model-checkpoints/country") | |
| processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14-336") | |
| continent_model = continent_model.to(device) | |
| country_model = country_model.to(device) | |
| continents = ["Africa", "Asia", "Europe", | |
| "North America", "Oceania", "South America"] | |
| countries_per_continent = { | |
| "Africa": [ | |
| "Algeria", "Angola", "Benin", "Botswana", "Burkina Faso", "Burundi", "Cabo Verde", "Cameroon", | |
| "Central African Republic", "Congo", "Democratic Republic of the Congo", | |
| "Djibouti", "Egypt", "Equatorial Guinea", "Eritrea", "Eswatini", "Ethiopia", "Gabon", | |
| "Gambia", "Ghana", "Guinea", "Guinea-Bissau", "Ivory Coast", "Kenya", "Lesotho", "Liberia", | |
| "Libya", "Madagascar", "Malawi", "Mali", "Mauritania", "Mauritius", "Morocco", "Mozambique", | |
| "Namibia", "Niger", "Nigeria", "Rwanda", "Sao Tome and Principe", "Senegal", "Seychelles", | |
| "Sierra Leone", "Somalia", "South Africa", "Sudan", "Tanzania", "Togo", | |
| "Tunisia", "Uganda", "Zambia", "Zimbabwe" | |
| ], | |
| "Asia": [ | |
| "Afghanistan", "Armenia", "Azerbaijan", "Bahrain", "Bangladesh", "Bhutan", "Brunei", | |
| "Cambodia", "China", "Cyprus", "Georgia", "India", "Indonesia", "Iran", "Iraq", | |
| "Israel", "Japan", "Jordan", "Kazakhstan", "Kuwait", "Kyrgyzstan", "Laos", "Lebanon", | |
| "Malaysia", "Maldives", "Mongolia", "Myanmar", "Nepal", "North Korea", "Oman", "Pakistan", | |
| "Palestine", "Philippines", "Qatar", "Russia", "Saudi Arabia", "Singapore", "South Korea", | |
| "Sri Lanka", "Syria", "Taiwan", "Tajikistan", "Thailand", "Timor-Leste", "Turkey", | |
| "Turkmenistan", "United Arab Emirates", "Uzbekistan", "Vietnam", "Yemen" | |
| ], | |
| "Europe": [ | |
| "Albania", "Armenia", "Austria", "Azerbaijan", "Belarus", "Belgium", "Bosnia and Herzegovina", | |
| "Bulgaria", "Croatia", "Cyprus", "Czech Republic", "Denmark", "Estonia", "Finland", "France", | |
| "Georgia", "Germany", "Greece", "Hungary", "Iceland", "Ireland", "Italy", "Kazakhstan", | |
| "Kosovo", "Latvia", "Liechtenstein", "Lithuania", "Luxembourg", "Malta", "Moldova", "Monaco", | |
| "Montenegro", "Netherlands", "North Macedonia", "Norway", "Poland", "Portugal", "Romania", | |
| "Russia", "San Marino", "Serbia", "Slovakia", "Slovenia", "Spain", "Sweden", "Switzerland", | |
| "Turkey", "Ukraine", "United Kingdom" | |
| ], | |
| "North America": [ | |
| "Antigua and Barbuda", "Bahamas", "Barbados", "Belize", "Canada", "Costa Rica", "Cuba", | |
| "Dominica", "Dominican Republic", "El Salvador", "Grenada", "Guatemala", "Haiti", "Honduras", | |
| "Jamaica", "Mexico", "Nicaragua", "Panama", "Saint Kitts and Nevis", "Saint Lucia", | |
| "Saint Vincent and the Grenadines", "Trinidad and Tobago", "United States" | |
| ], | |
| "Oceania": [ | |
| "Australia", "Fiji", "Kiribati", "Marshall Islands", "Micronesia", "Nauru", "New Zealand", | |
| "Palau", "Papua New Guinea", "Samoa", "Solomon Islands", "Tonga", "Tuvalu", "Vanuatu" | |
| ], | |
| "South America": [ | |
| "Argentina", "Bolivia", "Brazil", "Chile", "Colombia", "Ecuador", "Guyana", "Paraguay", | |
| "Peru", "Suriname", "Uruguay", "Venezuela" | |
| ] | |
| } | |
| countries = list(set(itertools.chain.from_iterable( | |
| countries_per_continent.values()))) | |
| INTIAL_VERSUS_IMAGE = "versus_images/Europe_Germany_49.069183_10.319444_im2gps3k.jpg" | |
| INITAL_VERSUS_STATE = { | |
| "image": INTIAL_VERSUS_IMAGE, | |
| "continent": INTIAL_VERSUS_IMAGE.split("/")[-1].split("_")[0], | |
| "country": INTIAL_VERSUS_IMAGE.split("/")[-1].split("_")[1], | |
| "lat": INTIAL_VERSUS_IMAGE.split("/")[-1].split("_")[2], | |
| "lon": INTIAL_VERSUS_IMAGE.split("/")[-1].split("_")[3], | |
| "score": { | |
| "HUMAN": 0, | |
| "AI": 0 | |
| }, | |
| "idx": 0 | |
| } | |
| def predict(input_img): | |
| inputs = processor(text=[f"A photo from { | |
| geo}." for geo in continents], images=input_img, return_tensors="pt", padding=True) | |
| inputs = inputs.to(device) | |
| with torch.no_grad(): | |
| outputs = continent_model(**inputs) | |
| logits_per_image = outputs.logits_per_image | |
| probs = logits_per_image.softmax(dim=-1) | |
| pred_id = probs.argmax().cpu().item() | |
| continent_probs = {label: prob for label, | |
| prob in zip(continents, probs.tolist()[0])} | |
| predicted_continent_countries = countries_per_continent[continents[pred_id]] | |
| inputs = processor(text=[f"A photo from { | |
| geo}." for geo in predicted_continent_countries], images=input_img, return_tensors="pt", padding=True) | |
| inputs = inputs.to(device) | |
| with torch.no_grad(): | |
| outputs = country_model(**inputs) | |
| logits_per_image = outputs.logits_per_image | |
| probs = logits_per_image.softmax(dim=-1) | |
| country_probs = {label: prob for label, prob in zip( | |
| predicted_continent_countries, probs.tolist()[0])} | |
| return continent_probs, country_probs | |
| def make_versus_map(human_country, model_country, versus_state): | |
| fig = go.Figure() | |
| fig.add_trace(go.Scattergeo( | |
| lon=[versus_state["lon"]], | |
| lat=[versus_state["lat"]], | |
| text=["π·"], | |
| mode='text+markers', | |
| hoverinfo='text', | |
| hovertext=f"Photo taken in {versus_state['country']}, { | |
| versus_state['continent']}", | |
| marker=dict(size=14, color='#00B945'), | |
| showlegend=False | |
| )) | |
| if human_country == model_country: | |
| fig.add_trace(go.Scattergeo( | |
| locations=[human_country], | |
| locationmode='country names', | |
| text=["π§π€"], | |
| mode='text', | |
| hoverinfo='location', | |
| showlegend=False | |
| )) | |
| else: | |
| fig.add_trace(go.Scattergeo( | |
| locations=[human_country], | |
| locationmode='country names', | |
| text=["π§"], | |
| mode='text', | |
| hoverinfo='location', | |
| showlegend=False | |
| )) | |
| fig.add_trace(go.Scattergeo( | |
| locations=[model_country], | |
| locationmode='country names', | |
| text=["π€"], | |
| mode='text', | |
| hoverinfo='location', | |
| showlegend=False | |
| )) | |
| fig.update_geos( | |
| visible=True, resolution=110, | |
| showcountries=True, countrycolor="grey", fitbounds="locations", projection_type="natural earth", | |
| ) | |
| return fig | |
| def versus_mode_inputs(input_img, human_continent, human_country, versus_state): | |
| human_points = 0 | |
| model_points = 0 | |
| if human_country == versus_state["country"]: | |
| country_result = "β " | |
| human_points += 2 | |
| else: | |
| country_result = "β" | |
| if human_continent == versus_state["continent"]: | |
| continent_result = "β " | |
| human_points += 1 | |
| else: | |
| continent_result = "β" | |
| human_result = f"The photo is from **{versus_state['country']}** { | |
| country_result} in **{versus_state['continent']}** {continent_result}" | |
| human_score_update = f"+{ | |
| human_points} points" if human_points > 0 else "Wrong guess, try a new image." | |
| versus_state['score']['HUMAN'] += human_points | |
| continent_probs, country_probs = predict(input_img) | |
| model_country = max(country_probs, key=country_probs.get) | |
| model_continent = max(continent_probs, key=continent_probs.get) | |
| if model_country == versus_state["country"]: | |
| model_country_result = "β " | |
| model_points += 2 | |
| else: | |
| model_country_result = "β" | |
| if model_continent == versus_state["continent"]: | |
| model_continent_result = "β " | |
| model_points += 1 | |
| else: | |
| model_continent_result = "β" | |
| model_score_update = f"+{ | |
| model_points} points" if model_points > 0 else "The model was wrong, seems the world is not yet doomed." | |
| versus_state['score']['AI'] += model_points | |
| map = make_versus_map(human_country, model_country, versus_state) | |
| return f""" | |
| ## {human_result} | |
| ### The AI π€ thinks this photo is from **{model_country}** {model_country_result} in **{model_continent}** {model_continent_result} | |
| π§ {human_score_update} | |
| π€ {model_score_update} | |
| ### Score π§ {versus_state['score']['HUMAN']} : {versus_state['score']['AI']} π€ | |
| """, continent_probs, country_probs, map, versus_state | |
| def get_example_images(dir): | |
| image_extensions = (".jpg", ".jpeg", ".png") | |
| image_files = [] | |
| for root, dirs, files in os.walk(dir): | |
| for file in files: | |
| if file.lower().endswith(image_extensions): | |
| image_files.append(os.path.join(root, file)) | |
| return image_files | |
| def next_versus_image(versus_state): | |
| images = get_example_images("versus_images") | |
| versus_state["idx"] += 1 | |
| if versus_state["idx"] > len(images): | |
| versus_state["idx"] = 0 | |
| versus_image = images[versus_state["idx"]] | |
| versus_state["continent"] = versus_image.split("/")[-1].split("_")[0] | |
| versus_state["country"] = versus_image.split("/")[-1].split("_")[1] | |
| versus_state["lat"] = versus_image.split("/")[-1].split("_")[2] | |
| versus_state["lon"] = versus_image.split("/")[-1].split("_")[3] | |
| versus_state["image"] = versus_image | |
| return versus_image, versus_state, None, None | |
| demo = gr.Blocks() | |
| with demo: | |
| with gr.Tab("Image Geolocation Demo"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| image = gr.Image(label="Image", type="pil", | |
| sources=["upload", "clipboard"]) | |
| predict_btn = gr.Button("Predict") | |
| example_images = get_example_images("kerger-test-images") | |
| # example_images.extend(get_example_images("versus_images")) | |
| gr.Examples(examples=example_images, | |
| inputs=image, examples_per_page=24) | |
| with gr.Column(): | |
| continents_label = gr.Label(label="Continents") | |
| country_label = gr.Label( | |
| num_top_classes=5, label="Top countries") | |
| # continents_label.select(predict_country, inputs=[image, continents_label], outputs=country_label) | |
| predict_btn.click(predict, inputs=image, outputs=[ | |
| continents_label, country_label]) | |
| with gr.Tab("Versus Mode"): | |
| versus_state = gr.State(value=INITAL_VERSUS_STATE) | |
| with gr.Row(): | |
| with gr.Column(): | |
| versus_image = gr.Image( | |
| INITAL_VERSUS_STATE["image"], interactive=False) | |
| continent_selection = gr.Radio( | |
| continents, label="Continents", info="Where was this image taken? (1 Point)") | |
| country_selection = gr.Dropdown(countries, label="Countries", info="Can you guess the exact country? (2 Points)" | |
| ), | |
| with gr.Row(): | |
| next_img_btn = gr.Button("Try new image") | |
| versus_btn = gr.Button("Submit guess") | |
| with gr.Column(): | |
| versus_output = gr.Markdown() | |
| # with gr.Accordion("View Map", open=False): | |
| map = gr.Plot(label="Locations") | |
| with gr.Accordion("Full Model Output", open=False): | |
| continents_label = gr.Label(label="Continents") | |
| country_label = gr.Label( | |
| num_top_classes=5, label="Top countries") | |
| next_img_btn.click(next_versus_image, inputs=[versus_state], outputs=[versus_image, versus_state, continent_selection, country_selection[0]]) | |
| versus_btn.click(versus_mode_inputs, inputs=[versus_image, continent_selection, country_selection[0], versus_state], outputs=[ | |
| versus_output, continents_label, country_label, map, versus_state]) | |
| if __name__ == "__main__": | |
| demo.launch() | |