Spaces:
Sleeping
Sleeping
File size: 4,538 Bytes
151c0f1 50ba2ea 151c0f1 5290c3a 151c0f1 4f08151 151c0f1 ff38b8f 151c0f1 c5aa871 151c0f1 c5aa871 151c0f1 50ba2ea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
# Import necessary libraries
import streamlit as st
from transformers import PretrainedConfig, PreTrainedModel, T5EncoderModel, AutoTokenizer
import torch
import torch.nn as nn
import copy
import pydeck as pdk
keep_layer_count=6
byt5_tokenizer = AutoTokenizer.from_pretrained("yachay/byt5-geotagging-es", token="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")
class ByT5ForTextGeotaggingConfig(PretrainedConfig):
model_type = "byt5_for_text_geotagging"
def __init__(self, n_clusters, model_name_or_path, class_to_location=None, **kwargs):
super(ByT5ForTextGeotaggingConfig, self).__init__(**kwargs)
self.n_clusters = n_clusters
self.model_name_or_path = model_name_or_path
self.class_to_location = class_to_location or {}
def to_diff_dict(self):
# Convert the configuration to a dictionary
config_dict = self.to_dict()
# Get the default configuration for comparison
default_config_dict = PretrainedConfig().to_dict()
# Return the differences
diff_dict = {k: v for k, v in config_dict.items() if k not in default_config_dict or v != default_config_dict[k]}
return diff_dict
def deleteEncodingLayers(model, num_layers_to_keep): # must pass in the full bert model
oldModuleList = model.encoder.block
newModuleList = torch.nn.ModuleList()
# Now iterate over all layers, only keepign only the relevant layers.
for i in range(0, num_layers_to_keep):
newModuleList.append(oldModuleList[i])
# create a copy of the model, modify it with the new list, and return
copyOfModel = copy.deepcopy(model)
copyOfModel.encoder.block = newModuleList
return copyOfModel
class ByT5ForTextGeotagging(PreTrainedModel):
config_class = ByT5ForTextGeotaggingConfig
def __init__(self, config):
super(ByT5ForTextGeotagging, self).__init__(config)
self.byt5 = T5EncoderModel.from_pretrained(config.model_name_or_path)
if keep_layer_count is not None:
self.byt5 = deleteEncodingLayers(self.byt5, keep_layer_count)
hidden_size = self.byt5.config.d_model
self.fc3 = nn.Linear(hidden_size, config.n_clusters)
def forward(self, input, return_coordinates=False):
input = self.byt5(input[:, 0, :].squeeze(1))['last_hidden_state']
input = input[:, 0, :].squeeze(1)
logits = self.fc3(input)
if return_coordinates:
class_idx = torch.argmax(logits, dim=1).item()
coordinates = self.config.class_to_location.get(str(class_idx))
return logits, coordinates
else:
return logits
def geolocate_text_byt5(text):
input_tensor = byt5_tokenizer(text, return_tensors="pt", truncation=True, max_length=140)['input_ids']
logits, (lat, lon) = model(input_tensor.unsqueeze(0), return_coordinates=True)
return lat, lon
model = ByT5ForTextGeotagging.from_pretrained("yachay/byt5-geotagging-es", token="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")
example_texts = [
"Disfrutando de una paella deliciosa en las playas de #Valencia 🥘☀️",
"La arquitectura de #Tokio es realmente algo fuera de este mundo 🌆🇯🇵",
"Escuchando jazz en un café acogedor en el corazón de #NuevaOrleans 🎷🎶",
"Los atardeceres en #CapeTown con la vista del Monte Table son inolvidables 🌅🇿🇦",
"Nada se compara con caminar por las históricas calles de #Roma 🏛️🍕"
]
# Streamlit interface
st.title('GeoTagging using ByT5')
# Buttons for example texts
for ex_text in example_texts:
if st.button(f'Example: "{ex_text[:30]}..."'):
text_input = ex_text
text_input = st.text_input('Enter your text:', value=text_input if 'text_input' in locals() else '')
if text_input:
location = geolocate_text_byt5(text_input)
st.write('Predicted Location: ', location)
# Render map with pydeck
map_data = pd.DataFrame(
[[location[0], location[1]]],
columns=["lat", "lon"]
)
st.pydeck_chart(pdk.Deck(
map_style='mapbox://styles/mapbox/light-v9',
initial_view_state=pdk.ViewState(
latitude=location[0],
longitude=location[1],
zoom=11,
pitch=50,
),
layers=[
pdk.Layer(
'ScatterplotLayer',
data=map_data,
get_position='[lon, lat]',
get_color='[200, 30, 0, 160]',
get_radius=200,
),
],
)) |