yachay commited on
Commit
66ae132
·
1 Parent(s): 6f8ed7d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -1
app.py CHANGED
@@ -82,6 +82,27 @@ def load_model_and_tokenizer():
82
 
83
  byt5_tokenizer, model = load_model_and_tokenizer()
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
  def geolocate_text_byt5(text):
87
  input_tensor = byt5_tokenizer(text, return_tensors="pt", truncation=True, max_length=140)['input_ids']
@@ -124,7 +145,8 @@ if st.button('Submit'):
124
 
125
 
126
  if st.session_state.text_input:
127
- lat, lon, confidence = geolocate_text_byt5(st.session_state.text_input)
 
128
  st.write('Predicted Location: (', lat, lon, '). Confidence: ', 'High' if confidence > 0.2 else 'Low')
129
 
130
  # Render map with pydeck
 
82
 
83
  byt5_tokenizer, model = load_model_and_tokenizer()
84
 
85
+ def geolocate_text_byt5_multiclass(text):
86
+ input_tensor = byt5_tokenizer(text, return_tensors="pt", truncation=True, max_length=140)['input_ids']
87
+ logits, (lat, lon), confidence = model(input_tensor.unsqueeze(0), return_coordinates=True)
88
+ probas = torch.nn.functional.softmax(logits, dim=1).detach().cpu().numpy()
89
+
90
+ # Sort probabilities in descending order and get their indices
91
+ sorted_indices = np.argsort(-probas[0])
92
+
93
+ results = []
94
+ cumulative_prob = 0.0
95
+ for class_idx in sorted_indices:
96
+ prob = probas[0][class_idx]
97
+ cumulative_prob += prob
98
+ if cumulative_prob > 0.5:
99
+ coordinates = model.config.class_to_location.get(str(class_idx))
100
+ if coordinates:
101
+ results.append((class_idx, prob, coordinates))
102
+ break
103
+
104
+ return results
105
+
106
 
107
  def geolocate_text_byt5(text):
108
  input_tensor = byt5_tokenizer(text, return_tensors="pt", truncation=True, max_length=140)['input_ids']
 
145
 
146
 
147
  if st.session_state.text_input:
148
+ results = geolocate_text_byt5_multiclass(st.session_state.text_input)
149
+ _, confidence, (lat, lon) = results[0]
150
  st.write('Predicted Location: (', lat, lon, '). Confidence: ', 'High' if confidence > 0.2 else 'Low')
151
 
152
  # Render map with pydeck