SpaGAN / app.py
JasonTPhillipsJr's picture
Update app.py
857dba3 verified
raw
history blame
7.59 kB
import streamlit as st
import spacy
import torch
from transformers import BertTokenizer, BertModel
from transformers.models.bert.modeling_bert import BertForMaskedLM
from models.spabert.models.spatial_bert_model import SpatialBertConfig, SpatialBertForMaskedLM, SpatialBertModel
from models.spabert.utils.common_utils import load_spatial_bert_pretrained_weights
from models.spabert.datasets.osm_sample_loader import PbfMapDataset
from torch.utils.data import DataLoader
from PIL import Image
device = torch.device('cpu')
#Spacy Initialization Section
nlp = spacy.load("./models/en_core_web_sm")
#BERT Initialization Section
bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
bert_model = BertModel.from_pretrained("bert-base-uncased")
bert_model.to(device)
bert_model.eval()
#SpaBERT Initialization Section
data_file_path = 'models/spabert/datasets/SPABERT_finetuning_data_combined.json' #Make a new json file with only the geo entities needed, or it takes too long to run.
pretrained_model_path = 'models/spabert/datasets/fine-spabert-base-uncased-finetuned-osm-mn.pth'
config = SpatialBertConfig()
config.output_hidden_states = True
spaBERT_model = SpatialBertForMaskedLM(config)
pre_trained_model = torch.load(pretrained_model_path, map_location=torch.device('cpu'))
spaBERT_model.load_state_dict(bert_model.state_dict(), strict = False)
spaBERT_model.load_state_dict(pre_trained_model, strict=False)
spaBERT_model.to(device)
spaBERT_model.eval()
#Load data using SpatialDataset
spatialDataset = PbfMapDataset(data_file_path = data_file_path,
tokenizer = bert_tokenizer,
max_token_len = 256, #Originally 300
#max_token_len = max_seq_length, #Originally 300
distance_norm_factor = 0.0001,
spatial_dist_fill = 20,
with_type = False,
sep_between_neighbors = True,
label_encoder = None,
mode = None) #If set to None it will use the full dataset for mlm
data_loader = DataLoader(spatialDataset, batch_size=1, num_workers=0, shuffle=False, pin_memory=False, drop_last=False) #issue needs to be fixed with num_workers not stopping after finished
#Pre-aquire the SpaBERT embeddings for all geo-entities within our dataset
def process_entity(batch, model, device):
input_ids = batch['masked_input'].to(device)
attention_mask = batch['attention_mask'].to(device)
position_list_x = batch['norm_lng_list'].to(device)
position_list_y = batch['norm_lat_list'].to(device)
sent_position_ids = batch['sent_position_ids'].to(device)
pseudo_sentence = batch['pseudo_sentence'].to(device)
# Convert tensor to list of token IDs, and decode them into a readable sentence
pseudo_sentence_decoded = bert_tokenizer.decode(pseudo_sentence[0].tolist(), skip_special_tokens=False)
with torch.no_grad():
outputs = spaBERT_model(#input_ids=input_ids,
input_ids=pseudo_sentence,
attention_mask=attention_mask,
sent_position_ids=sent_position_ids,
position_list_x=position_list_x,
position_list_y=position_list_y)
#NOTE: we are ommitting the pseudo_sentence here. Verify that this is correct
spaBERT_embedding = outputs.hidden_states[-1].to(device)
# Extract the [CLS] token embedding (first token)
spaBERT_embedding = embeddings[:, 0, :].detach() # [batch_size, hidden_size]
#pivot_token_len = batch['pivot_token_len'].item()
#pivot_embeddings = embeddings[:, :pivot_token_len, :]
#return pivot_embeddings.cpu().numpy(), input_ids.cpu().numpy()
return spaBERT_embedding, input_ids
spaBERT_embeddings = []
for i, batch in enumerate(data_loader):
if i >= 2: # Stop after processing 3 batches
break
embeddings, input_ids = process_entity(batch, spaBERT_model, device)
all_embeddings.append(embeddings)
st.write("SpaBERT Embedding shape:", all_embeddings[0].shape)
st.write("SpaBERT Embedding:", all_embeddings[0])
#Get BERT Embedding for review
def get_bert_embedding(review_text):
#tokenize review
inputs = bert_tokenizer(review_text, return_tensors='pt', padding=True, truncation=True).to(device)
# Forward pass through the BERT model
with torch.no_grad():
outputs = bert_model(**inputs)
# Extract embeddings from the last hidden state
bert_embedding = outputs.last_hidden_state[:, 0, :].detach() #CLS Token
return bert_embedding
st.title("SpaGAN Demo")
st.write("Enter a text, and the system will highlight the geo-entities within it.")
# Define a color map and descriptions for different entity types
COLOR_MAP = {
'FAC': ('red', 'Facilities (e.g., buildings, airports)'),
'ORG': ('blue', 'Organizations (e.g., companies, institutions)'),
'LOC': ('purple', 'Locations (e.g., mountain ranges, water bodies)'),
'GPE': ('green', 'Geopolitical Entities (e.g., countries, cities)')
}
# Display the color key
st.write("**Color Key:**")
for label, (color, description) in COLOR_MAP.items():
st.markdown(f"- **{label}**: <span style='color:{color}'>{color}</span> - {description}", unsafe_allow_html=True)
# Text input
#user_input = st.text_area("Input Text", height=200)
# Define example reviews for testing
example_reviews = {
"Review 1": "I visited the Empire State Building in New York last summer, and it was amazing!",
"Review 2": "Google, headquartered in Mountain View, is a leading tech company in the United States.",
}
# Dropdown for selecting an example review
user_input = st.selectbox("Select an example review", options=list(example_reviews.keys()))
# Get the selected review text
selected_review = example_reviews[user_input]
# Process the text when the button is clicked
if st.button("Highlight Geo-Entities"):
if selected_review.strip():
bert_embedding = get_bert_embedding(selected_review)
# Debug: Print the shape of the embeddings
st.write("Embedding Shape:", bert_embedding.shape)
# Debug: Print the embeddings themselves (optional)
st.write("Embeddings:", bert_embedding)
#combine the embeddings (NOTE: come back and update after testing)
combined_embedding = torch.cat((bert_embedding,spaBERT_embeddings[0]),dim=-1)
st.write("Concatenated Embedding Shape:", combined_embedding.shape)
st.write("Concatenated Embedding:", combined_embedding)
# Process the text using spaCy
doc = nlp(selected_review)
# Highlight geo-entities with different colors
highlighted_text = selected_review
for ent in reversed(doc.ents):
if ent.label_ in COLOR_MAP:
color = COLOR_MAP[ent.label_][0]
highlighted_text = (
highlighted_text[:ent.start_char] +
f"<span style='color:{color}; font-weight:bold'>{ent.text}</span>" +
highlighted_text[ent.end_char:]
)
# Display the highlighted text with HTML support
st.markdown(highlighted_text, unsafe_allow_html=True)
else:
st.error("Please select a review.")