Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -68,7 +68,8 @@ class ByT5ForTextGeotagging(PreTrainedModel):
|
|
68 |
if return_coordinates:
|
69 |
class_idx = torch.argmax(logits, dim=1).item()
|
70 |
coordinates = self.config.class_to_location.get(str(class_idx))
|
71 |
-
|
|
|
72 |
else:
|
73 |
return logits
|
74 |
|
@@ -84,8 +85,8 @@ byt5_tokenizer, model = load_model_and_tokenizer()
|
|
84 |
|
85 |
def geolocate_text_byt5(text):
|
86 |
input_tensor = byt5_tokenizer(text, return_tensors="pt", truncation=True, max_length=140)['input_ids']
|
87 |
-
logits, (lat, lon) = model(input_tensor.unsqueeze(0), return_coordinates=True)
|
88 |
-
return lat, lon
|
89 |
|
90 |
if 'text_input' not in st.session_state:
|
91 |
st.session_state.text_input = ""
|
@@ -123,8 +124,8 @@ if st.button('Submit'):
|
|
123 |
|
124 |
|
125 |
if st.session_state.text_input:
|
126 |
-
location = geolocate_text_byt5(st.session_state.text_input)
|
127 |
-
st.write('Predicted Location: ', location)
|
128 |
|
129 |
# Render map with pydeck
|
130 |
map_data = pd.DataFrame(
|
|
|
68 |
if return_coordinates:
|
69 |
class_idx = torch.argmax(logits, dim=1).item()
|
70 |
coordinates = self.config.class_to_location.get(str(class_idx))
|
71 |
+
confidence = torch.max(logits).item()
|
72 |
+
return logits, coordinates, confidence
|
73 |
else:
|
74 |
return logits
|
75 |
|
|
|
85 |
|
86 |
def geolocate_text_byt5(text):
|
87 |
input_tensor = byt5_tokenizer(text, return_tensors="pt", truncation=True, max_length=140)['input_ids']
|
88 |
+
logits, (lat, lon), confidence = model(input_tensor.unsqueeze(0), return_coordinates=True)
|
89 |
+
return lat, lon, confidence
|
90 |
|
91 |
if 'text_input' not in st.session_state:
|
92 |
st.session_state.text_input = ""
|
|
|
124 |
|
125 |
|
126 |
if st.session_state.text_input:
|
127 |
+
location, confidence = geolocate_text_byt5(st.session_state.text_input)
|
128 |
+
st.write('Predicted Location: ', location, 'confidence: ', 'High' if confidence > 0.2 else 'Low')
|
129 |
|
130 |
# Render map with pydeck
|
131 |
map_data = pd.DataFrame(
|