Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -12,6 +12,7 @@ from torch.utils.data import DataLoader
|
|
12 |
from PIL import Image
|
13 |
|
14 |
device = torch.device('cpu')
|
|
|
15 |
|
16 |
|
17 |
#Spacy Initialization Section
|
@@ -117,10 +118,12 @@ def get_bert_embedding(review_text):
|
|
117 |
def get_spaBert_embedding(entity):
|
118 |
entity_index = entity_index_dict.get(entity.lower(), None)
|
119 |
if entity_index is None:
|
120 |
-
|
|
|
121 |
return get_bert_embedding(entity) #Fallback in-case SpaBERT could not resolve entity to retrieve embedding. Rare-cases only.
|
122 |
else:
|
123 |
-
|
|
|
124 |
return spaBERT_embeddings[entity_index]
|
125 |
|
126 |
|
@@ -135,7 +138,8 @@ def processSpatialEntities(review, nlp):
|
|
135 |
if label in ['FAC', 'ORG', 'LOC', 'GPE']: # Filter to geo-entities
|
136 |
spaBert_emb = get_spaBert_embedding(text)
|
137 |
token_embeddings.append(spaBert_emb)
|
138 |
-
|
|
|
139 |
|
140 |
token_embeddings = torch.stack(token_embeddings, dim=0)
|
141 |
processed_embedding = token_embeddings.mean(dim=0) # Shape: (768)
|
@@ -246,14 +250,14 @@ selected_review = example_reviews[user_input]
|
|
246 |
if st.button("Highlight Geo-Entities"):
|
247 |
if selected_review.strip():
|
248 |
bert_embedding = get_bert_embedding(selected_review)
|
249 |
-
st.write("Review Embedding Shape:", bert_embedding.shape)
|
250 |
-
|
251 |
spaBert_embedding = processSpatialEntities(selected_review,nlp)
|
252 |
-
st.write("Geo-Entities embedding shape: ", spaBert_embedding.shape)
|
253 |
-
|
254 |
combined_embedding = torch.cat((bert_embedding,spaBert_embedding),dim=-1)
|
255 |
-
|
256 |
-
|
|
|
|
|
|
|
|
|
257 |
|
258 |
prediction = get_prediction(combined_embedding)
|
259 |
|
|
|
12 |
from PIL import Image
|
13 |
|
14 |
device = torch.device('cpu')
|
15 |
+
dev_mode = True
|
16 |
|
17 |
|
18 |
#Spacy Initialization Section
|
|
|
118 |
def get_spaBert_embedding(entity):
|
119 |
entity_index = entity_index_dict.get(entity.lower(), None)
|
120 |
if entity_index is None:
|
121 |
+
if(dev_mode == True):
|
122 |
+
st.write("Got Bert embedding for: ", entity)
|
123 |
return get_bert_embedding(entity) #Fallback in-case SpaBERT could not resolve entity to retrieve embedding. Rare-cases only.
|
124 |
else:
|
125 |
+
if(dev_mode == True):
|
126 |
+
st.write("Got SpaBert embedding for: ", entity)
|
127 |
return spaBERT_embeddings[entity_index]
|
128 |
|
129 |
|
|
|
138 |
if label in ['FAC', 'ORG', 'LOC', 'GPE']: # Filter to geo-entities
|
139 |
spaBert_emb = get_spaBert_embedding(text)
|
140 |
token_embeddings.append(spaBert_emb)
|
141 |
+
if(dev_mode == True)
|
142 |
+
st.write("Geo-Entity Found in review: ", text)
|
143 |
|
144 |
token_embeddings = torch.stack(token_embeddings, dim=0)
|
145 |
processed_embedding = token_embeddings.mean(dim=0) # Shape: (768)
|
|
|
250 |
if st.button("Highlight Geo-Entities"):
|
251 |
if selected_review.strip():
|
252 |
bert_embedding = get_bert_embedding(selected_review)
|
|
|
|
|
253 |
spaBert_embedding = processSpatialEntities(selected_review,nlp)
|
|
|
|
|
254 |
combined_embedding = torch.cat((bert_embedding,spaBert_embedding),dim=-1)
|
255 |
+
|
256 |
+
if(dev_mode == True):
|
257 |
+
st.write("Review Embedding Shape:", bert_embedding.shape)
|
258 |
+
st.write("Geo-Entities embedding shape: ", spaBert_embedding.shape)
|
259 |
+
st.write("Concatenated Embedding Shape:", combined_embedding.shape)
|
260 |
+
st.write("Concatenated Embedding:", combined_embedding)
|
261 |
|
262 |
prediction = get_prediction(combined_embedding)
|
263 |
|