yachay's picture
Update app.py
78afeef
raw
history blame
5.47 kB
# Import necessary libraries
import streamlit as st
import pandas as pd
from transformers import PretrainedConfig, PreTrainedModel, T5EncoderModel, AutoTokenizer
import torch
import torch.nn as nn
import copy
import pydeck as pdk
keep_layer_count=6
class ByT5ForTextGeotaggingConfig(PretrainedConfig):
model_type = "byt5_for_text_geotagging"
def __init__(self, n_clusters, model_name_or_path, class_to_location=None, **kwargs):
super(ByT5ForTextGeotaggingConfig, self).__init__(**kwargs)
self.n_clusters = n_clusters
self.model_name_or_path = model_name_or_path
self.class_to_location = class_to_location or {}
def to_diff_dict(self):
# Convert the configuration to a dictionary
config_dict = self.to_dict()
# Get the default configuration for comparison
default_config_dict = PretrainedConfig().to_dict()
# Return the differences
diff_dict = {k: v for k, v in config_dict.items() if k not in default_config_dict or v != default_config_dict[k]}
return diff_dict
def deleteEncodingLayers(model, num_layers_to_keep): # must pass in the full bert model
oldModuleList = model.encoder.block
newModuleList = torch.nn.ModuleList()
# Now iterate over all layers, only keepign only the relevant layers.
for i in range(0, num_layers_to_keep):
newModuleList.append(oldModuleList[i])
# create a copy of the model, modify it with the new list, and return
copyOfModel = copy.deepcopy(model)
copyOfModel.encoder.block = newModuleList
return copyOfModel
class ByT5ForTextGeotagging(PreTrainedModel):
config_class = ByT5ForTextGeotaggingConfig
def __init__(self, config):
super(ByT5ForTextGeotagging, self).__init__(config)
self.byt5 = T5EncoderModel.from_pretrained(config.model_name_or_path)
if keep_layer_count is not None:
self.byt5 = deleteEncodingLayers(self.byt5, keep_layer_count)
hidden_size = self.byt5.config.d_model
self.fc3 = nn.Linear(hidden_size, config.n_clusters)
def forward(self, input, return_coordinates=False):
input = self.byt5(input[:, 0, :].squeeze(1))['last_hidden_state']
input = input[:, 0, :].squeeze(1)
logits = self.fc3(input)
if return_coordinates:
class_idx = torch.argmax(logits, dim=1).item()
coordinates = self.config.class_to_location.get(str(class_idx))
confidence = torch.max(logits).item()
return logits, coordinates, confidence
else:
return logits
@st.cache_data(ttl=None, persist=True)
def load_model_and_tokenizer():
byt5_tokenizer = AutoTokenizer.from_pretrained("yachay/byt5-geotagging-es", token="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")
model = ByT5ForTextGeotagging.from_pretrained("yachay/byt5-geotagging-es", token="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")
return byt5_tokenizer, model
byt5_tokenizer, model = load_model_and_tokenizer()
def geolocate_text_byt5(text):
input_tensor = byt5_tokenizer(text, return_tensors="pt", truncation=True, max_length=140)['input_ids']
logits, (lat, lon), confidence = model(input_tensor.unsqueeze(0), return_coordinates=True)
return lat, lon, confidence
if 'text_input' not in st.session_state:
st.session_state.text_input = ""
if 'text_modified' not in st.session_state:
st.session_state.text_modified = ""
# When an example button is clicked, update the session state
def set_example_text(example_text):
st.session_state.text_input_state = example_text
example_texts = [
"Disfrutando de una paella deliciosa en las playas de #Valencia 🥘☀️",
"Una semana de conciertos, fuegos artificiales y txosnas en Aste Nagusia. ¡Esta ciudad sabe cómo celebrar! #Fiestas #AsteNagusia",
"Viendo las destrezas de los gauchos en la Semana Criolla. Los asados también son para morirse. #SemanaCriolla #Tradición",
"Bailando tango en las calles de Buenos Aires",
"Admirando las hermosas playas de Cancún. #México"
]
# Streamlit interface
st.title('GeoTagging using ByT5')
st.write('Examples:')
for example in example_texts:
if st.button(f"{example}"):
st.session_state.text_input = example
st.session_state.text_modified = st.session_state.text_input
# Get text input and update session state when it's modified
st.session_state.text_modified = st.text_input('Enter your text:', value=st.session_state.text_input)
if st.button('Submit'):
st.session_state.text_input = st.session_state.text_modified
if st.session_state.text_input:
location, confidence = geolocate_text_byt5(st.session_state.text_input)
st.write('Predicted Location: ', location, 'confidence: ', 'High' if confidence > 0.2 else 'Low')
# Render map with pydeck
map_data = pd.DataFrame(
[[location[0], location[1]]],
columns=["lat", "lon"]
)
st.pydeck_chart(pdk.Deck(
map_style='mapbox://styles/mapbox/light-v9',
initial_view_state=pdk.ViewState(
latitude=location[0],
longitude=location[1],
zoom=6,
pitch=50,
),
layers=[
pdk.Layer(
'ScatterplotLayer',
data=map_data,
get_position='[lon, lat]',
get_color='[200, 30, 0, 160]',
get_radius=200,
),
],
))