Spaces:
Sleeping
Sleeping
File size: 3,018 Bytes
ac736ed 5c91758 bc50d7d 3d81019 ebf50a4 b7111b8 83e90d7 ac736ed d914cbe 1aa7dda 3d81019 dc9ff0b 3d81019 d914cbe 3d81019 dc9ff0b 393b919 d914cbe 3be15aa 3d81019 d914cbe 3d81019 fbce538 3d81019 3be15aa d914cbe 80744c0 3d81019 ac736ed 5c91758 83e90d7 9822204 3577a57 9822204 3577a57 d914cbe 3577a57 9822204 3577a57 9822204 5c91758 83e90d7 5c91758 83e90d7 b4303dc 83e90d7 3577a57 5c91758 3577a57 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import streamlit as st
import spacy
import torch
from transformers import BertTokenizer, BertModel
from transformers.models.bert.modeling_bert import BertForMaskedLM
from models.spabert.models.spatial_bert_model import SpatialBertConfig, SpatialBertForMaskedLM, SpatialBertModel
from models.spabert.utils.common_utils import load_spatial_bert_pretrained_weights
from PIL import Image
device = torch.device('cpu')
#Spacy Initialization Section
nlp = spacy.load("./models/en_core_web_sm")
#BERT Initialization Section
bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
bert_model = BertModel.from_pretrained("bert-base-uncased")
bert_model.to(device)
bert_model.eval()
#SpaBERT Initialization Section
data_file_path = 'models/spabert/datasets/SPABERT_finetuning_data_combined.json'
pretrained_model_path = 'models/spabert/datasets/fine-spabert-base-uncased-finetuned-osm-mn.pth'
#b_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
#b_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
config = SpatialBertConfig()
config.output_hidden_states = True
spaBERT_model = SpatialBertForMaskedLM(config)
pre_trained_model = torch.load(pretrained_model_path, map_location=torch.device('cpu'))
spaBERT_model.load_state_dict(bert_model.state_dict(), strict = False)
spaBERT_model.load_state_dict(pre_trained_model, strict=False)
spaBERT_model.to(device)
spaBERT_model.eval()
st.title("SpaGAN Demo")
st.write("Enter a text, and the system will highlight the geo-entities within it.")
# Define a color map and descriptions for different entity types
COLOR_MAP = {
'FAC': ('red', 'Facilities (e.g., buildings, airports)'),
'ORG': ('blue', 'Organizations (e.g., companies, institutions)'),
'LOC': ('purple', 'Locations (e.g., mountain ranges, water bodies)'),
'GPE': ('green', 'Geopolitical Entities (e.g., countries, cities)')
}
# Display the color key
st.write("**Color Key:**")
for label, (color, description) in COLOR_MAP.items():
st.markdown(f"- **{label}**: <span style='color:{color}'>{color}</span> - {description}", unsafe_allow_html=True)
# Text input
user_input = st.text_area("Input Text", height=200)
# Process the text when the button is clicked
if st.button("Highlight Geo-Entities"):
if user_input.strip():
# Process the text using spaCy
doc = nlp(user_input)
# Highlight geo-entities with different colors
highlighted_text = user_input
for ent in reversed(doc.ents):
if ent.label_ in COLOR_MAP:
color = COLOR_MAP[ent.label_][0]
highlighted_text = (
highlighted_text[:ent.start_char] +
f"<span style='color:{color}; font-weight:bold'>{ent.text}</span>" +
highlighted_text[ent.end_char:]
)
# Display the highlighted text with HTML support
st.markdown(highlighted_text, unsafe_allow_html=True)
else:
st.error("Please enter some text.") |