Spaces:
Sleeping
Sleeping
# Import necessary libraries | |
import streamlit as st | |
import pandas as pd | |
from transformers import PretrainedConfig, PreTrainedModel, T5EncoderModel, AutoTokenizer | |
import torch | |
import torch.nn as nn | |
import copy | |
import pydeck as pdk | |
keep_layer_count=6 | |
class ByT5ForTextGeotaggingConfig(PretrainedConfig): | |
model_type = "byt5_for_text_geotagging" | |
def __init__(self, n_clusters, model_name_or_path, class_to_location=None, **kwargs): | |
super(ByT5ForTextGeotaggingConfig, self).__init__(**kwargs) | |
self.n_clusters = n_clusters | |
self.model_name_or_path = model_name_or_path | |
self.class_to_location = class_to_location or {} | |
def to_diff_dict(self): | |
# Convert the configuration to a dictionary | |
config_dict = self.to_dict() | |
# Get the default configuration for comparison | |
default_config_dict = PretrainedConfig().to_dict() | |
# Return the differences | |
diff_dict = {k: v for k, v in config_dict.items() if k not in default_config_dict or v != default_config_dict[k]} | |
return diff_dict | |
def deleteEncodingLayers(model, num_layers_to_keep): # must pass in the full bert model | |
oldModuleList = model.encoder.block | |
newModuleList = torch.nn.ModuleList() | |
# Now iterate over all layers, only keepign only the relevant layers. | |
for i in range(0, num_layers_to_keep): | |
newModuleList.append(oldModuleList[i]) | |
# create a copy of the model, modify it with the new list, and return | |
copyOfModel = copy.deepcopy(model) | |
copyOfModel.encoder.block = newModuleList | |
return copyOfModel | |
class ByT5ForTextGeotagging(PreTrainedModel): | |
config_class = ByT5ForTextGeotaggingConfig | |
def __init__(self, config): | |
super(ByT5ForTextGeotagging, self).__init__(config) | |
self.byt5 = T5EncoderModel.from_pretrained(config.model_name_or_path) | |
if keep_layer_count is not None: | |
self.byt5 = deleteEncodingLayers(self.byt5, keep_layer_count) | |
hidden_size = self.byt5.config.d_model | |
self.fc3 = nn.Linear(hidden_size, config.n_clusters) | |
def forward(self, input, return_coordinates=False): | |
input = self.byt5(input[:, 0, :].squeeze(1))['last_hidden_state'] | |
input = input[:, 0, :].squeeze(1) | |
logits = self.fc3(input) | |
if return_coordinates: | |
class_idx = torch.argmax(logits, dim=1).item() | |
coordinates = self.config.class_to_location.get(str(class_idx)) | |
return logits, coordinates | |
else: | |
return logits | |
def load_model_and_tokenizer(): | |
byt5_tokenizer = AutoTokenizer.from_pretrained("yachay/byt5-geotagging-es", token="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") | |
model = ByT5ForTextGeotagging.from_pretrained("yachay/byt5-geotagging-es", token="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") | |
return byt5_tokenizer, model | |
byt5_tokenizer, model = load_model_and_tokenizer() | |
def geolocate_text_byt5(text): | |
input_tensor = byt5_tokenizer(text, return_tensors="pt", truncation=True, max_length=140)['input_ids'] | |
logits, (lat, lon) = model(input_tensor.unsqueeze(0), return_coordinates=True) | |
return lat, lon | |
if 'text_input_state' not in st.session_state: | |
st.session_state.text_input_state = "" | |
# When an example button is clicked, update the session state | |
def set_example_text(example_text): | |
st.session_state.text_input_state = example_text | |
example_texts = [ | |
"Disfrutando de una paella deliciosa en las playas de #Valencia 🥘☀️", | |
"¡Barcelona es increíble! #VacacionesEnEspaña", | |
"Me encantó el ceviche en Lima. #Perú", | |
"Bailando tango en las calles de Buenos Aires", | |
"Admirando las hermosas playas de Cancún. #México" | |
] | |
# Streamlit interface | |
st.title('GeoTagging using ByT5') | |
st.write('Examples:') | |
for example in example_texts: | |
if st.button(f"{example}"): | |
set_example_text(example) | |
# Get text input and update session state when it's modified | |
text_input = st.text_input('Enter your text:', value=st.session_state.text_input_state) | |
if text_input != st.session_state.text_input_state: | |
st.session_state.text_input_state = text_input | |
if text_input: | |
location = geolocate_text_byt5(text_input) | |
st.write('Predicted Location: ', location) | |
# Render map with pydeck | |
map_data = pd.DataFrame( | |
[[location[0], location[1]]], | |
columns=["lat", "lon"] | |
) | |
st.pydeck_chart(pdk.Deck( | |
map_style='mapbox://styles/mapbox/light-v9', | |
initial_view_state=pdk.ViewState( | |
latitude=location[0], | |
longitude=location[1], | |
zoom=11, | |
pitch=50, | |
), | |
layers=[ | |
pdk.Layer( | |
'ScatterplotLayer', | |
data=map_data, | |
get_position='[lon, lat]', | |
get_color='[200, 30, 0, 160]', | |
get_radius=200, | |
), | |
], | |
)) |