# Import necessary libraries import streamlit as st from transformers import PretrainedConfig, PreTrainedModel, T5EncoderModel, AutoTokenizer import torch import torch.nn as nn import copy import pydeck as pdk keep_layer_count=6 byt5_tokenizer = AutoTokenizer.from_pretrained("yachay/byt5-geotagging-es", token="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") class ByT5ForTextGeotaggingConfig(PretrainedConfig): model_type = "byt5_for_text_geotagging" def __init__(self, n_clusters, model_name_or_path, class_to_location=None, **kwargs): super(ByT5ForTextGeotaggingConfig, self).__init__(**kwargs) self.n_clusters = n_clusters self.model_name_or_path = model_name_or_path self.class_to_location = class_to_location or {} def to_diff_dict(self): # Convert the configuration to a dictionary config_dict = self.to_dict() # Get the default configuration for comparison default_config_dict = PretrainedConfig().to_dict() # Return the differences diff_dict = {k: v for k, v in config_dict.items() if k not in default_config_dict or v != default_config_dict[k]} return diff_dict def deleteEncodingLayers(model, num_layers_to_keep): # must pass in the full bert model oldModuleList = model.encoder.block newModuleList = torch.nn.ModuleList() # Now iterate over all layers, only keepign only the relevant layers. for i in range(0, num_layers_to_keep): newModuleList.append(oldModuleList[i]) # create a copy of the model, modify it with the new list, and return copyOfModel = copy.deepcopy(model) copyOfModel.encoder.block = newModuleList return copyOfModel class ByT5ForTextGeotagging(PreTrainedModel): config_class = ByT5ForTextGeotaggingConfig def __init__(self, config): super(ByT5ForTextGeotagging, self).__init__(config) self.byt5 = T5EncoderModel.from_pretrained(config.model_name_or_path) if keep_layer_count is not None: self.byt5 = deleteEncodingLayers(self.byt5, keep_layer_count) hidden_size = self.byt5.config.d_model self.fc3 = nn.Linear(hidden_size, config.n_clusters) def forward(self, input, return_coordinates=False): input = self.byt5(input[:, 0, :].squeeze(1))['last_hidden_state'] input = input[:, 0, :].squeeze(1) logits = self.fc3(input) if return_coordinates: class_idx = torch.argmax(logits, dim=1).item() coordinates = self.config.class_to_location.get(str(class_idx)) return logits, coordinates else: return logits def geolocate_text_byt5(text): input_tensor = byt5_tokenizer(text, return_tensors="pt", truncation=True, max_length=140)['input_ids'] logits, (lat, lon) = model(input_tensor.unsqueeze(0), return_coordinates=True) return lat, lon model = ByT5ForTextGeotagging.from_pretrained("yachay/byt5-geotagging-es", token="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") example_texts = [ "Disfrutando de una paella deliciosa en las playas de #Valencia 🥘☀️", "La arquitectura de #Tokio es realmente algo fuera de este mundo 🌆🇯🇵", "Escuchando jazz en un café acogedor en el corazón de #NuevaOrleans 🎷🎶", "Los atardeceres en #CapeTown con la vista del Monte Table son inolvidables 🌅🇿🇦", "Nada se compara con caminar por las históricas calles de #Roma 🏛️🍕" ] # Streamlit interface st.title('GeoTagging using ByT5') # Buttons for example texts for ex_text in example_texts: if st.button(f'Example: "{ex_text[:30]}..."'): text_input = ex_text text_input = st.text_input('Enter your text:', value=text_input if 'text_input' in locals() else '') if text_input: location = geolocate_text_byt5(text_input) st.write('Predicted Location: ', location) # Render map with pydeck map_data = pd.DataFrame( [[location[0], location[1]]], columns=["lat", "lon"] ) st.pydeck_chart(pdk.Deck( map_style='mapbox://styles/mapbox/light-v9', initial_view_state=pdk.ViewState( latitude=location[0], longitude=location[1], zoom=11, pitch=50, ), layers=[ pdk.Layer( 'ScatterplotLayer', data=map_data, get_position='[lon, lat]', get_color='[200, 30, 0, 160]', get_radius=200, ), ], ))