Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -8,7 +8,6 @@ import copy
|
|
8 |
import pydeck as pdk
|
9 |
|
10 |
keep_layer_count=6
|
11 |
-
byt5_tokenizer = AutoTokenizer.from_pretrained("yachay/byt5-geotagging-es", token="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")
|
12 |
|
13 |
|
14 |
class ByT5ForTextGeotaggingConfig(PretrainedConfig):
|
@@ -73,13 +72,27 @@ class ByT5ForTextGeotagging(PreTrainedModel):
|
|
73 |
else:
|
74 |
return logits
|
75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
def geolocate_text_byt5(text):
|
77 |
input_tensor = byt5_tokenizer(text, return_tensors="pt", truncation=True, max_length=140)['input_ids']
|
78 |
logits, (lat, lon) = model(input_tensor.unsqueeze(0), return_coordinates=True)
|
79 |
return lat, lon
|
80 |
|
81 |
-
|
|
|
82 |
|
|
|
|
|
|
|
83 |
|
84 |
example_texts = [
|
85 |
"Disfrutando de una paella deliciosa en las playas de #Valencia 🥘☀️",
|
@@ -93,17 +106,18 @@ example_texts = [
|
|
93 |
# Streamlit interface
|
94 |
st.title('GeoTagging using ByT5')
|
95 |
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
text_input = ex_text
|
100 |
-
|
101 |
-
text_input = st.text_input('Enter your text:', value=text_input if 'text_input' in locals() else '')
|
102 |
|
|
|
|
|
|
|
103 |
|
104 |
if text_input:
|
105 |
location = geolocate_text_byt5(text_input)
|
106 |
st.write('Predicted Location: ', location)
|
|
|
107 |
# Render map with pydeck
|
108 |
map_data = pd.DataFrame(
|
109 |
[[location[0], location[1]]],
|
|
|
8 |
import pydeck as pdk
|
9 |
|
10 |
keep_layer_count=6
|
|
|
11 |
|
12 |
|
13 |
class ByT5ForTextGeotaggingConfig(PretrainedConfig):
|
|
|
72 |
else:
|
73 |
return logits
|
74 |
|
75 |
+
|
76 |
+
@st.cache(allow_output_mutation=True)
|
77 |
+
def load_model_and_tokenizer():
|
78 |
+
byt5_tokenizer = AutoTokenizer.from_pretrained("yachay/byt5-geotagging-es", token="x")
|
79 |
+
model = ByT5ForTextGeotagging.from_pretrained("yachay/byt5-geotagging-es", token="x")
|
80 |
+
return byt5_tokenizer, model
|
81 |
+
|
82 |
+
byt5_tokenizer, model = load_model_and_tokenizer()
|
83 |
+
|
84 |
+
|
85 |
def geolocate_text_byt5(text):
|
86 |
input_tensor = byt5_tokenizer(text, return_tensors="pt", truncation=True, max_length=140)['input_ids']
|
87 |
logits, (lat, lon) = model(input_tensor.unsqueeze(0), return_coordinates=True)
|
88 |
return lat, lon
|
89 |
|
90 |
+
if 'text_input' not in st.session_state:
|
91 |
+
st.session_state.text_input = ""
|
92 |
|
93 |
+
# When an example button is clicked, update the session state
|
94 |
+
def set_example_text(example_text):
|
95 |
+
st.session_state.text_input = example_text
|
96 |
|
97 |
example_texts = [
|
98 |
"Disfrutando de una paella deliciosa en las playas de #Valencia 🥘☀️",
|
|
|
106 |
# Streamlit interface
|
107 |
st.title('GeoTagging using ByT5')
|
108 |
|
109 |
+
for example in example_texts:
|
110 |
+
if st.button(f"Use example: {example}"):
|
111 |
+
set_example_text(example)
|
|
|
|
|
|
|
112 |
|
113 |
+
# Get text input and update session state when it's modified
|
114 |
+
text_input = st.text_input('Enter your text:', value=st.session_state.text_input)
|
115 |
+
st.session_state.text_input = text_input
|
116 |
|
117 |
if text_input:
|
118 |
location = geolocate_text_byt5(text_input)
|
119 |
st.write('Predicted Location: ', location)
|
120 |
+
|
121 |
# Render map with pydeck
|
122 |
map_data = pd.DataFrame(
|
123 |
[[location[0], location[1]]],
|