JasonTPhillipsJr commited on
Commit
60a8335
·
verified ·
1 Parent(s): 530b1c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -12
app.py CHANGED
@@ -13,15 +13,18 @@ from PIL import Image
13
 
14
  device = torch.device('cpu')
15
 
 
16
  #Spacy Initialization Section
17
  nlp = spacy.load("./models/en_core_web_sm")
18
 
 
19
  #BERT Initialization Section
20
  bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
21
  bert_model = BertModel.from_pretrained("bert-base-uncased")
22
  bert_model.to(device)
23
  bert_model.eval()
24
 
 
25
  #SpaBERT Initialization Section
26
  data_file_path = 'models/spabert/datasets/SpaBERTPivots.json' #Sample file otherwise this model will take too long on CPU.
27
  pretrained_model_path = 'models/spabert/datasets/fine-spabert-base-uncased-finetuned-osm-mn.pth'
@@ -37,6 +40,7 @@ spaBERT_model.load_state_dict(pre_trained_model, strict=False)
37
  spaBERT_model.to(device)
38
  spaBERT_model.eval()
39
 
 
40
  #Load data using SpatialDataset
41
  spatialDataset = PbfMapDataset(data_file_path = data_file_path,
42
  tokenizer = bert_tokenizer,
@@ -57,6 +61,7 @@ entity_index_dict = {entity['pivot_name']: i for i, entity in enumerate(spatialD
57
  # Ensure names are stored in lowercase for case-insensitive matching
58
  entity_index_dict = {name.lower(): index for name, index in entity_index_dict.items()}
59
 
 
60
  #Pre-aquire the SpaBERT embeddings for all geo-entities within our dataset
61
  def process_entity(batch, model, device):
62
  input_ids = batch['masked_input'].to(device)
@@ -91,8 +96,6 @@ for batch in (data_loader):
91
  spaBERT_embedding, input_ids = process_entity(batch, spaBERT_model, device)
92
  spaBERT_embeddings.append(spaBERT_embedding)
93
 
94
- #st.write("SpaBERT Embedding shape:", spaBERT_embedding[0].shape)
95
- #st.write("SpaBERT Embedding:", spaBERT_embedding[0])
96
  embedding_cache = {}
97
 
98
 
@@ -109,16 +112,18 @@ def get_bert_embedding(review_text):
109
  bert_embedding = outputs.last_hidden_state[:, 0, :].detach() #CLS Token
110
  return bert_embedding
111
 
 
112
  #Get SpaBERT Embedding for geo-entity
113
  def get_spaBert_embedding(entity):
114
  entity_index = entity_index_dict.get(entity.lower(), None)
115
  if entity_index is None:
116
  st.write("Got Bert embedding for: ", entity)
117
- return get_bert_embedding(entity)
118
  else:
119
  st.write("Got SpaBert embedding for: ", entity)
120
  return spaBERT_embeddings[entity_index]
121
 
 
122
  #Go through each review, identify all geo-entities, then extract their SpaBERT embedings
123
  def processSpatialEntities(review, nlp):
124
  doc = nlp(review)
@@ -131,7 +136,11 @@ def processSpatialEntities(review, nlp):
131
  spaBert_emb = get_spaBert_embedding(text)
132
  token_embeddings.append((text, spaBert_emb))
133
  st.write("Geo-Entity Found in review: ", text)
134
- return token_embeddings
 
 
 
 
135
 
136
  # Function to read reviews from a text file
137
  def load_reviews_from_file(file_path):
@@ -163,14 +172,6 @@ st.write("**Color Key:**")
163
  for label, (color, description) in COLOR_MAP.items():
164
  st.markdown(f"- **{label}**: <span style='color:{color}'>{color}</span> - {description}", unsafe_allow_html=True)
165
 
166
- # Text input
167
- #user_input = st.text_area("Input Text", height=200)
168
-
169
- # Define example reviews for testing
170
- #example_reviews = {
171
- # "Review 1": "Meh. My brother lives near the Italian Market in South Philly. I went for a visit. Luckily for me, my brother and his girlfriend are foodies. I was able to taste many different cuisines in Philly. Coming from San Francisco, there are places I don't go due to the tourist trap aura and the non-authenticity of it all (Fisherman’s Wharf, Chinatown, etc.). But when I was in Philly, I had to have a cheesesteak... and I had to go to the two most famous places, which of course are right across the street from one another, in a big rivalry, and featured on the Food Network! How cheesy, but essential. We split two, both "wit whiz"? (cheese whiz) one from Geno's and one from Pat's. Pat's was much tastier than Geno's. The meat was seasoned, and the bun and cheese had much more flavor... better of the two... it seems.",
172
- # "Review 2": "Google, headquartered in Mountain View, is a leading tech company in the United States.",
173
- #}
174
  review_file_path = "models/spabert/datasets/SampleReviews.txt"
175
  example_reviews = load_reviews_from_file(review_file_path)
176
 
 
13
 
14
  device = torch.device('cpu')
15
 
16
+
17
  #Spacy Initialization Section
18
  nlp = spacy.load("./models/en_core_web_sm")
19
 
20
+
21
  #BERT Initialization Section
22
  bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
23
  bert_model = BertModel.from_pretrained("bert-base-uncased")
24
  bert_model.to(device)
25
  bert_model.eval()
26
 
27
+
28
  #SpaBERT Initialization Section
29
  data_file_path = 'models/spabert/datasets/SpaBERTPivots.json' #Sample file otherwise this model will take too long on CPU.
30
  pretrained_model_path = 'models/spabert/datasets/fine-spabert-base-uncased-finetuned-osm-mn.pth'
 
40
  spaBERT_model.to(device)
41
  spaBERT_model.eval()
42
 
43
+
44
  #Load data using SpatialDataset
45
  spatialDataset = PbfMapDataset(data_file_path = data_file_path,
46
  tokenizer = bert_tokenizer,
 
61
  # Ensure names are stored in lowercase for case-insensitive matching
62
  entity_index_dict = {name.lower(): index for name, index in entity_index_dict.items()}
63
 
64
+
65
  #Pre-aquire the SpaBERT embeddings for all geo-entities within our dataset
66
  def process_entity(batch, model, device):
67
  input_ids = batch['masked_input'].to(device)
 
96
  spaBERT_embedding, input_ids = process_entity(batch, spaBERT_model, device)
97
  spaBERT_embeddings.append(spaBERT_embedding)
98
 
 
 
99
  embedding_cache = {}
100
 
101
 
 
112
  bert_embedding = outputs.last_hidden_state[:, 0, :].detach() #CLS Token
113
  return bert_embedding
114
 
115
+
116
  #Get SpaBERT Embedding for geo-entity
117
  def get_spaBert_embedding(entity):
118
  entity_index = entity_index_dict.get(entity.lower(), None)
119
  if entity_index is None:
120
  st.write("Got Bert embedding for: ", entity)
121
+ return get_bert_embedding(entity) #Fallback in-case SpaBERT could not resolve entity to retrieve embedding. Rare-cases only.
122
  else:
123
  st.write("Got SpaBert embedding for: ", entity)
124
  return spaBERT_embeddings[entity_index]
125
 
126
+
127
  #Go through each review, identify all geo-entities, then extract their SpaBERT embedings
128
  def processSpatialEntities(review, nlp):
129
  doc = nlp(review)
 
136
  spaBert_emb = get_spaBert_embedding(text)
137
  token_embeddings.append((text, spaBert_emb))
138
  st.write("Geo-Entity Found in review: ", text)
139
+
140
+ processed_embedding = torch.cat(token_embeddings, dim=0)
141
+ st.write("processed embedding shape: " processed_embedding.shape)
142
+ return processed_embedding
143
+
144
 
145
  # Function to read reviews from a text file
146
  def load_reviews_from_file(file_path):
 
172
  for label, (color, description) in COLOR_MAP.items():
173
  st.markdown(f"- **{label}**: <span style='color:{color}'>{color}</span> - {description}", unsafe_allow_html=True)
174
 
 
 
 
 
 
 
 
 
175
  review_file_path = "models/spabert/datasets/SampleReviews.txt"
176
  example_reviews = load_reviews_from_file(review_file_path)
177