import streamlit as st import spacy import torch from transformers import BertTokenizer, BertModel from transformers.models.bert.modeling_bert import BertForMaskedLM from models.spabert.models.spatial_bert_model import SpatialBertConfig, SpatialBertForMaskedLM, SpatialBertModel from models.spabert.utils.common_utils import load_spatial_bert_pretrained_weights from models.spabert.datasets.osm_sample_loader import PbfMapDataset from torch.utils.data import DataLoader from PIL import Image device = torch.device('cpu') #Spacy Initialization Section nlp = spacy.load("./models/en_core_web_sm") #BERT Initialization Section bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") bert_model = BertModel.from_pretrained("bert-base-uncased") bert_model.to(device) bert_model.eval() #SpaBERT Initialization Section data_file_path = 'models/spabert/datasets/SPABERT_finetuning_data_combined.json' #Make a new json file with only the geo entities needed, or it takes too long to run. pretrained_model_path = 'models/spabert/datasets/fine-spabert-base-uncased-finetuned-osm-mn.pth' config = SpatialBertConfig() config.output_hidden_states = True spaBERT_model = SpatialBertForMaskedLM(config) pre_trained_model = torch.load(pretrained_model_path, map_location=torch.device('cpu')) spaBERT_model.load_state_dict(bert_model.state_dict(), strict = False) spaBERT_model.load_state_dict(pre_trained_model, strict=False) spaBERT_model.to(device) spaBERT_model.eval() #Load data using SpatialDataset spatialDataset = PbfMapDataset(data_file_path = data_file_path, tokenizer = bert_tokenizer, max_token_len = 256, #Originally 300 #max_token_len = max_seq_length, #Originally 300 distance_norm_factor = 0.0001, spatial_dist_fill = 20, with_type = False, sep_between_neighbors = True, label_encoder = None, mode = None) #If set to None it will use the full dataset for mlm data_loader = DataLoader(spatialDataset, batch_size=1, num_workers=0, shuffle=False, pin_memory=False, drop_last=False) #issue needs to be fixed with num_workers not stopping after finished # Create a dictionary to map entity names to indices entity_index_dict = {entity['pivot_name']: i for i, entity in enumerate(spatialDataset)} # Ensure names are stored in lowercase for case-insensitive matching entity_index_dict = {name.lower(): index for name, index in entity_index_dict.items()} #Pre-aquire the SpaBERT embeddings for all geo-entities within our dataset def process_entity(batch, model, device): input_ids = batch['masked_input'].to(device) attention_mask = batch['attention_mask'].to(device) position_list_x = batch['norm_lng_list'].to(device) position_list_y = batch['norm_lat_list'].to(device) sent_position_ids = batch['sent_position_ids'].to(device) pseudo_sentence = batch['pseudo_sentence'].to(device) # Convert tensor to list of token IDs, and decode them into a readable sentence pseudo_sentence_decoded = bert_tokenizer.decode(pseudo_sentence[0].tolist(), skip_special_tokens=False) with torch.no_grad(): outputs = spaBERT_model(#input_ids=input_ids, input_ids=pseudo_sentence, attention_mask=attention_mask, sent_position_ids=sent_position_ids, position_list_x=position_list_x, position_list_y=position_list_y) #NOTE: we are ommitting the pseudo_sentence here. Verify that this is correct spaBERT_embedding = outputs.hidden_states[-1].to(device) # Extract the [CLS] token embedding (first token) spaBERT_embedding = spaBERT_embedding[:, 0, :].detach() # [batch_size, hidden_size] #return pivot_embeddings.cpu().numpy(), input_ids.cpu().numpy() return spaBERT_embedding, input_ids spaBERT_embeddings = [] for i, batch in enumerate(data_loader): if i >= 2: # Stop after processing 3 batches break spaBERT_embedding, input_ids = process_entity(batch, spaBERT_model, device) spaBERT_embeddings.append(spaBERT_embedding) #st.write("SpaBERT Embedding shape:", spaBERT_embedding[0].shape) #st.write("SpaBERT Embedding:", spaBERT_embedding[0]) embedding_cache = {} #Get BERT Embedding for review def get_bert_embedding(review_text): #tokenize review inputs = bert_tokenizer(review_text, return_tensors='pt', padding=True, truncation=True).to(device) # Forward pass through the BERT model with torch.no_grad(): outputs = bert_model(**inputs) # Extract embeddings from the last hidden state bert_embedding = outputs.last_hidden_state[:, 0, :].detach() #CLS Token return bert_embedding #Get SpaBERT Embedding for geo-entity def get_spaBert_embedding(entity): entity_index = entity_index_dict.get(entity.lower(), None) return spaBERT_embeddings[entity_index] #Go through each review, identify all geo-entities, then extract their SpaBERT embedings def processSpatialEntities(review, nlp): doc = nlp(review) entity_spans = [(ent.start, ent.end, ent.text, ent.label_) for ent in doc.ents] token_embeddings = [] # Iterate over each entity span and process only geo entities for start, end, text, label in entity_spans: if label in ['FAC', 'ORG', 'LOC', 'GPE']: # Filter to geo-entities #spaBert_emb = get_spaBert_embedding(text) #token_embeddings.append((text, spaBert_emb)) st.write("Geo-Entity Found in review: ", text) return token_embeddings # Function to read reviews from a text file def load_reviews_from_file(file_path): reviews = {} try: with open(file_path, 'r', encoding='utf-8') as file: for i, line in enumerate(file): line = line.strip() if line: # Ensure the line is not empty reviews[f"Review {i + 1}"] = line except FileNotFoundError: st.error(f"File not found: {file_path}") return reviews st.title("SpaGAN Demo") st.write("Enter a text, and the system will highlight the geo-entities within it.") # Define a color map and descriptions for different entity types COLOR_MAP = { 'FAC': ('red', 'Facilities (e.g., buildings, airports)'), 'ORG': ('blue', 'Organizations (e.g., companies, institutions)'), 'LOC': ('purple', 'Locations (e.g., mountain ranges, water bodies)'), 'GPE': ('green', 'Geopolitical Entities (e.g., countries, cities)') } # Display the color key st.write("**Color Key:**") for label, (color, description) in COLOR_MAP.items(): st.markdown(f"- **{label}**: {color} - {description}", unsafe_allow_html=True) # Text input #user_input = st.text_area("Input Text", height=200) # Define example reviews for testing #example_reviews = { # "Review 1": "Meh. My brother lives near the Italian Market in South Philly. I went for a visit. Luckily for me, my brother and his girlfriend are foodies. I was able to taste many different cuisines in Philly. Coming from San Francisco, there are places I don't go due to the tourist trap aura and the non-authenticity of it all (Fisherman’s Wharf, Chinatown, etc.). But when I was in Philly, I had to have a cheesesteak... and I had to go to the two most famous places, which of course are right across the street from one another, in a big rivalry, and featured on the Food Network! How cheesy, but essential. We split two, both "wit whiz"? (cheese whiz) one from Geno's and one from Pat's. Pat's was much tastier than Geno's. The meat was seasoned, and the bun and cheese had much more flavor... better of the two... it seems.", # "Review 2": "Google, headquartered in Mountain View, is a leading tech company in the United States.", #} review_file_path = "models/spabert/datasets/SampleReviews.txt" example_reviews = load_reviews_from_file(review_file_path) # Dropdown for selecting an example review user_input = st.selectbox("Select an example review", options=list(example_reviews.keys())) # Get the selected review text selected_review = example_reviews[user_input] # Process the text when the button is clicked if st.button("Highlight Geo-Entities"): if selected_review.strip(): bert_embedding = get_bert_embedding(selected_review) # Debug: Print the shape of the embeddings st.write("Embedding Shape:", bert_embedding.shape) # Debug: Print the embeddings themselves (optional) #st.write("Embeddings:", bert_embedding) spaBert_embedding = processSpatialEntities(selected_review,nlp) #combine the embeddings (NOTE: come back and update after testing) combined_embedding = torch.cat((bert_embedding,spaBERT_embeddings[0]),dim=-1) st.write("Concatenated Embedding Shape:", combined_embedding.shape) st.write("Concatenated Embedding:", combined_embedding) # Process the text using spaCy doc = nlp(selected_review) # Highlight geo-entities with different colors highlighted_text = selected_review for ent in reversed(doc.ents): if ent.label_ in COLOR_MAP: color = COLOR_MAP[ent.label_][0] highlighted_text = ( highlighted_text[:ent.start_char] + f"{ent.text}" + highlighted_text[ent.end_char:] ) # Display the highlighted text with HTML support st.markdown(highlighted_text, unsafe_allow_html=True) else: st.error("Please select a review.")