Spaces:
Sleeping
Sleeping
File size: 8,351 Bytes
151c0f1 c9c4318 151c0f1 50ba2ea 52a5d0f ac09070 151c0f1 ac09070 151c0f1 5290c3a 151c0f1 4f08151 151c0f1 8dff62f 7341b3b 151c0f1 ef8e2fe ac09070 ef8e2fe 6f8ed7d ef8e2fe 66ae132 ac09070 66ae132 ac09070 66ae132 ef8e2fe 151c0f1 7341b3b 151c0f1 dd48c9f 151c0f1 ef8e2fe 7cde107 c5aa871 232729a 77415bb b99d24c 48860ae c5aa871 151c0f1 c5aa871 11cc467 ef8e2fe 11cc467 dd48c9f c5aa871 ef8e2fe dd48c9f ebc7ce1 5de90a6 dd48c9f 66ae132 ac09070 66ae132 ac09070 ef8e2fe 50ba2ea 626813f 50ba2ea bc26907 ac09070 bc26907 50ba2ea 626813f 78afeef 50ba2ea bc26907 50ba2ea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 |
# Import necessary libraries
import streamlit as st
import pandas as pd
from transformers import PretrainedConfig, PreTrainedModel, T5EncoderModel, AutoTokenizer
import torch
import torch.nn as nn
import copy
import pydeck as pdk
import numpy as np
import base64
def get_base64_encoded_image(image_path):
with open(image_path, "rb") as img_file:
return base64.b64encode(img_file.read()).decode('utf-8')
keep_layer_count=6
class ByT5ForTextGeotaggingConfig(PretrainedConfig):
model_type = "byt5_for_text_geotagging"
def __init__(self, n_clusters, model_name_or_path, class_to_location=None, **kwargs):
super(ByT5ForTextGeotaggingConfig, self).__init__(**kwargs)
self.n_clusters = n_clusters
self.model_name_or_path = model_name_or_path
self.class_to_location = class_to_location or {}
def to_diff_dict(self):
# Convert the configuration to a dictionary
config_dict = self.to_dict()
# Get the default configuration for comparison
default_config_dict = PretrainedConfig().to_dict()
# Return the differences
diff_dict = {k: v for k, v in config_dict.items() if k not in default_config_dict or v != default_config_dict[k]}
return diff_dict
def deleteEncodingLayers(model, num_layers_to_keep): # must pass in the full bert model
oldModuleList = model.encoder.block
newModuleList = torch.nn.ModuleList()
# Now iterate over all layers, only keepign only the relevant layers.
for i in range(0, num_layers_to_keep):
newModuleList.append(oldModuleList[i])
# create a copy of the model, modify it with the new list, and return
copyOfModel = copy.deepcopy(model)
copyOfModel.encoder.block = newModuleList
return copyOfModel
class ByT5ForTextGeotagging(PreTrainedModel):
config_class = ByT5ForTextGeotaggingConfig
def __init__(self, config):
super(ByT5ForTextGeotagging, self).__init__(config)
self.byt5 = T5EncoderModel.from_pretrained(config.model_name_or_path)
if keep_layer_count is not None:
self.byt5 = deleteEncodingLayers(self.byt5, keep_layer_count)
hidden_size = self.byt5.config.d_model
self.fc3 = nn.Linear(hidden_size, config.n_clusters)
def forward(self, input, return_coordinates=False):
input = self.byt5(input[:, 0, :].squeeze(1))['last_hidden_state']
input = input[:, 0, :].squeeze(1)
logits = self.fc3(input)
if return_coordinates:
class_idx = torch.argmax(logits, dim=1).item()
coordinates = self.config.class_to_location.get(str(class_idx))
confidence = torch.max(torch.nn.functional.softmax(logits)).item()
return logits, coordinates, confidence
else:
return logits
@st.cache_resource(ttl=None)
def load_model_and_tokenizer():
byt5_tokenizer = AutoTokenizer.from_pretrained("yachay/byt5-geotagging-es", token=st.secrets['token'])
model = ByT5ForTextGeotagging.from_pretrained("yachay/byt5-geotagging-es", token=st.secrets['token'])
return byt5_tokenizer, model
byt5_tokenizer, model = load_model_and_tokenizer()
def geolocate_text_byt5_multiclass(text):
input_tensor = byt5_tokenizer(text, return_tensors="pt", truncation=True, max_length=140)['input_ids']
logits, (lat, lon), confidence = model(input_tensor.unsqueeze(0), return_coordinates=True)
probas = torch.nn.functional.softmax(logits, dim=1).detach().cpu().numpy()
# Sort probabilities in descending order and get their indices
sorted_indices = np.argsort(-probas[0])
results = []
cumulative_prob = 0.0
for class_idx in sorted_indices:
prob = probas[0][class_idx]
cumulative_prob += prob
if cumulative_prob > 0.5:
break
coordinates = model.config.class_to_location.get(str(class_idx))
if coordinates:
results.append((class_idx, prob, coordinates))
# Check if at least one result is added; if not, add the highest probability class
if not results:
class_idx = sorted_indices[0]
prob = probas[0][class_idx]
coordinates = model.config.class_to_location.get(str(class_idx))
if coordinates:
results.append((class_idx, prob, coordinates))
return results
def geolocate_text_byt5(text):
input_tensor = byt5_tokenizer(text, return_tensors="pt", truncation=True, max_length=140)['input_ids']
logits, (lat, lon), confidence = model(input_tensor.unsqueeze(0), return_coordinates=True)
return lat, lon, confidence
if 'text_input' not in st.session_state:
st.session_state.text_input = ""
if 'text_modified' not in st.session_state:
st.session_state.text_modified = ""
# When an example button is clicked, update the session state
def set_example_text(example_text):
st.session_state.text_input_state = example_text
example_texts = [
"Disfrutando de una paella deliciosa en las playas de #Valencia 🥘☀️",
"Una semana de conciertos, fuegos artificiales y txosnas en Aste Nagusia. ¡Esta ciudad sabe cómo celebrar! #Fiestas #AsteNagusia",
"Viendo las destrezas de los gauchos en la Semana Criolla. Los asados también son para morirse. #SemanaCriolla #Tradición",
"Bailando tango en las calles de Buenos Aires",
"Admirando las hermosas playas de Cancún. #México"
]
# Streamlit interface
st.title('GeoTagging using ByT5')
st.write('Examples:')
for example in example_texts:
if st.button(f"{example}"):
st.session_state.text_input = example
st.session_state.text_modified = st.session_state.text_input
# Get text input and update session state when it's modified
st.session_state.text_modified = st.text_input('Enter your text:', value=st.session_state.text_input)
if st.button('Submit'):
st.session_state.text_input = st.session_state.text_modified
if st.session_state.text_input:
results = geolocate_text_byt5_multiclass(st.session_state.text_input)
#st.write(results)
_, confidence, (lat, lon) = results[0]
if len(results) == 1:
confidence_def = 'High'
elif len(results) < 50:
confidence_def = 'Low'
else:
confidence_def = 'Very low'
st.write('Predicted Location: (', lat, lon, '). Confidence: ', confidence_def)
if confidence_def == 'Low':
st.write('Multiple possible locations were identified as confidence is low')
elif confidence_def == 'Very low':
st.write('There are too many possible locations as confidence is very low')
# Render map with pydeck
map_data = pd.DataFrame(
[[lat, lon]],
columns=["lat", "lon"]
)
encoded_image = get_base64_encoded_image("icons8-map-pin-48.png")
icon_url = f"data:image/png;base64,{encoded_image}"
# Example icon data
icon_data = {
"url": icon_url, # URL of the icon image
"width": 128, # Width of the icon in pixels
"height": 128, # Height of the icon in pixels
"anchorY": 128 # Anchor point of the icon in pixels (bottom center)
}
# Example location data
locations = pd.DataFrame({
'lat': [lat], # Latitude values
'lon': [lon], # Longitude values
'icon_data': [icon_data] # Repeating the icon data for each location
})
layers = []
if confidence_def != 'Very low':
# Add layers for each additional marker
for item in results:
_, confidence, (lat, lon) = item
layer = pdk.Layer(
type='IconLayer',
data=pd.DataFrame({
'lat': [lat], # Latitude values
'lon': [lon], # Longitude values
'icon_data': [icon_data] # Repeating the icon data for each location
}),
get_icon='icon_data',
get_size=4,
size_scale=15,
get_position=['lon', 'lat'],
pickable=True
)
layers.append(layer)
st.pydeck_chart(pdk.Deck(
map_style='mapbox://styles/mapbox/light-v9',
initial_view_state=pdk.ViewState(
latitude=lat,
longitude=lon,
zoom=6,
pitch=50,
),
layers=layers,
)) |