File size: 8,351 Bytes
151c0f1
 
c9c4318
151c0f1
 
 
 
50ba2ea
52a5d0f
ac09070
151c0f1
ac09070
 
 
 
151c0f1
5290c3a
151c0f1
 
4f08151
151c0f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8dff62f
7341b3b
151c0f1
 
 
ef8e2fe
ac09070
ef8e2fe
6f8ed7d
 
ef8e2fe
 
 
 
66ae132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac09070
 
 
 
 
 
 
 
 
 
 
66ae132
 
ac09070
66ae132
ef8e2fe
151c0f1
 
7341b3b
 
151c0f1
dd48c9f
 
 
 
 
151c0f1
ef8e2fe
 
7cde107
c5aa871
 
 
232729a
77415bb
b99d24c
48860ae
c5aa871
 
151c0f1
 
 
c5aa871
11cc467
ef8e2fe
11cc467
dd48c9f
 
c5aa871
ef8e2fe
dd48c9f
ebc7ce1
 
 
 
5de90a6
dd48c9f
66ae132
ac09070
66ae132
ac09070
 
 
 
 
 
 
 
 
 
 
ef8e2fe
50ba2ea
 
626813f
50ba2ea
 
bc26907
ac09070
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc26907
50ba2ea
 
 
626813f
 
78afeef
50ba2ea
 
bc26907
50ba2ea
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
# Import necessary libraries
import streamlit as st
import pandas as pd
from transformers import PretrainedConfig, PreTrainedModel, T5EncoderModel, AutoTokenizer
import torch
import torch.nn as nn
import copy
import pydeck as pdk
import numpy as np
import base64

def get_base64_encoded_image(image_path):
    with open(image_path, "rb") as img_file:
        return base64.b64encode(img_file.read()).decode('utf-8')
        
keep_layer_count=6


class ByT5ForTextGeotaggingConfig(PretrainedConfig):
    model_type = "byt5_for_text_geotagging"

    def __init__(self, n_clusters, model_name_or_path, class_to_location=None, **kwargs):
        super(ByT5ForTextGeotaggingConfig, self).__init__(**kwargs)
        self.n_clusters = n_clusters
        self.model_name_or_path = model_name_or_path
        self.class_to_location = class_to_location or {}

        
    def to_diff_dict(self):
        # Convert the configuration to a dictionary
        config_dict = self.to_dict()

        # Get the default configuration for comparison
        default_config_dict = PretrainedConfig().to_dict()
        
        # Return the differences
        diff_dict = {k: v for k, v in config_dict.items() if k not in default_config_dict or v != default_config_dict[k]}
        
        return diff_dict
    

def deleteEncodingLayers(model, num_layers_to_keep):  # must pass in the full bert model
    oldModuleList = model.encoder.block
    newModuleList = torch.nn.ModuleList()

    # Now iterate over all layers, only keepign only the relevant layers.
    for i in range(0, num_layers_to_keep):
        newModuleList.append(oldModuleList[i])

    # create a copy of the model, modify it with the new list, and return
    copyOfModel = copy.deepcopy(model)
    copyOfModel.encoder.block = newModuleList

    return copyOfModel

class ByT5ForTextGeotagging(PreTrainedModel):
    config_class = ByT5ForTextGeotaggingConfig
    
    def __init__(self, config):
        super(ByT5ForTextGeotagging, self).__init__(config)

        self.byt5 = T5EncoderModel.from_pretrained(config.model_name_or_path)
        if keep_layer_count is not None:
            self.byt5 = deleteEncodingLayers(self.byt5, keep_layer_count)

        hidden_size = self.byt5.config.d_model
        self.fc3 = nn.Linear(hidden_size, config.n_clusters)

    def forward(self, input, return_coordinates=False):
        input = self.byt5(input[:, 0, :].squeeze(1))['last_hidden_state']
        input = input[:, 0, :].squeeze(1)
        logits = self.fc3(input)

        if return_coordinates:
            class_idx = torch.argmax(logits, dim=1).item()
            coordinates = self.config.class_to_location.get(str(class_idx))
            confidence = torch.max(torch.nn.functional.softmax(logits)).item()
            return logits, coordinates, confidence
        else:
            return logits


@st.cache_resource(ttl=None)
def load_model_and_tokenizer():
    byt5_tokenizer = AutoTokenizer.from_pretrained("yachay/byt5-geotagging-es", token=st.secrets['token'])
    model = ByT5ForTextGeotagging.from_pretrained("yachay/byt5-geotagging-es", token=st.secrets['token'])
    return byt5_tokenizer, model

byt5_tokenizer, model = load_model_and_tokenizer()

