Spaces:
Sleeping
Sleeping
File size: 4,452 Bytes
151c0f1 c9c4318 151c0f1 50ba2ea 151c0f1 5290c3a 151c0f1 4f08151 151c0f1 ff38b8f 151c0f1 c5aa871 48860ae c5aa871 151c0f1 c5aa871 151c0f1 50ba2ea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
# Import necessary libraries
import streamlit as st
import pandas as pd
from transformers import PretrainedConfig, PreTrainedModel, T5EncoderModel, AutoTokenizer
import torch
import torch.nn as nn
import copy
import pydeck as pdk
keep_layer_count=6
byt5_tokenizer = AutoTokenizer.from_pretrained("yachay/byt5-geotagging-es", token="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")
class ByT5ForTextGeotaggingConfig(PretrainedConfig):
model_type = "byt5_for_text_geotagging"
def __init__(self, n_clusters, model_name_or_path, class_to_location=None, **kwargs):
super(ByT5ForTextGeotaggingConfig, self).__init__(**kwargs)
self.n_clusters = n_clusters
self.model_name_or_path = model_name_or_path
self.class_to_location = class_to_location or {}
def to_diff_dict(self):
# Convert the configuration to a dictionary
config_dict = self.to_dict()
# Get the default configuration for comparison
default_config_dict = PretrainedConfig().to_dict()
# Return the differences
diff_dict = {k: v for k, v in config_dict.items() if k not in default_config_dict or v != default_config_dict[k]}
return diff_dict
def deleteEncodingLayers(model, num_layers_to_keep): # must pass in the full bert model
oldModuleList = model.encoder.block
newModuleList = torch.nn.ModuleList()
# Now iterate over all layers, only keepign only the relevant layers.
for i in range(0, num_layers_to_keep):
newModuleList.append(oldModuleList[i])
# create a copy of the model, modify it with the new list, and return
copyOfModel = copy.deepcopy(model)
copyOfModel.encoder.block = newModuleList
return copyOfModel
class ByT5ForTextGeotagging(PreTrainedModel):
config_class = ByT5ForTextGeotaggingConfig
def __init__(self, config):
super(ByT5ForTextGeotagging, self).__init__(config)
self.byt5 = T5EncoderModel.from_pretrained(config.model_name_or_path)
if keep_layer_count is not None:
self.byt5 = deleteEncodingLayers(self.byt5, keep_layer_count)
hidden_size = self.byt5.config.d_model
self.fc3 = nn.Linear(hidden_size, config.n_clusters)
def forward(self, input, return_coordinates=False):
input = self.byt5(input[:, 0, :].squeeze(1))['last_hidden_state']
input = input[:, 0, :].squeeze(1)
logits = self.fc3(input)
if return_coordinates:
class_idx = torch.argmax(logits, dim=1).item()
coordinates = self.config.class_to_location.get(str(class_idx))
return logits, coordinates
else:
return logits
def geolocate_text_byt5(text):
input_tensor = byt5_tokenizer(text, return_tensors="pt", truncation=True, max_length=140)['input_ids']
logits, (lat, lon) = model(input_tensor.unsqueeze(0), return_coordinates=True)
return lat, lon
model = ByT5ForTextGeotagging.from_pretrained("yachay/byt5-geotagging-es", token="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")
example_texts = [
"Disfrutando de una paella deliciosa en las playas de #Valencia 🥘☀️",
"¡Barcelona es increíble! #VacacionesEnEspaña",
"Me encantó el ceviche en Lima. ¡Qué delicioso! #Perú",
"Bailando tango en las calles de Buenos Aires. #Argentina",
"Admirando las hermosas playas de Cancún. #México"
]
# Streamlit interface
st.title('GeoTagging using ByT5')
# Buttons for example texts
for ex_text in example_texts:
if st.button(f'Example: "{ex_text[:30]}..."'):
text_input = ex_text
text_input = st.text_input('Enter your text:', value=text_input if 'text_input' in locals() else '')
if text_input:
location = geolocate_text_byt5(text_input)
st.write('Predicted Location: ', location)
# Render map with pydeck
map_data = pd.DataFrame(
[[location[0], location[1]]],
columns=["lat", "lon"]
)
st.pydeck_chart(pdk.Deck(
map_style='mapbox://styles/mapbox/light-v9',
initial_view_state=pdk.ViewState(
latitude=location[0],
longitude=location[1],
zoom=11,
pitch=50,
),
layers=[
pdk.Layer(
'ScatterplotLayer',
data=map_data,
get_position='[lon, lat]',
get_color='[200, 30, 0, 160]',
get_radius=200,
),
],
)) |