Spaces:
Build error
Build error
| # Import necessary libraries | |
| import streamlit as st | |
| import pandas as pd | |
| from transformers import PretrainedConfig, PreTrainedModel, T5EncoderModel, AutoTokenizer | |
| import torch | |
| import torch.nn as nn | |
| import copy | |
| import pydeck as pdk | |
| import numpy as np | |
| import base64 | |
| def get_base64_encoded_image(image_path): | |
| with open(image_path, "rb") as img_file: | |
| return base64.b64encode(img_file.read()).decode('utf-8') | |
| keep_layer_count=6 | |
| class ByT5ForTextGeotaggingConfig(PretrainedConfig): | |
| model_type = "byt5_for_text_geotagging" | |
| def __init__(self, n_clusters, model_name_or_path, class_to_location=None, **kwargs): | |
| super(ByT5ForTextGeotaggingConfig, self).__init__(**kwargs) | |
| self.n_clusters = n_clusters | |
| self.model_name_or_path = model_name_or_path | |
| self.class_to_location = class_to_location or {} | |
| def to_diff_dict(self): | |
| # Convert the configuration to a dictionary | |
| config_dict = self.to_dict() | |
| # Get the default configuration for comparison | |
| default_config_dict = PretrainedConfig().to_dict() | |
| # Return the differences | |
| diff_dict = {k: v for k, v in config_dict.items() if k not in default_config_dict or v != default_config_dict[k]} | |
| return diff_dict | |
| def deleteEncodingLayers(model, num_layers_to_keep): # must pass in the full bert model | |
| oldModuleList = model.encoder.block | |
| newModuleList = torch.nn.ModuleList() | |
| # Now iterate over all layers, only keepign only the relevant layers. | |
| for i in range(0, num_layers_to_keep): | |
| newModuleList.append(oldModuleList[i]) | |
| # create a copy of the model, modify it with the new list, and return | |
| copyOfModel = copy.deepcopy(model) | |
| copyOfModel.encoder.block = newModuleList | |
| return copyOfModel | |
| class ByT5ForTextGeotagging(PreTrainedModel): | |
| config_class = ByT5ForTextGeotaggingConfig | |
| def __init__(self, config): | |
| super(ByT5ForTextGeotagging, self).__init__(config) | |
| self.byt5 = T5EncoderModel.from_pretrained(config.model_name_or_path) | |
| if keep_layer_count is not None: | |
| self.byt5 = deleteEncodingLayers(self.byt5, keep_layer_count) | |
| hidden_size = self.byt5.config.d_model | |
| self.fc3 = nn.Linear(hidden_size, config.n_clusters) | |
| def forward(self, input, return_coordinates=False): | |
| input = self.byt5(input[:, 0, :].squeeze(1))['last_hidden_state'] | |
| input = input[:, 0, :].squeeze(1) | |
| logits = self.fc3(input) | |
| if return_coordinates: | |
| class_idx = torch.argmax(logits, dim=1).item() | |
| coordinates = self.config.class_to_location.get(str(class_idx)) | |
| confidence = torch.max(torch.nn.functional.softmax(logits)).item() | |
| return logits, coordinates, confidence | |
| else: | |
| return logits | |
| def load_model_and_tokenizer(): | |
| byt5_tokenizer = AutoTokenizer.from_pretrained("yachay/byt5-geotagging-es", token=st.secrets['token']) | |
| model = ByT5ForTextGeotagging.from_pretrained("yachay/byt5-geotagging-es", token=st.secrets['token']) | |
| return byt5_tokenizer, model | |
| byt5_tokenizer, model = load_model_and_tokenizer() | |
| def geolocate_text_byt5_multiclass(text): | |
| input_tensor = byt5_tokenizer(text, return_tensors="pt", truncation=True, max_length=140)['input_ids'] | |
| logits, (lat, lon), confidence = model(input_tensor.unsqueeze(0), return_coordinates=True) | |
| probas = torch.nn.functional.softmax(logits, dim=1).detach().cpu().numpy() | |
| # Sort probabilities in descending order and get their indices | |
| sorted_indices = np.argsort(-probas[0]) | |
| results = [] | |
| cumulative_prob = 0.0 | |
| for class_idx in sorted_indices: | |
| prob = probas[0][class_idx] | |
| cumulative_prob += prob | |
| if cumulative_prob > 0.5: | |
| break | |
| coordinates = model.config.class_to_location.get(str(class_idx)) | |
| if coordinates: | |
| results.append((class_idx, prob, coordinates)) | |
| # Check if at least one result is added; if not, add the highest probability class | |
| if not results: | |
| class_idx = sorted_indices[0] | |
| prob = probas[0][class_idx] | |
| coordinates = model.config.class_to_location.get(str(class_idx)) | |
| if coordinates: | |
| results.append((class_idx, prob, coordinates)) | |
| return results | |
| def geolocate_text_byt5(text): | |
| input_tensor = byt5_tokenizer(text, return_tensors="pt", truncation=True, max_length=140)['input_ids'] | |
| logits, (lat, lon), confidence = model(input_tensor.unsqueeze(0), return_coordinates=True) | |
| return lat, lon, confidence | |
| if 'text_input' not in st.session_state: | |
| st.session_state.text_input = "" | |
| if 'text_modified' not in st.session_state: | |
| st.session_state.text_modified = "" | |
| # When an example button is clicked, update the session state | |
| def set_example_text(example_text): | |
| st.session_state.text_input_state = example_text | |
| example_texts = [ | |
| "Disfrutando de una paella deliciosa en las playas de #Valencia 🥘☀️", | |
| "Una semana de conciertos, fuegos artificiales y txosnas en Aste Nagusia. ¡Esta ciudad sabe cómo celebrar! #Fiestas #AsteNagusia", | |
| "Viendo las destrezas de los gauchos en la Semana Criolla. Los asados también son para morirse. #SemanaCriolla #Tradición", | |
| "Bailando tango en las calles de Buenos Aires", | |
| "Admirando las hermosas playas de Cancún. #México" | |
| ] | |
| # Streamlit interface | |
| st.title('GeoTagging using ByT5') | |
| st.write('Examples:') | |
| for example in example_texts: | |
| if st.button(f"{example}"): | |
| st.session_state.text_input = example | |
| st.session_state.text_modified = st.session_state.text_input | |
| # Get text input and update session state when it's modified | |
| st.session_state.text_modified = st.text_input('Enter your text:', value=st.session_state.text_input) | |
| if st.button('Submit'): | |
| st.session_state.text_input = st.session_state.text_modified | |
| if st.session_state.text_input: | |
| results = geolocate_text_byt5_multiclass(st.session_state.text_input) | |
| #st.write(results) | |
| _, confidence, (lat, lon) = results[0] | |
| if len(results) == 1: | |
| confidence_def = 'High' | |
| elif len(results) < 50: | |
| confidence_def = 'Low' | |
| else: | |
| confidence_def = 'Very low' | |
| st.write('Predicted Location: (', lat, lon, '). Confidence: ', confidence_def) | |
| if confidence_def == 'Low': | |
| st.write('Multiple possible locations were identified as confidence is low') | |
| elif confidence_def == 'Very low': | |
| st.write('There are too many possible locations as confidence is very low') | |
| # Render map with pydeck | |
| map_data = pd.DataFrame( | |
| [[lat, lon]], | |
| columns=["lat", "lon"] | |
| ) | |
| encoded_image = get_base64_encoded_image("icons8-map-pin-48.png") | |
| icon_url = f"data:image/png;base64,{encoded_image}" | |
| # Example icon data | |
| icon_data = { | |
| "url": icon_url, # URL of the icon image | |
| "width": 128, # Width of the icon in pixels | |
| "height": 128, # Height of the icon in pixels | |
| "anchorY": 128 # Anchor point of the icon in pixels (bottom center) | |
| } | |
| # Example location data | |
| locations = pd.DataFrame({ | |
| 'lat': [lat], # Latitude values | |
| 'lon': [lon], # Longitude values | |
| 'icon_data': [icon_data] # Repeating the icon data for each location | |
| }) | |
| layers = [] | |
| if confidence_def != 'Very low': | |
| # Add layers for each additional marker | |
| for item in results: | |
| _, confidence, (lat, lon) = item | |
| layer = pdk.Layer( | |
| type='IconLayer', | |
| data=pd.DataFrame({ | |
| 'lat': [lat], # Latitude values | |
| 'lon': [lon], # Longitude values | |
| 'icon_data': [icon_data] # Repeating the icon data for each location | |
| }), | |
| get_icon='icon_data', | |
| get_size=4, | |
| size_scale=15, | |
| get_position=['lon', 'lat'], | |
| pickable=True | |
| ) | |
| layers.append(layer) | |
| st.pydeck_chart(pdk.Deck( | |
| map_style='mapbox://styles/mapbox/light-v9', | |
| initial_view_state=pdk.ViewState( | |
| latitude=lat, | |
| longitude=lon, | |
| zoom=6, | |
| pitch=50, | |
| ), | |
| layers=layers, | |
| )) |