Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import spacy | |
| import torch | |
| from transformers import BertTokenizer, BertModel | |
| from transformers.models.bert.modeling_bert import BertForMaskedLM | |
| from models.spabert.models.spatial_bert_model import SpatialBertConfig, SpatialBertForMaskedLM, SpatialBertModel | |
| from models.spabert.utils.common_utils import load_spatial_bert_pretrained_weights | |
| from models.spabert.datasets.osm_sample_loader import PbfMapDataset | |
| from torch.utils.data import DataLoader | |
| from PIL import Image | |
| device = torch.device('cpu') | |
| #Spacy Initialization Section | |
| nlp = spacy.load("./models/en_core_web_sm") | |
| #BERT Initialization Section | |
| bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") | |
| bert_model = BertModel.from_pretrained("bert-base-uncased") | |
| bert_model.to(device) | |
| bert_model.eval() | |
| #SpaBERT Initialization Section | |
| data_file_path = 'models/spabert/datasets/SPABERT_finetuning_data_combined.json' #Make a new json file with only the geo entities needed, or it takes too long to run. | |
| pretrained_model_path = 'models/spabert/datasets/fine-spabert-base-uncased-finetuned-osm-mn.pth' | |
| config = SpatialBertConfig() | |
| config.output_hidden_states = True | |
| spaBERT_model = SpatialBertForMaskedLM(config) | |
| pre_trained_model = torch.load(pretrained_model_path, map_location=torch.device('cpu')) | |
| spaBERT_model.load_state_dict(bert_model.state_dict(), strict = False) | |
| spaBERT_model.load_state_dict(pre_trained_model, strict=False) | |
| spaBERT_model.to(device) | |
| spaBERT_model.eval() | |
| #Load data using SpatialDataset | |
| spatialDataset = PbfMapDataset(data_file_path = data_file_path, | |
| tokenizer = bert_tokenizer, | |
| max_token_len = 256, #Originally 300 | |
| #max_token_len = max_seq_length, #Originally 300 | |
| distance_norm_factor = 0.0001, | |
| spatial_dist_fill = 20, | |
| with_type = False, | |
| sep_between_neighbors = True, | |
| label_encoder = None, | |
| mode = None) #If set to None it will use the full dataset for mlm | |
| data_loader = DataLoader(spatialDataset, batch_size=1, num_workers=0, shuffle=False, pin_memory=False, drop_last=False) #issue needs to be fixed with num_workers not stopping after finished | |
| #Pre-aquire the SpaBERT embeddings for all geo-entities within our dataset | |
| def process_entity(batch, model, device): | |
| input_ids = batch['masked_input'].to(device) | |
| attention_mask = batch['attention_mask'].to(device) | |
| position_list_x = batch['norm_lng_list'].to(device) | |
| position_list_y = batch['norm_lat_list'].to(device) | |
| sent_position_ids = batch['sent_position_ids'].to(device) | |
| pseudo_sentence = batch['pseudo_sentence'].to(device) | |
| # Convert tensor to list of token IDs, and decode them into a readable sentence | |
| pseudo_sentence_decoded = bert_tokenizer.decode(pseudo_sentence[0].tolist(), skip_special_tokens=False) | |
| with torch.no_grad(): | |
| outputs = spaBERT_model(#input_ids=input_ids, | |
| input_ids=pseudo_sentence, | |
| attention_mask=attention_mask, | |
| sent_position_ids=sent_position_ids, | |
| position_list_x=position_list_x, | |
| position_list_y=position_list_y) | |
| #NOTE: we are ommitting the pseudo_sentence here. Verify that this is correct | |
| spaBERT_embedding = outputs.hidden_states[-1].to(device) | |
| # Extract the [CLS] token embedding (first token) | |
| spaBERT_embedding = embeddings[:, 0, :].detach() # [batch_size, hidden_size] | |
| #pivot_token_len = batch['pivot_token_len'].item() | |
| #pivot_embeddings = embeddings[:, :pivot_token_len, :] | |
| #return pivot_embeddings.cpu().numpy(), input_ids.cpu().numpy() | |
| return spaBERT_embedding, input_ids | |
| spaBERT_embeddings = [] | |
| for i, batch in enumerate(data_loader): | |
| if i >= 2: # Stop after processing 3 batches | |
| break | |
| embeddings, input_ids = process_entity(batch, spaBERT_model, device) | |
| all_embeddings.append(embeddings) | |
| st.write("SpaBERT Embedding shape:", all_embeddings[0].shape) | |
| st.write("SpaBERT Embedding:", all_embeddings[0]) | |
| #Get BERT Embedding for review | |
| def get_bert_embedding(review_text): | |
| #tokenize review | |
| inputs = bert_tokenizer(review_text, return_tensors='pt', padding=True, truncation=True).to(device) | |
| # Forward pass through the BERT model | |
| with torch.no_grad(): | |
| outputs = bert_model(**inputs) | |
| # Extract embeddings from the last hidden state | |
| bert_embedding = outputs.last_hidden_state[:, 0, :].detach() #CLS Token | |
| return bert_embedding | |
| st.title("SpaGAN Demo") | |
| st.write("Enter a text, and the system will highlight the geo-entities within it.") | |
| # Define a color map and descriptions for different entity types | |
| COLOR_MAP = { | |
| 'FAC': ('red', 'Facilities (e.g., buildings, airports)'), | |
| 'ORG': ('blue', 'Organizations (e.g., companies, institutions)'), | |
| 'LOC': ('purple', 'Locations (e.g., mountain ranges, water bodies)'), | |
| 'GPE': ('green', 'Geopolitical Entities (e.g., countries, cities)') | |
| } | |
| # Display the color key | |
| st.write("**Color Key:**") | |
| for label, (color, description) in COLOR_MAP.items(): | |
| st.markdown(f"- **{label}**: <span style='color:{color}'>{color}</span> - {description}", unsafe_allow_html=True) | |
| # Text input | |
| #user_input = st.text_area("Input Text", height=200) | |
| # Define example reviews for testing | |
| example_reviews = { | |
| "Review 1": "I visited the Empire State Building in New York last summer, and it was amazing!", | |
| "Review 2": "Google, headquartered in Mountain View, is a leading tech company in the United States.", | |
| } | |
| # Dropdown for selecting an example review | |
| user_input = st.selectbox("Select an example review", options=list(example_reviews.keys())) | |
| # Get the selected review text | |
| selected_review = example_reviews[user_input] | |
| # Process the text when the button is clicked | |
| if st.button("Highlight Geo-Entities"): | |
| if selected_review.strip(): | |
| bert_embedding = get_bert_embedding(selected_review) | |
| # Debug: Print the shape of the embeddings | |
| st.write("Embedding Shape:", bert_embedding.shape) | |
| # Debug: Print the embeddings themselves (optional) | |
| st.write("Embeddings:", bert_embedding) | |
| #combine the embeddings (NOTE: come back and update after testing) | |
| combined_embedding = torch.cat((bert_embedding,spaBERT_embeddings[0]),dim=-1) | |
| st.write("Concatenated Embedding Shape:", combined_embedding.shape) | |
| st.write("Concatenated Embedding:", combined_embedding) | |
| # Process the text using spaCy | |
| doc = nlp(selected_review) | |
| # Highlight geo-entities with different colors | |
| highlighted_text = selected_review | |
| for ent in reversed(doc.ents): | |
| if ent.label_ in COLOR_MAP: | |
| color = COLOR_MAP[ent.label_][0] | |
| highlighted_text = ( | |
| highlighted_text[:ent.start_char] + | |
| f"<span style='color:{color}; font-weight:bold'>{ent.text}</span>" + | |
| highlighted_text[ent.end_char:] | |
| ) | |
| # Display the highlighted text with HTML support | |
| st.markdown(highlighted_text, unsafe_allow_html=True) | |
| else: | |
| st.error("Please select a review.") |