anamargarida commited on
Commit
f26658a
·
verified ·
1 Parent(s): 456dc1d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +266 -0
app.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from transformers import AutoConfig, AutoTokenizer, AutoModel
4
+ from huggingface_hub import login
5
+ import re
6
+ import copy
7
+ from modeling_st2 import ST2ModelV2, SignalDetector
8
+ from huggingface_hub import hf_hub_download
9
+ from safetensors.torch import load_file
10
+
11
+ hf_token = st.secrets["HUGGINGFACE_TOKEN"]
12
+ login(token=hf_token)
13
+
14
+
15
+ # Load model & tokenizer once (cached for efficiency)
16
+ @st.cache_resource
17
+ def load_model():
18
+
19
+ config = AutoConfig.from_pretrained("roberta-large")
20
+
21
+ tokenizer = AutoTokenizer.from_pretrained("roberta-large", use_fast=True, add_prefix_space=True)
22
+
23
+ class Args:
24
+ def __init__(self):
25
+
26
+ self.dropout = 0.1
27
+ self.signal_classification = True
28
+ self.pretrained_signal_detector = False
29
+
30
+ args = Args()
31
+
32
+ model = ST2ModelV2(args)
33
+
34
+
35
+ repo_id = "anamargarida/SpanExtractionWithSignalCls_2"
36
+ filename = "model.safetensors"
37
+
38
+ # Download the model file
39
+ model_path = hf_hub_download(repo_id=repo_id, filename=filename)
40
+
41
+ # Load the model weights
42
+ state_dict = load_file(model_path)
43
+
44
+ model.load_state_dict(state_dict)
45
+
46
+ return tokenizer, model
47
+
48
+ # Load the model and tokenizer
49
+ tokenizer, model = load_model()
50
+
51
+ model.eval() # Set model to evaluation mode
52
+ def extract_arguments(text, tokenizer, model, beam_search=True):
53
+
54
+ class Args:
55
+ def __init__(self):
56
+ self.signal_classification = True
57
+ self.pretrained_signal_detector = False
58
+
59
+ args = Args()
60
+ inputs = tokenizer(text, return_offsets_mapping=True, return_tensors="pt")
61
+
62
+ # Get tokenized words (for reconstruction later)
63
+ word_ids = inputs.word_ids()
64
+
65
+ with torch.no_grad():
66
+ outputs = model(**inputs)
67
+
68
+
69
+ # Extract logits
70
+ start_cause_logits = outputs["start_arg0_logits"][0]
71
+ end_cause_logits = outputs["end_arg0_logits"][0]
72
+ start_effect_logits = outputs["start_arg1_logits"][0]
73
+ end_effect_logits = outputs["end_arg1_logits"][0]
74
+ start_signal_logits = outputs["start_sig_logits"][0]
75
+ end_signal_logits = outputs["end_sig_logits"][0]
76
+
77
+
78
+ # Set the first and last token logits to a very low value to ignore them
79
+ start_cause_logits[0] = -1e-4
80
+ end_cause_logits[0] = -1e-4
81
+ start_effect_logits[0] = -1e-4
82
+ end_effect_logits[0] = -1e-4
83
+ start_cause_logits[len(inputs["input_ids"][0]) - 1] = -1e-4
84
+ end_cause_logits[len(inputs["input_ids"][0]) - 1] = -1e-4
85
+ start_effect_logits[len(inputs["input_ids"][0]) - 1] = -1e-4
86
+ end_effect_logits[len(inputs["input_ids"][0]) - 1] = -1e-4
87
+
88
+
89
+ # Beam Search for position selection
90
+ if beam_search:
91
+ indices1, indices2, _, _, _ = model.beam_search_position_selector(
92
+ start_cause_logits=start_cause_logits,
93
+ end_cause_logits=end_cause_logits,
94
+ start_effect_logits=start_effect_logits,
95
+ end_effect_logits=end_effect_logits,
96
+ topk=5
97
+ )
98
+ start_cause1, end_cause1, start_effect1, end_effect1 = indices1
99
+ start_cause2, end_cause2, start_effect2, end_effect2 = indices2
100
+ else:
101
+ start_cause1 = start_cause_logits.argmax().item()
102
+ end_cause1 = end_cause_logits.argmax().item()
103
+ start_effect1 = start_effect_logits.argmax().item()
104
+ end_effect1 = end_effect_logits.argmax().item()
105
+
106
+ start_cause2, end_cause2, start_effect2, end_effect2 = None, None, None, None
107
+
108
+
109
+ has_signal = 1
110
+ if args.signal_classification:
111
+ if not args.pretrained_signal_detector:
112
+ has_signal = outputs["signal_classification_logits"].argmax().item()
113
+ else:
114
+ has_signal = signal_detector.predict(text=batch["text"])
115
+
116
+ if has_signal:
117
+ start_signal_logits[0] = -1e-4
118
+ end_signal_logits[0] = -1e-4
119
+
120
+ start_signal_logits[len(inputs["input_ids"][0]) - 1] = -1e-4
121
+ end_signal_logits[len(inputs["input_ids"][0]) - 1] = -1e-4
122
+
123
+ start_signal = start_signal_logits.argmax().item()
124
+ end_signal_logits[:start_signal] = -1e4
125
+ end_signal_logits[start_signal + 5:] = -1e4
126
+ end_signal = end_signal_logits.argmax().item()
127
+
128
+ if not has_signal:
129
+ start_signal, end_signal = None, None
130
+
131
+
132
+ tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
133
+ token_ids = inputs["input_ids"][0]
134
+ offset_mapping = inputs["offset_mapping"][0].tolist()
135
+
136
+ for i, (token, word_id) in enumerate(zip(tokens, word_ids)):
137
+ st.write(f"Token {i}: {token}, Word ID: {word_id}")
138
+
139
+ st.write("Token & offset:")
140
+ for i, (token, offset) in enumerate(zip(tokens, offset_mapping)):
141
+ st.write(f"Token {i}: {token}, Offset: {offset}")
142
+
143
+
144
+ st.write("Token Positions, IDs, and Corresponding Tokens:")
145
+ for position, (token_id, token) in enumerate(zip(token_ids, tokens)):
146
+ st.write(f"Position: {position}, ID: {token_id}, Token: {token}")
147
+
148
+ st.write(f"Start Cause 1: {start_cause1}, End Cause: {end_cause1}")
149
+ st.write(f"Start Effect 1: {start_effect1}, End Cause: {end_effect1}")
150
+ st.write(f"Start Signal: {start_signal}, End Signal: {end_signal}")
151
+
152
+ def extract_span(start, end):
153
+ return tokenizer.convert_tokens_to_string(tokens[start:end+1]) if start is not None and end is not None else ""
154
+
155
+ cause1 = extract_span(start_cause1, end_cause1)
156
+ cause2 = extract_span(start_cause2, end_cause2)
157
+ effect1 = extract_span(start_effect1, end_effect1)
158
+ effect2 = extract_span(start_effect2, end_effect2)
159
+ if has_signal:
160
+ signal = extract_span(start_signal, end_signal)
161
+ if not has_signal:
162
+ signal = 'NA'
163
+ list1 = [start_cause1, end_cause1, start_effect1, end_effect1, start_signal, end_signal]
164
+ list2 = [start_cause2, end_cause2, start_effect2, end_effect2, start_signal, end_signal]
165
+ #return cause1, cause2, effect1, effect2, signal, list1, list2
166
+ return start_cause1, end_cause1, start_cause2, end_cause2, start_effect1, end_effect1, start_effect2, end_effect2, start_signal, end_signal
167
+
168
+
169
+
170
+ def mark_text_by_position(original_text, start_idx, end_idx, color):
171
+ """Marks text in the original string based on character positions."""
172
+ if start_idx is not None and end_idx is not None and start_idx <= end_idx:
173
+ return (
174
+ original_text[:start_idx]
175
+ + f"<mark style='background-color:{color}; padding:2px; border-radius:4px;'>"
176
+ + original_text[start_idx:end_idx]
177
+ + "</mark>"
178
+ + original_text[end_idx:]
179
+ )
180
+ return original_text # Return unchanged if indices are invalidt # Return unchanged text if no span is found
181
+
182
+ def mark_text_by_tokens(tokenizer, tokens, start_idx, end_idx, color):
183
+ """Highlights a span in tokenized text using HTML."""
184
+ highlighted_tokens = copy.deepcopy(tokens) # Avoid modifying original tokens
185
+ if start_idx is not None and end_idx is not None and start_idx <= end_idx:
186
+ highlighted_tokens[start_idx] = f"<span style='background-color:{color}; padding:2px; border-radius:4px;'>{highlighted_tokens[start_idx]}"
187
+ highlighted_tokens[end_idx] = f"{highlighted_tokens[end_idx]}</span>"
188
+ return tokenizer.convert_tokens_to_string(highlighted_tokens)
189
+
190
+ def mark_text_by_word_ids(original_text, token_ids, start_word_id, end_word_id, color):
191
+ """Marks words in the original text based on word IDs from tokenized input."""
192
+ words = original_text.split() # Split text into words
193
+ if start_word_id is not None and end_word_id is not None and start_word_id <= end_word_id:
194
+ words[start_word_id] = f"<mark style='background-color:{color}; padding:2px; border-radius:4px;'>{words[start_word_id]}"
195
+ words[end_word_id] = f"{words[end_word_id]}</mark>"
196
+
197
+ return " ".join(words)
198
+
199
+
200
+
201
+
202
+ st.title("Causal Relation Extraction")
203
+ input_text = st.text_area("Enter your text here:", height=300)
204
+ beam_search = st.radio("Enable Beam Search?", ('No', 'Yes')) == 'Yes'
205
+
206
+
207
+ if st.button("Extract"):
208
+ if input_text:
209
+ start_cause1, end_cause1, start_cause2, end_cause2, start_effect1, end_effect1, start_effect2, end_effect2, start_signal, end_signal = extract_arguments(input_text, tokenizer, model, beam_search=beam_search)
210
+
211
+ cause_text = mark_text_by_position(input_text, start_cause_id, end_cause_id, "#FFD700") # Gold for cause
212
+ effect_text = mark_text_by_position(input_text, start_effect_id, end_effect_id, "#90EE90") # Light green for effect
213
+ signal_text = mark_text_by_position(input_text, start_signal_id, end_signal_id, "#FF6347") # Tomato red for signal
214
+
215
+ st.markdown(f"**Cause:**<br>{cause_text}", unsafe_allow_html=True)
216
+ st.markdown(f"**Effect:**<br>{effect_text}", unsafe_allow_html=True)
217
+ st.markdown(f"**Signal:**<br>{signal_text}", unsafe_allow_html=True)
218
+ else:
219
+ st.warning("Please enter some text before extracting.")
220
+
221
+
222
+ if st.button("Extract1"):
223
+ if input_text:
224
+ start_cause_id, end_cause_id, start_effect_id, end_effect_id, start_signal_id, end_signal_id = extract_arguments(input_text, tokenizer, model, beam_search=beam_search)
225
+
226
+ cause_text = mark_text_by_word_ids(input_text, inputs["input_ids"][0], start_cause_id, end_cause_id, "#FFD700") # Gold for cause
227
+ effect_text = mark_text_by_word_ids(input_text, inputs["input_ids"][0], start_effect_id, end_effect_id, "#90EE90") # Light green for effect
228
+ signal_text = mark_text_by_word_ids(input_text, inputs["input_ids"][0], start_signal_id, end_signal_id, "#FF6347") # Tomato red for signal
229
+
230
+ st.markdown(f"**Cause:**<br>{cause_text}", unsafe_allow_html=True)
231
+ st.markdown(f"**Effect:**<br>{effect_text}", unsafe_allow_html=True)
232
+ st.markdown(f"**Signal:**<br>{signal_text}", unsafe_allow_html=True)
233
+ else:
234
+ st.warning("Please enter some text before extracting.")
235
+
236
+
237
+
238
+
239
+ if st.button("Extract1"):
240
+ if input_text:
241
+ start_cause1, end_cause1, start_cause2, end_cause2, start_effect1, end_effect1, start_effect2, end_effect2, start_signal, end_signal = extract_arguments(input_text, tokenizer, model, beam_search=beam_search)
242
+
243
+ # Convert text to tokenized format
244
+ tokenized_input = tokenizer.tokenize(input_text)
245
+
246
+ cause_text1 = mark_text_by_tokens(tokenizer, tokenized_input, start_cause1, end_cause1, "#FFD700") # Gold for cause
247
+ effect_text1 = mark_text_by_tokens(tokenizer, tokenized_input, start_effect1, end_effect1, "#90EE90") # Light green for effect
248
+ signal_text = mark_text_by_tokens(tokenizer, tokenized_input, start_signal, end_signal, "#FF6347") # Tomato red for signal
249
+
250
+ # Display first relation
251
+ st.markdown(f"<strong>Relation 1:</strong>", unsafe_allow_html=True)
252
+ st.markdown(f"**Cause:** {cause_text1}", unsafe_allow_html=True)
253
+ st.markdown(f"**Effect:** {effect_text1}", unsafe_allow_html=True)
254
+ st.markdown(f"**Signal:** {signal_text}", unsafe_allow_html=True)
255
+
256
+ # Display second relation if beam search is enabled
257
+ if beam_search:
258
+ cause_text2 = mark_text_by_tokens(tokenizer, tokenized_input, start_cause2, end_cause2, "#FFD700")
259
+ effect_text2 = mark_text_by_tokens(tokenizer, tokenized_input, start_effect2, end_effect2, "#90EE90")
260
+
261
+ st.markdown(f"<strong>Relation 2:</strong>", unsafe_allow_html=True)
262
+ st.markdown(f"**Cause:** {cause_text2}", unsafe_allow_html=True)
263
+ st.markdown(f"**Effect:** {effect_text2}", unsafe_allow_html=True)
264
+ st.markdown(f"**Signal:** {signal_text}", unsafe_allow_html=True)
265
+ else:
266
+ st.warning("Please enter some text before extracting.")