lorenpe2 commited on
Commit
5585321
·
1 Parent(s): c9566b5

FEAT: Diagnostic mode

Browse files
Files changed (1) hide show
  1. app.py +33 -4
app.py CHANGED
@@ -56,7 +56,7 @@ model = get_model(model_path)
56
  inference_tokenizer = get_tokenizer(model_path)
57
 
58
 
59
- def get_evaluation_data(_context: List) -> List[Tuple[List, str, str]]:
60
  output_data = []
61
  for _dict in _context:
62
  _dict: Dict
@@ -67,11 +67,24 @@ def get_evaluation_data(_context: List) -> List[Tuple[List, str, str]]:
67
  return output_data
68
 
69
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  option = st.selectbox("Choose type of input:",
71
  ["01 - String (one turn per line)",
72
  "02 - JSON (aggregated)",
73
  "03 - JSON (example CA-OOD)",
74
- "04 - JSON (example Elysai)"])
 
75
 
76
 
77
  with st.form("input_text"):
@@ -115,13 +128,13 @@ with st.form("input_text"):
115
  context = st.text_area("Insert JSON here:", value=str(text))
116
 
117
  if "{" in context:
118
- evaluation_data = get_evaluation_data(_context=json.loads(context))
119
  results = []
120
  accuracy = []
121
 
122
  submitted = st.form_submit_button("Submit")
123
  if submitted:
124
- for datapoint in evaluation_data:
125
  c, s, human_label = datapoint
126
  input_tensor = inference_tokenizer.get_item(context=c, actual_sentence=s)
127
  output_model = model(**input_tensor.data).logits
@@ -138,6 +151,22 @@ with st.form("input_text"):
138
  df = pandas.DataFrame(results, columns=["Context", "Query", "Human Label", "Probability (follow)",
139
  "Probability (not-follow)"])
140
  st.dataframe(df)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
  st.markdown("## Description of models:")
143
  for x in sorted(models.values(), key=lambda x: x["model"]):
 
56
  inference_tokenizer = get_tokenizer(model_path)
57
 
58
 
59
+ def get_evaluation_data_from_json(_context: List) -> List[Tuple[List, str, str]]:
60
  output_data = []
61
  for _dict in _context:
62
  _dict: Dict
 
67
  return output_data
68
 
69
 
70
+ def get_evaluation_data_from_dialogue(_context: List) -> List[Tuple[List, str, Union[str | None]]]:
71
+ output_data = []
72
+ for idx, _line in enumerate(_context):
73
+ if idx == 0:
74
+ continue
75
+ actual_context = _context[max(0, idx - 5):idx]
76
+ actual_sentence = _line
77
+ for context_idx in range(len(actual_context)):
78
+ output_data.append((actual_context[-context_idx:], actual_sentence, None))
79
+ return output_data
80
+
81
+
82
  option = st.selectbox("Choose type of input:",
83
  ["01 - String (one turn per line)",
84
  "02 - JSON (aggregated)",
85
  "03 - JSON (example CA-OOD)",
86
+ "04 - JSON (example Elysai)",
87
+ "05 - Diagnostic mode"])
88
 
89
 
90
  with st.form("input_text"):
 
128
  context = st.text_area("Insert JSON here:", value=str(text))
129
 
130
  if "{" in context:
131
+ data_for_evaluation = get_evaluation_data_from_json(_context=json.loads(context))
132
  results = []
133
  accuracy = []
134
 
135
  submitted = st.form_submit_button("Submit")
136
  if submitted:
137
+ for datapoint in data_for_evaluation:
138
  c, s, human_label = datapoint
139
  input_tensor = inference_tokenizer.get_item(context=c, actual_sentence=s)
140
  output_model = model(**input_tensor.data).logits
 
151
  df = pandas.DataFrame(results, columns=["Context", "Query", "Human Label", "Probability (follow)",
152
  "Probability (not-follow)"])
153
  st.dataframe(df)
154
+ elif "05" in option:
155
+ context = st.text_area("Insert dialogue here (one turn per line):")
156
+ submitted = st.form_submit_button("Submit")
157
+ if submitted:
158
+ aggregated_result = []
159
+ data_for_evaluation = get_evaluation_data_from_dialogue(context.split("\n"))
160
+ for datapoint in data_for_evaluation:
161
+ c, s, _ = datapoint
162
+ input_tensor = inference_tokenizer.get_item(context=c, actual_sentence=s)
163
+ output_model = model(**input_tensor.data).logits
164
+ output_model = torch.softmax(output_model, dim=-1).detach().numpy()[0]
165
+ prop_follow = output_model[0]
166
+ prop_not_follow = output_model[1]
167
+
168
+ aggregated_result.append((c, s, prop_follow))
169
+ st.table(aggregated_result)
170
 
171
  st.markdown("## Description of models:")
172
  for x in sorted(models.values(), key=lambda x: x["model"]):