# Import necessary libraries import streamlit as st import pandas as pd from transformers import PretrainedConfig, PreTrainedModel, T5EncoderModel, AutoTokenizer import torch import torch.nn as nn import copy import pydeck as pdk import numpy as np import base64 def get_base64_encoded_image(image_path): with open(image_path, "rb") as img_file: return base64.b64encode(img_file.read()).decode('utf-8') keep_layer_count=6 class ByT5ForTextGeotaggingConfig(PretrainedConfig): model_type = "byt5_for_text_geotagging" def __init__(self, n_clusters, model_name_or_path, class_to_location=None, **kwargs): super(ByT5ForTextGeotaggingConfig, self).__init__(**kwargs) self.n_clusters = n_clusters self.model_name_or_path = model_name_or_path self.class_to_location = class_to_location or {} def to_diff_dict(self): # Convert the configuration to a dictionary config_dict = self.to_dict() # Get the default configuration for comparison default_config_dict = PretrainedConfig().to_dict() # Return the differences diff_dict = {k: v for k, v in config_dict.items() if k not in default_config_dict or v != default_config_dict[k]} return diff_dict def deleteEncodingLayers(model, num_layers_to_keep): # must pass in the full bert model oldModuleList = model.encoder.block newModuleList = torch.nn.ModuleList() # Now iterate over all layers, only keepign only the relevant layers. for i in range(0, num_layers_to_keep): newModuleList.append(oldModuleList[i]) # create a copy of the model, modify it with the new list, and return copyOfModel = copy.deepcopy(model) copyOfModel.encoder.block = newModuleList return copyOfModel class ByT5ForTextGeotagging(PreTrainedModel): config_class = ByT5ForTextGeotaggingConfig def __init__(self, config): super(ByT5ForTextGeotagging, self).__init__(config) self.byt5 = T5EncoderModel.from_pretrained(config.model_name_or_path) if keep_layer_count is not None: self.byt5 = deleteEncodingLayers(self.byt5, keep_layer_count) hidden_size = self.byt5.config.d_model self.fc3 = nn.Linear(hidden_size, config.n_clusters) def forward(self, input, return_coordinates=False): input = self.byt5(input[:, 0, :].squeeze(1))['last_hidden_state'] input = input[:, 0, :].squeeze(1) logits = self.fc3(input) if return_coordinates: class_idx = torch.argmax(logits, dim=1).item() coordinates = self.config.class_to_location.get(str(class_idx)) confidence = torch.max(torch.nn.functional.softmax(logits)).item() return logits, coordinates, confidence else: return logits @st.cache_resource(ttl=None) def load_model_and_tokenizer(): byt5_tokenizer = AutoTokenizer.from_pretrained("yachay/byt5-geotagging-es", token=st.secrets['token']) model = ByT5ForTextGeotagging.from_pretrained("yachay/byt5-geotagging-es", token=st.secrets['token']) return byt5_tokenizer, model byt5_tokenizer, model = load_model_and_tokenizer() def geolocate_text_byt5_multiclass(text): input_tensor = byt5_tokenizer(text, return_tensors="pt", truncation=True, max_length=140)['input_ids'] logits, (lat, lon), confidence = model(input_tensor.unsqueeze(0), return_coordinates=True) probas = torch.nn.functional.softmax(logits, dim=1).detach().cpu().numpy() # Sort probabilities in descending order and get their indices sorted_indices = np.argsort(-probas[0]) results = [] cumulative_prob = 0.0 for class_idx in sorted_indices: prob = probas[0][class_idx] cumulative_prob += prob if cumulative_prob > 0.5: break coordinates = model.config.class_to_location.get(str(class_idx)) if coordinates: results.append((class_idx, prob, coordinates)) # Check if at least one result is added; if not, add the highest probability class if not results: class_idx = sorted_indices[0] prob = probas[0][class_idx] coordinates = model.config.class_to_location.get(str(class_idx)) if coordinates: results.append((class_idx, prob, coordinates)) return results def geolocate_text_byt5(text): input_tensor = byt5_tokenizer(text, return_tensors="pt", truncation=True, max_length=140)['input_ids'] logits, (lat, lon), confidence = model(input_tensor.unsqueeze(0), return_coordinates=True) return lat, lon, confidence if 'text_input' not in st.session_state: st.session_state.text_input = "" if 'text_modified' not in st.session_state: st.session_state.text_modified = "" # When an example button is clicked, update the session state def set_example_text(example_text): st.session_state.text_input_state = example_text example_texts = [ "Disfrutando de una paella deliciosa en las playas de #Valencia 🥘☀️", "Una semana de conciertos, fuegos artificiales y txosnas en Aste Nagusia. ¡Esta ciudad sabe cómo celebrar! #Fiestas #AsteNagusia", "Viendo las destrezas de los gauchos en la Semana Criolla. Los asados también son para morirse. #SemanaCriolla #Tradición", "Bailando tango en las calles de Buenos Aires", "Admirando las hermosas playas de Cancún. #México" ] # Streamlit interface st.title('GeoTagging using ByT5') st.write('Examples:') for example in example_texts: if st.button(f"{example}"): st.session_state.text_input = example st.session_state.text_modified = st.session_state.text_input # Get text input and update session state when it's modified st.session_state.text_modified = st.text_input('Enter your text:', value=st.session_state.text_input) if st.button('Submit'): st.session_state.text_input = st.session_state.text_modified if st.session_state.text_input: results = geolocate_text_byt5_multiclass(st.session_state.text_input) #st.write(results) _, confidence, (lat, lon) = results[0] if len(results) == 1: confidence_def = 'High' elif len(results) < 50: confidence_def = 'Low' else: confidence_def = 'Very low' st.write('Predicted Location: (', lat, lon, '). Confidence: ', confidence_def) if confidence_def == 'Low': st.write('Multiple possible locations were identified as confidence is low') elif confidence_def == 'Very low': st.write('There are too many possible locations as confidence is very low') # Render map with pydeck map_data = pd.DataFrame( [[lat, lon]], columns=["lat", "lon"] ) encoded_image = get_base64_encoded_image("icons8-map-pin-48.png") icon_url = f"data:image/png;base64,{encoded_image}" # Example icon data icon_data = { "url": icon_url, # URL of the icon image "width": 128, # Width of the icon in pixels "height": 128, # Height of the icon in pixels "anchorY": 128 # Anchor point of the icon in pixels (bottom center) } # Example location data locations = pd.DataFrame({ 'lat': [lat], # Latitude values 'lon': [lon], # Longitude values 'icon_data': [icon_data] # Repeating the icon data for each location }) layers = [] if confidence_def != 'Very low': # Add layers for each additional marker for item in results: _, confidence, (lat, lon) = item layer = pdk.Layer( type='IconLayer', data=pd.DataFrame({ 'lat': [lat], # Latitude values 'lon': [lon], # Longitude values 'icon_data': [icon_data] # Repeating the icon data for each location }), get_icon='icon_data', get_size=4, size_scale=15, get_position=['lon', 'lat'], pickable=True ) layers.append(layer) st.pydeck_chart(pdk.Deck( map_style='mapbox://styles/mapbox/light-v9', initial_view_state=pdk.ViewState( latitude=lat, longitude=lon, zoom=6, pitch=50, ), layers=layers, ))