geekyrakshit commited on
Commit
41151eb
·
1 Parent(s): ec05364

update: evaluation app

Browse files
Files changed (1) hide show
  1. application_pages/evaluation_app.py +117 -16
application_pages/evaluation_app.py CHANGED
@@ -1,16 +1,12 @@
1
- import asyncio
2
- import os
3
  from importlib import import_module
4
 
5
  import pandas as pd
6
- import rich
7
  import streamlit as st
8
  import weave
9
  from dotenv import load_dotenv
10
 
11
  from guardrails_genie.guardrails import GuardrailManager
12
  from guardrails_genie.llm import OpenAIModel
13
- from guardrails_genie.metrics import AccuracyMetric
14
 
15
 
16
  def initialize_session_state():
@@ -29,6 +25,99 @@ def initialize_session_state():
29
  st.session_state.publish_dataset_button = False
30
  if "dataset_ref" not in st.session_state:
31
  st.session_state.dataset_ref = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
 
34
  initialize_session_state()
@@ -54,14 +143,11 @@ if st.session_state.uploaded_file is not None:
54
  publish_dataset_button = st.sidebar.button("Publish dataset")
55
  st.session_state.publish_dataset_button = publish_dataset_button
56
 
57
- if (
58
- st.session_state.publish_dataset_button
59
- and (
60
- st.session_state.dataset_name is not None
61
- and st.session_state.dataset_name != ""
62
- )
63
  ):
64
-
65
  with st.expander("Evaluation Dataset Preview", expanded=True):
66
  dataframe = pd.read_csv(st.session_state.uploaded_file)
67
  data_list = dataframe.to_dict(orient="records")
@@ -74,9 +160,7 @@ if st.session_state.uploaded_file is not None:
74
  dataset_name = st.session_state.dataset_name
75
  digest = st.session_state.dataset_ref._digest
76
  dataset_url = f"https://wandb.ai/{entity}/{project}/weave/objects/{dataset_name}/versions/{digest}"
77
- st.markdown(
78
- f"Dataset published to [**Weave**]({dataset_url})"
79
- )
80
 
81
  if preview_in_app:
82
  st.dataframe(dataframe.head(20))
@@ -86,6 +170,23 @@ if st.session_state.uploaded_file is not None:
86
  )
87
 
88
  st.session_state.is_dataset_published = True
89
-
90
  if st.session_state.is_dataset_published:
91
- st.write("Maza Ayega")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from importlib import import_module
2
 
3
  import pandas as pd
 
4
  import streamlit as st
5
  import weave
6
  from dotenv import load_dotenv
7
 
8
  from guardrails_genie.guardrails import GuardrailManager
9
  from guardrails_genie.llm import OpenAIModel
 
10
 
11
 
12
  def initialize_session_state():
 
25
  st.session_state.publish_dataset_button = False
26
  if "dataset_ref" not in st.session_state:
27
  st.session_state.dataset_ref = None
