File size: 4,538 Bytes
151c0f1
 
 
 
 
 
50ba2ea
151c0f1
 
5290c3a
 
151c0f1
 
4f08151
151c0f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff38b8f
151c0f1
c5aa871
 
 
 
 
 
 
 
 
151c0f1
 
 
c5aa871
 
 
 
 
 
 
 
 
151c0f1
 
50ba2ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# Import necessary libraries
import streamlit as st
from transformers import PretrainedConfig, PreTrainedModel, T5EncoderModel, AutoTokenizer
import torch
import torch.nn as nn
import copy
import pydeck as pdk

keep_layer_count=6
byt5_tokenizer = AutoTokenizer.from_pretrained("yachay/byt5-geotagging-es", token="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")


class ByT5ForTextGeotaggingConfig(PretrainedConfig):
    model_type = "byt5_for_text_geotagging"

    def __init__(self, n_clusters, model_name_or_path, class_to_location=None, **kwargs):
        super(ByT5ForTextGeotaggingConfig, self).__init__(**kwargs)
        self.n_clusters = n_clusters
        self.model_name_or_path = model_name_or_path
        self.class_to_location = class_to_location or {}

        
    def to_diff_dict(self):
        # Convert the configuration to a dictionary
        config_dict = self.to_dict()

        # Get the default configuration for comparison
        default_config_dict = PretrainedConfig().to_dict()
        
        # Return the differences
        diff_dict = {k: v for k, v in config_dict.items() if k not in default_config_dict or v != default_config_dict[k]}
        
        return diff_dict
    

def deleteEncodingLayers(model, num_layers_to_keep):  # must pass in the full bert model
    oldModuleList = model.encoder.block
    newModuleList = torch.nn.ModuleList()

    # Now iterate over all layers, only keepign only the relevant layers.
    for i in range(0, num_layers_to_keep):
        newModuleList.append(oldModuleList[i])

    # create a copy of the model, modify it with the new list, and return
    copyOfModel = copy.deepcopy(model)
    copyOfModel.encoder.block = newModuleList

    return copyOfModel

class ByT5ForTextGeotagging(PreTrainedModel):
    config_class = ByT5ForTextGeotaggingConfig
    
    def __init__(self, config):
        super(ByT5ForTextGeotagging, self).__init__(config)

        self.byt5 = T5EncoderModel.from_pretrained(config.model_name_or_path)
        if keep_layer_count is not None:
            self.byt5 = deleteEncodingLayers(self.byt5, keep_layer_count)

        hidden_size = self.byt5.config.d_model
        self.fc3 = nn.Linear(hidden_size, config.n_clusters)

    def forward(self, input, return_coordinates=False):
        input = self.byt5(input[:, 0, :].squeeze(1))['last_hidden_state']
        input = input[:, 0, :].squeeze(1)
        logits = self.fc3(input)

        if return_coordinates:
            class_idx = torch.argmax(logits, dim=1).item()
            coordinates = self.config.class_to_location.get(str(class_idx))
            return logits, coordinates
        else:
            return logits

def geolocate_text_byt5(text):
    input_tensor = byt5_tokenizer(text, return_tensors="pt", truncation=True, max_length=140)['input_ids']
    logits, (lat, lon) = model(input_tensor.unsqueeze(0), return_coordinates=True)
    return lat, lon

model = ByT5ForTextGeotagging.from_pretrained("yachay/byt5-geotagging-es", token="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")


example_texts = [
    "Disfrutando de una paella deliciosa en las playas de #Valencia 🥘☀️",
    "La arquitectura de #Tokio es realmente algo fuera de este mundo 🌆🇯🇵",
    "Escuchando jazz en un café acogedor en el corazón de #NuevaOrleans 🎷🎶",
    "Los atardeceres en #CapeTown con la vista del Monte Table son inolvidables 🌅🇿🇦",
    "Nada se compara con caminar por las históricas calles de #Roma 🏛️🍕"
]


# Streamlit interface
st.title('GeoTagging using ByT5')

# Buttons for example texts
for ex_text in example_texts:
    if st.button(f'Example: "{ex_text[:30]}..."'):
        text_input = ex_text

text_input = st.text_input('Enter your text:', value=text_input if 'text_input' in locals() else '')


if text_input:
    location = geolocate_text_byt5(text_input)
    st.write('Predicted Location: ', location)
     # Render map with pydeck
    map_data = pd.DataFrame(
        [[location[0], location[1]]],
        columns=["lat", "lon"]
    )

    st.pydeck_chart(pdk.Deck(
        map_style='mapbox://styles/mapbox/light-v9',
        initial_view_state=pdk.ViewState(
            latitude=location[0],
            longitude=location[1],
            zoom=11,
            pitch=50,
        ),
        layers=[
            pdk.Layer(
                'ScatterplotLayer',
                data=map_data,
                get_position='[lon, lat]',
                get_color='[200, 30, 0, 160]',
                get_radius=200,
            ),
        ],
    ))