Spaces:
Sleeping
Sleeping
Rename app_7.py to app_8.py
Browse files- app_7.py → app_8.py +23 -26
app_7.py → app_8.py
RENAMED
@@ -65,7 +65,6 @@ def extract_arguments(text, tokenizer, model, beam_search=True):
|
|
65 |
with torch.no_grad():
|
66 |
outputs = model(**inputs)
|
67 |
|
68 |
-
#st.write("Model output keys:", outputs.keys())
|
69 |
|
70 |
# Extract logits
|
71 |
start_cause_logits = outputs["start_arg0_logits"][0]
|
@@ -132,9 +131,8 @@ def extract_arguments(text, tokenizer, model, beam_search=True):
|
|
132 |
end_signal_logits[start_signal + 5:] = -1e4
|
133 |
end_signal = end_signal_logits.argmax().item()
|
134 |
|
135 |
-
|
136 |
-
start_signal =
|
137 |
-
end_signal = 'NA'
|
138 |
|
139 |
|
140 |
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
|
@@ -161,21 +159,20 @@ def extract_arguments(text, tokenizer, model, beam_search=True):
|
|
161 |
signal = 'NA'
|
162 |
list1 = [start_cause1, end_cause1, start_effect1, end_effect1, start_signal, end_signal]
|
163 |
list2 = [start_cause2, end_cause2, start_effect2, end_effect2, start_signal, end_signal]
|
164 |
-
return cause1, cause2, effect1, effect2, signal, list1, list2
|
165 |
-
|
166 |
-
|
167 |
-
def
|
168 |
-
"""
|
169 |
-
if
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
return original_text # Return unchanged text if no span is found
|
179 |
|
180 |
st.title("Causal Relation Extraction")
|
181 |
input_text = st.text_area("Enter your text here:", height=300)
|
@@ -184,11 +181,11 @@ beam_search = st.radio("Enable Beam Search?", ('No', 'Yes')) == 'Yes'
|
|
184 |
|
185 |
if st.button("Extract1"):
|
186 |
if input_text:
|
187 |
-
|
188 |
|
189 |
-
cause_text1 = mark_text(input_text,
|
190 |
-
effect_text1 = mark_text(input_text,
|
191 |
-
signal_text = mark_text(input_text,
|
192 |
|
193 |
st.markdown(f"<span style='font-size: 24px;'><strong>Relation 1:</strong></span>", unsafe_allow_html=True)
|
194 |
st.markdown(f"**Cause:**<br>{cause_text1}", unsafe_allow_html=True)
|
@@ -199,9 +196,9 @@ if st.button("Extract1"):
|
|
199 |
|
200 |
if beam_search:
|
201 |
|
202 |
-
cause_text2 = mark_text(input_text,
|
203 |
-
effect_text2 = mark_text(input_text,
|
204 |
-
signal_text = mark_text(input_text,
|
205 |
|
206 |
st.markdown(f"<span style='font-size: 24px;'><strong>Relation 2:</strong></span>", unsafe_allow_html=True)
|
207 |
st.markdown(f"**Cause:**<br>{cause_text2}", unsafe_allow_html=True)
|
|
|
65 |
with torch.no_grad():
|
66 |
outputs = model(**inputs)
|
67 |
|
|
|
68 |
|
69 |
# Extract logits
|
70 |
start_cause_logits = outputs["start_arg0_logits"][0]
|
|
|
131 |
end_signal_logits[start_signal + 5:] = -1e4
|
132 |
end_signal = end_signal_logits.argmax().item()
|
133 |
|
134 |
+
If not has_signal:
|
135 |
+
start_signal, end_signal = None, None
|
|
|
136 |
|
137 |
|
138 |
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
|
|
|
159 |
signal = 'NA'
|
160 |
list1 = [start_cause1, end_cause1, start_effect1, end_effect1, start_signal, end_signal]
|
161 |
list2 = [start_cause2, end_cause2, start_effect2, end_effect2, start_signal, end_signal]
|
162 |
+
#return cause1, cause2, effect1, effect2, signal, list1, list2
|
163 |
+
return start_cause1, end_cause1, start_cause2, end_cause2, start_effect1, end_effect1, start_effect2, end_effect2, start_signal, end_signal
|
164 |
+
|
165 |
+
def mark_text_by_position(original_text, start_idx, end_idx, color):
|
166 |
+
"""Marks text in the original string based on character positions."""
|
167 |
+
if start_idx is not None and end_idx is not None and start_idx < end_idx:
|
168 |
+
return (
|
169 |
+
original_text[:start_idx]
|
170 |
+
+ f"<mark style='background-color:{color}; padding:2px; border-radius:4px;'>"
|
171 |
+
+ original_text[start_idx:end_idx]
|
172 |
+
+ "</mark>"
|
173 |
+
+ original_text[end_idx:]
|
174 |
+
)
|
175 |
+
return original_text # Return unchanged if indices are invalidt # Return unchanged text if no span is found
|
|
|
176 |
|
177 |
st.title("Causal Relation Extraction")
|
178 |
input_text = st.text_area("Enter your text here:", height=300)
|
|
|
181 |
|
182 |
if st.button("Extract1"):
|
183 |
if input_text:
|
184 |
+
start_cause1, end_cause1, start_cause2, end_cause2, start_effect1, end_effect1, start_effect2, end_effect2, start_signal, end_signal = extract_arguments(input_text, tokenizer, model, beam_search=beam_search)
|
185 |
|
186 |
+
cause_text1 = mark_text(input_text, start_cause1, end_cause1, "#FFD700") # Gold for cause
|
187 |
+
effect_text1 = mark_text(input_text, start_effect1, end_effect1, "#90EE90") # Light green for effect
|
188 |
+
signal_text = mark_text(input_text, start_signal, end_signal, "#FF6347") # Tomato red for signal
|
189 |
|
190 |
st.markdown(f"<span style='font-size: 24px;'><strong>Relation 1:</strong></span>", unsafe_allow_html=True)
|
191 |
st.markdown(f"**Cause:**<br>{cause_text1}", unsafe_allow_html=True)
|
|
|
196 |
|
197 |
if beam_search:
|
198 |
|
199 |
+
cause_text2 = mark_text(input_text, start_cause2, end_cause2, "#FFD700") # Gold for cause
|
200 |
+
effect_text2 = mark_text(input_text, start_effect2, end_effect2, "#90EE90") # Light green for effect
|
201 |
+
signal_text = mark_text(input_text, start_signal, end_signal, "#FF6347") # Tomato red for signal
|
202 |
|
203 |
st.markdown(f"<span style='font-size: 24px;'><strong>Relation 2:</strong></span>", unsafe_allow_html=True)
|
204 |
st.markdown(f"**Cause:**<br>{cause_text2}", unsafe_allow_html=True)
|