Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -13,15 +13,18 @@ from PIL import Image
|
|
13 |
|
14 |
device = torch.device('cpu')
|
15 |
|
|
|
16 |
#Spacy Initialization Section
|
17 |
nlp = spacy.load("./models/en_core_web_sm")
|
18 |
|
|
|
19 |
#BERT Initialization Section
|
20 |
bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
21 |
bert_model = BertModel.from_pretrained("bert-base-uncased")
|
22 |
bert_model.to(device)
|
23 |
bert_model.eval()
|
24 |
|
|
|
25 |
#SpaBERT Initialization Section
|
26 |
data_file_path = 'models/spabert/datasets/SpaBERTPivots.json' #Sample file otherwise this model will take too long on CPU.
|
27 |
pretrained_model_path = 'models/spabert/datasets/fine-spabert-base-uncased-finetuned-osm-mn.pth'
|
@@ -37,6 +40,7 @@ spaBERT_model.load_state_dict(pre_trained_model, strict=False)
|
|
37 |
spaBERT_model.to(device)
|
38 |
spaBERT_model.eval()
|
39 |
|
|
|
40 |
#Load data using SpatialDataset
|
41 |
spatialDataset = PbfMapDataset(data_file_path = data_file_path,
|
42 |
tokenizer = bert_tokenizer,
|
@@ -57,6 +61,7 @@ entity_index_dict = {entity['pivot_name']: i for i, entity in enumerate(spatialD
|
|
57 |
# Ensure names are stored in lowercase for case-insensitive matching
|
58 |
entity_index_dict = {name.lower(): index for name, index in entity_index_dict.items()}
|
59 |
|
|
|
60 |
#Pre-aquire the SpaBERT embeddings for all geo-entities within our dataset
|
61 |
def process_entity(batch, model, device):
|
62 |
input_ids = batch['masked_input'].to(device)
|
@@ -91,8 +96,6 @@ for batch in (data_loader):
|
|
91 |
spaBERT_embedding, input_ids = process_entity(batch, spaBERT_model, device)
|
92 |
spaBERT_embeddings.append(spaBERT_embedding)
|
93 |
|
94 |
-
#st.write("SpaBERT Embedding shape:", spaBERT_embedding[0].shape)
|
95 |
-
#st.write("SpaBERT Embedding:", spaBERT_embedding[0])
|
96 |
embedding_cache = {}
|
97 |
|
98 |
|
@@ -109,16 +112,18 @@ def get_bert_embedding(review_text):
|
|
109 |
bert_embedding = outputs.last_hidden_state[:, 0, :].detach() #CLS Token
|
110 |
return bert_embedding
|
111 |
|
|
|
112 |
#Get SpaBERT Embedding for geo-entity
|
113 |
def get_spaBert_embedding(entity):
|
114 |
entity_index = entity_index_dict.get(entity.lower(), None)
|
115 |
if entity_index is None:
|
116 |
st.write("Got Bert embedding for: ", entity)
|
117 |
-
return get_bert_embedding(entity)
|
118 |
else:
|
119 |
st.write("Got SpaBert embedding for: ", entity)
|
120 |
return spaBERT_embeddings[entity_index]
|
121 |
|
|
|
122 |
#Go through each review, identify all geo-entities, then extract their SpaBERT embedings
|
123 |
def processSpatialEntities(review, nlp):
|
124 |
doc = nlp(review)
|
@@ -131,7 +136,11 @@ def processSpatialEntities(review, nlp):
|
|
131 |
spaBert_emb = get_spaBert_embedding(text)
|
132 |
token_embeddings.append((text, spaBert_emb))
|
133 |
st.write("Geo-Entity Found in review: ", text)
|
134 |
-
|
|
|
|
|
|
|
|
|
135 |
|
136 |
# Function to read reviews from a text file
|
137 |
def load_reviews_from_file(file_path):
|
@@ -163,14 +172,6 @@ st.write("**Color Key:**")
|
|
163 |
for label, (color, description) in COLOR_MAP.items():
|
164 |
st.markdown(f"- **{label}**: <span style='color:{color}'>{color}</span> - {description}", unsafe_allow_html=True)
|
165 |
|
166 |
-
# Text input
|
167 |
-
#user_input = st.text_area("Input Text", height=200)
|
168 |
-
|
169 |
-
# Define example reviews for testing
|
170 |
-
#example_reviews = {
|
171 |
-
# "Review 1": "Meh. My brother lives near the Italian Market in South Philly. I went for a visit. Luckily for me, my brother and his girlfriend are foodies. I was able to taste many different cuisines in Philly. Coming from San Francisco, there are places I don't go due to the tourist trap aura and the non-authenticity of it all (Fisherman’s Wharf, Chinatown, etc.). But when I was in Philly, I had to have a cheesesteak... and I had to go to the two most famous places, which of course are right across the street from one another, in a big rivalry, and featured on the Food Network! How cheesy, but essential. We split two, both "wit whiz"? (cheese whiz) one from Geno's and one from Pat's. Pat's was much tastier than Geno's. The meat was seasoned, and the bun and cheese had much more flavor... better of the two... it seems.",
|
172 |
-
# "Review 2": "Google, headquartered in Mountain View, is a leading tech company in the United States.",
|
173 |
-
#}
|
174 |
review_file_path = "models/spabert/datasets/SampleReviews.txt"
|
175 |
example_reviews = load_reviews_from_file(review_file_path)
|
176 |
|
|
|
13 |
|
14 |
device = torch.device('cpu')
|
15 |
|
16 |
+
|
17 |
#Spacy Initialization Section
|
18 |
nlp = spacy.load("./models/en_core_web_sm")
|
19 |
|
20 |
+
|
21 |
#BERT Initialization Section
|
22 |
bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
23 |
bert_model = BertModel.from_pretrained("bert-base-uncased")
|
24 |
bert_model.to(device)
|
25 |
bert_model.eval()
|
26 |
|
27 |
+
|
28 |
#SpaBERT Initialization Section
|
29 |
data_file_path = 'models/spabert/datasets/SpaBERTPivots.json' #Sample file otherwise this model will take too long on CPU.
|
30 |
pretrained_model_path = 'models/spabert/datasets/fine-spabert-base-uncased-finetuned-osm-mn.pth'
|
|
|
40 |
spaBERT_model.to(device)
|
41 |
spaBERT_model.eval()
|
42 |
|
43 |
+
|
44 |
#Load data using SpatialDataset
|
45 |
spatialDataset = PbfMapDataset(data_file_path = data_file_path,
|
46 |
tokenizer = bert_tokenizer,
|
|
|
61 |
# Ensure names are stored in lowercase for case-insensitive matching
|
62 |
entity_index_dict = {name.lower(): index for name, index in entity_index_dict.items()}
|
63 |
|
64 |
+
|
65 |
#Pre-aquire the SpaBERT embeddings for all geo-entities within our dataset
|
66 |
def process_entity(batch, model, device):
|
67 |
input_ids = batch['masked_input'].to(device)
|
|
|
96 |
spaBERT_embedding, input_ids = process_entity(batch, spaBERT_model, device)
|
97 |
spaBERT_embeddings.append(spaBERT_embedding)
|
98 |
|
|
|
|
|
99 |
embedding_cache = {}
|
100 |
|
101 |
|
|
|
112 |
bert_embedding = outputs.last_hidden_state[:, 0, :].detach() #CLS Token
|
113 |
return bert_embedding
|
114 |
|
115 |
+
|
116 |
#Get SpaBERT Embedding for geo-entity
|
117 |
def get_spaBert_embedding(entity):
|
118 |
entity_index = entity_index_dict.get(entity.lower(), None)
|
119 |
if entity_index is None:
|
120 |
st.write("Got Bert embedding for: ", entity)
|
121 |
+
return get_bert_embedding(entity) #Fallback in-case SpaBERT could not resolve entity to retrieve embedding. Rare-cases only.
|
122 |
else:
|
123 |
st.write("Got SpaBert embedding for: ", entity)
|
124 |
return spaBERT_embeddings[entity_index]
|
125 |
|
126 |
+
|
127 |
#Go through each review, identify all geo-entities, then extract their SpaBERT embedings
|
128 |
def processSpatialEntities(review, nlp):
|
129 |
doc = nlp(review)
|
|
|
136 |
spaBert_emb = get_spaBert_embedding(text)
|
137 |
token_embeddings.append((text, spaBert_emb))
|
138 |
st.write("Geo-Entity Found in review: ", text)
|
139 |
+
|
140 |
+
processed_embedding = torch.cat(token_embeddings, dim=0)
|
141 |
+
st.write("processed embedding shape: " processed_embedding.shape)
|
142 |
+
return processed_embedding
|
143 |
+
|
144 |
|
145 |
# Function to read reviews from a text file
|
146 |
def load_reviews_from_file(file_path):
|
|
|
172 |
for label, (color, description) in COLOR_MAP.items():
|
173 |
st.markdown(f"- **{label}**: <span style='color:{color}'>{color}</span> - {description}", unsafe_allow_html=True)
|
174 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
review_file_path = "models/spabert/datasets/SampleReviews.txt"
|
176 |
example_reviews = load_reviews_from_file(review_file_path)
|
177 |
|