anamargarida commited on
Commit
c7b9f8a
·
verified ·
1 Parent(s): 04ec2d1

Rename app_7.py to app_8.py

Browse files
Files changed (1) hide show
  1. app_7.py → app_8.py +23 -26
app_7.py → app_8.py RENAMED
@@ -65,7 +65,6 @@ def extract_arguments(text, tokenizer, model, beam_search=True):
65
  with torch.no_grad():
66
  outputs = model(**inputs)
67
 
68
- #st.write("Model output keys:", outputs.keys())
69
 
70
  # Extract logits
71
  start_cause_logits = outputs["start_arg0_logits"][0]
@@ -132,9 +131,8 @@ def extract_arguments(text, tokenizer, model, beam_search=True):
132
  end_signal_logits[start_signal + 5:] = -1e4
133
  end_signal = end_signal_logits.argmax().item()
134
 
135
- if not has_signal:
136
- start_signal = 'NA'
137
- end_signal = 'NA'
138
 
139
 
140
  tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
@@ -161,21 +159,20 @@ def extract_arguments(text, tokenizer, model, beam_search=True):
161
  signal = 'NA'
162
  list1 = [start_cause1, end_cause1, start_effect1, end_effect1, start_signal, end_signal]
163
  list2 = [start_cause2, end_cause2, start_effect2, end_effect2, start_signal, end_signal]
164
- return cause1, cause2, effect1, effect2, signal, list1, list2
165
-
166
-
167
- def mark_text(original_text, span, color):
168
- """Replace extracted span with a colored background marker, starting at the exact first token."""
169
- if span:
170
- # Get the starting index of the span in the original text
171
- start_idx = original_text.find(span)
172
-
173
- # Ensure the span is found and replace it with <mark> tags directly at the first token's position
174
- if start_idx != -1:
175
- # Replace the span with the mark tag at the exact first token position
176
- return re.sub(re.escape(span), f"<mark style='background-color:{color}; padding:2px; border-radius:4px;'>{span}</mark>", original_text, flags=re.IGNORECASE)
177
-
178
- return original_text # Return unchanged text if no span is found
179
 
180
  st.title("Causal Relation Extraction")
181
  input_text = st.text_area("Enter your text here:", height=300)
@@ -184,11 +181,11 @@ beam_search = st.radio("Enable Beam Search?", ('No', 'Yes')) == 'Yes'
184
 
185
  if st.button("Extract1"):
186
  if input_text:
187
- cause1, cause2, effect1, effect2, signal, list1, list2 = extract_arguments(input_text, tokenizer, model, beam_search=beam_search)
188
 
189
- cause_text1 = mark_text(input_text, cause1, "#FFD700") # Gold for cause
190
- effect_text1 = mark_text(input_text, effect1, "#90EE90") # Light green for effect
191
- signal_text = mark_text(input_text, signal, "#FF6347") # Tomato red for signal
192
 
193
  st.markdown(f"<span style='font-size: 24px;'><strong>Relation 1:</strong></span>", unsafe_allow_html=True)
194
  st.markdown(f"**Cause:**<br>{cause_text1}", unsafe_allow_html=True)
@@ -199,9 +196,9 @@ if st.button("Extract1"):
199
 
200
  if beam_search:
201
 
202
- cause_text2 = mark_text(input_text, cause2, "#FFD700") # Gold for cause
203
- effect_text2 = mark_text(input_text, effect2, "#90EE90") # Light green for effect
204
- signal_text = mark_text(input_text, signal, "#FF6347") # Tomato red for signal
205
 
206
  st.markdown(f"<span style='font-size: 24px;'><strong>Relation 2:</strong></span>", unsafe_allow_html=True)
207
  st.markdown(f"**Cause:**<br>{cause_text2}", unsafe_allow_html=True)
 
65
  with torch.no_grad():
66
  outputs = model(**inputs)
67
 
 
68
 
69
  # Extract logits
70
  start_cause_logits = outputs["start_arg0_logits"][0]
 
131
  end_signal_logits[start_signal + 5:] = -1e4
132
  end_signal = end_signal_logits.argmax().item()
133
 
134
+ If not has_signal:
135
+ start_signal, end_signal = None, None
 
136
 
137
 
138
  tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
 
159
  signal = 'NA'
160
  list1 = [start_cause1, end_cause1, start_effect1, end_effect1, start_signal, end_signal]
161
  list2 = [start_cause2, end_cause2, start_effect2, end_effect2, start_signal, end_signal]
162
+ #return cause1, cause2, effect1, effect2, signal, list1, list2
163
+ return start_cause1, end_cause1, start_cause2, end_cause2, start_effect1, end_effect1, start_effect2, end_effect2, start_signal, end_signal
164
+
165
+ def mark_text_by_position(original_text, start_idx, end_idx, color):
166
+ """Marks text in the original string based on character positions."""
167
+ if start_idx is not None and end_idx is not None and start_idx < end_idx:
168
+ return (
169
+ original_text[:start_idx]
170
+ + f"<mark style='background-color:{color}; padding:2px; border-radius:4px;'>"
171
+ + original_text[start_idx:end_idx]
172
+ + "</mark>"
173
+ + original_text[end_idx:]
174
+ )
175
+ return original_text # Return unchanged if indices are invalidt # Return unchanged text if no span is found
 
176
 
177
  st.title("Causal Relation Extraction")
178
  input_text = st.text_area("Enter your text here:", height=300)
 
181
 
182
  if st.button("Extract1"):
183
  if input_text:
184
+ start_cause1, end_cause1, start_cause2, end_cause2, start_effect1, end_effect1, start_effect2, end_effect2, start_signal, end_signal = extract_arguments(input_text, tokenizer, model, beam_search=beam_search)
185
 
186
+ cause_text1 = mark_text(input_text, start_cause1, end_cause1, "#FFD700") # Gold for cause
187
+ effect_text1 = mark_text(input_text, start_effect1, end_effect1, "#90EE90") # Light green for effect
188
+ signal_text = mark_text(input_text, start_signal, end_signal, "#FF6347") # Tomato red for signal
189
 
190
  st.markdown(f"<span style='font-size: 24px;'><strong>Relation 1:</strong></span>", unsafe_allow_html=True)
191
  st.markdown(f"**Cause:**<br>{cause_text1}", unsafe_allow_html=True)
 
196
 
197
  if beam_search:
198
 
199
+ cause_text2 = mark_text(input_text, start_cause2, end_cause2, "#FFD700") # Gold for cause
200
+ effect_text2 = mark_text(input_text, start_effect2, end_effect2, "#90EE90") # Light green for effect
201
+ signal_text = mark_text(input_text, start_signal, end_signal, "#FF6347") # Tomato red for signal
202
 
203
  st.markdown(f"<span style='font-size: 24px;'><strong>Relation 2:</strong></span>", unsafe_allow_html=True)
204
  st.markdown(f"**Cause:**<br>{cause_text2}", unsafe_allow_html=True)