28
+ if "guardrails" not in st.session_state:
29
+ st.session_state.guardrails = []
30
+ if "guardrail_names" not in st.session_state:
31
+ st.session_state.guardrail_names = []
32
+ if "start_evaluations_button" not in st.session_state:
33
+ st.session_state.start_evaluations_button = False
34
+
35
+
36
+ def initialize_guardrails():
37
+ st.session_state.guardrails = []
38
+ for guardrail_name in st.session_state.guardrail_names:
39
+ if guardrail_name == "PromptInjectionSurveyGuardrail":
40
+ survey_guardrail_model = st.sidebar.selectbox(
41
+ "Survey Guardrail LLM", ["", "gpt-4o-mini", "gpt-4o"]
42
+ )
43
+ if survey_guardrail_model:
44
+ st.session_state.guardrails.append(
45
+ getattr(
46
+ import_module("guardrails_genie.guardrails"),
47
+ guardrail_name,
48
+ )(llm_model=OpenAIModel(model_name=survey_guardrail_model))
49
+ )
50
+ elif guardrail_name == "PromptInjectionClassifierGuardrail":
51
+ classifier_model_name = st.sidebar.selectbox(
52
+ "Classifier Guardrail Model",
53
+ [
54
+ "",
55
+ "ProtectAI/deberta-v3-base-prompt-injection-v2",
56
+ "wandb://geekyrakshit/guardrails-genie/model-6rwqup9b:v3",
57
+ ],
58
+ )
59
+ if classifier_model_name != "":
60
+ st.session_state.guardrails.append(
61
+ getattr(
62
+ import_module("guardrails_genie.guardrails"),
63
+ guardrail_name,
64
+ )(model_name=classifier_model_name)
65
+ )
66
+ elif guardrail_name == "PresidioEntityRecognitionGuardrail":
67
+ st.session_state.guardrails.append(
68
+ getattr(
69
+ import_module("guardrails_genie.guardrails"),
70
+ guardrail_name,
71
+ )(should_anonymize=True)
72
+ )
73
+ elif guardrail_name == "RegexEntityRecognitionGuardrail":
74
+ st.session_state.guardrails.append(
75
+ getattr(
76
+ import_module("guardrails_genie.guardrails"),
77
+ guardrail_name,
78
+ )(should_anonymize=True)
79
+ )
80
+ elif guardrail_name == "TransformersEntityRecognitionGuardrail":
81
+ st.session_state.guardrails.append(
82
+ getattr(
83
+ import_module("guardrails_genie.guardrails"),
84
+ guardrail_name,
85
+ )(should_anonymize=True)
86
+ )
87
+ elif guardrail_name == "RestrictedTermsJudge":
88
+ st.session_state.guardrails.append(
89
+ getattr(
90
+ import_module("guardrails_genie.guardrails"),
91
+ guardrail_name,
92
+ )(should_anonymize=True)
93
+ )
94
+ elif guardrail_name == "PromptInjectionLlamaGuardrail":
95
+ llama_guard_checkpoint_name = st.sidebar.text_input(
96
+ "Checkpoint Name", value=""
97
+ )
98
+ st.session_state.llama_guard_checkpoint_name = llama_guard_checkpoint_name
99
+ st.session_state.guardrails.append(
100
+ getattr(
101
+ import_module("guardrails_genie.guardrails"),
102
+ guardrail_name,
103
+ )(
104
+ checkpoint=(
105
+ None
106
+ if st.session_state.llama_guard_checkpoint_name == ""
107
+ else st.session_state.llama_guard_checkpoint_name
108
+ )
109
+ )
110
+ )
111
+ else:
112
+ st.session_state.guardrails.append(
113
+ getattr(
114
+ import_module("guardrails_genie.guardrails"),
115
+ guardrail_name,
116
+ )()
117
+ )
118
+ st.session_state.guardrails_manager = GuardrailManager(
119
+ guardrails=st.session_state.guardrails
120
+ )
121
 
122
 
123
  initialize_session_state()
 
143
  publish_dataset_button = st.sidebar.button("Publish dataset")
144
  st.session_state.publish_dataset_button = publish_dataset_button
145
 
146
+ if st.session_state.publish_dataset_button and (
147
+ st.session_state.dataset_name is not None
148
+ and st.session_state.dataset_name != ""
 
 
 
149
  ):
150
+
151
  with st.expander("Evaluation Dataset Preview", expanded=True):
152
  dataframe = pd.read_csv(st.session_state.uploaded_file)
153
  data_list = dataframe.to_dict(orient="records")
 
160
  dataset_name = st.session_state.dataset_name
161
  digest = st.session_state.dataset_ref._digest
162
  dataset_url = f"https://wandb.ai/{entity}/{project}/weave/objects/{dataset_name}/versions/{digest}"
163
+ st.markdown(f"Dataset published to [**Weave**]({dataset_url})")
 
 
164
 
165
  if preview_in_app:
166
  st.dataframe(dataframe.head(20))
 
170
  )
171
 
172
  st.session_state.is_dataset_published = True
173
+
174
  if st.session_state.is_dataset_published:
175
+ guardrail_names = st.sidebar.multiselect(
176
+ "Select Guardrails",
177
+ options=[
178
+ cls_name
179
+ for cls_name, cls_obj in vars(
180
+ import_module("guardrails_genie.guardrails")
181
+ ).items()
182
+ if isinstance(cls_obj, type) and cls_name != "GuardrailManager"
183
+ ],
184
+ )
185
+ st.session_state.guardrail_names = guardrail_names
186
+
187
+ initialize_guardrails()
188
+
189
+ start_evaluations_button = st.sidebar.button("Start Evaluations")
190
+ st.session_state.start_evaluations_button = start_evaluations_button
191
+ if st.session_state.start_evaluations_button:
192
+ st.write(len(st.session_state.guardrails))