Spaces:
Sleeping
Sleeping
File size: 10,057 Bytes
ac736ed 5c91758 bc50d7d 3d81019 ebf50a4 b7111b8 3d8cd48 b7111b8 83e90d7 ac736ed d914cbe 1aa7dda 3d81019 dc9ff0b 3d81019 d914cbe 3d81019 dc9ff0b bf52bfd d914cbe 3be15aa fbce538 3d81019 3be15aa d914cbe 80744c0 3d81019 a74fa0d 3d8cd48 5914cea 18634d6 5914cea 3d8cd48 bf52bfd 18634d6 3d8cd48 3d81019 4c18d69 a74fa0d 169e7aa a74fa0d 857dba3 a74fa0d 0cea6d5 a74fa0d 857dba3 a74fa0d 857dba3 564da7a 0cea6d5 a74fa0d 4c18d69 a74fa0d f82dac8 1091141 f82dac8 857dba3 3d81019 cbcad17 4c18d69 cbcad17 4c18d69 b9b16e5 4c18d69 18634d6 5c91758 83e90d7 9822204 3577a57 9822204 3577a57 d914cbe 3577a57 9822204 3577a57 9822204 b23060e 4c18d69 b23060e 83e90d7 5c91758 b23060e e72b522 b9b16e5 bf52bfd 18634d6 857dba3 6ed7a92 18634d6 5c91758 b23060e e72b522 b4303dc b23060e b4303dc 83e90d7 3577a57 5c91758 b23060e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 |
import streamlit as st
import spacy
import torch
from transformers import BertTokenizer, BertModel
from transformers.models.bert.modeling_bert import BertForMaskedLM
from models.spabert.models.spatial_bert_model import SpatialBertConfig, SpatialBertForMaskedLM, SpatialBertModel
from models.spabert.utils.common_utils import load_spatial_bert_pretrained_weights
from models.spabert.datasets.osm_sample_loader import PbfMapDataset
from torch.utils.data import DataLoader
from PIL import Image
device = torch.device('cpu')
#Spacy Initialization Section
nlp = spacy.load("./models/en_core_web_sm")
#BERT Initialization Section
bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
bert_model = BertModel.from_pretrained("bert-base-uncased")
bert_model.to(device)
bert_model.eval()
#SpaBERT Initialization Section
data_file_path = 'models/spabert/datasets/SPABERT_finetuning_data_combined.json' #Make a new json file with only the geo entities needed, or it takes too long to run.
pretrained_model_path = 'models/spabert/datasets/fine-spabert-base-uncased-finetuned-osm-mn.pth'
config = SpatialBertConfig()
config.output_hidden_states = True
spaBERT_model = SpatialBertForMaskedLM(config)
pre_trained_model = torch.load(pretrained_model_path, map_location=torch.device('cpu'))
spaBERT_model.load_state_dict(bert_model.state_dict(), strict = False)
spaBERT_model.load_state_dict(pre_trained_model, strict=False)
spaBERT_model.to(device)
spaBERT_model.eval()
#Load data using SpatialDataset
spatialDataset = PbfMapDataset(data_file_path = data_file_path,
tokenizer = bert_tokenizer,
max_token_len = 256, #Originally 300
#max_token_len = max_seq_length, #Originally 300
distance_norm_factor = 0.0001,
spatial_dist_fill = 20,
with_type = False,
sep_between_neighbors = True,
label_encoder = None,
mode = None) #If set to None it will use the full dataset for mlm
data_loader = DataLoader(spatialDataset, batch_size=1, num_workers=0, shuffle=False, pin_memory=False, drop_last=False) #issue needs to be fixed with num_workers not stopping after finished
# Create a dictionary to map entity names to indices
entity_index_dict = {entity['pivot_name']: i for i, entity in enumerate(spatialDataset)}
# Ensure names are stored in lowercase for case-insensitive matching
entity_index_dict = {name.lower(): index for name, index in entity_index_dict.items()}
#Pre-aquire the SpaBERT embeddings for all geo-entities within our dataset
def process_entity(batch, model, device):
input_ids = batch['masked_input'].to(device)
attention_mask = batch['attention_mask'].to(device)
position_list_x = batch['norm_lng_list'].to(device)
position_list_y = batch['norm_lat_list'].to(device)
sent_position_ids = batch['sent_position_ids'].to(device)
pseudo_sentence = batch['pseudo_sentence'].to(device)
# Convert tensor to list of token IDs, and decode them into a readable sentence
pseudo_sentence_decoded = bert_tokenizer.decode(pseudo_sentence[0].tolist(), skip_special_tokens=False)
with torch.no_grad():
outputs = spaBERT_model(#input_ids=input_ids,
input_ids=pseudo_sentence,
attention_mask=attention_mask,
sent_position_ids=sent_position_ids,
position_list_x=position_list_x,
position_list_y=position_list_y)
#NOTE: we are ommitting the pseudo_sentence here. Verify that this is correct
spaBERT_embedding = outputs.hidden_states[-1].to(device)
# Extract the [CLS] token embedding (first token)
spaBERT_embedding = spaBERT_embedding[:, 0, :].detach() # [batch_size, hidden_size]
#return pivot_embeddings.cpu().numpy(), input_ids.cpu().numpy()
return spaBERT_embedding, input_ids
spaBERT_embeddings = []
for i, batch in enumerate(data_loader):
if i >= 2: # Stop after processing 3 batches
break
spaBERT_embedding, input_ids = process_entity(batch, spaBERT_model, device)
spaBERT_embeddings.append(spaBERT_embedding)
#st.write("SpaBERT Embedding shape:", spaBERT_embedding[0].shape)
#st.write("SpaBERT Embedding:", spaBERT_embedding[0])
embedding_cache = {}
#Get BERT Embedding for review
def get_bert_embedding(review_text):
#tokenize review
inputs = bert_tokenizer(review_text, return_tensors='pt', padding=True, truncation=True).to(device)
# Forward pass through the BERT model
with torch.no_grad():
outputs = bert_model(**inputs)
# Extract embeddings from the last hidden state
bert_embedding = outputs.last_hidden_state[:, 0, :].detach() #CLS Token
return bert_embedding
#Get SpaBERT Embedding for geo-entity
def get_spaBert_embedding(entity):
entity_index = entity_index_dict.get(entity.lower(), None)
return spaBERT_embeddings[entity_index]
#Go through each review, identify all geo-entities, then extract their SpaBERT embedings
def processSpatialEntities(review, nlp):
doc = nlp(review)
entity_spans = [(ent.start, ent.end, ent.text, ent.label_) for ent in doc.ents]
token_embeddings = []
# Iterate over each entity span and process only geo entities
for start, end, text, label in entity_spans:
if label in ['FAC', 'ORG', 'LOC', 'GPE']: # Filter to geo-entities
#spaBert_emb = get_spaBert_embedding(text)
#token_embeddings.append((text, spaBert_emb))
st.write("Geo-Entity Found in review: ", text)
return token_embeddings
# Function to read reviews from a text file
def load_reviews_from_file(file_path):
reviews = {}
try:
with open(file_path, 'r', encoding='utf-8') as file:
for i, line in enumerate(file):
line = line.strip()
if line: # Ensure the line is not empty
reviews[f"Review {i + 1}"] = line
except FileNotFoundError:
st.error(f"File not found: {file_path}")
return reviews
st.title("SpaGAN Demo")
st.write("Enter a text, and the system will highlight the geo-entities within it.")
# Define a color map and descriptions for different entity types
COLOR_MAP = {
'FAC': ('red', 'Facilities (e.g., buildings, airports)'),
'ORG': ('blue', 'Organizations (e.g., companies, institutions)'),
'LOC': ('purple', 'Locations (e.g., mountain ranges, water bodies)'),
'GPE': ('green', 'Geopolitical Entities (e.g., countries, cities)')
}
# Display the color key
st.write("**Color Key:**")
for label, (color, description) in COLOR_MAP.items():
st.markdown(f"- **{label}**: <span style='color:{color}'>{color}</span> - {description}", unsafe_allow_html=True)
# Text input
#user_input = st.text_area("Input Text", height=200)
# Define example reviews for testing
#example_reviews = {
# "Review 1": "Meh. My brother lives near the Italian Market in South Philly. I went for a visit. Luckily for me, my brother and his girlfriend are foodies. I was able to taste many different cuisines in Philly. Coming from San Francisco, there are places I don't go due to the tourist trap aura and the non-authenticity of it all (Fisherman’s Wharf, Chinatown, etc.). But when I was in Philly, I had to have a cheesesteak... and I had to go to the two most famous places, which of course are right across the street from one another, in a big rivalry, and featured on the Food Network! How cheesy, but essential. We split two, both "wit whiz"? (cheese whiz) one from Geno's and one from Pat's. Pat's was much tastier than Geno's. The meat was seasoned, and the bun and cheese had much more flavor... better of the two... it seems.",
# "Review 2": "Google, headquartered in Mountain View, is a leading tech company in the United States.",
#}
review_file_path = "models/spabert/datasets/SampleReviews.txt"
example_reviews = load_reviews_from_file(review_file_path)
# Dropdown for selecting an example review
user_input = st.selectbox("Select an example review", options=list(example_reviews.keys()))
# Get the selected review text
selected_review = example_reviews[user_input]
# Process the text when the button is clicked
if st.button("Highlight Geo-Entities"):
if selected_review.strip():
bert_embedding = get_bert_embedding(selected_review)
# Debug: Print the shape of the embeddings
st.write("Embedding Shape:", bert_embedding.shape)
# Debug: Print the embeddings themselves (optional)
#st.write("Embeddings:", bert_embedding)
spaBert_embedding = processSpatialEntities(selected_review,nlp)
#combine the embeddings (NOTE: come back and update after testing)
combined_embedding = torch.cat((bert_embedding,spaBERT_embeddings[0]),dim=-1)
st.write("Concatenated Embedding Shape:", combined_embedding.shape)
st.write("Concatenated Embedding:", combined_embedding)
# Process the text using spaCy
doc = nlp(selected_review)
# Highlight geo-entities with different colors
highlighted_text = selected_review
for ent in reversed(doc.ents):
if ent.label_ in COLOR_MAP:
color = COLOR_MAP[ent.label_][0]
highlighted_text = (
highlighted_text[:ent.start_char] +
f"<span style='color:{color}; font-weight:bold'>{ent.text}</span>" +
highlighted_text[ent.end_char:]
)
# Display the highlighted text with HTML support
st.markdown(highlighted_text, unsafe_allow_html=True)
else:
st.error("Please select a review.") |