File size: 4,452 Bytes
151c0f1
 
c9c4318
151c0f1
 
 
 
50ba2ea
151c0f1
 
5290c3a
 
151c0f1
 
4f08151
151c0f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff38b8f
151c0f1
c5aa871
 
 
48860ae
 
 
 
c5aa871
 
151c0f1
 
 
c5aa871
 
 
 
 
 
 
 
 
151c0f1
 
50ba2ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# Import necessary libraries
import streamlit as st
import pandas as pd
from transformers import PretrainedConfig, PreTrainedModel, T5EncoderModel, AutoTokenizer
import torch
import torch.nn as nn
import copy
import pydeck as pdk

keep_layer_count=6
byt5_tokenizer = AutoTokenizer.from_pretrained("yachay/byt5-geotagging-es", token="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")


class ByT5ForTextGeotaggingConfig(PretrainedConfig):
    model_type = "byt5_for_text_geotagging"

    def __init__(self, n_clusters, model_name_or_path, class_to_location=None, **kwargs):
        super(ByT5ForTextGeotaggingConfig, self).__init__(**kwargs)
        self.n_clusters = n_clusters
        self.model_name_or_path = model_name_or_path
        self.class_to_location = class_to_location or {}

        
    def to_diff_dict(self):
        # Convert the configuration to a dictionary
        config_dict = self.to_dict()

        # Get the default configuration for comparison
        default_config_dict = PretrainedConfig().to_dict()
        
        # Return the differences
        diff_dict = {k: v for k, v in config_dict.items() if k not in default_config_dict or v != default_config_dict[k]}
        
        return diff_dict
    

def deleteEncodingLayers(model, num_layers_to_keep):  # must pass in the full bert model
    oldModuleList = model.encoder.block
    newModuleList = torch.nn.ModuleList()

    # Now iterate over all layers, only keepign only the relevant layers.
    for i in range(0, num_layers_to_keep):
        newModuleList.append(oldModuleList[i])

    # create a copy of the model, modify it with the new list, and return
    copyOfModel = copy.deepcopy(model)
    copyOfModel.encoder.block = newModuleList

    return copyOfModel

class ByT5ForTextGeotagging(PreTrainedModel):
    config_class = ByT5ForTextGeotaggingConfig
    
    def __init__(self, config):
        super(ByT5ForTextGeotagging, self).__init__(config)

        self.byt5 = T5EncoderModel.from_pretrained(config.model_name_or_path)
        if keep_layer_count is not None:
            self.byt5 = deleteEncodingLayers(self.byt5, keep_layer_count)

        hidden_size = self.byt5.config.d_model
        self.fc3 = nn.Linear(hidden_size, config.n_clusters)

    def forward(self, input, return_coordinates=False):
        input = self.byt5(input[:, 0, :].squeeze(1))['last_hidden_state']
        input = input[:, 0, :].squeeze(1)
        logits = self.fc3(input)

        if return_coordinates:
            class_idx = torch.argmax(logits, dim=1).item()
            coordinates = self.config.class_to_location.get(str(class_idx))
            return logits, coordinates
        else:
            return logits

def geolocate_text_byt5(text):
    input_tensor = byt5_tokenizer(text, return_tensors="pt", truncation=True, max_length=140)['input_ids']
    logits, (lat, lon) = model(input_tensor.unsqueeze(0), return_coordinates=True)
    return lat, lon

model = ByT5ForTextGeotagging.from_pretrained("yachay/byt5-geotagging-es", token="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")


example_texts = [
    "Disfrutando de una paella deliciosa en las playas de #Valencia 🥘☀️",
    "¡Barcelona es increíble! #VacacionesEnEspaña",
    "Me encantó el ceviche en Lima. ¡Qué delicioso! #Perú",
    "Bailando tango en las calles de Buenos Aires. #Argentina",
    "Admirando las hermosas playas de Cancún. #México"
]


# Streamlit interface
st.title('GeoTagging using ByT5')

# Buttons for example texts
for ex_text in example_texts:
    if st.button(f'Example: "{ex_text[:30]}..."'):
        text_input = ex_text

text_input = st.text_input('Enter your text:', value=text_input if 'text_input' in locals() else '')


if text_input:
    location = geolocate_text_byt5(text_input)
    st.write('Predicted Location: ', location)
     # Render map with pydeck
    map_data = pd.DataFrame(
        [[location[0], location[1]]],
        columns=["lat", "lon"]
    )

    st.pydeck_chart(pdk.Deck(
        map_style='mapbox://styles/mapbox/light-v9',
        initial_view_state=pdk.ViewState(
            latitude=location[0],
            longitude=location[1],
            zoom=11,
            pitch=50,
        ),
        layers=[
            pdk.Layer(
                'ScatterplotLayer',
                data=map_data,
                get_position='[lon, lat]',
                get_color='[200, 30, 0, 160]',
                get_radius=200,
            ),
        ],
    ))