yachay commited on
Commit
ef8e2fe
·
1 Parent(s): 48860ae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -8
app.py CHANGED
@@ -8,7 +8,6 @@ import copy
8
  import pydeck as pdk
9
 
10
  keep_layer_count=6
11
- byt5_tokenizer = AutoTokenizer.from_pretrained("yachay/byt5-geotagging-es", token="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")
12
 
13
 
14
  class ByT5ForTextGeotaggingConfig(PretrainedConfig):
@@ -73,13 +72,27 @@ class ByT5ForTextGeotagging(PreTrainedModel):
73
  else:
74
  return logits
75
 
 
 
 
 
 
 
 
 
 
 
76
  def geolocate_text_byt5(text):
77
  input_tensor = byt5_tokenizer(text, return_tensors="pt", truncation=True, max_length=140)['input_ids']
78
  logits, (lat, lon) = model(input_tensor.unsqueeze(0), return_coordinates=True)
79
  return lat, lon
80
 
81
- model = ByT5ForTextGeotagging.from_pretrained("yachay/byt5-geotagging-es", token="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")
 
82
 
 
 
 
83
 
84
  example_texts = [
85
  "Disfrutando de una paella deliciosa en las playas de #Valencia 🥘☀️",
@@ -93,17 +106,18 @@ example_texts = [
93
  # Streamlit interface
94
  st.title('GeoTagging using ByT5')
95
 
96
- # Buttons for example texts
97
- for ex_text in example_texts:
98
- if st.button(f'Example: "{ex_text[:30]}..."'):
99
- text_input = ex_text
100
-
101
- text_input = st.text_input('Enter your text:', value=text_input if 'text_input' in locals() else '')
102
 
 
 
 
103
 
104
  if text_input:
105
  location = geolocate_text_byt5(text_input)
106
  st.write('Predicted Location: ', location)
 
107
  # Render map with pydeck
108
  map_data = pd.DataFrame(
109
  [[location[0], location[1]]],
 
8
  import pydeck as pdk
9
 
10
  keep_layer_count=6
 
11
 
12
 
13
  class ByT5ForTextGeotaggingConfig(PretrainedConfig):
 
72
  else:
73
  return logits
74
 
75
+
76
+ @st.cache(allow_output_mutation=True)
77
+ def load_model_and_tokenizer():
78
+ byt5_tokenizer = AutoTokenizer.from_pretrained("yachay/byt5-geotagging-es", token="x")
79
+ model = ByT5ForTextGeotagging.from_pretrained("yachay/byt5-geotagging-es", token="x")
80
+ return byt5_tokenizer, model
81
+
82
+ byt5_tokenizer, model = load_model_and_tokenizer()
83
+
84
+
85
  def geolocate_text_byt5(text):
86
  input_tensor = byt5_tokenizer(text, return_tensors="pt", truncation=True, max_length=140)['input_ids']
87
  logits, (lat, lon) = model(input_tensor.unsqueeze(0), return_coordinates=True)
88
  return lat, lon
89
 
90
+ if 'text_input' not in st.session_state:
91
+ st.session_state.text_input = ""
92
 
93
+ # When an example button is clicked, update the session state
94
+ def set_example_text(example_text):
95
+ st.session_state.text_input = example_text
96
 
97
  example_texts = [
98
  "Disfrutando de una paella deliciosa en las playas de #Valencia 🥘☀️",
 
106
  # Streamlit interface
107
  st.title('GeoTagging using ByT5')
108
 
109
+ for example in example_texts:
110
+ if st.button(f"Use example: {example}"):
111
+ set_example_text(example)
 
 
 
112
 
113
+ # Get text input and update session state when it's modified
114
+ text_input = st.text_input('Enter your text:', value=st.session_state.text_input)
115
+ st.session_state.text_input = text_input
116
 
117
  if text_input:
118
  location = geolocate_text_byt5(text_input)
119
  st.write('Predicted Location: ', location)
120
+
121
  # Render map with pydeck
122
  map_data = pd.DataFrame(
123
  [[location[0], location[1]]],