yachay commited on
Commit
151c0f1
·
1 Parent(s): ae294a4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -0
app.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import necessary libraries
2
+ import streamlit as st
3
+ from transformers import PretrainedConfig, PreTrainedModel, T5EncoderModel, AutoTokenizer
4
+ import torch
5
+ import torch.nn as nn
6
+ import copy
7
+
8
+ keep_layer_count=6
9
+
10
+ class ByT5ForTextGeotaggingConfig(PretrainedConfig):
11
+ model_type = "byt5_for_text)geotagging"
12
+
13
+ def __init__(self, n_clusters, model_name_or_path, class_to_location=None, **kwargs):
14
+ super(ByT5ForTextGeotaggingConfig, self).__init__(**kwargs)
15
+ self.n_clusters = n_clusters
16
+ self.model_name_or_path = model_name_or_path
17
+ self.class_to_location = class_to_location or {}
18
+
19
+
20
+ def to_diff_dict(self):
21
+ # Convert the configuration to a dictionary
22
+ config_dict = self.to_dict()
23
+
24
+ # Get the default configuration for comparison
25
+ default_config_dict = PretrainedConfig().to_dict()
26
+
27
+ # Return the differences
28
+ diff_dict = {k: v for k, v in config_dict.items() if k not in default_config_dict or v != default_config_dict[k]}
29
+
30
+ return diff_dict
31
+
32
+
33
+ def deleteEncodingLayers(model, num_layers_to_keep): # must pass in the full bert model
34
+ oldModuleList = model.encoder.block
35
+ newModuleList = torch.nn.ModuleList()
36
+
37
+ # Now iterate over all layers, only keepign only the relevant layers.
38
+ for i in range(0, num_layers_to_keep):
39
+ newModuleList.append(oldModuleList[i])
40
+
41
+ # create a copy of the model, modify it with the new list, and return
42
+ copyOfModel = copy.deepcopy(model)
43
+ copyOfModel.encoder.block = newModuleList
44
+
45
+ return copyOfModel
46
+
47
+ class ByT5ForTextGeotagging(PreTrainedModel):
48
+ config_class = ByT5ForTextGeotaggingConfig
49
+
50
+ def __init__(self, config):
51
+ super(ByT5ForTextGeotagging, self).__init__(config)
52
+
53
+ self.byt5 = T5EncoderModel.from_pretrained(config.model_name_or_path)
54
+ if keep_layer_count is not None:
55
+ self.byt5 = deleteEncodingLayers(self.byt5, keep_layer_count)
56
+
57
+ hidden_size = self.byt5.config.d_model
58
+ self.fc3 = nn.Linear(hidden_size, config.n_clusters)
59
+
60
+ def forward(self, input, return_coordinates=False):
61
+ input = self.byt5(input[:, 0, :].squeeze(1))['last_hidden_state']
62
+ input = input[:, 0, :].squeeze(1)
63
+ logits = self.fc3(input)
64
+
65
+ if return_coordinates:
66
+ class_idx = torch.argmax(logits, dim=1).item()
67
+ coordinates = self.config.class_to_location.get(str(class_idx))
68
+ return logits, coordinates
69
+ else:
70
+ return logits
71
+
72
+ def geolocate_text_byt5(text):
73
+ input_tensor = byt5_tokenizer(text, return_tensors="pt", truncation=True, max_length=140)['input_ids']
74
+ logits, (lat, lon) = model(input_tensor.unsqueeze(0), return_coordinates=True)
75
+ return lat, lon
76
+
77
+
78
+ model = ByT5ForTextGeotagging.from_pretrained("byt5-geotagging-spanish")
79
+
80
+ #text = "¡Barcelona es increíble! #VacacionesEnEspaña"
81
+
82
+ # Streamlit interface
83
+ st.title('GeoTagging using ByT5')
84
+ text_input = st.text_input('Enter your text:')
85
+ if text_input:
86
+ location = geolocate_text_byt5(text_input)
87
+ st.write('Predicted Location: ', location)