Spaces:
Running
Running
annotation
Browse files
app.py
CHANGED
|
@@ -133,9 +133,11 @@ def main():
|
|
| 133 |
# col1, col2, col3 = st.columns(3, gap="medium")
|
| 134 |
col1, col2, col3 = st.columns([2, 2, 1], gap="medium")
|
| 135 |
sentiment = col1.slider(
|
| 136 |
-
"Sentiment
|
|
|
|
| 137 |
detoxification = col2.slider(
|
| 138 |
-
"Detoxification Strength
|
|
|
|
| 139 |
steer_interval)
|
| 140 |
max_length = col3.number_input("Max length", 50, 300, 50, 50)
|
| 141 |
col1, col2, col3, _ = st.columns(4)
|
|
@@ -144,15 +146,16 @@ def main():
|
|
| 144 |
if "output" not in st.session_state:
|
| 145 |
st.session_state.output = ""
|
| 146 |
if col1.button("Steer and generate!", type="primary"):
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
st.session_state.
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
|
|
|
| 156 |
analyzed_text = \
|
| 157 |
st.text_area("Generated text:", st.session_state.output, height=200)
|
| 158 |
|
|
@@ -176,46 +179,51 @@ def main():
|
|
| 176 |
[2, 0],
|
| 177 |
["#ff7f0e", "#1f77b4"],
|
| 178 |
):
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
|
| 220 |
st.divider()
|
| 221 |
st.divider()
|
|
@@ -234,7 +242,8 @@ def main():
|
|
| 234 |
["Sentiment", "Detoxification"],
|
| 235 |
)
|
| 236 |
dim = 2 if dimension == "Sentiment" else 0
|
| 237 |
-
|
|
|
|
| 238 |
|
| 239 |
|
| 240 |
if __name__ == "__main__":
|
|
|
|
| 133 |
# col1, col2, col3 = st.columns(3, gap="medium")
|
| 134 |
col1, col2, col3 = st.columns([2, 2, 1], gap="medium")
|
| 135 |
sentiment = col1.slider(
|
| 136 |
+
"Sentiment (the larger the more positive)",
|
| 137 |
+
-steer_range, steer_range, 3.0, steer_interval)
|
| 138 |
detoxification = col2.slider(
|
| 139 |
+
"Detoxification Strength (the larger the less toxic)",
|
| 140 |
+
-steer_range, steer_range, 0.0,
|
| 141 |
steer_interval)
|
| 142 |
max_length = col3.number_input("Max length", 50, 300, 50, 50)
|
| 143 |
col1, col2, col3, _ = st.columns(4)
|
|
|
|
| 146 |
if "output" not in st.session_state:
|
| 147 |
st.session_state.output = ""
|
| 148 |
if col1.button("Steer and generate!", type="primary"):
|
| 149 |
+
with st.spinner("Generating..."):
|
| 150 |
+
steer_values = [detoxification, 0, sentiment, 0]
|
| 151 |
+
st.session_state.output = model.generate(
|
| 152 |
+
st.session_state.prompt,
|
| 153 |
+
steer_values,
|
| 154 |
+
seed=None if randomness else 0,
|
| 155 |
+
min_length=0,
|
| 156 |
+
max_length=max_length,
|
| 157 |
+
do_sample=True,
|
| 158 |
+
)
|
| 159 |
analyzed_text = \
|
| 160 |
st.text_area("Generated text:", st.session_state.output, height=200)
|
| 161 |
|
|
|
|
| 179 |
[2, 0],
|
| 180 |
["#ff7f0e", "#1f77b4"],
|
| 181 |
):
|
| 182 |
+
with st.spinner(f"Analyzing {name}..."):
|
| 183 |
+
col.subheader(name)
|
| 184 |
+
# classification
|
| 185 |
+
col.markdown(
|
| 186 |
+
"##### Dimension-Wise Classification Distribution")
|
| 187 |
+
_, dist_list, _ = model.steer_analysis(
|
| 188 |
+
analyzed_text,
|
| 189 |
+
dim, -steer_range, steer_range,
|
| 190 |
+
bins=2*int(steer_range)+1,
|
| 191 |
+
)
|
| 192 |
+
dist_list = np.array(dist_list)
|
| 193 |
+
col.bar_chart(
|
| 194 |
+
pd.DataFrame(
|
| 195 |
+
{
|
| 196 |
+
"Value": dist_list[:, 0],
|
| 197 |
+
"Probability": dist_list[:, 1],
|
| 198 |
+
}
|
| 199 |
+
), x="Value", y="Probability",
|
| 200 |
+
color=color,
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
# key tokens
|
| 204 |
+
pos_steer, neg_steer = np.zeros((2, 4))
|
| 205 |
+
pos_steer[dim] = 1
|
| 206 |
+
neg_steer[dim] = -1
|
| 207 |
+
_, token_evidence = model.evidence_words(
|
| 208 |
+
analyzed_text,
|
| 209 |
+
[pos_steer, neg_steer],
|
| 210 |
+
)
|
| 211 |
+
tokens = tokenizer(analyzed_text).input_ids
|
| 212 |
+
tokens = [f"{i:3d}: {tokenizer.decode([t])}"
|
| 213 |
+
for i, t in enumerate(tokens)]
|
| 214 |
+
col.markdown("##### Token's Evidence Score in the Dimension")
|
| 215 |
+
col.write("The polarity of the token's evidence score "
|
| 216 |
+
"which aligns with sliding bar directions."
|
| 217 |
+
)
|
| 218 |
+
col.bar_chart(
|
| 219 |
+
pd.DataFrame(
|
| 220 |
+
{
|
| 221 |
+
"Token": tokens[1:],
|
| 222 |
+
"Evidence": token_evidence,
|
| 223 |
+
}
|
| 224 |
+
), x="Token", y="Evidence",
|
| 225 |
+
horizontal=True, color=color,
|
| 226 |
+
)
|
| 227 |
|
| 228 |
st.divider()
|
| 229 |
st.divider()
|
|
|
|
| 242 |
["Sentiment", "Detoxification"],
|
| 243 |
)
|
| 244 |
dim = 2 if dimension == "Sentiment" else 0
|
| 245 |
+
with st.spinner("Analyzing..."):
|
| 246 |
+
word_embedding_space_analysis(model, tokenizer, dim)
|
| 247 |
|
| 248 |
|
| 249 |
if __name__ == "__main__":
|