import streamlit as st import spacy import torch from transformers import BertTokenizer, BertModel from transformers.models.bert.modeling_bert import BertForMaskedLM from models.spabert.models.spatial_bert_model import SpatialBertConfig, SpatialBertForMaskedLM, SpatialBertModel from models.spabert.utils.common_utils import load_spatial_bert_pretrained_weights from models.spabert.datasets.osm_sample_loader import PbfMapDataset from torch.utils.data import DataLoader from PIL import Image device = torch.device('cpu') #Spacy Initialization Section nlp = spacy.load("./models/en_core_web_sm") #BERT Initialization Section bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") bert_model = BertModel.from_pretrained("bert-base-uncased") bert_model.to(device) bert_model.eval() #SpaBERT Initialization Section data_file_path = 'models/spabert/datasets/SPABERT_finetuning_data_combined.json' pretrained_model_path = 'models/spabert/datasets/fine-spabert-base-uncased-finetuned-osm-mn.pth' config = SpatialBertConfig() config.output_hidden_states = True spaBERT_model = SpatialBertForMaskedLM(config) pre_trained_model = torch.load(pretrained_model_path, map_location=torch.device('cpu')) spaBERT_model.load_state_dict(bert_model.state_dict(), strict = False) spaBERT_model.load_state_dict(pre_trained_model, strict=False) spaBERT_model.to(device) spaBERT_model.eval() # Load data using SpatialDataset spatialDataset = PbfMapDataset(data_file_path = data_file_path, tokenizer = tokenizer, #max_token_len = 256, #Originally 300 max_token_len = max_seq_length, #Originally 300 distance_norm_factor = 0.0001, spatial_dist_fill = 20, with_type = False, sep_between_neighbors = True, #Initially false, play around with this potentially? label_encoder = None, #Initially None, potentially change this because we do have real/fake reviews. mode = None) #If set to None it will use the full dataset for mlm data_loader = DataLoader(spatialDataset, batch_size=1, num_workers=0, shuffle=False, pin_memory=False, drop_last=False) #issue needs to be fixed with num_workers not stopping after finished #Get BERT Embedding for review def get_bert_embedding(review_text): #tokenize review inputs = bert_tokenizer(review_text, return_tensors='pt', padding=True, truncation=True).to(device) # Forward pass through the BERT model with torch.no_grad(): outputs = bert_model(**inputs) # Extract embeddings from the last hidden state embeddings = outputs.last_hidden_state[:, 0, :].detach() #CLS Token return embeddings st.title("SpaGAN Demo") st.write("Enter a text, and the system will highlight the geo-entities within it.") # Define a color map and descriptions for different entity types COLOR_MAP = { 'FAC': ('red', 'Facilities (e.g., buildings, airports)'), 'ORG': ('blue', 'Organizations (e.g., companies, institutions)'), 'LOC': ('purple', 'Locations (e.g., mountain ranges, water bodies)'), 'GPE': ('green', 'Geopolitical Entities (e.g., countries, cities)') } # Display the color key st.write("**Color Key:**") for label, (color, description) in COLOR_MAP.items(): st.markdown(f"- **{label}**: {color} - {description}", unsafe_allow_html=True) # Text input #user_input = st.text_area("Input Text", height=200) # Define example reviews for testing example_reviews = { "Review 1": "I visited the Empire State Building in New York last summer, and it was amazing!", "Review 2": "Google, headquartered in Mountain View, is a leading tech company in the United States.", } # Dropdown for selecting an example review user_input = st.selectbox("Select an example review", options=list(example_reviews.keys())) # Get the selected review text selected_review = example_reviews[user_input] # Process the text when the button is clicked if st.button("Highlight Geo-Entities"): if selected_review.strip(): bert_embedding = get_bert_embedding(selected_review) # Debug: Print the shape of the embeddings st.write("Embedding Shape:", bert_embedding.shape) # Debug: Print the embeddings themselves (optional) st.write("Embeddings:", bert_embedding) # Process the text using spaCy doc = nlp(selected_review) # Highlight geo-entities with different colors highlighted_text = selected_review for ent in reversed(doc.ents): if ent.label_ in COLOR_MAP: color = COLOR_MAP[ent.label_][0] highlighted_text = ( highlighted_text[:ent.start_char] + f"{ent.text}" + highlighted_text[ent.end_char:] ) # Display the highlighted text with HTML support st.markdown(highlighted_text, unsafe_allow_html=True) else: st.error("Please select a review.")