Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -82,6 +82,27 @@ def load_model_and_tokenizer():
|
|
82 |
|
83 |
byt5_tokenizer, model = load_model_and_tokenizer()
|
84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
def geolocate_text_byt5(text):
|
87 |
input_tensor = byt5_tokenizer(text, return_tensors="pt", truncation=True, max_length=140)['input_ids']
|
@@ -124,7 +145,8 @@ if st.button('Submit'):
|
|
124 |
|
125 |
|
126 |
if st.session_state.text_input:
|
127 |
-
|
|
|
128 |
st.write('Predicted Location: (', lat, lon, '). Confidence: ', 'High' if confidence > 0.2 else 'Low')
|
129 |
|
130 |
# Render map with pydeck
|
|
|
82 |
|
83 |
byt5_tokenizer, model = load_model_and_tokenizer()
|
84 |
|
85 |
+
def geolocate_text_byt5_multiclass(text):
|
86 |
+
input_tensor = byt5_tokenizer(text, return_tensors="pt", truncation=True, max_length=140)['input_ids']
|
87 |
+
logits, (lat, lon), confidence = model(input_tensor.unsqueeze(0), return_coordinates=True)
|
88 |
+
probas = torch.nn.functional.softmax(logits, dim=1).detach().cpu().numpy()
|
89 |
+
|
90 |
+
# Sort probabilities in descending order and get their indices
|
91 |
+
sorted_indices = np.argsort(-probas[0])
|
92 |
+
|
93 |
+
results = []
|
94 |
+
cumulative_prob = 0.0
|
95 |
+
for class_idx in sorted_indices:
|
96 |
+
prob = probas[0][class_idx]
|
97 |
+
cumulative_prob += prob
|
98 |
+
if cumulative_prob > 0.5:
|
99 |
+
coordinates = model.config.class_to_location.get(str(class_idx))
|
100 |
+
if coordinates:
|
101 |
+
results.append((class_idx, prob, coordinates))
|
102 |
+
break
|
103 |
+
|
104 |
+
return results
|
105 |
+
|
106 |
|
107 |
def geolocate_text_byt5(text):
|
108 |
input_tensor = byt5_tokenizer(text, return_tensors="pt", truncation=True, max_length=140)['input_ids']
|
|
|
145 |
|
146 |
|
147 |
if st.session_state.text_input:
|
148 |
+
results = geolocate_text_byt5_multiclass(st.session_state.text_input)
|
149 |
+
_, confidence, (lat, lon) = results[0]
|
150 |
st.write('Predicted Location: (', lat, lon, '). Confidence: ', 'High' if confidence > 0.2 else 'Low')
|
151 |
|
152 |
# Render map with pydeck
|