def geolocate_text_byt5_multiclass(text):
    input_tensor = byt5_tokenizer(text, return_tensors="pt", truncation=True, max_length=140)['input_ids']
    logits, (lat, lon), confidence = model(input_tensor.unsqueeze(0), return_coordinates=True)
    probas = torch.nn.functional.softmax(logits, dim=1).detach().cpu().numpy()

    # Sort probabilities in descending order and get their indices
    sorted_indices = np.argsort(-probas[0])

    results = []
    cumulative_prob = 0.0
    for class_idx in sorted_indices:
        prob = probas[0][class_idx]
        cumulative_prob += prob
        if cumulative_prob > 0.5:
            break
        coordinates = model.config.class_to_location.get(str(class_idx))
        if coordinates:
            results.append((class_idx, prob, coordinates))

    # Check if at least one result is added; if not, add the highest probability class
    if not results:
        class_idx = sorted_indices[0]
        prob = probas[0][class_idx]
        coordinates = model.config.class_to_location.get(str(class_idx))
        if coordinates:
            results.append((class_idx, prob, coordinates))

    return results

   

def geolocate_text_byt5(text):
    input_tensor = byt5_tokenizer(text, return_tensors="pt", truncation=True, max_length=140)['input_ids']
    logits, (lat, lon), confidence = model(input_tensor.unsqueeze(0), return_coordinates=True)
    return lat, lon, confidence

if 'text_input' not in st.session_state:
    st.session_state.text_input = ""

if 'text_modified' not in st.session_state:
    st.session_state.text_modified = ""

# When an example button is clicked, update the session state
def set_example_text(example_text):
    st.session_state.text_input_state =  example_text

example_texts = [
    "Disfrutando de una paella deliciosa en las playas de #Valencia 🥘☀️",
    "Una semana de conciertos, fuegos artificiales y txosnas en Aste Nagusia. ¡Esta ciudad sabe cómo celebrar! #Fiestas #AsteNagusia",
    "Viendo las destrezas de los gauchos en la Semana Criolla. Los asados también son para morirse. #SemanaCriolla #Tradición",
    "Bailando tango en las calles de Buenos Aires",
    "Admirando las hermosas playas de Cancún. #México"
]


# Streamlit interface
st.title('GeoTagging using ByT5')

st.write('Examples:')
for example in example_texts:
    if st.button(f"{example}"):
        st.session_state.text_input = example
        st.session_state.text_modified = st.session_state.text_input

# Get text input and update session state when it's modified
st.session_state.text_modified = st.text_input('Enter your text:', value=st.session_state.text_input)

if st.button('Submit'):
    st.session_state.text_input = st.session_state.text_modified


if st.session_state.text_input:
    results = geolocate_text_byt5_multiclass(st.session_state.text_input)
    #st.write(results)
    _, confidence, (lat, lon) = results[0]
    if len(results) == 1:
        confidence_def = 'High'
    elif len(results) < 50:
        confidence_def = 'Low'
    else:
        confidence_def = 'Very low'
    st.write('Predicted Location: (', lat, lon, '). Confidence: ', confidence_def)
    if confidence_def == 'Low':
        st.write('Multiple possible locations were identified as confidence is low')
    elif confidence_def == 'Very low':
        st.write('There are too many possible locations as confidence is very low')

     # Render map with pydeck
    map_data = pd.DataFrame(
        [[lat, lon]],
        columns=["lat", "lon"]
    )
    
    encoded_image = get_base64_encoded_image("icons8-map-pin-48.png")
    icon_url = f"data:image/png;base64,{encoded_image}"
    
    # Example icon data
    icon_data = {
        "url": icon_url,  # URL of the icon image
        "width": 128,  # Width of the icon in pixels
        "height": 128,  # Height of the icon in pixels
        "anchorY": 128  # Anchor point of the icon in pixels (bottom center)
    }

    # Example location data
    locations = pd.DataFrame({
        'lat': [lat],  # Latitude values
        'lon': [lon],  # Longitude values
        'icon_data': [icon_data]  # Repeating the icon data for each location
    })

    layers = []
        
    if confidence_def != 'Very low':
        # Add layers for each additional marker
        for item in results:
            _, confidence, (lat, lon) = item
            layer = pdk.Layer(
                type='IconLayer',
                data=pd.DataFrame({
                    'lat': [lat],  # Latitude values
                    'lon': [lon],  # Longitude values
                    'icon_data': [icon_data]  # Repeating the icon data for each location
                }),
                get_icon='icon_data',
                get_size=4,
                size_scale=15,
                get_position=['lon', 'lat'],
                pickable=True
            )
            layers.append(layer)

    st.pydeck_chart(pdk.Deck(
        map_style='mapbox://styles/mapbox/light-v9',
        initial_view_state=pdk.ViewState(
            latitude=lat,
            longitude=lon,
            zoom=6,
            pitch=50,
        ),
        layers=layers,
    ))