yachay commited on
Commit
7341b3b
·
1 Parent(s): 77415bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -68,7 +68,8 @@ class ByT5ForTextGeotagging(PreTrainedModel):
68
  if return_coordinates:
69
  class_idx = torch.argmax(logits, dim=1).item()
70
  coordinates = self.config.class_to_location.get(str(class_idx))
71
- return logits, coordinates
 
72
  else:
73
  return logits
74
 
@@ -84,8 +85,8 @@ byt5_tokenizer, model = load_model_and_tokenizer()
84
 
85
  def geolocate_text_byt5(text):
86
  input_tensor = byt5_tokenizer(text, return_tensors="pt", truncation=True, max_length=140)['input_ids']
87
- logits, (lat, lon) = model(input_tensor.unsqueeze(0), return_coordinates=True)
88
- return lat, lon
89
 
90
  if 'text_input' not in st.session_state:
91
  st.session_state.text_input = ""
@@ -123,8 +124,8 @@ if st.button('Submit'):
123
 
124
 
125
  if st.session_state.text_input:
126
- location = geolocate_text_byt5(st.session_state.text_input)
127
- st.write('Predicted Location: ', location)
128
 
129
  # Render map with pydeck
130
  map_data = pd.DataFrame(
 
68
  if return_coordinates:
69
  class_idx = torch.argmax(logits, dim=1).item()
70
  coordinates = self.config.class_to_location.get(str(class_idx))
71
+ confidence = torch.max(logits).item()
72
+ return logits, coordinates, confidence
73
  else:
74
  return logits
75
 
 
85
 
86
  def geolocate_text_byt5(text):
87
  input_tensor = byt5_tokenizer(text, return_tensors="pt", truncation=True, max_length=140)['input_ids']
88
+ logits, (lat, lon), confidence = model(input_tensor.unsqueeze(0), return_coordinates=True)
89
+ return lat, lon, confidence
90
 
91
  if 'text_input' not in st.session_state:
92
  st.session_state.text_input = ""
 
124
 
125
 
126
  if st.session_state.text_input:
127
+ location, confidence = geolocate_text_byt5(st.session_state.text_input)
128
+ st.write('Predicted Location: ', location, 'confidence: ', 'High' if confidence > 0.2 else 'Low')
129
 
130
  # Render map with pydeck
131
  map_data = pd.DataFrame(