Spaces:
Sleeping
Sleeping
import streamlit as st | |
import spacy | |
import torch | |
from transformers import BertTokenizer, BertModel | |
from transformers.models.bert.modeling_bert import BertForMaskedLM | |
from models.spabert.models.spatial_bert_model import SpatialBertConfig, SpatialBertForMaskedLM, SpatialBertModel | |
from models.spabert.utils.common_utils import load_spatial_bert_pretrained_weights | |
from PIL import Image | |
device = torch.device('cpu') | |
#Spacy Initialization Section | |
nlp = spacy.load("./models/en_core_web_sm") | |
#BERT Initialization Section | |
bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") | |
bert_model = BertModel.from_pretrained("bert-base-uncased") | |
bert_model.to(device) | |
bert_model.eval() | |
#SpaBERT Initialization Section | |
data_file_path = 'models/spabert/datasets/SPABERT_finetuning_data_combined.json' | |
pretrained_model_path = 'models/spabert/datasets/fine-spabert-base-uncased-finetuned-osm-mn.pth' | |
#b_model = BertForMaskedLM.from_pretrained('bert-base-uncased') | |
#b_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
config = SpatialBertConfig() | |
config.output_hidden_states = True | |
spaBERT_model = SpatialBertForMaskedLM(config) | |
pre_trained_model = torch.load(pretrained_model_path, map_location=torch.device('cpu')) | |
spaBERT_model.load_state_dict(bert_model.state_dict(), strict = False) | |
spaBERT_model.load_state_dict(pre_trained_model, strict=False) | |
spaBERT_model.to(device) | |
spaBERT_model.eval() | |
st.title("SpaGAN Demo") | |
st.write("Enter a text, and the system will highlight the geo-entities within it.") | |
# Define a color map and descriptions for different entity types | |
COLOR_MAP = { | |
'FAC': ('red', 'Facilities (e.g., buildings, airports)'), | |
'ORG': ('blue', 'Organizations (e.g., companies, institutions)'), | |
'LOC': ('purple', 'Locations (e.g., mountain ranges, water bodies)'), | |
'GPE': ('green', 'Geopolitical Entities (e.g., countries, cities)') | |
} | |
# Display the color key | |
st.write("**Color Key:**") | |
for label, (color, description) in COLOR_MAP.items(): | |
st.markdown(f"- **{label}**: <span style='color:{color}'>{color}</span> - {description}", unsafe_allow_html=True) | |
# Text input | |
user_input = st.text_area("Input Text", height=200) | |
# Process the text when the button is clicked | |
if st.button("Highlight Geo-Entities"): | |
if user_input.strip(): | |
# Process the text using spaCy | |
doc = nlp(user_input) | |
# Highlight geo-entities with different colors | |
highlighted_text = user_input | |
for ent in reversed(doc.ents): | |
if ent.label_ in COLOR_MAP: | |
color = COLOR_MAP[ent.label_][0] | |
highlighted_text = ( | |
highlighted_text[:ent.start_char] + | |
f"<span style='color:{color}; font-weight:bold'>{ent.text}</span>" + | |
highlighted_text[ent.end_char:] | |
) | |
# Display the highlighted text with HTML support | |
st.markdown(highlighted_text, unsafe_allow_html=True) | |
else: | |
st.error("Please enter some text.") |