File size: 5,147 Bytes
151c0f1
 
c9c4318
151c0f1
 
 
 
50ba2ea
151c0f1
 
5290c3a
151c0f1
 
4f08151
151c0f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef8e2fe
160efbe
ef8e2fe
eb2a314
 
ef8e2fe
 
 
 
 
151c0f1
 
 
 
 
dd48c9f
 
 
 
 
151c0f1
ef8e2fe
 
7cde107
c5aa871
 
 
232729a
b99d24c
 
48860ae
c5aa871
 
151c0f1
 
 
c5aa871
11cc467
ef8e2fe
11cc467
dd48c9f
 
c5aa871
ef8e2fe
dd48c9f
7cde107
dd48c9f
 
50ba2ea
ef8e2fe
50ba2ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# Import necessary libraries
import streamlit as st
import pandas as pd
from transformers import PretrainedConfig, PreTrainedModel, T5EncoderModel, AutoTokenizer
import torch
import torch.nn as nn
import copy
import pydeck as pdk

keep_layer_count=6


class ByT5ForTextGeotaggingConfig(PretrainedConfig):
    model_type = "byt5_for_text_geotagging"

    def __init__(self, n_clusters, model_name_or_path, class_to_location=None, **kwargs):
        super(ByT5ForTextGeotaggingConfig, self).__init__(**kwargs)
        self.n_clusters = n_clusters
        self.model_name_or_path = model_name_or_path
        self.class_to_location = class_to_location or {}

        
    def to_diff_dict(self):
        # Convert the configuration to a dictionary
        config_dict = self.to_dict()

        # Get the default configuration for comparison
        default_config_dict = PretrainedConfig().to_dict()
        
        # Return the differences
        diff_dict = {k: v for k, v in config_dict.items() if k not in default_config_dict or v != default_config_dict[k]}
        
        return diff_dict
    

def deleteEncodingLayers(model, num_layers_to_keep):  # must pass in the full bert model
    oldModuleList = model.encoder.block
    newModuleList = torch.nn.ModuleList()

    # Now iterate over all layers, only keepign only the relevant layers.
    for i in range(0, num_layers_to_keep):
        newModuleList.append(oldModuleList[i])

    # create a copy of the model, modify it with the new list, and return
    copyOfModel = copy.deepcopy(model)
    copyOfModel.encoder.block = newModuleList

    return copyOfModel

class ByT5ForTextGeotagging(PreTrainedModel):
    config_class = ByT5ForTextGeotaggingConfig
    
    def __init__(self, config):
        super(ByT5ForTextGeotagging, self).__init__(config)

        self.byt5 = T5EncoderModel.from_pretrained(config.model_name_or_path)
        if keep_layer_count is not None:
            self.byt5 = deleteEncodingLayers(self.byt5, keep_layer_count)

        hidden_size = self.byt5.config.d_model
        self.fc3 = nn.Linear(hidden_size, config.n_clusters)

    def forward(self, input, return_coordinates=False):
        input = self.byt5(input[:, 0, :].squeeze(1))['last_hidden_state']
        input = input[:, 0, :].squeeze(1)
        logits = self.fc3(input)

        if return_coordinates:
            class_idx = torch.argmax(logits, dim=1).item()
            coordinates = self.config.class_to_location.get(str(class_idx))
            return logits, coordinates
        else:
            return logits


@st.cache_data(ttl=None, persist=True)
def load_model_and_tokenizer():
    byt5_tokenizer = AutoTokenizer.from_pretrained("yachay/byt5-geotagging-es", token="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")
    model = ByT5ForTextGeotagging.from_pretrained("yachay/byt5-geotagging-es", token="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")
    return byt5_tokenizer, model

byt5_tokenizer, model = load_model_and_tokenizer()


def geolocate_text_byt5(text):
    input_tensor = byt5_tokenizer(text, return_tensors="pt", truncation=True, max_length=140)['input_ids']
    logits, (lat, lon) = model(input_tensor.unsqueeze(0), return_coordinates=True)
    return lat, lon

if 'text_input' not in st.session_state:
    st.session_state.text_input = ""

if 'text_modified' not in st.session_state:
    st.session_state.text_modified = ""

# When an example button is clicked, update the session state
def set_example_text(example_text):
    st.session_state.text_input_state =  example_text

example_texts = [
    "Disfrutando de una paella deliciosa en las playas de #Valencia 🥘☀️",
    "Una semana de conciertos, fuegos artificiales y txosnas en Aste Nagusia. ¡Esta ciudad sabe cómo celebrar! #Fiestas #AsteNagusia",
    "Me encantó el ceviche en Lima. #Perú",
    "Bailando tango en las calles de Buenos Aires",
    "Admirando las hermosas playas de Cancún. #México"
]


# Streamlit interface
st.title('GeoTagging using ByT5')

st.write('Examples:')
for example in example_texts:
    if st.button(f"{example}"):
        st.session_state.text_input = example
        st.session_state.text_modified = st.session_state.text_input

# Get text input and update session state when it's modified
st.session_state.text_modified = st.text_input('Enter your text:', value=st.session_state.text_input)
    
if st.session_state.text_input:
    location = geolocate_text_byt5(st.session_state.text_input)
    st.write('Predicted Location: ', location)

     # Render map with pydeck
    map_data = pd.DataFrame(
        [[location[0], location[1]]],
        columns=["lat", "lon"]
    )

    st.pydeck_chart(pdk.Deck(
        map_style='mapbox://styles/mapbox/light-v9',
        initial_view_state=pdk.ViewState(
            latitude=location[0],
            longitude=location[1],
            zoom=11,
            pitch=50,
        ),
        layers=[
            pdk.Layer(
                'ScatterplotLayer',
                data=map_data,
                get_position='[lon, lat]',
                get_color='[200, 30, 0, 160]',
                get_radius=200,
            ),
        ],
    ))