File size: 3,207 Bytes
ac736ed
5c91758
bc50d7d
3d81019
ebf50a4
 
b7111b8
 
 
83e90d7
ac736ed
ebf50a4
1aa7dda
3d81019
dc9ff0b
3d81019
 
 
 
dc9ff0b
393b919
 
3d81019
fbce538
 
3d81019
fbce538
 
3d81019
fbce538
 
3d81019
811f100
3d81019
80744c0
 
dc9ff0b
80744c0
 
 
 
 
 
 
 
 
 
3d81019
 
 
 
 
 
 
 
 
 
 
 
ac736ed
5c91758
 
83e90d7
9822204
3577a57
9822204
 
 
 
3577a57
 
9822204
3577a57
9822204
 
3577a57
9822204
5c91758
83e90d7
5c91758
 
 
 
 
83e90d7
b4303dc
 
 
 
 
 
 
 
 
 
83e90d7
3577a57
 
5c91758
3577a57
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import streamlit as st
import spacy
import torch
from transformers import BertTokenizer, BertModel
from transformers.models.bert.modeling_bert import BertForMaskedLM

from models.spabert.models.spatial_bert_model import SpatialBertConfig, SpatialBertForMaskedLM, SpatialBertModel
from models.spabert.utils.common_utils import load_spatial_bert_pretrained_weights

from PIL import Image

##LOAD MODEL SECTION##
nlp = spacy.load("./models/en_core_web_sm")

#BERT Initialization Section
bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
bert_model = BertModel.from_pretrained("bert-base-uncased")
bert_model.eval()

#SpaBERT Initialization Section
data_file_path = 'models/spabert/datasets/SPABERT_finetuning_data_combined.json'
pretrained_model = 'models/spabert/datasets/fine-spabert-base-uncased-finetuned-osm-mn.pth'

b_model =  BertForMaskedLM.from_pretrained('bert-base-uncased')
b_tokenizer =  BertTokenizer.from_pretrained('bert-base-uncased')

config = SpatialBertConfig()
config.output_hidden_states = True

spaBERT_model = SpatialBertForMaskedLM(config)
spaBERT_model.load_state_dict(b_model.state_dict(), strict = False)

pre_trained_model = torch.load(pretrained_model, map_location=torch.device('cpu'))

model_keys = spaBERT_model.state_dict()
cnt_layers = 0
for key in model_keys:
    if key in pre_trained_model:
        model_keys[key] = pre_trained_model[key]
        cnt_layers += 1
    else:
        print("No weight for", key)
print(cnt_layers, 'layers loaded')

spaBERT_model.load_state_dict(model_keys)
spaBERT_model.to(device)
spaBERT_model.eval()













st.title("SpaGAN Demo")
st.write("Enter a text, and the system will highlight the geo-entities within it.")

# Define a color map and descriptions for different entity types
COLOR_MAP = {
    'FAC': ('red', 'Facilities (e.g., buildings, airports)'),
    'ORG': ('blue', 'Organizations (e.g., companies, institutions)'),
    'LOC': ('purple', 'Locations (e.g., mountain ranges, water bodies)'),
    'GPE': ('green', 'Geopolitical Entities (e.g., countries, cities)')
}

# Display the color key with descriptions
st.write("**Color Key:**")
for label, (color, description) in COLOR_MAP.items():
    st.markdown(f"- **{label}**: <span style='color:{color}'>{color}</span> - {description}", unsafe_allow_html=True)

# Text input
user_input = st.text_area("Input Text", height=200)

# Process the text when the button is clicked
if st.button("Highlight Geo-Entities"):
    if user_input.strip():
        # Process the text using spaCy
        doc = nlp(user_input)

        # Highlight geo-entities with different colors
        highlighted_text = user_input
        for ent in reversed(doc.ents):
            if ent.label_ in COLOR_MAP:
                color = COLOR_MAP[ent.label_][0]
                highlighted_text = (
                    highlighted_text[:ent.start_char] +
                    f"<span style='color:{color}; font-weight:bold'>{ent.text}</span>" + 
                    highlighted_text[ent.end_char:]
                )

        # Display the highlighted text with HTML support
        st.markdown(highlighted_text, unsafe_allow_html=True)
    else:
        st.error("Please enter some text.")