Spaces:
Running
Running
Commit
·
41151eb
1
Parent(s):
ec05364
update: evaluation app
Browse files- application_pages/evaluation_app.py +117 -16
application_pages/evaluation_app.py
CHANGED
@@ -1,16 +1,12 @@
|
|
1 |
-
import asyncio
|
2 |
-
import os
|
3 |
from importlib import import_module
|
4 |
|
5 |
import pandas as pd
|
6 |
-
import rich
|
7 |
import streamlit as st
|
8 |
import weave
|
9 |
from dotenv import load_dotenv
|
10 |
|
11 |
from guardrails_genie.guardrails import GuardrailManager
|
12 |
from guardrails_genie.llm import OpenAIModel
|
13 |
-
from guardrails_genie.metrics import AccuracyMetric
|
14 |
|
15 |
|
16 |
def initialize_session_state():
|
@@ -29,6 +25,99 @@ def initialize_session_state():
|
|
29 |
st.session_state.publish_dataset_button = False
|
30 |
if "dataset_ref" not in st.session_state:
|
31 |
st.session_state.dataset_ref = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
|
34 |
initialize_session_state()
|
@@ -54,14 +143,11 @@ if st.session_state.uploaded_file is not None:
|
|
54 |
publish_dataset_button = st.sidebar.button("Publish dataset")
|
55 |
st.session_state.publish_dataset_button = publish_dataset_button
|
56 |
|
57 |
-
if (
|
58 |
-
st.session_state.
|
59 |
-
and
|
60 |
-
st.session_state.dataset_name is not None
|
61 |
-
and st.session_state.dataset_name != ""
|
62 |
-
)
|
63 |
):
|
64 |
-
|
65 |
with st.expander("Evaluation Dataset Preview", expanded=True):
|
66 |
dataframe = pd.read_csv(st.session_state.uploaded_file)
|
67 |
data_list = dataframe.to_dict(orient="records")
|
@@ -74,9 +160,7 @@ if st.session_state.uploaded_file is not None:
|
|
74 |
dataset_name = st.session_state.dataset_name
|
75 |
digest = st.session_state.dataset_ref._digest
|
76 |
dataset_url = f"https://wandb.ai/{entity}/{project}/weave/objects/{dataset_name}/versions/{digest}"
|
77 |
-
st.markdown(
|
78 |
-
f"Dataset published to [**Weave**]({dataset_url})"
|
79 |
-
)
|
80 |
|
81 |
if preview_in_app:
|
82 |
st.dataframe(dataframe.head(20))
|
@@ -86,6 +170,23 @@ if st.session_state.uploaded_file is not None:
|
|
86 |
)
|
87 |
|
88 |
st.session_state.is_dataset_published = True
|
89 |
-
|
90 |
if st.session_state.is_dataset_published:
|
91 |
-
st.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from importlib import import_module
|
2 |
|
3 |
import pandas as pd
|
|
|
4 |
import streamlit as st
|
5 |
import weave
|
6 |
from dotenv import load_dotenv
|
7 |
|
8 |
from guardrails_genie.guardrails import GuardrailManager
|
9 |
from guardrails_genie.llm import OpenAIModel
|
|
|
10 |
|
11 |
|
12 |
def initialize_session_state():
|
|
|
25 |
st.session_state.publish_dataset_button = False
|
26 |
if "dataset_ref" not in st.session_state:
|
27 |
st.session_state.dataset_ref = None
|
28 |
+
if "guardrails" not in st.session_state:
|
29 |
+
st.session_state.guardrails = []
|
30 |
+
if "guardrail_names" not in st.session_state:
|
31 |
+
st.session_state.guardrail_names = []
|
32 |
+
if "start_evaluations_button" not in st.session_state:
|
33 |
+
st.session_state.start_evaluations_button = False
|
34 |
+
|
35 |
+
|
36 |
+
def initialize_guardrails():
|
37 |
+
st.session_state.guardrails = []
|
38 |
+
for guardrail_name in st.session_state.guardrail_names:
|
39 |
+
if guardrail_name == "PromptInjectionSurveyGuardrail":
|
40 |
+
survey_guardrail_model = st.sidebar.selectbox(
|
41 |
+
"Survey Guardrail LLM", ["", "gpt-4o-mini", "gpt-4o"]
|
42 |
+
)
|
43 |
+
if survey_guardrail_model:
|
44 |
+
st.session_state.guardrails.append(
|
45 |
+
getattr(
|
46 |
+
import_module("guardrails_genie.guardrails"),
|
47 |
+
guardrail_name,
|
48 |
+
)(llm_model=OpenAIModel(model_name=survey_guardrail_model))
|
49 |
+
)
|
50 |
+
elif guardrail_name == "PromptInjectionClassifierGuardrail":
|
51 |
+
classifier_model_name = st.sidebar.selectbox(
|
52 |
+
"Classifier Guardrail Model",
|
53 |
+
[
|
54 |
+
"",
|
55 |
+
"ProtectAI/deberta-v3-base-prompt-injection-v2",
|
56 |
+
"wandb://geekyrakshit/guardrails-genie/model-6rwqup9b:v3",
|
57 |
+
],
|
58 |
+
)
|
59 |
+
if classifier_model_name != "":
|
60 |
+
st.session_state.guardrails.append(
|
61 |
+
getattr(
|
62 |
+
import_module("guardrails_genie.guardrails"),
|
63 |
+
guardrail_name,
|
64 |
+
)(model_name=classifier_model_name)
|
65 |
+
)
|
66 |
+
elif guardrail_name == "PresidioEntityRecognitionGuardrail":
|
67 |
+
st.session_state.guardrails.append(
|
68 |
+
getattr(
|
69 |
+
import_module("guardrails_genie.guardrails"),
|
70 |
+
guardrail_name,
|
71 |
+
)(should_anonymize=True)
|
72 |
+
)
|
73 |
+
elif guardrail_name == "RegexEntityRecognitionGuardrail":
|
74 |
+
st.session_state.guardrails.append(
|
75 |
+
getattr(
|
76 |
+
import_module("guardrails_genie.guardrails"),
|
77 |
+
guardrail_name,
|
78 |
+
)(should_anonymize=True)
|
79 |
+
)
|
80 |
+
elif guardrail_name == "TransformersEntityRecognitionGuardrail":
|
81 |
+
st.session_state.guardrails.append(
|
82 |
+
getattr(
|
83 |
+
import_module("guardrails_genie.guardrails"),
|
84 |
+
guardrail_name,
|
85 |
+
)(should_anonymize=True)
|
86 |
+
)
|
87 |
+
elif guardrail_name == "RestrictedTermsJudge":
|
88 |
+
st.session_state.guardrails.append(
|
89 |
+
getattr(
|
90 |
+
import_module("guardrails_genie.guardrails"),
|
91 |
+
guardrail_name,
|
92 |
+
)(should_anonymize=True)
|
93 |
+
)
|
94 |
+
elif guardrail_name == "PromptInjectionLlamaGuardrail":
|
95 |
+
llama_guard_checkpoint_name = st.sidebar.text_input(
|
96 |
+
"Checkpoint Name", value=""
|
97 |
+
)
|
98 |
+
st.session_state.llama_guard_checkpoint_name = llama_guard_checkpoint_name
|
99 |
+
st.session_state.guardrails.append(
|
100 |
+
getattr(
|
101 |
+
import_module("guardrails_genie.guardrails"),
|
102 |
+
guardrail_name,
|
103 |
+
)(
|
104 |
+
checkpoint=(
|
105 |
+
None
|
106 |
+
if st.session_state.llama_guard_checkpoint_name == ""
|
107 |
+
else st.session_state.llama_guard_checkpoint_name
|
108 |
+
)
|
109 |
+
)
|
110 |
+
)
|
111 |
+
else:
|
112 |
+
st.session_state.guardrails.append(
|
113 |
+
getattr(
|
114 |
+
import_module("guardrails_genie.guardrails"),
|
115 |
+
guardrail_name,
|
116 |
+
)()
|
117 |
+
)
|
118 |
+
st.session_state.guardrails_manager = GuardrailManager(
|
119 |
+
guardrails=st.session_state.guardrails
|
120 |
+
)
|
121 |
|
122 |
|
123 |
initialize_session_state()
|
|
|
143 |
publish_dataset_button = st.sidebar.button("Publish dataset")
|
144 |
st.session_state.publish_dataset_button = publish_dataset_button
|
145 |
|
146 |
+
if st.session_state.publish_dataset_button and (
|
147 |
+
st.session_state.dataset_name is not None
|
148 |
+
and st.session_state.dataset_name != ""
|
|
|
|
|
|
|
149 |
):
|
150 |
+
|
151 |
with st.expander("Evaluation Dataset Preview", expanded=True):
|
152 |
dataframe = pd.read_csv(st.session_state.uploaded_file)
|
153 |
data_list = dataframe.to_dict(orient="records")
|
|
|
160 |
dataset_name = st.session_state.dataset_name
|
161 |
digest = st.session_state.dataset_ref._digest
|
162 |
dataset_url = f"https://wandb.ai/{entity}/{project}/weave/objects/{dataset_name}/versions/{digest}"
|
163 |
+
st.markdown(f"Dataset published to [**Weave**]({dataset_url})")
|
|
|
|
|
164 |
|
165 |
if preview_in_app:
|
166 |
st.dataframe(dataframe.head(20))
|
|
|
170 |
)
|
171 |
|
172 |
st.session_state.is_dataset_published = True
|
173 |
+
|
174 |
if st.session_state.is_dataset_published:
|
175 |
+
guardrail_names = st.sidebar.multiselect(
|
176 |
+
"Select Guardrails",
|
177 |
+
options=[
|
178 |
+
cls_name
|
179 |
+
for cls_name, cls_obj in vars(
|
180 |
+
import_module("guardrails_genie.guardrails")
|
181 |
+
).items()
|
182 |
+
if isinstance(cls_obj, type) and cls_name != "GuardrailManager"
|
183 |
+
],
|
184 |
+
)
|
185 |
+
st.session_state.guardrail_names = guardrail_names
|
186 |
+
|
187 |
+
initialize_guardrails()
|
188 |
+
|
189 |
+
start_evaluations_button = st.sidebar.button("Start Evaluations")
|
190 |
+
st.session_state.start_evaluations_button = start_evaluations_button
|
191 |
+
if st.session_state.start_evaluations_button:
|
192 |
+
st.write(len(st.session_state.guardrails))
|